From 19f87963185ef40a6e08f01a854e4250f239bad1 Mon Sep 17 00:00:00 2001 From: Chris Hua Date: Mon, 30 Sep 2024 20:38:37 -0400 Subject: [PATCH 1/3] chore: add 3.2 1B and 3B versions --- model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/model.py b/model.py index c2657995..17178c4f 100644 --- a/model.py +++ b/model.py @@ -79,6 +79,12 @@ def from_name(cls, name: str): "llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000, rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), ), + "llama-3.2-1b": dict(block_size=131072, n_layer=16, n_head=32, n_local_heads=8, dim=2048, intermediate_size=8192, vocab_size=128256, rope_base=500000, + rope_scaling=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), + ), + "llama-3.2-3b": dict(block_size=131072, n_layer=28, n_head=24, n_local_heads=8, dim=3072, intermediate_size=8192, vocab_size=128256, rope_base=500000, + rope_scaling=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), + ), } class KVCache(nn.Module): From fdea37a2729e8c3b84361434ee085bdcf0a92767 Mon Sep 17 00:00:00 2001 From: Chris Hua Date: Mon, 30 Sep 2024 21:09:43 -0400 Subject: [PATCH 2/3] feat: refactor conversion to support 1B 1B does not split model into multiple files so we do not need to merge the weights. --- scripts/convert_hf_checkpoint.py | 123 ++++++++++++++++++------------- 1 file changed, 73 insertions(+), 50 deletions(-) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index f14ba6ca..6b7dd88e 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -7,6 +7,7 @@ import re import shutil import sys +import os from pathlib import Path from typing import Optional from safetensors.torch import load_file as load_safetensors_file @@ -35,26 +36,65 @@ def convert_hf_checkpoint( model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json' model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json" model_map_json = None - + try: - assert model_map_json_safetensors.is_file() - model_map_json = model_map_json_safetensors - print(f"Found safetensors index at {model_map_json_safetensors}") + assert model_map_json_safetensors.is_file() + model_map_json = model_map_json_safetensors + print(f"Found safetensors index at {model_map_json_safetensors}") except AssertionError: - print(f"{model_map_json_safetensors} not found") - if model_map_json is None: - try: - assert model_map_json_pytorch.is_file() - model_map_json = model_map_json_pytorch - print(f"Found pytorch index at {model_map_json_pytorch}") - except AssertionError: - print(f"{model_map_json_pytorch} not found") - - if model_map_json is None: raise Exception("No model map found!") - - with open(model_map_json) as json_map: - bin_index = json.load(json_map) + print(f"{model_map_json_safetensors} not found") + if model_map_json is None: + try: + assert model_map_json_pytorch.is_file() + model_map_json = model_map_json_pytorch + print(f"Found pytorch index at {model_map_json_pytorch}") + except AssertionError: + print(f"{model_map_json_pytorch} not found") + + # If the solo safetensors file exists, we should load it directly + model_solo_safetensors = checkpoint_dir / 'model.safetensors' + if model_solo_safetensors.is_file(): + print(f"Found whole safetensors file at {model_solo_safetensors}") + merged_result = load_safetensors_file(str(model_solo_safetensors), device="cpu") + else: + if model_map_json is None: + raise Exception("No model map found!") + + with open(model_map_json) as json_map: + bin_index = json.load(json_map) + + # Refactored merging logic into a separate function + merged_result = merge_weights(checkpoint_dir, bin_index) + + # Refactored key mapping logic into a separate function + final_result = map_keys(merged_result, config) + + print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") + torch.save(final_result, checkpoint_dir / "model.pth") + + if 'llama-3-' in model_name.lower() or 'llama-3.1-' or 'llama-3.2-' in model_name.lower(): + if 'llama-3.1-405b' in model_name.lower(): + original_dir = checkpoint_dir / "original" / "mp16" + else: + original_dir = checkpoint_dir / "original" + tokenizer_model = original_dir / "tokenizer.model" + tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" + print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") + shutil.copy(tokenizer_model, tokenizer_model_tiktoken) +def merge_weights(checkpoint_dir, bin_index): + bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} + merged_result = {} + for file in sorted(bin_files): + if "safetensors" in str(file): + state_dict = load_safetensors_file(str(file), device="cpu") + merged_result.update(state_dict) + else: + state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + merged_result.update(state_dict) + return merged_result + +def map_keys(merged_result, config): weight_map = { "model.embed_tokens.weight": "tok_embeddings.weight", "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", @@ -70,60 +110,43 @@ def convert_hf_checkpoint( "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", } - bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} - def permute(w, n_head): - dim = config.dim - return ( - w.view(n_head, 2, config.head_dim // 2, dim) - .transpose(1, 2) - .reshape(config.head_dim * n_head, dim) - ) - - merged_result = {} - for file in sorted(bin_files): - if "safetensors" in str(file): - state_dict = load_safetensors_file(str(file), device="cpu") - merged_result.update(state_dict) - else: - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) - merged_result.update(state_dict) final_result = {} for key, value in merged_result.items(): if "layers" in key: abstract_key = re.sub(r'(\d+)', '{}', key) layer_num = re.search(r'\d+', key).group(0) - new_key = weight_map[abstract_key] + new_key = weight_map.get(abstract_key) if new_key is None: continue new_key = new_key.format(layer_num) else: - new_key = weight_map[key] + new_key = weight_map.get(key) - final_result[new_key] = value + if new_key is not None: + final_result[new_key] = value for key in tuple(final_result.keys()): if "wq" in key: q = final_result[key] k = final_result[key.replace("wq", "wk")] v = final_result[key.replace("wq", "wv")] - q = permute(q, config.n_head) - k = permute(k, config.n_local_heads) + q = permute(q, config.n_head, config) + k = permute(k, config.n_local_heads, config) final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) del final_result[key] del final_result[key.replace("wq", "wk")] del final_result[key.replace("wq", "wv")] - print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") - torch.save(final_result, checkpoint_dir / "model.pth") - if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower(): - if 'llama-3.1-405b' in model_name.lower(): - original_dir = checkpoint_dir / "original" / "mp16" - else: - original_dir = checkpoint_dir / "original" - tokenizer_model = original_dir / "tokenizer.model" - tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" - print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") - shutil.copy(tokenizer_model, tokenizer_model_tiktoken) + return final_result + +def permute(w, n_head, config): + dim = config.dim + return ( + w.view(n_head, 2, config.head_dim // 2, dim) + .transpose(1, 2) + .reshape(config.head_dim * n_head, dim) + ) + if __name__ == '__main__': import argparse From 405962bb7361a776e0e5b0e026d800bffcde8d87 Mon Sep 17 00:00:00 2001 From: Chris Hua Date: Mon, 30 Sep 2024 22:09:16 -0400 Subject: [PATCH 3/3] fix: support tied embeddings neccessary for the smol guys --- scripts/convert_hf_checkpoint.py | 134 ++++++++++++++++++------------- 1 file changed, 76 insertions(+), 58 deletions(-) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 6b7dd88e..a45b955c 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -7,7 +7,7 @@ import re import shutil import sys -import os +from typing import Dict from pathlib import Path from typing import Optional from safetensors.torch import load_file as load_safetensors_file @@ -23,7 +23,9 @@ @torch.inference_mode() def convert_hf_checkpoint( *, - checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"), + checkpoint_dir: Path = Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf" + ), model_name: Optional[str] = None, ) -> None: if model_name is None: @@ -32,8 +34,33 @@ def convert_hf_checkpoint( config = ModelArgs.from_name(model_name) print(f"Model config {config.__dict__}") - # Load the json file containing weight mapping - model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json' + # Check for solo safetensors file + model_solo_safetensors = checkpoint_dir / "model.safetensors" + if model_solo_safetensors.is_file(): + print(f"Found whole safetensors file at {model_solo_safetensors}") + state_dict = load_safetensors_file(str(model_solo_safetensors), device="cpu") + else: + # If solo file doesn't exist, merge indices + state_dict = merge_model_indices(checkpoint_dir) + + final_result = process_state_dict(state_dict, config) + + print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") + torch.save(final_result, checkpoint_dir / "model.pth") + + if "llama-3-" in model_name.lower() or "llama-3.1-" in model_name.lower(): + if "llama-3.1-405b" in model_name.lower(): + original_dir = checkpoint_dir / "original" / "mp16" + else: + original_dir = checkpoint_dir / "original" + tokenizer_model = original_dir / "tokenizer.model" + tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" + print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") + shutil.copy(tokenizer_model, tokenizer_model_tiktoken) + + +def merge_model_indices(checkpoint_dir: Path) -> Dict[str, torch.Tensor]: + model_map_json_safetensors = checkpoint_dir / "model.safetensors.index.json" model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json" model_map_json = None @@ -51,58 +78,37 @@ def convert_hf_checkpoint( except AssertionError: print(f"{model_map_json_pytorch} not found") - # If the solo safetensors file exists, we should load it directly - model_solo_safetensors = checkpoint_dir / 'model.safetensors' - if model_solo_safetensors.is_file(): - print(f"Found whole safetensors file at {model_solo_safetensors}") - merged_result = load_safetensors_file(str(model_solo_safetensors), device="cpu") - else: - if model_map_json is None: - raise Exception("No model map found!") - - with open(model_map_json) as json_map: - bin_index = json.load(json_map) + if model_map_json is None: + raise Exception("No model map found!") - # Refactored merging logic into a separate function - merged_result = merge_weights(checkpoint_dir, bin_index) + with open(model_map_json) as json_map: + bin_index = json.load(json_map) - # Refactored key mapping logic into a separate function - final_result = map_keys(merged_result, config) - - print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") - torch.save(final_result, checkpoint_dir / "model.pth") - - if 'llama-3-' in model_name.lower() or 'llama-3.1-' or 'llama-3.2-' in model_name.lower(): - if 'llama-3.1-405b' in model_name.lower(): - original_dir = checkpoint_dir / "original" / "mp16" - else: - original_dir = checkpoint_dir / "original" - tokenizer_model = original_dir / "tokenizer.model" - tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" - print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") - shutil.copy(tokenizer_model, tokenizer_model_tiktoken) - -def merge_weights(checkpoint_dir, bin_index): bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} + merged_result = {} for file in sorted(bin_files): if "safetensors" in str(file): state_dict = load_safetensors_file(str(file), device="cpu") - merged_result.update(state_dict) else: - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) - merged_result.update(state_dict) + state_dict = torch.load( + str(file), map_location="cpu", mmap=True, weights_only=True + ) + merged_result.update(state_dict) return merged_result -def map_keys(merged_result, config): + +def process_state_dict( + state_dict: Dict[str, torch.Tensor], config: ModelArgs +) -> Dict[str, torch.Tensor]: weight_map = { "model.embed_tokens.weight": "tok_embeddings.weight", "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, - 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", @@ -112,10 +118,10 @@ def map_keys(merged_result, config): } final_result = {} - for key, value in merged_result.items(): + for key, value in state_dict.items(): if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) + abstract_key = re.sub(r"(\d+)", "{}", key) + layer_num = re.search(r"\d+", key).group(0) new_key = weight_map.get(abstract_key) if new_key is None: continue @@ -123,36 +129,48 @@ def map_keys(merged_result, config): else: new_key = weight_map.get(key) - if new_key is not None: + if new_key: final_result[new_key] = value + # tie embeddings if the output weight does not exist + # necessary for 1B and 3B models + if "output.weight" not in final_result: + print("Tying embeddings - this is only necessary for 1B and 3B models") + final_result["output.weight"] = final_result["tok_embeddings.weight"] + + def permute(w, n_head): + dim = config.dim + return ( + w.view(n_head, 2, config.head_dim // 2, dim) + .transpose(1, 2) + .reshape(config.head_dim * n_head, dim) + ) + for key in tuple(final_result.keys()): if "wq" in key: q = final_result[key] k = final_result[key.replace("wq", "wk")] v = final_result[key.replace("wq", "wv")] - q = permute(q, config.n_head, config) - k = permute(k, config.n_local_heads, config) + q = permute(q, config.n_head) + k = permute(k, config.n_local_heads) final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) del final_result[key] del final_result[key.replace("wq", "wk")] del final_result[key.replace("wq", "wv")] - return final_result -def permute(w, n_head, config): - dim = config.dim - return ( - w.view(n_head, 2, config.head_dim // 2, dim) - .transpose(1, 2) - .reshape(config.head_dim * n_head, dim) - ) + return final_result -if __name__ == '__main__': +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') - parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) - parser.add_argument('--model_name', type=str, default=None) + + parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") + parser.add_argument( + "--checkpoint_dir", + type=Path, + default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"), + ) + parser.add_argument("--model_name", type=str, default=None) args = parser.parse_args() convert_hf_checkpoint(