from typing import List, Optional, Tuple
[docs]class Vocabulary:
"""Vocabulary class for mapping between tokens and indices."""
def __init__(self, tokens: List[str], special_tokens: Optional[List[str]] = None):
"""Initializes the vocabulary.
This function initializes the vocabulary by adding the special tokens first
and then the tokens. The order of the tokens is preserved.
If <unk> is not provided in the special_tokens, then the tokenizer
will raise an exception if an unknown token is encountered.
If padding is performed on the input tokens, padding token <pad> should always
be added to the special_tokens.
Args:
tokens: List[str], list of tokens in the vocabulary.
special_tokens: Optional[List[str]], list of special tokens to add to
the vocabulary. (e.g., <pad>, <unk>). Default is empty list.
Note:
If vocabulary is used to convert output labels to indices, one should
be very careful about the special tokens.
"""
if special_tokens is None:
special_tokens = []
all_tokens = special_tokens + tokens
self.token2idx = {}
self.idx2token = {}
self.idx = 0
for token in all_tokens:
self.add_token(token)
[docs] def add_token(self, token):
"""Adds a token to the vocabulary."""
if token not in self.token2idx:
self.token2idx[token] = self.idx
self.idx2token[self.idx] = token
self.idx += 1
def __call__(self, token):
"""Retrieves the index of the token.
Note that if the token is not in the vocabulary, this function will try to
return the index of <unk>. If <unk> is not in the vocabulary,
an exception will be raised.
"""
if token not in self.token2idx:
if "<unk>" in self.token2idx:
return self.token2idx["<unk>"]
else:
raise ValueError("Unknown token: {}".format(token))
return self.token2idx[token]
def __len__(self):
"""Returns the size of the vocabulary."""
return len(self.token2idx)
def __contains__(self, token):
return token in self.token2idx
[docs]class Tokenizer:
"""Tokenizer class for converting tokens to indices and vice versa.
This class will build a vocabulary from the provided tokens and provide the
functionality to convert tokens to indices and vice versa. This class also
provides the functionality to tokenize a batch of data.
"""
def __init__(self, tokens: List[str], special_tokens: Optional[List[str]] = None):
"""Initializes the tokenizer.
Args:
tokens: List[str], list of tokens in the vocabulary.
special_tokens: Optional[List[str]], list of special tokens to add to
the vocabulary. (e.g., <pad>, <unk>). Default is empty list.
"""
self.vocabulary = Vocabulary(tokens=tokens, special_tokens=special_tokens)
[docs] def get_padding_index(self):
"""Returns the index of the padding token."""
return self.vocabulary("<pad>")
[docs] def get_vocabulary_size(self):
"""Returns the size of the vocabulary."""
return len(self.vocabulary)
[docs] def convert_tokens_to_indices(self, tokens: List[str]) -> List[int]:
"""Converts a list of tokens to indices."""
return [self.vocabulary(token) for token in tokens]
[docs] def convert_indices_to_tokens(self, indices: List[int]) -> List[str]:
"""Converts a list of indices to tokens."""
return [self.vocabulary.idx2token[idx] for idx in indices]
[docs] def batch_encode_2d(
self,
batch: List[List[str]],
padding: bool = True,
truncation: bool = True,
max_length: int = 512,
):
"""Converts a list of lists of tokens (2D) to indices.
Args:
batch: List of lists of tokens to convert to indices.
padding: whether to pad the tokens to the max number of tokens in
the batch (smart padding).
truncation: whether to truncate the tokens to max_length.
max_length: maximum length of the tokens. This argument is ignored
if truncation is False.
"""
if truncation:
batch = [tokens[-max_length:] for tokens in batch]
if padding:
batch_max_length = max([len(tokens) for tokens in batch])
batch = [
tokens + ["<pad>"] * (batch_max_length - len(tokens))
for tokens in batch
]
return [[self.vocabulary(token) for token in tokens] for tokens in batch]
[docs] def batch_decode_2d(
self,
batch: List[List[int]],
padding: bool = False,
):
"""Converts a list of lists of indices (2D) to tokens.
Args:
batch: List of lists of indices to convert to tokens.
padding: whether to keep the padding tokens from the tokens.
"""
batch = [[self.vocabulary.idx2token[idx] for idx in tokens] for tokens in batch]
if not padding:
return [[token for token in tokens if token != "<pad>"] for tokens in batch]
return batch
[docs] def batch_encode_3d(
self,
batch: List[List[List[str]]],
padding: Tuple[bool, bool] = (True, True),
truncation: Tuple[bool, bool] = (True, True),
max_length: Tuple[int, int] = (10, 512),
):
"""Converts a list of lists of lists of tokens (3D) to indices.
Args:
batch: List of lists of lists of tokens to convert to indices.
padding: a tuple of two booleans indicating whether to pad the tokens
to the max number of tokens and visits (smart padding).
truncation: a tuple of two booleans indicating whether to truncate the
tokens to the corresponding element in max_length
max_length: a tuple of two integers indicating the maximum length of the
tokens along the first and second dimension. This argument is ignored
if truncation is False.
"""
if truncation[0]:
batch = [tokens[-max_length[0] :] for tokens in batch]
if truncation[1]:
batch = [
[tokens[-max_length[1] :] for tokens in visits] for visits in batch
]
if padding[0]:
batch_max_length = max([len(tokens) for tokens in batch])
batch = [
tokens + [["<pad>"]] * (batch_max_length - len(tokens))
for tokens in batch
]
if padding[1]:
batch_max_length = max(
[max([len(tokens) for tokens in visits]) for visits in batch]
)
batch = [
[
tokens + ["<pad>"] * (batch_max_length - len(tokens))
for tokens in visits
]
for visits in batch
]
return [
[[self.vocabulary(token) for token in tokens] for tokens in visits]
for visits in batch
]
[docs] def batch_decode_3d(
self,
batch: List[List[List[int]]],
padding: bool = False,
):
"""Converts a list of lists of lists of indices (3D) to tokens.
Args:
batch: List of lists of lists of indices to convert to tokens.
padding: whether to keep the padding tokens from the tokens.
"""
batch = [
self.batch_decode_2d(batch=visits, padding=padding) for visits in batch
]
if not padding:
batch = [[visit for visit in visits if visit != []] for visits in batch]
return batch
if __name__ == "__main__":
tokens = ["a", "b", "c", "d", "e", "f", "g", "h"]
tokenizer = Tokenizer(tokens=tokens, special_tokens=["<pad>", "<unk>"])
print(tokenizer.get_vocabulary_size())
out = tokenizer.convert_tokens_to_indices(["a", "b", "c", "d", "e", "z"])
print(out)
print(tokenizer.convert_indices_to_tokens(out))
out = tokenizer.batch_encode_2d(
[["a", "b", "c", "e", "z"], ["a", "b", "c", "d", "e", "z"]],
padding=True,
truncation=True,
max_length=10,
)
print(out)
print(tokenizer.batch_decode_2d(out, padding=False))
out = tokenizer.batch_encode_3d(
[
[["a", "b", "c", "e", "z"], ["a", "b", "c", "d", "e", "z"]],
[["a", "b", "c", "e", "z"], ["a", "b", "c", "d", "e", "z"], ["c", "f"]],
],
padding=(True, True),
truncation=(True, True),
max_length=(10, 10),
)
print(out)
print(tokenizer.batch_decode_3d(out, padding=False))