diff --git a/metaseq/data/append_token_dataset.py b/metaseq/data/append_token_dataset.py index bf718304d..dd20df256 100644 --- a/metaseq/data/append_token_dataset.py +++ b/metaseq/data/append_token_dataset.py @@ -10,7 +10,7 @@ class AppendTokenDataset(BaseWrapperDataset): - def __init__(self, dataset, token=None): + def __init__(self, dataset: BaseWrapperDataset, token: str = None): super().__init__(dataset) self.token = token if token is not None: @@ -18,7 +18,7 @@ def __init__(self, dataset, token=None): else: self._sizes = dataset.sizes - def __getitem__(self, idx): + def __getitem__(self, idx: int): item = self.dataset[idx] if self.token is not None: item = torch.cat([item, item.new([self.token])]) @@ -28,13 +28,13 @@ def __getitem__(self, idx): def sizes(self): return self._sizes - def num_tokens(self, index): + def num_tokens(self, index: int): n = self.dataset.num_tokens(index) if self.token is not None: n += 1 return n - def size(self, index): + def size(self, index: int): n = self.dataset.size(index) if self.token is not None: n += 1 diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index cb3fad545..26dac2ddd 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -12,10 +12,11 @@ import json from functools import lru_cache +from typing import Union @lru_cache() -def bytes_to_unicode(): +def bytes_to_unicode() -> dict: """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. @@ -41,7 +42,7 @@ def bytes_to_unicode(): return dict(zip(bs, cs)) -def get_pairs(word): +def get_pairs(word: tuple) -> set: """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ @@ -54,7 +55,7 @@ def get_pairs(word): class Encoder: - def __init__(self, encoder, bpe_merges, errors="replace"): + def __init__(self, encoder: dict, bpe_merges: list, errors: str = "replace"): self.encoder = encoder self.decoder = {v: k for k, v in self.encoder.items()} self.errors = errors # how to handle errors in decoding @@ -75,7 +76,7 @@ def __init__(self, encoder, bpe_merges, errors="replace"): r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" ) - def bpe(self, token): + def bpe(self, token: str) -> Union[str, list]: if token in self.cache: return self.cache[token] word = tuple(token) @@ -116,7 +117,7 @@ def bpe(self, token): self.cache[token] = word return word - def encode(self, text): + def encode(self, text: str) -> list: bpe_tokens = [] for token in self.re.findall(self.pat, text): token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) @@ -125,7 +126,7 @@ def encode(self, text): ) return bpe_tokens - def decode(self, tokens): + def decode(self, tokens: list) -> str: text = "".join([self.decoder.get(token, token) for token in tokens]) text = bytearray([self.byte_decoder[c] for c in text]).decode( "utf-8", errors=self.errors @@ -133,7 +134,7 @@ def decode(self, tokens): return text -def get_encoder(encoder_json_path, vocab_bpe_path): +def get_encoder(encoder_json_path: str, vocab_bpe_path: str) -> Encoder: with open(encoder_json_path, "r") as f: encoder = json.load(f) with open(vocab_bpe_path, "r", encoding="utf-8") as f: