Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions metaseq/data/append_token_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@


class AppendTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
def __init__(self, dataset: BaseWrapperDataset, token: str = None):
super().__init__(dataset)
self.token = token
if token is not None:
self._sizes = np.array(dataset.sizes) + 1
else:
self._sizes = dataset.sizes

def __getitem__(self, idx):
def __getitem__(self, idx: int):
item = self.dataset[idx]
if self.token is not None:
item = torch.cat([item, item.new([self.token])])
Expand All @@ -28,13 +28,13 @@ def __getitem__(self, idx):
def sizes(self):
return self._sizes

def num_tokens(self, index):
def num_tokens(self, index: int):
n = self.dataset.num_tokens(index)
if self.token is not None:
n += 1
return n

def size(self, index):
def size(self, index: int):
n = self.dataset.size(index)
if self.token is not None:
n += 1
Expand Down
15 changes: 8 additions & 7 deletions metaseq/data/encoders/gpt2_bpe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

import json
from functools import lru_cache
from typing import Union


@lru_cache()
def bytes_to_unicode():
def bytes_to_unicode() -> dict:
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
Expand All @@ -41,7 +42,7 @@ def bytes_to_unicode():
return dict(zip(bs, cs))


def get_pairs(word):
def get_pairs(word: tuple) -> set:
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
Expand All @@ -54,7 +55,7 @@ def get_pairs(word):


class Encoder:
def __init__(self, encoder, bpe_merges, errors="replace"):
def __init__(self, encoder: dict, bpe_merges: list, errors: str = "replace"):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
Expand All @@ -75,7 +76,7 @@ def __init__(self, encoder, bpe_merges, errors="replace"):
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)

def bpe(self, token):
def bpe(self, token: str) -> Union[str, list]:
if token in self.cache:
return self.cache[token]
word = tuple(token)
Expand Down Expand Up @@ -116,7 +117,7 @@ def bpe(self, token):
self.cache[token] = word
return word

def encode(self, text):
def encode(self, text: str) -> list:
bpe_tokens = []
for token in self.re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
Expand All @@ -125,15 +126,15 @@ def encode(self, text):
)
return bpe_tokens

def decode(self, tokens):
def decode(self, tokens: list) -> str:
text = "".join([self.decoder.get(token, token) for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
"utf-8", errors=self.errors
)
return text


def get_encoder(encoder_json_path, vocab_bpe_path):
def get_encoder(encoder_json_path: str, vocab_bpe_path: str) -> Encoder:
with open(encoder_json_path, "r") as f:
encoder = json.load(f)
with open(vocab_bpe_path, "r", encoding="utf-8") as f:
Expand Down