Source code for archai.supergraph.datasets.providers.fashion_mnist_provider
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torchvision
from overrides import overrides
from torchvision.transforms import transforms
from archai.common import utils
from archai.common.config import Config
from archai.supergraph.datasets.dataset_provider import (
DatasetProvider,
ImgSize,
TrainTestDatasets,
register_dataset_provider,
)
[docs]class FashionMnistProvider(DatasetProvider):
def __init__(self, conf_dataset:Config):
super().__init__(conf_dataset)
self._dataroot = utils.full_path(conf_dataset['dataroot'])
[docs] @overrides
def get_datasets(self, load_train:bool, load_test:bool,
transform_train, transform_test)->TrainTestDatasets:
trainset, testset = None, None
if load_train:
trainset = torchvision.datasets.FashionMNIST(root=self._dataroot,
train=True, download=True, transform=transform_train)
if load_test:
testset = torchvision.datasets.FashionMNIST(root=self._dataroot,
train=False, download=True, transform=transform_test)
return trainset, testset
register_dataset_provider('fashion_mnist', FashionMnistProvider)