5
5
# LICENSE file in the root directory of this source tree.
6
6
import json
7
7
import re
8
+ import shutil
8
9
import sys
9
10
from pathlib import Path
10
11
from typing import Optional
@@ -27,33 +28,62 @@ def convert_hf_checkpoint(
27
28
if model_name is None :
28
29
model_name = checkpoint_dir .name
29
30
31
+ # Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
32
+ # need to be copied into model.pth.
33
+ # Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
34
+ # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
35
+ # currently supported.
36
+ # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
37
+ is_llama3 = "Llama-3" in model_name
38
+ if is_llama3 :
39
+ # Check if we have multiple original/consolidated.NN.pth files and report error
40
+ # if we do for Llama 3.
41
+ original_dir = checkpoint_dir / "original"
42
+ pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
43
+ bin_files = [bin for bin in original_dir .iterdir () if pattern .match (bin .name )]
44
+ if len (bin_files ) > 1 :
45
+ raise ValueError (
46
+ f"Multiple consolidated.NN.pth files found in { original_dir } . "
47
+ "Merging them into one model.pth file is not supported for Llama 3." )
48
+
49
+
30
50
config = ModelArgs .from_name (model_name )
31
51
print (f"Model config { config .__dict__ } " )
32
52
33
53
# Load the json file containing weight mapping
34
- model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
35
-
36
- assert model_map_json .is_file ()
37
-
38
- with open (model_map_json ) as json_map :
39
- bin_index = json .load (json_map )
40
-
41
- weight_map = {
42
- "model.embed_tokens.weight" : "tok_embeddings.weight" ,
43
- "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
44
- "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
45
- "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
46
- "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
47
- 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
48
- 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
49
- "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
50
- "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
51
- "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
52
- "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
53
- "model.norm.weight" : "norm.weight" ,
54
- "lm_head.weight" : "output.weight" ,
55
- }
56
- bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
54
+ if not is_llama3 :
55
+ model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
56
+
57
+ assert model_map_json .is_file ()
58
+
59
+ with open (model_map_json ) as json_map :
60
+ bin_index = json .load (json_map )
61
+
62
+ weight_map = {
63
+ "model.embed_tokens.weight" : "tok_embeddings.weight" ,
64
+ "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
65
+ "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
66
+ "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
67
+ "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
68
+ 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
69
+ 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
70
+ "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
71
+ "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
72
+ "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
73
+ "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
74
+ "model.norm.weight" : "norm.weight" ,
75
+ "lm_head.weight" : "output.weight" ,
76
+ }
77
+ bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
78
+ else :
79
+ # There is no separate pytorch_model.bin.index.json file for llama3.
80
+ # Instead, we will just use all original/consolidated.NN.pth files.
81
+ # so, we use model.safetensors.index.json
82
+ weight_map = None
83
+ original_dir = checkpoint_dir / "original"
84
+ pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
85
+ bin_files = {bin for bin in original_dir .iterdir () if pattern .match (bin .name )}
86
+
57
87
58
88
def permute (w , n_head ):
59
89
dim = config .dim
@@ -68,32 +98,41 @@ def permute(w, n_head):
68
98
state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
69
99
merged_result .update (state_dict )
70
100
final_result = {}
71
- for key , value in merged_result .items ():
72
- if "layers" in key :
73
- abstract_key = re .sub (r'(\d+)' , '{}' , key )
74
- layer_num = re .search (r'\d+' , key ).group (0 )
75
- new_key = weight_map [abstract_key ]
76
- if new_key is None :
77
- continue
78
- new_key = new_key .format (layer_num )
79
- else :
80
- new_key = weight_map [key ]
81
-
82
- final_result [new_key ] = value
83
-
84
- for key in tuple (final_result .keys ()):
85
- if "wq" in key :
86
- q = final_result [key ]
87
- k = final_result [key .replace ("wq" , "wk" )]
88
- v = final_result [key .replace ("wq" , "wv" )]
89
- q = permute (q , config .n_head )
90
- k = permute (k , config .n_local_heads )
91
- final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
92
- del final_result [key ]
93
- del final_result [key .replace ("wq" , "wk" )]
94
- del final_result [key .replace ("wq" , "wv" )]
101
+ if weight_map is not None :
102
+ for key , value in merged_result .items ():
103
+ if "layers" in key :
104
+ abstract_key = re .sub (r'(\d+)' , '{}' , key )
105
+ layer_num = re .search (r'\d+' , key ).group (0 )
106
+ new_key = weight_map [abstract_key ]
107
+ if new_key is None :
108
+ continue
109
+ new_key = new_key .format (layer_num )
110
+ else :
111
+ new_key = weight_map [key ]
112
+
113
+ final_result [new_key ] = value
114
+
115
+ for key in tuple (final_result .keys ()):
116
+ if "wq" in key :
117
+ q = final_result [key ]
118
+ k = final_result [key .replace ("wq" , "wk" )]
119
+ v = final_result [key .replace ("wq" , "wv" )]
120
+ q = permute (q , config .n_head )
121
+ k = permute (k , config .n_local_heads )
122
+ final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
123
+ del final_result [key ]
124
+ del final_result [key .replace ("wq" , "wk" )]
125
+ del final_result [key .replace ("wq" , "wv" )]
126
+ else :
127
+ final_result = merged_result
95
128
print (f"Saving checkpoint to { checkpoint_dir / 'model.pth' } " )
96
129
torch .save (final_result , checkpoint_dir / "model.pth" )
130
+ if is_llama3 :
131
+ original_dir = checkpoint_dir / "original"
132
+ tokenizer_model = original_dir / "tokenizer.model"
133
+ tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
134
+ print (f"Copying { tokenizer_model } to { tokenizer_model_tiktoken } " )
135
+ shutil .copy (tokenizer_model , tokenizer_model_tiktoken )
97
136
98
137
if __name__ == '__main__' :
99
138
import argparse
0 commit comments