Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Nemotron GGUF Loading Support #34725

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ For now the supported model architectures are the architectures that have been v
- Starcoder2
- T5
- Mamba
- Nemotron

## Example usage

Expand Down
27 changes: 27 additions & 0 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@
"output_norm": "backbone.norm_f",
"output.weight": "lm_head.weight",
},
"nemotron": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}


Expand Down Expand Up @@ -397,6 +411,18 @@
"ssm.time_step_rank": "time_step_rank",
"ssm.inner_size": "intermediate_size",
},
"nemotron": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "norm_eps",
"vocab_size": "vocab_size",
},
}

GGUF_TOKENIZER_MAPPING = {
Expand Down Expand Up @@ -793,6 +819,7 @@ def converted(self) -> Tokenizer:
"starcoder2": GGUFGPTConverter,
"t5": GGUFT5Converter,
"mamba": GGUFGPTConverter,
"nemotron": GGUFGPTConverter,
}


Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
)
model_size = m.group().strip("-") # only keeps `7b`

if "nemotron" in architecture:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you explicitly assign architecture to updated one if it is the same?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry if my answer seems obvious, but isn't it for addressing cases where the "architecture" does not only contain "nemotron"? I took reference on what you did for the qwen2moe, so I think It's better to also do it for nemotron. But I tested it without these lines and it passes through. What do you think? And thank you for reviewing! As this is my first time contributing, please let me know if anything seems odds or is there any better implementation. Thank you!

Copy link
Contributor

@VladOS95-cyber VladOS95-cyber Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, for qwen2moe, I explicitly assigned another architecture name, because gguf file contains qwen2moe, but later, execution chain expects to get qwen2_moe for config, model processing and so on. You provided the same name "nemotron". So, there is no reason to explicitly assign updated architecture to the same name and even to mention nemotron, because gguf processing takes it from config by default.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry for misreading the code and thank you for pointing it out! I have deleted the unnecessary lines in the new commit. Please let me know if there is something needs to be fixed.

updated_architecture = "nemotron"

if architecture + model_size not in GGUF_SUPPORTED_ARCHITECTURES:
raise ValueError(f"Architecture {architecture + model_size} not supported")

Expand Down
40 changes: 40 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class GgufIntegrationTests(unittest.TestCase):
starcoder2_original_model_id = "bigcode/starcoder2-3b"
mamba_original_model_id = "state-spaces/mamba-2.8b-hf"
mamba_model_id = "jpodivin/mamba-2.8b-hf-GGUF"
nemotron_original_model_id = "nvidia/Nemotron-Mini-4B-Instruct"
nemotron_model_id = "bartowski/Nemotron-Mini-4B-Instruct-GGUF"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand Down Expand Up @@ -106,6 +108,8 @@ class GgufIntegrationTests(unittest.TestCase):
fp16_starcoder2_gguf_model_id = "starcoder2-3b.fp16.gguf"
q6_k_mamba_model_id = "ggml-model-Q6_K.gguf"
fp16_mamba_model_id = "ggml-model-f16.gguf"
q6_k_nemotron_model_id = "Nemotron-Mini-4B-Instruct-Q6_K.gguf"
fp16_nemotron_model_id = "Nemotron-Mini-4B-Instruct-f16.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -792,6 +796,42 @@ def test_mamba_q6_k(self):
EXPECTED_TEXT = "Hello,I answerthe question.\n\nA"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_nemotron_weights_conversion_fp16(self):
original_model = AutoModelForCausalLM.from_pretrained(
self.nemotron_original_model_id,
torch_dtype=torch.float16,
)

converted_model = AutoModelForCausalLM.from_pretrained(
self.nemotron_model_id,
gguf_file=self.fp16_nemotron_model_id,
torch_dtype=torch.float16,
)

converted_state_dict = converted_model.state_dict()
original_state_dict = original_model.state_dict()

for layer_name, original_params in original_state_dict.items():
if layer_name in converted_state_dict:
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")

def test_nemotron_q6_k(self):
model = AutoModelForCausalLM.from_pretrained(
self.nemotron_model_id,
gguf_file=self.q6_k_nemotron_model_id,
torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(self.nemotron_model_id, gguf_file=self.q6_k_nemotron_model_id)
text = tokenizer(self.example_text, return_tensors="pt")["input_ids"]
out = model.generate(text, max_new_tokens=10)

EXPECTED_TEXT = "'Hello. hotmail.com.'"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down