From 81e74151f1f97a027c63ccbfcb39b12f4820d3a4 Mon Sep 17 00:00:00 2001 From: Fzilan Date: Tue, 28 Oct 2025 17:40:16 +0800 Subject: [PATCH 1/4] add hunyuanv1 modeling scripts --- mindone/transformers/models/__init__.py | 2 +- .../models/auto/configuration_auto.py | 4 + .../transformers/models/auto/modeling_auto.py | 7 + .../models/hunyuan_v1_dense/__init__.py | 1 + .../modeling_hunyuan_v1_dense.py | 510 +++++++++++++++ .../models/hunyuan_v1_moe/__init__.py | 1 + .../hunyuan_v1_moe/modeling_hunyuan_v1_moe.py | 581 ++++++++++++++++++ 7 files changed, 1105 insertions(+), 1 deletion(-) create mode 100644 mindone/transformers/models/hunyuan_v1_dense/__init__.py create mode 100644 mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py create mode 100644 mindone/transformers/models/hunyuan_v1_moe/__init__.py create mode 100644 mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index 28e7ced270..980ed77c7c 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -275,4 +275,4 @@ from . import glm4v, minimax, qwen2_5_omni, vjepa2 if version.parse(transformers.__version__) >= version.parse("4.57.0"): - from . import qwen3_vl, qwen3_vl_moe + from . import qwen3_vl, qwen3_vl_moe, hunyuan_v1_dense, hunyuan_v1_moe diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index b1d25744a3..ec30c17327 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -126,6 +126,8 @@ ("helium", "HeliumConfig"), ("hiera", "HieraConfig"), ("hubert", "HubertConfig"), + ("hunyuan_v1_dense", "HunYuanDenseV1Config"), + ("hunyuan_v1_moe", "HunYuanMoEV1Config"), ("ibert", "IBertConfig"), ("idefics", "IdeficsConfig"), ("idefics2", "Idefics2Config"), @@ -394,6 +396,8 @@ ("helium", "Helium"), ("hiera", "Hiera"), ("hubert", "Hubert"), + ("hunyuan_v1_dense", "HunYuanDenseV1"), + ("hunyuan_v1_moe", "HunYuanMoeV1"), ("ibert", "I-BERT"), ("idefics", "IDEFICS"), ("idefics2", "Idefics2"), diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index c89e50bc47..43d8d49945 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -120,6 +120,8 @@ ("helium", "HeliumModel"), ("hiera", "HieraModel"), ("hubert", "HubertModel"), + ("hunyuan_v1_dense", "HunYuanDenseV1Model"), + ("hunyuan_v1_moe", "HunYuanMoEV1Model"), ("ibert", "IBertModel"), ("idefics", "IdeficsModel"), ("idefics2", "Idefics2Model"), @@ -449,6 +451,9 @@ ("granite", "GraniteForCausalLM"), ("granitemoe", "GraniteMoeForCausalLM"), ("granitemoeshared", "GraniteMoeSharedForCausalLM"), + ("helium", "HeliumForCausalLM"), + ("hunyuan_v1_dense", "HunYuanDenseV1ForCausalLM"), + ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"), ("jetmoe", "JetMoeForCausalLM"), ("layoutlm", "LayoutLMForMaskedLM"), ("jamba", "JambaForCausalLM"), @@ -879,6 +884,8 @@ ("gpt_neox", "SequenceClassification"), ("helium", "HeliumForSequenceClassification"), ("hubert", "HubertForSequenceClassification"), + ("hunyuan_v1_dense", "HunYuanDenseV1ForSequenceClassification"), + ("hunyuan_v1_moe", "HunYuanMoEV1ForSequenceClassification"), ("ibert", "IBertForSequenceClassification"), ("jamba", "JambaForSequenceClassification"), ("jetmoe", "JetMoeForSequenceClassification"), diff --git a/mindone/transformers/models/hunyuan_v1_dense/__init__.py b/mindone/transformers/models/hunyuan_v1_dense/__init__.py new file mode 100644 index 0000000000..b3301ea1d2 --- /dev/null +++ b/mindone/transformers/models/hunyuan_v1_dense/__init__.py @@ -0,0 +1 @@ +from .modeling_hunyuan_v1_dense import * diff --git a/mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py new file mode 100644 index 0000000000..8b758f1fce --- /dev/null +++ b/mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -0,0 +1,510 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_hunyuan_v1_dense.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright (C) 2025 THL A29 Limited, a Tencent company and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import mindspore as ms +from mindspore import nn, mint + +from ...cache_utils import Cache + +from ...activations import ACT2FN +from ...cache_utils import DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, can_return_tuple +from ...utils.generic import check_model_inputs +from transformers import HunYuanDenseV1Config +from transformers.utils.deprecation import deprecate_kwarg +from mindone.models.utils import normal_, trunc_normal_, zeros_ + + +class HunYuanDenseV1RMSNorm(nn.Cell): + def __init__(self, hidden_size, eps=1e-6): + """ + HunYuanDenseV1RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = ms.Parameter(mint.ones(hidden_size)) + self.variance_epsilon = eps + + def construct(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class HunYuanDenseV1MLP(nn.Cell): + def __init__(self, config: HunYuanDenseV1Config, layer_idx=None, is_shared_mlp=False): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = mint.nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + self.layer_idx = layer_idx + + def construct(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return mint.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`ms.Tensor`): The query tensor. + k (`ms.Tensor`): The key tensor. + cos (`ms.Tensor`): The cosine part of the rotary embedding. + sin (`ms.Tensor`): The sine part of the rotary embedding. + position_ids (`ms.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(ms.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].broadcast_to(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Cell, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = mint.matmul(query, key_states.swapaxes(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype) + attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mint.matmul(attn_weights, value_states) + attn_output = attn_output.swapaxes(1, 2).contiguous() + + return attn_output, attn_weights + + +class HunYuanDenseV1Attention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: HunYuanDenseV1Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = mint.nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = mint.nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.query_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def construct( + self, + hidden_states: ms.Tensor, + position_embeddings: tuple[ms.Tensor, ms.Tensor], + attention_mask: Optional[ms.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[ms.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[ms.Tensor, ms.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class HunYuanDenseV1DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: HunYuanDenseV1Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = HunYuanDenseV1Attention(config=config, layer_idx=layer_idx) + + self.mlp = HunYuanDenseV1MLP(config) + self.input_layernorm = HunYuanDenseV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = HunYuanDenseV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[ms.Tensor] = None, + position_embeddings: Optional[tuple[ms.Tensor, ms.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> ms.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class HunYuanDenseV1PreTrainedModel(PreTrainedModel): + config: HunYuanDenseV1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["HunYuanDenseV1DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": HunYuanDenseV1DecoderLayer, + "attentions": HunYuanDenseV1Attention, + } + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, mint.nn.Linear): + normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + zeros_(module.bias) + elif isinstance(module, mint.nn.Embedding): + normal_(module.weight, mean=0.0, std=std) + if module.padding_idx is not None: + zeros_(module.weight.data[module.padding_idx]) + + +class HunYuanDenseV1RotaryEmbedding(nn.Cell): + inv_freq: ms.Tensor # fix linting for `register_buffer` + + def __init__(self, config: HunYuanDenseV1Config): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "dynamic" and config.rope_scaling["alpha"]: + # DynamicNTKAlphaRotary + self.dim = config.head_dim + base = config.rope_theta * config.rope_scaling.get("alpha") ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (mint.arange(0, self.dim, 2).float() / self.dim)) + self.attention_scaling = 1.0 + else: + inv_freq, self.attention_scaling = self.rope_init_fn(self.config) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def construct(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = mint.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class HunYuanDenseV1Model(HunYuanDenseV1PreTrainedModel): + def __init__(self, config: HunYuanDenseV1Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = mint.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.CellList( + [HunYuanDenseV1DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = HunYuanDenseV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = HunYuanDenseV1RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + cache_position: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: ms.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: ms.Tensor = mint.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class HunYuanDenseV1ForCausalLM(HunYuanDenseV1PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = HunYuanDenseV1Model(config) + self.vocab_size = config.vocab_size + self.lm_head = mint.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[ms.Tensor] = None, + logits_to_keep: Union[int, ms.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import HunYuanDenseV1ForCausalLM + >>> import mindspore as ms + + >>> model = HunYuanDenseV1ForCausalLM.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="np") + + >>> # Generate + >>> generate_ids = model.generate(ms.tensor(inputs.input_ids), max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class HunYuanDenseV1ForSequenceClassification(GenericForSequenceClassification, HunYuanDenseV1PreTrainedModel): + pass + + +__all__ = [ + "HunYuanDenseV1ForCausalLM", + "HunYuanDenseV1Model", + "HunYuanDenseV1PreTrainedModel", + "HunYuanDenseV1ForSequenceClassification", +] diff --git a/mindone/transformers/models/hunyuan_v1_moe/__init__.py b/mindone/transformers/models/hunyuan_v1_moe/__init__.py new file mode 100644 index 0000000000..414c33b039 --- /dev/null +++ b/mindone/transformers/models/hunyuan_v1_moe/__init__.py @@ -0,0 +1 @@ +from .modeling_hunyuan import * diff --git a/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py new file mode 100644 index 0000000000..7e66874ec3 --- /dev/null +++ b/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -0,0 +1,581 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_hunyuan_v1_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright (C) 2025 THL A29 Limited, a Tencent company and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import mindspore as ms +from mindspore import nn, mint + +from ...cache_utils import Cache + +from ...activations import ACT2FN +from ...cache_utils import DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, can_return_tuple +from transformers.utils.deprecation import deprecate_kwarg +from mindone.models.utils import normal_, trunc_normal_, zeros_ + +from ...utils.generic import check_model_inputs +from .configuration_hunyuan_v1_moe import HunYuanMoEV1Config + + +class HunYuanMoEV1RMSNorm(nn.Cell): + def __init__(self, hidden_size, eps=1e-6): + """ + HunYuanMoEV1RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = ms.Parameter(mint.ones(hidden_size)) + self.variance_epsilon = eps + + def construct(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class HunYuanMoEV1MLP(nn.Cell): + def __init__(self, config: HunYuanMoEV1Config, layer_idx=None, is_shared_mlp=False): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = mint.nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + self.layer_idx = layer_idx + + def construct(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return mint.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`ms.Tensor`): The query tensor. + k (`ms.Tensor`): The key tensor. + cos (`ms.Tensor`): The cosine part of the rotary embedding. + sin (`ms.Tensor`): The sine part of the rotary embedding. + position_ids (`ms.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(ms.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].broadcast_to(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Cell, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = mint.matmul(query, key_states.swapaxes(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype) + attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mint.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class HunYuanMoEV1Attention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: HunYuanMoEV1Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = mint.nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = mint.nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.query_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def construct( + self, + hidden_states: ms.Tensor, + position_embeddings: tuple[ms.Tensor, ms.Tensor], + attention_mask: Optional[ms.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[ms.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[ms.Tensor, ms.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class HunYuanMoEV1Gate(nn.Cell): + def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] + self.wg = mint.nn.Linear(config.hidden_size, num_experts, bias=False, dtype=ms.float32) + + def construct(self, hidden_states): + bsz, seq_len, hidden_size = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_size) + if self.wg.weight.dtype == ms.Tensor: + hidden_states = hidden_states.float() + logits = self.wg(hidden_states) + return logits + + +class HunYuanMoEV1Moe(nn.Cell): + def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] + self.top_k = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx] + self.gate = HunYuanMoEV1Gate(config, layer_idx=layer_idx) + # self.wg = mint.nn.Linear(config.hidden_size, config.num_experts, bias=False, dtype=ms.Tensor) + self.experts = nn.CellList( + [HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(self.num_experts)] + ) + + self.shared_mlp = HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=True) + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_mlp = self.shared_mlp(hidden_states) + router_logits = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + + routing_weights = mint.functional.softmax(router_logits, dim=1, dtype=ms.float32) + routing_weights, selected_experts = mint.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = mint.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = mint.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hit = mint.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = mint.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + hidden_states_mlp + + +class HunYuanMoEV1DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: HunYuanMoEV1Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = HunYuanMoEV1Attention(config=config, layer_idx=layer_idx) + self.mlp = HunYuanMoEV1Moe(config, layer_idx=layer_idx) + self.input_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[ms.Tensor] = None, + position_embeddings: Optional[tuple[ms.Tensor, ms.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> ms.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class HunYuanMoEV1PreTrainedModel(PreTrainedModel): + config: HunYuanMoEV1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["HunYuanMoEV1DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": HunYuanMoEV1DecoderLayer, + "attentions": HunYuanMoEV1Attention, + } + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, mint.nn.Linear): + normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + zeros_(module.bias) + elif isinstance(module, mint.nn.Embedding): + normal_(module.weight, mean=0.0, std=std) + if module.padding_idx is not None: + zeros_(module.weight.data[module.padding_idx]) + + +class HunYuanMoEV1RotaryEmbedding(nn.Cell): + inv_freq: ms.Tensor # fix linting for `register_buffer` + + def __init__(self, config: HunYuanMoEV1Config): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "dynamic" and config.rope_scaling["alpha"]: + # DynamicNTKAlphaRotary + self.dim = config.head_dim + base = config.rope_theta * config.rope_scaling.get("alpha") ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (mint.arange(0, self.dim, 2).float() / self.dim)) + self.attention_scaling = 1.0 + else: + inv_freq, self.attention_scaling = self.rope_init_fn(self.config) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def construct(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = mint.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class HunYuanMoEV1Model(HunYuanMoEV1PreTrainedModel): + def __init__(self, config: HunYuanMoEV1Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = mint.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.CellList( + [HunYuanMoEV1DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = HunYuanMoEV1RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + cache_position: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: ms.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: ms.Tensor = mint.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class HunYuanMoEV1ForCausalLM(HunYuanMoEV1PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = HunYuanMoEV1Model(config) + self.vocab_size = config.vocab_size + self.lm_head = mint.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[ms.Tensor] = None, + logits_to_keep: Union[int, ms.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import HunYuanMoEV1ForCausalLM + >>> import mindspore as ms + + + >>> model = HunYuanMoEV1ForCausalLM.from_pretrained("meta-hunyuan_v1_moe/HunYuanMoEV1-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-hunyuan_v1_moe/HunYuanMoEV1-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="np") + + >>> # Generate + >>> generate_ids = model.generate(ms.tensor(inputs.input_ids), max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class HunYuanMoEV1ForSequenceClassification(GenericForSequenceClassification, HunYuanMoEV1PreTrainedModel): + pass + + +__all__ = [ + "HunYuanMoEV1ForCausalLM", + "HunYuanMoEV1Model", + "HunYuanMoEV1PreTrainedModel", + "HunYuanMoEV1ForSequenceClassification", +] From 3cd1d8523977116c2177c7de37c0ec37a282da4b Mon Sep 17 00:00:00 2001 From: Fzilan Date: Wed, 29 Oct 2025 14:09:53 +0800 Subject: [PATCH 2/4] add fast ut --- mindone/transformers/models/__init__.py | 2 +- .../modeling_hunyuan_v1_dense.py | 18 +-- .../hunyuan_v1_moe/modeling_hunyuan_v1_moe.py | 23 +-- .../models/hunyuan_v1_dense/__init__.py | 0 .../test_modeling_hunyuan_v1_dense.py | 152 +++++++++++++++++ .../models/hunyuan_v1_moe/__init__.py | 0 .../test_modeling_hunyuan_v1_moe.py | 153 ++++++++++++++++++ 7 files changed, 321 insertions(+), 27 deletions(-) create mode 100644 tests/transformers_tests/models/hunyuan_v1_dense/__init__.py create mode 100644 tests/transformers_tests/models/hunyuan_v1_dense/test_modeling_hunyuan_v1_dense.py create mode 100644 tests/transformers_tests/models/hunyuan_v1_moe/__init__.py create mode 100644 tests/transformers_tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index 980ed77c7c..33d0daa3a3 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -275,4 +275,4 @@ from . import glm4v, minimax, qwen2_5_omni, vjepa2 if version.parse(transformers.__version__) >= version.parse("4.57.0"): - from . import qwen3_vl, qwen3_vl_moe, hunyuan_v1_dense, hunyuan_v1_moe + from . import hunyuan_v1_dense, hunyuan_v1_moe, qwen3_vl, qwen3_vl_moe diff --git a/mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index 8b758f1fce..8ef1d537ce 100644 --- a/mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -21,13 +21,16 @@ from typing import Callable, Optional, Union +from transformers import HunYuanDenseV1Config +from transformers.utils.deprecation import deprecate_kwarg + import mindspore as ms -from mindspore import nn, mint +from mindspore import mint, nn -from ...cache_utils import Cache +from mindone.models.utils import normal_, zeros_ from ...activations import ACT2FN -from ...cache_utils import DynamicCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer @@ -37,9 +40,6 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, can_return_tuple from ...utils.generic import check_model_inputs -from transformers import HunYuanDenseV1Config -from transformers.utils.deprecation import deprecate_kwarg -from mindone.models.utils import normal_, trunc_normal_, zeros_ class HunYuanDenseV1RMSNorm(nn.Cell): @@ -383,9 +383,7 @@ def construct( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position: ms.Tensor = mint.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] - ) + cache_position: ms.Tensor = mint.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -420,7 +418,6 @@ def construct( ) -@auto_docstring class HunYuanDenseV1ForCausalLM(HunYuanDenseV1PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -436,7 +433,6 @@ def __init__(self, config): self.post_init() @can_return_tuple - @auto_docstring def construct( self, input_ids: Optional[ms.Tensor] = None, diff --git a/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 7e66874ec3..388cd53ceb 100644 --- a/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -21,13 +21,16 @@ from typing import Callable, Optional, Union +from transformers import HunYuanMoEV1Config +from transformers.utils.deprecation import deprecate_kwarg + import mindspore as ms -from mindspore import nn, mint +from mindspore import mint, nn -from ...cache_utils import Cache +from mindone.models.utils import normal_, zeros_ from ...activations import ACT2FN -from ...cache_utils import DynamicCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer @@ -36,11 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, can_return_tuple -from transformers.utils.deprecation import deprecate_kwarg -from mindone.models.utils import normal_, trunc_normal_, zeros_ - from ...utils.generic import check_model_inputs -from .configuration_hunyuan_v1_moe import HunYuanMoEV1Config class HunYuanMoEV1RMSNorm(nn.Cell): @@ -272,9 +271,7 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = mint.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype - ) + final_hidden_states = mint.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated @@ -413,7 +410,6 @@ def construct(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -@auto_docstring class HunYuanMoEV1Model(HunYuanMoEV1PreTrainedModel): def __init__(self, config: HunYuanMoEV1Config): super().__init__(config) @@ -454,9 +450,7 @@ def construct( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position: ms.Tensor = mint.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] - ) + cache_position: ms.Tensor = mint.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -491,7 +485,6 @@ def construct( ) -@auto_docstring class HunYuanMoEV1ForCausalLM(HunYuanMoEV1PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} diff --git a/tests/transformers_tests/models/hunyuan_v1_dense/__init__.py b/tests/transformers_tests/models/hunyuan_v1_dense/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/hunyuan_v1_dense/test_modeling_hunyuan_v1_dense.py b/tests/transformers_tests/models/hunyuan_v1_dense/test_modeling_hunyuan_v1_dense.py new file mode 100644 index 0000000000..3fce5f9a1e --- /dev/null +++ b/tests/transformers_tests/models/hunyuan_v1_dense/test_modeling_hunyuan_v1_dense.py @@ -0,0 +1,152 @@ +# Copyright (C) 2024 THL A29 Limited, a Tencent company and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the MindSpore HunYuanDenseV1 model.""" + +import inspect + +import numpy as np +import pytest +import torch +import transformers + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import floats_numpy, ids_numpy +from transformers import HunYuanDenseV1Config + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} +MODES = [1] # not support graph mode yet + +if transformers.__version__ >= "4.57.0": + from transformers import HunYuanDenseV1Config + + class HunyuanV1DenseModelTester: + def __init__( + self, + batch_size=5, + seq_length=20, + ): + self.batch_size = batch_size + self.seq_length = seq_length + + def get_config(self): + return HunYuanDenseV1Config() + + def prepare_config_and_inputs(self): + config = self.get_config() + vocab_size = config.vocab_size + input_ids = ids_numpy([self.batch_size, self.seq_length], vocab_size) + attention_mask = np.tril(np.ones_like(input_ids)) + + return config, input_ids, attention_mask + + model_tester = HunyuanV1DenseModelTester() + config, input_ids, attention_mask = model_tester.prepare_config_and_inputs() + + HUNYUANV1DENSE_CASES = [ + [ + "HunYuanDenseV1ForCausalLM", + "transformers.HunYuanDenseV1ForCausalLM", + "mindone.transformers.HunYuanDenseV1ForCausalLM", + (config,), + {}, + (), + { + "input_ids": input_ids, + "attention_mask": attention_mask, + }, + { + "logits": 0, + }, + ], + ] + + @pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in HUNYUANV1DENSE_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], + ) + def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, + ): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + ms_inputs_kwargs.update({"use_cache": False}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + ) \ No newline at end of file diff --git a/tests/transformers_tests/models/hunyuan_v1_moe/__init__.py b/tests/transformers_tests/models/hunyuan_v1_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py b/tests/transformers_tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py new file mode 100644 index 0000000000..21fb2ef520 --- /dev/null +++ b/tests/transformers_tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py @@ -0,0 +1,153 @@ +# Copyright (C) 2024 THL A29 Limited, a Tencent company and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the MindSpore HunYuanMoEV1 model.""" + +import inspect + +import numpy as np +import pytest +import torch +import transformers + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import floats_numpy, ids_numpy +from transformers import HunYuanMoEV1Config + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} +MODES = [1] # not support graph mode yet + +if transformers.__version__ >= "4.57.0": + from transformers import HunYuanMoEV1Config + + class HunyuanV1MoeModelTester: + def __init__( + self, + batch_size=5, + seq_length=20, + ): + self.batch_size = batch_size + self.seq_length = seq_length + + def get_config(self): + return HunYuanMoEV1Config() + + def prepare_config_and_inputs(self): + config = self.get_config() + vocab_size = config.vocab_size + input_ids = ids_numpy([self.batch_size, self.seq_length], vocab_size) + attention_mask = np.tril(np.ones_like(input_ids)) + + return config, input_ids, attention_mask + + model_tester = HunyuanV1DenseModelTester() + config, input_ids, attention_mask = model_tester.prepare_config_and_inputs() + + HUNYUANV1MOE_CASES = [ + [ + "HunYuanMoEV1ForCausalLM", + "transformers.HunYuanMoEV1ForCausalLM", + "mindone.transformers.HunYuanMoEV1ForCausalLM", + (config,), + {}, + (), + { + "input_ids": input_ids, + "attention_mask": attention_mask, + }, + { + "logits": 0, + }, + ], + ] + + @pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in HUNYUANV1MOE_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], + ) + def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, + ): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + ms_inputs_kwargs.update({"use_cache": False}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + ) + From e87277c7e623f3a671e98f0417439b8a11dbb45e Mon Sep 17 00:00:00 2001 From: Fzilan Date: Wed, 29 Oct 2025 18:32:01 +0800 Subject: [PATCH 3/4] bugfix --- mindone/transformers/__init__.py | 2 ++ .../hunyuan_v1_dense/modeling_hunyuan_v1_dense.py | 8 +++++--- .../transformers/models/hunyuan_v1_moe/__init__.py | 2 +- .../hunyuan_v1_moe/modeling_hunyuan_v1_moe.py | 8 +++++--- .../test_modeling_hunyuan_v1_dense.py | 12 ++++++++---- .../hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py | 13 ++++++++----- 6 files changed, 29 insertions(+), 16 deletions(-) diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 2cda1be9ad..1a95d80f27 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -1575,6 +1575,8 @@ from .models.vjepa2 import VJEPA2ForVideoClassification, VJEPA2Model, VJEPA2PreTrainedModel if version.parse(transformers.__version__) >= version.parse("4.57.0"): + from .models.hunyuan_v1_dense import HunYuanDenseV1ForCausalLM, HunYuanDenseV1ForSequenceClassification + from .models.hunyuan_v1_moe import HunYuanMoEV1ForCausalLM, HunYuanMoEV1ForSequenceClassification from .models.qwen3_vl import ( Qwen3VLForConditionalGeneration, Qwen3VLModel, diff --git a/mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index 8ef1d537ce..bd40bfa385 100644 --- a/mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/mindone/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -121,7 +121,7 @@ def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].broadcast_to(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].broadcast_to((batch, num_key_value_heads, n_rep, slen, head_dim)) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -299,7 +299,9 @@ def _init_weights(self, module): elif isinstance(module, mint.nn.Embedding): normal_(module.weight, mean=0.0, std=std) if module.padding_idx is not None: - zeros_(module.weight.data[module.padding_idx]) + weights = module.weight.data + weights[module.padding_idx] = 0.0 + module.weight.set_data(weights) class HunYuanDenseV1RotaryEmbedding(nn.Cell): @@ -331,7 +333,7 @@ def __init__(self, config: HunYuanDenseV1Config): @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def construct(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand((position_ids.shape[0], -1, 1)) position_ids_expanded = position_ids[:, None, :].float() # Force float32 diff --git a/mindone/transformers/models/hunyuan_v1_moe/__init__.py b/mindone/transformers/models/hunyuan_v1_moe/__init__.py index 414c33b039..07a242781b 100644 --- a/mindone/transformers/models/hunyuan_v1_moe/__init__.py +++ b/mindone/transformers/models/hunyuan_v1_moe/__init__.py @@ -1 +1 @@ -from .modeling_hunyuan import * +from .modeling_hunyuan_v1_moe import * diff --git a/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 388cd53ceb..77dbc6e361 100644 --- a/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -280,7 +280,7 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: # Loop over all available experts in the model and perform the computation on each expert expert_hit = mint.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - expert_layer = self.experts[expert_idx] + expert_layer = self.experts[int(expert_idx)] idx, top_x = mint.where(expert_mask[expert_idx].squeeze(0)) # Index the correct hidden states and compute the expert hidden state for @@ -366,7 +366,9 @@ def _init_weights(self, module): elif isinstance(module, mint.nn.Embedding): normal_(module.weight, mean=0.0, std=std) if module.padding_idx is not None: - zeros_(module.weight.data[module.padding_idx]) + weights = module.weight.data + weights[module.padding_idx] = 0.0 + module.weight.set_data(weights) class HunYuanMoEV1RotaryEmbedding(nn.Cell): @@ -398,7 +400,7 @@ def __init__(self, config: HunYuanMoEV1Config): @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def construct(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand((position_ids.shape[0], -1, 1)) position_ids_expanded = position_ids[:, None, :].float() # Force float32 diff --git a/tests/transformers_tests/models/hunyuan_v1_dense/test_modeling_hunyuan_v1_dense.py b/tests/transformers_tests/models/hunyuan_v1_dense/test_modeling_hunyuan_v1_dense.py index 3fce5f9a1e..9d45128a3b 100644 --- a/tests/transformers_tests/models/hunyuan_v1_dense/test_modeling_hunyuan_v1_dense.py +++ b/tests/transformers_tests/models/hunyuan_v1_dense/test_modeling_hunyuan_v1_dense.py @@ -29,8 +29,7 @@ generalized_parse_args, get_modules, ) -from tests.transformers_tests.models.modeling_common import floats_numpy, ids_numpy -from transformers import HunYuanDenseV1Config +from tests.transformers_tests.models.modeling_common import ids_numpy DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} MODES = [1] # not support graph mode yet @@ -48,7 +47,12 @@ def __init__( self.seq_length = seq_length def get_config(self): - return HunYuanDenseV1Config() + return HunYuanDenseV1Config( + hidden_size=32, + intermediate_size=128, + num_attention_heads=16, + head_dim=2, + ) def prepare_config_and_inputs(self): config = self.get_config() @@ -149,4 +153,4 @@ def test_named_modules( assert (np.array(diffs) < THRESHOLD).all(), ( f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" - ) \ No newline at end of file + ) diff --git a/tests/transformers_tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py b/tests/transformers_tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py index 21fb2ef520..e5f7f86380 100644 --- a/tests/transformers_tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py +++ b/tests/transformers_tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py @@ -29,8 +29,7 @@ generalized_parse_args, get_modules, ) -from tests.transformers_tests.models.modeling_common import floats_numpy, ids_numpy -from transformers import HunYuanMoEV1Config +from tests.transformers_tests.models.modeling_common import ids_numpy DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} MODES = [1] # not support graph mode yet @@ -48,7 +47,12 @@ def __init__( self.seq_length = seq_length def get_config(self): - return HunYuanMoEV1Config() + return HunYuanMoEV1Config( + hidden_size=32, + intermediate_size=128, + num_attention_heads=16, + head_dim=2, + ) def prepare_config_and_inputs(self): config = self.get_config() @@ -58,7 +62,7 @@ def prepare_config_and_inputs(self): return config, input_ids, attention_mask - model_tester = HunyuanV1DenseModelTester() + model_tester = HunyuanV1MoeModelTester() config, input_ids, attention_mask = model_tester.prepare_config_and_inputs() HUNYUANV1MOE_CASES = [ @@ -150,4 +154,3 @@ def test_named_modules( f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" ) - From 4f0a84754721e181cffce72aad3c2f27dd2f15ef Mon Sep 17 00:00:00 2001 From: Fzilan Date: Thu, 30 Oct 2025 14:29:57 +0800 Subject: [PATCH 4/4] fix typo --- mindone/transformers/models/auto/configuration_auto.py | 2 +- .../models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index ec30c17327..9630bafcaf 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -397,7 +397,7 @@ ("hiera", "Hiera"), ("hubert", "Hubert"), ("hunyuan_v1_dense", "HunYuanDenseV1"), - ("hunyuan_v1_moe", "HunYuanMoeV1"), + ("hunyuan_v1_moe", "HunYuanMoEV1"), ("ibert", "I-BERT"), ("idefics", "IDEFICS"), ("idefics2", "Idefics2"), diff --git a/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 77dbc6e361..89509ab69b 100644 --- a/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/mindone/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -237,7 +237,7 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None): def construct(self, hidden_states): bsz, seq_len, hidden_size = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_size) - if self.wg.weight.dtype == ms.Tensor: + if self.wg.weight.dtype == ms.float32: hidden_states = hidden_states.float() logits = self.wg(hidden_states) return logits