Source code for archai.discrete_search.search_spaces.nlp.transformer_flex.models.configuration_gpt2_flex

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Any, List, Optional

from transformers.models.gpt2.configuration_gpt2 import GPT2Config


def _map_to_list(variable: Any, size: int) -> List[Any]:
    if isinstance(variable, list):
        size_diff = size - len(variable)

        if size_diff < 0:
            return variable[:size]
        elif size_diff == 0:
            return variable
        elif size_diff > 0:
            return variable + [variable[0]] * size_diff

    return [variable] * size


[docs]class GPT2FlexConfig(GPT2Config): model_type = "gpt2-flex" def __init__(self, *args, primer_square: Optional[bool] = False, **kwargs) -> None: super().__init__(*args, **kwargs) self.primer_square = primer_square if primer_square: self.activation_function = "relu" self.n_inner = self.n_inner if self.n_inner is not None else 4 * self.n_embd self.n_inner = _map_to_list(self.n_inner, self.n_layer) self.n_head = _map_to_list(self.n_head, self.n_layer)