9
9
import sys
10
10
from pathlib import Path
11
11
from typing import Optional
12
-
12
+ from safetensors . torch import load_file as load_safetensors_file
13
13
import torch
14
14
15
15
# support running without installing as a package
@@ -28,62 +28,49 @@ def convert_hf_checkpoint(
28
28
if model_name is None :
29
29
model_name = checkpoint_dir .name
30
30
31
- # Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
32
- # need to be copied into model.pth.
33
- # Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
34
- # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
35
- # currently supported.
36
- # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
37
- is_llama3 = "Llama-3" in model_name
38
- if is_llama3 :
39
- # Check if we have multiple original/consolidated.NN.pth files and report error
40
- # if we do for Llama 3.
41
- original_dir = checkpoint_dir / "original"
42
- pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
43
- bin_files = [bin for bin in original_dir .iterdir () if pattern .match (bin .name )]
44
- if len (bin_files ) > 1 :
45
- raise ValueError (
46
- f"Multiple consolidated.NN.pth files found in { original_dir } . "
47
- "Merging them into one model.pth file is not supported for Llama 3." )
48
-
49
-
50
31
config = ModelArgs .from_name (model_name )
51
32
print (f"Model config { config .__dict__ } " )
52
33
53
34
# Load the json file containing weight mapping
54
- if not is_llama3 :
55
- model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
56
-
57
- assert model_map_json .is_file ()
58
-
59
- with open (model_map_json ) as json_map :
60
- bin_index = json .load (json_map )
61
-
62
- weight_map = {
63
- "model.embed_tokens.weight" : "tok_embeddings.weight" ,
64
- "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
65
- "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
66
- "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
67
- "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
68
- 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
69
- 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
70
- "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
71
- "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
72
- "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
73
- "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
74
- "model.norm.weight" : "norm.weight" ,
75
- "lm_head.weight" : "output.weight" ,
76
- }
77
- bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
78
- else :
79
- # There is no separate pytorch_model.bin.index.json file for llama3.
80
- # Instead, we will just use all original/consolidated.NN.pth files.
81
- # so, we use model.safetensors.index.json
82
- weight_map = None
83
- original_dir = checkpoint_dir / "original"
84
- pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
85
- bin_files = {bin for bin in original_dir .iterdir () if pattern .match (bin .name )}
86
-
35
+ model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
36
+ model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
37
+ model_map_json = None
38
+
39
+ try :
40
+ assert model_map_json_safetensors .is_file ()
41
+ model_map_json = model_map_json_safetensors
42
+ print (f"Found safetensors index at { model_map_json_safetensors } " )
43
+ except AssertionError :
44
+ print (f"{ model_map_json_safetensors } not found" )
45
+ if model_map_json is None :
46
+ try :
47
+ assert model_map_json_pytorch .is_file ()
48
+ model_map_json = model_map_json_pytorch
49
+ print (f"Found pytorch index at { model_map_json_pytorch } " )
50
+ except AssertionError :
51
+ print (f"{ model_map_json_pytorch } not found" )
52
+
53
+ if model_map_json is None : raise Exception ("No model map found!" )
54
+
55
+ with open (model_map_json ) as json_map :
56
+ bin_index = json .load (json_map )
57
+
58
+ weight_map = {
59
+ "model.embed_tokens.weight" : "tok_embeddings.weight" ,
60
+ "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
61
+ "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
62
+ "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
63
+ "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
64
+ 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
65
+ 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
66
+ "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
67
+ "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
68
+ "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
69
+ "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
70
+ "model.norm.weight" : "norm.weight" ,
71
+ "lm_head.weight" : "output.weight" ,
72
+ }
73
+ bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
87
74
88
75
def permute (w , n_head ):
89
76
dim = config .dim
@@ -95,39 +82,40 @@ def permute(w, n_head):
95
82
96
83
merged_result = {}
97
84
for file in sorted (bin_files ):
98
- state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
99
- merged_result .update (state_dict )
85
+ if "safetensors" in str (file ):
86
+ state_dict = load_safetensors_file (str (file ), device = "cpu" )
87
+ merged_result .update (state_dict )
88
+ else :
89
+ state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
90
+ merged_result .update (state_dict )
100
91
final_result = {}
101
- if weight_map is not None :
102
- for key , value in merged_result .items ():
103
- if "layers" in key :
104
- abstract_key = re .sub (r'(\d+)' , '{}' , key )
105
- layer_num = re .search (r'\d+' , key ).group (0 )
106
- new_key = weight_map [abstract_key ]
107
- if new_key is None :
108
- continue
109
- new_key = new_key .format (layer_num )
110
- else :
111
- new_key = weight_map [key ]
112
-
113
- final_result [new_key ] = value
114
-
115
- for key in tuple (final_result .keys ()):
116
- if "wq" in key :
117
- q = final_result [key ]
118
- k = final_result [key .replace ("wq" , "wk" )]
119
- v = final_result [key .replace ("wq" , "wv" )]
120
- q = permute (q , config .n_head )
121
- k = permute (k , config .n_local_heads )
122
- final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
123
- del final_result [key ]
124
- del final_result [key .replace ("wq" , "wk" )]
125
- del final_result [key .replace ("wq" , "wv" )]
126
- else :
127
- final_result = merged_result
92
+ for key , value in merged_result .items ():
93
+ if "layers" in key :
94
+ abstract_key = re .sub (r'(\d+)' , '{}' , key )
95
+ layer_num = re .search (r'\d+' , key ).group (0 )
96
+ new_key = weight_map [abstract_key ]
97
+ if new_key is None :
98
+ continue
99
+ new_key = new_key .format (layer_num )
100
+ else :
101
+ new_key = weight_map [key ]
102
+
103
+ final_result [new_key ] = value
104
+
105
+ for key in tuple (final_result .keys ()):
106
+ if "wq" in key :
107
+ q = final_result [key ]
108
+ k = final_result [key .replace ("wq" , "wk" )]
109
+ v = final_result [key .replace ("wq" , "wv" )]
110
+ q = permute (q , config .n_head )
111
+ k = permute (k , config .n_local_heads )
112
+ final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
113
+ del final_result [key ]
114
+ del final_result [key .replace ("wq" , "wk" )]
115
+ del final_result [key .replace ("wq" , "wv" )]
128
116
print (f"Saving checkpoint to { checkpoint_dir / 'model.pth' } " )
129
117
torch .save (final_result , checkpoint_dir / "model.pth" )
130
- if is_llama3 :
118
+ if 'llama-3' in model_name . lower () :
131
119
original_dir = checkpoint_dir / "original"
132
120
tokenizer_model = original_dir / "tokenizer.model"
133
121
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
0 commit comments