Exporting Models to ONNX#

Exporting a pre-trained model to ONNX involves converting the model into a common format that can be easily integrated and deployed across different platforms. The conversion can be done using a tool or library, which converts the model’s architecture, weights, and configurations. This allows the model to be used in various applications, such as edge devices, cloud services, and web-based systems, with improved compatibility and performance.

Loading the Model#

The first step is to load any NLP-related model. In this notebook, we will be using a pre-trained GPT-2 model from the Hugging Face’s Hub.

[1]:
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2")

Exporting to ONNX#

After the model has been loaded, we call Archai’s export_to_onnx() method, which wraps all the inner computation of an ONNX export. Additionally, it supports a set of arguments that can be defined according to the input model and task, such as:

  • task: Task identifier to use proper inputs/outputs.

  • use_past: Whether to include past key/values in the model.

  • validate: Whether to validate the exported model.

  • share_weights: Whether to share the embedding and softmax weights.

  • opset: Set of operations to use with ONNX.

  • atol: Tolerance between input and exported model.

[2]:
from archai.common.file_utils import calculate_onnx_model_size
from archai.onnx.export import export_to_onnx

onnx_model_path = "model.onnx"
onnx_config = export_to_onnx(
        model,
        onnx_model_path,
        task="causal-lm",
        use_past=True,
        share_weights=True,
        opset=11,
        atol=1e-4,
    )
print(f"Model: {calculate_onnx_model_size(onnx_model_path)}MB")
2023-03-21 15:16:14,303 - archai.onnx.export — INFO —  Exporting model: model.onnx
c:\Users\gderosa\Anaconda3\envs\archai\lib\site-packages\transformers\models\gpt2\modeling_gpt2.py:318: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  past_key, past_value = layer_past
2023-03-21 15:16:28,808 - archai.onnx.export — INFO —  Validating model ...
2023-03-21 15:16:28,808 - archai.onnx.onnx_loader — INFO —  Loading model: model.onnx
2023-03-21 15:16:30,917 - archai.onnx.export — DEBUG —  Matched outputs: {'present_0', 'present_9', 'present_5', 'present_7', 'present_6', 'present_10', 'present_3', 'present_4', 'present_8', 'present_11', 'present_2', 'present_1', 'probs'}
2023-03-21 15:16:30,925 - archai.onnx.export — DEBUG —  Validating output: probs
2023-03-21 15:16:30,927 - archai.onnx.export — DEBUG —  Matched shape: (2, 50257) (ONNX) and (2, 50257) (reference)
2023-03-21 15:16:30,933 - archai.onnx.export — DEBUG —  Matched difference: 1.0133e-06 < 0.0001
2023-03-21 15:16:30,933 - archai.onnx.export — DEBUG —  Validating output: present_0
2023-03-21 15:16:30,935 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,935 - archai.onnx.export — DEBUG —  Matched difference: 2.8610e-06 < 0.0001
2023-03-21 15:16:30,941 - archai.onnx.export — DEBUG —  Validating output: present_1
2023-03-21 15:16:30,942 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,943 - archai.onnx.export — DEBUG —  Matched difference: 4.8876e-06 < 0.0001
2023-03-21 15:16:30,944 - archai.onnx.export — DEBUG —  Validating output: present_2
2023-03-21 15:16:30,944 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,948 - archai.onnx.export — DEBUG —  Matched difference: 4.2915e-06 < 0.0001
2023-03-21 15:16:30,950 - archai.onnx.export — DEBUG —  Validating output: present_3
2023-03-21 15:16:30,950 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,954 - archai.onnx.export — DEBUG —  Matched difference: 1.2398e-05 < 0.0001
2023-03-21 15:16:30,954 - archai.onnx.export — DEBUG —  Validating output: present_4
2023-03-21 15:16:30,958 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,959 - archai.onnx.export — DEBUG —  Matched difference: 1.3351e-05 < 0.0001
2023-03-21 15:16:30,959 - archai.onnx.export — DEBUG —  Validating output: present_5
2023-03-21 15:16:30,966 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,969 - archai.onnx.export — DEBUG —  Matched difference: 7.6294e-06 < 0.0001
2023-03-21 15:16:30,969 - archai.onnx.export — DEBUG —  Validating output: present_6
2023-03-21 15:16:30,969 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,975 - archai.onnx.export — DEBUG —  Matched difference: 1.3351e-05 < 0.0001
2023-03-21 15:16:30,979 - archai.onnx.export — DEBUG —  Validating output: present_7
2023-03-21 15:16:30,979 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,982 - archai.onnx.export — DEBUG —  Matched difference: 7.6294e-06 < 0.0001
2023-03-21 15:16:30,982 - archai.onnx.export — DEBUG —  Validating output: present_8
2023-03-21 15:16:30,982 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,988 - archai.onnx.export — DEBUG —  Matched difference: 9.0599e-06 < 0.0001
2023-03-21 15:16:30,991 - archai.onnx.export — DEBUG —  Validating output: present_9
2023-03-21 15:16:30,991 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:30,996 - archai.onnx.export — DEBUG —  Matched difference: 8.5831e-06 < 0.0001
2023-03-21 15:16:30,999 - archai.onnx.export — DEBUG —  Validating output: present_10
2023-03-21 15:16:31,001 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:31,003 - archai.onnx.export — DEBUG —  Matched difference: 7.1526e-06 < 0.0001
2023-03-21 15:16:31,005 - archai.onnx.export — DEBUG —  Validating output: present_11
2023-03-21 15:16:31,005 - archai.onnx.export — DEBUG —  Matched shape: (2, 2, 12, 16, 64) (ONNX) and (2, 2, 12, 16, 64) (reference)
2023-03-21 15:16:31,008 - archai.onnx.export — DEBUG —  Matched difference: 6.9141e-06 < 0.0001
Model: 499.167738MB

