Source code for archai.datasets.cv.tensorpack_lmdb_dataset_provider

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Callable, List, Optional, Tuple

import cv2
from overrides import overrides
from torch.utils.data import ConcatDataset

from archai.api.dataset_provider import DatasetProvider
from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.datasets.cv.tensorpack_lmdb_dataset_provider_utils import (
    TensorpackLmdbDataset,
)

logger = OrderedDictLogger(source=__name__)


[docs]class TensorpackLmdbDatasetProvider(DatasetProvider): """Tensorpack LMDB dataset provider.""" def __init__( self, img_key: str, train_lmdb_file_path: str, val_lmdb_file_path: Optional[str] = None, test_lmdb_file_path: Optional[str] = None, ) -> None: """Initialize Tensorpack LMDB dataset provider. Args: img_key: Image key in the LMDB file. train_lmdb_file_path: Path to the training LMDB file. val_lmdb_file_path: Path to the validation LMDB file. test_lmdb_file_path: Path to the testing LMDB file. """ self.img_key = img_key assert train_lmdb_file_path.exists(), f"File {train_lmdb_file_path} must exists." self.train_lmdb_file_path = train_lmdb_file_path if val_lmdb_file_path is not None: assert val_lmdb_file_path.exists(), f"File {val_lmdb_file_path} must exists." self.val_lmdb_file_path = val_lmdb_file_path if test_lmdb_file_path is not None: assert test_lmdb_file_path.exists(), f"File {test_lmdb_file_path} must exists." self.test_lmdb_file_path = test_lmdb_file_path
[docs] @overrides def get_train_dataset( self, mask_key: Optional[str] = None, serializer: Optional[str] = "msgpack", img_size: Optional[Tuple[int, ...]] = None, img_format: Optional[str] = "numpy", ones_mask: Optional[bool] = False, zeroes_mask: Optional[bool] = False, raise_errors: Optional[bool] = True, is_bgr: Optional[bool] = True, valid_resolutions: Optional[List[Tuple]] = None, augmentation_fn: Optional[Callable] = None, mask_interpolation_method: int = cv2.INTER_NEAREST, ) -> TensorpackLmdbDataset: return TensorpackLmdbDataset( self.train_lmdb_file_path, self.img_key, mask_key=mask_key, serializer=serializer, img_size=img_size, img_format=img_format, ones_mask=ones_mask, zeroes_mask=zeroes_mask, raise_errors=raise_errors, is_bgr=is_bgr, valid_resolutions=valid_resolutions, augmentation_fn=augmentation_fn, mask_interpolation_method=mask_interpolation_method, )
[docs] @overrides def get_val_dataset( self, mask_key: Optional[str] = None, serializer: Optional[str] = "msgpack", img_size: Optional[Tuple[int, ...]] = None, img_format: Optional[str] = "numpy", ones_mask: Optional[bool] = False, zeroes_mask: Optional[bool] = False, raise_errors: Optional[bool] = True, is_bgr: Optional[bool] = True, valid_resolutions: Optional[List[Tuple]] = None, augmentation_fn: Optional[Callable] = None, mask_interpolation_method: int = cv2.INTER_NEAREST, ) -> TensorpackLmdbDataset: try: return TensorpackLmdbDataset( self.val_lmdb_file_path, self.img_key, mask_key=mask_key, serializer=serializer, img_size=img_size, img_format=img_format, ones_mask=ones_mask, zeroes_mask=zeroes_mask, raise_errors=raise_errors, is_bgr=is_bgr, valid_resolutions=valid_resolutions, augmentation_fn=augmentation_fn, mask_interpolation_method=mask_interpolation_method, ) except: logger.warn("Validation set not available. Returning training set ...") return self.get_train_dataset( mask_key=mask_key, serializer=serializer, img_size=img_size, img_format=img_format, ones_mask=ones_mask, zeroes_mask=zeroes_mask, raise_errors=raise_errors, is_bgr=is_bgr, valid_resolutions=valid_resolutions, augmentation_fn=augmentation_fn, mask_interpolation_method=mask_interpolation_method, )
[docs] @overrides def get_test_dataset( self, mask_key: Optional[str] = None, serializer: Optional[str] = "msgpack", img_size: Optional[Tuple[int, ...]] = None, img_format: Optional[str] = "numpy", ones_mask: Optional[bool] = False, zeroes_mask: Optional[bool] = False, raise_errors: Optional[bool] = True, is_bgr: Optional[bool] = True, valid_resolutions: Optional[List[Tuple]] = None, augmentation_fn: Optional[Callable] = None, mask_interpolation_method: int = cv2.INTER_NEAREST, ) -> TensorpackLmdbDataset: try: return TensorpackLmdbDataset( self.test_lmdb_file_path, self.img_key, mask_key=mask_key, serializer=serializer, img_size=img_size, img_format=img_format, ones_mask=ones_mask, zeroes_mask=zeroes_mask, raise_errors=raise_errors, is_bgr=is_bgr, valid_resolutions=valid_resolutions, augmentation_fn=augmentation_fn, mask_interpolation_method=mask_interpolation_method, ) except: logger.warn("Testing set not available. Returning validation set ...") return self.get_val_dataset( mask_key=mask_key, serializer=serializer, img_size=img_size, img_format=img_format, ones_mask=ones_mask, zeroes_mask=zeroes_mask, raise_errors=raise_errors, is_bgr=is_bgr, valid_resolutions=valid_resolutions, augmentation_fn=augmentation_fn, mask_interpolation_method=mask_interpolation_method, )
[docs]class MultiFileTensorpackLmdbDatasetProvider(DatasetProvider): """Multi-file Tensorpack LMDB dataset provider.""" def __init__( self, img_key: List[str], train_lmdb_file_path: List[str], val_lmdb_file_path: Optional[List[str]] = None, test_lmdb_file_path: Optional[List[str]] = None, ) -> None: """Initialize multi-file Tensorpack LMDB dataset provider. Args: img_key: Image keys in the LMDB files. train_lmdb_file_path: Path to the training LMDB files. val_lmdb_file_path: Path to the validation LMDB files. test_lmdb_file_path: Path to the testing LMDB files. """ assert len(img_key) == len( train_lmdb_file_path ), "Number of image keys must be equal to the number of training LMDB files." self.img_key = img_key for file_path in train_lmdb_file_path: assert file_path.exists(), f"File {file_path} must exists." self.train_lmdb_file_path = train_lmdb_file_path if val_lmdb_file_path is not None: for file_path in val_lmdb_file_path: assert file_path.exists(), f"File {file_path} must exists." assert len(img_key) == len( val_lmdb_file_path ), "Number of image keys must be equal to the number of validation LMDB files." self.val_lmdb_file_path = val_lmdb_file_path if test_lmdb_file_path is not None: for file_path in test_lmdb_file_path: assert file_path.exists(), f"File {file_path} must exists." assert len(img_key) == len( val_lmdb_file_path ), "Number of image keys must be equal to the number of testing LMDB files." self.test_lmdb_file_path = test_lmdb_file_path
[docs] @overrides def get_train_dataset( self, mask_key: Optional[str] = None, serializer: Optional[str] = "msgpack", img_size: Optional[Tuple[int, ...]] = None, img_format: Optional[str] = "numpy", ones_mask: Optional[bool] = False, zeroes_mask: Optional[bool] = False, raise_errors: Optional[bool] = True, is_bgr: Optional[bool] = True, valid_resolutions: Optional[List[Tuple]] = None, augmentation_fn: Optional[Callable] = None, mask_interpolation_method: int = cv2.INTER_NEAREST, ) -> ConcatDataset: return ConcatDataset( [ TensorpackLmdbDataset( file_path, img_key, mask_key=mask_key, serializer=serializer, img_size=img_size, img_format=img_format, ones_mask=ones_mask, zeroes_mask=zeroes_mask, raise_errors=raise_errors, is_bgr=is_bgr, valid_resolutions=valid_resolutions, augmentation_fn=augmentation_fn, mask_interpolation_method=mask_interpolation_method, ) for file_path, img_key in zip(self.train_lmdb_file_path, self.img_key) ] )
[docs] @overrides def get_val_dataset( self, mask_key: Optional[str] = None, serializer: Optional[str] = "msgpack", img_size: Optional[Tuple[int, ...]] = None, img_format: Optional[str] = "numpy", ones_mask: Optional[bool] = False, zeroes_mask: Optional[bool] = False, raise_errors: Optional[bool] = True, is_bgr: Optional[bool] = True, valid_resolutions: Optional[List[Tuple]] = None, augmentation_fn: Optional[Callable] = None, mask_interpolation_method: int = cv2.INTER_NEAREST, ) -> ConcatDataset: try: return ConcatDataset( [ TensorpackLmdbDataset( file_path, img_key, mask_key=mask_key, serializer=serializer, img_size=img_size, img_format=img_format, ones_mask=ones_mask, zeroes_mask=zeroes_mask, raise_errors=raise_errors, is_bgr=is_bgr, valid_resolutions=valid_resolutions, augmentation_fn=augmentation_fn, mask_interpolation_method=mask_interpolation_method, ) for file_path, img_key in zip(self.val_lmdb_file_path, self.img_key) ] ) except: logger.warn("Validation set not available. Returning training set ...") return self.get_train_dataset( mask_key=mask_key, serializer=serializer, img_size=img_size, img_format=img_format, ones_mask=ones_mask, zeroes_mask=zeroes_mask, raise_errors=raise_errors, is_bgr=is_bgr, valid_resolutions=valid_resolutions, augmentation_fn=augmentation_fn, mask_interpolation_method=mask_interpolation_method, )
[docs] @overrides def get_test_dataset( self, mask_key: Optional[str] = None, serializer: Optional[str] = "msgpack", img_size: Optional[Tuple[int, ...]] = None, img_format: Optional[str] = "numpy", ones_mask: Optional[bool] = False, zeroes_mask: Optional[bool] = False, raise_errors: Optional[bool] = True, is_bgr: Optional[bool] = True, valid_resolutions: Optional[List[Tuple]] = None, augmentation_fn: Optional[Callable] = None, mask_interpolation_method: int = cv2.INTER_NEAREST, ) -> ConcatDataset: try: return ConcatDataset( [ TensorpackLmdbDataset( file_path, img_key, mask_key=mask_key, serializer=serializer, img_size=img_size, img_format=img_format, ones_mask=ones_mask, zeroes_mask=zeroes_mask, raise_errors=raise_errors, is_bgr=is_bgr, valid_resolutions=valid_resolutions, augmentation_fn=augmentation_fn, mask_interpolation_method=mask_interpolation_method, ) for file_path, img_key in zip(self.test_lmdb_file_path, self.img_key) ] ) except: logger.warn("Testing set not available. Returning validation set ...") return self.get_val_dataset( mask_key=mask_key, serializer=serializer, img_size=img_size, img_format=img_format, ones_mask=ones_mask, zeroes_mask=zeroes_mask, raise_errors=raise_errors, is_bgr=is_bgr, valid_resolutions=valid_resolutions, augmentation_fn=augmentation_fn, mask_interpolation_method=mask_interpolation_method, )