Source code for archai.supergraph.datasets.dataset_provider
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import abstractmethod
from typing import Dict, Optional, Tuple, Union
from overrides import EnforceOverrides
from torch.utils.data.dataset import Dataset
from archai.common.config import Config
TrainTestDatasets = Tuple[Optional[Dataset], Optional[Dataset]]
ImgSize = Optional[Union[int, Tuple[int, int]]]
[docs]class DatasetProvider(EnforceOverrides):
def __init__(self, conf_dataset:Config):
super().__init__()
pass
[docs] @abstractmethod
def get_datasets(self, load_train:bool, load_test:bool,
transform_train, transform_test)->TrainTestDatasets:
pass
DatasetProviderType = type(DatasetProvider)
_providers: Dict[str, DatasetProviderType] = {}
[docs]def register_dataset_provider(name:str, class_type:DatasetProviderType)->None:
global _providers
if name in _providers:
raise KeyError(f'dataset provider with name {name} has already been registered')
_providers[name] = class_type
[docs]def get_provider_type(name:str)->DatasetProviderType:
global _providers
return _providers[name]