Relational Multi-Head Attention¶
The mi_module_zoo.relationalmultiheadattention
contains implementations of
transformer-based models that process sets.
- class mi_module_zoo.relationalmultiheadattention.RelationalMultiheadAttention(*, num_heads: int, num_edge_types: int, key_query_dimension: int, value_dimension: int, output_dimension: int, dropout_rate: float, use_edge_value_biases: bool = False, edge_attention_bias_is_scalar: bool = False)[source]¶
A relational multihead implementation supporting two variations of using additional relationship (sparse) information between input elements:
Sparse relations (edges):
If edges are present and
edge_attention_bias_is_scalar=False
,`` anduse_edge_value_biases=True
is set, this implements Eqs. (3) and (4) of Shaw, Peter, Jakob Uszkoreit, and Ashish Vaswani. “Self-attention with relative position representations.” In ACL 2018. https://www.aclweb.org/anthology/N18-2074/and Eq. (2) of Wang, Bailin, et al. “RAT-SQL: Relation-aware schema encoding and linking for text-to-SQL parsers.” In ICML 2020. https://arxiv.org/pdf/1911.04942.pdf
If edges are present and
edge_attention_bias_is_scalar=True
, anduse_edge_value_biases=False
is set, this implements Sect. 3.1 of Hellendoorn, Vincent J., et al. “Global relational modules of source code.” In ICLR 2020. https://openreview.net/pdf?id=B1lnbRNtwr
Dense relations, when all input elements have a relationship information to all other elements in the input. This can be encoded in one or both of the following two ways:
Passing a dense
dense_relations_kq
of shape[batch_size, query_len, key_len, num_heads]
inforward()
for every pair of query-key.Passing a dense
dense_relations_kv
of shape[batch_size, query_len, key_len, num_heads, value_dimension]
inforward()
for every pair of query-key.
If no edges are present and no dense relations are passed then this acts as a standard multihead attention layer.
- Parameters
num_heads – the number of attention heads.
num_edge_types – the number of discrete edge types.
key_query_dimension – the dimensionality of keys and queries (per head).
value_dimension – the dimension of the values (per head).
output_dimension – the output dimension (after the feedforward).
dropout_rate – the rate of dropout in \([0, 1)\).
use_edge_value_biases – should the edges (relations) use value biases?
edge_attention_bias_is_scalar – Should edge_attention_biases be a scalar or of size
key_query_dimension
?
- forward(*, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, masked_elements: Optional[torch.Tensor], edges: torch.Tensor, edge_types: torch.Tensor, dense_relations_kq: Optional[torch.Tensor] = None, dense_relations_kv: Optional[torch.Tensor] = None)[source]¶
- Parameters
queries –
[batch_size, query_len, D]
keys –
[batch_size, key_len, D]
values –
[batch_size, key_len, H]
masked_elements – bool tensor of shape
[batch_size, key_len]
or[batch_size, query_len, key_len]
True
values are those that should be masked (no attention paid).None
keeps everything unmasked.edges –
[num_edges, 3]
each row has the form(batch_idx, source_idx, target_idx)
edge_types –
[num_edges]
of integers from0..num_edge_types
dense_relations_kq – Optional
[batch_size, query_len, key_len, num_heads]
dense_relations_kv – Optional
[batch_size, query_len, key_len, num_heads, value_dimension]
- Returns
[batch_size, seq_size, H]