Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mindone/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@
DebertaV2Model,
DebertaV2PreTrainedModel,
)
from .models.deepseek_v3 import DeepseekV3ForCausalLM, DeepseekV3Model, DeepseekV3PreTrainedModel
from .models.deit import (
DeiTForImageClassification,
DeiTForImageClassificationWithTeacher,
Expand Down
1 change: 1 addition & 0 deletions mindone/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
dac,
data2vec,
dbrx,
deepseek_v3,
deit,
deprecated,
depth_anything,
Expand Down
2 changes: 2 additions & 0 deletions mindone/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
("dinov2", "Dinov2Config"),
("dinov2_with_registers", "Dinov2WithRegistersConfig"),
("deit", "DeiTConfig"),
("deepseek_v3", "DeepseekV3Config"),
("distilbert", "DistilBertConfig"),
("dpr", "DPRConfig"),
("dpt", "DPTConfig"),
Expand Down Expand Up @@ -351,6 +352,7 @@
("deit", "DeiT"),
("depth_anything", "Depth Anything"),
("depth_pro", "DepthPro"),
("deepseek_v3", "DeepSeek-V3"),
("detr", "DETR"),
("diffllama", "DiffLlama"),
("dinov2", "DINOv2"),
Expand Down
2 changes: 2 additions & 0 deletions mindone/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
("deit", "DeiTModel"),
("depth_pro", "DepthProModel"),
("detr", "DetrModel"),
("deepseek_v3", "DeepseekV3Model"),
("diffllama", "DiffLlamaModel"),
("dinov2", "Dinov2Model"),
("dinov2_with_registers", "Dinov2WithRegistersModel"),
Expand Down Expand Up @@ -429,6 +430,7 @@
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForCausalLM"),
("diffllama", "DiffLlamaForCausalLM"),
("deepseek_v3", "DeepseekV3ForCausalLM"),
("emu3", "Emu3ForCausalLM"),
("falcon", "FalconForCausalLM"),
("fuyu", "FuyuForCausalLM"),
Expand Down
17 changes: 17 additions & 0 deletions mindone/transformers/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling_deepseek_v3 import *
674 changes: 674 additions & 0 deletions mindone/transformers/models/deepseek_v3/modeling_deepseek_v3.py

Large diffs are not rendered by default.

Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the Mindspore DeepseekV3 model."""

import inspect

import numpy as np
import pytest
import torch
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config

import mindspore as ms

from tests.modeling_test_utils import (
MS_DTYPE_MAPPING,
PT_DTYPE_MAPPING,
compute_diffs,
generalized_parse_args,
get_modules,
)
from tests.transformers_tests.models.modeling_common import ids_numpy

DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2}
MODES = [1]


class DeepseekV3ModelTester:
def __init__(
self,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
intermediate_size=37,
moe_intermediate_size=12,
num_hidden_layers=5,
num_attention_heads=4,
num_key_value_heads=4,
n_shared_experts=1,
n_routed_experts=8,
routed_scaling_factor=2.5,
kv_lora_rank=16,
q_lora_rank=32,
qk_rope_head_dim=16,
v_head_dim=32,
qk_nope_head_dim=32,
n_group=2,
topk_group=1,
num_experts_per_tok=8,
first_k_dense_replace=2,
norm_topk_prob=True,
aux_loss_alpha=0.001,
hidden_act="silu",
max_position_embeddings=512,
initializer_range=0.02,
attention_probs_dropout_prob=0.1,
type_vocab_size=16,
type_sequence_label_size=2,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.aux_loss_alpha = aux_loss_alpha
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope

def prepare_config_and_inputs(self):
input_ids = ids_numpy([self.batch_size, self.seq_length], self.vocab_size)

input_mask = None
if self.use_input_mask:
input_mask = np.tril(np.ones_like(input_ids))

token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_numpy([self.batch_size, self.seq_length], self.type_vocab_size)

sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_numpy([self.batch_size], self.type_sequence_label_size)
token_labels = ids_numpy([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_numpy([self.batch_size], self.num_choices)

config = self.get_config()

return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

def get_config(self):
return DeepseekV3Config(
attn_implementation="eager",
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
moe_intermediate_size=self.moe_intermediate_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
n_shared_experts=self.n_shared_experts,
n_routed_experts=self.n_routed_experts,
routed_scaling_factor=self.routed_scaling_factor,
kv_lora_rank=self.kv_lora_rank,
q_lora_rank=self.q_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
qk_nope_head_dim=self.qk_nope_head_dim,
n_group=self.n_group,
topk_group=self.topk_group,
num_experts_per_tok=self.num_experts_per_tok,
first_k_dense_replace=self.first_k_dense_replace,
norm_topk_prob=self.norm_topk_prob,
aux_loss_alpha=self.aux_loss_alpha,
hidden_act=self.hidden_act,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
use_cache=True,
pad_token_id=self.pad_token_id,
attention_dropout=self.attention_probs_dropout_prob,
)

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict


model_tester = DeepseekV3ModelTester()
(
config,
inputs_dict,
) = model_tester.prepare_config_and_inputs_for_common()


DEEPSEEKV3_CASES = [
[
"DeepseekV3Model",
"transformers.DeepseekV3Model",
"mindone.transformers.DeepseekV3Model",
(config,),
{},
(),
{"input_ids": inputs_dict["input_ids"], "attention_mask": inputs_dict["attention_mask"]},
{
"last_hidden_state": 0,
},
],
]


# transformers need >= 4.41.2
@pytest.mark.parametrize(
"name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode",
[
case
+ [
dtype,
]
+ [
mode,
]
for case in DEEPSEEKV3_CASES
for dtype in DTYPE_AND_THRESHOLDS.keys()
for mode in MODES
],
)
def test_named_modules(
name,
pt_module,
ms_module,
init_args,
init_kwargs,
inputs_args,
inputs_kwargs,
outputs_map,
dtype,
mode,
):
ms.set_context(mode=mode)

(
pt_model,
ms_model,
pt_dtype,
ms_dtype,
) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs)
pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args(
pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs
)

# set `hidden_dtype` if requiring, for some modules always compute in float
# precision and require specific `hidden_dtype` to cast before return
if "hidden_dtype" in inspect.signature(pt_model.forward).parameters:
pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]})
ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]})
with torch.no_grad():
pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs)
ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs)
# print("ms:", ms_outputs)
# print("pt:", pt_outputs)
if outputs_map:
pt_outputs_n = []
ms_outputs_n = []
for pt_key, ms_idx in outputs_map.items():
# print("===map", pt_key, ms_idx)
pt_output = getattr(pt_outputs, pt_key)
ms_output = ms_outputs[ms_idx]
if isinstance(pt_output, (list, tuple)):
pt_outputs_n += list(pt_output)
ms_outputs_n += list(ms_output)
else:
pt_outputs_n.append(pt_output)
ms_outputs_n.append(ms_output)
diffs = compute_diffs(pt_outputs_n, ms_outputs_n)
else:
diffs = compute_diffs(pt_outputs, ms_outputs)

THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype]
assert (np.array(diffs) < THRESHOLD).all(), (
f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, "
f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}"
)