# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
from collections import OrderedDict
from typing import Any, Mapping, Optional, Tuple
import torch
from overrides import overrides
from overrides.enforce import EnforceOverrides
from transformers.configuration_utils import PretrainedConfig
[docs]class OnnxConfig(EnforceOverrides):
"""Base ONNX configuration.
This class defines a base ONNX configuration for a specific task, which includes the
input and output structure required for ONNX models, as well as additional properties
and methods for handling ONNX Runtime graph optimization.
"""
DEFAULT_TASK_OUTPUTS = {"causal-lm": OrderedDict({"probs": {0: "batch_size"}})}
def __init__(
self,
config: PretrainedConfig,
task: Optional[str] = "causal-lm",
) -> None:
"""Initialize the ONNX configuration by verifying whether the
specified `task` is supported.
Args:
config: Configuration of the model being exported.
task: Type of task that the exported model will be used for.
"""
assert task in self.DEFAULT_TASK_OUTPUTS.keys(), f"`task`: {task} is not supported yet."
self.config = config
self.task = task
@property
def is_ort_graph_optimizable(self) -> bool:
"""Return whether configuration supports additional graph optimization."""
return False
@property
def ort_graph_optimizer_args(self) -> Tuple[Any, ...]:
"""Return additional arguments used by the ORT graph optimizer."""
return None
[docs] def get_outputs(self) -> Mapping[str, Mapping[int, str]]:
"""Get the ONNX-based outputs structure.
Returns:
ONNX-based outputs.
"""
return copy.deepcopy(self.DEFAULT_TASK_OUTPUTS[self.task])
[docs]class OnnxConfigWithPast(OnnxConfig):
"""ONNX configuration with support for past key/values.
This class is a subclass of `OnnxConfig` that adds the ability to use past key/values
(also known as 'use_cache') in the model's ONNX export.
"""
def __init__(
self,
config: PretrainedConfig,
task: Optional[str] = "causal-lm",
use_past: Optional[bool] = False,
past_key_values: Optional[int] = 2,
) -> None:
"""Initialize the ONNX configuration with past key/values.
Args:
config: Model's configuration.
task: Type of task that the exported model will be used for.
use_past: Whether past key/values should be used.
past_key_values: Number of past-related information (2 for key and values).
"""
super().__init__(config, task=task)
if use_past:
self.config.use_cache = True
self.config.past_key_values = past_key_values
else:
self.config.use_cache = False
self.use_past = use_past
@property
def hidden_size(self) -> int:
"""Return the dimensionality of hidden units."""
if not hasattr(self.config, "hidden_size"):
raise AttributeError("Please override `hidden_size` with correct attribute.")
return self.config.hidden_size
@property
def num_layers(self) -> int:
"""Return the number of layers."""
if not hasattr(self.config, "num_layers"):
raise AttributeError("Please override `num_layers` with correct attribute.")
return self.config.num_layers
@property
def num_attention_heads(self) -> int:
"""Return the number of attention heads."""
if not hasattr(self.config, "num_attention_heads"):
raise AttributeError("Please override `num_attention_heads` with correct attribute.")
return self.config.num_attention_heads
[docs] @overrides
def get_outputs(self) -> Mapping[str, Mapping[int, str]]:
outputs = super().get_outputs()
if self.use_past:
for i in range(self.num_layers):
# [past_key_values, batch_size, n_head, total_seq_len, d_head]
# Note that total_seq_len is seq_len + past_seq_len
outputs[f"present_{i}"] = {1: "batch_size", 3: "total_seq_len"}
return outputs