Set Transformers¶
The mi_module_zoo.settransformer contains implementations of
transformer-based models that process sets.
- class mi_module_zoo.settransformer.SetTransformer(input_embedding_dim: int, set_embedding_dim: int, transformer_embedding_dim: Optional[int] = None, num_heads: int = 1, num_blocks: int = 2, num_seed_vectors: int = 1, use_isab: bool = False, num_inducing_points: Optional[int] = None, multihead_init_type: mi_module_zoo.settransformer.MultiheadInitType = MultiheadInitType.XAVIER, use_layer_norm: bool = True, elementwise_transform_type: mi_module_zoo.settransformer.ElementwiseTransformType = ElementwiseTransformType.SINGLE, use_elementwise_transform_pma: bool = True, dropout_rate: float = 0.0)[source]¶
The Set Transformer model https://arxiv.org/abs/1810.00825
Generates an embedding from a set of features using several blocks of self attention and pooling by attention.
- Parameters
input_embedding_dim – Dimension of the input data, the embedded features.
set_embedding_dim – Dimension of the output data, the set embedding.
transformer_embedding_dim – Embedding dimension to be used in the set transformer blocks.
num_heads – Number of heads in each multi-head attention block.
num_blocks – Number of SABs in the model.
num_seed_vectors – Number of seed vectors used in the pooling block (PMA).
use_isab – Should ISAB blocks be used instead of SAB blocks.
num_inducing_points – Number of inducing points.
multihead_init_type – How linear layers in nn.MultiheadAttention are initialised. Valid options are “xavier” and “kaiming”.
use_layer_norm – Whether layer normalisation should be used in MAB blocks.
elementwise_transform_type – What version of the elementwise transform (rFF) should be used. Valid options are “single” and “double”.
use_elementwise_transform_pma – Whether an elementwise transform (rFF) should be used in the PMA block.
- forward(x: torch.Tensor, mask: Optional[torch.Tensor] = None) torch.Tensor[source]¶
- Parameters
x – Embedded features tensor with shape
[batch_size, set_size, input_embedding_dim].mask – Mask tensor with shape
[batch_size, set_size],Truevalues are masked
- Returns
Set embedding tensor with shape
[batch_size, set_embedding_dim].
- class mi_module_zoo.settransformer.ISAB(embedding_dim: int, num_heads: int, num_inducing_points: int, multihead_init_type: mi_module_zoo.settransformer.MultiheadInitType, use_layer_norm: bool, elementwise_transform_type: mi_module_zoo.settransformer.ElementwiseTransformType, dropout_rate: float)[source]¶
Inducing-point self attention block. This reduces memory use and compute time from \(O(N^2)\) to \(O(NM)\) where \(N\) is the number of features and \(M\) is the number of inducing points.
Reference: https://arxiv.org/pdf/1810.00825.pdf
- Parameters
embedding_dim – Dimension of the input data.
num_heads – Number of heads.
num_inducing_points – Number of inducing points.
multihead_init_type – How linear layers in nn.MultiheadAttention are initialised.
use_layer_norm – Whether layer normalisation should be used in MAB blocks.
elementwise_transform_type – What version of the elementwise transform (rFF) should be used.
- forward(x: torch.Tensor, mask: Optional[torch.Tensor] = None) torch.Tensor[source]¶
- Parameters
x – Input tensor with shape
[batch_size, set_size, embedding_dim]to be used as query, key and value.mask – Mask tensor with shape
[batch_size, set_size],Truevalues are masked. IfNone, all elements are used is used. The mask enforces that only the selected values are attended to in multihead attention, but the output is generated for all elements of x.
- Returns
Attention output tensor with shape
[batch_size, set_size, embedding_dim].
- class mi_module_zoo.settransformer.PMA(embedding_dim: int, num_heads: int, num_seed_vectors: int, multihead_init_type: mi_module_zoo.settransformer.MultiheadInitType, use_layer_norm: bool, elementwise_transform_type: mi_module_zoo.settransformer.ElementwiseTransformType, use_elementwise_transform_pma: bool, dropout_rate: float)[source]¶
Pooling by Multihead Attention block of the Set Transformer model. Seed vectors attend to the given values.
- Parameters
embedding_dim – Dimension of the input data.
num_heads – Number of heads.
num_seed_vectors – Number of seed vectors.
multihead_init_type – How linear layers in nn.MultiheadAttention are initialised.
use_layer_norm – Whether layer normalisation should be used in MAB blocks.
elementwise_transform_type – What version of the elementwise transform (rFF) should be used.
use_elementwise_transform_pma – Whether an elementwise transform (rFF) should be used in the PMA block.
- forward(x: torch.Tensor, mask: Optional[torch.Tensor] = None) torch.Tensor[source]¶
- Parameters
x – Input tensor with shape
[batch_size, set_size, embedding_dim]to be used as key and value.mask – Mask tensor with shape
[batch_size, set_size],Truefor masked elements. IfNone, everything is observed. The mask enforces that only the selected values are attended to in multihead attention.
- Returns
Attention output tensor with shape
[batch_size, num_seed_vectors, embedding_dim].