Source code for archai.datasets.cv.coco_dataset_provider

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Callable, Optional

from overrides import overrides
from torch.utils.data import Dataset
from torchvision.datasets import CocoCaptions, CocoDetection
from torchvision.transforms import ToTensor

from archai.api.dataset_provider import DatasetProvider
from archai.common.ordered_dict_logger import OrderedDictLogger

logger = OrderedDictLogger(source=__name__)


[docs]class CocoDatasetProvider(DatasetProvider): """COCO-based dataset provider.""" SUPPORTED_DATASETS = { "coco_captions": CocoCaptions, "coco_detection": CocoDetection, } def __init__( self, dataset: Optional[str] = "coco_captions", root: Optional[str] = "dataroot", ) -> None: """Initialize COCO-based dataset provider. Args: dataset: Name of dataset. root: Root directory of dataset where is saved. """ super().__init__() assert dataset in self.SUPPORTED_DATASETS, f"`dataset` should be one of: {list(self.SUPPORTED_DATASETS)}" self.dataset = dataset self.root = root
[docs] @overrides def get_train_dataset( self, ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> Dataset: return self.SUPPORTED_DATASETS[self.dataset]( self.root, ann_file, transform=transform or ToTensor(), target_transform=target_transform, )
[docs] @overrides def get_val_dataset( self, ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> Dataset: logger.warn(f"Validation set not available for `{self.dataset}`. Returning training set ...") return self.get_train_dataset(ann_file, transform=transform, target_transform=target_transform)
[docs] @overrides def get_test_dataset( self, ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> Dataset: logger.warn(f"Testing set not available for `{self.dataset}`. Returning validation set ...") return self.get_val_dataset(ann_file, transform=transform, target_transform=target_transform)