Creating CV-based Data#
In this notebook, we will create CV-based data using PyTorch’s torchvision
library, which provides access to popular CV datasets, such as CIFAR10, MNIST, and ImageNet, among others.
Loading the Data#
The first step is to create an instance of the MnistDatasetProvider
, which offers pre-loads the dataset and offers three methods to retrieve them: get_train_dataset()
, get_val_dataset()
and get_test_dataset()
.
[1]:
from archai.datasets.cv.mnist_dataset_provider import MnistDatasetProvider
dataset_provider = MnistDatasetProvider()
train_dataset = dataset_provider.get_train_dataset()
val_dataset = dataset_provider.get_val_dataset()
print(train_dataset, val_dataset)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataroot\MNIST\raw\train-images-idx3-ubyte.gz
Extracting dataroot\MNIST\raw\train-images-idx3-ubyte.gz to dataroot\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataroot\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting dataroot\MNIST\raw\train-labels-idx1-ubyte.gz to dataroot\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataroot\MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting dataroot\MNIST\raw\t10k-images-idx3-ubyte.gz to dataroot\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataroot\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting dataroot\MNIST\raw\t10k-labels-idx1-ubyte.gz to dataroot\MNIST\raw
Dataset MNIST
Number of datapoints: 60000
Root location: dataroot
Split: Train
StandardTransform
Transform: ToTensor() Dataset MNIST
Number of datapoints: 10000
Root location: dataroot
Split: Test
StandardTransform
Transform: ToTensor()
Transforming the Data#
Additionally, the torchvision
library supports various data augmentation techniques, such as random cropping, flipping, and rotation, to increase the size and diversity of the dataset, leading to better model generalization and performance. This can be applied as a post-processing function or by passing the transform
argument when retrieving the dataset.
By default, every dataset retrieved by the dataset providers use the ``ToTensor()`` transform.
[2]:
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
transformed_train_dataset = dataset_provider.get_train_dataset(transform=transform)
print(transformed_train_dataset)
Dataset MNIST
Number of datapoints: 60000
Root location: dataroot
Split: Train
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.1307,), std=(0.3081,))
)