Source code for archai.datasets.cv.aircraft_dataset_provider
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Callable, Optional
from overrides import overrides
from torch.utils.data import Dataset
from torchvision.datasets import FGVCAircraft
from torchvision.transforms import ToTensor
from archai.api.dataset_provider import DatasetProvider
[docs]class AircraftDatasetProvider(DatasetProvider):
"""FGVC Aircraft dataset provider."""
def __init__(
self,
root: Optional[str] = "dataroot",
) -> None:
"""Initialize FGVC Aircraft dataset provider.
Args:
root: Root directory of dataset where is saved.
"""
super().__init__()
self.root = root
[docs] @overrides
def get_train_dataset(
self,
annotation_level: Optional[str] = "variant",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> Dataset:
return FGVCAircraft(
self.root,
split="train",
annotation_level=annotation_level,
transform=transform or ToTensor(),
target_transform=target_transform,
download=True,
)
[docs] @overrides
def get_val_dataset(
self,
annotation_level: Optional[str] = "variant",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> Dataset:
return FGVCAircraft(
self.root,
split="val",
annotation_level=annotation_level,
transform=transform or ToTensor(),
target_transform=target_transform,
download=True,
)
[docs] @overrides
def get_test_dataset(
self,
annotation_level: Optional[str] = "variant",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> Dataset:
return FGVCAircraft(
self.root,
split="test",
annotation_level=annotation_level,
transform=transform or ToTensor(),
target_transform=target_transform,
download=True,
)