causica.functional_relationships.deci_functional_relationships

Module Contents

Classes

DECIEmbedFunctionalRelationships

This is a FunctionalRelationsips that wraps the DECIEmbedNN module.

class causica.functional_relationships.deci_functional_relationships.DECIEmbedFunctionalRelationships(shapes: dict[str, torch.Size], embedding_size: int, out_dim_g: int, num_layers_g: int, num_layers_zeta: int)[source]

Bases: causica.functional_relationships.functional_relationships.FunctionalRelationships

This is a FunctionalRelationsips that wraps the DECIEmbedNN module.

forward(samples: tensordict.TensorDict, graphs: torch.Tensor) tensordict.TensorDict[source]

Calculates the predictions of the children from parents.

Functional relationships expect samples to have a batch shape in order: samples_shape, functions_shape, graphs_shape. The graphs are expected have a matching batch shape graphs_shape. This then applies the functional relationship to each sample using the corresponding function and graph. This allows for batched processing of different functions (e.g. interventions) or graphs.

Parameters:
samples: tensordict.TensorDict

dictionary of variable samples of shape batch_size_x + batch_size_f + batch_shape_g + [node shape].

graphs: torch.Tensor

tensor of shape batch_size_g + [nodes, nodes]

Returns:

Dictionary of torch.Tensors of shape batch_size_x + batch_size_f + batch_shape_g + [node shape]