Skip to content

Commit fbe2e3f

Browse files
authored
Update merge_lora.py (#284)
1 parent 06fbf45 commit fbe2e3f

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

tools/llama/merge_lora.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,18 @@ def merge(llama_config, lora_config, llama_weight, lora_weight, output):
5252
lora_state_dict = lora_state_dict["state_dict"]
5353

5454
# remove prefix model.
55-
llama_state_dict = {
56-
k.replace("model.", ""): v
57-
for k, v in llama_state_dict.items()
58-
if k.startswith("model.")
59-
}
60-
lora_state_dict = {
61-
k.replace("model.", ""): v
62-
for k, v in lora_state_dict.items()
63-
if k.startswith("model.")
64-
}
55+
if any(k.startswith("model.") for k in llama_state_dict.keys()):
56+
llama_state_dict = {
57+
k.replace("model.", ""): v
58+
for k, v in llama_state_dict.items()
59+
if k.startswith("model.")
60+
}
61+
if any(k.startswith("model.") for k in lora_state_dict.keys()):
62+
lora_state_dict = {
63+
k.replace("model.", ""): v
64+
for k, v in lora_state_dict.items()
65+
if k.startswith("model.")
66+
}
6567

6668
logger.info(f"Found {len(llama_state_dict)} keys in llama model")
6769
logger.info(f"Found {len(lora_state_dict)} keys in lora model")

0 commit comments

Comments
 (0)