Implementing a Custom Dataset Provider#
Abstract base classes (ABCs) define a blueprint for a class, specifying its methods and attributes, but not its implementation. They are important in implementing a consistent interface, as they enforce a set of requirements on implementing classes and make it easier to write code that can work with multiple implementations.
First, we define a boilerplate for the DatasetProvider
class, which is the same implemented in archai.api.dataset_provider
module.
[1]:
from abc import abstractmethod
from typing import Any
from overrides import EnforceOverrides
class DatasetProvider(EnforceOverrides):
def __init__(self) -> None:
pass
@abstractmethod
def get_train_dataset(self) -> Any:
pass
@abstractmethod
def get_val_dataset(self) -> Any:
pass
@abstractmethod
def get_test_dataset(self) -> Any:
pass
Torchvision-based Dataset Provider#
In the context of a custom dataset provider, using ABCs can help ensure that the provider implements the required methods and provides a consistent interface for loading and processing data. In this example, we will implement a Torchvision-based dataset provider, as follows:
[2]:
from typing import Callable, Optional
from overrides import overrides
from torch.utils.data import Dataset
from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import ToTensor
class TorchvisionDatasetProvider(DatasetProvider):
SUPPORTED_DATASETS = {
"mnist": MNIST,
"cifar10": CIFAR10
}
def __init__(self, dataset: str, root: Optional[str] = "dataroot") -> None:
super().__init__()
self.dataset = dataset
self.root = root
@overrides
def get_train_dataset(
self,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> Dataset:
return self.SUPPORTED_DATASETS[self.dataset](
self.root,
train=True,
download=True,
transform=transform or ToTensor(),
target_transform=target_transform,
)
@overrides
def get_val_dataset(
self,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> Dataset:
return self.SUPPORTED_DATASETS[self.dataset](
self.root,
train=False,
download=True,
transform=transform or ToTensor(),
target_transform=target_transform,
)
@overrides
def get_test_dataset(
self,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> Dataset:
print(f"Testing set not available for `{self.dataset}`. Returning validation set ...")
return self.get_val_dataset(transform=transform, target_transform=target_transform)
Using the Dataset Provider#
Finally, one need to call the implemented methods to retrieve the datasets, as follows:
[3]:
dataset_provider = TorchvisionDatasetProvider("mnist")
train_dataset = dataset_provider.get_train_dataset()
val_dataset = dataset_provider.get_val_dataset()
print(train_dataset, val_dataset)
# As there is no `test_dataset` available, it returns the validation set
test_dataset = dataset_provider.get_test_dataset()
print(test_dataset)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataroot\MNIST\raw\train-images-idx3-ubyte.gz
Extracting dataroot\MNIST\raw\train-images-idx3-ubyte.gz to dataroot\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataroot\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting dataroot\MNIST\raw\train-labels-idx1-ubyte.gz to dataroot\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataroot\MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting dataroot\MNIST\raw\t10k-images-idx3-ubyte.gz to dataroot\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataroot\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting dataroot\MNIST\raw\t10k-labels-idx1-ubyte.gz to dataroot\MNIST\raw
Dataset MNIST
Number of datapoints: 60000
Root location: dataroot
Split: Train
StandardTransform
Transform: ToTensor() Dataset MNIST
Number of datapoints: 10000
Root location: dataroot
Split: Test
StandardTransform
Transform: ToTensor()
Testing set not available for `mnist`. Returning validation set ...
Dataset MNIST
Number of datapoints: 10000
Root location: dataroot
Split: Test
StandardTransform
Transform: ToTensor()