# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union
from olive.common.config_utils import serialize_to_json, validate_config
from olive.common.constants import DEFAULT_HF_TASK
from olive.common.hf.utils import load_model_from_task
from olive.common.utils import dict_diff
from olive.constants import Framework
from olive.hardware.accelerator import Device
from olive.model.config import HfLoadKwargs, IoConfig
from olive.model.config.registry import model_handler_registry
from olive.model.handler.base import OliveModelHandler
from olive.model.handler.mixin import HfMixin, MLFlowTransformersMixin
from olive.model.handler.pytorch import PyTorchModelHandlerBase
from olive.resource_path import OLIVE_RESOURCE_ANNOTATIONS
if TYPE_CHECKING:
import torch
logger = logging.getLogger(__name__)
[docs]@model_handler_registry("HFModel")
class HfModelHandler(PyTorchModelHandlerBase, MLFlowTransformersMixin, HfMixin): # pylint: disable=too-many-ancestors
resource_keys: Tuple[str, ...] = ("model_path", "adapter_path")
json_config_keys: Tuple[str, ...] = ("task", "load_kwargs", "generative")
def __init__(
self,
model_path: OLIVE_RESOURCE_ANNOTATIONS,
task: str = DEFAULT_HF_TASK,
load_kwargs: Union[Dict[str, Any], HfLoadKwargs] = None,
io_config: Union[Dict[str, Any], IoConfig, str] = None,
adapter_path: OLIVE_RESOURCE_ANNOTATIONS = None,
model_attributes: Optional[Dict[str, Any]] = None,
generative: bool = False,
):
super().__init__(
framework=Framework.PYTORCH,
model_file_format=None,
model_path=model_path,
model_attributes=model_attributes,
io_config=io_config,
generative=generative,
)
self.add_resources(locals())
self.task = task
self.load_kwargs = validate_config(load_kwargs, HfLoadKwargs, warn_unused_keys=False) if load_kwargs else None
self.model_attributes = {**self.get_hf_model_config().to_dict(), **(self.model_attributes or {})}
self.model = None
self.dummy_inputs = None
@property
def model_name_or_path(self) -> str:
"""Return the path to valid hf transformers checkpoint.
Call this instead of model_path if you expect a checkpoint path.
"""
return self.get_mlflow_transformers_path() or self.model_path
@property
def adapter_path(self) -> str:
"""Return the path to the peft adapter."""
return self.get_resource("adapter_path")
def load_model(self, rank: int = None, cache_model: bool = True) -> "torch.nn.Module":
"""Load the model from the model path."""
if self.model:
model = self.model
else:
model = load_model_from_task(self.task, self.model_path, **self.get_load_kwargs())
# we only have peft adapters for now
if self.adapter_path:
from peft import PeftModel
model = PeftModel.from_pretrained(model, self.adapter_path)
self.model = model if cache_model else None
return model
@property
def io_config(self) -> Dict[str, Any]:
"""Return io config of the model.
Priority: io_config > hf onnx_config
"""
io_config = None
if self._io_config:
# io_config is provided
io_config = self.get_resolved_io_config(
self._io_config, force_kv_cache=self.task.endswith("-with-past"), model_attributes=self.model_attributes
)
else:
logger.debug("Trying hf optimum export config to get io_config")
io_config = self.get_hf_io_config()
if io_config:
logger.debug("Got io_config from hf optimum export config")
return io_config
def get_dummy_inputs(self, filter_hook=None, filter_hook_kwargs=None):
"""Return a dummy input for the model."""
if self.dummy_inputs is not None:
return self.dummy_inputs
# Priority: io_config > hf onnx_config
dummy_inputs = self._get_dummy_inputs_from_io_config(
filter_hook=filter_hook,
filter_hook_kwargs=filter_hook_kwargs,
)
if dummy_inputs:
return dummy_inputs
logger.debug("Trying hf optimum export config to get dummy inputs")
dummy_inputs = self.get_hf_dummy_inputs()
if dummy_inputs:
logger.debug("Got dummy inputs from hf optimum export config")
if dummy_inputs is None:
raise ValueError(
"Unable to get dummy inputs for the model. Please provide io_config or install an optimum version that"
" supports the model for export."
)
return dummy_inputs
def to_json(self, check_object: bool = False):
config = super().to_json(check_object)
# only keep model_attributes that are not in hf model config
hf_model_config_dict = self.get_hf_model_config().to_dict()
config["config"]["model_attributes"] = dict_diff(self.model_attributes, hf_model_config_dict)
return serialize_to_json(config, check_object)
[docs]@model_handler_registry("DistributedHfModel")
class DistributedHfModelHandler(OliveModelHandler):
json_config_keys: Tuple[str, ...] = (
"model_name_pattern",
"num_ranks",
"task",
"load_kwargs",
"io_config",
"generative",
)
DEFAULT_RANKED_MODEL_NAME_FORMAT: ClassVar[str] = "model_{:02d}"
def __init__(
self,
model_path: OLIVE_RESOURCE_ANNOTATIONS,
model_name_pattern: str,
num_ranks: int,
task: str,
load_kwargs: Union[Dict[str, Any], HfLoadKwargs] = None,
io_config: Union[Dict[str, Any], IoConfig] = None,
model_attributes: Optional[Dict[str, Any]] = None,
generative: bool = False,
):
super().__init__(
framework=Framework.PYTORCH,
model_file_format=None,
model_path=model_path,
model_attributes=model_attributes,
io_config=io_config,
generative=generative,
)
self.add_resources(locals())
self.model_name_pattern = model_name_pattern
self.num_ranks = num_ranks
self.task = task
self.load_kwargs = load_kwargs
def ranked_model_name(self, rank: int) -> str:
return self.model_name_pattern.format(rank)
def ranked_model_path(self, rank: int) -> Union[Path, str]:
return Path(self.model_path) / self.ranked_model_name(rank)
def load_model(self, rank: int = None, cache_model: bool = True) -> HfModelHandler:
return HfModelHandler(
model_path=self.ranked_model_path(rank),
task=self.task,
load_kwargs=self.load_kwargs,
io_config=self.io_config,
model_attributes=self.model_attributes,
generative=self.generative,
)
def prepare_session(
self,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.GPU, # pylint: disable=signature-differs
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = 0,
) -> "torch.nn.Module":
return self.load_model(rank).load_model(rank).eval()
def run_session(
self,
session: Any = None,
inputs: Union[Dict[str, Any], List[Any], Tuple[Any, ...]] = None,
**kwargs: Dict[str, Any],
) -> Any:
if isinstance(inputs, dict):
results = session.generate(**inputs, **kwargs) if self.generative else session(**inputs, **kwargs)
else:
results = session.generate(inputs, **kwargs) if self.generative else session(inputs, **kwargs)
return results