Source code for archai.datasets.nlp.tokenizer_utils.tokenizer_base
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import abstractmethod
from typing import List, Optional
import torch
from overrides import EnforceOverrides
from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.datasets.nlp.tokenizer_utils.token_config import SpecialTokenEnum
logger = OrderedDictLogger(source=__name__)
[docs]class TokenizerBase(EnforceOverrides):
"""Abstract class for tokenizers.
This class serves as a base for training, encoding and decoding. The class enforces
implementation of nine methods: `__len__`, `train`, `is_trained`, `load`, `encode_text`,
`decode_text`, `special_token_id`, `token_to_id` and `id_to_token`.
Note:
This class is inherited from `EnforceOverrides` and any overridden methods in the
subclass should be decorated with `@overrides` to ensure they are properly overridden.
"""
@abstractmethod
def __len__(self) -> int:
"""Get the length of the vocabulary.
Returns:
The length of the vocabulary.
"""
pass
[docs] @abstractmethod
def train(self, filepaths: List[str]) -> None:
"""Train the tokenizer on a list of files.
Args:
filepaths: A list of paths to input files.
"""
pass
[docs] @abstractmethod
def is_trained(self) -> bool:
"""Check if the vocabulary has been trained.
Returns:
`True` if the vocabulary has been trained, `False` otherwise.
"""
pass
[docs] @abstractmethod
def load(self) -> None:
"""Load a pre-trained tokenizer."""
pass
[docs] @abstractmethod
def encode_text(self, text: str) -> List[int]:
"""Encode text into tokens.
Args:
text: The input text to encode.
Returns:
The encoded text (tokens).
"""
pass
[docs] @abstractmethod
def decode_text(self, ids: List[int]) -> str:
"""Decode tokens into text.
Args:
ids: The tokens to decode.
Returns:
The decoded tokens (text).
"""
pass
[docs] @abstractmethod
def special_token_id(self, sp: SpecialTokenEnum) -> int:
"""Get the identifier of a special token.
Args:
sp: The special token's enumerator.
Returns:
The special token's identifier.
"""
pass
[docs] @abstractmethod
def token_to_id(self, t: str) -> int:
"""Convert a string-based token to its identifier.
Args:
t: The string-based token.
Returns:
The token's identifier.
"""
pass
[docs] @abstractmethod
def id_to_token(self, id: int) -> str:
"""Convert a token identifier to its string-based representation.
Args:
id: The token's identifier.
Returns:
The string-based token.
"""
pass
[docs] def tokens_to_ids(self, ts: List[str]) -> List[int]:
"""Convert a list of string-based tokens to their corresponding identifiers.
Args:
ts: A list of string-based tokens.
Returns:
The identifiers corresponding to the input tokens.
"""
return [self.token_to_id(t) for t in ts]
[docs] def ids_to_tokens(self, ids: List[int]) -> List[str]:
"""Convert a list of tokens' identifiers to their string-based representations.
Args:
ids: A list of tokens' identifiers.
Returns:
The string-based representations of the input tokens.
"""
return [self.id_to_token(id) for id in ids]
[docs] def encode_file(self, path: str, verbose: Optional[bool] = True) -> torch.Tensor:
"""Encode text from an input file.
This method reads text from the specified file and encodes it using
the `encode_text` method. It also includes options for verbosity and
efficiently handling large datasets by converting the encoded tokens
to a `torch.Tensor` every 500k lines.
Args:
path: The path to the input file.
verbose: Whether to add verbosity to the logger.
Returns:
The encoded tokens.
"""
logger.info(f"Encoding file: {path}")
encoded = []
tensor_encoded = torch.LongTensor()
with open(path, "r", encoding="utf-8") as f:
for idx, line in enumerate(f):
# Converts to tensor.Tensor every 500k lines,
# otherwise Python list uses a lot of RAM
if idx > 0 and idx % 500000 == 0:
tensor_encoded = torch.cat((tensor_encoded, torch.LongTensor(encoded)))
encoded = []
if verbose and idx > 0 and idx % 500000 == 0:
logger.debug(f"Completed line: {format(idx)}")
tokens = self.encode_text(line)
encoded.extend(tokens)
if len(encoded) > 0:
tensor_encoded = torch.cat((tensor_encoded, torch.LongTensor(encoded)))
return tensor_encoded