Source code for archai.discrete_search.evaluators.nlp.transformer_flex_latency
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import pathlib
import shutil
import timeit
from typing import Any, Dict, List, Optional
import numpy as np
import torch
from onnxruntime import InferenceSession
from overrides import overrides
from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.model_evaluator import ModelEvaluator
from archai.discrete_search.search_spaces.nlp.transformer_flex.search_space import (
TransformerFlexSearchSpace,
)
from archai.onnx.config_utils.onnx_config_base import OnnxConfig
from archai.onnx.export import export_to_onnx
from archai.onnx.export_utils import prepare_model_for_onnx
from archai.onnx.onnx_loader import load_from_onnx
from archai.onnx.optimization import optimize_onnx
TMP_FOLDER = pathlib.Path("tmp")
[docs]class TransformerFlexOnnxLatency(ModelEvaluator):
"""Measure the average latency of models from the Transformer-Flex search space."""
def __init__(
self,
search_space: TransformerFlexSearchSpace,
providers: Optional[List[str]] = None,
batch_size: Optional[int] = 1,
seq_len: Optional[int] = 192,
past_seq_len: Optional[int] = 0,
n_trials: Optional[int] = 1,
use_median: Optional[bool] = False,
use_past: Optional[bool] = True,
validate: Optional[bool] = True,
share_weights: Optional[bool] = True,
opset: Optional[int] = 11,
optimize: Optional[bool] = True,
only_ort: Optional[bool] = False,
) -> None:
"""Initialize the evaluator.
This evaluator supports measuring in different ONNX Runtime providers. For measuring on
GPUs, use `providers=["CUDAExecutionProvider"]` and make sure that `onnxruntime-gpu`
package is installed.
Args:
search_space: The search space to use for loading the model.
providers: The list of ORT providers to use for benchmarking.
batch_size: The batch size to use when benchmarking the model.
seq_len: The sequence length to use when benchmarking the model.
past_seq_len: The past sequence length to use when benchmarking the model.
n_trials: The number of trials to use when benchmarking the model.
use_median: Whether to use the median or the mean of the measured
times as the result.
use_past: Whether to include past key/values in the model.
validate: Whether to validate the exported model.
share_weights: Whether to share the embedding and softmax weights.
opset: Set of operations to use with ONNX.
optimize: Whether to optimize the ONNX model.
only_ort: Whether to only apply ORT optimization.
"""
assert search_space.arch_type in ["codegen", "gpt2", "gpt2-flex"]
self.search_space = search_space
# Benchmark settings
self.providers = providers
self.batch_size = batch_size
self.seq_len = seq_len
self.past_seq_len = past_seq_len
self.n_trials = n_trials
self.use_median = use_median
self.use_past = use_past
self.validate = validate
self.share_weights = share_weights
self.opset = opset
self.optimize = optimize
self.only_ort = only_ort
def _load_and_prepare(self, config: Dict[str, Any]) -> torch.nn.Module:
config = copy.deepcopy(config)
if self.use_past:
config["use_cache"] = True
model = self.search_space._load_model_from_config(config)
return prepare_model_for_onnx(model, self.search_space.arch_type)
def _benchmark_model(self, session: InferenceSession, model_config: OnnxConfig) -> float:
inputs = model_config.generate_dummy_inputs(self.batch_size, self.seq_len, self.past_seq_len)
if self.use_past:
past_inputs = inputs.pop("past_key_values")
for i, past in enumerate(past_inputs):
inputs[f"past_{i}"] = past
timer = timeit.Timer(
stmt="onnx_model_session(None, inputs)",
globals={"inputs": {k: v.numpy() for k, v in inputs.items()}, "onnx_model_session": session.run},
)
# Perform a quick warmup prior to the calculation
_ = timer.timeit(number=max(int(self.n_trials // 100), 2))
# Calculate proper set of times (instead of sum)
runner = timer.repeat(repeat=self.n_trials, number=self.n_trials)
runner = [r / self.n_trials for r in runner]
return float(np.median(runner) if self.use_median else np.mean(runner))
[docs] @overrides
def evaluate(self, arch: ArchaiModel, budget: Optional[float] = None) -> float:
model = self._load_and_prepare(arch.metadata["config"])
# There is a bug for Python < 3.10 when using TemporaryFile with Windows,
# thus, we opted to manually save and remove the temporary file
TMP_FOLDER.mkdir(parents=True, exist_ok=True)
onnx_path = TMP_FOLDER / "model.onnx"
onnx_config = export_to_onnx(
model,
onnx_path.as_posix(),
task="causal-lm",
use_past=self.use_past,
validate=self.validate,
share_weights=self.share_weights,
opset=self.opset,
)
if self.optimize:
onnx_path = optimize_onnx(onnx_path.as_posix(), onnx_config, opt_level=0, only_ort=self.only_ort)
session = load_from_onnx(onnx_path, providers=self.providers)
latency = self._benchmark_model(session, onnx_config)
shutil.rmtree(TMP_FOLDER)
return latency