Source code for archai.datasets.cv.mnist_dataset_provider

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Callable, Optional

from overrides import overrides
from torch.utils.data import Dataset
from torchvision.datasets import KMNIST, MNIST, QMNIST, FashionMNIST
from torchvision.transforms import ToTensor

from archai.api.dataset_provider import DatasetProvider
from archai.common.ordered_dict_logger import OrderedDictLogger

logger = OrderedDictLogger(source=__name__)


[docs]class MnistDatasetProvider(DatasetProvider): """MNIST-based dataset provider.""" SUPPORTED_DATASETS = { "fashion_mnist": FashionMNIST, "kmnist": KMNIST, "mnist": MNIST, "qmnist": QMNIST, } def __init__( self, dataset: Optional[str] = "mnist", root: Optional[str] = "dataroot", ) -> None: """Initialize MNIST-based dataset provider. Args: dataset: Name of dataset. root: Root directory of dataset where is saved. """ super().__init__() assert dataset in self.SUPPORTED_DATASETS, f"`dataset` should be one of: {list(self.SUPPORTED_DATASETS)}" self.dataset = dataset self.root = root
[docs] @overrides def get_train_dataset( self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> Dataset: return self.SUPPORTED_DATASETS[self.dataset]( self.root, train=True, download=True, transform=transform or ToTensor(), target_transform=target_transform, )
[docs] @overrides def get_val_dataset( self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> Dataset: return self.SUPPORTED_DATASETS[self.dataset]( self.root, train=False, download=True, transform=transform or ToTensor(), target_transform=target_transform, )
[docs] @overrides def get_test_dataset( self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> Dataset: logger.warn(f"Testing set not available for `{self.dataset}`. Returning validation set ...") return self.get_val_dataset(transform=transform, target_transform=target_transform)