Source code for mi_module_zoo.utils.activation

import torch
from typing import Callable, Dict

try:
    from typing import Final
except ImportError:
    from typing_extensions import Final


def identity(x: torch.Tensor) -> torch.Tensor:
    return x


ACTIVATION_FNS: Final[Dict[str, Callable[[torch.Tensor], torch.Tensor]]] = {
    "celu": torch.nn.functional.celu,
    "elu": torch.nn.functional.elu,
    "gelu": torch.nn.functional.gelu,
    "hardswish": torch.nn.functional.hardswish,
    "hardtanh": torch.nn.functional.hardtanh,
    "identity": identity,
    "leaky_relu": torch.nn.functional.leaky_relu,
    "logsigmoid": torch.nn.functional.logsigmoid,
    "log_sigmoid": torch.nn.functional.logsigmoid,
    "none": identity,
    "relu": torch.relu,
    "relu6": torch.nn.functional.relu6,
    "rrelu": torch.nn.functional.rrelu,
    "sigmoid": torch.sigmoid,
    "selu": torch.nn.functional.selu,
    "silu": torch.nn.functional.silu,
    "softplus": torch.nn.functional.softplus,
    "swish": torch.nn.functional.silu,
    "tanh": torch.tanh,
}


[docs]def get_activation_fn(activation: str) -> Callable[[torch.Tensor], torch.Tensor]: """ Get an activation function by name. :param activation: the name of the activation function. :param activation: a case-insensitive name of the activation function. :returns: an activation function """ activation = activation.lower() if activation not in ACTIVATION_FNS: raise RuntimeError( "Supported activations are `{}`, not {}".format(ACTIVATION_FNS.keys(), activation) ) return ACTIVATION_FNS[activation]