Skip to content

Commit 4280a4a

Browse files
committed
WIP: llama3 support, tiktoken tokenizer
1 parent c21a889 commit 4280a4a

File tree

4 files changed

+122
-3
lines changed

4 files changed

+122
-3
lines changed

generate.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def device_sync(device):
3535
from sentencepiece import SentencePieceProcessor
3636

3737
from model import Transformer
38-
38+
from tokenizer import get_tokenizer
3939

4040
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
4141
q = torch.empty_like(probs_sort).exponential_(1)
@@ -268,7 +268,10 @@ def main(
268268
"""
269269
assert checkpoint_path.is_file(), checkpoint_path
270270

271-
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
271+
if "Llama-3" in str(checkpoint_path):
272+
tokenizer_path = checkpoint_path.parent / "original/tokenizer.model"
273+
else:
274+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
272275
assert tokenizer_path.is_file(), tokenizer_path
273276

274277
global print
@@ -297,7 +300,9 @@ def main(
297300
device_sync(device=device) # MKG
298301
print(f"Time to load model: {time.time() - t0:.02f} seconds")
299302

300-
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
303+
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
304+
305+
#tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
301306
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
302307
prompt_length = encoded.size(0)
303308

model.py

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def from_name(cls, name: str):
6565
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
6666
"stories15M": dict(n_layer=6, n_head=6, dim=288),
6767
"stories110M": dict(n_layer=12, n_head=12, dim=768),
68+
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256),
69+
"Llama-3-70B": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256),
6870
}
6971

7072
class KVCache(nn.Module):

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
torch
22
sentencepiece
3+
tiktoken

tokenizer.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import os
2+
import sentencepiece as spm
3+
import tiktoken
4+
from tiktoken.load import load_tiktoken_bpe
5+
from pathlib import Path
6+
from typing import Dict
7+
8+
class TokenizerInterface:
9+
def __init__(self, model_path):
10+
self.model_path = model_path
11+
12+
def encode(self, text):
13+
raise NotImplementedError("This method should be overridden by subclasses.")
14+
15+
def decode(self, tokens):
16+
raise NotImplementedError("This method should be overridden by subclasses.")
17+
18+
def bos_id(self):
19+
raise NotImplementedError("This method should be overridden by subclasses.")
20+
21+
def eos_id(self):
22+
raise NotImplementedError("This method should be overridden by subclasses.")
23+
24+
class SentencePieceWrapper(TokenizerInterface):
25+
def __init__(self, model_path):
26+
super().__init__(model_path)
27+
self.processor = spm.SentencePieceProcessor(str(model_path))
28+
29+
def encode(self, text):
30+
return self.processor.EncodeAsIds(text)
31+
32+
def decode(self, tokens):
33+
return self.processor.DecodeIds(tokens)
34+
35+
def bos_id(self):
36+
return self.processor.bos_id()
37+
38+
def eos_id(self):
39+
return self.processor.eos_id()
40+
41+
class TiktokenWrapper(TokenizerInterface):
42+
"""
43+
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
44+
"""
45+
46+
special_tokens: Dict[str, int]
47+
48+
num_reserved_special_tokens = 256
49+
50+
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
51+
52+
def __init__(self, model_path):
53+
super().__init__(model_path)
54+
assert os.path.isfile(model_path), str(model_path)
55+
mergeable_ranks = load_tiktoken_bpe(str(model_path))
56+
num_base_tokens = len(mergeable_ranks)
57+
special_tokens = [
58+
"<|begin_of_text|>",
59+
"<|end_of_text|>",
60+
"<|reserved_special_token_0|>",
61+
"<|reserved_special_token_1|>",
62+
"<|reserved_special_token_2|>",
63+
"<|reserved_special_token_3|>",
64+
"<|start_header_id|>",
65+
"<|end_header_id|>",
66+
"<|reserved_special_token_4|>",
67+
"<|eot_id|>", # end of turn
68+
] + [
69+
f"<|reserved_special_token_{i}|>"
70+
for i in range(5, self.num_reserved_special_tokens - 5)
71+
]
72+
self.special_tokens = {
73+
token: num_base_tokens + i for i, token in enumerate(special_tokens)
74+
}
75+
self.model = tiktoken.Encoding(
76+
name=Path(model_path).name,
77+
pat_str=self.pat_str,
78+
mergeable_ranks=mergeable_ranks,
79+
special_tokens=self.special_tokens,
80+
)
81+
# BOS / EOS token IDs
82+
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
83+
self._eos_id: int = self.special_tokens["<|end_of_text|>"]
84+
85+
def encode(self, text):
86+
return self.model.encode(text)
87+
88+
def decode(self, tokens):
89+
return self.model.decode(tokens)
90+
91+
def bos_id(self):
92+
return self._bos_id
93+
94+
def eos_id(self):
95+
return self._eos_id
96+
97+
def get_tokenizer(tokenizer_model_path, model_name):
98+
"""
99+
Factory function to get the appropriate tokenizer based on the model name.
100+
101+
Args:
102+
- tokenizer_model_path (str): The file path to the tokenizer model.
103+
- model_name (str): The name of the model, used to determine the tokenizer type.
104+
105+
Returns:
106+
- TokenizerInterface: An instance of a tokenizer.
107+
"""
108+
if "Llama-3" in str(model_name):
109+
return TiktokenWrapper(tokenizer_model_path)
110+
else:
111+
return SentencePieceWrapper(tokenizer_model_path)

0 commit comments

Comments
 (0)