'''
Adapted from https://github.com/ctlllll/SGConv
'''
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from einops import rearrange
import opt_einsum as oe
from archai.discrete_search.search_spaces.config import ArchConfig
from ..utils import get_optim_flag
try:
from .fftconv_ import fftconv_func
except ImportError:
fftconv_func = None
optimized = True
if optimized:
contract = oe.contract
else:
contract = torch.einsum
[docs]def get_initializer(name, activation=None):
if activation in [ None, 'id', 'identity', 'linear', 'modrelu' ]:
nonlinearity = 'linear'
elif activation in ['relu', 'tanh', 'sigmoid']:
nonlinearity = activation
elif activation in ['gelu', 'swish', 'silu']:
nonlinearity = 'relu' # Close to ReLU so approximate with ReLU's gain
else:
raise NotImplementedError(f"get_initializer: activation {activation} not supported")
if name == 'uniform':
initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity)
elif name == 'normal':
initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity)
elif name == 'xavier':
initializer = torch.nn.init.xavier_normal_
elif name == 'zero':
initializer = partial(torch.nn.init.constant_, val=0)
elif name == 'one':
initializer = partial(torch.nn.init.constant_, val=1)
else:
raise NotImplementedError(f"get_initializer: initializer type {name} not supported")
return initializer
[docs]class modrelu(nn.Module):
def __init__(self, features):
# For now we just support square layers
super(modrelu, self).__init__()
self.features = features
self.b = nn.Parameter(torch.Tensor(self.features))
self.reset_parameters()
[docs] def reset_parameters(self):
self.b.data.uniform_(-0.01, 0.01)
[docs] def forward(self, inputs):
norm = torch.abs(inputs)
biased_norm = norm + self.b
magnitude = nn.functional.relu(biased_norm)
phase = torch.sign(inputs)
return phase * magnitude
[docs]class Modrelu(modrelu):
[docs] def reset_parameters(self):
self.b.data.uniform_(-0.01, 0.01)
[docs]class TransposedLinear(nn.Module):
""" Linear module on the second-to-last dimension
Assumes shape (B, D, L), where L can be 1 or more axis
"""
def __init__(self, d_input, d_output, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.empty(d_output, d_input))
# nn.Linear default init
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
# nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent
if bias:
self.bias = nn.Parameter(torch.empty(d_output))
bound = 1 / math.sqrt(d_input)
nn.init.uniform_(self.bias, -bound, bound)
setattr(self.bias, "_optim", {"weight_decay": 0.0})
else:
self.bias = 0.0
[docs] def forward(self, x):
num_axis = len(x.shape[2:]) # num_axis in L, for broadcasting bias
y = contract('b u ..., v u -> b v ...', x, self.weight) + \
self.bias.view(-1, *[1]*num_axis)
return y
[docs]class TransposedLN(nn.Module):
""" LayerNorm module over second dimension
Assumes shape (B, D, L), where L can be 1 or more axis
This is slow and a dedicated CUDA/Triton implementation shuld provide substantial end-to-end speedup
"""
def __init__(self, d, scalar=True):
super().__init__()
self.scalar = scalar
if self.scalar:
self.m = nn.Parameter(torch.zeros(1))
self.s = nn.Parameter(torch.ones(1))
setattr(self.m, "_optim", {"weight_decay": 0.0})
setattr(self.s, "_optim", {"weight_decay": 0.0})
else:
self.ln = nn.LayerNorm(d)
[docs] def forward(self, x):
if self.scalar:
# calc. stats over D dim / channels
s, m = torch.std_mean(x, dim=1, unbiased=False, keepdim=True)
y = (self.s/s) * (x-m+self.m)
else:
# move channel to last axis, apply layer_norm, then move channel back to second axis
_x = self.ln(rearrange(x, 'b d ... -> b ... d'))
y = rearrange(_x, 'b ... d -> b d ...')
return y
[docs]def Activation(activation=None, size=None, dim=-1):
if activation in [None, 'id', 'identity', 'linear']:
return nn.Identity()
elif activation == 'tanh':
return nn.Tanh()
elif activation == 'relu':
return nn.ReLU()
elif activation == 'gelu':
return nn.GELU(approximate='none')
elif activation in ['swish', 'silu']:
return nn.SiLU()
elif activation == 'glu':
return nn.GLU(dim=dim)
elif activation == 'sigmoid':
return nn.Sigmoid()
elif activation == 'modrelu':
return Modrelu(size)
elif activation == 'ln':
return TransposedLN(dim)
else:
raise NotImplementedError(
"hidden activation '{}' is not implemented".format(activation))
[docs]def LinearActivation(
d_input, d_output, bias=True,
zero_bias_init=False,
transposed=False,
initializer=None,
activation=None,
activate=False, # Apply activation as part of this module
weight_norm=False,
**kwargs,
):
""" Returns a linear nn.Module with control over axes order, initialization, and activation """
# Construct core module
# linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear
linear_cls = TransposedLinear if transposed else nn.Linear
if activation == 'glu':
d_output *= 2
linear = linear_cls(d_input, d_output, bias=bias, **kwargs)
# Initialize weight
if initializer is not None:
get_initializer(initializer, activation)(linear.weight)
# Initialize bias
if bias and zero_bias_init:
nn.init.zeros_(linear.bias)
# Weight norm
if weight_norm:
linear = nn.utils.weight_norm(linear)
if activate and activation is not None:
activation = Activation(activation, d_output,
dim=1 if transposed else -1)
linear = nn.Sequential(linear, activation)
return linear
[docs]class Normalization(nn.Module):
def __init__(
self,
d,
transposed=False, # Length dimension is -1 or -2
_name_='layer',
**kwargs
):
super().__init__()
self.transposed = transposed
self._name_ = _name_
if _name_ == 'layer':
self.channel = True # Normalize over channel dimension
if self.transposed:
self.norm = TransposedLN(d, **kwargs)
else:
self.norm = nn.LayerNorm(d, **kwargs)
elif _name_ == 'instance':
self.channel = False
norm_args = {'affine': False, 'track_running_stats': False}
norm_args.update(kwargs)
self.norm = nn.InstanceNorm1d(d, **norm_args) # (True, True) performs very poorly
elif _name_ == 'batch':
self.channel = False
norm_args = {'affine': True, 'track_running_stats': True}
norm_args.update(kwargs)
self.norm = nn.BatchNorm1d(d, **norm_args)
elif _name_ == 'group':
self.channel = False
self.norm = nn.GroupNorm(1, d, *kwargs)
elif _name_ == 'none':
self.channel = True
self.norm = nn.Identity()
else: raise NotImplementedError
[docs] def forward(self, x):
# Handle higher dimension logic
shape = x.shape
if self.transposed:
x = rearrange(x, 'b d ... -> b d (...)')
else:
x = rearrange(x, 'b ... d -> b (...)d ')
# The cases of LayerNorm / no normalization are automatically handled in all cases
# Instance/Batch Norm work automatically with transposed axes
if self.channel or self.transposed:
x = self.norm(x)
else:
x = x.transpose(-1, -2)
x = self.norm(x)
x = x.transpose(-1, -2)
x = x.view(shape)
return x
[docs] def step(self, x, **kwargs):
assert self._name_ in ["layer", "none"]
if self.transposed: x = x.unsqueeze(-1)
x = self.forward(x)
if self.transposed: x = x.squeeze(-1)
return x
[docs]class GConv(nn.Module):
requires_length = True
def __init__(
self,
d_model,
d_state=64,
l_max=1, # Maximum length of sequence. Fine if not provided: the kernel will keep doubling in length until longer than sequence. However, this can be marginally slower if the true length is not a power of 2
channels=1, # maps 1-dim to C-dim
bidirectional=False,
# Arguments for FF
activation='gelu', # activation in between SS and FF
ln=False, # Extra normalization
postact=None, # activation after FF
initializer=None, # initializer on FF
weight_norm=False, # weight normalization on FF
hyper_act=None, # Use a "hypernetwork" multiplication
use_fast_fftconv=False,
dropout=0.0,
transposed=True, # axis ordering (B, L, D) or (B, D, L)
verbose=False,
shift=False,
linear=False,
mode="cat_randn",
# SSM Kernel arguments
**kernel_args,
):
"""
d_state: the dimension of the state, also denoted by N
l_max: the maximum sequence length, also denoted by L
if this is not known at model creation, set l_max=1
channels: can be interpreted as a number of "heads"
bidirectional: bidirectional
dropout: standard dropout argument
transposed: choose backbone axis ordering of (B, L, H) or (B, H, L) [B=batch size, L=sequence length, H=hidden dimension]
Other options are all experimental and should not need to be configured
"""
super().__init__()
self.h = d_model
self.n = d_state
self.bidirectional = bidirectional
self.ln = ln
self.channels = channels
self.transposed = transposed
self.shift = shift
self.linear = linear
self.mode = mode
self.l_max = l_max
self.verbose = verbose
self.use_fast_fftconv = use_fast_fftconv
if self.use_fast_fftconv:
assert fftconv_func is not None, 'Need to install fftconv'
assert self.channels == 1, 'channel must be 1 for fast FFTConv'
# optional multiplicative modulation GLU-style
# https://arxiv.org/abs/2002.05202
self.hyper = hyper_act is not None
if self.hyper:
channels *= 2
self.hyper_activation = Activation(hyper_act)
self.D = nn.Parameter(torch.randn(channels, self.h))
if self.bidirectional:
channels *= 2
# Pointwise
if not self.linear:
self.activation = Activation(activation)
dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout
self.dropout = dropout_fn(
dropout) if dropout > 0.0 else nn.Identity()
if self.ln:
self.norm = Normalization(
self.h*self.channels, transposed=transposed)
else:
self.norm = nn.Identity()
# position-wise output transform to mix features
if not self.linear:
self.output_linear = LinearActivation(
self.h*self.channels,
self.h,
transposed=self.transposed,
initializer=initializer,
activation=postact,
activate=True,
weight_norm=weight_norm,
)
self.init_scale = kernel_args.get('init_scale', 0)
self.kernel_dim = kernel_args.get('kernel_dim', 64)
self.num_scales = kernel_args.get(
'n_scales', 1+math.ceil(math.log2(l_max/self.kernel_dim))-self.init_scale)
if self.num_scales is None:
self.num_scales = 1 + \
math.ceil(math.log2(l_max/self.kernel_dim)) - self.init_scale
self.kernel_list = nn.ParameterList()
decay_min = kernel_args.get('decay_min', 2)
decay_max = kernel_args.get('decay_max', 2)
for _ in range(self.num_scales):
if 'randn' in mode:
kernel = nn.Parameter(torch.randn(
channels, self.h, self.kernel_dim))
elif 'cos' in mode:
kernel = nn.Parameter(torch.cat([torch.cos(torch.linspace(0, 2*i*math.pi, self.kernel_dim)).expand(
channels, 1, self.kernel_dim) for i in range(self.h)], dim=1)[:, torch.randperm(self.h), :])
else:
raise ValueError(f"Unknown mode {mode}")
kernel._optim = {
'lr': kernel_args.get('lr', 0.001),
}
self.kernel_list.append(kernel)
if 'learnable' in mode:
self.decay = nn.Parameter(torch.rand(
self.h) * (decay_max - decay_min) + decay_min)
if 'fixed' in mode:
self.decay.requires_grad = False
else:
self.decay._optim = {
'lr': kernel_args.get('lr', 0.001),
}
self.register_buffer('multiplier', torch.tensor(1.0))
else:
self.register_buffer('multiplier', torch.linspace(
decay_min, decay_max, self.h).view(1, -1, 1))
self.register_buffer('kernel_norm', torch.ones(self.channels, self.h, 1))
self.register_buffer('kernel_norm_initialized',
torch.tensor(0, dtype=torch.bool))
[docs] def fft_conv(self, u, k, L):
if self.use_fast_fftconv:
k = rearrange(k, '1 h l -> h l')
dropout_mask = None
# No GeLU after the SSM
# We want output_hbl=True so that y has the same layout as u
y = fftconv_func(u, k, self.D.squeeze(0), dropout_mask, False, False, True)
y = rearrange(rearrange(y, 'b h l -> h b l'), 'h b l -> b h l')
# y = rearrange(y, 'b h l -> b 1 h l')
return y
k_f = torch.fft.rfft(k, n=2*L) # (C H L)
u_f = torch.fft.rfft(u, n=2*L) # (B H L)
# k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L)
y_f = contract('bhl,chl->bchl', u_f, k_f)
y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L)
# Compute D term in state space equation - essentially a skip connection
y = y + contract('bhl,ch->bchl', u, self.D)
# Reshape to flatten channels
return rearrange(y, '... c h l -> ... (c h) l')
[docs] def forward(self, u, return_kernel=False):
"""
u: (B H L) if self.transposed else (B L H)
state: (H N) never needed unless you know what you're doing
Returns: same shape as u
"""
if not self.transposed:
u = u.transpose(-1, -2)
L = u.size(-1)
if self.use_fast_fftconv and L % 2 != 0:
u = F.pad(u, (0, 1))
kernel_list = []
interpolate_mode = 'nearest' if 'nearest' in self.mode else 'linear'
multiplier = self.multiplier
if 'sum' in self.mode:
for i in range(self.num_scales):
kernel = F.pad(
F.interpolate(
self.kernel_list[i],
scale_factor=2**(i+self.init_scale),
mode=interpolate_mode,
),
(0, self.kernel_dim*2**(self.num_scales-1+self.init_scale) -
self.kernel_dim*2**(i+self.init_scale)),
) * multiplier ** (self.num_scales - i - 1)
kernel_list.append(kernel)
k = sum(kernel_list)
elif 'cat' in self.mode:
for i in range(self.num_scales):
kernel = F.interpolate(
self.kernel_list[i],
scale_factor=2**(max(0, i-1)+self.init_scale),
mode=interpolate_mode,
) * multiplier ** (self.num_scales - i - 1)
kernel_list.append(kernel)
k = torch.cat(kernel_list, dim=-1)
else:
raise ValueError(f"Unknown mode {self.mode}")
if 'learnable' in self.mode:
k = k * torch.exp(-self.decay.view(1, -1, 1)*torch.log(
torch.arange(k.size(-1), device=k.device)+1).view(1, 1, -1))
if not self.kernel_norm_initialized:
self.kernel_norm = k.norm(dim=-1, keepdim=True).detach()
self.kernel_norm_initialized = torch.tensor(
1, dtype=torch.bool, device=k.device)
if self.verbose:
print(f"Kernel norm: {self.kernel_norm.mean()}")
print(f"Kernel size: {k.size()}")
if k.size(-1) > L:
k = k[..., :L]
elif k.size(-1) < L:
k = F.pad(k, (0, L - k.size(-1)))
k = k / self.kernel_norm # * (L / self.l_max) ** 0.5
# Convolution
if self.bidirectional:
k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2)
k = F.pad(k0, (0, L)) \
+ F.pad(k1.flip(-1), (L, 0)) \
y = self.fft_conv(u, k, L)
if not self.linear:
y = self.dropout(self.activation(y))
if not self.transposed:
y = y.transpose(-1, -2)
if not self.linear:
y = self.norm(y)
y = self.output_linear(y)
if return_kernel:
return y, k
return y, None
@property
def d_state(self):
return self.h * self.n
@property
def d_output(self):
return self.h
@property
def state_to_tensor(self):
return lambda state: rearrange('... h n -> ... (h n)', state)
[docs]class SGConv(nn.Module):
def __init__(self, arch_config: ArchConfig, hidden_size: int,
total_heads: int, op_heads: int,
hf_config: PretrainedConfig, **kwargs):
super().__init__()
assert hidden_size % total_heads == 0
self.hidden_size = hidden_size
self.total_heads = total_heads
self.op_heads = op_heads
# Architecture params
self.kernel_size = arch_config.pick('kernel_size')
self.use_fast_fftconv = get_optim_flag(hf_config, 'fast_fftconv')
self.channels = 1
self.op_size = op_heads * (hidden_size // total_heads)
self.in_proj = nn.Sequential(
nn.Linear(hidden_size, self.op_size * 2),
nn.GLU(dim=-1)
)
self.sgconv = GConv(
self.op_size, l_max=hf_config.max_position_embeddings,
channels=self.channels, kernel_dim=self.kernel_size,
use_fast_fftconv=self.use_fast_fftconv,
transposed=False, verbose=False
)
self.act = nn.GELU(approximate='none')
[docs] def forward(self, x: torch.Tensor, **kwargs):
output, _ = self.sgconv(self.in_proj(x))
return self.act(output), None
if __name__ == '__main__':
B = 2 # batch size
H = 768 # d_model
L = 2048 # sequence length
device = 'cuda'
import torch.utils.benchmark as benchmark
flash_layer = GConv(d_model=H, l_max=L, kernel_dim=128, use_fast_fftconv=True, transposed=False).to(device)
layer = GConv(d_model=H, l_max=L, kernel_dim=128, use_fast_fftconv=False, transposed=False).to(device)
u = torch.randn(B, L, H, device=device, dtype=torch.float32, requires_grad=True)
t0 = benchmark.Timer(
stmt='flash_layer(u)',
globals={'flash_layer': flash_layer, 'u': u})
t1 = benchmark.Timer(
stmt='layer(u)',
globals={'layer': layer, 'u': u})
print(t0.timeit(100))
print(t1.timeit(100))