Source code for archai.supergraph.datasets.providers.imagenet_provider
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from overrides import overrides
from torchvision import datasets
from torchvision.transforms import transforms
from archai.common import utils
from archai.common.config import Config
from archai.supergraph.datasets.dataset_provider import (
DatasetProvider,
ImgSize,
TrainTestDatasets,
register_dataset_provider,
)
[docs]class ImagenetProvider(DatasetProvider):
def __init__(self, conf_dataset:Config):
super().__init__(conf_dataset)
self._dataroot = utils.full_path(conf_dataset['dataroot'])
[docs] @overrides
def get_datasets(self, load_train:bool, load_test:bool,
transform_train, transform_test)->TrainTestDatasets:
trainset, testset = None, None
if load_train:
trainset = datasets.ImageFolder(root=os.path.join(self._dataroot, 'ImageNet', 'train'),
transform=transform_train)
# compatibility with older PyTorch
if not hasattr(trainset, 'targets'):
trainset.targets = [lb for _, lb in trainset.samples]
if load_test:
testset = datasets.ImageFolder(root=os.path.join(self._dataroot, 'ImageNet', 'val'),
transform=transform_test)
return trainset, testset
[docs] @overrides
def get_transforms(self, img_size:ImgSize)->tuple:
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
_IMAGENET_PCA = {
'eigval': [0.2175, 0.0188, 0.0045],
'eigvec': [
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
]
}
transform_train, transform_test = None, None
transform_train = transforms.Compose([
transforms.RandomResizedCrop(224,
scale=(0.08, 1.0), # TODO: these two params are normally not specified
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2
),
transforms.ToTensor(),
# TODO: Lighting is not used in original darts paper
# Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
transforms.Normalize(mean=MEAN, std=STD)
])
transform_test = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD)
])
return transform_train, transform_test
register_dataset_provider('imagenet', ImagenetProvider)