Source code for archai.quantization.mixed_qat
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
from typing import Optional, Tuple
import torch
from archai.quantization.qat import prepare_with_qat
[docs]class MixedQAT(torch.nn.Module):
"""Mixed QAT (Quantization-Aware Training) model, which can be fine-tuned
using a linear combination of regular and QAT losses.
"""
def __init__(self, model: torch.nn.Module, qat_weight: Optional[float] = 0.2) -> None:
"""Initialize the class by creating standard and QAT-based attributes
of the incoming model.
Args:
model: Instance of the model that will be fine-tuned with Mixed QAT.
qat_weight: Amount of QAT-based loss that should be used in the linear combination.
This value should be between 0 and 1.
"""
super().__init__()
if qat_weight < 0.0 or qat_weight > 1.0:
raise ValueError(f"qat_weight: {qat_weight} should be between 0 and 1.")
self.qat_weight = qat_weight
self.regular_weight = 1.0 - qat_weight
self.model = model
self.qat_model = copy.deepcopy(model)
# Shares all parameters
for module, qat_module in zip(self.model.modules(), self.qat_model.modules()):
if hasattr(qat_module, "weight"):
qat_module.weight = module.weight
if hasattr(qat_module, "bias"):
qat_module.bias = module.bias
# Adds fake quantization
prepare_with_qat(self.qat_model, onnx_compatible=True)
for param, qat_param in zip(self.model.parameters(), self.qat_model.parameters()):
assert qat_param is param, "MixedQAT parameters are not fully shared."
[docs] def forward(
self, input_ids: torch.LongTensor, labels: torch.LongTensor, *args, **kwargs
) -> Tuple[torch.Tensor, ...]:
outputs = self.model(input_ids=input_ids, labels=labels, *args, **kwargs)
qat_outputs = self.qat_model(input_ids=input_ids, labels=labels, *args, **kwargs)
# If training, returns the linear combination of losses
if self.training:
loss = outputs.loss * self.regular_weight + qat_outputs.loss * self.qat_weight
return (loss,) + outputs[1:]
return qat_outputs