Skip to content

Commit 30d69b3

Browse files
authored
llama3 8B support, tiktoken tokenizer (#158)
* WIP: llama3 support, tiktoken tokenizer * Finalizing
1 parent c21a889 commit 30d69b3

File tree

8 files changed

+210
-59
lines changed

8 files changed

+210
-59
lines changed

eval.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
torch._inductor.config.triton.cudagraphs = True
1919
torch._dynamo.config.cache_size_limit = 100000
2020

21-
from sentencepiece import SentencePieceProcessor
21+
from tokenizer import get_tokenizer
2222

2323
from model import Transformer
2424

@@ -217,7 +217,7 @@ def main(
217217
assert checkpoint_path.is_file(), checkpoint_path
218218

219219
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
220-
assert tokenizer_path.is_file(), tokenizer_path
220+
assert tokenizer_path.is_file(), str(tokenizer_path)
221221

222222
device = 'cuda'
223223
precision = torch.bfloat16
@@ -231,7 +231,7 @@ def main(
231231

232232
model.eval()
233233

234-
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
234+
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
235235

236236
torch.manual_seed(1234)
237237

generate.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,8 @@ def device_sync(device):
3232
wd = Path(__file__).parent.parent.resolve()
3333
sys.path.append(str(wd))
3434

35-
from sentencepiece import SentencePieceProcessor
36-
3735
from model import Transformer
38-
36+
from tokenizer import get_tokenizer
3937

4038
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
4139
q = torch.empty_like(probs_sort).exponential_(1)
@@ -269,7 +267,7 @@ def main(
269267
assert checkpoint_path.is_file(), checkpoint_path
270268

271269
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
272-
assert tokenizer_path.is_file(), tokenizer_path
270+
assert tokenizer_path.is_file(), str(tokenizer_path)
273271

274272
global print
275273
from tp import maybe_init_dist
@@ -297,7 +295,8 @@ def main(
297295
device_sync(device=device) # MKG
298296
print(f"Time to load model: {time.time() - t0:.02f} seconds")
299297

300-
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
298+
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
299+
301300
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
302301
prompt_length = encoded.size(0)
303302

mixtral-moe/generate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def main(
175175
assert checkpoint_path.is_file(), checkpoint_path
176176

177177
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
178-
assert tokenizer_path.is_file(), tokenizer_path
178+
assert tokenizer_path.is_file(), str(tokenizer_path)
179179

180180
global print
181181
rank = maybe_init_dist()

model.py

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def from_name(cls, name: str):
6565
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
6666
"stories15M": dict(n_layer=6, n_head=6, dim=288),
6767
"stories110M": dict(n_layer=12, n_head=12, dim=768),
68+
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256),
6869
}
6970

7071
class KVCache(nn.Module):

quantize.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111
import torch.nn.functional as F
12-
from sentencepiece import SentencePieceProcessor
12+
from tokenizer import get_tokenizer
1313

1414
try:
1515
from GPTQ import GenericGPTQRunner, InputRecorder
@@ -578,8 +578,8 @@ def quantize(
578578
quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)
579579

580580
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
581-
assert tokenizer_path.is_file(), tokenizer_path
582-
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
581+
assert tokenizer_path.is_file(), str(tokenizer_path)
582+
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
583583

584584
quantized_state_dict = quant_handler.create_quantized_state_dict(
585585
tokenizer,

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
torch
22
sentencepiece
3+
tiktoken

scripts/convert_hf_checkpoint.py

+86-47
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import json
77
import re
8+
import shutil
89
import sys
910
from pathlib import Path
1011
from typing import Optional
@@ -27,33 +28,62 @@ def convert_hf_checkpoint(
2728
if model_name is None:
2829
model_name = checkpoint_dir.name
2930

31+
# Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
32+
# need to be copied into model.pth.
33+
# Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
34+
# weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
35+
# currently supported.
36+
# Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
37+
is_llama3 = "Llama-3" in model_name
38+
if is_llama3:
39+
# Check if we have multiple original/consolidated.NN.pth files and report error
40+
# if we do for Llama 3.
41+
original_dir = checkpoint_dir / "original"
42+
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
43+
bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)]
44+
if len(bin_files) > 1:
45+
raise ValueError(
46+
f"Multiple consolidated.NN.pth files found in {original_dir}. "
47+
"Merging them into one model.pth file is not supported for Llama 3.")
48+
49+
3050
config = ModelArgs.from_name(model_name)
3151
print(f"Model config {config.__dict__}")
3252

3353
# Load the json file containing weight mapping
34-
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
35-
36-
assert model_map_json.is_file()
37-
38-
with open(model_map_json) as json_map:
39-
bin_index = json.load(json_map)
40-
41-
weight_map = {
42-
"model.embed_tokens.weight": "tok_embeddings.weight",
43-
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
44-
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
45-
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
46-
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
47-
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
48-
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
49-
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
50-
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
51-
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
52-
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
53-
"model.norm.weight": "norm.weight",
54-
"lm_head.weight": "output.weight",
55-
}
56-
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
54+
if not is_llama3:
55+
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
56+
57+
assert model_map_json.is_file()
58+
59+
with open(model_map_json) as json_map:
60+
bin_index = json.load(json_map)
61+
62+
weight_map = {
63+
"model.embed_tokens.weight": "tok_embeddings.weight",
64+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
65+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
66+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
67+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
68+
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
69+
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
70+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
71+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
72+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
73+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
74+
"model.norm.weight": "norm.weight",
75+
"lm_head.weight": "output.weight",
76+
}
77+
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
78+
else:
79+
# There is no separate pytorch_model.bin.index.json file for llama3.
80+
# Instead, we will just use all original/consolidated.NN.pth files.
81+
# so, we use model.safetensors.index.json
82+
weight_map = None
83+
original_dir = checkpoint_dir / "original"
84+
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
85+
bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)}
86+
5787

5888
def permute(w, n_head):
5989
dim = config.dim
@@ -68,32 +98,41 @@ def permute(w, n_head):
6898
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
6999
merged_result.update(state_dict)
70100
final_result = {}
71-
for key, value in merged_result.items():
72-
if "layers" in key:
73-
abstract_key = re.sub(r'(\d+)', '{}', key)
74-
layer_num = re.search(r'\d+', key).group(0)
75-
new_key = weight_map[abstract_key]
76-
if new_key is None:
77-
continue
78-
new_key = new_key.format(layer_num)
79-
else:
80-
new_key = weight_map[key]
81-
82-
final_result[new_key] = value
83-
84-
for key in tuple(final_result.keys()):
85-
if "wq" in key:
86-
q = final_result[key]
87-
k = final_result[key.replace("wq", "wk")]
88-
v = final_result[key.replace("wq", "wv")]
89-
q = permute(q, config.n_head)
90-
k = permute(k, config.n_local_heads)
91-
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
92-
del final_result[key]
93-
del final_result[key.replace("wq", "wk")]
94-
del final_result[key.replace("wq", "wv")]
101+
if weight_map is not None:
102+
for key, value in merged_result.items():
103+
if "layers" in key:
104+
abstract_key = re.sub(r'(\d+)', '{}', key)
105+
layer_num = re.search(r'\d+', key).group(0)
106+
new_key = weight_map[abstract_key]
107+
if new_key is None:
108+
continue
109+
new_key = new_key.format(layer_num)
110+
else:
111+
new_key = weight_map[key]
112+
113+
final_result[new_key] = value
114+
115+
for key in tuple(final_result.keys()):
116+
if "wq" in key:
117+
q = final_result[key]
118+
k = final_result[key.replace("wq", "wk")]
119+
v = final_result[key.replace("wq", "wv")]
120+
q = permute(q, config.n_head)
121+
k = permute(k, config.n_local_heads)
122+
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
123+
del final_result[key]
124+
del final_result[key.replace("wq", "wk")]
125+
del final_result[key.replace("wq", "wv")]
126+
else:
127+
final_result = merged_result
95128
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
96129
torch.save(final_result, checkpoint_dir / "model.pth")
130+
if is_llama3:
131+
original_dir = checkpoint_dir / "original"
132+
tokenizer_model = original_dir / "tokenizer.model"
133+
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
134+
print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")
135+
shutil.copy(tokenizer_model, tokenizer_model_tiktoken)
97136

98137
if __name__ == '__main__':
99138
import argparse

tokenizer.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import os
2+
import sentencepiece as spm
3+
import tiktoken
4+
from tiktoken.load import load_tiktoken_bpe
5+
from pathlib import Path
6+
from typing import Dict
7+
8+
class TokenizerInterface:
9+
def __init__(self, model_path):
10+
self.model_path = model_path
11+
12+
def encode(self, text):
13+
raise NotImplementedError("This method should be overridden by subclasses.")
14+
15+
def decode(self, tokens):
16+
raise NotImplementedError("This method should be overridden by subclasses.")
17+
18+
def bos_id(self):
19+
raise NotImplementedError("This method should be overridden by subclasses.")
20+
21+
def eos_id(self):
22+
raise NotImplementedError("This method should be overridden by subclasses.")
23+
24+
class SentencePieceWrapper(TokenizerInterface):
25+
def __init__(self, model_path):
26+
super().__init__(model_path)
27+
self.processor = spm.SentencePieceProcessor(str(model_path))
28+
29+
def encode(self, text):
30+
return self.processor.EncodeAsIds(text)
31+
32+
def decode(self, tokens):
33+
return self.processor.DecodeIds(tokens)
34+
35+
def bos_id(self):
36+
return self.processor.bos_id()
37+
38+
def eos_id(self):
39+
return self.processor.eos_id()
40+
41+
class TiktokenWrapper(TokenizerInterface):
42+
"""
43+
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
44+
"""
45+
46+
special_tokens: Dict[str, int]
47+
48+
num_reserved_special_tokens = 256
49+
50+
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
51+
52+
def __init__(self, model_path):
53+
super().__init__(model_path)
54+
assert os.path.isfile(model_path), str(model_path)
55+
mergeable_ranks = load_tiktoken_bpe(str(model_path))
56+
num_base_tokens = len(mergeable_ranks)
57+
special_tokens = [
58+
"<|begin_of_text|>",
59+
"<|end_of_text|>",
60+
"<|reserved_special_token_0|>",
61+
"<|reserved_special_token_1|>",
62+
"<|reserved_special_token_2|>",
63+
"<|reserved_special_token_3|>",
64+
"<|start_header_id|>",
65+
"<|end_header_id|>",
66+
"<|reserved_special_token_4|>",
67+
"<|eot_id|>", # end of turn
68+
] + [
69+
f"<|reserved_special_token_{i}|>"
70+
for i in range(5, self.num_reserved_special_tokens - 5)
71+
]
72+
self.special_tokens = {
73+
token: num_base_tokens + i for i, token in enumerate(special_tokens)
74+
}
75+
self.model = tiktoken.Encoding(
76+
name=Path(model_path).name,
77+
pat_str=self.pat_str,
78+
mergeable_ranks=mergeable_ranks,
79+
special_tokens=self.special_tokens,
80+
)
81+
# BOS / EOS token IDs
82+
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
83+
self._eos_id: int = self.special_tokens["<|end_of_text|>"]
84+
85+
def encode(self, text):
86+
return self.model.encode(text)
87+
88+
def decode(self, tokens):
89+
return self.model.decode(tokens)
90+
91+
def bos_id(self):
92+
return self._bos_id
93+
94+
def eos_id(self):
95+
return self._eos_id
96+
97+
def get_tokenizer(tokenizer_model_path, model_name):
98+
"""
99+
Factory function to get the appropriate tokenizer based on the model name.
100+
101+
Args:
102+
- tokenizer_model_path (str): The file path to the tokenizer model.
103+
- model_name (str): The name of the model, used to determine the tokenizer type.
104+
105+
Returns:
106+
- TokenizerInterface: An instance of a tokenizer.
107+
"""
108+
if "Llama-3" in str(model_name):
109+
return TiktokenWrapper(tokenizer_model_path)
110+
else:
111+
return SentencePieceWrapper(tokenizer_model_path)

0 commit comments

Comments
 (0)