causica.lightning.data_modules.synthetic_data_module¶
Module Contents¶
Classes¶
A datamodule to produce datasets and their underlying causal graphs and interventions. |
-
class causica.lightning.data_modules.synthetic_data_module.SyntheticDataModule(sem_samplers: list[causica.data_generation.samplers.sem_sampler.SEMSampler] | collections.abc.Callable[[], list[causica.data_generation.samplers.sem_sampler.SEMSampler]], train_batch_size: int, test_batch_size: int, dataset_size: int, num_interventions: int =
0, num_intervention_samples: int =100, num_sems: int =0, batches_per_metaset: int =1, sample_interventions: bool =False, sample_counterfactuals: bool =False, num_workers: int =0, pin_memory: bool =True, persistent_workers: bool =True, prefetch_factor: int =16)[source]¶ Bases:
pytorch_lightning.LightningDataModuleA datamodule to produce datasets and their underlying causal graphs and interventions.
This datamodule samples synthetic datasets from a list of SEM samplers and returns batches of CausalDataset objects. This means, that the underlying tensors are not yet stacked and will need to be processed by the module using the data. Currently, the data module uses the SubsetBatchSampler to sample batches from the different sampled SEMs, ensuring that each batch is sampled from SEMs with the same number of nodes. This is important, as it allows for easy batching of the data, as the tensors are already of the same shape.
Note
This data module does currently not support DDP, because the batch sampler does not simply consume indices from a a regular sampler. In future, DDP support will be added by ensuring that each compute node is individually seeded to generate different data samples and adapting the batch sampler to ensure that an epoch has the same length regardless of the number of compute nodes.
This data module might generate duplicate data samples, when num_workers > 1. This is because the individual workers are not individually seeded.
- prepare_data() None[source]¶
Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.
Warning
DO NOT set state to the model (use
setupinstead) since this is NOT called on every deviceExample:
def prepare_data(self): # good download_data() tokenize() etc() # bad self.split = data_split self.some_state = some_other_state()In a distributed environment,
prepare_datacan be called in two ways (using prepare_data_per_node)Once per node. This is the default and is only called on LOCAL_RANK=0.
Once in total. Only called on GLOBAL_RANK=0.
Example:
# DEFAULT # called once per node on LOCAL_RANK=0 of that node class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = True # call on GLOBAL_RANK=0 (great for shared file systems) class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = FalseThis is called before requesting the dataloaders:
model.prepare_data() initialize_distributed() model.setup(stage) model.train_dataloader() model.val_dataloader() model.test_dataloader() model.predict_dataloader()
- _get_dataset(dataset_size: int) torch.utils.data.ConcatDataset[source]¶
Builds causal datasets given the SEM samplers.
- train_dataloader()[source]¶
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader()[source]¶
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
- test_dataloader()[source]¶
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.