Skip to content

DataModules¤

pdearena.data.datamodule ¤

PDEDataModule ¤

Bases: LightningDataModule

Defines the standard dataloading process for PDE data.

Does not support generaliztion to different parameterizations or time. Consider using pdearena.data.cond_datamodule.CondPDEDataModule for that.

Parameters:

Name Type Description Default
task str

The task to be solved.

required
data_dir str

The path to the data directory.

required
time_history int

The number of time steps in the past.

required
time_future int

The number of time steps in the future.

required
time_gap int

The number of time steps between the past and the future to be skipped.

required
pde dict

The PDE to be solved.

required
batch_size int

The batch size.

required
pin_memory bool

Whether to pin memory.

required
num_workers int

The number of workers. Make sure when using values greater than 1 on multi-GPU systems, the number of shards is divisible by the number of workers times number of GPUs.

required
train_limit_trajectories int

The number of trajectories to be used for training. This is from each shard.

required
valid_limit_trajectories int

The number of trajectories to be used for validation. This is from each shard.

required
test_limit_trajectories int

The number of trajectories to be used for testing. This is from each shard.

required
usegrid bool

Whether to use a grid. Defaults to False.

False
Source code in pdearena/data/datamodule.py
class PDEDataModule(LightningDataModule):
    """Defines the standard dataloading process for PDE data.

    Does not support generaliztion to different parameterizations or time.
    Consider using [pdearena.data.cond_datamodule.CondPDEDataModule][] for that.

    Args:
        task (str): The task to be solved.
        data_dir (str): The path to the data directory.
        time_history (int): The number of time steps in the past.
        time_future (int): The number of time steps in the future.
        time_gap (int): The number of time steps between the past and the future to be skipped.
        pde (dict): The PDE to be solved.
        batch_size (int): The batch size.
        pin_memory (bool): Whether to pin memory.
        num_workers (int): The number of workers. Make sure when using values greater than 1 on multi-GPU systems, the number of shards is divisible by the number of workers times number of GPUs.
        train_limit_trajectories (int): The number of trajectories to be used for training. This is from each shard.
        valid_limit_trajectories (int): The number of trajectories to be used for validation. This is from each shard.
        test_limit_trajectories (int): The number of trajectories to be used for testing. This is from each shard.
        usegrid (bool, optional): Whether to use a grid. Defaults to False.
    """

    def __init__(
        self,
        task: str,
        data_dir: str,
        time_history: int,
        time_future: int,
        time_gap: int,
        pde: PDEDataConfig,
        batch_size: int,
        pin_memory: bool,
        num_workers: int,
        train_limit_trajectories: int,
        valid_limit_trajectories: int,
        test_limit_trajectories: int,
        usegrid: bool = False,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.pde = pde

        self.save_hyperparameters(ignore="pde", logger=False)

    def setup(self, stage: Optional[str] = None):
        dps = DATAPIPE_REGISTRY[self.hparams.task]
        self.train_dp = dps["train"](
            pde=self.pde,
            data_path=self.data_dir,
            limit_trajectories=self.hparams.train_limit_trajectories,
            usegrid=self.hparams.usegrid,
            time_history=self.hparams.time_history,
            time_future=self.hparams.time_future,
            time_gap=self.hparams.time_gap,
        )
        self.valid_dp1 = dps["valid"][0](
            pde=self.pde,
            data_path=self.data_dir,
            limit_trajectories=self.hparams.valid_limit_trajectories,
            usegrid=self.hparams.usegrid,
            time_history=self.hparams.time_history,
            time_future=self.hparams.time_future,
            time_gap=self.hparams.time_gap,
        )
        self.valid_dp2 = dps["valid"][1](
            pde=self.pde,
            data_path=self.data_dir,
            limit_trajectories=self.hparams.valid_limit_trajectories,
            usegrid=self.hparams.usegrid,
            time_history=self.hparams.time_history,
            time_future=self.hparams.time_future,
            time_gap=self.hparams.time_gap,
        )
        self.test_dp_onestep = dps["test"][0](
            pde=self.pde,
            data_path=self.data_dir,
            limit_trajectories=self.hparams.test_limit_trajectories,
            usegrid=self.hparams.usegrid,
            time_history=self.hparams.time_history,
            time_future=self.hparams.time_future,
            time_gap=self.hparams.time_gap,
        )
        self.test_dp = dps["test"][1](
            pde=self.pde,
            data_path=self.data_dir,
            limit_trajectories=self.hparams.test_limit_trajectories,
            usegrid=self.hparams.usegrid,
            time_history=self.hparams.time_history,
            time_future=self.hparams.time_future,
            time_gap=self.hparams.time_gap,
        )

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_dp,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            drop_last=True,
            collate_fn=collate_fn_cat,
        )

    def val_dataloader(self):
        timestep_loader = DataLoader(
            dataset=self.valid_dp1,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            collate_fn=collate_fn_cat,
        )
        rollout_loader = DataLoader(
            dataset=self.valid_dp2,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            batch_size=self.hparams.batch_size,  # TODO: might need to reduce this
            shuffle=False,
            collate_fn=collate_fn_stack,
        )
        return [timestep_loader, rollout_loader]

    def test_dataloader(self):
        rollout_loader = DataLoader(
            dataset=self.test_dp,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            collate_fn=collate_fn_stack,
        )
        timestep_loader = DataLoader(
            dataset=self.test_dp_onestep,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            collate_fn=collate_fn_cat,
        )
        return [timestep_loader, rollout_loader]

pdearena.data.cond_datamodule ¤

CondPDEDataModule ¤

Bases: LightningDataModule

