Skip to content

Commit 50e45b0

Browse files
committed
Finalizing
1 parent 4280a4a commit 50e45b0

File tree

6 files changed

+95
-63
lines changed

6 files changed

+95
-63
lines changed

eval.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
torch._inductor.config.triton.cudagraphs = True
1919
torch._dynamo.config.cache_size_limit = 100000
2020

21-
from sentencepiece import SentencePieceProcessor
21+
from tokenizer import get_tokenizer
2222

2323
from model import Transformer
2424

@@ -217,7 +217,7 @@ def main(
217217
assert checkpoint_path.is_file(), checkpoint_path
218218

219219
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
220-
assert tokenizer_path.is_file(), tokenizer_path
220+
assert tokenizer_path.is_file(), str(tokenizer_path)
221221

222222
device = 'cuda'
223223
precision = torch.bfloat16
@@ -231,7 +231,7 @@ def main(
231231

232232
model.eval()
233233

234-
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
234+
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
235235

236236
torch.manual_seed(1234)
237237

generate.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ def device_sync(device):
3232
wd = Path(__file__).parent.parent.resolve()
3333
sys.path.append(str(wd))
3434

35-
from sentencepiece import SentencePieceProcessor
36-
3735
from model import Transformer
3836
from tokenizer import get_tokenizer
3937

@@ -268,11 +266,8 @@ def main(
268266
"""
269267
assert checkpoint_path.is_file(), checkpoint_path
270268

271-
if "Llama-3" in str(checkpoint_path):
272-
tokenizer_path = checkpoint_path.parent / "original/tokenizer.model"
273-
else:
274-
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
275-
assert tokenizer_path.is_file(), tokenizer_path
269+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
270+
assert tokenizer_path.is_file(), str(tokenizer_path)
276271

277272
global print
278273
from tp import maybe_init_dist
@@ -302,7 +297,6 @@ def main(
302297

303298
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
304299

305-
#tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
306300
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
307301
prompt_length = encoded.size(0)
308302

mixtral-moe/generate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def main(
175175
assert checkpoint_path.is_file(), checkpoint_path
176176

177177
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
178-
assert tokenizer_path.is_file(), tokenizer_path
178+
assert tokenizer_path.is_file(), str(tokenizer_path)
179179

180180
global print
181181
rank = maybe_init_dist()

model.py

-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def from_name(cls, name: str):
6666
"stories15M": dict(n_layer=6, n_head=6, dim=288),
6767
"stories110M": dict(n_layer=12, n_head=12, dim=768),
6868
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256),
69-
"Llama-3-70B": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256),
7069
}
7170

7271
class KVCache(nn.Module):

quantize.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111
import torch.nn.functional as F
12-
from sentencepiece import SentencePieceProcessor
12+
from tokenizer import get_tokenizer
1313

1414
try:
1515
from GPTQ import GenericGPTQRunner, InputRecorder
@@ -578,8 +578,8 @@ def quantize(
578578
quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)
579579

580580
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
581-
assert tokenizer_path.is_file(), tokenizer_path
582-
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
581+
assert tokenizer_path.is_file(), str(tokenizer_path)
582+
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
583583

584584
quantized_state_dict = quant_handler.create_quantized_state_dict(
585585
tokenizer,

scripts/convert_hf_checkpoint.py

+86-47
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import json
77
import re
8+
import shutil
89
import sys
910
from pathlib import Path
1011
from typing import Optional
@@ -27,33 +28,62 @@ def convert_hf_checkpoint(
2728
if model_name is None:
2829
model_name = checkpoint_dir.name
2930

31+
# Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
32+
# need to be copied into model.pth.
33+
# Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
34+
# weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
35+
# currently supported.
36+
# Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
37+
is_llama3 = "Llama-3" in model_name
38+
if is_llama3:
39+
# Check if we have multiple original/consolidated.NN.pth files and report error
40+
# if we do for Llama 3.
41+
original_dir = checkpoint_dir / "original"
42+
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
43+
bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)]
44+
if len(bin_files) > 1:
45+
raise ValueError(
46+
f"Multiple consolidated.NN.pth files found in {original_dir}. "
47+
"Merging them into one model.pth file is not supported for Llama 3.")
48+
49+
3050
config = ModelArgs.from_name(model_name)
3151
print(f"Model config {config.__dict__}")
3252

3353
# Load the json file containing weight mapping
34-
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
35-
36-
assert model_map_json.is_file()
37-
38-
with open(model_map_json) as json_map:
39-
bin_index = json.load(json_map)
40-
41-
weight_map = {
42-
"model.embed_tokens.weight": "tok_embeddings.weight",
43-
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
44-
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
45-
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
46-
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
47-
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
48-
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
49-
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
50-
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
51-
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
52-
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
53-
"model.norm.weight": "norm.weight",
54-
"lm_head.weight": "output.weight",
55-
}
56-
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
54+
if not is_llama3:
55+
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
56+
57+
assert model_map_json.is_file()
58+
59+
with open(model_map_json) as json_map:
60+
bin_index = json.load(json_map)
61+
62+
weight_map = {
63+
"model.embed_tokens.weight": "tok_embeddings.weight",
64+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
65+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
66+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
67+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
68+
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
69+
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
70+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
71+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
72+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
73+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
74+
"model.norm.weight": "norm.weight",
75+
"lm_head.weight": "output.weight",
76+
}
77+
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
78+
else:
79+
# There is no separate pytorch_model.bin.index.json file for llama3.
80+
# Instead, we will just use all original/consolidated.NN.pth files.
81+
# so, we use model.safetensors.index.json
82+
weight_map = None
83+
original_dir = checkpoint_dir / "original"
84+
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
85+
bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)}
86+
5787

5888
def permute(w, n_head):
5989
dim = config.dim
@@ -68,32 +98,41 @@ def permute(w, n_head):
6898
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
6999
merged_result.update(state_dict)
70100
final_result = {}
71-
for key, value in merged_result.items():
72-
if "layers" in key:
73-
abstract_key = re.sub(r'(\d+)', '{}', key)
74-
layer_num = re.search(r'\d+', key).group(0)
75-
new_key = weight_map[abstract_key]
76-
if new_key is None:
77-
continue
78-
new_key = new_key.format(layer_num)
79-
else:
80-
new_key = weight_map[key]
81-
82-
final_result[new_key] = value
83-
84-
for key in tuple(final_result.keys()):
85-
if "wq" in key:
86-
q = final_result[key]
87-
k = final_result[key.replace("wq", "wk")]
88-
v = final_result[key.replace("wq", "wv")]
89-
q = permute(q, config.n_head)
90-
k = permute(k, config.n_local_heads)
91-
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
92-
del final_result[key]
93-
del final_result[key.replace("wq", "wk")]
94-
del final_result[key.replace("wq", "wv")]
101+
if weight_map is not None:
102+
for key, value in merged_result.items():
103+
if "layers" in key:
104+
abstract_key = re.sub(r'(\d+)', '{}', key)
105+
layer_num = re.search(r'\d+', key).group(0)
106+
new_key = weight_map[abstract_key]
107+
if new_key is None:
108+
continue
109+
new_key = new_key.format(layer_num)
110+
else:
111+
new_key = weight_map[key]
112+
113+
final_result[new_key] = value
114+
115+
for key in tuple(final_result.keys()):
116+
if "wq" in key:
117+
q = final_result[key]
118+
k = final_result[key.replace("wq", "wk")]
119+
v = final_result[key.replace("wq", "wv")]
120+
q = permute(q, config.n_head)
121+
k = permute(k, config.n_local_heads)
122+
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
123+
del final_result[key]
124+
del final_result[key.replace("wq", "wk")]
125+
del final_result[key.replace("wq", "wv")]
126+
else:
127+
final_result = merged_result
95128
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
96129
torch.save(final_result, checkpoint_dir / "model.pth")
130+
if is_llama3:
131+
original_dir = checkpoint_dir / "original"
132+
tokenizer_model = original_dir / "tokenizer.model"
133+
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
134+
print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")
135+
shutil.copy(tokenizer_model, tokenizer_model_tiktoken)
97136

98137
if __name__ == '__main__':
99138
import argparse

0 commit comments

Comments
 (0)