Source code for archai.discrete_search.evaluators.nlp.parameters

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import List, Optional

from overrides import overrides
from torch import nn

from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.model_evaluator import ModelEvaluator


[docs]class NonEmbeddingParamsProxy(ModelEvaluator): """Total number of non-embedding parameters.""" def __init__(self, exclude_cls: Optional[List[nn.Module]] = None, trainable_only: Optional[bool] = True) -> None: """Initialize the evaluator. Used as a proxy for the perplexity of decoder-only transformer LMs. Args: exclude_cls: List of PyTorch module classes to exclude from parameter counting. If `None`, defaults to `[torch.nn.Embedding]`. trainable_only: Whether only trainable parameters should be counted. Reference: "LiteTransformerSearch: Training-free Neural Architecture Search for Efficient Language Models", Javaheripi et. al, 2022 """ self.exclude_cls = [nn.Embedding] or exclude_cls self.trainable_only = trainable_only
[docs] @overrides def evaluate(self, model: ArchaiModel, budget: Optional[float] = None) -> float: total_params = sum( param.numel() for param in model.arch.parameters() if not self.trainable_only or param.requires_grad ) embed_params = sum( sum(param.numel() for param in module.parameters()) for module in model.arch.modules() if isinstance(module, tuple(self.exclude_cls)) ) return total_params - embed_params