from typing import List, Optional, Tuple
[docs]class Vocabulary:
"""Vocabulary class for mapping between tokens and indices."""
def __init__(self, tokens: List[str], special_tokens: Optional[List[str]] = None):
"""Initializes the vocabulary.
This function initializes the vocabulary by adding the special tokens first
and then the tokens. The order of the tokens is preserved.
If <unk> is not provided in the special_tokens, then the tokenizer
will raise an exception if an unknown token is encountered.
If padding is performed on the input tokens, padding token <pad> should always
be added to the special_tokens.
Args:
tokens: List[str], list of tokens in the vocabulary.
special_tokens: Optional[List[str]], list of special tokens to add to
the vocabulary. (e.g., <pad>, <unk>). Default is empty list.
Note:
If vocabulary is used to convert output labels to indices, one should
be very careful about the special tokens.
"""
if special_tokens is None:
special_tokens = []
all_tokens = special_tokens + tokens
self.token2idx = {}
self.idx2token = {}
self.idx = 0
for token in all_tokens:
self.add_token(token)
[docs] def add_token(self, token):
"""Adds a token to the vocabulary."""
if token not in self.token2idx:
self.token2idx[token] = self.idx
self.idx2token[self.idx] = token
self.idx += 1
def __call__(self, token):
"""Retrieves the index of the token.
Note that if the token is not in the vocabulary, this function will try to
return the index of <unk>. If <unk> is not in the vocabulary,
an exception will be raised.
"""
if token not in self.token2idx:
if "<unk>" in self.token2idx:
return self.token2idx["<unk>"]
else:
raise ValueError("Unknown token: {}".format(token))
return self.token2idx[token]
def __len__(self):
"""Returns the size of the vocabulary."""
return len(self.token2idx)
def __contains__(self, token):
return token in self.token2idx
[docs]class Tokenizer:
"""Tokenizer class for converting tokens to indices and vice versa.
This class will build a vocabulary from the provided tokens and provide the
functionality to convert tokens to indices and vice versa. This class also
provides the functionality to tokenize a batch of data.
Examples:
>>> from pyhealth.tokenizer import Tokenizer
>>> token_space = ['A01A', 'A02A', 'A02B', 'A02X', 'A03A', 'A03B', 'A03C', 'A03D', 'A03E', \
... 'A03F', 'A04A', 'A05A', 'A05B', 'A05C', 'A06A', 'A07A', 'A07B', 'A07C', \
... 'A07D', 'A07E', 'A07F', 'A07X', 'A08A', 'A09A', 'A10A', 'A10B', 'A10X', \
... 'A11A', 'A11B', 'A11C', 'A11D', 'A11E', 'A11G', 'A11H', 'A11J', 'A12A', \
... 'A12B', 'A12C', 'A13A', 'A14A', 'A14B', 'A16A']
>>> tokenizer = Tokenizer(tokens=token_space, special_tokens=["<pad>", "<unk>"])
"""
def __init__(self, tokens: List[str], special_tokens: Optional[List[str]] = None):
"""Initializes the tokenizer.
Args:
tokens: List[str], list of tokens in the vocabulary.
special_tokens: Optional[List[str]], list of special tokens to add to
the vocabulary. (e.g., <pad>, <unk>). Default is empty list.
"""
self.vocabulary = Vocabulary(tokens=tokens, special_tokens=special_tokens)
[docs] def get_padding_index(self):
"""Returns the index of the padding token."""
return self.vocabulary("<pad>")
[docs] def get_vocabulary_size(self):
"""Returns the size of the vocabulary.
Examples:
>>> tokenizer.get_vocabulary_size()
44
"""
return len(self.vocabulary)
[docs] def convert_tokens_to_indices(self, tokens: List[str]) -> List[int]:
"""Converts a list of tokens to indices.
Examples:
>>> tokens = ['A03C', 'A03D', 'A03E', 'A03F', 'A04A', 'A05A', 'A05B', 'B035', 'C129']
>>> indices = tokenizer.convert_tokens_to_indices(tokens)
>>> print(indices)
[8, 9, 10, 11, 12, 13, 14, 1, 1]
"""
return [self.vocabulary(token) for token in tokens]
[docs] def convert_indices_to_tokens(self, indices: List[int]) -> List[str]:
"""Converts a list of indices to tokens.
Examples:
>>> indices = [0, 1, 2, 3, 4, 5]
>>> tokens = tokenizer.convert_indices_to_tokens(indices)
>>> print(tokens)
['<pad>', '<unk>', 'A01A', 'A02A', 'A02B', 'A02X']
"""
return [self.vocabulary.idx2token[idx] for idx in indices]
[docs] def batch_encode_2d(
self,
batch: List[List[str]],
padding: bool = True,
truncation: bool = True,
max_length: int = 512,
):
"""Converts a list of lists of tokens (2D) to indices.
Args:
batch: List of lists of tokens to convert to indices.
padding: whether to pad the tokens to the max number of tokens in
the batch (smart padding).
truncation: whether to truncate the tokens to max_length.
max_length: maximum length of the tokens. This argument is ignored
if truncation is False.
Examples:
>>> tokens = [
... ['A03C', 'A03D', 'A03E', 'A03F'],
... ['A04A', 'B035', 'C129']
... ]
>>> indices = tokenizer.batch_encode_2d(tokens)
>>> print ('case 1:', indices)
case 1: [[8, 9, 10, 11], [12, 1, 1, 0]]
>>> indices = tokenizer.batch_encode_2d(tokens, padding=False)
>>> print ('case 2:', indices)
case 2: [[8, 9, 10, 11], [12, 1, 1]]
>>> indices = tokenizer.batch_encode_2d(tokens, max_length=3)
>>> print ('case 3:', indices)
case 3: [[9, 10, 11], [12, 1, 1]]
"""
if truncation:
batch = [tokens[-max_length:] for tokens in batch]
if padding:
batch_max_length = max([len(tokens) for tokens in batch])
batch = [
tokens + ["<pad>"] * (batch_max_length - len(tokens))
for tokens in batch
]
return [[self.vocabulary(token) for token in tokens] for tokens in batch]
[docs] def batch_decode_2d(
self,
batch: List[List[int]],
padding: bool = False,
):
"""Converts a list of lists of indices (2D) to tokens.
Args:
batch: List of lists of indices to convert to tokens.
padding: whether to keep the padding tokens from the tokens.
Examples:
>>> indices = [
... [8, 9, 10, 11],
... [12, 1, 1, 0]
... ]
>>> tokens = tokenizer.batch_decode_2d(indices)
>>> print ('case 1:', tokens)
case 1: [['A03C', 'A03D', 'A03E', 'A03F'], ['A04A', '<unk>', '<unk>']]
>>> tokens = tokenizer.batch_decode_2d(indices, padding=True)
>>> print ('case 2:', tokens)
case 2: [['A03C', 'A03D', 'A03E', 'A03F'], ['A04A', '<unk>', '<unk>', '<pad>']]
"""
batch = [[self.vocabulary.idx2token[idx] for idx in tokens] for tokens in batch]
if not padding:
return [[token for token in tokens if token != "<pad>"] for tokens in batch]
return batch
[docs] def batch_encode_3d(
self,
batch: List[List[List[str]]],
padding: Tuple[bool, bool] = (True, True),
truncation: Tuple[bool, bool] = (True, True),
max_length: Tuple[int, int] = (10, 512),
):
"""Converts a list of lists of lists of tokens (3D) to indices.
Args:
batch: List of lists of lists of tokens to convert to indices.
padding: a tuple of two booleans indicating whether to pad the tokens
to the max number of tokens and visits (smart padding).
truncation: a tuple of two booleans indicating whether to truncate the
tokens to the corresponding element in max_length
max_length: a tuple of two integers indicating the maximum length of the
tokens along the first and second dimension. This argument is ignored
if truncation is False.
Examples:
>>> tokens = [
... [
... ['A03C', 'A03D', 'A03E', 'A03F'],
... ['A08A', 'A09A'],
... ],
... [
... ['A04A', 'B035', 'C129'],
... ]
... ]
>>> indices = tokenizer.batch_encode_3d(tokens)
>>> print ('case 1:', indices)
case 1: [[[8, 9, 10, 11], [24, 25, 0, 0]], [[12, 1, 1, 0], [0, 0, 0, 0]]]
>>> indices = tokenizer.batch_encode_3d(tokens, padding=(False, True))
>>> print ('case 2:', indices)
case 2: [[[8, 9, 10, 11], [24, 25, 0, 0]], [[12, 1, 1, 0]]]
>>> indices = tokenizer.batch_encode_3d(tokens, padding=(True, False))
>>> print ('case 3:', indices)
case 3: [[[8, 9, 10, 11], [24, 25]], [[12, 1, 1], [0]]]
>>> indices = tokenizer.batch_encode_3d(tokens, padding=(False, False))
>>> print ('case 4:', indices)
case 4: [[[8, 9, 10, 11], [24, 25]], [[12, 1, 1]]]
>>> indices = tokenizer.batch_encode_3d(tokens, max_length=(2,2))
>>> print ('case 5:', indices)
case 5: [[[10, 11], [24, 25]], [[1, 1], [0, 0]]]
"""
if truncation[0]:
batch = [tokens[-max_length[0] :] for tokens in batch]
if truncation[1]:
batch = [
[tokens[-max_length[1] :] for tokens in visits] for visits in batch
]
if padding[0]:
batch_max_length = max([len(tokens) for tokens in batch])
batch = [
tokens + [["<pad>"]] * (batch_max_length - len(tokens))
for tokens in batch
]
if padding[1]:
batch_max_length = max(
[max([len(tokens) for tokens in visits]) for visits in batch]
)
batch = [
[
tokens + ["<pad>"] * (batch_max_length - len(tokens))
for tokens in visits
]
for visits in batch
]
return [
[[self.vocabulary(token) for token in tokens] for tokens in visits]
for visits in batch
]
[docs] def batch_decode_3d(
self,
batch: List[List[List[int]]],
padding: bool = False,
):
"""Converts a list of lists of lists of indices (3D) to tokens.
Args:
batch: List of lists of lists of indices to convert to tokens.
padding: whether to keep the padding tokens from the tokens.
Examples:
>>> indices = [
... [
... [8, 9, 10, 11],
... [24, 25, 0, 0]
... ],
... [
... [12, 1, 1, 0],
... [0, 0, 0, 0]
... ]
... ]
>>> tokens = tokenizer.batch_decode_3d(indices)
>>> print ('case 1:', tokens)
case 1: [[['A03C', 'A03D', 'A03E', 'A03F'], ['A08A', 'A09A']], [['A04A', '<unk>', '<unk>']]]
>>> tokens = tokenizer.batch_decode_3d(indices, padding=True)
>>> print ('case 2:', tokens)
case 2: [[['A03C', 'A03D', 'A03E', 'A03F'], ['A08A', 'A09A', '<pad>', '<pad>']], [['A04A', '<unk>', '<unk>', '<pad>'], ['<pad>', '<pad>', '<pad>', '<pad>']]]
"""
batch = [
self.batch_decode_2d(batch=visits, padding=padding) for visits in batch
]
if not padding:
batch = [[visit for visit in visits if visit != []] for visits in batch]
return batch
if __name__ == "__main__":
tokens = ["a", "b", "c", "d", "e", "f", "g", "h"]
tokenizer = Tokenizer(tokens=tokens, special_tokens=["<pad>", "<unk>"])
print(tokenizer.get_vocabulary_size())
out = tokenizer.convert_tokens_to_indices(["a", "b", "c", "d", "e", "z"])
print(out)
print(tokenizer.convert_indices_to_tokens(out))
out = tokenizer.batch_encode_2d(
[["a", "b", "c", "e", "z"], ["a", "b", "c", "d", "e", "z"]],
padding=True,
truncation=True,
max_length=10,
)
print(out)
print(tokenizer.batch_decode_2d(out, padding=False))
out = tokenizer.batch_encode_3d(
[
[["a", "b", "c", "e", "z"], ["a", "b", "c", "d", "e", "z"]],
[["a", "b", "c", "e", "z"], ["a", "b", "c", "d", "e", "z"], ["c", "f"]],
],
padding=(True, True),
truncation=(True, True),
max_length=(10, 10),
)
print(out)
print(tokenizer.batch_decode_3d(out, padding=False))