Source code for archai.discrete_search.evaluators.nlp.transformer_flex_memory

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import copy
import pathlib
import shutil
from typing import Any, Dict, Optional

import torch
from overrides import overrides

from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.model_evaluator import ModelEvaluator
from archai.discrete_search.search_spaces.nlp.transformer_flex.search_space import (
    TransformerFlexSearchSpace,
)
from archai.onnx.export import export_to_onnx
from archai.onnx.export_utils import prepare_model_for_onnx
from archai.onnx.optimization import optimize_onnx

TMP_FOLDER = pathlib.Path("tmp")


[docs]class TransformerFlexOnnxMemory(ModelEvaluator): """Measure the memory usage of models from the Transformer-Flex search space.""" def __init__( self, search_space: TransformerFlexSearchSpace, use_past: Optional[bool] = True, validate: Optional[bool] = True, share_weights: Optional[bool] = True, opset: Optional[int] = 11, optimize: Optional[bool] = True, only_ort: Optional[bool] = False, ) -> None: """Initialize the evaluator. Args: search_space: The search space to use for loading the model. use_past: Whether to include past key/values in the model. validate: Whether to validate the exported model. share_weights: Whether to share the embedding and softmax weights. opset: Set of operations to use with ONNX. optimize: Whether to optimize the ONNX model. only_ort: Whether to only apply ORT optimization. """ assert search_space.arch_type in ["codegen", "gpt2", "gpt2-flex"] self.search_space = search_space # Benchmark settings self.use_past = use_past self.validate = validate self.share_weights = share_weights self.opset = opset self.optimize = optimize self.only_ort = only_ort def _load_and_prepare(self, config: Dict[str, Any]) -> torch.nn.Module: config = copy.deepcopy(config) if self.use_past: config["use_cache"] = True model = self.search_space._load_model_from_config(config) return prepare_model_for_onnx(model, self.search_space.arch_type)
[docs] @overrides def evaluate(self, arch: ArchaiModel, budget: Optional[float] = None) -> float: model = self._load_and_prepare(arch.metadata["config"]) # There is a bug for Python < 3.10 when using TemporaryFile with Windows, # thus, we opted to manually save and remove the temporary file TMP_FOLDER.mkdir(parents=True, exist_ok=True) onnx_path = TMP_FOLDER / "model.onnx" onnx_config = export_to_onnx( model, onnx_path.as_posix(), task="causal-lm", use_past=self.use_past, validate=self.validate, share_weights=self.share_weights, opset=self.opset, ) if self.optimize: onnx_path = optimize_onnx(onnx_path.as_posix(), onnx_config, opt_level=0, only_ort=self.only_ort) memory = pathlib.Path(onnx_path).stat().st_size / (1024**2) shutil.rmtree(TMP_FOLDER) return memory