Source code for archai.onnx.config_utils.codegen_onnx_config

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Any, Optional, Tuple

from transformers.configuration_utils import PretrainedConfig

from archai.onnx.config_utils.onnx_config_base import OnnxConfigWithPast


[docs]class CodeGenOnnxConfig(OnnxConfigWithPast): """CodeGen ONNX configuration (with past key/values support).""" def __init__( self, config: PretrainedConfig, task: Optional[str] = "causal-lm", use_past: Optional[bool] = False, ) -> None: super().__init__(config, task=task, use_past=use_past, past_key_values=2) @property def num_layers(self) -> int: return self.config.n_layer @property def is_ort_graph_optimizable(self) -> bool: return False @property def ort_graph_optimizer_args(self) -> Tuple[Any, ...]: return (self.num_attention_heads, self.hidden_size)