Datasets#

Augmentation Policies#

archai.supergraph.datasets.aug_policies.fa_reduced_cifar10()[source]#
archai.supergraph.datasets.aug_policies.fa_resnet50_rimagenet()[source]#
archai.supergraph.datasets.aug_policies.fa_reduced_svhn()[source]#

Augmentations#

class archai.supergraph.datasets.augmentation.Augmentation(policies)[source]#
archai.supergraph.datasets.augmentation.add_named_augs(transform_train, aug: List | str, cutout: int)[source]#
archai.supergraph.datasets.augmentation.ShearX(img, v)[source]#
archai.supergraph.datasets.augmentation.ShearY(img, v)[source]#
archai.supergraph.datasets.augmentation.TranslateX(img, v)[source]#
archai.supergraph.datasets.augmentation.TranslateY(img, v)[source]#
archai.supergraph.datasets.augmentation.TranslateXAbs(img, v)[source]#
archai.supergraph.datasets.augmentation.TranslateYAbs(img, v)[source]#
archai.supergraph.datasets.augmentation.Rotate(img, v)[source]#
archai.supergraph.datasets.augmentation.AutoContrast(img, _)[source]#
archai.supergraph.datasets.augmentation.Invert(img, _)[source]#
archai.supergraph.datasets.augmentation.Equalize(img, _)[source]#
archai.supergraph.datasets.augmentation.Flip(img, _)[source]#
archai.supergraph.datasets.augmentation.Solarize(img, v)[source]#
archai.supergraph.datasets.augmentation.Posterize(img, v)[source]#
archai.supergraph.datasets.augmentation.Posterize2(img, v)[source]#
archai.supergraph.datasets.augmentation.Contrast(img, v)[source]#
archai.supergraph.datasets.augmentation.Color(img, v)[source]#
archai.supergraph.datasets.augmentation.Brightness(img, v)[source]#
archai.supergraph.datasets.augmentation.Sharpness(img, v)[source]#
archai.supergraph.datasets.augmentation.Cutout(img, v)[source]#
archai.supergraph.datasets.augmentation.CutoutAbs(img, v)[source]#
archai.supergraph.datasets.augmentation.SamplePairing(imgs)[source]#
archai.supergraph.datasets.augmentation.augment_list(for_autoaug=True)[source]#
archai.supergraph.datasets.augmentation.get_augment(name)[source]#
archai.supergraph.datasets.augmentation.apply_augment(img, name, level)[source]#
archai.supergraph.datasets.augmentation.arsaug_policy()[source]#
archai.supergraph.datasets.augmentation.autoaug2arsaug(f)[source]#
archai.supergraph.datasets.augmentation.autoaug_paper_cifar10()[source]#
archai.supergraph.datasets.augmentation.autoaug_policy()[source]#
archai.supergraph.datasets.augmentation.float_parameter(level, maxval)[source]#
archai.supergraph.datasets.augmentation.int_parameter(level, maxval)[source]#
archai.supergraph.datasets.augmentation.no_duplicates(f)[source]#
archai.supergraph.datasets.augmentation.remove_deplicates(policies)[source]#
archai.supergraph.datasets.augmentation.policy_decoder(augment, num_policy, num_op)[source]#

Data#

class archai.supergraph.datasets.data.DataLoaders(train_dl: DataLoader | None = None, val_dl: DataLoader | None = None, test_dl: DataLoader | None = None)[source]#
archai.supergraph.datasets.data.get_data(conf_loader: Config) DataLoaders[source]#
archai.supergraph.datasets.data.create_dataset_provider(conf_dataset: Config) DatasetProvider[source]#
archai.supergraph.datasets.data.get_dataloaders(ds_provider: DatasetProvider, load_train: bool, train_batch_size: int, load_test: bool, test_batch_size: int, aug, cutout: int, val_ratio: float, apex: ApexUtils, val_fold=0, img_size: int | None = None, train_workers: int | None = None, test_workers: int | None = None, target_lb=-1, max_batches: int = -1) Tuple[DataLoader | None, DataLoader | None, DataLoader | None][source]#
class archai.supergraph.datasets.data.SubsetSampler(indices)[source]#

Samples elements from a given list of indices, without replacement.

Parameters:

indices (sequence) – a sequence of indices

Dataset Provider#

class archai.supergraph.datasets.dataset_provider.DatasetProvider(conf_dataset: Config)[source]#
abstract get_datasets(load_train: bool, load_test: bool, transform_train, transform_test) Tuple[Dataset | None, Dataset | None][source]#
abstract get_transforms(img_size: int | Tuple[int, int] | None) tuple[source]#
archai.supergraph.datasets.dataset_provider.register_dataset_provider(name: str, class_type: EnforceOverridesMeta) None[source]#
archai.supergraph.datasets.dataset_provider.get_provider_type(name: str) EnforceOverridesMeta[source]#

Distributed Stratified Sampler#

class archai.supergraph.datasets.distributed_stratified_sampler.DistributedStratifiedSampler(dataset: Dataset, world_size: int | None = None, rank: int | None = None, shuffle: bool | None = True, val_ratio: float | None = 0.0, is_val_split: bool | None = False, max_samples: int | None = None)[source]#

Distributed stratified sampling of dataset.

This sampler works in distributed as well as non-distributed setting with no penalty in either mode and is a replacement for built-in torch.util.data.DistributedSampler.

In distributed setting, many instances of the same code runs as process known as replicas. Each replica has sequential number assigned by the launcher, starting from 0 to uniquely identify it. This is known as global rank or simply rank. The number of replicas is known as the world size. For non-distributed setting, world_size=1 and rank=0.

This sampler assumes that labels for each datapoint is available in dataset.targets property which should be array like containing as many values as length of the dataset. This is availalble already for many popular datasets such as cifar and, with newer PyTorch versions, ImageFolder as well as DatasetFolder. If you are using custom dataset, you can usually create this property with one line of code such as dataset.targets = [yi for _, yi in dataset].

To do distributed sampling, each replica must shuffle with same seed as all other replicas with every epoch and then chose some subset of dataset for itself. Traditionally, we use epoch number as seed for shuffling for each replica. However, this then requires that training code calls sampler.set_epoch(epoch) to set seed at every epoch.

set_epoch(epoch: int) None[source]#

Set the epoch for the current replica, which is used to seed the shuffling.

Parameters:

epoch – Epoch number.

Limit Dataset#

class archai.supergraph.datasets.limit_dataset.LimitDataset(dataset, n)[source]#

Meta Dataset#

class archai.supergraph.datasets.meta_dataset.MetaDataset(source: Dataset, transform=None, target_transform=None)[source]#