Source code for archai.quantization.nlp.modules

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

from typing import Any, Dict, Optional

import torch
import transformers

from archai.quantization.quantizers import FakeDynamicQuant


[docs]class FakeDynamicQuantHFConv1D(transformers.modeling_utils.Conv1D): """Translate a huggingface/transformers Conv1D layer into a QAT-ready Conv1D layer.""" _FLOAT_MODULE = transformers.modeling_utils.Conv1D def __init__( self, *args, dynamic_weight: Optional[bool] = True, activation_reduce_range: Optional[bool] = True, bits: Optional[int] = 8, onnx_compatible: Optional[bool] = False, qconfig: Optional[Dict[torch.nn.Module, Any]] = None, **kwargs, ) -> None: """Initialize a fake quantized Conv1D layer. Args: dynamic_weight: Whether to use dynamic weights. activation_reduce_range: Whether to reduce the range of activations. bits: Number of quantization bits. onnx_compatible: Whether quantization is compatible with ONNX. qconfig: Quantization configuration. """ super().__init__(*args, **kwargs) self.dynamic_weight = dynamic_weight if dynamic_weight: self.weight_fake_quant = FakeDynamicQuant( dtype=torch.qint8, reduce_range=False, bits=bits, onnx_compatible=onnx_compatible, ) self.input_pre_process = FakeDynamicQuant( reduce_range=activation_reduce_range, bits=bits, onnx_compatible=onnx_compatible, ) @property def fake_quant_weight(self) -> torch.Tensor: """Return a fake quantization over the weight matrix.""" return self.weight_fake_quant(self.weight)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.input_pre_process(x) size_out = x.size()[:-1] + (self.nf,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.fake_quant_weight) x = x.view(*size_out) return x
[docs] @classmethod def from_float( cls: FakeDynamicQuantHFConv1D, mod: torch.nn.Module, qconfig: Optional[Dict[torch.nn.Module, Any]] = None, activation_reduce_range: Optional[bool] = True, **kwargs, ) -> FakeDynamicQuantHFConv1D: """Map module from float to QAT-ready. Args: mod: Module to be mapped. qconfig: Quantization configuration. activation_reduce_range: Whether to reduce the range of activations. Returns: QAT-ready module. """ assert type(mod) == cls._FLOAT_MODULE, ( " qat." + cls.__name__ + ".from_float only works for " + cls._FLOAT_MODULE.__name__ ) if not qconfig: assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert mod.qconfig, "Input float module must have a valid qconfig" qconfig = mod.qconfig qat_conv1d = cls( mod.nf, mod.weight.shape[0], activation_reduce_range=activation_reduce_range, qconfig=qconfig, **kwargs, ) qat_conv1d.weight = mod.weight qat_conv1d.bias = mod.bias return qat_conv1d
[docs] def to_float(self) -> torch.nn.Module: """Map module from QAT-ready to float. Returns: Float-based module. """ weight = self.weight_fake_quant(self.weight) float_conv1d = transformers.modeling_utils.Conv1D(self.nf, self.weight.shape[0]) float_conv1d.weight = torch.nn.Parameter(weight) float_conv1d.bias = self.bias return float_conv1d
[docs]class FakeDynamicQuantHFConv1DForOnnx(FakeDynamicQuantHFConv1D): """Allow a QAT-ready huggingface/transformers Conv1D layer to be exported with ONNX.""" def __init__(self, *args, **kwargs): """Initialize a fake quantized Conv1D layer compatible with ONNX.""" kwargs["activation_reduce_range"] = False kwargs["onnx_compatible"] = True super().__init__(*args, **kwargs)