Source code for archai.api.dataset_provider

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from abc import abstractmethod
from typing import Any

from overrides import EnforceOverrides


[docs]class DatasetProvider(EnforceOverrides): """Abstract class for dataset providers. This class serves as a base for implementing dataset providers that can return training, validation and testing datasets. The class enforces implementation of three methods: `get_train_dataset`, `get_val_dataset` and `get_test_dataset`. These methods should return an instance of the respective dataset, regardless of its structure. Note: This class is inherited from `EnforceOverrides` and any overridden methods in the subclass should be decorated with `@overrides` to ensure they are properly overridden. Examples: >>> class MyDatasetProvider(DatasetProvider): >>> def __init__(self) -> None: >>> super().__init__() >>> >>> @overrides >>> def get_train_dataset(self) -> Any: >>> return torchvision.datasets.MNIST(train=True) >>> >>> @overrides >>> def get_val_dataset(self) -> Any: >>> return torchvision.datasets.MNIST(train=False) >>> >>> @overrides >>> def get_test_dataset(self) -> Any: >>> return torchvision.datasets.MNIST(train=False) """ def __init__(self) -> None: """Initialize the dataset provider.""" pass
[docs] @abstractmethod def get_train_dataset(self) -> Any: """Get a training dataset. Returns: An instance of a training dataset. """ pass
[docs] @abstractmethod def get_val_dataset(self) -> Any: """Get a validation dataset. Returns: An instance of a validation dataset, or the training dataset if validation dataset is not available. """ pass
[docs] @abstractmethod def get_test_dataset(self) -> Any: """Get a testing dataset. Returns: An instance of a testing dataset, or the training/validation dataset if testing dataset is not available. """ pass