Source code for archai.datasets.cv.tensorpack_lmdb_dataset_provider_utils

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Any, Callable, Dict, List, Optional, Tuple

import cv2
import lmdb
import msgpack
import numpy as np
import torch
from torch.utils.data import Dataset

from archai.common.ordered_dict_logger import OrderedDictLogger

logger = OrderedDictLogger(source=__name__)


[docs]class TensorpackLmdbDataset(Dataset): """Tensorpack LMDB dataset.""" def __init__( self, lmdb_file_path: str, img_key: str, mask_key: Optional[str] = None, serializer: Optional[str] = "msgpack", img_size: Optional[Tuple[int, ...]] = None, img_format: Optional[str] = "numpy", ones_mask: Optional[bool] = False, zeroes_mask: Optional[bool] = False, raise_errors: Optional[bool] = True, is_bgr: Optional[bool] = True, valid_resolutions: Optional[List[Tuple]] = None, augmentation_fn: Optional[Callable] = None, mask_interpolation_method: int = cv2.INTER_NEAREST, ) -> None: """Initialize Tensorpack LMDB dataset. Args: lmdb_file_path: Path to the LMDB file. img_key: Image key in LMDB file. mask_key: Mask key in LMDB file. serializer: Serializer used to serialize data in LMDB file. img_size: Image size. img_format: Image format. ones_mask: Whether mask is composed of ones. zeroes_mask: Whether mask is composed of zeroes. raise_errors: Whether to raise errors. is_bgr: Whether image is in BGR format. valid_resolutions: Valid resolutions. augmentation_fn: Augmentation function. mask_interpolation_method: Mask interpolation method. """ self.lmdb_file_path = lmdb_file_path self.db = lmdb.open( lmdb_file_path, subdir=False, readonly=True, lock=False, readahead=True, map_size=1099511627776 * 2, max_readers=100, ) self.img_key = img_key self.mask_key = mask_key self.txn = self.db.begin() self.keys = [k for k, _ in self.txn.cursor() if k != b"__keys__"] self.img_size = img_size self.serializer = serializer self.img_format = img_format self.ones_mask = ones_mask self.zeroes_mask = zeroes_mask assert not (self.ones_mask and self.zeroes_mask), "`ones_mask` and `zeroes_mask` are mutually exclusive." if self.mask_key is None: assert ( self.ones_mask or self.zeroes_mask ), "`ones_mask` or `zeroes_mask` must be True if `mask_key` is None." self.is_bgr = is_bgr self.raise_errors = raise_errors self.valid_resolutions = valid_resolutions self.augmentation_fn = augmentation_fn self.mask_interpolation_method = mask_interpolation_method def __len__(self) -> int: """Return length of the dataset.""" return len(self.keys) def _get_datapoint(self, idx: int) -> Dict[str, Any]: """Get a data point from the dataset. Args: idx: Index of the data point. Returns: Data point. """ key = self.keys[idx] value = self.txn.get(key) if self.serializer == "msgpack": sample = msgpack.loads(value) else: raise NotImplementedError(f"Unsupported serializer {self.serializer}") for d_key in [self.img_key, self.mask_key]: if d_key and d_key not in sample: available_keys = sample.keys() if isinstance(sample, dict) else [] raise KeyError(f"{d_key} not found in sample. Available keys: {available_keys}") if d_key and isinstance(sample[d_key], dict) and b"data" in sample[d_key]: sample[d_key] = sample[d_key][b"data"] return sample def __getitem__(self, idx: int) -> Dict[str, Any]: """Get a sample from the dataset. Args: idx: Index of the sample. Returns: Sample. """ try: sample = self._get_datapoint(idx) if self.img_format == "numpy": img = np.frombuffer(sample[self.img_key], dtype=np.uint8).reshape((-1, 1)) img = cv2.imdecode(img, cv2.IMREAD_COLOR) img = img[..., ::-1].copy() if self.is_bgr else img if self.ones_mask: mask = np.ones(img.shape[:2], dtype=np.uint8) elif self.zeroes_mask or len(sample[self.mask_key]) == 0: mask = np.zeros(img.shape[:2], dtype=np.uint8) else: mask_cv2_buf = np.frombuffer(sample[self.mask_key], dtype=np.uint8).reshape((-1, 1)) mask = cv2.imdecode(mask_cv2_buf, cv2.IMREAD_GRAYSCALE) sample = {"image": img, "mask": mask} if self.augmentation_fn: sample = self.augmentation_fn(**sample) if self.img_size: sample["image"] = cv2.resize(sample["image"], self.img_size) sample["mask"] = cv2.resize( sample["mask"], self.img_size, interpolation=self.mask_interpolation_method ) if self.valid_resolutions: assert img.shape[:2] in self.valid_resolutions assert mask.shape[:2] in self.valid_resolutions else: raise NotImplementedError(f"Unsupported image format: {self.img_format}") return { "image": torch.tensor(sample["image"].transpose(2, 0, 1) / 255.0, dtype=torch.float), "mask": torch.tensor(sample["mask"], dtype=torch.long), "dataset_path": self.lmdb_file_path, "key": self.keys[idx], } except Exception as e: if self.raise_errors: raise e else: logger.error(f"Sample {idx} from dataset {self.lmdb_file_path} could not be loaded.")