Relational Transformer Encoder/Decoder Layers

This contains the relational encoder and decoder layers with sparse/dense relations among inputs.

class mi_module_zoo.relationaltransformerlayers.RelationalTransformerEncoderLayer(d_model: int, key_query_dimension: int, value_dimension: int, num_heads: int, num_edge_types: int, add_reverse_edges: bool = True, dim_feedforward: int = 2048, dropout_rate: float = 0.1, activation: str = 'relu', use_edge_value_biases: bool = False, edge_attention_bias_is_scalar: bool = False, rezero_mode: Literal['off', 'scalar', 'vector', 'scalar-tied'] = 'off', normalisation_mode: Literal['off', 'prenorm', 'postnorm'] = 'postnorm')[source]

A relational transformer encoder layer. That supports both discrete/sparse edge types and dense (all-to-all) relations, different ReZero modes, and different normalization modes.

Parameters
  • d_model – the dimensionality of the inputs/ouputs of the transformer layer.

  • key_query_dimension – the dimensionality of key/queries in the multihead attention.

  • value_dimension – the dimensionality of the multihead attention values,

  • num_heads – the number of attention heads,

  • num_edge_types – the number of discrete edge types. If 0, no discrete edge types are to be used.

  • add_reverse_edges – if num_edge_types>0 should reverse edge types be introduced?

  • dim_feedforward – the dimensionality of the feedforward hidden layer in the transformer layer.

  • dropout_rate – the dropout rate in \([0, 1)\),

  • activation – the activation function to be used in the feedforward layer. Defaults to ReLU.

  • use_edge_value_biases – should the discrete edges (relations) use value biases?

  • edge_attention_bias_is_scalar – should edge_attention_biases be a scalar or of size key_query_dimension?

  • rezero_mode

    Three different modes are supported

    • "off": No ReZero use.

    • "scalar": Sublayers (attention / fully connected) are scaled by a single scalar, i.e., alpha is a scalar in the following:

      x' = x + alpha * SelfAtt(x)
      x'' = x' + alpha * Boom(x')
      return x''
      

      See https://arxiv.org/pdf/2003.04887.pdf.

    • "vector": Sublayers (attention / fully connected) are scaled by one value per dim, i.e., alpha is a vector in the following:

      x' = x + alpha * SelfAtt(x)
      x'' = x' + alpha * Boom(x')
      return x''
      

      See https://arxiv.org/pdf/2103.17239.pdf.

  • normalisation_mode

    Three different modes are supported:

    • "off": use no layer norm at all. Likely to diverge without using ReZero as well.

    • "prenorm": Normalise values before each sublayer (attention / fully connected):

      x' = x + SelfAtt(LN(x))
      x'' = x' + Boom(LN(x'))
      return x''
      
    • "postnorm": Normalise values after each sublayer:

      x' = LN(x + SelfAtt(x))
      x'' = LN(x' + Boom(x))
      return x''
      

forward(src: torch.Tensor, src_mask: torch.Tensor, edges: torch.Tensor, edge_types: torch.Tensor, dense_relations_kq: Optional[torch.Tensor] = None, dense_relations_kv: Optional[torch.Tensor] = None, post_self_att_hook: Optional[Callable[[torch.Tensor, Union[float, torch.Tensor]], torch.Tensor]] = None)[source]
Parameters
  • src – A [batch_size, seq_len, D] tensor.

  • src_mask – A [batch_size, seq_len] or [batch_size, seq_len (query), seq_len (key)] bool tensor. True values are those that should be masked (no attention paid).

  • edges[num_edges, 3] each row has the form (batch_idx, source_idx, target_idx) or an empty tensor of shape (0, 3) of sparse edges are unused.

  • edge_types[num_edges] of integers from 0..num_edge_types-1 or an empty tensor if sparse edges are unused.

  • dense_relations_kq – Optional [batch_size, seq_len, seq_len, num_heads].

  • dense_relations_kv – Optional [batch_size, seq_len, seq_len, num_heads, value_dimension]

Returns

[batch_size, seq_len, D]

class mi_module_zoo.relationaltransformerlayers.RelationalTransformerDecoderLayer(d_model: int, key_query_dimension: int, value_dimension: int, num_heads: int, num_self_edge_types: int, num_edge_types_to_encoder: int, add_reverse_edges: bool = True, dim_feedforward: int = 2048, dropout_rate: float = 0.1, activation: str = 'relu', use_edge_value_biases: bool = False, edge_attention_bias_is_scalar: bool = False, rezero_mode: Literal['off', 'scalar', 'vector', 'scalar-tied'] = 'off', normalisation_mode: Literal['off', 'prenorm', 'postnorm'] = 'postnorm')[source]

A relational transformer decoder layer. See the RelationalTransformerEncoderLayer for more information.

forward(tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor, memory_mask: torch.Tensor, self_edges: torch.Tensor, self_edge_types: torch.Tensor, encoder_edges: torch.Tensor, encoder_edge_types: torch.Tensor, dense_self_relations_kq: Optional[torch.Tensor] = None, dense_self_relations_kv: Optional[torch.Tensor] = None, dense_encoder_relations_kq: Optional[torch.Tensor] = None, dense_encoder_relations_kv: Optional[torch.Tensor] = None)[source]
Parameters
  • tgt – A [batch_size, seq_len, D] tensor.

  • memory – A [batch_size, mem_len, D] tensor.

  • tgt_mask – A [batch_size, seq_len] or [batch_size, seq_len, seq_len] bool tensor. True values are those that should be masked (no attention paid). For “causal” attention, tgt_mask should be 3D and tgt_mask[:, i, j] = i > j, i.e. the upper-triangular elements should be True.

  • memory_mask – A [batch_size, mem_len] bool tensor. True values are those that should be masked (no attention paid).

  • self_edges[num_self_edges, 3] each row has the form (batch_idx, source_idx, target_idx) or an empty tensor of shape (0, 3) of sparse edges are unused..

  • self_edge_types[num_self_edges] of integers from 0..num_self_edges-1.

  • encoder_edges[num_enc_edges, 3] each row has the form (batch_idx, source_idx, target_idx) or an empty tensor of shape (0, 3) of sparse edges are unused. Note: target_idx refers to elements in the memory.

  • encoder_edge_types[num_enc_edges] of integers from 0..num_enc_edges-1

  • dense_self_relations_kq – Optional [batch_size, seq_len, seq_len, num_heads] for the relationships within the decoder.

  • dense_self_relations_kv – Optional [batch_size, seq_len, seq_len, num_heads, value_dimension] relationships within the decoder.

  • dense_encoder_relations_kq – Optional [batch_size, seq_len, mem_len, num_heads] relationships between the encoded inputs and the decoder.

  • dense_encoder_relations_kv – Optional [batch_size, seq_len, mem_len, num_heads, value_dimension] relationships between the encoded inputs and the decoder.

Returns

[batch_size, seq_len, H]