Relational Transformer Encoder/Decoder Layers¶
This contains the relational encoder and decoder layers with sparse/dense relations among inputs.
- class mi_module_zoo.relationaltransformerlayers.RelationalTransformerEncoderLayer(d_model: int, key_query_dimension: int, value_dimension: int, num_heads: int, num_edge_types: int, add_reverse_edges: bool = True, dim_feedforward: int = 2048, dropout_rate: float = 0.1, activation: str = 'relu', use_edge_value_biases: bool = False, edge_attention_bias_is_scalar: bool = False, rezero_mode: Literal['off', 'scalar', 'vector', 'scalar-tied'] = 'off', normalisation_mode: Literal['off', 'prenorm', 'postnorm'] = 'postnorm')[source]¶
A relational transformer encoder layer. That supports both discrete/sparse edge types and dense (all-to-all) relations, different ReZero modes, and different normalization modes.
- Parameters
d_model – the dimensionality of the inputs/ouputs of the transformer layer.
key_query_dimension – the dimensionality of key/queries in the multihead attention.
value_dimension – the dimensionality of the multihead attention values,
num_heads – the number of attention heads,
num_edge_types – the number of discrete edge types. If
0
, no discrete edge types are to be used.add_reverse_edges – if
num_edge_types>0
should reverse edge types be introduced?dim_feedforward – the dimensionality of the feedforward hidden layer in the transformer layer.
dropout_rate – the dropout rate in \([0, 1)\),
activation – the activation function to be used in the feedforward layer. Defaults to ReLU.
use_edge_value_biases – should the discrete edges (relations) use value biases?
edge_attention_bias_is_scalar – should
edge_attention_biases
be a scalar or of sizekey_query_dimension
?rezero_mode –
Three different modes are supported
"off"
: No ReZero use."scalar"
: Sublayers (attention / fully connected) are scaled by a single scalar, i.e.,alpha
is a scalar in the following:x' = x + alpha * SelfAtt(x) x'' = x' + alpha * Boom(x') return x''
"vector"
: Sublayers (attention / fully connected) are scaled by one value per dim, i.e.,alpha
is a vector in the following:x' = x + alpha * SelfAtt(x) x'' = x' + alpha * Boom(x') return x''
normalisation_mode –
Three different modes are supported:
"off"
: use no layer norm at all. Likely to diverge without using ReZero as well."prenorm"
: Normalise values before each sublayer (attention / fully connected):x' = x + SelfAtt(LN(x)) x'' = x' + Boom(LN(x')) return x''
"postnorm"
: Normalise values after each sublayer:x' = LN(x + SelfAtt(x)) x'' = LN(x' + Boom(x)) return x''
- forward(src: torch.Tensor, src_mask: torch.Tensor, edges: torch.Tensor, edge_types: torch.Tensor, dense_relations_kq: Optional[torch.Tensor] = None, dense_relations_kv: Optional[torch.Tensor] = None, post_self_att_hook: Optional[Callable[[torch.Tensor, Union[float, torch.Tensor]], torch.Tensor]] = None)[source]¶
- Parameters
src – A
[batch_size, seq_len, D]
tensor.src_mask – A
[batch_size, seq_len]
or[batch_size, seq_len (query), seq_len (key)]
bool tensor.True
values are those that should be masked (no attention paid).edges –
[num_edges, 3]
each row has the form(batch_idx, source_idx, target_idx)
or an empty tensor of shape(0, 3)
of sparse edges are unused.edge_types –
[num_edges]
of integers from0..num_edge_types-1
or an empty tensor if sparse edges are unused.dense_relations_kq – Optional
[batch_size, seq_len, seq_len, num_heads]
.dense_relations_kv – Optional
[batch_size, seq_len, seq_len, num_heads, value_dimension]
- Returns
[batch_size, seq_len, D]
- class mi_module_zoo.relationaltransformerlayers.RelationalTransformerDecoderLayer(d_model: int, key_query_dimension: int, value_dimension: int, num_heads: int, num_self_edge_types: int, num_edge_types_to_encoder: int, add_reverse_edges: bool = True, dim_feedforward: int = 2048, dropout_rate: float = 0.1, activation: str = 'relu', use_edge_value_biases: bool = False, edge_attention_bias_is_scalar: bool = False, rezero_mode: Literal['off', 'scalar', 'vector', 'scalar-tied'] = 'off', normalisation_mode: Literal['off', 'prenorm', 'postnorm'] = 'postnorm')[source]¶
A relational transformer decoder layer. See the
RelationalTransformerEncoderLayer
for more information.- forward(tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor, memory_mask: torch.Tensor, self_edges: torch.Tensor, self_edge_types: torch.Tensor, encoder_edges: torch.Tensor, encoder_edge_types: torch.Tensor, dense_self_relations_kq: Optional[torch.Tensor] = None, dense_self_relations_kv: Optional[torch.Tensor] = None, dense_encoder_relations_kq: Optional[torch.Tensor] = None, dense_encoder_relations_kv: Optional[torch.Tensor] = None)[source]¶
- Parameters
tgt – A
[batch_size, seq_len, D]
tensor.memory – A
[batch_size, mem_len, D]
tensor.tgt_mask – A
[batch_size, seq_len]
or[batch_size, seq_len, seq_len]
bool tensor.True
values are those that should be masked (no attention paid). For “causal” attention,tgt_mask
should be 3D andtgt_mask[:, i, j] = i > j
, i.e. the upper-triangular elements should beTrue
.memory_mask – A
[batch_size, mem_len]
bool tensor.True
values are those that should be masked (no attention paid).self_edges –
[num_self_edges, 3]
each row has the form(batch_idx, source_idx, target_idx)
or an empty tensor of shape(0, 3)
of sparse edges are unused..self_edge_types –
[num_self_edges]
of integers from0..num_self_edges-1
.encoder_edges –
[num_enc_edges, 3]
each row has the form(batch_idx, source_idx, target_idx)
or an empty tensor of shape(0, 3)
of sparse edges are unused. Note:target_idx
refers to elements in the memory.encoder_edge_types –
[num_enc_edges]
of integers from0..num_enc_edges-1
dense_self_relations_kq – Optional
[batch_size, seq_len, seq_len, num_heads]
for the relationships within the decoder.dense_self_relations_kv – Optional
[batch_size, seq_len, seq_len, num_heads, value_dimension]
relationships within the decoder.dense_encoder_relations_kq – Optional
[batch_size, seq_len, mem_len, num_heads]
relationships between the encoded inputs and the decoder.dense_encoder_relations_kv – Optional
[batch_size, seq_len, mem_len, num_heads, value_dimension]
relationships between the encoded inputs and the decoder.
- Returns
[batch_size, seq_len, H]