From fb267450038a86294a11f87877af4365401cdd18 Mon Sep 17 00:00:00 2001 From: farrosalferro Date: Thu, 14 Nov 2024 13:16:44 +0900 Subject: [PATCH 1/2] Add Nemotron GGUF Loading Support --- docs/source/en/gguf.md | 1 + src/transformers/integrations/ggml.py | 27 +++++++++++++ .../modeling_gguf_pytorch_utils.py | 3 ++ tests/quantization/ggml/test_ggml.py | 40 +++++++++++++++++++ 4 files changed, 71 insertions(+) diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index 2da721b28986af..b1ed1f0d492ab9 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -87,6 +87,7 @@ For now the supported model architectures are the architectures that have been v - Starcoder2 - T5 - Mamba +- Nemotron ## Example usage diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index f4545f2698c017..57f0af5667e648 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -248,6 +248,20 @@ "output_norm": "backbone.norm_f", "output.weight": "lm_head.weight", }, + "nemotron": { + "token_embd": "model.embed_tokens", + "blk": "model.layers", + "ffn_up": "mlp.up_proj", + "ffn_down": "mlp.down_proj", + "ffn_norm": "post_attention_layernorm", + "attn_norm": "input_layernorm", + "attn_q": "self_attn.q_proj", + "attn_v": "self_attn.v_proj", + "attn_k": "self_attn.k_proj", + "attn_output": "self_attn.o_proj", + "output.weight": "lm_head.weight", + "output_norm": "model.norm", + }, } @@ -397,6 +411,18 @@ "ssm.time_step_rank": "time_step_rank", "ssm.inner_size": "intermediate_size", }, + "nemotron": { + "context_length": "max_position_embeddings", + "block_count": "num_hidden_layers", + "feed_forward_length": "intermediate_size", + "embedding_length": "hidden_size", + "rope.dimension_count": None, + "rope.freq_base": "rope_theta", + "attention.head_count": "num_attention_heads", + "attention.head_count_kv": "num_key_value_heads", + "attention.layer_norm_rms_epsilon": "norm_eps", + "vocab_size": "vocab_size", + }, } GGUF_TOKENIZER_MAPPING = { @@ -793,6 +819,7 @@ def converted(self) -> Tokenizer: "starcoder2": GGUFGPTConverter, "t5": GGUFT5Converter, "mamba": GGUFGPTConverter, + "nemotron": GGUFGPTConverter, } diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index f58bf330ce7db3..3e720c266c7832 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -129,6 +129,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): ) model_size = m.group().strip("-") # only keeps `7b` + if "nemotron" in architecture: + updated_architecture = "nemotron" + if architecture + model_size not in GGUF_SUPPORTED_ARCHITECTURES: raise ValueError(f"Architecture {architecture + model_size} not supported") diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 84278e7032537b..42b05f18449ded 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -61,6 +61,8 @@ class GgufIntegrationTests(unittest.TestCase): starcoder2_original_model_id = "bigcode/starcoder2-3b" mamba_original_model_id = "state-spaces/mamba-2.8b-hf" mamba_model_id = "jpodivin/mamba-2.8b-hf-GGUF" + nemotron_original_model_id = "nvidia/Nemotron-Mini-4B-Instruct" + nemotron_model_id = "bartowski/Nemotron-Mini-4B-Instruct-GGUF" # standard quants q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" @@ -106,6 +108,8 @@ class GgufIntegrationTests(unittest.TestCase): fp16_starcoder2_gguf_model_id = "starcoder2-3b.fp16.gguf" q6_k_mamba_model_id = "ggml-model-Q6_K.gguf" fp16_mamba_model_id = "ggml-model-f16.gguf" + q6_k_nemotron_model_id = "Nemotron-Mini-4B-Instruct-Q6_K.gguf" + fp16_nemotron_model_id = "Nemotron-Mini-4B-Instruct-f16.gguf" example_text = "Hello" @@ -792,6 +796,42 @@ def test_mamba_q6_k(self): EXPECTED_TEXT = "Hello,I answerthe question.\n\nA" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_nemotron_weights_conversion_fp16(self): + original_model = AutoModelForCausalLM.from_pretrained( + self.nemotron_original_model_id, + torch_dtype=torch.float16, + ) + + converted_model = AutoModelForCausalLM.from_pretrained( + self.nemotron_model_id, + gguf_file=self.fp16_nemotron_model_id, + torch_dtype=torch.float16, + ) + + converted_state_dict = converted_model.state_dict() + original_state_dict = original_model.state_dict() + + for layer_name, original_params in original_state_dict.items(): + if layer_name in converted_state_dict: + self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape) + torch.testing.assert_close(original_params, converted_state_dict[layer_name]) + else: + raise ValueError(f"Layer {layer_name} is not presented in GGUF model") + + def test_nemotron_q6_k(self): + model = AutoModelForCausalLM.from_pretrained( + self.nemotron_model_id, + gguf_file=self.q6_k_nemotron_model_id, + torch_dtype=torch.float16, + ) + + tokenizer = AutoTokenizer.from_pretrained(self.nemotron_model_id, gguf_file=self.q6_k_nemotron_model_id) + text = tokenizer(self.example_text, return_tensors="pt")["input_ids"] + out = model.generate(text, max_new_tokens=10) + + EXPECTED_TEXT = "'Hello. hotmail.com.'" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_tokenization_xnli(self): import tqdm from datasets import load_dataset From f0e85e92196f26e72767287eebe2fa9ac5b28eec Mon Sep 17 00:00:00 2001 From: farrosalferro Date: Fri, 15 Nov 2024 11:37:16 +0900 Subject: [PATCH 2/2] fix the Nemotron architecture assignation --- src/transformers/modeling_gguf_pytorch_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 3e720c266c7832..f58bf330ce7db3 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -129,9 +129,6 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): ) model_size = m.group().strip("-") # only keeps `7b` - if "nemotron" in architecture: - updated_architecture = "nemotron" - if architecture + model_size not in GGUF_SUPPORTED_ARCHITECTURES: raise ValueError(f"Architecture {architecture + model_size} not supported")