Implementing a Custom Dataset Provider#

Abstract base classes (ABCs) define a blueprint for a class, specifying its methods and attributes, but not its implementation. They are important in implementing a consistent interface, as they enforce a set of requirements on implementing classes and make it easier to write code that can work with multiple implementations.

First, we define a boilerplate for the DatasetProvider class, which is the same implemented in archai.api.dataset_provider module.

[1]:
from abc import abstractmethod
from typing import Any

from overrides import EnforceOverrides


class DatasetProvider(EnforceOverrides):
    def __init__(self) -> None:
        pass

    @abstractmethod
    def get_train_dataset(self) -> Any:
        pass

    @abstractmethod
    def get_val_dataset(self) -> Any:
        pass

    @abstractmethod
    def get_test_dataset(self) -> Any:
        pass

Torchvision-based Dataset Provider#

In the context of a custom dataset provider, using ABCs can help ensure that the provider implements the required methods and provides a consistent interface for loading and processing data. In this example, we will implement a Torchvision-based dataset provider, as follows:

[2]:
from typing import Callable, Optional

from overrides import overrides
from torch.utils.data import Dataset
from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import ToTensor


class TorchvisionDatasetProvider(DatasetProvider):
    SUPPORTED_DATASETS = {
        "mnist": MNIST,
        "cifar10": CIFAR10
    }

    def __init__(self, dataset: str, root: Optional[str] = "dataroot") -> None:
        super().__init__()

        self.dataset = dataset
        self.root = root

    @overrides
    def get_train_dataset(
        self,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> Dataset:
        return self.SUPPORTED_DATASETS[self.dataset](
            self.root,
            train=True,
            download=True,
            transform=transform or ToTensor(),
            target_transform=target_transform,
        )

    @overrides
    def get_val_dataset(
        self,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> Dataset:
        return self.SUPPORTED_DATASETS[self.dataset](
            self.root,
            train=False,
            download=True,
            transform=transform or ToTensor(),
            target_transform=target_transform,
        )

    @overrides
    def get_test_dataset(
        self,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> Dataset:
        print(f"Testing set not available for `{self.dataset}`. Returning validation set ...")
        return self.get_val_dataset(transform=transform, target_transform=target_transform)

Using the Dataset Provider#

Finally, one need to call the implemented methods to retrieve the datasets, as follows:

[3]:
dataset_provider = TorchvisionDatasetProvider("mnist")

train_dataset = dataset_provider.get_train_dataset()
val_dataset = dataset_provider.get_val_dataset()
print(train_dataset, val_dataset)

# As there is no `test_dataset` available, it returns the validation set
test_dataset = dataset_provider.get_test_dataset()
print(test_dataset)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataroot\MNIST\raw\train-images-idx3-ubyte.gz
Extracting dataroot\MNIST\raw\train-images-idx3-ubyte.gz to dataroot\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataroot\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting dataroot\MNIST\raw\train-labels-idx1-ubyte.gz to dataroot\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataroot\MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting dataroot\MNIST\raw\t10k-images-idx3-ubyte.gz to dataroot\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataroot\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting dataroot\MNIST\raw\t10k-labels-idx1-ubyte.gz to dataroot\MNIST\raw

Dataset MNIST
    Number of datapoints: 60000
    Root location: dataroot
    Split: Train
    StandardTransform
Transform: ToTensor() Dataset MNIST
    Number of datapoints: 10000
    Root location: dataroot
    Split: Test
    StandardTransform
Transform: ToTensor()
Testing set not available for `mnist`. Returning validation set ...
Dataset MNIST
    Number of datapoints: 10000
    Root location: dataroot
    Split: Test
    StandardTransform
Transform: ToTensor()