Source code for archai.supergraph.datasets.distributed_stratified_sampler

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import math
from typing import Iterable, Optional, Tuple

import numpy as np
import torch
import torch.distributed as dist
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from torch.utils.data import Sampler
from torch.utils.data.dataset import Dataset


[docs]class DistributedStratifiedSampler(Sampler): """Distributed stratified sampling of dataset. This sampler works in distributed as well as non-distributed setting with no penalty in either mode and is a replacement for built-in `torch.util.data.DistributedSampler`. In distributed setting, many instances of the same code runs as process known as replicas. Each replica has sequential number assigned by the launcher, starting from 0 to uniquely identify it. This is known as global rank or simply rank. The number of replicas is known as the world size. For non-distributed setting, `world_size=1` and `rank=0`. This sampler assumes that labels for each datapoint is available in dataset.targets property which should be array like containing as many values as length of the dataset. This is availalble already for many popular datasets such as cifar and, with newer PyTorch versions, `ImageFolder` as well as `DatasetFolder`. If you are using custom dataset, you can usually create this property with one line of code such as `dataset.targets = [yi for _, yi in dataset]`. To do distributed sampling, each replica must shuffle with same seed as all other replicas with every epoch and then chose some subset of dataset for itself. Traditionally, we use epoch number as seed for shuffling for each replica. However, this then requires that training code calls `sampler.set_epoch(epoch)` to set seed at every epoch. """ def __init__( self, dataset: Dataset, world_size: Optional[int] = None, rank: Optional[int] = None, shuffle: Optional[bool] = True, val_ratio: Optional[float] = 0.0, is_val_split: Optional[bool] = False, max_samples: Optional[int] = None, ) -> None: """Initialize the sampler. Args: dataset: Input dataset. world_size: Total number of replicas. If `None` then auto-detect, while 1 for non-distributed setting. rank: Global rank of this replica. If `None` then auto-detect, while 0 for non-distributed setting. shuffle: Whether shuffling should be applied for every epoch. val_ratio: Creates a validation split when set to > 0. is_val_split: Whether the validation split should be returned. max_samples: Maximum number of samples for each replica. """ # CIFAR-10 amd DatasetFolder has this attribute # For others it may be easy to add from outside assert ( hasattr(dataset, "targets") and dataset.targets is not None ), "dataset needs to have targets attribute to work with this sampler" if world_size is None: if dist.is_available() and dist.is_initialized(): world_size = dist.get_world_size() else: world_size = 1 if rank is None: if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() else: rank = 0 if val_ratio is None: val_ratio = 0.0 assert world_size >= 1 assert rank >= 0 and rank < world_size assert val_ratio < 1.0 and val_ratio >= 0.0 self.dataset = dataset self.world_size = world_size self.rank = rank self.epoch = 0 # Used as a seed self.shuffle = shuffle self.data_len = len(self.dataset) self.max_samples = max_samples if max_samples is not None and max_samples >= 0 else None assert self.data_len == len(dataset.targets) self.val_ratio = val_ratio self.is_val_split = is_val_split # Computes duplications of dataset to make it divisible by world_size self.replica_len = self.replica_len_full = int(math.ceil(float(self.data_len) / self.world_size)) self.total_size = self.replica_len_full * self.world_size assert self.total_size >= self.data_len if self.max_samples is not None: self.replica_len = min(self.replica_len_full, self.max_samples) self.main_split_len = int(math.floor(self.replica_len * (1 - val_ratio))) self.val_split_len = self.replica_len - self.main_split_len self._len = self.val_split_len if self.is_val_split else self.main_split_len def __len__(self) -> int: return self._len def __iter__(self) -> Iterable: indices, targets = self._get_indices() indices, targets = self._split_rank(indices, targets) indices, targets = self._limit_indices(indices, targets, self.max_samples) indices, _ = self._split_indices(indices, targets, self.val_split_len, self.is_val_split) assert len(indices) == self._len if self.shuffle and self.val_ratio > 0.0 and self.epoch > 0: np.random.shuffle(indices) return iter(indices) def _split_rank(self, indices: np.ndarray, targets: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: if self.world_size > 1: replica_fold_idxs = None rfolder = StratifiedKFold(n_splits=self.world_size, shuffle=False) folds = rfolder.split(indices, targets) for _ in range(self.rank + 1): _, replica_fold_idxs = next(folds) assert replica_fold_idxs is not None and len(replica_fold_idxs) == self.replica_len_full return indices[replica_fold_idxs], targets[replica_fold_idxs] assert self.world_size == 1 return indices, targets def _get_indices(self) -> Tuple[np.ndarray, np.ndarray]: if self.shuffle: g = torch.Generator() g.manual_seed(self._get_seed()) indices = torch.randperm(self.data_len, generator=g).numpy() else: indices = np.arange(self.data_len) if self.total_size > self.data_len: indices = np.append(indices, indices[: (self.total_size - self.data_len)]) else: assert self.total_size == self.data_len, "`total_size` cannot be less than dataset size." targets = np.array(list(self.dataset.targets[i] for i in indices)) assert len(indices) == self.total_size return indices, targets def _limit_indices( self, indices: np.ndarray, targets: np.ndarray, max_samples: Optional[int] ) -> Tuple[np.ndarray, np.ndarray]: if max_samples is not None: return self._split_indices(indices, targets, len(indices) - max_samples, False) return indices, targets def _get_seed(self) -> int: return self.epoch if self.val_ratio == 0.0 else 0 def _split_indices( self, indices: np.ndarray, targets: np.ndarray, val_size: int, return_val_split: bool ) -> Tuple[np.ndarray, np.ndarray]: if val_size: assert isinstance(val_size, int) vfolder = StratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=self._get_seed()) vfolder = vfolder.split(indices, targets) train_idx, valid_idx = next(vfolder) idxs = valid_idx if return_val_split else train_idx return indices[idxs], targets[idxs] return indices, targets
[docs] def set_epoch(self, epoch: int) -> None: """Set the epoch for the current replica, which is used to seed the shuffling. Args: epoch: Epoch number. """ self.epoch = epoch