Source code for archai.onnx.optimization

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional

from onnx import load_model
from onnxruntime.transformers.onnx_model_gpt2 import Gpt2OnnxModel
from onnxruntime.transformers.optimizer import optimize_by_onnxruntime

from archai.common.file_utils import create_file_name_identifier
from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.onnx.config_utils.onnx_config_base import OnnxConfig
from archai.onnx.optimization_utils.fusion_options import FusionOptions

logger = OrderedDictLogger(source=__name__)

AVAILABLE_ONNX_MODELS = {"gpt2": Gpt2OnnxModel, "gpt2-flex": Gpt2OnnxModel}


[docs]def optimize_onnx( onnx_model_path: str, onnx_config: OnnxConfig, use_gpu: Optional[bool] = False, opt_level: Optional[int] = 1, only_ort: Optional[bool] = False, float16: Optional[bool] = False, input_int32: Optional[bool] = False, ) -> str: """Optimize an ONNX model using a combination of standard ORT-based optimization and additional transformer-based optimization. Args: onnx_model_path: Path to the ONNX model to be optimized. onnx_config: ONNX configuration of model to be optimized. use_gpu: Whether to use GPU during optimization. opt_level: Level of optimization. only_ort: Whether to only apply ORT optimization. float16: Whether to use graph with float16. input_int32: Whether to use inputs with int32. Returns: Path to the optimized ONNX model. """ logger.info(f"Optimizing model: {onnx_model_path}") assert opt_level in [0, 1, 2, 99] ort_model_path = onnx_model_path # Applies standard ORT-based optimization if opt_level > 0: disabled_optimizers = [] if opt_level > 1: # Disables some optimizers that might influence shape inference/attention fusion if not only_ort: disabled_optimizers = [ "MatMulScaleFusion", "MatMulAddFusion", "SimplifiedLayerNormFusion", "GemmActivationFusion", "BiasSoftmaxFusion", ] # Performs the standard ORT optimization ort_model_path = create_file_name_identifier(onnx_model_path, "-opt") optimize_by_onnxruntime( onnx_model_path, use_gpu=use_gpu, optimized_model_path=ort_model_path, opt_level=opt_level, disabled_optimizers=disabled_optimizers, ) if not only_ort: model_type = onnx_config.config.model_type available_models = list(AVAILABLE_ONNX_MODELS.keys()) assert model_type in available_models, f"`model_type`: {model_type} is not supported for `only_ort=False`." # Applies additional transformer-based optimization if onnx_config.is_ort_graph_optimizable: ort_model = load_model(ort_model_path) ort_model_path = create_file_name_identifier(onnx_model_path, "-opt") onnx_opt_model = AVAILABLE_ONNX_MODELS[model_type] options = FusionOptions(model_type) optimizer = onnx_opt_model(ort_model, *onnx_config.ort_graph_optimizer_args) optimizer.optimize(options) optimizer.topological_sort() if float16: ort_model_path = create_file_name_identifier(ort_model_path, "-fp16") optimizer.convert_float_to_float16(keep_io_types=True) if input_int32: optimizer.change_graph_inputs_to_int32() optimizer.save_model_to_file(ort_model_path) return ort_model_path