# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
import torch
import yaml
from olive.common.config_utils import serialize_to_json, validate_config
from olive.common.user_module_loader import UserModuleLoader
from olive.constants import Framework, ModelFileFormat
from olive.hardware.accelerator import Device
from olive.model.config import (
HfComponent,
HfConfig,
IoConfig,
complete_kv_cache_with_model_attributes,
extend_io_config_with_kv_cache,
)
from olive.model.config.registry import model_handler_registry
from olive.model.handler.base import OliveModelHandler
from olive.model.handler.mixin import DummyInputsMixin, HfConfigMixin, MLFlowMixin, PytorchKvCacheMixin
from olive.model.utils.hf_utils import load_hf_model_from_model_class
from olive.resource_path import OLIVE_RESOURCE_ANNOTATIONS, ResourceType, create_resource_path
logger = logging.getLogger(__name__)
[docs]@model_handler_registry("PyTorchModel")
class PyTorchModelHandler(
OliveModelHandler, HfConfigMixin, DummyInputsMixin, PytorchKvCacheMixin, MLFlowMixin
): # pylint: disable=too-many-ancestors
"""PyTorch model handler.
Besides the model loading for PyTorch model, the model handler also provides the following functionalities:
* Get the model io configuration either from user provider io_config or from hf_config. The priority is user
provided io_config is higher than hf_config.
* Get the dummy inputs for PyTorch model used to evaluate the latency.
* All kinds of Hf model functionalities by HfConfigMixin.
"""
resource_keys: Tuple[str, ...] = ("model_path", "script_dir", "model_script", "adapter_path")
json_config_keys: Tuple[str, ...] = (
"model_file_format",
"model_loader",
"dummy_inputs_func",
"hf_config",
)
def __init__(
self,
model_path: OLIVE_RESOURCE_ANNOTATIONS = None,
model_file_format: ModelFileFormat = ModelFileFormat.PYTORCH_ENTIRE_MODEL,
model_loader: Union[str, Callable] = None,
model_script: Union[str, Path] = None,
script_dir: Union[str, Path] = None,
io_config: Union[Dict[str, Any], IoConfig, str, Callable] = None,
dummy_inputs_func: Union[str, Callable] = None,
hf_config: Union[Dict[str, Any], HfConfig] = None,
adapter_path: OLIVE_RESOURCE_ANNOTATIONS = None,
model_attributes: Optional[Dict[str, Any]] = None,
mlflow_transformer_model_cache_dir: Optional[str] = None,
):
if not (
isinstance(model_loader, Callable)
or (isinstance(model_loader, str) and model_script)
or model_path
or hf_config
):
raise ValueError(
"model_path is required since model_loader is not callable or model_script is not provided"
)
self.mlflow_transformer_model_cache_dir = mlflow_transformer_model_cache_dir
self.model_loader = model_loader
self.model = None
super().__init__(
framework=Framework.PYTORCH,
model_file_format=model_file_format,
model_path=model_path,
model_attributes=model_attributes,
io_config=io_config,
)
self.add_resources(locals())
self.hf_config = None
if hf_config:
self.hf_config = validate_config(hf_config, HfConfig)
hf_model_config = self.get_hf_model_config().to_dict()
model_attr = self.model_attributes or {}
hf_model_config.update(model_attr)
self.model_attributes = hf_model_config
# ensure that script_dirs are local folder
script_dir_resource = create_resource_path(self.script_dir)
if script_dir_resource:
assert script_dir_resource.type == ResourceType.LocalFolder, "script_dir must be a local directory."
# ensure that model_script is local file or string name
model_script_resource = create_resource_path(self.model_script)
if model_script_resource:
assert model_script_resource.type in (
ResourceType.LocalFile,
ResourceType.StringName,
), "model_script must be a local file or a string name."
self.dummy_inputs_func = dummy_inputs_func
self.dummy_inputs = None
@property
def script_dir(self) -> str:
return self.get_resource("script_dir")
@property
def model_script(self) -> str:
return self.get_resource("model_script")
@property
def adapter_path(self) -> str:
return self.get_resource("adapter_path")
def get_mlflow_transformers_dir(self):
return self.mlflow_transformer_model_cache_dir or self.model_path
def load_model(self, rank: int = None) -> torch.nn.Module:
if self.model is not None:
return self.model
# Load user module at the beginning since we may need user defined models to load model
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
# Load special path or format model -> load model from hf config -> load normal path model
if self.model_loader is not None:
model = user_module_loader.call_object(self.model_loader, self.model_path)
elif self.model_file_format == ModelFileFormat.PYTORCH_TORCH_SCRIPT:
model = torch.jit.load(self.model_path)
elif self.model_file_format == ModelFileFormat.PYTORCH_MLFLOW_MODEL:
model = self._load_mlflow_model()
elif self.hf_config and (self.hf_config.model_class or self.hf_config.task):
model = self.load_hf_model(self.model_path)
elif self.model_file_format == ModelFileFormat.PYTORCH_ENTIRE_MODEL:
model = torch.load(self.model_path)
elif self.model_file_format == ModelFileFormat.PYTORCH_SLICE_GPT_MODEL:
model = self._load_slicegpt_model()
elif self.model_file_format == ModelFileFormat.PYTORCH_STATE_DICT:
raise ValueError("Please use customized model loader to load state dict of model.")
else:
raise ValueError(f"Unsupported model file format: {self.model_file_format}")
# we only have peft adapters for now
if self.adapter_path:
from peft import PeftModel
model = PeftModel.from_pretrained(model, self.adapter_path)
self.model = model
return model
def get_component_model(self, component: HfComponent, rank: Optional[int] = None) -> "PyTorchModelHandler":
if component.component_func is None:
logger.debug("component_func is not provided, using hf_config to get component")
model_component = self.load_hf_model(self.model_path)
else:
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
model_component = user_module_loader.call_object(component.component_func, self)
# the second default parameter is to fix ruff b023:
# https://docs.astral.sh/ruff/rules/function-uses-loop-variable/
def model_loader(_, model_component=model_component):
return model_component
component_hf_config = deepcopy(self.hf_config).dict()
component_hf_config.pop("components", None)
return PyTorchModelHandler(
model_loader=model_loader,
io_config=component.io_config,
dummy_inputs_func=component.dummy_inputs_func,
model_script=self.model_script,
script_dir=self.script_dir,
hf_config=HfConfig.parse_obj(component_hf_config),
model_attributes=self.model_attributes,
)
def prepare_session(
self,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.CPU,
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = None,
):
return self.load_model(rank).eval()
def _load_mlflow_model(self):
logger.info("Loading MLFlow model from %s", self.model_path)
mlflow_transformers_path = self.to_mlflow_transformer_model(self.get_mlflow_transformers_dir())
with open(os.path.join(self.model_path, "MLmodel")) as fp:
mlflow_data = yaml.safe_load(fp)
# default flavor is "hftransformersv2" from azureml.evaluate.mlflow>=0.0.8
# "hftransformers" from azureml.evaluate.mlflow<0.0.8
# TODO(trajep): let user specify flavor name if needed
# to support other flavors in mlflow not only hftransformers
hf_pretrained_class = None
flavors = mlflow_data.get("flavors", {})
if not flavors:
raise ValueError(
"Invalid MLFlow model format. Please make sure the input model"
" format is same with the result of mlflow.transformers.save_model,"
" or aml_mlflow.hftransformers.save_model from azureml.evaluate.mlflow"
)
if "hftransformersv2" in flavors:
hf_pretrained_class = flavors["hftransformersv2"].get("hf_pretrained_class", "AutoModel")
elif "hftransformers" in flavors:
hf_pretrained_class = flavors["hftransformers"].get("hf_pretrained_class", "AutoModel")
else:
raise ValueError(
"Unsupported MLFlow model flavor. Currently only support hftransformersv2/hftransformers."
)
loading_args = self.hf_config.get_loading_args_from_pretrained() if self.hf_config else {}
loaded_model = load_hf_model_from_model_class(hf_pretrained_class, mlflow_transformers_path, **loading_args)
loaded_model.eval()
return loaded_model
def _load_slicegpt_model(self):
logger.info("Loading SliceGPT model from %s", self.model_path)
from slicgpt.hf_utils import load_sliced_model
loaded_model, _ = load_sliced_model(self.hf_config.model_name, self.model_path)
return loaded_model
def to_json(self, check_object: bool = False):
config = super().to_json(check_object)
# add _io_config to config to keep what was provided at init
config["config"]["io_config"] = self._io_config
# only keep model_attributes that are not in hf_config
if self.model_attributes and self.hf_config:
model_attributes = {}
hf_config_dict = self.get_hf_model_config().to_dict()
for key, value in self.model_attributes.items():
if key not in hf_config_dict or hf_config_dict[key] != value:
model_attributes[key] = value
config["config"]["model_attributes"] = model_attributes or None
return serialize_to_json(config, check_object)
def get_user_io_config(self, io_config: Union[Dict[str, Any], IoConfig, str, Callable]) -> Dict[str, Any]:
"""Resolve io_config to a dictionary.
If io_config is a string name or a callable, it will be called to get io_config.
"""
io_config_obj = None
if isinstance(io_config, dict):
io_config_obj = IoConfig.parse_obj(io_config)
elif isinstance(io_config, IoConfig):
# return a new copy of io_config to avoid modifying the original one
io_config_obj = io_config.copy(deep=True)
elif isinstance(io_config, (str, Callable)):
# io_config is a string name or a callable
logger.debug("Calling %s to get io_config", io_config)
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
io_config = user_module_loader.call_object(io_config, self)
io_config_obj = validate_config(io_config, IoConfig)
# TODO(anyone): infer if to use kv_cache from task config
if io_config_obj.kv_cache:
kv_cache_config = complete_kv_cache_with_model_attributes(io_config_obj.kv_cache, self.model_attributes)
io_config_obj = extend_io_config_with_kv_cache(io_config_obj, kv_cache_config)
return io_config_obj.dict(exclude_none=True)
@property
def io_config(self) -> Dict[str, Any]:
"""Return io config of the model.
Priority: io_config > hf_config (using onnx_config)
"""
io_config = None
if self._io_config:
# io_config is provided
io_config = self.get_user_io_config(self._io_config)
elif self.hf_config and self.hf_config.task and not self.hf_config.components:
# hf_config is provided
logger.debug("Trying hf onnx_config to get io_config")
# For MLFlow model, get io config from model_name instead of model_path
# TODO(xiaoyu): more investigation on the integration between MLFlow and HF
io_config = self.get_hf_io_config()
if io_config:
logger.debug("Got io_config from hf_config")
return io_config
[docs]@model_handler_registry("DistributedPyTorchModel")
class DistributedPyTorchModelHandler(OliveModelHandler, HfConfigMixin):
resource_keys: Tuple[str, ...] = ("model_path", "script_dir", "model_script", "adapter_path")
json_config_keys: Tuple[str, ...] = (
"model_name_pattern",
"num_ranks",
"model_loader",
"io_config",
"dummy_inputs_func",
"hf_config",
)
DEFAULT_RANKED_MODEL_NAME_FORMAT: ClassVar[str] = "model_{:02d}"
def __init__(
self,
model_path: OLIVE_RESOURCE_ANNOTATIONS,
model_name_pattern: str,
num_ranks: int,
model_file_format: ModelFileFormat = ModelFileFormat.PYTORCH_ENTIRE_MODEL,
model_loader: Union[str, Callable] = None,
model_script: Union[str, Path] = None,
script_dir: Union[str, Path] = None,
io_config: Union[Dict[str, Any], IoConfig, str, Callable] = None,
dummy_inputs_func: Union[str, Callable] = None,
hf_config: Union[Dict[str, Any], HfConfig] = None,
adapter_path: OLIVE_RESOURCE_ANNOTATIONS = None,
model_attributes: Optional[Dict[str, Any]] = None,
):
super().__init__(
framework=Framework.PYTORCH,
model_file_format=model_file_format,
model_path=model_path,
model_attributes=model_attributes,
io_config=io_config,
)
self.add_resources(locals())
self.model_name_pattern = model_name_pattern
self.num_ranks = num_ranks
self.model_loader = model_loader
self.dummy_inputs_func = dummy_inputs_func
self.hf_config = validate_config(hf_config, HfConfig) if hf_config else None
@property
def script_dir(self) -> str:
return self.get_resource("script_dir")
@property
def model_script(self) -> str:
return self.get_resource("model_script")
@property
def adapter_path(self) -> str:
return self.get_resource("adapter_path")
def ranked_model_name(self, rank: int) -> str:
return self.model_name_pattern.format(rank)
def ranked_model_path(self, rank: int) -> Union[Path, str]:
return Path(self.model_path) / self.ranked_model_name(rank)
def load_model(self, rank: int = None) -> PyTorchModelHandler:
return PyTorchModelHandler(
model_path=self.ranked_model_path(rank),
model_file_format=ModelFileFormat.PYTORCH_ENTIRE_MODEL,
model_loader=self.model_loader,
model_script=self.model_script,
script_dir=self.script_dir,
io_config=self._io_config,
dummy_inputs_func=self.dummy_inputs_func,
hf_config=self.hf_config,
adapter_path=self.adapter_path,
model_attributes=self.model_attributes,
)
def get_component_model(self, component: HfComponent, rank: int = 0) -> PyTorchModelHandler:
# TODO(shaahji): Add support for 'HfComponent.component_func'
hf_config = deepcopy(self.hf_config).dict()
hf_config.pop("components", None)
return PyTorchModelHandler(
model_path=self.ranked_model_path(rank),
model_file_format=ModelFileFormat.PYTORCH_ENTIRE_MODEL,
model_script=self.model_script,
script_dir=self.script_dir,
io_config=component.io_config,
dummy_inputs_func=component.dummy_inputs_func,
hf_config=HfConfig.parse_obj(hf_config),
adapter_path=self.adapter_path,
model_attributes=self.model_attributes,
)
def prepare_session(
self,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.GPU, # pylint: disable=signature-differs
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = 0,
) -> torch.nn.Module:
return self.load_model(rank).load_model(rank).eval()