Set Transformers

The mi_module_zoo.settransformer contains implementations of transformer-based models that process sets.

class mi_module_zoo.settransformer.SetTransformer(input_embedding_dim: int, set_embedding_dim: int, transformer_embedding_dim: Optional[int] = None, num_heads: int = 1, num_blocks: int = 2, num_seed_vectors: int = 1, use_isab: bool = False, num_inducing_points: Optional[int] = None, multihead_init_type: mi_module_zoo.settransformer.MultiheadInitType = MultiheadInitType.XAVIER, use_layer_norm: bool = True, elementwise_transform_type: mi_module_zoo.settransformer.ElementwiseTransformType = ElementwiseTransformType.SINGLE, use_elementwise_transform_pma: bool = True, dropout_rate: float = 0.0)[source]

The Set Transformer model https://arxiv.org/abs/1810.00825

Generates an embedding from a set of features using several blocks of self attention and pooling by attention.

Parameters
  • input_embedding_dim – Dimension of the input data, the embedded features.

  • set_embedding_dim – Dimension of the output data, the set embedding.

  • transformer_embedding_dim – Embedding dimension to be used in the set transformer blocks.

  • num_heads – Number of heads in each multi-head attention block.

  • num_blocks – Number of SABs in the model.

  • num_seed_vectors – Number of seed vectors used in the pooling block (PMA).

  • use_isab – Should ISAB blocks be used instead of SAB blocks.

  • num_inducing_points – Number of inducing points.

  • multihead_init_type – How linear layers in nn.MultiheadAttention are initialised. Valid options are “xavier” and “kaiming”.

  • use_layer_norm – Whether layer normalisation should be used in MAB blocks.

  • elementwise_transform_type – What version of the elementwise transform (rFF) should be used. Valid options are “single” and “double”.

  • use_elementwise_transform_pma – Whether an elementwise transform (rFF) should be used in the PMA block.

forward(x: torch.Tensor, mask: Optional[torch.Tensor] = None) torch.Tensor[source]
Parameters
  • x – Embedded features tensor with shape [batch_size, set_size, input_embedding_dim].

  • mask – Mask tensor with shape [batch_size, set_size], True values are masked

Returns

Set embedding tensor with shape [batch_size, set_embedding_dim].

class mi_module_zoo.settransformer.ISAB(embedding_dim: int, num_heads: int, num_inducing_points: int, multihead_init_type: mi_module_zoo.settransformer.MultiheadInitType, use_layer_norm: bool, elementwise_transform_type: mi_module_zoo.settransformer.ElementwiseTransformType, dropout_rate: float)[source]

Inducing-point self attention block. This reduces memory use and compute time from \(O(N^2)\) to \(O(NM)\) where \(N\) is the number of features and \(M\) is the number of inducing points.

Reference: https://arxiv.org/pdf/1810.00825.pdf

Parameters
  • embedding_dim – Dimension of the input data.

  • num_heads – Number of heads.

  • num_inducing_points – Number of inducing points.

  • multihead_init_type – How linear layers in nn.MultiheadAttention are initialised.

  • use_layer_norm – Whether layer normalisation should be used in MAB blocks.

  • elementwise_transform_type – What version of the elementwise transform (rFF) should be used.

forward(x: torch.Tensor, mask: Optional[torch.Tensor] = None) torch.Tensor[source]
Parameters
  • x – Input tensor with shape [batch_size, set_size, embedding_dim] to be used as query, key and value.

  • mask – Mask tensor with shape [batch_size, set_size], True values are masked. If None, all elements are used is used. The mask enforces that only the selected values are attended to in multihead attention, but the output is generated for all elements of x.

Returns

Attention output tensor with shape [batch_size, set_size, embedding_dim].

class mi_module_zoo.settransformer.PMA(embedding_dim: int, num_heads: int, num_seed_vectors: int, multihead_init_type: mi_module_zoo.settransformer.MultiheadInitType, use_layer_norm: bool, elementwise_transform_type: mi_module_zoo.settransformer.ElementwiseTransformType, use_elementwise_transform_pma: bool, dropout_rate: float)[source]

Pooling by Multihead Attention block of the Set Transformer model. Seed vectors attend to the given values.

Parameters
  • embedding_dim – Dimension of the input data.

  • num_heads – Number of heads.

  • num_seed_vectors – Number of seed vectors.

  • multihead_init_type – How linear layers in nn.MultiheadAttention are initialised.

  • use_layer_norm – Whether layer normalisation should be used in MAB blocks.

  • elementwise_transform_type – What version of the elementwise transform (rFF) should be used.

  • use_elementwise_transform_pma – Whether an elementwise transform (rFF) should be used in the PMA block.

forward(x: torch.Tensor, mask: Optional[torch.Tensor] = None) torch.Tensor[source]
Parameters
  • x – Input tensor with shape [batch_size, set_size, embedding_dim] to be used as key and value.

  • mask – Mask tensor with shape [batch_size, set_size], True for masked elements. If None, everything is observed. The mask enforces that only the selected values are attended to in multihead attention.

Returns

Attention output tensor with shape [batch_size, num_seed_vectors, embedding_dim].

class mi_module_zoo.settransformer.MultiheadInitType(value)[source]

Initialization type for multihead attetion.

KAIMING = 1
XAVIER = 0