Source code for mi_module_zoo.relationaltransformerlayers

import torch
from torch import nn
from typing import Callable, Optional, Tuple, Union

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

from mi_module_zoo.relationalmultiheadattention import RelationalMultiheadAttention
from mi_module_zoo.utils.activation import get_activation_fn


[docs]class RelationalTransformerEncoderLayer(nn.Module): """ A relational transformer encoder layer. That supports both discrete/sparse edge types and dense (all-to-all) relations, different ReZero modes, and different normalization modes. Args: d_model: the dimensionality of the inputs/ouputs of the transformer layer. key_query_dimension: the dimensionality of key/queries in the multihead attention. value_dimension: the dimensionality of the multihead attention values, num_heads: the number of attention heads, num_edge_types: the number of discrete edge types. If ``0``, no discrete edge types are to be used. add_reverse_edges: if ``num_edge_types>0`` should reverse edge types be introduced? dim_feedforward: the dimensionality of the feedforward hidden layer in the transformer layer. dropout_rate: the dropout rate in :math:`[0, 1)`, activation: the activation function to be used in the feedforward layer. Defaults to ReLU. use_edge_value_biases: should the discrete edges (relations) use value biases? edge_attention_bias_is_scalar: should ``edge_attention_biases`` be a scalar or of size ``key_query_dimension``? rezero_mode: Three different modes are supported * ``"off"``: No ReZero use. * ``"scalar"``: Sublayers (attention / fully connected) are scaled by a single scalar, i.e., ``alpha`` is a scalar in the following: :: x' = x + alpha * SelfAtt(x) x'' = x' + alpha * Boom(x') return x'' See https://arxiv.org/pdf/2003.04887.pdf. * ``"vector"``: Sublayers (attention / fully connected) are scaled by one value per dim, i.e., ``alpha`` is a vector in the following: :: x' = x + alpha * SelfAtt(x) x'' = x' + alpha * Boom(x') return x'' See https://arxiv.org/pdf/2103.17239.pdf. normalisation_mode: Three different modes are supported: * ``"off"``: use no layer norm at all. Likely to diverge without using ReZero as well. * ``"prenorm"``: Normalise values before each sublayer (attention / fully connected): :: x' = x + SelfAtt(LN(x)) x'' = x' + Boom(LN(x')) return x'' * ``"postnorm"``: Normalise values after each sublayer: :: x' = LN(x + SelfAtt(x)) x'' = LN(x' + Boom(x)) return x'' """ def __init__( self, d_model: int, key_query_dimension: int, value_dimension: int, num_heads: int, num_edge_types: int, add_reverse_edges: bool = True, dim_feedforward: int = 2048, dropout_rate: float = 0.1, activation: str = "relu", use_edge_value_biases: bool = False, edge_attention_bias_is_scalar: bool = False, rezero_mode: Literal["off", "scalar", "vector", "scalar-tied"] = "off", normalisation_mode: Literal["off", "prenorm", "postnorm"] = "postnorm", ): super(RelationalTransformerEncoderLayer, self).__init__() assert 0 <= dropout_rate < 1 self.self_attn = RelationalMultiheadAttention( num_heads=num_heads, output_dimension=d_model, dropout_rate=dropout_rate, num_edge_types=2 * num_edge_types if add_reverse_edges else num_edge_types, key_query_dimension=key_query_dimension, value_dimension=value_dimension, use_edge_value_biases=use_edge_value_biases, edge_attention_bias_is_scalar=edge_attention_bias_is_scalar, ) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model) self._selfatt_head_transforms = nn.Linear( in_features=d_model, out_features=num_heads * (2 * key_query_dimension + value_dimension), bias=False, ) self._normalisation_mode = normalisation_mode if normalisation_mode in ("prenorm", "postnorm"): self.norm1: Optional[nn.LayerNorm] = nn.LayerNorm(d_model) self.norm2: Optional[nn.LayerNorm] = nn.LayerNorm(d_model) elif normalisation_mode == "off": self.norm1 = None self.norm2 = None else: raise ValueError(f"Unrecognized normalization mode `{normalisation_mode}`.") self.dropout1 = nn.Dropout(dropout_rate) self.dropout2 = nn.Dropout(dropout_rate) self.activation = get_activation_fn(activation) self._rezero_mode = rezero_mode if rezero_mode == "off": self._alpha1: Union[float, torch.Tensor] = 1.0 self._alpha2: Union[float, torch.Tensor] = 1.0 elif rezero_mode == "scalar": self._alpha1 = nn.Parameter(torch.tensor(0.0)) self._alpha2 = nn.Parameter(torch.tensor(0.0)) elif rezero_mode == "scalar-tied": # The original ReZero setting: https://github.com/majumderb/rezero/blob/e2c94a825c5564217e8cf4d75a28d59cab1d7029/rezero/transformer/rztx.py#L47 self._alpha1 = nn.Parameter(torch.tensor(0.0)) self._alpha2 = self._alpha1 elif rezero_mode == "vector": self._alpha1 = nn.Parameter(torch.zeros(size=(d_model,))) self._alpha2 = nn.Parameter(torch.zeros(size=(d_model,))) else: raise ValueError(f"Unrecognized rezero mode `{rezero_mode}`.") self._num_edge_types = num_edge_types self._add_reverse_edges = add_reverse_edges def _compute_qkv( self, input_seq_states: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: keys_queries_values = self._selfatt_head_transforms(input_seq_states).reshape( input_seq_states.shape[0], input_seq_states.shape[1], self.self_attn._num_heads, -1, ) queries, keys, values = torch.split( keys_queries_values, split_size_or_sections=[ self.self_attn._key_query_dim, self.self_attn._key_query_dim, self.self_attn._value_dim, ], dim=-1, ) # [B, query_len, num_heads, key_dim], [B, memory_len, num_heads, key_dim], [B, memory_len, num_heads, value_dim] return queries, keys, values
[docs] def forward( self, src: torch.Tensor, src_mask: torch.Tensor, edges: torch.Tensor, edge_types: torch.Tensor, dense_relations_kq: Optional[torch.Tensor] = None, dense_relations_kv: Optional[torch.Tensor] = None, post_self_att_hook: Optional[ Callable[[torch.Tensor, Union[float, torch.Tensor]], torch.Tensor] ] = None, ): """ :param src: A ``[batch_size, seq_len, D]`` tensor. :param src_mask: A ``[batch_size, seq_len]`` or ``[batch_size, seq_len (query), seq_len (key)]`` bool tensor. ``True`` values are those that should be masked (no attention paid). :param edges: ``[num_edges, 3]`` each row has the form ``(batch_idx, source_idx, target_idx)`` or an empty tensor of shape ``(0, 3)`` of sparse edges are unused. :param edge_types: ``[num_edges]`` of integers from ``0..num_edge_types-1`` or an empty tensor if sparse edges are unused. :param dense_relations_kq: Optional ``[batch_size, seq_len, seq_len, num_heads]``. :param dense_relations_kv: Optional ``[batch_size, seq_len, seq_len, num_heads, value_dimension]`` :return: ``[batch_size, seq_len, D]`` """ # --- Sublayer 1: Self-Attention: attn_input = src if self._normalisation_mode == "prenorm": attn_input = self.norm1(src) if self._add_reverse_edges: # Create reverse edges edge_sample_ids = edges[:, 0].repeat(2) edge_sources = torch.cat([edges[:, 1], edges[:, 2]]) edge_targets = torch.cat([edges[:, 2], edges[:, 1]]) edges = torch.stack((edge_sample_ids, edge_sources, edge_targets), dim=-1) edge_types = torch.cat([edge_types, edge_types + self._num_edge_types]) queries, keys, values = self._compute_qkv(attn_input) src2 = self.self_attn( queries=queries, keys=keys, values=values, masked_elements=src_mask, edges=edges, edge_types=edge_types, dense_relations_kq=dense_relations_kq, dense_relations_kv=dense_relations_kv, ) src2 = self._alpha1 * src2 src = src + self.dropout1(src2) if post_self_att_hook is not None: src = post_self_att_hook(src, self._alpha1) if self._normalisation_mode == "postnorm": src = self.norm1(src) fc_input = src if self._normalisation_mode == "prenorm": fc_input = self.norm2(fc_input) src2 = self.linear2(self.dropout(self.activation(self.linear1(fc_input)))) src2 = self._alpha2 * src2 src = src + self.dropout2(src2) if self._normalisation_mode == "postnorm": src = self.norm1(src) return src
[docs]class RelationalTransformerDecoderLayer(nn.Module): """ A relational transformer decoder layer. See the :class:`.RelationalTransformerEncoderLayer` for more information. """ def __init__( self, d_model: int, key_query_dimension: int, value_dimension: int, num_heads: int, num_self_edge_types: int, num_edge_types_to_encoder: int, add_reverse_edges: bool = True, dim_feedforward: int = 2048, dropout_rate: float = 0.1, activation: str = "relu", use_edge_value_biases: bool = False, edge_attention_bias_is_scalar: bool = False, rezero_mode: Literal["off", "scalar", "vector", "scalar-tied"] = "off", normalisation_mode: Literal["off", "prenorm", "postnorm"] = "postnorm", ): super().__init__() self.decoder = RelationalTransformerEncoderLayer( d_model=d_model, key_query_dimension=key_query_dimension, value_dimension=value_dimension, num_heads=num_heads, num_edge_types=2 * num_self_edge_types if add_reverse_edges else num_self_edge_types, add_reverse_edges=add_reverse_edges, dim_feedforward=dim_feedforward, dropout_rate=dropout_rate, activation=activation, use_edge_value_biases=use_edge_value_biases, edge_attention_bias_is_scalar=edge_attention_bias_is_scalar, rezero_mode=rezero_mode, normalisation_mode=normalisation_mode, ) self.dropout = nn.Dropout(dropout_rate) self._key_query_dim = key_query_dimension self._value_dim = value_dimension self._multi_head_att_transforms = nn.Linear( in_features=d_model, out_features=num_heads * (key_query_dimension + value_dimension), bias=False, ) self._query_transforms = nn.Linear( in_features=d_model, out_features=num_heads * key_query_dimension, bias=False ) self.multihead_attn = RelationalMultiheadAttention( num_heads=num_heads, output_dimension=d_model, dropout_rate=dropout_rate, num_edge_types=2 * num_edge_types_to_encoder if add_reverse_edges else num_edge_types_to_encoder, key_query_dimension=key_query_dimension, value_dimension=value_dimension, use_edge_value_biases=use_edge_value_biases, edge_attention_bias_is_scalar=edge_attention_bias_is_scalar, )
[docs] def forward( self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor, memory_mask: torch.Tensor, self_edges: torch.Tensor, self_edge_types: torch.Tensor, encoder_edges: torch.Tensor, encoder_edge_types: torch.Tensor, dense_self_relations_kq: Optional[torch.Tensor] = None, dense_self_relations_kv: Optional[torch.Tensor] = None, dense_encoder_relations_kq: Optional[torch.Tensor] = None, dense_encoder_relations_kv: Optional[torch.Tensor] = None, ): """ :param tgt: A ``[batch_size, seq_len, D]`` tensor. :param memory: A ``[batch_size, mem_len, D]`` tensor. :param tgt_mask: A ``[batch_size, seq_len]`` or ``[batch_size, seq_len, seq_len]`` bool tensor. ``True`` values are those that should be masked (no attention paid). For "causal" attention, ``tgt_mask`` should be 3D and ``tgt_mask[:, i, j] = i > j``, i.e. the upper-triangular elements should be ``True``. :param memory_mask: A ``[batch_size, mem_len]`` bool tensor. ``True`` values are those that should be masked (no attention paid). :param self_edges: ``[num_self_edges, 3]`` each row has the form ``(batch_idx, source_idx, target_idx)`` or an empty tensor of shape ``(0, 3)`` of sparse edges are unused.. :param self_edge_types: ``[num_self_edges]`` of integers from ``0..num_self_edges-1``. :param encoder_edges: ``[num_enc_edges, 3]`` each row has the form ``(batch_idx, source_idx, target_idx)`` or an empty tensor of shape ``(0, 3)`` of sparse edges are unused. Note: ``target_idx`` refers to elements in the memory. :param encoder_edge_types: ``[num_enc_edges]`` of integers from ``0..num_enc_edges-1`` :param dense_self_relations_kq: Optional ``[batch_size, seq_len, seq_len, num_heads]`` for the relationships within the decoder. :param dense_self_relations_kv: Optional ``[batch_size, seq_len, seq_len, num_heads, value_dimension]`` relationships within the decoder. :param dense_encoder_relations_kq: Optional ``[batch_size, seq_len, mem_len, num_heads]`` relationships between the encoded inputs and the decoder. :param dense_encoder_relations_kv: Optional ``[batch_size, seq_len, mem_len, num_heads, value_dimension]`` relationships between the encoded inputs and the decoder. :return: ``[batch_size, seq_len, H]`` """ def callback(src: torch.Tensor, rezero_alpha: Union[float, torch.Tensor]) -> torch.Tensor: kv = self._multi_head_att_transforms(memory).reshape( memory.shape[0], memory.shape[1], self.multihead_attn._num_heads, -1 ) keys, values = ( kv[:, :, :, : self._key_query_dim], kv[:, :, :, self._key_query_dim :], ) queries = self._query_transforms(src).reshape( src.shape[0], src.shape[1], self.multihead_attn.num_heads, -1 ) src2 = self.multihead_attn( queries=queries, keys=keys, values=values, masked_elements=memory_mask, edges=encoder_edges, edge_types=encoder_edge_types, dense_relations_kq=dense_encoder_relations_kq, dense_relations_kv=dense_encoder_relations_kv, ) return src + self.dropout(rezero_alpha * src2) return self.decoder( tgt, tgt_mask, edges=self_edges, edge_types=self_edge_types, dense_relations_kq=dense_self_relations_kq, dense_relations_kv=dense_self_relations_kv, post_self_att_hook=callback, )