Source code for archai.supergraph.datasets.providers.svhn_provider
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torchvision
from overrides import overrides
from torch.utils.data import ConcatDataset
from torchvision.transforms import transforms
from archai.common import utils
from archai.common.config import Config
from archai.supergraph.datasets.dataset_provider import (
DatasetProvider,
ImgSize,
TrainTestDatasets,
register_dataset_provider,
)
[docs]class SvhnProvider(DatasetProvider):
def __init__(self, conf_dataset:Config):
super().__init__(conf_dataset)
self._dataroot = utils.full_path(conf_dataset['dataroot'])
[docs] @overrides
def get_datasets(self, load_train:bool, load_test:bool,
transform_train, transform_test)->TrainTestDatasets:
trainset, testset = None, None
if load_train:
trainset = torchvision.datasets.SVHN(root=self._dataroot, split='train',
download=True, transform=transform_train)
extraset = torchvision.datasets.SVHN(root=self._dataroot, split='extra',
download=True, transform=transform_train)
trainset = ConcatDataset([trainset, extraset])
if load_test:
testset = torchvision.datasets.SVHN(root=self._dataroot, split='test',
download=True, transform=transform_test)
return trainset, testset
register_dataset_provider('svhn', SvhnProvider)