diff --git a/eval.py b/eval.py index 7e8f841e..d38abf86 100644 --- a/eval.py +++ b/eval.py @@ -18,7 +18,7 @@ torch._inductor.config.triton.cudagraphs = True torch._dynamo.config.cache_size_limit = 100000 -from sentencepiece import SentencePieceProcessor +from tokenizer import get_tokenizer from model import Transformer @@ -217,7 +217,7 @@ def main( assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), tokenizer_path + assert tokenizer_path.is_file(), str(tokenizer_path) device = 'cuda' precision = torch.bfloat16 @@ -231,7 +231,7 @@ def main( model.eval() - tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) torch.manual_seed(1234) diff --git a/generate.py b/generate.py index 8446d115..24ba553d 100644 --- a/generate.py +++ b/generate.py @@ -32,10 +32,8 @@ def device_sync(device): wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from sentencepiece import SentencePieceProcessor - from model import Transformer - +from tokenizer import get_tokenizer def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization q = torch.empty_like(probs_sort).exponential_(1) @@ -269,7 +267,7 @@ def main( assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), tokenizer_path + assert tokenizer_path.is_file(), str(tokenizer_path) global print from tp import maybe_init_dist @@ -297,7 +295,8 @@ def main( device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") - tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) prompt_length = encoded.size(0) diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index ffe71131..9aa076b6 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -175,7 +175,7 @@ def main( assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), tokenizer_path + assert tokenizer_path.is_file(), str(tokenizer_path) global print rank = maybe_init_dist() diff --git a/model.py b/model.py index fbb60405..0660bc2b 100644 --- a/model.py +++ b/model.py @@ -65,6 +65,7 @@ def from_name(cls, name: str): "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), "stories15M": dict(n_layer=6, n_head=6, dim=288), "stories110M": dict(n_layer=12, n_head=12, dim=768), + "Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256), } class KVCache(nn.Module): diff --git a/quantize.py b/quantize.py index af17a698..4ebbe5f5 100644 --- a/quantize.py +++ b/quantize.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from sentencepiece import SentencePieceProcessor +from tokenizer import get_tokenizer try: from GPTQ import GenericGPTQRunner, InputRecorder @@ -578,8 +578,8 @@ def quantize( quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize) tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), tokenizer_path - tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + assert tokenizer_path.is_file(), str(tokenizer_path) + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) quantized_state_dict = quant_handler.create_quantized_state_dict( tokenizer, diff --git a/requirements.txt b/requirements.txt index 762cb095..04f828ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ torch sentencepiece +tiktoken diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index b92114c4..8a221067 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import json import re +import shutil import sys from pathlib import Path from typing import Optional @@ -27,33 +28,62 @@ def convert_hf_checkpoint( if model_name is None: model_name = checkpoint_dir.name + # Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files + # need to be copied into model.pth. + # Llama 3 70B can't be easily merged into one model.pth file, though, since names of the + # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not + # currently supported. + # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken + is_llama3 = "Llama-3" in model_name + if is_llama3: + # Check if we have multiple original/consolidated.NN.pth files and report error + # if we do for Llama 3. + original_dir = checkpoint_dir / "original" + pattern = re.compile(r"^consolidated\.\d{2}\.pth$") + bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)] + if len(bin_files) > 1: + raise ValueError( + f"Multiple consolidated.NN.pth files found in {original_dir}. " + "Merging them into one model.pth file is not supported for Llama 3.") + + config = ModelArgs.from_name(model_name) print(f"Model config {config.__dict__}") # Load the json file containing weight mapping - model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" - - assert model_map_json.is_file() - - with open(model_map_json) as json_map: - bin_index = json.load(json_map) - - weight_map = { - "model.embed_tokens.weight": "tok_embeddings.weight", - "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", - "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", - "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", - "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, - 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', - "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", - "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", - "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", - "model.norm.weight": "norm.weight", - "lm_head.weight": "output.weight", - } - bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} + if not is_llama3: + model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" + + assert model_map_json.is_file() + + with open(model_map_json) as json_map: + bin_index = json.load(json_map) + + weight_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, + 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} + else: + # There is no separate pytorch_model.bin.index.json file for llama3. + # Instead, we will just use all original/consolidated.NN.pth files. + # so, we use model.safetensors.index.json + weight_map = None + original_dir = checkpoint_dir / "original" + pattern = re.compile(r"^consolidated\.\d{2}\.pth$") + bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)} + def permute(w, n_head): dim = config.dim @@ -68,32 +98,41 @@ def permute(w, n_head): state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) merged_result.update(state_dict) final_result = {} - for key, value in merged_result.items(): - if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) - new_key = weight_map[abstract_key] - if new_key is None: - continue - new_key = new_key.format(layer_num) - else: - new_key = weight_map[key] - - final_result[new_key] = value - - for key in tuple(final_result.keys()): - if "wq" in key: - q = final_result[key] - k = final_result[key.replace("wq", "wk")] - v = final_result[key.replace("wq", "wv")] - q = permute(q, config.n_head) - k = permute(k, config.n_local_heads) - final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) - del final_result[key] - del final_result[key.replace("wq", "wk")] - del final_result[key.replace("wq", "wv")] + if weight_map is not None: + for key, value in merged_result.items(): + if "layers" in key: + abstract_key = re.sub(r'(\d+)', '{}', key) + layer_num = re.search(r'\d+', key).group(0) + new_key = weight_map[abstract_key] + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = weight_map[key] + + final_result[new_key] = value + + for key in tuple(final_result.keys()): + if "wq" in key: + q = final_result[key] + k = final_result[key.replace("wq", "wk")] + v = final_result[key.replace("wq", "wv")] + q = permute(q, config.n_head) + k = permute(k, config.n_local_heads) + final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) + del final_result[key] + del final_result[key.replace("wq", "wk")] + del final_result[key.replace("wq", "wv")] + else: + final_result = merged_result print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") torch.save(final_result, checkpoint_dir / "model.pth") + if is_llama3: + original_dir = checkpoint_dir / "original" + tokenizer_model = original_dir / "tokenizer.model" + tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" + print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") + shutil.copy(tokenizer_model, tokenizer_model_tiktoken) if __name__ == '__main__': import argparse diff --git a/tokenizer.py b/tokenizer.py new file mode 100644 index 00000000..c62a0c5b --- /dev/null +++ b/tokenizer.py @@ -0,0 +1,111 @@ +import os +import sentencepiece as spm +import tiktoken +from tiktoken.load import load_tiktoken_bpe +from pathlib import Path +from typing import Dict + +class TokenizerInterface: + def __init__(self, model_path): + self.model_path = model_path + + def encode(self, text): + raise NotImplementedError("This method should be overridden by subclasses.") + + def decode(self, tokens): + raise NotImplementedError("This method should be overridden by subclasses.") + + def bos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + + def eos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + +class SentencePieceWrapper(TokenizerInterface): + def __init__(self, model_path): + super().__init__(model_path) + self.processor = spm.SentencePieceProcessor(str(model_path)) + + def encode(self, text): + return self.processor.EncodeAsIds(text) + + def decode(self, tokens): + return self.processor.DecodeIds(tokens) + + def bos_id(self): + return self.processor.bos_id() + + def eos_id(self): + return self.processor.eos_id() + +class TiktokenWrapper(TokenizerInterface): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path): + super().__init__(model_path) + assert os.path.isfile(model_path), str(model_path) + mergeable_ranks = load_tiktoken_bpe(str(model_path)) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + # BOS / EOS token IDs + self._bos_id: int = self.special_tokens["<|begin_of_text|>"] + self._eos_id: int = self.special_tokens["<|end_of_text|>"] + + def encode(self, text): + return self.model.encode(text) + + def decode(self, tokens): + return self.model.decode(tokens) + + def bos_id(self): + return self._bos_id + + def eos_id(self): + return self._eos_id + +def get_tokenizer(tokenizer_model_path, model_name): + """ + Factory function to get the appropriate tokenizer based on the model name. + + Args: + - tokenizer_model_path (str): The file path to the tokenizer model. + - model_name (str): The name of the model, used to determine the tokenizer type. + + Returns: + - TokenizerInterface: An instance of a tokenizer. + """ + if "Llama-3" in str(model_name): + return TiktokenWrapper(tokenizer_model_path) + else: + return SentencePieceWrapper(tokenizer_model_path)