causica.lightning.data_modules.synthetic_data_module

Module Contents

Classes

SyntheticDataModule

A datamodule to produce datasets and their underlying causal graphs and interventions.

class causica.lightning.data_modules.synthetic_data_module.SyntheticDataModule(sem_samplers: list[causica.data_generation.samplers.sem_sampler.SEMSampler] | collections.abc.Callable[[], list[causica.data_generation.samplers.sem_sampler.SEMSampler]], train_batch_size: int, test_batch_size: int, dataset_size: int, num_interventions: int = 0, num_intervention_samples: int = 100, num_sems: int = 0, batches_per_metaset: int = 1, sample_interventions: bool = False, sample_counterfactuals: bool = False, num_workers: int = 0, pin_memory: bool = True, persistent_workers: bool = True, prefetch_factor: int = 16)[source]

Bases: pytorch_lightning.LightningDataModule

A datamodule to produce datasets and their underlying causal graphs and interventions.

This datamodule samples synthetic datasets from a list of SEM samplers and returns batches of CausalDataset objects. This means, that the underlying tensors are not yet stacked and will need to be processed by the module using the data. Currently, the data module uses the SubsetBatchSampler to sample batches from the different sampled SEMs, ensuring that each batch is sampled from SEMs with the same number of nodes. This is important, as it allows for easy batching of the data, as the tensors are already of the same shape.

Note

This data module does currently not support DDP, because the batch sampler does not simply consume indices from a a regular sampler. In future, DDP support will be added by ensuring that each compute node is individually seeded to generate different data samples and adapting the batch sampler to ensure that an epoch has the same length regardless of the number of compute nodes.

This data module might generate duplicate data samples, when num_workers > 1. This is because the individual workers are not individually seeded.

prepare_data() None[source]

Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.

Warning

DO NOT set state to the model (use setup instead) since this is NOT called on every device

Example:

def prepare_data(self):
    # good
    download_data()
    tokenize()
    etc()

    # bad
    self.split = data_split
    self.some_state = some_other_state()

In a distributed environment, prepare_data can be called in two ways (using prepare_data_per_node)

  1. Once per node. This is the default and is only called on LOCAL_RANK=0.

  2. Once in total. Only called on GLOBAL_RANK=0.

Example:

# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True

# call on GLOBAL_RANK=0 (great for shared file systems)
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = False

This is called before requesting the dataloaders:

model.prepare_data()
initialize_distributed()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
_get_dataset(dataset_size: int) torch.utils.data.ConcatDataset[source]

Builds causal datasets given the SEM samplers.

Parameters:
dataset_size: int

Number of samples of the causal dataset (ie number of datasets generated).

Returns:

dataset object

train_dataloader()[source]

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader()[source]

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

test_dataloader()[source]

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.