Source code for archai.quantization.qat

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import copy
from typing import Any, Dict, Optional

import torch

from archai.quantization.modules import (
    FakeDynamicQuantConv1d,
    FakeDynamicQuantConv1dForOnnx,
    FakeDynamicQuantLinear,
    FakeDynamicQuantLinearForOnnx,
    FakeQuantEmbedding,
    FakeQuantEmbeddingForOnnx,
)

DYNAMIC_QAT_MODULE_MAP = {
    torch.nn.Conv1d: FakeDynamicQuantConv1d,
    torch.nn.Linear: FakeDynamicQuantLinear,
    torch.nn.Embedding: FakeQuantEmbedding,
}
ONNX_DYNAMIC_QAT_MODULE_MAP = {
    torch.nn.Conv1d: FakeDynamicQuantConv1dForOnnx,
    torch.nn.Linear: FakeDynamicQuantLinearForOnnx,
    torch.nn.Embedding: FakeQuantEmbeddingForOnnx,
}

try:
    import transformers

    from archai.quantization.nlp.modules import (
        FakeDynamicQuantHFConv1D,
        FakeDynamicQuantHFConv1DForOnnx,
    )

    DYNAMIC_QAT_MODULE_MAP[transformers.modeling_utils.Conv1D] = FakeDynamicQuantHFConv1D
    ONNX_DYNAMIC_QAT_MODULE_MAP[transformers.modeling_utils.Conv1D] = FakeDynamicQuantHFConv1DForOnnx
except ModuleNotFoundError:
    print("`archai.quantization.nlp` is not available. If needed, install with: pip install archai[nlp].")

from archai.common.ordered_dict_logger import OrderedDictLogger

logger = OrderedDictLogger(source=__name__)


[docs]def qat_to_float_modules(model: torch.nn.Module) -> None: """Convert QAT-ready modules to float-based modules. This function converts all QAT-ready modules in the input model to float-based modules. It does this recursively, so all sub-modules within the input model will also be converted if applicable. Args: model: QAT-ready module to be converted. """ for name in list(model._modules): module = model._modules[name] if hasattr(module, "to_float"): model._modules[name] = module.to_float() else: qat_to_float_modules(module)
[docs]def float_to_qat_modules( model: torch.nn.Module, module_mapping: Optional[Dict[torch.nn.Module, torch.nn.Module]] = DYNAMIC_QAT_MODULE_MAP, qconfig: Optional[Dict[torch.nn.Module, Any]] = None, **kwargs ) -> None: """Convert float-based modules to QAT-ready modules. This function converts all float-based modules in the input model to QAT-ready modules using the provided module mapping. It does this recursively, so all sub-modules within the input model will also be converted if applicable. A quantization configuration can also be supplied. Args: model: Float-based module to be converted. module_mapping: Maps between float and QAT-ready modules. qconfig: Quantization configuration to be used for the conversion. """ for name in list(model._modules): module = model._modules[name] if type(module) in module_mapping: if not hasattr(module, "qconfig"): module.qconfig = qconfig model._modules[name] = module_mapping[type(module)].from_float(module, qconfig, **kwargs) else: float_to_qat_modules(module, module_mapping=module_mapping, qconfig=qconfig, **kwargs)
[docs]def prepare_with_qat( model: torch.nn.Module, inplace: Optional[bool] = True, onnx_compatible: Optional[bool] = False, backend: Optional[str] = "qnnpack", **kwargs ) -> torch.nn.Module: """Prepare a float-based PyTorch model for quantization-aware training (QAT). This function modifies the input model in place by inserting QAT-based modules and configurations. Args: model: Float-based PyTorch module to be prepared for QAT. inplace: Whether the prepared QAT model should replace the original model. onnx_compatible: Whether the prepared QAT model should be compatible with ONNX. backend: Quantization backend to be used. Returns: The input model, modified in place (or not) to be ready for QAT. """ logger.info("Preparing model with QAT ...") prepared_model = model if not inplace: prepared_model = copy.deepcopy(model) qconfig = torch.quantization.get_default_qat_qconfig(backend) module_mapping = ONNX_DYNAMIC_QAT_MODULE_MAP if onnx_compatible else DYNAMIC_QAT_MODULE_MAP float_to_qat_modules(prepared_model, module_mapping=module_mapping, qconfig=qconfig, **kwargs) return prepared_model