Exporting Models to ONNX#
Exporting a pre-trained model to ONNX involves converting the model into a common format that can be easily integrated and deployed across different platforms. The conversion can be done using a tool or library, which converts the model’s architecture, weights, and configurations. This allows the model to be used in various applications, such as edge devices, cloud services, and web-based systems, with improved compatibility and performance.
Loading the Model#
The first step is to load any NLP-related model. In this notebook, we will be using a pre-trained GPT-2 model from the Hugging Face’s Hub.
[1]:
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained("gpt2")
Exporting to ONNX#
After the model has been loaded, we call Archai’s export_to_onnx()
method, which wraps all the inner computation of an ONNX export. Additionally, it supports a set of arguments that can be defined according to the input model and task, such as:
task
: Task identifier to use proper inputs/outputs.use_past
: Whether to include past key/values in the model.validate
: Whether to validate the exported model.share_weights
: Whether to share the embedding and softmax weights.opset
: Set of operations to use with ONNX.atol
: Tolerance between input and exported model.
[2]:
from archai.common.file_utils import calculate_onnx_model_size
from archai.onnx.export import export_to_onnx
onnx_model_path = "model.onnx"
onnx_config = export_to_onnx(
model,
onnx_model_path,
task="causal-lm",
use_past=True,
share_weights=True,
opset=11,
atol=1e-4,
)
print(f"Model: {calculate_onnx_model_size(onnx_model_path)}MB")
2023-03-21 15:16:14,303 - archai.onnx.export — INFO — Exporting model: model.onnx
c:\Users\gderosa\Anaconda3\envs\archai\lib\site-packages\transformers\models\gpt2\modeling_gpt2.py:318: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
past_key, past_value = layer_past
2023-03-21 15:16:28,808 - archai.onnx.export — INFO — Validating model ...
2023-03-21 15:16:28,808 - archai.onnx.onnx_loader — INFO — Loading model: model.onnx
2023-03-21 15:16:30,917 - archai.onnx.export — DEBUG — Matched outputs: {'present_0', 'present_9', 'present_5', 'present_7', 'present_6', 'present_10', 'present_3', 'present_4', 'present_8', 'present_11', 'present_2', 'present_1', 'probs'}
2023-03-21 15:16:30,925 - archai.onnx.export — DEBUG — Validating output: probs
2023-03-21 15:16:30,927 - archai.onnx.export — DEBUG — Matched shape: (2, 50257) (ONNX) and (2, 50257) (reference)
2023-03-21 15:16:30,933 - archai.onnx.export — DEBUG — Matched difference: 1.0133e-06 < 0.0001
2023-03-21 15:16:30,933 - archai.onnx.export — DEBUG — Validating output: present_0
2023-03-21 15:16:30,935 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,935 - archai.onnx.export — DEBUG — Matched difference: 2.8610e-06 < 0.0001
2023-03-21 15:16:30,941 - archai.onnx.export — DEBUG — Validating output: present_1
2023-03-21 15:16:30,942 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,943 - archai.onnx.export — DEBUG — Matched difference: 4.8876e-06 < 0.0001
2023-03-21 15:16:30,944 - archai.onnx.export — DEBUG — Validating output: present_2
2023-03-21 15:16:30,944 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,948 - archai.onnx.export — DEBUG — Matched difference: 4.2915e-06 < 0.0001
2023-03-21 15:16:30,950 - archai.onnx.export — DEBUG — Validating output: present_3
2023-03-21 15:16:30,950 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,954 - archai.onnx.export — DEBUG — Matched difference: 1.2398e-05 < 0.0001
2023-03-21 15:16:30,954 - archai.onnx.export — DEBUG — Validating output: present_4
2023-03-21 15:16:30,958 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,959 - archai.onnx.export — DEBUG — Matched difference: 1.3351e-05 < 0.0001
2023-03-21 15:16:30,959 - archai.onnx.export — DEBUG — Validating output: present_5
2023-03-21 15:16:30,966 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,969 - archai.onnx.export — DEBUG — Matched difference: 7.6294e-06 < 0.0001
2023-03-21 15:16:30,969 - archai.onnx.export — DEBUG — Validating output: present_6
2023-03-21 15:16:30,969 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,975 - archai.onnx.export — DEBUG — Matched difference: 1.3351e-05 < 0.0001
2023-03-21 15:16:30,979 - archai.onnx.export — DEBUG — Validating output: present_7
2023-03-21 15:16:30,979 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,982 - archai.onnx.export — DEBUG — Matched difference: 7.6294e-06 < 0.0001
2023-03-21 15:16:30,982 - archai.onnx.export — DEBUG — Validating output: present_8
2023-03-21 15:16:30,982 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,988 - archai.onnx.export — DEBUG — Matched difference: 9.0599e-06 < 0.0001
2023-03-21 15:16:30,991 - archai.onnx.export — DEBUG — Validating output: present_9
2023-03-21 15:16:30,991 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,996 - archai.onnx.export — DEBUG — Matched difference: 8.5831e-06 < 0.0001
2023-03-21 15:16:30,999 - archai.onnx.export — DEBUG — Validating output: present_10
2023-03-21 15:16:31,001 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:31,003 - archai.onnx.export — DEBUG — Matched difference: 7.1526e-06 < 0.0001
2023-03-21 15:16:31,005 - archai.onnx.export — DEBUG — Validating output: present_11
2023-03-21 15:16:31,005 - archai.onnx.export — DEBUG — Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:31,008 - archai.onnx.export — DEBUG — Matched difference: 6.9141e-06 < 0.0001
Model: 499.167738MB
Post-Export Optimization#
For Transformer-based models, ONNX Runtime offers a set of post-optimization tools that enables node fusion and hence, a more optimized graph. Thus, we can call optimize_onnx()
passing the path of the previously exported ONNX model.
The prints compares the models’ sizes, but is highly recommended to use an external graph inspection tool, such as Netron.
[3]:
from archai.onnx.optimization import optimize_onnx
ort_model_path = optimize_onnx(onnx_model_path, onnx_config, opt_level=1)
print(f"Model-OPT: {calculate_onnx_model_size(ort_model_path)}MB")
2023-03-21 15:16:32,958 - archai.onnx.optimization — INFO — Optimizing model: model.onnx
Model-OPT: 498.940725MB
Post-Training Quantization (PTQ)#
Finally, either the exported or post-optimized models can be dynamically quantized using the dynamic_quantization_onnx()
method.
Nevertheless, please note that if the model has not been pre-trained with Quantization Aware Training (QAT), it might produce different logits and have its performance diminished.
[4]:
from archai.quantization.ptq import dynamic_quantization_onnx
qnt_model_path = dynamic_quantization_onnx(ort_model_path)
print(f"Model-QNT: {calculate_onnx_model_size(qnt_model_path)}MB")
2023-03-21 15:16:49,535 - archai.quantization.ptq — INFO — Quantizing model: model-opt.onnx
WARNING:root:Failed to infer data type of tensor: /transformer/h.0/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.0/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.0/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.0/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.0/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.0/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.1/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.1/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.1/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.1/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.1/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.1/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.2/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.2/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.2/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.2/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.2/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.2/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.3/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.3/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.3/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.3/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.3/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.3/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.4/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.4/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.4/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.4/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.4/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.4/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.5/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.5/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.5/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.5/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.5/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.5/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.6/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.6/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.6/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.6/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.6/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.6/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.7/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.7/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.7/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.7/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.7/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.7/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.8/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.8/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.8/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.8/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.8/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.8/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.9/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.9/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.9/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.9/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.9/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.9/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.10/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.10/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.10/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.10/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.10/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.10/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.11/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.11/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.11/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.11/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.11/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.11/attn/MatMul_1]
Model-QNT: 126.068565MB