From 1a3c2bbfbc1e9bb79b92a125c3fb7b2bbf13c9bf Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Fri, 21 Apr 2023 22:59:31 -0700 Subject: [PATCH 01/13] start adding type for gpt2_bpe_utils --- metaseq/data/encoders/gpt2_bpe_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index cb3fad545..6dfe0ff8f 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -15,7 +15,7 @@ @lru_cache() -def bytes_to_unicode(): +def bytes_to_unicode() -> dict: """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. @@ -41,7 +41,7 @@ def bytes_to_unicode(): return dict(zip(bs, cs)) -def get_pairs(word): +def get_pairs(word): #TODO """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ @@ -75,7 +75,7 @@ def __init__(self, encoder, bpe_merges, errors="replace"): r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" ) - def bpe(self, token): + def bpe(self, token: list) -> str: if token in self.cache: return self.cache[token] word = tuple(token) @@ -116,7 +116,7 @@ def bpe(self, token): self.cache[token] = word return word - def encode(self, text): + def encode(self, text: str) -> list: bpe_tokens = [] for token in self.re.findall(self.pat, text): token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) @@ -125,7 +125,7 @@ def encode(self, text): ) return bpe_tokens - def decode(self, tokens): + def decode(self, tokens: list) -> str: text = "".join([self.decoder.get(token, token) for token in tokens]) text = bytearray([self.byte_decoder[c] for c in text]).decode( "utf-8", errors=self.errors @@ -133,7 +133,7 @@ def decode(self, tokens): return text -def get_encoder(encoder_json_path, vocab_bpe_path): +def get_encoder(encoder_json_path, vocab_bpe_path) -> Encoder: with open(encoder_json_path, "r") as f: encoder = json.load(f) with open(vocab_bpe_path, "r", encoding="utf-8") as f: From 4c27a23637d2392a881c9362a02aac9388a0410e Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Wed, 26 Apr 2023 22:16:28 -0700 Subject: [PATCH 02/13] adding typing --- metaseq/data/encoders/gpt2_bpe_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index 6dfe0ff8f..095ad66d2 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -12,6 +12,7 @@ import json from functools import lru_cache +from typing import Set @lru_cache() @@ -41,7 +42,7 @@ def bytes_to_unicode() -> dict: return dict(zip(bs, cs)) -def get_pairs(word): #TODO +def get_pairs(word: tuple) -> set: """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ From 6f88b3ce883c653534dc277d21e6a7501ff68658 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Wed, 26 Apr 2023 22:17:22 -0700 Subject: [PATCH 03/13] add typing --- metaseq/data/encoders/gpt2_bpe_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index 095ad66d2..2b603a9f8 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -134,7 +134,7 @@ def decode(self, tokens: list) -> str: return text -def get_encoder(encoder_json_path, vocab_bpe_path) -> Encoder: +def get_encoder(encoder_json_path: str, vocab_bpe_path: str) -> Encoder: with open(encoder_json_path, "r") as f: encoder = json.load(f) with open(vocab_bpe_path, "r", encoding="utf-8") as f: From a88d550ff11414384e168f2b1530aeaaddc4ceed Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Thu, 27 Apr 2023 10:58:35 -0700 Subject: [PATCH 04/13] test results --- metaseq/data/encoders/gpt2_bpe_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index 2b603a9f8..85f2df4d9 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -55,7 +55,7 @@ def get_pairs(word: tuple) -> set: class Encoder: - def __init__(self, encoder, bpe_merges, errors="replace"): + def __init__(self, encoder, bpe_merges, errors="replace"): # TODO self.encoder = encoder self.decoder = {v: k for k, v in self.encoder.items()} self.errors = errors # how to handle errors in decoding @@ -76,7 +76,7 @@ def __init__(self, encoder, bpe_merges, errors="replace"): r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" ) - def bpe(self, token: list) -> str: + def bpe(self, token: list): if token in self.cache: return self.cache[token] word = tuple(token) From 47542858385095ae82c232cfd315c6eae9512fb6 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Fri, 28 Apr 2023 15:20:06 -0700 Subject: [PATCH 05/13] test --- metaseq/data/encoders/gpt2_bpe_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index 85f2df4d9..7a433e85c 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -42,7 +42,7 @@ def bytes_to_unicode() -> dict: return dict(zip(bs, cs)) -def get_pairs(word: tuple) -> set: +def get_pairs(word): """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ From 86b627ced24d21bd035f0079790a582f66872aaf Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Tue, 2 May 2023 15:28:08 -0700 Subject: [PATCH 06/13] try to pass test --- metaseq/data/encoders/gpt2_bpe_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index 7a433e85c..9308bd332 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -76,7 +76,7 @@ def __init__(self, encoder, bpe_merges, errors="replace"): # TODO r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" ) - def bpe(self, token: list): + def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token) From 67bf9fffb94e433a291b4ed524ae5de0339eaed3 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Wed, 3 May 2023 13:56:52 -0700 Subject: [PATCH 07/13] pass test --- metaseq/data/encoders/gpt2_bpe_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index 9308bd332..c8b992855 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -12,7 +12,6 @@ import json from functools import lru_cache -from typing import Set @lru_cache() From 6c6c0466af4c870168a4ccbcf4433bac1f272ed8 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Wed, 3 May 2023 17:16:00 -0700 Subject: [PATCH 08/13] add tuple --- metaseq/data/encoders/gpt2_bpe_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index c8b992855..2b99f79fe 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -41,7 +41,7 @@ def bytes_to_unicode() -> dict: return dict(zip(bs, cs)) -def get_pairs(word): +def get_pairs(word: tuple) -> set: """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ From 225bcea16dfcdf2f320e9a52efa0a83581d8b14f Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Wed, 3 May 2023 18:58:44 -0700 Subject: [PATCH 09/13] add type --- metaseq/data/encoders/gpt2_bpe_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index 2b99f79fe..132c6a4dc 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -26,16 +26,16 @@ def bytes_to_unicode() -> dict: And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) ) cs = bs[:] n = 0 - for b in range(2**8): + for b in range(2 ** 8): if b not in bs: bs.append(b) - cs.append(2**8 + n) + cs.append(2 ** 8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) @@ -54,7 +54,7 @@ def get_pairs(word: tuple) -> set: class Encoder: - def __init__(self, encoder, bpe_merges, errors="replace"): # TODO + def __init__(self, encoder, bpe_merges: list, errors: str = "replace"): # TODO self.encoder = encoder self.decoder = {v: k for k, v in self.encoder.items()} self.errors = errors # how to handle errors in decoding From 4fdd25d4036b4c830ec0bb3eab093f4753b32bdc Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Wed, 3 May 2023 19:50:24 -0700 Subject: [PATCH 10/13] add type --- metaseq/data/encoders/gpt2_bpe_utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index 132c6a4dc..8fd8b377d 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -12,6 +12,7 @@ import json from functools import lru_cache +from typing import Any @lru_cache() @@ -26,16 +27,16 @@ def bytes_to_unicode() -> dict: And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) ) cs = bs[:] n = 0 - for b in range(2 ** 8): + for b in range(2**8): if b not in bs: bs.append(b) - cs.append(2 ** 8 + n) + cs.append(2**8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) @@ -54,7 +55,9 @@ def get_pairs(word: tuple) -> set: class Encoder: - def __init__(self, encoder, bpe_merges: list, errors: str = "replace"): # TODO + def __init__( + self, encoder: dict, bpe_merges: list, errors: str = "replace" + ): self.encoder = encoder self.decoder = {v: k for k, v in self.encoder.items()} self.errors = errors # how to handle errors in decoding @@ -75,7 +78,7 @@ def __init__(self, encoder, bpe_merges: list, errors: str = "replace"): # TODO r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" ) - def bpe(self, token): + def bpe(self, token: str) -> Any[str, list]: if token in self.cache: return self.cache[token] word = tuple(token) From 77bb07a240cec8bb93704720bad6030458c8e95f Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Wed, 3 May 2023 19:53:42 -0700 Subject: [PATCH 11/13] fix format --- metaseq/data/encoders/gpt2_bpe_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index 8fd8b377d..a784cf325 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -55,9 +55,7 @@ def get_pairs(word: tuple) -> set: class Encoder: - def __init__( - self, encoder: dict, bpe_merges: list, errors: str = "replace" - ): + def __init__(self, encoder: dict, bpe_merges: list, errors: str = "replace"): self.encoder = encoder self.decoder = {v: k for k, v in self.encoder.items()} self.errors = errors # how to handle errors in decoding From ccefaacc2f81c2eb8c26dbc57f689c9919b9d3db Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Wed, 3 May 2023 20:22:55 -0700 Subject: [PATCH 12/13] use union for one typing --- metaseq/data/encoders/gpt2_bpe_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metaseq/data/encoders/gpt2_bpe_utils.py b/metaseq/data/encoders/gpt2_bpe_utils.py index a784cf325..26dac2ddd 100644 --- a/metaseq/data/encoders/gpt2_bpe_utils.py +++ b/metaseq/data/encoders/gpt2_bpe_utils.py @@ -12,7 +12,7 @@ import json from functools import lru_cache -from typing import Any +from typing import Union @lru_cache() @@ -76,7 +76,7 @@ def __init__(self, encoder: dict, bpe_merges: list, errors: str = "replace"): r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" ) - def bpe(self, token: str) -> Any[str, list]: + def bpe(self, token: str) -> Union[str, list]: if token in self.cache: return self.cache[token] word = tuple(token) From 4f2f847755284ab0cba52197d884de3eb9986f04 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Mon, 8 May 2023 15:02:37 -0700 Subject: [PATCH 13/13] add typing for append_token_dataset --- metaseq/data/append_token_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/metaseq/data/append_token_dataset.py b/metaseq/data/append_token_dataset.py index bf718304d..dd20df256 100644 --- a/metaseq/data/append_token_dataset.py +++ b/metaseq/data/append_token_dataset.py @@ -10,7 +10,7 @@ class AppendTokenDataset(BaseWrapperDataset): - def __init__(self, dataset, token=None): + def __init__(self, dataset: BaseWrapperDataset, token: str = None): super().__init__(dataset) self.token = token if token is not None: @@ -18,7 +18,7 @@ def __init__(self, dataset, token=None): else: self._sizes = dataset.sizes - def __getitem__(self, idx): + def __getitem__(self, idx: int): item = self.dataset[idx] if self.token is not None: item = torch.cat([item, item.new([self.token])]) @@ -28,13 +28,13 @@ def __getitem__(self, idx): def sizes(self): return self._sizes - def num_tokens(self, index): + def num_tokens(self, index: int): n = self.dataset.num_tokens(index) if self.token is not None: n += 1 return n - def size(self, index): + def size(self, index: int): n = self.dataset.size(index) if self.token is not None: n += 1