Skip to content

Conversation

@iugoood
Copy link
Contributor

@iugoood iugoood commented Nov 5, 2025

Add

1 add deepseek_v3 model
2 add UT

ps: Quantitative weights cannot be validated.

Usage

from transformers import AutoTokenizer
from mindone.transformers import DeepseekV3ForCausalLM
import mindspore as ms
model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="np")
# Generate
generate_ids = model.generate(ms.tensor(inputs.input_ids), max_length=30)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

Performance

Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.6.0.

model mode speed
deepseek_v3 pynative 3.32 token/s

@iugoood iugoood requested a review from vigo999 as a code owner November 5, 2025 08:11
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @iugoood, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the mindone.transformers library by introducing full support for the DeepseekV3 large language model. It encompasses the addition of the model's core architecture, its integration into the existing auto-loading and configuration systems, and the inclusion of robust tests to verify its correctness and performance within the MindSpore framework. This enhancement allows users to leverage DeepseekV3 for various natural language processing tasks using MindSpore.

Highlights

  • DeepseekV3 Model Integration: The DeepseekV3 model, including its CausalLM, base model, and pre-trained model components, has been added to the mindone.transformers library.
  • Modular Architecture Support: New files for the DeepseekV3 model's modular components (e.g., RMSNorm, RotaryEmbedding, MLP, MoE, Attention, DecoderLayer) are introduced, adapted from HuggingFace Transformers.
  • Auto-Configuration and Auto-Modeling Updates: The library's auto-configuration and auto-modeling mechanisms have been updated to correctly recognize and load the DeepseekV3 model and its associated configuration.
  • Comprehensive Testing: A dedicated test suite (test_modeling_deepseek_v3.py) has been added to validate the MindSpore implementation against the original HuggingFace Transformers version, ensuring functional parity and numerical stability across different data types.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the deepseek_v3 model, which appears to be based on the DeepSeek-V2 architecture. The changes are well-structured, including updates to auto-configuration classes, the model implementation, and a new test suite. My review focuses on the new model implementation in modeling_deepseek_v3.py. I've identified a few key areas for improvement, primarily concerning performance. The Mixture-of-Experts (MoE) layer and the interleaved rotary position embeddings are currently implemented in a way that could lead to significant performance bottlenecks. Additionally, the example code in the model's docstring uses an invalid model identifier, which should be corrected to ensure it's useful for users.

Comment on lines +165 to +189
def moe(self, hidden_states: mindspore.Tensor, topk_indices: mindspore.Tensor, topk_weights: mindspore.Tensor):
r"""
CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
to not have to do a loop here (deepseek has 256 experts soooo yeah).
"""
final_hidden_states = mindspore.mint.zeros_like(hidden_states, dtype=topk_weights.dtype)
expert_mask = mindspore.mint.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
expert_mask = expert_mask.permute(2, 0, 1)

for expert_idx in range(len(self.experts)):
expert = self.experts[expert_idx]
mask = expert_mask[expert_idx]
token_indices, weight_indices = mindspore.mint.where(mask)

if token_indices.numel() > 0:
expert_weights = topk_weights[token_indices, weight_indices]
expert_input = hidden_states[token_indices]
expert_output = expert(expert_input)
weighted_output = expert_output * expert_weights.unsqueeze(-1)
final_hidden_states.index_add_(0, token_indices, weighted_output)

# in original deepseek, the output of the experts are gathered once we leave this module
# thus the moe module is itelsf an IsolatedParallel module
# and all expert are "local" meaning we shard but we don't gather
return final_hidden_states.type(hidden_states.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of the moe method iterates over each expert in a Python for loop. As acknowledged by the "CALL FOR CONTRIBUTION" comment, this is inefficient and will be a significant performance bottleneck, especially with a large number of experts. For better performance, this loop should be replaced with vectorized operations. A common optimization pattern for MoE layers is to use gather and scatter operations to group tokens by their assigned expert and then perform computations in a batched manner.

Comment on lines +273 to +308
def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
r"""
TODO let's just use the original freqcis computation to not have the view
transpose + reshape! This is not optimized!
Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`mindspore.Tensor`): The query tensor.
k (`mindspore.Tensor`): The key tensor.
cos (`mindspore.Tensor`): The cosine part of the rotary embedding.
sin (`mindspore.Tensor`): The sine part of the rotary embedding.
position_ids (`mindspore.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(mindspore.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)

b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The apply_rotary_pos_emb_interleave function, as noted by the TODO comment, is not optimized. It uses several view, transpose, and reshape operations to reorder the query and key tensors before applying rotary embeddings. These data layout transformations are computationally expensive and should be avoided. Consider refactoring this function to compute the rotary embeddings more directly on the interleaved layout to improve performance.

Comment on lines +634 to +635
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The example in the DeepseekV3ForCausalLM docstring uses the model identifier "meta-deepseek_v3/DeepseekV3-2-7b-hf", which does not exist on the Hugging Face Hub and seems to be a result of a copy-paste error. This makes the example unusable and confusing for users. Please update it to use a valid public model identifier for a DeepSeek model, for example, one from the deepseek-ai organization.

Suggested change
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
>>> model = DeepseekV3ForCausalLM.from_pretrained("deepseek-ai/DeepSeek-V2-Lite-Base")
>>> tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V2-Lite-Base")

@iugoood iugoood force-pushed the deepseek_v3_master branch from 35188da to 9de9b8d Compare November 7, 2025 01:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant