Skip to content

Commit 091515a

Browse files
authored
Unified Llama 3 (8b,70b) + Safetensors support (#169)
* unify llama 3 support * add safetensors support * Bug * Add safetensors to reqs * rope base bug fix. Thx @xavierpuigf From comment #169 (comment) * Update model.py
1 parent 900cd67 commit 091515a

File tree

4 files changed

+80
-87
lines changed

4 files changed

+80
-87
lines changed

model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ def from_name(cls, name: str):
4444
if name in transformer_configs:
4545
return cls(**transformer_configs[name])
4646
# fuzzy search
47-
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
47+
config = [config for config in transformer_configs if config.lower() in str(name).lower()]
4848

4949
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
5050
# take longer name (as it have more symbols matched)
5151
if len(config) > 1:
5252
config.sort(key=len, reverse=True)
5353
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
54-
54+
5555
return cls(**transformer_configs[config[0]])
5656

5757

@@ -65,7 +65,9 @@ def from_name(cls, name: str):
6565
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
6666
"stories15M": dict(n_layer=6, n_head=6, dim=288),
6767
"stories110M": dict(n_layer=12, n_head=12, dim=768),
68-
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
68+
69+
"llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
70+
"llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000),
6971
}
7072

7173
class KVCache(nn.Module):

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
torch
22
sentencepiece
33
tiktoken
4+
blobfile
5+
safetensors

scripts/convert_hf_checkpoint.py

Lines changed: 71 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sys
1010
from pathlib import Path
1111
from typing import Optional
12-
12+
from safetensors.torch import load_file as load_safetensors_file
1313
import torch
1414

1515
# support running without installing as a package
@@ -28,62 +28,49 @@ def convert_hf_checkpoint(
2828
if model_name is None:
2929
model_name = checkpoint_dir.name
3030

31-
# Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
32-
# need to be copied into model.pth.
33-
# Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
34-
# weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
35-
# currently supported.
36-
# Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
37-
is_llama3 = "Llama-3" in model_name
38-
if is_llama3:
39-
# Check if we have multiple original/consolidated.NN.pth files and report error
40-
# if we do for Llama 3.
41-
original_dir = checkpoint_dir / "original"
42-
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
43-
bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)]
44-
if len(bin_files) > 1:
45-
raise ValueError(
46-
f"Multiple consolidated.NN.pth files found in {original_dir}. "
47-
"Merging them into one model.pth file is not supported for Llama 3.")
48-
49-
5031
config = ModelArgs.from_name(model_name)
5132
print(f"Model config {config.__dict__}")
5233

5334
# Load the json file containing weight mapping
54-
if not is_llama3:
55-
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
56-
57-
assert model_map_json.is_file()
58-
59-
with open(model_map_json) as json_map:
60-
bin_index = json.load(json_map)
61-
62-
weight_map = {
63-
"model.embed_tokens.weight": "tok_embeddings.weight",
64-
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
65-
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
66-
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
67-
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
68-
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
69-
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
70-
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
71-
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
72-
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
73-
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
74-
"model.norm.weight": "norm.weight",
75-
"lm_head.weight": "output.weight",
76-
}
77-
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
78-
else:
79-
# There is no separate pytorch_model.bin.index.json file for llama3.
80-
# Instead, we will just use all original/consolidated.NN.pth files.
81-
# so, we use model.safetensors.index.json
82-
weight_map = None
83-
original_dir = checkpoint_dir / "original"
84-
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
85-
bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)}
86-
35+
model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
36+
model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
37+
model_map_json = None
38+
39+
try:
40+
assert model_map_json_safetensors.is_file()
41+
model_map_json = model_map_json_safetensors
42+
print(f"Found safetensors index at {model_map_json_safetensors}")
43+
except AssertionError:
44+
print(f"{model_map_json_safetensors} not found")
45+
if model_map_json is None:
46+
try:
47+
assert model_map_json_pytorch.is_file()
48+
model_map_json = model_map_json_pytorch
49+
print(f"Found pytorch index at {model_map_json_pytorch}")
50+
except AssertionError:
51+
print(f"{model_map_json_pytorch} not found")
52+
53+
if model_map_json is None: raise Exception("No model map found!")
54+
55+
with open(model_map_json) as json_map:
56+
bin_index = json.load(json_map)
57+
58+
weight_map = {
59+
"model.embed_tokens.weight": "tok_embeddings.weight",
60+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
61+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
62+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
63+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
64+
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
65+
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
66+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
67+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
68+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
69+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
70+
"model.norm.weight": "norm.weight",
71+
"lm_head.weight": "output.weight",
72+
}
73+
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
8774