Definest the dataloading process for conditioned PDE data.

Supports generalization experiments.

Parameters:

Name Type Description Default
task str

Name of the task.

required
data_dir str

Path to the data directory.

required
pde dict

Dictionary containing the PDE class and its arguments.

required
batch_size int

Batch size.

required
pin_memory bool

Whether to pin memory.

required
num_workers int

Number of workers.

required
train_limit_trajectories int

Number of trajectories to use for training.

required
valid_limit_trajectories int

Number of trajectories to use for validation.

required
test_limit_trajectories int

Number of trajectories to use for testing.

required
eval_dts List[int]

List of timesteps to use for evaluation. Defaults to [1, 2, 4, 8, 16].

[1, 2, 4, 8, 16]
usegrid bool

Whether to use the grid. Defaults to False.

False
Source code in pdearena/data/cond_datamodule.py
class CondPDEDataModule(LightningDataModule):
    """Definest the dataloading process for conditioned PDE data.

    Supports generalization experiments.

    Args:
        task (str): Name of the task.
        data_dir (str): Path to the data directory.
        pde (dict): Dictionary containing the PDE class and its arguments.
        batch_size (int): Batch size.
        pin_memory (bool): Whether to pin memory.
        num_workers (int): Number of workers.
        train_limit_trajectories (int): Number of trajectories to use for training.
        valid_limit_trajectories (int): Number of trajectories to use for validation.
        test_limit_trajectories (int): Number of trajectories to use for testing.
        eval_dts (List[int], optional): List of timesteps to use for evaluation. Defaults to [1, 2, 4, 8, 16].
        usegrid (bool, optional): Whether to use the grid. Defaults to False.
    """

    def __init__(
        self,
        task: str,
        data_dir: str,
        pde: PDEDataConfig,
        batch_size: int,
        pin_memory: bool,
        num_workers: int,
        train_limit_trajectories: int,
        valid_limit_trajectories: int,
        test_limit_trajectories: int,
        eval_dts: List[int] = [1, 2, 4, 8, 16],
        usegrid: bool = False,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.eval_dts = eval_dts
        self.pde = pde
        self.save_hyperparameters(ignore="pde", logger=False)

        # if "Weather" in pde["class_path"]:
        #     self.dataset_opener = WeatherDatasetOpener
        #     self.randomized_traindatapipe = RandomTimeStepPDETrainData
        #     self.evaldatapipe = TimestepPDEEvalData
        #     # self.train_filter = _weathertrain_filter
        #     # self.valid_filter = _weathervalid_filter
        #     # self.test_filter = _weathertest_filter
        #     self.lister = lambda x: dp.iter.IterableWrapper(
        #         map(lambda y: os.path.join(self.data_dir, y), os.listdir(x))
        #     )
        #     self.sharder = lambda x: x
        # elif len(self.pde.grid_size) == 3:
        #     self.dataset_opener = NavierStokesDatasetOpener
        #     self.randomized_traindatapipe = RandomTimeStepPDETrainData
        #     self.evaldatapipe = TimestepPDEEvalData
        #     self.train_filter = _train_filter
        #     self.valid_filter = _valid_filter
        #     self.test_filter = _test_filter
        #     self.lister = dp.iter.FileLister
        #     self.sharder = dp.iter.ShardingFilter
        # else:
        #     raise NotImplementedError()

    def setup(self, stage=None):
        dps = DATAPIPE_REGISTRY[self.hparams.task]
        self.train_dp = dps["train"](
            pde=self.pde,
            data_path=self.data_dir,
            limit_trajectories=self.hparams.train_limit_trajectories,
            usegrid=self.hparams.usegrid,
            time_history=1,
            time_future=1,
            time_gap=0,
        )
        self.valid_dps = [
            dps["valid"](
                pde=self.pde,
                data_path=self.data_dir,
                limit_trajectories=self.hparams.valid_limit_trajectories,
                usegrid=False,
                time_history=1,
                time_future=1,
                time_gap=0,
                delta_t=dt,
            )
            for dt in self.eval_dts
        ]

        self.test_dp = dps["test"][1](
            pde=self.pde,
            data_path=self.data_dir,
            limit_trajectories=self.hparams.test_limit_trajectories,
            usegrid=self.hparams.usegrid,
            time_history=1,
            time_future=1,
            time_gap=0,
        )
        self.test_dps = [
            dps["test"][0](
                pde=self.pde,
                data_path=self.data_dir,
                limit_trajectories=self.hparams.test_limit_trajectories,
                usegrid=False,
                time_history=1,
                time_future=1,
                time_gap=0,
                delta_t=dt,
            )
            for dt in self.eval_dts
        ]

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_dp,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            drop_last=True,
            collate_fn=collate_fn_cat,
        )

    def val_dataloader(self):
        timestep_loaders = [
            DataLoader(
                dataset=dp,
                num_workers=self.hparams.num_workers,
                pin_memory=self.hparams.pin_memory,
                batch_size=self.hparams.batch_size,
                shuffle=False,
                collate_fn=collate_fn_cat,
            )
            for dp in self.valid_dps
        ]
        return timestep_loaders

    def test_dataloader(self):
        rollout_loader = DataLoader(
            dataset=self.test_dp,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            collate_fn=collate_fn_stack,
        )
        timestep_loader = [
            DataLoader(
                dataset=dp,
                num_workers=self.hparams.num_workers,
                pin_memory=self.hparams.pin_memory,
                batch_size=self.hparams.batch_size,
                shuffle=False,
                collate_fn=collate_fn_cat,
            )
            for dp in self.test_dps
        ]
        return [rollout_loader] + timestep_loader