Post-Export Optimization#

For Transformer-based models, ONNX Runtime offers a set of post-optimization tools that enables node fusion and hence, a more optimized graph. Thus, we can call optimize_onnx() passing the path of the previously exported ONNX model.

The prints compares the models’ sizes, but is highly recommended to use an external graph inspection tool, such as Netron.

[3]:
from archai.onnx.optimization import optimize_onnx

ort_model_path = optimize_onnx(onnx_model_path, onnx_config, opt_level=1)
print(f"Model-OPT: {calculate_onnx_model_size(ort_model_path)}MB")
2023-03-21 15:16:32,958 - archai.onnx.optimization — INFO —  Optimizing model: model.onnx
Model-OPT: 498.940725MB

Post-Training Quantization (PTQ)#

Finally, either the exported or post-optimized models can be dynamically quantized using the dynamic_quantization_onnx() method.

Nevertheless, please note that if the model has not been pre-trained with Quantization Aware Training (QAT), it might produce different logits and have its performance diminished.

[4]:
from archai.quantization.ptq import dynamic_quantization_onnx

qnt_model_path = dynamic_quantization_onnx(ort_model_path)
print(f"Model-QNT: {calculate_onnx_model_size(qnt_model_path)}MB")
2023-03-21 15:16:49,535 - archai.quantization.ptq — INFO —  Quantizing model: model-opt.onnx
WARNING:root:Failed to infer data type of tensor: /transformer/h.0/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.0/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.0/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.0/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.0/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.0/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.1/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.1/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.1/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.1/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.1/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.1/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.2/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.2/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.2/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.2/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.2/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.2/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.3/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.3/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.3/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.3/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.3/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.3/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.4/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.4/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.4/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.4/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.4/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.4/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.5/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.5/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.5/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.5/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.5/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.5/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.6/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.6/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.6/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.6/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.6/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.6/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.7/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.7/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.7/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.7/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.7/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.7/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.8/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.8/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.8/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.8/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.8/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.8/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.9/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.9/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.9/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.9/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.9/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.9/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.10/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.10/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.10/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.10/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.10/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.10/attn/MatMul_1]
WARNING:root:Failed to infer data type of tensor: /transformer/h.11/attn/Reshape_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.11/attn/Reshape_1_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.11/attn/Reshape_2_output_0. Please add data type info for this tensor if your model has customized operators.
WARNING:root:Failed to infer data type of tensor: /transformer/h.11/attn/MatMul_1_output_0. Please add data type info for this tensor if your model has customized operators.
Ignore MatMul due to non constant B: /[/transformer/h.11/attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/h.11/attn/MatMul_1]
Model-QNT: 126.068565MB