Source code for archai.onnx.config_utils.gpt2_onnx_config
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Mapping, Optional, Tuple
import torch
from overrides import overrides
from transformers.configuration_utils import PretrainedConfig
from archai.onnx.config_utils.onnx_config_base import OnnxConfig, OnnxConfigWithPast
[docs]class GPT2OnnxConfig(OnnxConfigWithPast):
"""GPT-2 ONNX configuration (with past key/values support)."""
def __init__(
self,
config: PretrainedConfig,
task: Optional[str] = "causal-lm",
use_past: Optional[bool] = False,
) -> None:
super().__init__(config, task=task, use_past=use_past, past_key_values=2)
@property
def num_layers(self) -> int:
return self.config.n_layer
@property
def is_ort_graph_optimizable(self) -> bool:
return True
@property
def ort_graph_optimizer_args(self) -> Tuple[Any, ...]:
return (self.num_attention_heads, self.hidden_size)
[docs]class GPT2FlexOnnxConfig(OnnxConfigWithPast):
"""GPT-2 Flex ONNX configuration (with past key/values support)."""
def __init__(
self, config: PretrainedConfig, task: Optional[str] = "causal-lm", use_past: Optional[bool] = False
) -> None:
super().__init__(config, task=task, use_past=use_past, past_key_values=2)
@property
def num_layers(self) -> int:
return self.config.n_layer
@property
def is_ort_graph_optimizable(self) -> bool:
return all(nh == self.num_attention_heads[0] for nh in self.num_attention_heads)
@property
def ort_graph_optimizer_args(self) -> Tuple[Any, ...]:
return (self.num_attention_heads[0], self.hidden_size)