Source code for mi_module_zoo.mlp
import torch.nn as nn
from typing import List, Sequence
[docs]def construct_mlp(
input_dim: int,
out_dim: int,
hidden_layer_dims: Sequence[int],
activation_layer: nn.Module = nn.ReLU(),
) -> nn.Sequential:
"""
Construct a multi-linear perceptron (MLP). No non-linearity is applied at the final layer.
:param input_dim: the input dimension of the MLP.
:param out_dim: the input dimension of the MLP.
:param hidden_layer_dims: a list of zero or more integers indicating the dimensions of the
hidden layers.
:param activation_layer: the activation layer used between the input and hidden layers.
:returns: a :class:`nn.Sequential` with the constructed MLP.
"""
layers: List[nn.Module] = []
cur_hidden_dim = input_dim
for hidden_layer_dim in hidden_layer_dims:
layers.append(nn.Linear(cur_hidden_dim, hidden_layer_dim))
layers.append(activation_layer)
cur_hidden_dim = hidden_layer_dim
layers.append(nn.Linear(cur_hidden_dim, out_dim))
return nn.Sequential(*layers)