causica.functional_relationships.deci_functional_relationships¶
Module Contents¶
Classes¶
This is a FunctionalRelationsips that wraps the DECIEmbedNN module. |
- class causica.functional_relationships.deci_functional_relationships.DECIEmbedFunctionalRelationships(shapes: dict[str, torch.Size], embedding_size: int, out_dim_g: int, num_layers_g: int, num_layers_zeta: int)[source]¶
Bases:
causica.functional_relationships.functional_relationships.FunctionalRelationshipsThis is a FunctionalRelationsips that wraps the DECIEmbedNN module.
- forward(samples: tensordict.TensorDict, graphs: torch.Tensor) tensordict.TensorDict[source]¶
Calculates the predictions of the children from parents.
Functional relationships expect samples to have a batch shape in order: samples_shape, functions_shape, graphs_shape. The graphs are expected have a matching batch shape graphs_shape. This then applies the functional relationship to each sample using the corresponding function and graph. This allows for batched processing of different functions (e.g. interventions) or graphs.