From f289d578312c80df0c16c7fc1506e44eec66c306 Mon Sep 17 00:00:00 2001 From: alien_0119 Date: Fri, 17 Oct 2025 15:26:34 +0800 Subject: [PATCH] add t5gemma --- mindone/transformers/__init__.py | 9 + mindone/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../transformers/models/auto/modeling_auto.py | 7 + .../transformers/models/t5gemma/__init__.py | 17 + .../models/t5gemma/modeling_t5gemma.py | 1374 +++++++++++++++++ .../models/t5gemma/__init__.py | 0 .../models/t5gemma/test_modeling_t5gemma.py | 349 +++++ 8 files changed, 1759 insertions(+) create mode 100644 mindone/transformers/models/t5gemma/__init__.py create mode 100644 mindone/transformers/models/t5gemma/modeling_t5gemma.py create mode 100644 tests/transformers_tests/models/t5gemma/__init__.py create mode 100644 tests/transformers_tests/models/t5gemma/test_modeling_t5gemma.py diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 2cda1be9ad..5e0d107ac6 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -1318,6 +1318,15 @@ T5Model, T5PreTrainedModel, ) +from .models.t5gemma import ( + T5GemmaEncoder, + T5GemmaEncoderModel, + T5GemmaForConditionalGeneration, + T5GemmaForSequenceClassification, + T5GemmaForTokenClassification, + T5GemmaModel, + T5GemmaPreTrainedModel, +) from .models.table_transformer import ( TableTransformerForObjectDetection, TableTransformerModel, diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index 28e7ced270..8e0876fc14 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -221,6 +221,7 @@ swinv2, switch_transformers, t5, + t5gemma, table_transformer, tapas, textnet, diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index b1d25744a3..aa4bb5b1df 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -257,6 +257,7 @@ ("trocr", "TrOCRConfig"), ("tvp", "TvpConfig"), ("udop", "UdopConfig"), + ("t5gemma", "T5GemmaConfig"), ("umt5", "UMT5Config"), ("unispeech", "UniSpeechConfig"), ("unispeech-sat", "UniSpeechSatConfig"), @@ -521,6 +522,7 @@ ("swinv2", "Swin Transformer V2"), ("swin2sr", "Swin2SR"), ("t5", "T5"), + ("t5gemma", "T5Gemma"), ("t5v1.1", "T5v1.1"), ("table-transformer", "Table Transformer"), ("tapas", "TAPAS"), diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index c89e50bc47..ef833f9c38 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -233,6 +233,7 @@ ("timesformer", "TimesformerModel"), ("tvp", "TvpModel"), ("udop", "UdopModel"), + ("t5gemma", "T5GemmaModel"), ("umt5", "UMT5Model"), ("unispeech", "UniSpeechModel"), ("unispeech-sat", "UniSpeechSatModel"), @@ -328,6 +329,7 @@ ("vipllava", "VipLlavaForConditionalGeneration"), ("visual_bert", "VisualBertForPreTraining"), ("vit_mae", "ViTMAEForPreTraining"), + ("t5gemma", "T5GemmaForConditionalGeneration"), ("wav2vec2", "Wav2Vec2ForPreTraining"), ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"), ("xlm", "XLMWithLMHeadModel"), @@ -397,6 +399,7 @@ ("squeezebert", "SqueezeBertForMaskedLM"), ("t5", "T5ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), + ("t5gemma", "T5GemmaForConditionalGeneration"), ("wav2vec2", "Wav2Vec2ForMaskedLM"), ("whisper", "WhisperForConditionalGeneration"), ("xlm", "XLMWithLMHeadModel"), @@ -831,6 +834,7 @@ ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"), ("squeezebert", "SqueezeBertForSequenceClassification"), ("t5", "T5ForConditionalGeneration"), + ("t5gemma", "T5GemmaForConditionalGeneration"), ("umt5", "UMT5ForConditionalGeneration"), ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), ] @@ -919,6 +923,7 @@ ("starcoder2", "Starcoder2ForSequenceClassification"), ("t5", "T5ForSequenceClassification"), ("tapas", "TapasForSequenceClassification"), + ("t5gemma", "T5GemmaForSequenceClassification"), ("umt5", "UMT5ForSequenceClassification"), ("xlm", "XLMForSequenceClassification"), ("xlm-roberta", "XLMRobertaForSequenceClassification"), @@ -1070,6 +1075,7 @@ ("squeezebert", "SqueezeBertForTokenClassification"), ("stablelm", "StableLmForTokenClassification"), ("t5", "T5ForTokenClassification"), + ("t5gemma", "T5GemmaForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm", "XLMForTokenClassification"), ("xlm-roberta", "XLMRobertaForTokenClassification"), @@ -1257,6 +1263,7 @@ ("roberta-prelayernorm", "RobertaPreLayerNormModel"), ("squeezebert", "SqueezeBertModel"), ("t5", "T5EncoderModel"), + ("t5gemma", "T5GemmaEncoderModel"), ("umt5", "UMT5EncoderModel"), ("xlm", "XLMModel"), ("xlm-roberta", "XLMRobertaModel"), diff --git a/mindone/transformers/models/t5gemma/__init__.py b/mindone/transformers/models/t5gemma/__init__.py new file mode 100644 index 0000000000..82c6d09276 --- /dev/null +++ b/mindone/transformers/models/t5gemma/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_t5gemma import * diff --git a/mindone/transformers/models/t5gemma/modeling_t5gemma.py b/mindone/transformers/models/t5gemma/modeling_t5gemma.py new file mode 100644 index 0000000000..f2c1a2fb33 --- /dev/null +++ b/mindone/transformers/models/t5gemma/modeling_t5gemma.py @@ -0,0 +1,1374 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/t5gemma/modular_t5gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_t5gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +from transformers import T5GemmaConfig, T5GemmaModuleConfig + +import mindspore +from mindspore import Parameter, Tensor, mint, nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, can_return_tuple, check_model_inputs, logging +from ...utils.generic import OutputRecorder + +logger = logging.get_logger(__name__) + + +class T5GemmaRMSNorm(nn.Cell): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = Parameter(mint.zeros(dim)) + + def _norm(self, x): + return x * mint.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def construct(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst T5Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class T5GemmaMLP(nn.Cell): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = mint.nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + self.dropout = mint.nn.Dropout(config.dropout_rate) + + def construct(self, x): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states) + down_proj = self.down_proj(hidden_states) + return down_proj + + +class T5GemmaRotaryEmbedding(nn.Cell): + def __init__(self, config): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @mindspore._no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def construct(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand((position_ids.shape[0], -1, 1)) + position_ids_expanded = position_ids[:, None, :].float() + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = mint.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return mint.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`ms.Tensor`): The query tensor. + k (`ms.Tensor`): The key tensor. + cos (`ms.Tensor`): The cosine part of the rotary embedding. + sin (`ms.Tensor`): The sine part of the rotary embedding. + position_ids (`ms.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(ms.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor: + """ + This is the equivalent of repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand((batch, num_key_value_heads, n_rep, slen, head_dim)) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Cell, + query: Tensor, + key: Tensor, + value: Tensor, + attention_mask: Optional[Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> tuple[Tensor, Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = mint.matmul(query, key_states.transpose(2, 3)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = mint.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(query.dtype) + attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mint.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class T5GemmaSelfAttention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + # Requied by flash attention: encoder selfattention is non-causal + self.is_causal = config.is_decoder + + self.q_proj = mint.nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = mint.nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + def construct( + self, + hidden_states: Tensor, + position_embeddings: tuple[Tensor, Tensor], + attention_mask: Optional[Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[Tensor, Optional[Tensor], Optional[tuple[Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + softcap=self.attn_logit_softcapping, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class T5GemmaCrossAttention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + self.is_causal = False + + self.q_proj = mint.nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + + self.k_proj = mint.nn.Linear( + config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = mint.nn.Linear( + config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = mint.nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + + if config.cross_attention_hidden_size is None: + raise ValueError("Cross-attention needs cross_attention_hidden_size to be specified.") + + def construct( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor], + encoder_hidden_states: Optional[Tensor], + past_key_value: Optional[Cache] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[Tensor, Optional[Tensor], Optional[tuple[Tensor]]]: + if encoder_hidden_states is None: + raise ValueError("Encoder hidden state is required for cross attention.") + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + curr_past_key_value = past_key_value.cross_attention_cache + + if past_key_value is None or not is_updated: + encoder_input_shape = encoder_hidden_states.shape[:-1] + encoder_hidden_shape = (*encoder_input_shape, -1, self.head_dim) + key_states = self.k_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) + past_key_value.is_updated[self.layer_idx] = True + else: + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=None, + softcap=self.attn_logit_softcapping, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class T5GemmaEncoderLayer(GradientCheckpointingLayer): + """Encoder sub-layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.config = config + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + + self.self_attn = T5GemmaSelfAttention( + config=config, + layer_idx=layer_idx, + ) + self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.mlp = T5GemmaMLP(config) + self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.dropout = mint.nn.Dropout(config.dropout_rate) + + def construct( + self, + hidden_states: Tensor, + position_embeddings: tuple[Tensor, Tensor], + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + **kwargs, + ) -> tuple[Tensor]: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5GemmaDecoderLayer(T5GemmaEncoderLayer): + """Decoder sub-layer: an extra cross-attention layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx) + self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def construct( + self, + hidden_states: Tensor, + position_embeddings: tuple[Tensor, Tensor], + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[Tensor] = None, + encoder_hidden_states: Optional[Tensor] = None, + encoder_attention_mask: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value.self_attention_cache if past_key_value is not None else None, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_cross_attn_layernorm(hidden_states) + hidden_states, _ = self.cross_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + use_cache=use_cache, + **kwargs, + ) + hidden_states = self.post_cross_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5GemmaClassificationHead(nn.Cell): + """Head for sentence-level classification tasks.""" + + def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0): + super().__init__() + self.dropout = mint.nn.Dropout(p=classifier_dropout_rate) + self.out_proj = mint.nn.Linear(hidden_size, num_labels) + + def construct(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class T5GemmaLMHead(nn.Cell): + """Head for language modeling (generation) tasks.""" + + def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False): + super().__init__() + self.out_proj = mint.nn.Linear(hidden_size, vocab_size, bias=bias) + + def construct(self, hidden_states: Tensor) -> Tensor: + logits = self.out_proj(hidden_states) + return logits + + +class T5GemmaAttention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: T5GemmaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + self.is_causal = True + + self.q_proj = mint.nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = mint.nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + def construct( + self, + hidden_states: Tensor, + position_embeddings: tuple[Tensor, Tensor], + attention_mask: Optional[Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[Tensor, Optional[Tensor], Optional[tuple[Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + softcap=self.attn_logit_softcapping, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class T5GemmaPreTrainedModel(PreTrainedModel): + config: T5GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["T5GemmaBlock"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": T5GemmaDecoderLayer, + "attentions": T5GemmaAttention, + } + + def _init_weights(self, module): + # TODO: support intialization for encoders and decoders separately(?) + super()._init_weights(module) + std = self.config.initializer_range + if isinstance(module, T5GemmaClassificationHead): + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, T5GemmaLMHead): + if not self.config.tie_word_embeddings: + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + + def _shift_right(self, input_ids): + """ + Shifts input_ids to the right, prepends the decoder_start_token_id, and handles + pad_token_id replacement for labels that were -100. + This is a common preparation step for decoder inputs in sequence-to-sequence models. + """ + decoder_start_token_id = self.config.decoder.bos_token_id + pad_token_id = self.config.decoder.pad_token_id + + if decoder_start_token_id is None: + raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ") + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.decoder.pad_token_id has to be defined.") + + # Is this T5 specific? + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def bidirectional_mask_function(attention_mask: Optional[Tensor]) -> Callable: + """ + This creates bidirectional attention mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + if attention_mask is None: + return mint.ones((), dtype=mindspore.bool) + return attention_mask[batch_idx, kv_idx].to(mindspore.bool) + + return inner_mask + + +def sliding_window_bidirectional_mask_function(sliding_window: int) -> Callable: + """ + This creates bidirectional attention mask with sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return (q_idx - sliding_window < kv_idx) & (kv_idx < q_idx + sliding_window) + + return inner_mask + + +def make_default_2d_attention_mask( + token_ids: Optional[Tensor], + hidden_states: Tensor, + pad_token_id: Optional[int], +) -> Tensor: + """Construct the default attention mask.""" + if token_ids is not None: + if pad_token_id is None: + raise ValueError("`pad_token_id` is required for padding information.") + attention_mask = (token_ids != pad_token_id).to(mindspore.long) + else: + attention_mask = mint.ones((hidden_states.shape[0], hidden_states.shape[1]), dtype=mindspore.long) + return attention_mask + + +class T5GemmaEncoder(T5GemmaPreTrainedModel): + _can_record_outputs = { + "attentions": T5GemmaSelfAttention, + "hidden_states": T5GemmaEncoderLayer, + } + + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = mint.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = T5GemmaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + self.layers = nn.CellList( + [T5GemmaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.dropout = mint.nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + def construct( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + cache_position = mint.arange(0, inputs_embeds.shape[1]) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if attention_mask is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": None, + "position_ids": position_ids, + } + self_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(attention_mask), + ), + "sliding_attention": create_sliding_window_causal_mask( + **mask_kwargs, + or_mask_function=sliding_window_bidirectional_mask_function(self.config.sliding_window), + and_mask_function=bidirectional_mask_function(attention_mask), + ), + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + normalizer = Tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer_module( + hidden_states, + position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + ) + + +class T5GemmaDecoder(T5GemmaEncoder): + _can_record_outputs = { + "attentions": OutputRecorder(T5GemmaSelfAttention, index=1), + "cross_attentions": OutputRecorder(T5GemmaCrossAttention, index=1), + "hidden_states": T5GemmaDecoderLayer, + } + + def __init__(self, config): + super().__init__(config) + self.layers = nn.CellList( + [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + self.post_init() + + @check_model_inputs + def construct( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[Tensor] = None, + encoder_hidden_states: Optional[Tensor] = None, + encoder_attention_mask: Optional[Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPastAndCrossAttentions: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states` must be given in decoder") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if not self.training and use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache( + self_attention_cache=DynamicCache(), + cross_attention_cache=DynamicCache(), + ) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = mint.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if attention_mask is None and past_key_values is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, + "position_ids": position_ids, + } + self_attn_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": encoder_hidden_states, + "attention_mask": encoder_attention_mask, + "cache_position": cache_position, + "past_key_values": None, + "position_ids": None, + } + cross_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(encoder_attention_mask), + ), + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + normalizer = Tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer_module( + hidden_states, + position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + past_key_values, + use_cache, + cache_position, + encoder_hidden_states, + cross_attn_mask_mapping["full_attention"], + **kwargs, + ) + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class T5GemmaModel(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig): + super().__init__(config) + + if not config.is_encoder_decoder: + raise ValueError("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead.") + + self.encoder = T5GemmaEncoder(config.encoder) + self.decoder = T5GemmaDecoder(config.decoder) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + def construct( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + decoder_input_ids: Optional[Tensor] = None, + decoder_attention_mask: Optional[Tensor] = None, + decoder_position_ids: Optional[Tensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[Tensor] = None, + decoder_inputs_embeds: Optional[Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Seq2SeqModelOutput: + r""" + decoder_position_ids (`ms.Tensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + """ + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + encoder_hidden_states = encoder_outputs.last_hidden_state + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=( + decoder_outputs.hidden_states + if kwargs.get("output_hidden_states", False) + else (decoder_outputs.last_hidden_state,) + ), + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class T5GemmaEncoderModel(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig): + super().__init__(config) + + if config.is_encoder_decoder: + raise ValueError("T5GemmaEncoderModel only supports encoder-only model. Use `T5GemmaModel` instead.") + + self.encoder = T5GemmaEncoder(config.encoder) + self.post_init() + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + def construct( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + return encoder_outputs + + +class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["model.decoder.embed_tokens.weight", "lm_head.out_proj.weight"] + _tp_plan = {"lm_head.out_proj": "colwise_rep"} + _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])} + + def __init__(self, config: T5GemmaConfig): + config.is_encoder_decoder = True + super().__init__(config) + + self.model = T5GemmaModel(config) + self.vocab_size = config.decoder.vocab_size + self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size) + self.loss_type = "ForMaskedLM" + + self.post_init() + + def set_output_embeddings(self, new_embeddings): + self.lm_head.out_proj = new_embeddings + + def get_output_embeddings(self): + return self.lm_head.out_proj + + def _tie_weights(self): + # Decoder input and output embeddings are tied. + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.lm_head.out_proj, self.get_decoder().get_input_embeddings()) + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + @can_return_tuple + def construct( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + decoder_input_ids: Optional[Tensor] = None, + decoder_attention_mask: Optional[Tensor] = None, + decoder_position_ids: Optional[Tensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[Tensor] = None, + decoder_inputs_embeds: Optional[Tensor] = None, + labels: Optional[Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[Tensor] = None, + logits_to_keep: Union[int, Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple[Tensor], Seq2SeqLMOutput]: + r""" + decoder_position_ids (`ms.Tensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + if self.training and self.config._attn_implementation != "eager": + msg = ( + "It is strongly recommended to train T5Gemma models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`." + "Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + logger.warning_once(msg) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + decoder_outputs: Seq2SeqModelOutput = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = decoder_outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + decoder_config = self.get_decoder().config + if decoder_config.final_logit_softcapping is not None: + logits = logits / decoder_config.final_logit_softcapping + logits = mint.tanh(logits) + logits = logits * decoder_config.final_logit_softcapping + + loss = None + if labels is not None: + # Input has right-shifted so we directly perform masked lm loss + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.decoder_hidden_states, + decoder_attentions=decoder_outputs.decoder_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state, + encoder_hidden_states=decoder_outputs.encoder_hidden_states, + encoder_attentions=decoder_outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: Tensor): + return self._shift_right(labels) + + +class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None): + r""" + is_encoder_decoder (`Optional`, *optional*): + Whether use encoder_decoder for sequence classification. When set to False, only encoder is used. + """ + if is_encoder_decoder is not None: + config.is_encoder_decoder = is_encoder_decoder + super().__init__(config) + self.num_labels = config.num_labels + + if config.is_encoder_decoder: + self.model = T5GemmaModel(config) + else: + self.model = T5GemmaEncoderModel(config) + + hidden_size = config.encoder.hidden_size + if config.is_encoder_decoder: + hidden_size = config.decoder.hidden_size + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + def construct( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + decoder_input_ids: Optional[Tensor] = None, + decoder_attention_mask: Optional[Tensor] = None, + decoder_position_ids: Optional[Tensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[Tensor] = None, + decoder_inputs_embeds: Optional[Tensor] = None, + labels: Optional[Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutput: + r""" + decoder_position_ids (`ms.Tensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`ms.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None): + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode." + ) + + # Following T5, we automatically creates decoder_input_ids from input_ids if no decoder_input_ids are provided + if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None): + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + if self.config.is_encoder_decoder: + outputs: Seq2SeqModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + else: + outputs: BaseModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.hidden_states + attentions = outputs.attentions + + logits = self.score(last_hidden_state) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(mindspore.int32) + token_indices = mint.arange(input_ids.shape[-1], dtype=mindspore.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + + if self.config.is_encoder_decoder: + last_non_pad_token += 1 # due to the right shift. + last_non_pad_token = mint.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[mint.arange(batch_size), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +class T5GemmaForTokenClassification(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None): + r""" + is_encoder_decoder (`Optional`, *optional*): + Whether use encoder_decoder for token classification. When set to False, only encoder is used. + """ + if is_encoder_decoder is not None: + config.is_encoder_decoder = is_encoder_decoder + super().__init__(config) + self.num_labels = config.num_labels + + if config.is_encoder_decoder: + self.model = T5GemmaModel(config) + else: + self.model = T5GemmaEncoderModel(config) + + hidden_size = config.encoder.hidden_size + if config.is_encoder_decoder: + hidden_size = config.decoder.hidden_size + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + def construct( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + decoder_input_ids: Optional[Tensor] = None, + decoder_attention_mask: Optional[Tensor] = None, + decoder_position_ids: Optional[Tensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[Tensor] = None, + decoder_inputs_embeds: Optional[Tensor] = None, + labels: Optional[Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + r""" + decoder_position_ids (`ms.Tensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`ms.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None): + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode." + ) + + if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None): + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + if self.config.is_encoder_decoder: + outputs: Seq2SeqModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + else: + outputs: BaseModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.hidden_states + attentions = outputs.attentions + + logits = self.score(last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = [ + "T5GemmaForConditionalGeneration", + "T5GemmaModel", + "T5GemmaEncoder", # for diffusers + "T5GemmaEncoderModel", + "T5GemmaPreTrainedModel", + "T5GemmaForSequenceClassification", + "T5GemmaForTokenClassification", +] diff --git a/tests/transformers_tests/models/t5gemma/__init__.py b/tests/transformers_tests/models/t5gemma/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/t5gemma/test_modeling_t5gemma.py b/tests/transformers_tests/models/t5gemma/test_modeling_t5gemma.py new file mode 100644 index 0000000000..5741df305d --- /dev/null +++ b/tests/transformers_tests/models/t5gemma/test_modeling_t5gemma.py @@ -0,0 +1,349 @@ +"""Adapted from https://github.com/huggingface/transformers/tree/main/tests/models/t5gemma/test_modeling_t5gemma.py.""" + +# This module contains test cases that are defined in the `.test_cases.py` file, structured as lists or tuples like +# [name, pt_module, ms_module, init_args, init_kwargs, inputs_args, inputs_kwargs, outputs_map]. +# +# Each defined case corresponds to a pair consisting of PyTorch and MindSpore modules, including their respective +# initialization parameters and inputs for the forward. The testing framework adopted here is designed to generically +# parse these parameters to assess and compare the precision of forward outcomes between the two frameworks. +# +# In cases where models have unique initialization procedures or require testing with specialized output formats, +# it is necessary to develop distinct, dedicated test cases. +import inspect + +import numpy as np +import pytest +import torch +from transformers import T5GemmaConfig, T5GemmaModuleConfig + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import ids_numpy + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} +MODES = [1] + + +class T5GemmaModelTester: + config_class = T5GemmaConfig + module_config_class = T5GemmaModuleConfig + + def __init__( + self, + batch_size=13, + is_training=True, + use_attention_mask=True, + use_labels=True, + vocab_size=99, + # decoder-specific + seq_length=7, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=37, + # encoder-specific + encoder_seq_length=7, + encoder_hidden_size=32, + encoder_num_hidden_layers=2, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_intermediate_size=37, + # common + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + # special ids + eos_token_id=1, + pad_token_id=0, + bos_token_id=2, + ): + self.batch_size = batch_size + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + # decoder + self.seq_length = seq_length + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + # encoder + self.encoder_seq_length = encoder_seq_length + self.encoder_hidden_size = encoder_hidden_size + self.encoder_num_hidden_layers = encoder_num_hidden_layers + self.encoder_num_attention_heads = encoder_num_attention_heads + self.encoder_num_key_value_heads = encoder_num_key_value_heads + self.encoder_intermediate_size = encoder_intermediate_size + # common + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.head_dim = self.hidden_size // self.num_attention_heads + # assume encoder and decoder have the same head dimension. + assert self.head_dim == self.encoder_hidden_size // self.encoder_num_attention_heads + # special ids + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + # assume the number of attention heads are the same across encoder and decoder + # only used for generation testing purpose. + assert self.num_attention_heads == self.encoder_num_attention_heads + + def get_encoder_config(self): + return self.module_config_class( + vocab_size=self.vocab_size, + hidden_size=self.encoder_hidden_size, + num_hidden_layers=self.encoder_num_hidden_layers, + num_attention_heads=self.encoder_num_attention_heads, + num_key_value_heads=self.encoder_num_key_value_heads, + intermediate_size=self.encoder_intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + head_dim=self.head_dim, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + ) + + def get_decoder_config(self): + return self.module_config_class( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + cross_attention_hidden_size=self.encoder_hidden_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=True, + initializer_range=self.initializer_range, + head_dim=self.head_dim, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + ) + + def get_config(self, is_encoder_decoder=True): + return self.config_class( + encoder=self.get_encoder_config(), + decoder=self.get_decoder_config(), + is_encoder_decoder=is_encoder_decoder, + # Used for generation test. + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + ) + + def prepare_config_and_inputs(self): + input_ids = ids_numpy([self.batch_size, self.encoder_seq_length], self.vocab_size) + decoder_input_ids = ids_numpy([self.batch_size, self.seq_length], self.vocab_size) + + # Remove BOS symbols from inputs. + input_ids = np.where(input_ids == self.bos_token_id, 42, input_ids) + decoder_input_ids = np.where(decoder_input_ids == self.bos_token_id, 42, decoder_input_ids) + + attention_mask = None + decoder_attention_mask = None + if self.use_attention_mask: + attention_mask = ids_numpy([self.batch_size, self.encoder_seq_length], vocab_size=2) + decoder_attention_mask = ids_numpy([self.batch_size, self.seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_numpy([self.batch_size, self.seq_length], self.vocab_size) + + config = self.get_config() + + return ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + return config, inputs_dict + + +model_tester = T5GemmaModelTester() +config, inputs_dict = model_tester.prepare_config_and_inputs_for_common() + + +TEST_CASES = [ + [ + "T5GemmaModel", + "transformers.T5GemmaModel", + "mindone.transformers.T5GemmaModel", + (config,), + {}, + (), + inputs_dict, + { + "last_hidden_state": "last_hidden_state", + }, + ], + [ + "T5GemmaForConditionalGeneration", + "transformers.T5GemmaForConditionalGeneration", + "mindone.transformers.T5GemmaForConditionalGeneration", + (config,), + {}, + (), + inputs_dict, + { + "logits": "logits", + }, + ], + [ + "T5GemmaForSequenceClassification", + "transformers.T5GemmaForSequenceClassification", + "mindone.transformers.T5GemmaForSequenceClassification", + (config,), + {}, + (), + inputs_dict, + { + "logits": "logits", + }, + ], + [ + "T5GemmaForTokenClassification", + "transformers.T5GemmaForTokenClassification", + "mindone.transformers.T5GemmaForTokenClassification", + (config,), + {}, + (), + inputs_dict, + { + "logits": "logits", + }, + ], +] + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in TEST_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + # print("ms:", ms_outputs) + # print("pt:", pt_outputs) + + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + # print("===map", pt_key, ms_idx) + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + )