# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Tuple
from torch.utils.data import DataLoader, Dataset, Sampler
from archai.common import apex_utils, utils
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.datasets.distributed_stratified_sampler import DistributedStratifiedSampler
from archai.supergraph.datasets.augmentation import add_named_augs
from archai.supergraph.datasets.dataset_provider import (
DatasetProvider,
get_provider_type,
)
from archai.supergraph.datasets.limit_dataset import DatasetLike
logger = get_global_logger()
[docs]class DataLoaders:
def __init__(self, train_dl:Optional[DataLoader]=None,
val_dl:Optional[DataLoader]=None,
test_dl:Optional[DataLoader]=None) -> None:
self.train_dl = train_dl
self.val_dl = val_dl
self.test_dl = test_dl
[docs]def get_data(conf_loader:Config)->DataLoaders:
logger.pushd('data')
# region conf vars
# dataset
conf_dataset = conf_loader['dataset']
max_batches = conf_dataset['max_batches']
aug = conf_loader['aug']
cutout = conf_loader['cutout']
val_ratio = conf_loader['val_ratio']
val_fold = conf_loader['val_fold']
img_size = conf_loader.get('img_size', None)
load_train = conf_loader['load_train']
train_batch = conf_loader['train_batch']
train_workers = conf_loader['train_workers']
load_test = conf_loader['load_test']
test_batch = conf_loader['test_batch']
test_workers = conf_loader['test_workers']
conf_apex = conf_loader['apex']
# endregion
ds_provider = create_dataset_provider(conf_dataset)
apex = apex_utils.ApexUtils(conf_apex)
train_dl, val_dl, test_dl = get_dataloaders(ds_provider,
load_train=load_train, train_batch_size=train_batch,
load_test=load_test, test_batch_size=test_batch,
aug=aug, cutout=cutout, val_ratio=val_ratio, val_fold=val_fold,
img_size=img_size, train_workers=train_workers,
test_workers=test_workers, max_batches=max_batches, apex=apex)
assert train_dl is not None
logger.popd()
return DataLoaders(train_dl=train_dl, val_dl=val_dl, test_dl=test_dl)
[docs]def create_dataset_provider(conf_dataset:Config)->DatasetProvider:
ds_name = conf_dataset['name']
dataroot = utils.full_path(conf_dataset['dataroot'])
storage_name = conf_dataset['storage_name']
logger.info({'ds_name': ds_name, 'dataroot':dataroot, 'storage_name':storage_name})
ds_provider_type = get_provider_type(ds_name)
return ds_provider_type(conf_dataset)
[docs]def get_dataloaders(ds_provider:DatasetProvider,
load_train:bool, train_batch_size:int,
load_test:bool, test_batch_size:int,
aug, cutout:int, val_ratio:float, apex:apex_utils.ApexUtils,
val_fold=0, img_size:Optional[int]=None, train_workers:Optional[int]=None,
test_workers:Optional[int]=None, target_lb=-1, max_batches:int=-1) \
-> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]:
# if debugging in vscode, workers > 0 gets termination
default_workers = 4
if utils.is_debugging():
train_workers = test_workers = 0
logger.warn({'debugger': True})
if train_workers is None:
train_workers = default_workers # following NVidia DeepLearningExamples
if test_workers is None:
test_workers = default_workers
train_workers = round((1-val_ratio)*train_workers)
val_workers = round(val_ratio*train_workers)
logger.info({'train_workers': train_workers, 'val_workers': val_workers,
'test_workers':test_workers})
transform_train, transform_test = ds_provider.get_transforms(img_size)
add_named_augs(transform_train, aug, cutout)
trainset, testset = _get_datasets(ds_provider,
load_train, load_test, transform_train, transform_test)
# TODO: below will never get executed, set_preaug does not exist in PyTorch
# if total_aug is not None and augs is not None:
# trainset.set_preaug(augs, total_aug)
# logger.info('set_preaug-')
trainloader, validloader, testloader, train_sampler = None, None, None, None
if trainset:
max_train_fold = min(len(trainset), max_batches*train_batch_size) if max_batches else None # pyright: ignore[reportGeneralTypeIssues]
logger.info({'val_ratio': val_ratio, 'max_train_batches': max_batches,
'max_train_fold': max_train_fold})
# sample validation set from trainset if cv_ratio > 0
train_sampler, valid_sampler = _get_sampler(trainset, val_ratio=val_ratio,
shuffle=True, apex=apex,
max_items=max_train_fold)
logger.info({'train_sampler_world_size':train_sampler.world_size,
'train_sampler_rank':train_sampler.rank,
'train_sampler_len': len(train_sampler)})
if valid_sampler:
logger.info({'valid_sampler_world_size':valid_sampler.world_size,
'valid_sampler_rank':valid_sampler.rank,
'valid_sampler_len': len(valid_sampler)
})
# shuffle is performed by sampler at each epoch
trainloader = DataLoader(trainset,
batch_size=train_batch_size, shuffle=False,
num_workers=train_workers,
pin_memory=True,
sampler=train_sampler, drop_last=False) # TODO: original paper has this True
if val_ratio > 0.0:
validloader = DataLoader(trainset,
batch_size=train_batch_size, shuffle=False,
num_workers=val_workers,
pin_memory=True,
sampler=valid_sampler, drop_last=False)
# else validloader is left as None
if testset:
max_test_fold = min(len(testset), max_batches*test_batch_size) if max_batches else None # pyright: ignore[reportGeneralTypeIssues]
logger.info({'max_test_batches': max_batches,
'max_test_fold': max_test_fold})
test_sampler, test_val_sampler = _get_sampler(testset, val_ratio=None,
shuffle=False, apex=apex,
max_items=max_test_fold)
logger.info({'test_sampler_world_size':test_sampler.world_size,
'test_sampler_rank':test_sampler.rank,
'test_sampler_len': len(test_sampler)})
assert test_val_sampler is None
testloader = DataLoader(testset,
batch_size=test_batch_size, shuffle=False,
num_workers=test_workers,
pin_memory=True,
sampler=test_sampler, drop_last=False
)
assert val_ratio > 0.0 or validloader is None
logger.info({
'train_batch_size': train_batch_size, 'test_batch_size': test_batch_size,
'train_batches': len(trainloader) if trainloader is not None else None,
'val_batches': len(validloader) if validloader is not None else None,
'test_batches': len(testloader) if testloader is not None else None
})
return trainloader, validloader, testloader
[docs]class SubsetSampler(Sampler):
"""Samples elements from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return (i for i in self.indices)
def __len__(self):
return len(self.indices)
def _get_datasets(ds_provider:DatasetProvider, load_train:bool, load_test:bool,
transform_train, transform_test)\
->Tuple[DatasetLike, DatasetLike]:
trainset, testset = ds_provider.get_datasets(load_train, load_test,
transform_train, transform_test)
return trainset, testset
# target_lb allows to filter dataset for a specific class, not used
def _get_sampler(dataset:Dataset, val_ratio:Optional[float], shuffle:bool,
max_items:Optional[int], apex:apex_utils.ApexUtils)\
->Tuple[DistributedStratifiedSampler, Optional[DistributedStratifiedSampler]]:
world_size, global_rank = apex.world_size, apex.global_rank
# we cannot not shuffle just for train or just val because of in distributed mode both must come from same shrad
train_sampler = DistributedStratifiedSampler(dataset,
val_ratio=val_ratio, is_val_split=False, shuffle=shuffle,
max_samples=max_items, world_size=world_size, rank=global_rank)
valid_sampler = DistributedStratifiedSampler(dataset,
val_ratio=val_ratio, is_val_split=True, shuffle=shuffle,
max_samples=max_items, world_size=world_size, rank=global_rank) \
if val_ratio is not None else None
return train_sampler, valid_sampler