Source code for olive.model.handler.composite

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Any, Dict, List, Optional, Tuple, Union

from olive.common.config_utils import serialize_to_json, validate_config
from olive.common.utils import dict_diff
from olive.constants import Framework, ModelFileFormat
from olive.hardware.accelerator import Device
from olive.model.config.model_config import ModelConfig
from olive.model.config.registry import model_handler_registry
from olive.model.handler.base import OliveModelHandler

logger = logging.getLogger(__name__)


[docs]@model_handler_registry("CompositeModel") class CompositeModelHandler(OliveModelHandler): """CompositeModel represents multiple component models. The only responsibility of CompositeModelHandler is to provider a get_model_components which will iterate all the child models. Whisper is an example composite model that has encoder and decoder components. CompositeModelHandler is a collection of Models. All the child model in the container should have same model type. """ def __init__( self, model_components: List[Union[OliveModelHandler, Dict[str, Any]]], model_component_names: List[str], model_attributes: Optional[Dict[str, Any]] = None, ): super().__init__( model_path=None, framework=Framework.ONNX, model_file_format=ModelFileFormat.COMPOSITE_MODEL, model_attributes=model_attributes, ) self._model_components = [ validate_config(m, ModelConfig).create_model() if isinstance(m, dict) else m for m in model_components ] assert all( isinstance(m, OliveModelHandler) for m in self._model_components ), "All components must be OliveModelHandler or dict" assert len(self._model_components) == len(model_component_names), "Number of components and names must match" self.model_component_names = model_component_names @property def model_components(self): for m in self._model_components: # the parent attributes should be inherited by the child model # child attributes take precedence m.model_attributes = {**(self.model_attributes or {}), **(m.model_attributes or {})} yield m def to_json(self, check_object: bool = False): json_dict = { "type": self.model_type, "config": {"model_attributes": self.model_attributes, "model_component_names": self.model_component_names}, } json_dict["config"]["model_components"] = [] for m in self._model_components: component_json = m.to_json(check_object) # only keep attributes that are different from the parent component_json["config"]["model_attributes"] = dict_diff( component_json["config"]["model_attributes"], self.model_attributes ) json_dict["config"]["model_components"].append(component_json) return serialize_to_json(json_dict, check_object) def get_model_components(self) -> List[Tuple[str, OliveModelHandler]]: return zip(self.model_component_names, self.model_components) def load_model(self, rank: int = None, cache_model: bool = True): raise NotImplementedError def prepare_session( self, inference_settings: Optional[Dict[str, Any]] = None, device: Device = Device.CPU, execution_providers: Union[str, List[str]] = None, rank: Optional[int] = None, ): raise RuntimeError("CompositeModelHandler doesn't have a session of its own") def run_session( self, session: Any = None, inputs: Union[Dict[str, Any], List[Any], Tuple[Any, ...]] = None, **kwargs: Dict[str, Any], ) -> Any: raise RuntimeError("CompositeModelHandler doesn't have a session of its own")