Source code for olive.model.config.model_config

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from copy import deepcopy
from pathlib import Path
from typing import Dict

from olive.common.config_utils import NestedConfig
from olive.common.constants import LOCAL_INPUT_MODEL_ID
from olive.common.pydantic_v1 import Field, validator
from olive.common.utils import hash_dict, hash_file, hash_string
from olive.model.config.registry import get_model_handler, is_valid_model_type
from olive.resource_path import create_resource_path

logger = logging.getLogger(__name__)


[docs]class ModelConfig(NestedConfig): """Input model config which will be used to create the model handler.""" type: str = Field(description="The type of the model handler.") config: dict = Field(description="The config for the model handler. Used to initialize the model handler.") @validator("type") def validate_type(cls, v): if not is_valid_model_type(v): raise ValueError(f"Unknown model type {v}") return v.lower() def get_resource_strings(self): cls = get_model_handler(self.type) resource_keys = cls.get_resource_keys() return {k: v for k, v in self.config.items() if k in resource_keys} def get_resource_paths(self): resources = self.get_resource_strings() return {k: create_resource_path(v) for k, v in resources.items()} def create_model(self): cls = get_model_handler(self.type) return cls(**self.config) def get_model_id(self): for v in self.config.values(): if callable(v): return LOCAL_INPUT_MODEL_ID model_identifier = self.get_model_identifier() model_config = deepcopy(self) model_config.config.pop("model_path", None) model_config.config.pop("adapter_path", None) if model_config.config.get("model_attributes"): model_config.config["model_attributes"].pop("additional_files", None) model_config.config["model_attributes"].pop("_name_or_path", None) return hash_dict({"model_identifier": model_identifier, "model_config": model_config.dict()})[:8] def get_model_identifier(self): model_path = self.config.get("model_path") if model_path: model_path_resource_path = create_resource_path(model_path) if ( self.type == "hfmodel" and model_path_resource_path.is_string_name() and self.config.get("adapter_path") is None ): try: # huggingface_hub is a dependency of transformers from huggingface_hub import repo_info except ImportError as exc: logger.exception( "huggingface_hub is not installed. " "Please install huggingface_hub for supporting Huggingface model." ) raise ImportError("huggingface_hub is not installed.") from exc return repo_info(model_path).sha if model_path_resource_path.is_azureml_resource(): return model_path_resource_path.get_path() file_hashes = self._get_model_files_hash(self.config) sorted_file_hashes = sorted(file_hashes) return hash_string("".join(sorted_file_hashes)) def _get_model_files_hash(self, config: Dict): keys = ["model_path", "adapter_path", "model_script", "script_dir"] local_resource_paths = [Path(config[key]) for key in keys if config.get(key)] additional_files = (config.get("model_attributes") or {}).get("additional_files") or [] local_resource_paths.extend(Path(f) for f in additional_files) file_hashes = [] for local_resource_path in local_resource_paths: file_hashes.extend(self._get_file_hash(local_resource_path)) return file_hashes def _get_file_hash(self, file_path: Path): file_hashes = [] if file_path.is_file(): file_hashes.append(hash_file(file_path, block_size=1024 * 1024)[:8]) elif file_path.is_dir(): for file in file_path.iterdir(): file_hashes.extend(self._get_file_hash(file)) return file_hashes