Source code for archai.quantization.quantizers

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional

import torch
from torch._C import dtype
from torch.quantization import MinMaxObserver

from archai.quantization.observers import OnnxDynamicObserver


[docs]class FakeDynamicQuant(torch.nn.Module): """Fake dynamic quantizer to allow for scale/zero point calculation during Quantization-Aware Training. This class allows inserting a fake dynamic quantization operator in a PyTorch model, in order to calculate scale and zero point values that can be used to quantize the model during training. The operator can be customized to use different quantization types (quint8 or qint8) and bit widths, and it can be made compatible with ONNX. Note: This module is only meant to be used during training, and should not be present in the final, deployed model. """ def __init__( self, reduce_range: Optional[bool] = True, dtype: Optional[dtype] = torch.quint8, bits: Optional[int] = 8, onnx_compatible: Optional[bool] = False, ) -> None: """Initialize a customizable fake dynamic quantization operator. Args: reduce_range: Whether to reduce the range of quantization. This option is only supported for 8-bit quantization. dtype: Type of quantization operators. Supported values are `torch.quint8` and `torch.qint8`. bits: Number of bits used in the quantization. Supported values are 8 and 16. onnx_compatible: Whether the quantization should be compatible with ONNX. """ super().__init__() self.bits = bits self.reduce_range = reduce_range if bits == 8 else False self.dtype = dtype self.onnx_compatible = onnx_compatible assert dtype in (torch.quint8, torch.qint8) if dtype == torch.quint8: if self.reduce_range: self.qmin, self.qmax = 0, 2 ** (bits - 1) else: self.qmin, self.qmax = 0, 2**bits - 1 else: if self.reduce_range: self.qmin, self.qmax = -(2 ** (bits - 2)), 2 ** (bits - 2) - 1 else: self.qmin, self.qmax = -(2 ** (bits - 1)), 2 ** (bits - 1) - 1
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if x.dtype == torch.float32: if self.bits == 8: if self.dtype == torch.quint8: qscheme = torch.per_tensor_affine else: qscheme = torch.per_tensor_symmetric if self.onnx_compatible: observer = OnnxDynamicObserver(dtype=self.dtype) else: observer = MinMaxObserver( dtype=self.dtype, qscheme=qscheme, reduce_range=self.reduce_range, ) observer(x) scale, zero_pointer = observer.calculate_qparams() else: min_val, max_val = x.min(), x.max() initial_scale = (max_val - min_val) / float(self.qmax - self.qmin) min_zero_pointer = self.qmin - min_val / initial_scale max_zero_pointer = self.qmax - max_val / initial_scale min_zero_pointer_error = abs(self.qmin) - abs(min_val / initial_scale) max_zero_pointer_error = abs(self.qmax) - abs(max_val / initial_scale) if min_zero_pointer_error < max_zero_pointer_error: initial_zero_pointer = min_zero_pointer else: initial_zero_pointer = max_zero_pointer initial_zero_pointer = initial_zero_pointer.round() scale, zero_pointer = initial_scale, initial_zero_pointer # Prevents `zero_pointer` from being outside the range of the quantized dtype if zero_pointer > self.qmax: zero_pointer = torch.tensor(self.qmax) elif zero_pointer < self.qmin: zero_pointer = torch.tensor(self.qmin) x = torch.fake_quantize_per_tensor_affine( x, float(scale.item()), int(zero_pointer.item()), self.qmin, self.qmax ) self._scale, self._zero_pointer = scale, zero_pointer return x