8875
def permute(w, n_head):
8976
dim = config.dim
@@ -95,39 +82,40 @@ def permute(w, n_head):
9582

9683
merged_result = {}
9784
for file in sorted(bin_files):
98-
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
99-
merged_result.update(state_dict)
85+
if "safetensors" in str(file):
86+
state_dict = load_safetensors_file(str(file), device="cpu")
87+
merged_result.update(state_dict)
88+
else:
89+
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
90+
merged_result.update(state_dict)
10091
final_result = {}
101-
if weight_map is not None:
102-
for key, value in merged_result.items():
103-
if "layers" in key:
104-
abstract_key = re.sub(r'(\d+)', '{}', key)
105-
layer_num = re.search(r'\d+', key).group(0)
106-
new_key = weight_map[abstract_key]
107-
if new_key is None:
108-
continue
109-
new_key = new_key.format(layer_num)
110-
else:
111-
new_key = weight_map[key]
112-
113-
final_result[new_key] = value
114-
115-
for key in tuple(final_result.keys()):
116-
if "wq" in key:
117-
q = final_result[key]
118-
k = final_result[key.replace("wq", "wk")]
119-
v = final_result[key.replace("wq", "wv")]
120-
q = permute(q, config.n_head)
121-
k = permute(k, config.n_local_heads)
122-
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
123-
del final_result[key]
124-
del final_result[key.replace("wq", "wk")]
125-
del final_result[key.replace("wq", "wv")]
126-
else:
127-
final_result = merged_result
92+
for key, value in merged_result.items():
93+
if "layers" in key:
94+
abstract_key = re.sub(r'(\d+)', '{}', key)
95+
layer_num = re.search(r'\d+', key).group(0)
96+
new_key = weight_map[abstract_key]
97+
if new_key is None:
98+
continue
99+
new_key = new_key.format(layer_num)
100+
else:
101+
new_key = weight_map[key]
102+
103+
final_result[new_key] = value
104+
105+
for key in tuple(final_result.keys()):
106+
if "wq" in key:
107+
q = final_result[key]
108+
k = final_result[key.replace("wq", "wk")]
109+
v = final_result[key.replace("wq", "wv")]
110+
q = permute(q, config.n_head)
111+
k = permute(k, config.n_local_heads)
112+
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
113+
del final_result[key]
114+
del final_result[key.replace("wq", "wk")]
115+
del final_result[key.replace("wq", "wv")]
128116
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
129117
torch.save(final_result, checkpoint_dir / "model.pth")
130-
if is_llama3:
118+
if 'llama-3' in model_name.lower():
131119
original_dir = checkpoint_dir / "original"
132120
tokenizer_model = original_dir / "tokenizer.model"
133121
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"

tokenizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def get_tokenizer(tokenizer_model_path, model_name):
105105
Returns:
106106
- TokenizerInterface: An instance of a tokenizer.
107107
"""
108-
if "Llama-3" in str(model_name):
108+
109+
if "llama-3" in str(model_name).lower():
109110
return TiktokenWrapper(tokenizer_model_path)
110111
else:
111112
return SentencePieceWrapper(tokenizer_model_path)

0 commit comments

Comments
 (0)