Source code for archai.onnx.onnx_loader
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from os import environ
from typing import List, Optional
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
from archai.common.ordered_dict_logger import OrderedDictLogger
logger = OrderedDictLogger(source=__name__)
[docs]def load_from_onnx(onnx_model_path: str, providers: Optional[List[str]] = None) -> InferenceSession:
"""Load an ONNX-based model from file.
This function loads an ONNX-based model from the specified file path and
returns an ONNX inference session. Performance optimization constants are set as well.
Args:
onnx_model_path: Path to the ONNX model file.
providers: List of providers to use for inference.
Returns:
ONNX inference session.
"""
logger.info(f"Loading model: {onnx_model_path}")
# Constants available in ONNXRuntime that enables performance optimization
OMP_NUM_THREADS = 1
environ["OMP_NUM_THREADS"] = str(OMP_NUM_THREADS)
environ["OMP_WAIT_POLICY"] = "ACTIVE"
options = SessionOptions()
options.intra_op_num_threads = OMP_NUM_THREADS
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
providers = providers or ["CPUExecutionProvider"]
session = InferenceSession(onnx_model_path, sess_options=options, providers=providers)
session.disable_fallback()
return session