Source code for archai.datasets.cv.caltech_dataset_provider
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Callable, List, Optional, Union
from overrides import overrides
from torch.utils.data import Dataset
from torchvision.datasets import Caltech101, Caltech256
from torchvision.transforms import ToTensor
from archai.api.dataset_provider import DatasetProvider
from archai.common.ordered_dict_logger import OrderedDictLogger
logger = OrderedDictLogger(source=__name__)
[docs]class CaltechDatasetProvider(DatasetProvider):
"""Caltech-based dataset provider."""
SUPPORTED_DATASETS = {
"caltech101": Caltech101,
"caltech256": Caltech256,
}
def __init__(
self,
dataset: Optional[str] = "caltech101",
root: Optional[str] = "dataroot",
) -> None:
"""Initialize Caltech-based dataset provider.
Args:
dataset: Name of dataset.
root: Root directory of dataset where is saved.
"""
super().__init__()
assert dataset in self.SUPPORTED_DATASETS, f"`dataset` should be one of: {list(self.SUPPORTED_DATASETS)}"
self.dataset = dataset
self.root = root
[docs] @overrides
def get_train_dataset(
self,
target_type: Optional[Union[str, List[str]]] = "category",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> Dataset:
kwargs = {"target_type": target_type} if self.dataset == "caltech101" else {}
return self.SUPPORTED_DATASETS[self.dataset](
self.root, download=True, transform=transform or ToTensor(), target_transform=target_transform, **kwargs
)
[docs] @overrides
def get_val_dataset(
self,
target_type: Optional[Union[str, List[str]]] = "category",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> Dataset:
logger.warn(f"Validation set not available for `{self.dataset}`. Returning training set ...")
return self.get_train_dataset(target_type=target_type, transform=transform, target_transform=target_transform)
[docs] @overrides
def get_test_dataset(
self,
target_type: Optional[Union[str, List[str]]] = "category",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> Dataset:
logger.warn(f"Testing set not available for `{self.dataset}`. Returning training set ...")
return self.get_train_dataset(target_type=target_type, transform=transform, target_transform=target_transform)