Source code for archai.datasets.nlp.tokenizer_utils.word_tokenizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
from collections import Counter, OrderedDict
from typing import List, Optional

from overrides import overrides

from archai.common.distributed_utils import sync_workers
from archai.common.file_utils import get_full_path
from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.datasets.nlp.tokenizer_utils.token_config import (
    SpecialTokenEnum,
    TokenConfig,
)
from archai.datasets.nlp.tokenizer_utils.tokenizer_base import TokenizerBase

logger = OrderedDictLogger(source=__name__)


[docs]class WordTokenizer(TokenizerBase): """Word-based tokenizer.""" def __init__( self, save_path: str, vocab_size: Optional[int] = None, bos_token: Optional[str] = None, eos_token: Optional[str] = "<eos>", unk_token: Optional[str] = "<unk>", min_frequency: Optional[int] = 0, lower_case: Optional[int] = False, delimiter: Optional[str] = None, encode_special_tokens: Optional[bool] = True, decode_special_tokens: Optional[bool] = True, ): """Define the tokenization pipeline. Args: save_path: Path to save the tokenizer. vocab_size: Maximum size of vocabulary. bos_token: Begin-of-sentence token. eos_token: End-of-sentence token. unk_token: Unknown token. min_frequency: Minimum frequency of tokens. model_max_length: Maximum length of sequence. lower_case: Whether lower case should be applied. delimiter: Delimiter between tokens. encode_special_tokens: Whether special tokens should be encoded. decode_special_tokens: Whether special tokens should be decoded. """ self.counter = Counter() # No prefix space or line needed as we delimit on white space unlike in bbpe self._config = TokenConfig( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=None, add_prefix_space=False, add_prefix_new_line=False, lower_case=lower_case, ) assert self._config.unk_token, "`unk_token` must be supplied." self._bos = [self._config.bos_token] if self._config.bos_token else [] self._eos = [self._config.eos_token] if self._config.eos_token else [] self.save_path = save_path self.vocab_size = vocab_size self.min_frequency = min_frequency self.delimiter = delimiter self.encode_special_tokens = encode_special_tokens self.decode_special_tokens = decode_special_tokens @overrides def __len__(self) -> int: return len(self.idx2sym)
[docs] @overrides def train(self, filepaths: List[str]) -> None: logger.info( f"Training tokenizer with min_frequency = {self.min_frequency} and vocab_size = {self.vocab_size}, using {len(filepaths)} training file(s) at {self.save_path} ..." ) assert len(filepaths) self._clear() for filepath in filepaths: self._add_file(filepath) # Adds specials tokens regardless of vocab_size for sym in self._config.get_special_tokens(): self._add_special(sym) remaining_len = self.vocab_size - len(self) if self.vocab_size is not None else None for sym, cnt in self.counter.most_common(remaining_len): if cnt < self.min_frequency: break self._add_symbol(sym) with sync_workers() as rank: if rank == 0: self._save() logger.info(f"Vocabulary size = {len(self)} | Unique tokens = {len(self.counter)}")
[docs] @overrides def is_trained(self) -> bool: vocab_filepath = self._vocab_filepath() return os.path.exists(vocab_filepath)
[docs] @overrides def load(self) -> None: vocab_filepath = self._vocab_filepath() self._clear() with open(vocab_filepath, "r", encoding="utf-8") as f: for line in f: symb = line.strip().split()[0] self._add_symbol(symb) self.unk_idx = self.sym2idx[self._config.unk_token]
[docs] @overrides def encode_text(self, text: str) -> List[int]: symbols = self._tokenize_text(text) if self.encode_special_tokens: symbols = self._bos + symbols + self._eos toks = self.tokens_to_ids(symbols) return toks
[docs] @overrides def decode_text(self, ids: List[int]) -> str: syms = self.ids_to_tokens(ids) if self.decode_special_tokens and len(syms): if syms[0] == self._bos: syms = syms[1:] if len(syms) and syms[-1] == self._eos: syms = syms[:-1] return " ".join(syms)
[docs] @overrides def special_token_id(self, sp: SpecialTokenEnum) -> int: return self.token_to_id(self._config.special_token_name(sp))
[docs] @overrides def token_to_id(self, t: str) -> int: return self._get_idx(t)
[docs] @overrides def id_to_token(self, id: int) -> str: return self._get_sym(id)
[docs] @overrides def tokens_to_ids(self, ts: List[str]) -> List[int]: return [self._get_idx(t) for t in ts]
[docs] @overrides def ids_to_tokens(self, ids: List[int]) -> List[str]: return [self._get_sym(id) for id in ids]
def _preprocess_text(self, text: str) -> str: if self._config.add_prefix_space: text = " " + text if self._config.add_prefix_new_line: text = "\n" + text if self._config.lower_case: text = text.lower() return text def _add_file(self, path: str, verbose: Optional[bool] = True) -> None: if verbose: logger.debug(f"Counting file: {path}") assert os.path.exists(path), f"File does not exist: {path}" with open(path, "r", encoding="utf-8") as f: for idx, line in enumerate(f): if verbose and idx > 0 and idx % 500000 == 0: logger.debug(f"Completed line: {idx}") symbols = self._tokenize_text(line) self.counter.update(symbols) def _tokenize_text(self, text: str) -> List[str]: text = self._preprocess_text(text) symbols = text.split(self.delimiter) return symbols def _clear(self) -> None: self.idx2sym = [] self.sym2idx = OrderedDict() def _vocab_filepath(self) -> str: vocab_dir = get_full_path(os.path.join(self.save_path), create_folder=True) return os.path.join(vocab_dir, "vocab.txt") def _save(self) -> None: vocab_filepath = self._vocab_filepath() with open(vocab_filepath, "w", encoding="utf-8") as f: f.write("\n".join(self.idx2sym)) def _add_special(self, sym: str) -> None: if sym not in self.sym2idx: self.idx2sym.append(sym) self.sym2idx[sym] = len(self.idx2sym) - 1 setattr(self, "{}_idx".format(sym.strip("<>")), self.sym2idx[sym]) def _add_symbol(self, sym: str) -> None: if sym not in self.sym2idx: self.idx2sym.append(sym) self.sym2idx[sym] = len(self.idx2sym) - 1 def _get_sym(self, idx: int) -> str: assert 0 <= idx < len(self), f"Index {idx} out of range." return self.idx2sym[idx] def _get_idx(self, sym: str) -> int: if sym in self.sym2idx: return self.sym2idx[sym] return self.sym2idx.get(sym, self.unk_idx)