Relational Multi-Head Attention

The mi_module_zoo.relationalmultiheadattention contains implementations of transformer-based models that process sets.

class mi_module_zoo.relationalmultiheadattention.RelationalMultiheadAttention(*, num_heads: int, num_edge_types: int, key_query_dimension: int, value_dimension: int, output_dimension: int, dropout_rate: float, use_edge_value_biases: bool = False, edge_attention_bias_is_scalar: bool = False)[source]

A relational multihead implementation supporting two variations of using additional relationship (sparse) information between input elements:

  • Sparse relations (edges):

    • If edges are present and edge_attention_bias_is_scalar=False,`` and use_edge_value_biases=True is set, this implements Eqs. (3) and (4) of Shaw, Peter, Jakob Uszkoreit, and Ashish Vaswani. “Self-attention with relative position representations.” In ACL 2018. https://www.aclweb.org/anthology/N18-2074/

      and Eq. (2) of Wang, Bailin, et al. “RAT-SQL: Relation-aware schema encoding and linking for text-to-SQL parsers.” In ICML 2020. https://arxiv.org/pdf/1911.04942.pdf

    • If edges are present and edge_attention_bias_is_scalar=True, and use_edge_value_biases=False is set, this implements Sect. 3.1 of Hellendoorn, Vincent J., et al. “Global relational modules of source code.” In ICLR 2020. https://openreview.net/pdf?id=B1lnbRNtwr

  • Dense relations, when all input elements have a relationship information to all other elements in the input. This can be encoded in one or both of the following two ways:

    • Passing a dense dense_relations_kq of shape [batch_size, query_len, key_len, num_heads] in forward() for every pair of query-key.

    • Passing a dense dense_relations_kv of shape [batch_size, query_len, key_len, num_heads, value_dimension] in forward() for every pair of query-key.

  • If no edges are present and no dense relations are passed then this acts as a standard multihead attention layer.

Parameters
  • num_heads – the number of attention heads.

  • num_edge_types – the number of discrete edge types.

  • key_query_dimension – the dimensionality of keys and queries (per head).

  • value_dimension – the dimension of the values (per head).

  • output_dimension – the output dimension (after the feedforward).

  • dropout_rate – the rate of dropout in \([0, 1)\).

  • use_edge_value_biases – should the edges (relations) use value biases?

  • edge_attention_bias_is_scalar – Should edge_attention_biases be a scalar or of size key_query_dimension?

forward(*, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, masked_elements: Optional[torch.Tensor], edges: torch.Tensor, edge_types: torch.Tensor, dense_relations_kq: Optional[torch.Tensor] = None, dense_relations_kv: Optional[torch.Tensor] = None)[source]
Parameters
  • queries[batch_size, query_len, D]

  • keys[batch_size, key_len, D]

  • values[batch_size, key_len, H]

  • masked_elements – bool tensor of shape [batch_size, key_len] or [batch_size, query_len, key_len] True values are those that should be masked (no attention paid). None keeps everything unmasked.

  • edges[num_edges, 3] each row has the form (batch_idx, source_idx, target_idx)

  • edge_types[num_edges] of integers from 0..num_edge_types

  • dense_relations_kq – Optional [batch_size, query_len, key_len, num_heads]

  • dense_relations_kv – Optional [batch_size, query_len, key_len, num_heads, value_dimension]

Returns

[batch_size, seq_size, H]