Source code for archai.datasets.nlp.hf_dataset_provider

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Dict, List, Optional, Union

from datasets import load_dataset as hf_load_dataset
from datasets import load_from_disk as hf_load_from_disk
from datasets.arrow_dataset import Dataset
from datasets.dataset_dict import DatasetDict, IterableDatasetDict
from datasets.iterable_dataset import IterableDataset
from datasets.splits import Split
from datasets.utils.version import Version
from overrides import overrides

from archai.api.dataset_provider import DatasetProvider
from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.datasets.nlp.hf_dataset_provider_utils import should_refresh_cache

logger = OrderedDictLogger(source=__name__)


[docs]class HfHubDatasetProvider(DatasetProvider): """Hugging Face Hub dataset provider.""" def __init__( self, dataset_name: str, dataset_config_name: Optional[str] = None, data_dir: Optional[str] = None, data_files: Optional[Union[str, List[str], Dict[str, Union[str, List[str]]]]] = None, cache_dir: Optional[str] = None, revision: Optional[Union[str, Version]] = None, ) -> None: """Initialize Hugging Face Hub dataset provider. Args: dataset_name: Name of the dataset. dataset_config_name: Name of the dataset configuration. data_dir: Path to the data directory. data_files: Path(s) to the data file(s). cache_dir: Path to the read/write cache directory. revision: Version of the dataset to load. """ super().__init__() self.dataset_name = dataset_name self.dataset_config_name = dataset_config_name self.data_dir = data_dir self.data_files = data_files self.cache_dir = cache_dir self.revision = revision
[docs] def get_dataset( self, split: Optional[Union[str, Split]] = None, refresh_cache: Optional[bool] = False, keep_in_memory: Optional[bool] = False, streaming: Optional[bool] = False, ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: return hf_load_dataset( self.dataset_name, name=self.dataset_config_name, data_dir=self.data_dir, data_files=self.data_files, split=split, cache_dir=self.cache_dir, download_mode=should_refresh_cache(refresh_cache), keep_in_memory=keep_in_memory, revision=self.revision, streaming=streaming, )
[docs] @overrides def get_train_dataset( self, split: Optional[Union[str, Split]] = "train", refresh_cache: Optional[bool] = False, keep_in_memory: Optional[bool] = False, streaming: Optional[bool] = False, ) -> Union[Dataset, IterableDataset]: return self.get_dataset( split=split, refresh_cache=refresh_cache, keep_in_memory=keep_in_memory, streaming=streaming )
[docs] @overrides def get_val_dataset( self, split: Optional[Union[str, Split]] = "validation", refresh_cache: Optional[bool] = False, keep_in_memory: Optional[bool] = False, streaming: Optional[bool] = False, ) -> Union[Dataset, IterableDataset]: try: return self.get_dataset( split=split, refresh_cache=refresh_cache, keep_in_memory=keep_in_memory, streaming=streaming ) except ValueError: logger.warn(f"Validation set not available for `{self.dataset}`. Returning full training set ...") return self.get_dataset( split="train", refresh_cache=refresh_cache, keep_in_memory=keep_in_memory, streaming=streaming )
[docs] @overrides def get_test_dataset( self, split: Optional[Union[str, Split]] = "test", refresh_cache: Optional[bool] = False, keep_in_memory: Optional[bool] = False, streaming: Optional[bool] = False, ) -> Union[Dataset, IterableDataset]: try: return self.get_dataset( split=split, refresh_cache=refresh_cache, keep_in_memory=keep_in_memory, streaming=streaming ) except ValueError: logger.warn(f"Testing set not available for `{self.dataset}`. Returning full validation set ...") return self.get_dataset( split="validation", refresh_cache=refresh_cache, keep_in_memory=keep_in_memory, streaming=streaming )
[docs]class HfDiskDatasetProvider(DatasetProvider): """Hugging Face disk-saved dataset provider.""" def __init__( self, data_dir: str, keep_in_memory: Optional[bool] = False, ) -> None: """Initialize Hugging Face disk-saved dataset provider. Args: data_dir: Path to the disk-saved dataset. keep_in_memory: Whether to keep the dataset in memory. """ super().__init__() self.data_dir = data_dir self.keep_in_memory = keep_in_memory # Pre-loads the dataset when class is instantiated to avoid loading it multiple times self.dataset = hf_load_from_disk(self.data_dir, keep_in_memory=keep_in_memory)
[docs] @overrides def get_train_dataset(self) -> Dataset: if isinstance(self.dataset, DatasetDict): return self.dataset["train"] return self.dataset
[docs] @overrides def get_val_dataset(self) -> Dataset: try: if isinstance(self.dataset, DatasetDict): return self.dataset["validation"] except: logger.warn("Validation set not available. Returning training set ...") return self.get_train_dataset()
[docs] @overrides def get_test_dataset(self) -> Dataset: try: if isinstance(self.dataset, DatasetDict): return self.dataset["test"] except: logger.warn("Testing set not available. Returning validation set ...") return self.get_val_dataset()