Source code for archai.quantization.observers

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch


[docs]class OnnxDynamicObserver: """DynamicObserver that is compliant with ONNX-based graphs. This class can be used to perform symmetric or assymetric quantization, depending on the `dtype` provided. `qint8` is usually used for symmetric quantization, while `quint8` is used for assymetric quantization. """ def __init__(self, dtype: str) -> None: """Initialize the class by setting appropriate values for quantization bounds. Args: dtype: Type of quantization operators. This should be either `torch.quint8` or `torch.qint8`. """ self.dtype = dtype self.eps = torch.finfo(torch.float32).eps assert dtype in (torch.quint8, torch.qint8) if dtype == torch.quint8: self.qmin, self.qmax = 0, 255 else: self.qmin, self.qmax = -128, 127 def __call__(self, x: torch.Tensor) -> None: x = x.detach().float() self.min_val, self.max_val = x.min().view(-1), x.max().view(-1)
[docs] def calculate_qparams(self) -> None: """Calculate the quantization parameters.""" if self.dtype == torch.qint8: scale = torch.max(self.max_val.clamp(min=0), -self.min_val.clamp(max=0)) / 127 zero_pointer = torch.zeros_like(scale).to(torch.int64) return scale.clamp(min=self.eps), zero_pointer else: scale = (self.max_val - self.min_val) / float(self.qmax - self.qmin) scale = scale.clamp(min=self.eps) zero_pointer = self.qmin - torch.round(self.min_val / scale) zero_pointer = zero_pointer.clamp(min=self.qmin, max=self.qmax).to(torch.int64) return scale, zero_pointer