Set Transformers¶
The mi_module_zoo.settransformer
contains implementations of
transformer-based models that process sets.
- class mi_module_zoo.settransformer.SetTransformer(input_embedding_dim: int, set_embedding_dim: int, transformer_embedding_dim: Optional[int] = None, num_heads: int = 1, num_blocks: int = 2, num_seed_vectors: int = 1, use_isab: bool = False, num_inducing_points: Optional[int] = None, multihead_init_type: mi_module_zoo.settransformer.MultiheadInitType = MultiheadInitType.XAVIER, use_layer_norm: bool = True, elementwise_transform_type: mi_module_zoo.settransformer.ElementwiseTransformType = ElementwiseTransformType.SINGLE, use_elementwise_transform_pma: bool = True, dropout_rate: float = 0.0)[source]¶
The Set Transformer model https://arxiv.org/abs/1810.00825
Generates an embedding from a set of features using several blocks of self attention and pooling by attention.
- Parameters
input_embedding_dim – Dimension of the input data, the embedded features.
set_embedding_dim – Dimension of the output data, the set embedding.
transformer_embedding_dim – Embedding dimension to be used in the set transformer blocks.
num_heads – Number of heads in each multi-head attention block.
num_blocks – Number of SABs in the model.
num_seed_vectors – Number of seed vectors used in the pooling block (PMA).
use_isab – Should ISAB blocks be used instead of SAB blocks.
num_inducing_points – Number of inducing points.
multihead_init_type – How linear layers in nn.MultiheadAttention are initialised. Valid options are “xavier” and “kaiming”.
use_layer_norm – Whether layer normalisation should be used in MAB blocks.
elementwise_transform_type – What version of the elementwise transform (rFF) should be used. Valid options are “single” and “double”.
use_elementwise_transform_pma – Whether an elementwise transform (rFF) should be used in the PMA block.
- forward(x: torch.Tensor, mask: Optional[torch.Tensor] = None) torch.Tensor [source]¶
- Parameters
x – Embedded features tensor with shape
[batch_size, set_size, input_embedding_dim]
.mask – Mask tensor with shape
[batch_size, set_size]
,True
values are masked
- Returns
Set embedding tensor with shape
[batch_size, set_embedding_dim]
.
- class mi_module_zoo.settransformer.ISAB(embedding_dim: int, num_heads: int, num_inducing_points: int, multihead_init_type: mi_module_zoo.settransformer.MultiheadInitType, use_layer_norm: bool, elementwise_transform_type: mi_module_zoo.settransformer.ElementwiseTransformType, dropout_rate: float)[source]¶
Inducing-point self attention block. This reduces memory use and compute time from \(O(N^2)\) to \(O(NM)\) where \(N\) is the number of features and \(M\) is the number of inducing points.
Reference: https://arxiv.org/pdf/1810.00825.pdf
- Parameters
embedding_dim – Dimension of the input data.
num_heads – Number of heads.
num_inducing_points – Number of inducing points.
multihead_init_type – How linear layers in nn.MultiheadAttention are initialised.
use_layer_norm – Whether layer normalisation should be used in MAB blocks.
elementwise_transform_type – What version of the elementwise transform (rFF) should be used.
- forward(x: torch.Tensor, mask: Optional[torch.Tensor] = None) torch.Tensor [source]¶
- Parameters
x – Input tensor with shape
[batch_size, set_size, embedding_dim]
to be used as query, key and value.mask – Mask tensor with shape
[batch_size, set_size]
,True
values are masked. IfNone
, all elements are used is used. The mask enforces that only the selected values are attended to in multihead attention, but the output is generated for all elements of x.
- Returns
Attention output tensor with shape
[batch_size, set_size, embedding_dim]
.
- class mi_module_zoo.settransformer.PMA(embedding_dim: int, num_heads: int, num_seed_vectors: int, multihead_init_type: mi_module_zoo.settransformer.MultiheadInitType, use_layer_norm: bool, elementwise_transform_type: mi_module_zoo.settransformer.ElementwiseTransformType, use_elementwise_transform_pma: bool, dropout_rate: float)[source]¶
Pooling by Multihead Attention block of the Set Transformer model. Seed vectors attend to the given values.
- Parameters
embedding_dim – Dimension of the input data.
num_heads – Number of heads.
num_seed_vectors – Number of seed vectors.
multihead_init_type – How linear layers in nn.MultiheadAttention are initialised.
use_layer_norm – Whether layer normalisation should be used in MAB blocks.
elementwise_transform_type – What version of the elementwise transform (rFF) should be used.
use_elementwise_transform_pma – Whether an elementwise transform (rFF) should be used in the PMA block.
- forward(x: torch.Tensor, mask: Optional[torch.Tensor] = None) torch.Tensor [source]¶
- Parameters
x – Input tensor with shape
[batch_size, set_size, embedding_dim]
to be used as key and value.mask – Mask tensor with shape
[batch_size, set_size]
,True
for masked elements. IfNone
, everything is observed. The mask enforces that only the selected values are attended to in multihead attention.
- Returns
Attention output tensor with shape
[batch_size, num_seed_vectors, embedding_dim]
.