Source code for olive.model.config.model_config

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from olive.common.config_utils import ConfigBase
from olive.common.pydantic_v1 import validator
from olive.model.config.registry import get_model_handler, is_valid_model_type
from olive.resource_path import create_resource_path


[docs]class ModelConfig(ConfigBase): """Input model config which will be used to create the model handler. For example, the config looks like for llama2: .. code-block:: json { "input_model": { "type": "CompositePyTorchModel", "config": { "model_path": "llama_v2", "model_components": [ { "name": "decoder_model", "type": "PyTorchModel", "config": { "model_script": "user_script.py", "io_config": { "input_names": ["tokens", "position_ids", "attn_mask", ...], "output_names": ["logits", "attn_mask_out", ...], "dynamic_axes": { "tokens": { "0": "batch_size", "1": "seq_len" }, "position_ids": { "0": "batch_size", "1": "seq_len" }, "attn_mask": { "0": "batch_size", "1": "max_seq_len" }, ... } }, "model_loader": "load_decoder_model", "dummy_inputs_func": "decoder_inputs" } }, { "name": "decoder_with_past_model", "type": "PyTorchModel", "config": { "model_script": "user_script.py", "io_config": { "input_names": ["tokens_increment", "position_ids_increment", "attn_mask", ...], "output_names": ["logits", "attn_mask_out", ...], "dynamic_axes": { "tokens_increment": { "0": "batch_size", "1": "seq_len_increment" }, "position_ids_increment": { "0": "batch_size", "1": "seq_len_increment" }, "attn_mask": { "0": "batch_size", "1": "max_seq_len" }, ... } }, "model_loader": "load_decoder_with_past_model", "dummy_inputs_func": "decoder_with_past_inputs" } } ] } } } """ type: str config: dict @validator("type") def validate_type(cls, v): if not is_valid_model_type(v): raise ValueError(f"Unknown model type {v}") return v def get_resource_strings(self): cls = get_model_handler(self.type) resource_keys = cls.get_resource_keys() return {k: v for k, v in self.config.items() if k in resource_keys} def get_resource_paths(self): resources = self.get_resource_strings() return {k: create_resource_path(v) for k, v in resources.items()} def create_model(self): cls = get_model_handler(self.type) return cls(**self.config)