Source code for archai.quantization.ptq

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import List, Optional

import onnx
import torch
from onnx import onnx_pb as onnx_proto
from onnx.onnx_ml_pb2 import NodeProto
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
from onnxruntime.quantization.operators.base_operator import QuantOperatorBase
from onnxruntime.quantization.quant_utils import attribute_to_kwarg, ms_domain
from onnxruntime.quantization.quantize import quantize_dynamic
from onnxruntime.quantization.registry import IntegerOpsRegistry

from archai.common.file_utils import create_file_name_identifier
from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.quantization.quantization_utils import rgetattr, rsetattr

logger = OrderedDictLogger(source=__name__)


[docs]class GemmQuant(QuantOperatorBase): """Quantized version of the Gemm operator.""" def __init__(self, onnx_quantizer: ONNXQuantizer, onnx_node: NodeProto) -> None: """Override initialization method with custom arguments. Args: onnx_quantizer: An instance of the quantizer itself. onnx_node: Node to be quantized. """ super().__init__(onnx_quantizer, onnx_node)
[docs] def quantize(self) -> None: """Quantize a Gemm node into QGemm. This method replaces the original `Gemm` node with a `QGemm` node, which is a quantized version of the `Gemm` operator. It also adds a `Cast` node to cast the output of `QGemm` to float, and an `Add` node to sum the remaining bias to the `Gemm` output. """ node = self.node assert node.op_type == "Gemm" # Updates original attributes to current node kwargs = {} for attribute in node.attribute: kwargs.update(attribute_to_kwarg(attribute)) kwargs.pop("beta") # Adds proper domain and missing attributes kwargs["domain"] = ms_domain kwargs["transA"] = 0 # Creates proper inputs for the QGemm node (q_names, zp_names, scale_names, nodes) = self.quantizer.quantize_inputs( node, [0, 1], reduce_range=True, op_level_per_channel=True ) qgemm_inputs = [] for (q_name, scale_name, zp_name) in zip(q_names, scale_names, zp_names): qgemm_inputs += [q_name, scale_name, zp_name] # Adds a "QGemm" node to replace original Gemm with its quantized version qgemm_output = node.output[0] + "_output_quantized" qgemm_name = node.name + "_quant" if node.name != "" else "" qgemm_node = onnx.helper.make_node("QGemm", qgemm_inputs, [qgemm_output], qgemm_name, **kwargs) nodes.append(qgemm_node) # Adds a "Cast" node to cast QGemm output to float cast_op_output = qgemm_output + "_cast_output" cast_node = onnx.helper.make_node( "Cast", [qgemm_output], [cast_op_output], qgemm_output + "_cast", to=onnx_proto.TensorProto.FLOAT ) nodes.append(cast_node) # Adds a "Add" node to sum the remaining bias to the Gemm output bias_node = onnx.helper.make_node( "Add", [cast_node.output[0], "crit.out_layers_biases.0"], [node.output[0]], qgemm_name + "_output_add" ) nodes.append(bias_node) # Adds new nodes to the original quantizer list self.quantizer.new_nodes += nodes
[docs]def add_new_quant_operators() -> None: """Add support for new quantization operators by changing internal onnxruntime registry dictionaries. """ # Changes the internal `IntegerOpsRegistry` # and adds support for new quantization operators IntegerOpsRegistry["Gemm"] = GemmQuant
[docs]def dynamic_quantization_onnx(onnx_model_path: str) -> str: """Perform dynamic quantization on an ONNX model. The quantized model is saved to a new file with "-int8" appended to the original file name. Args: onnx_model_path: Path to the ONNX model to be quantized. Returns: Path to the dynamic quantized ONNX model. """ logger.info(f"Quantizing model: {onnx_model_path}") # Adds new quantization operators # For now, we are only adding support for Gemm # add_new_quant_operators() # Performs the dynamic quantization qnt_model_path = create_file_name_identifier(onnx_model_path, "-int8") quantize_dynamic(onnx_model_path, qnt_model_path, per_channel=False, reduce_range=False, optimize_model=False) return qnt_model_path
[docs]def dynamic_quantization_torch( model: torch.nn.Module, embedding_layers: Optional[List[str]] = ["word_emb", "transformer.wpe", "transformer.wte"] ) -> None: """Perform dynamic quantization on a PyTorch model. This function performs dynamic quantization on the input PyTorch model, including any specified embedding layers. Args: model: PyTorch model to be quantized. embedding_layers: List of string-based identifiers of embedding layers to be quantized. """ logger.info("Quantizing model ...") # Sets the number of threads # Quantized model only uses maximum of 1 thread torch.set_num_threads(1) # Performs an initial dynamic quantization model_qnt = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, inplace=False) # Currently, code below works as a caveat to quantize the embedding layers for layer in embedding_layers: # Checks if supplied embedding layer really exists if rgetattr(model_qnt, layer, 0): # Sets the appropriate `qconfig` for embedding layers attr = layer + ".qconfig" rsetattr(model_qnt, attr, torch.quantization.float_qparams_weight_only_qconfig) # Prepares the model for quantization and quantizes it model_qnt = torch.quantization.prepare(model_qnt, inplace=False) model_qnt = torch.quantization.convert(model_qnt, inplace=False) return model_qnt