# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
import torch
import yaml
from olive.common.config_utils import serialize_to_json, validate_config
from olive.common.user_module_loader import UserModuleLoader
from olive.common.utils import copy_dir
from olive.constants import Framework, ModelFileFormat
from olive.hardware.accelerator import Device
from olive.model.config import HfComponent, HfConfig, IoConfig
from olive.model.config.registry import model_handler_registry
from olive.model.handler.base import OliveModelHandler
from olive.model.handler.mixin import DummyInputsMixin, HfConfigMixin
from olive.model.utils.hf_utils import load_hf_model_from_model_class
from olive.resource_path import OLIVE_RESOURCE_ANNOTATIONS, ResourceType, create_resource_path
logger = logging.getLogger(__name__)
[docs]@model_handler_registry("PyTorchModel")
class PyTorchModelHandler(OliveModelHandler, HfConfigMixin, DummyInputsMixin): # pylint: disable=too-many-ancestors
"""PyTorch model handler.
Besides the model loading for PyTorch model, the model handler also provides the following functionalities:
* Get the model io configuration either from user provider io_config or from hf_config. The priority is user
provided io_config is higher than hf_config.
* Get the dummy inputs for PyTorch model used to evaluate the latency.
* All kinds of Hf model functionalities by HfConfigMixin.
"""
resource_keys: Tuple[str, ...] = ("model_path", "script_dir", "model_script", "adapter_path")
json_config_keys: Tuple[str, ...] = (
"model_file_format",
"model_loader",
"io_config",
"dummy_inputs_func",
"hf_config",
)
def __init__(
self,
model_path: OLIVE_RESOURCE_ANNOTATIONS = None,
model_file_format: ModelFileFormat = ModelFileFormat.PYTORCH_ENTIRE_MODEL,
model_loader: Union[str, Callable] = None,
model_script: Union[str, Path] = None,
script_dir: Union[str, Path] = None,
io_config: Union[Dict[str, Any], IoConfig, str, Callable] = None,
dummy_inputs_func: Union[str, Callable] = None,
hf_config: Union[Dict[str, Any], HfConfig] = None,
adapter_path: OLIVE_RESOURCE_ANNOTATIONS = None,
model_attributes: Optional[Dict[str, Any]] = None,
):
if not (
isinstance(model_loader, Callable)
or (isinstance(model_loader, str) and model_script)
or model_path
or hf_config
):
raise ValueError(
"model_path is required since model_loader is not callable or model_script is not provided"
)
self.model_loader = model_loader
self.model = None
super().__init__(
framework=Framework.PYTORCH,
model_file_format=model_file_format,
model_path=model_path,
model_attributes=model_attributes,
)
self.add_resources(locals())
self.hf_config = None
if hf_config:
self.hf_config = validate_config(hf_config, HfConfig)
hf_model_config = self.get_hf_model_config().to_dict()
model_attr = self.model_attributes or {}
hf_model_config.update(model_attr)
self.model_attributes = hf_model_config
# ensure that script_dirs are local folder
script_dir_resource = create_resource_path(self.script_dir)
if script_dir_resource:
assert script_dir_resource.type == ResourceType.LocalFolder, "script_dir must be a local directory."
# ensure that model_script is local file or string name
model_script_resource = create_resource_path(self.model_script)
if model_script_resource:
assert model_script_resource.type in (
ResourceType.LocalFile,
ResourceType.StringName,
), "model_script must be a local file or a string name."
# io config for conversion to onnx
self.io_config = (
validate_config(io_config, IoConfig).dict() if isinstance(io_config, (IoConfig, dict)) else io_config
)
self.dummy_inputs_func = dummy_inputs_func
self.dummy_inputs = None
@property
def script_dir(self) -> str:
return self.get_resource("script_dir")
@property
def model_script(self) -> str:
return self.get_resource("model_script")
@property
def adapter_path(self) -> str:
return self.get_resource("adapter_path")
def load_model(self, rank: int = None) -> torch.nn.Module:
if self.model is not None:
return self.model
# Load user module at the beginning since we may need user defined models to load model
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
# Load special path or format model -> load model from hf config -> load normal path model
if self.model_loader is not None:
model = user_module_loader.call_object(self.model_loader, self.model_path)
elif self.model_file_format == ModelFileFormat.PYTORCH_TORCH_SCRIPT:
model = torch.jit.load(self.model_path)
elif self.model_file_format == ModelFileFormat.PYTORCH_MLFLOW_MODEL:
model = self._load_mlflow_model()
elif self.hf_config and (self.hf_config.model_class or self.hf_config.task):
model = self.load_hf_model(self.model_path)
elif self.model_file_format == ModelFileFormat.PYTORCH_ENTIRE_MODEL:
model = torch.load(self.model_path)
elif self.model_file_format == ModelFileFormat.PYTORCH_STATE_DICT:
raise ValueError("Please use customized model loader to load state dict of model.")
else:
raise ValueError(f"Unsupported model file format: {self.model_file_format}")
# we only have peft adapters for now
if self.adapter_path:
from peft import PeftModel
model = PeftModel.from_pretrained(model, self.adapter_path)
self.model = model
return model
def get_component_model(self, component: HfComponent, rank: Optional[int] = None) -> "PyTorchModelHandler":
if component.component_func is None:
logger.debug("component_func is not provided, using hf_config to get component")
model_component = self.load_hf_model(self.model_path)
else:
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
model_component = user_module_loader.call_object(component.component_func, self)
# the second default parameter is to fix ruff b023:
# https://docs.astral.sh/ruff/rules/function-uses-loop-variable/
def model_loader(_, model_component=model_component):
return model_component
component_hf_config = deepcopy(self.hf_config).dict()
component_hf_config.pop("components", None)
return PyTorchModelHandler(
model_loader=model_loader,
io_config=component.io_config,
dummy_inputs_func=component.dummy_inputs_func,
model_script=self.model_script,
script_dir=self.script_dir,
hf_config=HfConfig.parse_obj(component_hf_config),
model_attributes=self.model_attributes,
)
def prepare_session(
self,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.CPU,
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = None,
):
return self.load_model(rank).eval()
def _load_mlflow_model(self):
logger.info("Loading MLFlow model from %s", self.model_path)
with tempfile.TemporaryDirectory(prefix="mlflow_tmp") as tmp_dir:
copy_dir(os.path.join(self.model_path, "data/model"), tmp_dir, dirs_exist_ok=True)
copy_dir(os.path.join(self.model_path, "data/config"), tmp_dir, dirs_exist_ok=True)
copy_dir(os.path.join(self.model_path, "data/tokenizer"), tmp_dir, dirs_exist_ok=True)
with open(os.path.join(self.model_path, "MLmodel")) as fp:
mlflow_data = yaml.safe_load(fp)
# default flavor is "hftransformersv2" from azureml.evaluate.mlflow>=0.0.8
# "hftransformers" from azureml.evaluate.mlflow<0.0.8
# TODO(trajep): let user specify flavor name if needed
# to support other flavors in mlflow not only hftransformers
hf_pretrained_class = None
flavors = mlflow_data.get("flavors", {})
if not flavors:
raise ValueError(
"Invalid MLFlow model format. Please make sure the input model"
" format is same with the result of mlflow.transformers.save_model,"
" or aml_mlflow.hftransformers.save_model from azureml.evaluate.mlflow"
)
if "hftransformersv2" in flavors:
hf_pretrained_class = flavors["hftransformersv2"].get("hf_pretrained_class", "AutoModel")
elif "hftransformers" in flavors:
hf_pretrained_class = flavors["hftransformers"].get("hf_pretrained_class", "AutoModel")
else:
raise ValueError(
"Unsupported MLFlow model flavor. Currently only support hftransformersv2/hftransformers."
)
loading_args = self.hf_config.get_loading_args_from_pretrained() if self.hf_config else {}
loaded_model = load_hf_model_from_model_class(hf_pretrained_class, tmp_dir, **loading_args)
loaded_model.eval()
return loaded_model
def to_json(self, check_object: bool = False):
config = super().to_json(check_object)
# only keep model_attributes that are not in hf_config
if self.model_attributes and self.hf_config:
model_attributes = {}
hf_config_dict = self.get_hf_model_config().to_dict()
for key, value in self.model_attributes.items():
if key not in hf_config_dict or hf_config_dict[key] != value:
model_attributes[key] = value
config["config"]["model_attributes"] = model_attributes or None
return serialize_to_json(config, check_object)
def get_user_io_config(self, io_config: Union[Dict[str, Any], IoConfig, str, Callable]) -> Dict[str, Any]:
"""Resolve io_config to a dictionary.
If io_config is a string name or a callable, it will be called to get io_config.
"""
if isinstance(io_config, dict):
# io_config is provided
return io_config
if isinstance(io_config, IoConfig):
# io_config is an IoConfig
return io_config.dict()
if isinstance(io_config, (str, Callable)):
# io_config is a string name or a callable
logger.debug("Calling %s to get io_config", io_config)
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
io_config = user_module_loader.call_object(io_config, self)
return validate_config(io_config, IoConfig).dict()
return None
def get_io_config(self) -> Dict[str, Any]:
"""Return io config of the model.
Priority: io_config > hf_config (using onnx_config)
"""
io_config = None
if self.io_config:
# io_config is provided
io_config = self.get_user_io_config(self.io_config)
elif self.hf_config and self.hf_config.task and not self.hf_config.components:
# hf_config is provided
logger.debug("Using hf onnx_config to get io_config")
# For MLFlow model, get io config from model_name instead of model_path
# TODO(xiaoyu): more investigation on the integration between MLFlow and HF
io_config = self.get_hf_io_config()
return io_config
[docs]@model_handler_registry("DistributedPyTorchModel")
class DistributedPyTorchModelHandler(OliveModelHandler, HfConfigMixin):
resource_keys: Tuple[str, ...] = ("model_path", "script_dir", "model_script", "adapter_path")
json_config_keys: Tuple[str, ...] = (
"model_name_pattern",
"num_ranks",
"model_loader",
"io_config",
"dummy_inputs_func",
"hf_config",
)
DEFAULT_RANKED_MODEL_NAME_FORMAT: ClassVar[str] = "model_{:02d}"
def __init__(
self,
model_path: OLIVE_RESOURCE_ANNOTATIONS,
model_name_pattern: str,
num_ranks: int,
model_file_format: ModelFileFormat = ModelFileFormat.PYTORCH_ENTIRE_MODEL,
model_loader: Union[str, Callable] = None,
model_script: Union[str, Path] = None,
script_dir: Union[str, Path] = None,
io_config: Union[Dict[str, Any], IoConfig, str, Callable] = None,
dummy_inputs_func: Union[str, Callable] = None,
hf_config: Union[Dict[str, Any], HfConfig] = None,
adapter_path: OLIVE_RESOURCE_ANNOTATIONS = None,
model_attributes: Optional[Dict[str, Any]] = None,
):
super().__init__(
framework=Framework.PYTORCH,
model_file_format=model_file_format,
model_path=model_path,
model_attributes=model_attributes,
)
self.add_resources(locals())
self.model_name_pattern = model_name_pattern
self.num_ranks = num_ranks
self.model_loader = model_loader
self.io_config = (
validate_config(io_config, IoConfig).dict() if isinstance(io_config, (IoConfig, dict)) else io_config
)
self.dummy_inputs_func = dummy_inputs_func
self.hf_config = validate_config(hf_config, HfConfig) if hf_config else None
@property
def script_dir(self) -> str:
return self.get_resource("script_dir")
@property
def model_script(self) -> str:
return self.get_resource("model_script")
@property
def adapter_path(self) -> str:
return self.get_resource("adapter_path")
def ranked_model_name(self, rank: int) -> str:
return self.model_name_pattern.format(rank)
def ranked_model_path(self, rank: int) -> Union[Path, str]:
return Path(self.model_path) / self.ranked_model_name(rank)
def load_model(self, rank: int = None) -> PyTorchModelHandler:
return PyTorchModelHandler(
model_path=self.ranked_model_path(rank),
model_file_format=ModelFileFormat.PYTORCH_ENTIRE_MODEL,
model_loader=self.model_loader,
model_script=self.model_script,
script_dir=self.script_dir,
io_config=self.io_config,
dummy_inputs_func=self.dummy_inputs_func,
hf_config=self.hf_config,
adapter_path=self.adapter_path,
model_attributes=self.model_attributes,
)
def get_component_model(self, component: HfComponent, rank: int = 0) -> PyTorchModelHandler:
# TODO(shaahji): Add support for 'HfComponent.component_func'
hf_config = deepcopy(self.hf_config).dict()
hf_config.pop("components", None)
return PyTorchModelHandler(
model_path=self.ranked_model_path(rank),
model_file_format=ModelFileFormat.PYTORCH_ENTIRE_MODEL,
model_script=self.model_script,
script_dir=self.script_dir,
io_config=component.io_config,
dummy_inputs_func=component.dummy_inputs_func,
hf_config=HfConfig.parse_obj(hf_config),
adapter_path=self.adapter_path,
model_attributes=self.model_attributes,
)
def prepare_session(
self,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.GPU, # pylint: disable=signature-differs
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = 0,
) -> torch.nn.Module:
return self.load_model(rank).load_model(rank).eval()