Skip to content

Unified Llama 3 (8b,70b) + Safetensors support #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
# fuzzy search
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
config = [config for config in transformer_configs if config.lower() in str(name).lower() or config.lower() in str(name).lower()]

# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
# take longer name (as it have more symbols matched)
if len(config) > 1:
config.sort(key=len, reverse=True)
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match

return cls(**transformer_configs[config[0]])


Expand All @@ -65,7 +65,8 @@ def from_name(cls, name: str):
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
"stories15M": dict(n_layer=6, n_head=6, dim=288),
"stories110M": dict(n_layer=12, n_head=12, dim=768),
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256),
"llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256),
"llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256),
}

class KVCache(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torch
sentencepiece
tiktoken
blobfile
154 changes: 71 additions & 83 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from pathlib import Path
from typing import Optional

from safetensors.torch import load_file as load_safetensors_file
import torch

# support running without installing as a package
Expand All @@ -28,62 +28,49 @@ def convert_hf_checkpoint(
if model_name is None:
model_name = checkpoint_dir.name

# Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
# need to be copied into model.pth.
# Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
# weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
# currently supported.
# Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
is_llama3 = "Llama-3" in model_name
if is_llama3:
# Check if we have multiple original/consolidated.NN.pth files and report error
# if we do for Llama 3.
original_dir = checkpoint_dir / "original"
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)]
if len(bin_files) > 1:
raise ValueError(
f"Multiple consolidated.NN.pth files found in {original_dir}. "
"Merging them into one model.pth file is not supported for Llama 3.")


config = ModelArgs.from_name(model_name)
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
if not is_llama3:
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"

assert model_map_json.is_file()

with open(model_map_json) as json_map:
bin_index = json.load(json_map)

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
else:
# There is no separate pytorch_model.bin.index.json file for llama3.
# Instead, we will just use all original/consolidated.NN.pth files.
# so, we use model.safetensors.index.json
weight_map = None
original_dir = checkpoint_dir / "original"
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)}

model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
model_map_json = None

try:
assert model_map_json_safetensors.is_file()
model_map_json = model_map_json_safetensors
print(f"Found safetensors index at {model_map_json_safetensors}")
except AssertionError:
print(f"{model_map_json_safetensors} not found")
if model_map_json is None:
try:
assert model_map_json_pytorch.is_file()
model_map_json = model_map_json_pytorch
print(f"Found pytorch index at {model_map_json_pytorch}")
except AssertionError:
print(f"{model_map_json_pytorch} not found")

if model_map_json is None: raise Exception("No model map found!")

with open(model_map_json) as json_map:
bin_index = json.load(json_map)

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}

def permute(w, n_head):
dim = config.dim
Expand All @@ -95,39 +82,40 @@ def permute(w, n_head):

merged_result = {}
for file in sorted(bin_files):
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)
if "safetensors" in str(file):
state_dict = load_safetensors_file(str(file), device="cpu")
merged_result.update(state_dict)
else:
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)
final_result = {}
if weight_map is not None:
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_head)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
else:
final_result = merged_result
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_head)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")
if is_llama3:
if 'llama-3' in model_name.lower():
original_dir = checkpoint_dir / "original"
tokenizer_model = original_dir / "tokenizer.model"
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
Expand Down
3 changes: 2 additions & 1 deletion tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def get_tokenizer(tokenizer_model_path, model_name):
Returns:
- TokenizerInterface: An instance of a tokenizer.
"""
if "Llama-3" in str(model_name):

if "llama-3" in str(model_name).lower():
return TiktokenWrapper(tokenizer_model_path)
else:
return SentencePieceWrapper(tokenizer_model_path)