Source code for archai.datasets.nlp.fast_hf_dataset_provider

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import json
import os
import pickle
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import numpy as np
import torch
from datasets import load_dataset, load_from_disk
from datasets.dataset_dict import DatasetDict
from overrides import overrides
from transformers import AutoTokenizer, DataCollatorForLanguageModeling

from archai.api.dataset_provider import DatasetProvider
from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.datasets.nlp.fast_hf_dataset_provider_utils import (
    FastHfDataset,
    SHMArray,
    process_with_memory_map_files,
    process_with_shared_memory,
    xor,
)
from archai.datasets.nlp.hf_dataset_provider_utils import tokenize_concatenated_dataset

logger = OrderedDictLogger(source=__name__)

if sys.version_info.major == 3 and sys.version_info.minor >= 8:
    ALLOW_SHARED_MEMORY = True
else:
    logger.warn("Shared memory is not available in Python < 3.8.")
    ALLOW_SHARED_MEMORY = False


[docs]class FastHfDatasetProvider(DatasetProvider): """Fast Hugging Face-based dataset provider.""" def __init__( self, train_file: str, validation_file: str, test_file: str, tokenizer: Optional[AutoTokenizer] = None, ) -> None: """Initialize Fast Hugging Face-based dataset provider. Args: train_file: Path to the training array file (.npy). validation_file: Path to the validation array file (.npy). test_file: Path to the test array file (.npy). tokenizer: Instance of tokenizer to use. """ super().__init__() self.train_file = train_file self.validation_file = validation_file self.test_file = test_file self.tokenizer = tokenizer # Windows does not allow tests to memory map the same file # when tests are running in parallel self.mmap_mode = None if os.name == "nt" and os.getenv("PYTEST_CURRENT_TEST") else "r" @staticmethod def _create_splits(dataset_dict: DatasetDict, validation_split: float, shuffle: bool, seed: int) -> DatasetDict: if "validation" not in dataset_dict: logger.info("Creating validation split ...") validation_split = validation_split or 0.1 tmp_dataset_dict = dataset_dict["train"].train_test_split( test_size=validation_split, shuffle=shuffle, seed=seed ) dataset_dict["train"] = tmp_dataset_dict["train"] dataset_dict["validation"] = tmp_dataset_dict["test"] if "test" not in dataset_dict: logger.info("Creating test split ...") tmp_dataset_dict = dataset_dict["validation"].train_test_split(test_size=0.25, shuffle=shuffle, seed=seed) dataset_dict["validation"] = tmp_dataset_dict["train"] dataset_dict["test"] = tmp_dataset_dict["test"] return dataset_dict @staticmethod def _encode_dataset( dataset_dict: DatasetDict, tokenizer: AutoTokenizer, mapping_fn: Callable[[Any], Dict[str, Any]], mapping_fn_kwargs: Dict[str, Any], mapping_column_name: List[str], use_eos_token: bool, dtype: np.dtype, num_workers: int, ) -> DatasetDict: logger.info("Encoding dataset ...") logger.info(f"Number of workers: {num_workers} | EOS token: {use_eos_token}") mapping_fn = mapping_fn or tokenize_concatenated_dataset mapping_fn_kwargs = mapping_fn_kwargs or { "tokenizer": tokenizer, "mapping_column_name": mapping_column_name, "use_eos_token": use_eos_token, "dtype": dtype, } column_names = dataset_dict["train"].column_names encoded_dataset_dict = dataset_dict.map( mapping_fn, fn_kwargs=mapping_fn_kwargs, batched=True, num_proc=num_workers, remove_columns=column_names, ) return encoded_dataset_dict @staticmethod def _close_mem_maps(processed_dataset_dict: DatasetDict) -> None: for key in processed_dataset_dict: dataset = processed_dataset_dict[key] if isinstance(dataset, np.memmap) and dataset._mmap is not None: dataset._mmap.close() @staticmethod def _process_dataset_to_memory( dataset_dict: DatasetDict, cache_dir: str, dtype: np.dtype, num_workers: int, use_shared_memory: int ) -> Dict[str, Union[SHMArray, np.ndarray]]: logger.info("Processing dataset to memory ...") logger.info(f"Number of workers: {num_workers} | Shared memory: {use_shared_memory}") if use_shared_memory: return process_with_shared_memory(dataset_dict, dtype, num_proc=num_workers) return process_with_memory_map_files(dataset_dict, cache_dir, dtype, num_proc=num_workers) @staticmethod def _save_dataset( dataset_dict: Dict[str, Union[SHMArray, np.ndarray]], tokenizer: AutoTokenizer, cache_dir: str, use_shared_memory: bool, ) -> Tuple[Path, Path, Path]: logger.info(f"Saving dataset to: {cache_dir}") cache_files = {} for split, dataset in dataset_dict.items(): np.save(cache_dir / f"{split}.npy", dataset) # If using shared memory, dataset needs to have its shared memory # unlinked to prevent memory leak if use_shared_memory: dataset.shm.unlink() # If not using shared memory, dataset needs to have its memory map # closed to prevent an additional .bin file if not use_shared_memory: dataset._mmap.close() Path(cache_dir / f"{split}.bin").unlink() cache_files[f"{split}_file"] = cache_dir / f"{split}.npy" with open(cache_dir / "tokenizer.pkl", "wb") as f: pickle.dump(tokenizer, f) return cache_files
[docs] @classmethod def from_disk( cls: FastHfDatasetProvider, dataset_file_path: str, tokenizer: Optional[AutoTokenizer] = None, tokenizer_name: Optional[str] = None, mapping_fn: Optional[Callable[[Any], Dict[str, Any]]] = None, mapping_fn_kwargs: Optional[Dict[str, Any]] = None, mapping_column_name: Optional[List[str]] = None, validation_split: Optional[float] = 0.0, shuffle: Optional[bool] = True, seed: Optional[int] = 42, num_workers: Optional[int] = 1, use_eos_token: Optional[bool] = True, use_shared_memory: Optional[bool] = True, cache_dir: Optional[str] = "cache", ) -> FastHfDatasetProvider: """Load a dataset provider by loading and encoding data from disk. Args: dataset_file_path: Path to the dataset file stored in disk. tokenizer: Instance of tokenizer to use. tokenizer_name: Name of the tokenizer, if `tokenizer` has not been passed. mapping_fn: A function that maps the dataset. If not provided, the default `tokenize_concatenated_dataset` function will be used. mapping_fn_kwargs: Keyword arguments to pass to `mapping_fn`. mapping_column_name: The columns in the dataset to be tokenized. If `str`, only one column will be tokenized. If `List[str]`, multiple columns will be tokenized. validation_split: Fraction of the dataset to use for validation. shuffle: Whether to shuffle the dataset. seed: Random seed. num_workers: Number of workers to use for encoding. use_eos_token: Whether to use EOS token to separate sequences. use_shared_memory: Whether to use shared memory for caching. cache_dir: Root path to the cache directory. Returns: Dataset provider. """ assert xor(tokenizer, tokenizer_name), "`tokenizer` and `tokenizer_name` are mutually exclusive." tokenizer = tokenizer or AutoTokenizer.from_pretrained(tokenizer_name) dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32 use_shared_memory = use_shared_memory and ALLOW_SHARED_MEMORY cache_dir = Path(cache_dir) if cache_dir.is_dir(): logger.warn(f"Cache: {cache_dir} already exists and will be overritten.") cache_dir.mkdir(parents=True, exist_ok=True) # Ensure that loaded dataset is always a dictionary logger.info(f"Loading dataset from: {dataset_file_path}") disk_dataset_dict = load_from_disk(dataset_file_path) if not isinstance(disk_dataset_dict, DatasetDict): disk_dataset_dict = DatasetDict({"train": disk_dataset_dict}) # Ensure that `validation` and `test` splits are available disk_dataset_dict = FastHfDatasetProvider._create_splits(disk_dataset_dict, validation_split, shuffle, seed) encoded_dataset_dict = FastHfDatasetProvider._encode_dataset( disk_dataset_dict, tokenizer, mapping_fn, mapping_fn_kwargs, mapping_column_name, use_eos_token, dtype, num_workers, ) processed_dataset_dict = FastHfDatasetProvider._process_dataset_to_memory( encoded_dataset_dict, cache_dir, dtype, num_workers, use_shared_memory ) cache_files = FastHfDatasetProvider._save_dataset( processed_dataset_dict, tokenizer, cache_dir, use_shared_memory ) FastHfDatasetProvider._close_mem_maps(processed_dataset_dict) with open(cache_dir / "config.json", "w") as f: json.dump( { "dataset_file_path": dataset_file_path, "tokenizer": { "name_or_path": tokenizer.name_or_path, "model_max_length": None, }, "mapping_column_name": mapping_column_name or ["text"], "validation_split": validation_split, "shuffle": shuffle, "seed": seed, "use_eos_token": use_eos_token, }, f, ) return FastHfDatasetProvider(**cache_files, tokenizer=tokenizer)
[docs] @classmethod def from_hub( cls: FastHfDatasetProvider, dataset_name: str, dataset_config_name: Optional[str] = None, data_dir: Optional[str] = None, data_files: Optional[Union[List[str], Dict[str, Union[str, List[str]]]]] = None, tokenizer: Optional[AutoTokenizer] = None, tokenizer_name: Optional[str] = None, mapping_fn: Optional[Callable[[Any], Dict[str, Any]]] = None, mapping_fn_kwargs: Optional[Dict[str, Any]] = None, mapping_column_name: Optional[List[str]] = None, validation_split: Optional[float] = 0.0, shuffle: Optional[bool] = True, seed: Optional[int] = 42, num_workers: Optional[int] = 1, use_eos_token: Optional[bool] = True, use_shared_memory: Optional[bool] = True, cache_dir: Optional[str] = "cache", ) -> FastHfDatasetProvider: """Load a dataset provider by downloading and encoding data from Hugging Face Hub. Args: dataset_name: Name of the dataset. dataset_config_name: Name of the dataset configuration. data_dir: Path to the data directory. data_files: Path to the source data file(s). tokenizer: Instance of tokenizer to use. tokenizer_name: Name of the tokenizer, if `tokenizer` has not been passed. mapping_fn: A function that maps the dataset. If not provided, the default `tokenize_concatenated_dataset` function will be used. mapping_fn_kwargs: Keyword arguments to pass to `mapping_fn`. mapping_column_name: The columns in the dataset to be tokenized. If `str`, only one column will be tokenized. If `List[str]`, multiple columns will be tokenized. validation_split: Fraction of the dataset to use for validation. shuffle: Whether to shuffle the dataset. seed: Random seed. num_workers: Number of workers to use for encoding. use_eos_token: Whether to use EOS token to separate sequences. use_shared_memory: Whether to use shared memory for caching. cache_dir: Root path to the cache directory. Returns: Dataset provider. """ assert xor(tokenizer, tokenizer_name), "`tokenizer` and `tokenizer_name` are mutually exclusive." tokenizer = tokenizer or AutoTokenizer.from_pretrained(tokenizer_name) dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32 use_shared_memory = use_shared_memory and ALLOW_SHARED_MEMORY cache_dir = Path(cache_dir) if cache_dir.is_dir(): logger.warn(f"Cache: {cache_dir} already exists and will be overritten.") cache_dir.mkdir(parents=True, exist_ok=True) # Ensure that downloaded dataset is always a dictionary logger.info("Downloading dataset ...") hub_dataset_dict = load_dataset( dataset_name, name=dataset_config_name, data_dir=data_dir, data_files=data_files ) if not isinstance(hub_dataset_dict, DatasetDict): hub_dataset_dict = DatasetDict({"train": hub_dataset_dict}) # Ensure that `validation` and `test` splits are available hub_dataset_dict = FastHfDatasetProvider._create_splits(hub_dataset_dict, validation_split, shuffle, seed) encoded_dataset_dict = FastHfDatasetProvider._encode_dataset( hub_dataset_dict, tokenizer, mapping_fn, mapping_fn_kwargs, mapping_column_name, use_eos_token, dtype, num_workers, ) processed_dataset_dict = FastHfDatasetProvider._process_dataset_to_memory( encoded_dataset_dict, cache_dir, dtype, num_workers, use_shared_memory ) cache_files = FastHfDatasetProvider._save_dataset( processed_dataset_dict, tokenizer, cache_dir, use_shared_memory ) FastHfDatasetProvider._close_mem_maps(processed_dataset_dict) with open(cache_dir / "config.json", "w") as f: json.dump( { "dataset_name": dataset_name, "dataset_config_name": dataset_config_name, "data_dir": data_dir, "data_files": data_files, "tokenizer": { "name_or_path": tokenizer.name_or_path, "model_max_length": None, }, "mapping_column_name": mapping_column_name or ["text"], "validation_split": validation_split, "shuffle": shuffle, "seed": seed, "use_eos_token": use_eos_token, }, f, ) return FastHfDatasetProvider(**cache_files, tokenizer=tokenizer)
[docs] @classmethod def from_cache(cls: FastHfDatasetProvider, cache_dir: str) -> FastHfDatasetProvider: """Load a dataset provider from a cache directory. Args: cache_dir: Path to the cache directory. Returns: Dataset provider. """ logger.info(f"Loading dataset from: {cache_dir}") cache_dir = Path(cache_dir) cache_train_file = cache_dir / "train.npy" cache_validation_file = cache_dir / "validation.npy" cache_test_file = cache_dir / "test.npy" tokenizer_file = cache_dir / "tokenizer.pkl" try: with open(tokenizer_file, "rb") as f: tokenizer = pickle.load(f) except: logger.warn(f"Could not load tokenizer.pkl from {cache_dir}.") tokenizer = None return FastHfDatasetProvider(cache_train_file, cache_validation_file, cache_test_file, tokenizer=tokenizer)
[docs] @overrides def get_train_dataset(self, seq_len: Optional[int] = 1) -> FastHfDataset: input_ids = np.load(self.train_file, mmap_mode=self.mmap_mode) return FastHfDataset(input_ids, seq_len=seq_len)
[docs] @overrides def get_val_dataset(self, seq_len: Optional[int] = 1) -> FastHfDataset: input_ids = np.load(self.validation_file, mmap_mode=self.mmap_mode) return FastHfDataset(input_ids, seq_len=seq_len)
[docs] @overrides def get_test_dataset(self, seq_len: Optional[int] = 1) -> FastHfDataset: input_ids = np.load(self.test_file, mmap_mode=self.mmap_mode) return FastHfDataset(input_ids, seq_len=seq_len)
[docs]@dataclass class FastDataCollatorForLanguageModeling(DataCollatorForLanguageModeling): """Language modeling data collator compatible with FastHfDataset. Args: use_shifted_labels: Whether to use the original labels (shifted) or the non-shifted labels. """ use_shifted_labels: bool = False
[docs] def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: if isinstance(examples[0], Mapping): return super().torch_call(examples) batch = super().torch_call([example[0] for example in examples]) if self.use_shifted_labels: batch["labels"] = torch.stack([example[1] for example in examples], dim=0) return batch