Source code for archai.supergraph.datasets.meta_dataset

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from torch.utils.data import Dataset


[docs]class MetaDataset(Dataset): def __init__(self, source:Dataset, transform=None, target_transform=None) -> None: self._source = source self.transform = transform if transform is not None else lambda x: x self.target_transform = target_transform if target_transform is not None else lambda x: x self._meta = [{'idx':i} for i in range(len(source))] def __len__(self): return len(self._source) def __getitem__(self, idx): x, y = self._source[idx] return (self.transform(x), self.target_transform(y), self._meta[idx])