Source code for archai.datasets.nlp.nvidia_dataset_provider

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import List, Optional

from overrides import overrides

from archai.api.dataset_provider import DatasetProvider
from archai.common.distributed_utils import sync_workers
from archai.datasets.nlp.nvidia_dataset_provider_utils import Corpus


[docs]class NvidiaDatasetProvider(DatasetProvider): """NVIDIA dataset provider.""" def __init__( self, dataset_name: Optional[str] = "wt103", dataset_dir: Optional[str] = "", cache_dir: Optional[str] = "cache", vocab_type: Optional[str] = "gpt2", vocab_size: Optional[int] = None, refresh_cache: Optional[bool] = False, ) -> None: """Initialize NVIDIA dataset provider. Args: dataset_name: Name of the dataset. dataset_dir: Dataset folder. cache_dir: Path to the cache folder. vocab_type: Type of vocabulary/tokenizer. vocab_size: Vocabulary size. refresh_cache: Whether cache should be refreshed. """ super().__init__() self.corpus = Corpus( dataset_name, dataset_dir, cache_dir, vocab_type, vocab_size=vocab_size, refresh_cache=refresh_cache ) if not self.corpus.load(): self.corpus.train_and_encode() with sync_workers() as rank: if rank == 0 and dataset_name != "lm1b": self.corpus.save_cache()
[docs] @overrides def get_train_dataset(self) -> List[int]: return self.corpus.train
[docs] @overrides def get_val_dataset(self) -> List[int]: return self.corpus.valid
[docs] @overrides def get_test_dataset(self) -> List[int]: return self.corpus.test