Creating CV-based Data#

In this notebook, we will create CV-based data using PyTorch’s torchvision library, which provides access to popular CV datasets, such as CIFAR10, MNIST, and ImageNet, among others.

Loading the Data#

The first step is to create an instance of the MnistDatasetProvider, which offers pre-loads the dataset and offers three methods to retrieve them: get_train_dataset(), get_val_dataset() and get_test_dataset().

[1]:
from archai.datasets.cv.mnist_dataset_provider import MnistDatasetProvider

dataset_provider = MnistDatasetProvider()

train_dataset = dataset_provider.get_train_dataset()
val_dataset = dataset_provider.get_val_dataset()
print(train_dataset, val_dataset)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataroot\MNIST\raw\train-images-idx3-ubyte.gz
Extracting dataroot\MNIST\raw\train-images-idx3-ubyte.gz to dataroot\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataroot\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting dataroot\MNIST\raw\train-labels-idx1-ubyte.gz to dataroot\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataroot\MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting dataroot\MNIST\raw\t10k-images-idx3-ubyte.gz to dataroot\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataroot\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting dataroot\MNIST\raw\t10k-labels-idx1-ubyte.gz to dataroot\MNIST\raw

Dataset MNIST
    Number of datapoints: 60000
    Root location: dataroot
    Split: Train
    StandardTransform
Transform: ToTensor() Dataset MNIST
    Number of datapoints: 10000
    Root location: dataroot
    Split: Test
    StandardTransform
Transform: ToTensor()

Transforming the Data#

Additionally, the torchvision library supports various data augmentation techniques, such as random cropping, flipping, and rotation, to increase the size and diversity of the dataset, leading to better model generalization and performance. This can be applied as a post-processing function or by passing the transform argument when retrieving the dataset.

By default, every dataset retrieved by the dataset providers use the ``ToTensor()`` transform.

[2]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

transformed_train_dataset = dataset_provider.get_train_dataset(transform=transform)
print(transformed_train_dataset)
Dataset MNIST
    Number of datapoints: 60000
    Root location: dataroot
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )