ONNX#
ONNX Forward#
- archai.onnx.onnx_forward.gpt2_onnx_forward(self, input_ids: LongTensor, past_key_values: Tuple[FloatTensor, ...] | None = None) Dict[str, FloatTensor] [source]#
Forward pass through the GPT-2 model with ONNX exportability.
This method overrides the default GPT-2 forward method and returns both output probabilities and past key/values.
- Parameters:
input_ids – Input tensor.
past_key_values – Past pre-computed key/values tensor.
- Returns:
Output probabilities and past key/values.
ONNX Loader#
- archai.onnx.onnx_loader.load_from_onnx(onnx_model_path: str, providers: List[str] | None = None) InferenceSession [source]#
Load an ONNX-based model from file.
This function loads an ONNX-based model from the specified file path and returns an ONNX inference session. Performance optimization constants are set as well.
- Parameters:
onnx_model_path – Path to the ONNX model file.
providers – List of providers to use for inference.
- Returns:
ONNX inference session.
Export#
- archai.onnx.export.validate_onnx_outputs(onnx_config: OnnxConfig, reference_model: Module, onnx_model_path: str, atol: float) None [source]#
Validate the outputs of an ONNX model against a reference PyTorch model.
- Parameters:
onnx_config – Configuration for ONNX model.
reference_model – PyTorch model to use as reference.
onnx_model_path – Path to the ONNX model.
atol – Tolerance value for comparing the model outputs.
- Raises:
ValueError – If the shapes or values of the ONNX model outputs do not match the reference model outputs within the specified tolerance.
- archai.onnx.export.export_to_onnx(model: Module, output_model_path: str, task: str | None = 'causal-lm', use_past: bool | None = True, validate: bool | None = True, share_weights: bool | None = True, opset: int | None = 11, atol: float | None = 0.0001) OnnxConfig [source]#
Export a pre-trained PyTorch model to ONNX format.
- Parameters:
model – Instance of the PyTorch model to be exported.
output_model_path – Path to save the exported ONNX model.
task – Task identifier to use proper inputs/outputs.
use_past – Whether to include past key/values in the model.
validate – Whether to validate the exported model.
share_weights – Whether to share the embedding and softmax weights.
opset – Set of operations to use with ONNX.
atol – Tolerance between input and exported model.
- Returns:
ONNX configuration of the model that was exported.
Export (Utilities)#
- archai.onnx.export_utils.prepare_model_for_onnx(model: Module, model_type: str) Module [source]#
Prepare a PyTorch model for ONNX export by modifying the forward function and performing any additional pre-processing steps.
- Parameters:
model – Instance of the model to prepare for ONNX export.
model_type – Type of model.
- Returns:
The prepared PyTorch model, ready for ONNX export.
Optimization#
- archai.onnx.optimization.optimize_onnx(onnx_model_path: str, onnx_config: OnnxConfig, use_gpu: bool | None = False, opt_level: int | None = 1, only_ort: bool | None = False, float16: bool | None = False, input_int32: bool | None = False) str [source]#
Optimize an ONNX model using a combination of standard ORT-based optimization and additional transformer-based optimization.
- Parameters:
onnx_model_path – Path to the ONNX model to be optimized.
onnx_config – ONNX configuration of model to be optimized.
use_gpu – Whether to use GPU during optimization.
opt_level – Level of optimization.
only_ort – Whether to only apply ORT optimization.
float16 – Whether to use graph with float16.
input_int32 – Whether to use inputs with int32.
- Returns:
Path to the optimized ONNX model.