Admin-Torch documentation
A plug-in-and-play PyTorch wrapper for Adaptive model initialization (Admin).
For a neural network f, input x, randomly initialized weight w, we describe its stability (
output_change_scale
) as
In our study, we show that, an original N-layer Transformer’s output_change_scale
is O(n)
,
which unstabilizes its training. Admin stabilize Transformer’s training by regulating this scale to
O(logn)
and O(1)
. We keep O(logn)
as the default
setting, which can handle most scenarios.
In need of additional stability, set output_change_scale
to O(1)
instead.
admin_torch.as_module()
- admin_torch.as_module(num_res_layers, output_change_scale='default', as_parameter=False, embed_dim=None) admin_torch.admin.OmegaResidual [source]
Calculate initialization for omega and return a residual module with the initialized omega.
- Parameters
num_res_layers (
int
, required.) – The total number of residual layers. Typical n-layer Transformer encoder has 2n residual layers.output_change_scale (
str
, optional (default ='O(logn)'
).) – The desired output change scale at initialization. Only'O(n)'
,'O(logn)'
/'default'
, and'O(1)'
are supported.as_parameter (
bool
, optional (default = False).) – Whether to set the rescalar as trainable parameter. Note that, when set as trainable parameters, the rescalar would be set as a vector (similar to the weight vector in layer norm), and the embed_dim input is required.embed_dim (
int
, optional (default = None).) – The hidden state dimension of the shortcut connection. This field is required and only used when as_parameter == True.
- Returns
admin_torch.OmegaResidual
- Return type
It would return a
OmegaResidual
module with the properly initialized omega inside.
Example
import torch.nn as nn import admin_torch class TransformerEncoderLayer(nn.Module): def __init__(self, cfg): super().__init__() num_layer = 2 * cfg.encoder_layers # number of residual layers self.attn = nn.MultiheadAttention(cfg.embed_dim, cfg.num_heads) self.residual_attn = admin_torch.as_module(num_layer) self.ln_attn = nn.LayerNorm(cfg.embed_dim) self.ffn = nn.Sequential( nn.Linear(cfg.embed_dim, cfg.feedforward_dim), nn.ReLU(), nn.Linear(cfg.feedforward_dim) ) self.residual_ffn = admin_torch.as_module(num_layer) self.ln_ffn = nn.LayerNorm(cfg.embed_dim) def forward(self, x): f_x, _ = self.attn(x) x = self.residual_attn(x, f_x) x = self.ln_attn(x) f_x = self.ffn(x) x = self.residual_ffn(x, f_x) x = self.ln_ffn(x) return x
admin_torch.as_parameter()
- admin_torch.as_parameter(network, parameter_name, num_res_layers, embed_dim, output_change_scale='default') None [source]
Calculate initialization for omega and register omega as a parameter (trainable).
- Parameters
network (
torch.nn.Module
, required.) – Thetorch.nn.Module
contains the residual network. This is where the omega would be registered to.parameter_name (
str
, required.) – The name of omega (as parameter). The omega can be accessed in the network, using the given name.num_res_layers (
int
, required.) – The total number of residual layers. Typical n-layer Transformer encoder has 2n residual layers.embed_dim (
int
, required.) – The hidden state dimension of the shortcut connection.output_change_scale (
str
, optional (default ='O(logn)'
).) – The desired output change scale at initialization. Only'O(n)'
,'O(logn)'
/'default'
, and'O(1)'
are supported.
- Returns
None
- Return type
No returns. The initialized omega would be registered as a parameter within network.
Example
import torch.nn as nn import admin_torch class TransformerEncoderLayer(nn.Module): def __init__(self, cfg): super().__init__() num_layer = 2 * cfg.encoder_layers # number of residual layers self.attn = nn.MultiheadAttention(cfg.embed_dim, cfg.num_heads) admin_torch.as_parameter(self, 'attn_omega', num_layer, cfg.embed_dim) self.ln_attn = nn.LayerNorm(cfg.embed_dim) self.ffn = nn.Sequential( nn.Linear(cfg.embed_dim, cfg.feedforward_dim), nn.ReLU(), nn.Linear(cfg.feedforward_dim) ) admin_torch.as_parameter(self, 'ffn_omega', num_layer, cfg.embed_dim) self.ln_ffn = nn.LayerNorm(cfg.embed_dim) def forward(self, x): f_x, _ = self.attn(x) x = x * self.attn_omega + f_x x = self.ln_attn(x) f_x = self.ffn(x) x = x * self.ffn_omega + f_x x = self.ln_ffn(x) return x
admin_torch.as_buffer()
- admin_torch.as_buffer(network, buffer_name, num_res_layers, output_change_scale='default') None [source]
Calculate initialization for omega and register omega as a buffer (not trainable).
- Parameters
network (
torch.nn.Module
, required.) – Thetorch.nn.Module
contains the residual network. This is where the omega would be registered to.buffer_name (
str
, required.) – The name of omega (as buffer). The omega can be accessed in the network, using the given name.num_res_layers (
int
, required.) – The total number of residual layers. Typical n-layer Transformer encoder has 2n residual layers.output_change_scale (
str
, optional (default ='O(logn)'
).) – The desired output change scale at initialization. Only'O(n)'
,'O(logn)'
/'default'
, and'O(1)'
are supported.
- Returns
None
- Return type
No returns. The initialized omega would be registered as a buffer within network.
Example
import torch.nn as nn import admin_torch class TransformerEncoderLayer(nn.Module): def __init__(self, cfg): super().__init__() num_layer = 2 * cfg.encoder_layers # number of residual layers self.attn = nn.MultiheadAttention(cfg.embed_dim, cfg.num_heads) admin_torch.as_buffer(self, 'attn_omega', num_layer) self.ln_attn = nn.LayerNorm(cfg.embed_dim) self.ffn = nn.Sequential( nn.Linear(cfg.embed_dim, cfg.feedforward_dim), nn.ReLU(), nn.Linear(cfg.feedforward_dim) ) admin_torch.as_buffer(self, 'ffn_omega', num_layer) self.ln_ffn = nn.LayerNorm(cfg.embed_dim) def forward(self, x): f_x, _ = self.attn(x) x = x * self.attn_omega + f_x x = self.ln_attn(x) f_x = self.ffn(x) x = x * self.ffn_omega + f_x x = self.ln_ffn(x) return x
admin_torch.OmegaResidual
- class admin_torch.OmegaResidual(*args: Any, **kwargs: Any)[source]
Residual connection module with shortcut connection rescaling.
- Parameters
init_value (
float
, required.) – The initialization value of the shortcut connection rescalar, omega.as_parameter (
bool
, optional (default = False).) – Whether to set the rescalar as trainable parameter. Note that, when set as trainable parameters, the rescalar would be set as a vector (similar to the weight vector in layer norm), and the embed_dim input is required.embed_dim (
int
, optional (default = None).) – The hidden state dimension of the shortcut connection. This field is required and only used whenas_parameter == True
.