Skip to content

Conversation

@alien-0119
Copy link
Collaborator

@alien-0119 alien-0119 commented Nov 6, 2025

What does this PR do?

Adds # (feature)
Add T5Gemma model and fast ut.

Rely on #1412

Usage Example:

from mindone.transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer


dtype = ms.bfloat16

tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-2b-2b-prefixlm-it")
model = AutoModelForSeq2SeqLM.from_pretrained(
    "google/t5gemma-2b-2b-prefixlm-it",
    mindspore_dtype=dtype,
)

messages = [
    {"role": "user", "content": "Tell me an unknown interesting biology fact about the brain."},
]
inputs = tokenizer.apply_chat_template(messages, return_tensors="np", return_dict=True, add_generation_prompt=True)
inputs = {k: ms.tensor(v) for k, v in inputs.items()}

outputs = model.generate(**inputs, max_new_tokens=32)
print(tokenizer.decode(outputs[0]))
# <bos>It's hard to find truly "unknown" facts about the brain, as research is constantly uncovering new things. However, here's an interesting and lesser

Performance:
Experiments were tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0 pynative mode.

model precision weight load(s) s/step
google/t5gemma-2b-2b-prefixlm-it fp32 229.860 0.633
google/t5gemma-2b-2b-prefixlm-it fp16 232.853 0.678
google/t5gemma-2b-2b-prefixlm-it bf16 142.434 0.691

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes? E.g. record bug fixes or new features in What's New. Here are the
    documentation guidelines
  • Did you build and run the code without any errors?
  • Did you report the running environment (NPU type/MS version) and performance in the doc? (better record it for data loading, model inference, or training tasks)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@xxx

@alien-0119 alien-0119 requested a review from vigo999 as a code owner November 6, 2025 06:11
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @alien-0119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the mindone/transformers library by adding the T5Gemma model. The changes encompass the full implementation of the T5Gemma architecture, its integration into the existing auto-loading mechanisms for configurations and models, and the inclusion of a dedicated test suite to ensure the new model's reliability and compatibility within the MindSpore ecosystem.

Highlights

  • New Model Integration: Introduces the T5Gemma model architecture into the mindone/transformers library, expanding the range of available models.
  • Comprehensive Implementation: Adds core T5Gemma components including T5GemmaModel, T5GemmaEncoderModel, and task-specific heads like T5GemmaForConditionalGeneration, T5GemmaForSequenceClassification, and T5GemmaForTokenClassification.
  • Auto-Configuration and Auto-Modeling Support: Integrates T5Gemma into the AutoConfig and AutoModel systems, allowing for easy loading and use of the new model within the framework.
  • Dedicated Test Suite: Includes a new test file to validate the functional correctness and numerical precision of the T5Gemma implementation against its PyTorch counterpart.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the T5Gemma model, including its configuration, model definition, and tests. The implementation looks solid, but there are a few areas for improvement. I've pointed out some issues related to code consistency, such as alphabetical ordering of imports and mappings. There's also a potential bug in the model parallelism configuration (_no_split_modules) and some code duplication that could be refactored. Additionally, I've suggested replacing a wildcard import with explicit imports for better code clarity.

config: T5GemmaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["T5GemmaBlock"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _no_split_modules list contains "T5GemmaBlock", but this class is not defined in the file. This will likely cause issues with model parallelism. It seems it should refer to the layer classes, like T5GemmaEncoderLayer and T5GemmaDecoderLayer.

Suggested change
_no_split_modules = ["T5GemmaBlock"]
_no_split_modules = ["T5GemmaEncoderLayer", "T5GemmaDecoderLayer"]

("vipllava", "VipLlavaForConditionalGeneration"),
("visual_bert", "VisualBertForPreTraining"),
("vit_mae", "ViTMAEForPreTraining"),
("t5gemma", "T5GemmaForConditionalGeneration"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new entries for t5gemma are not in alphabetical order in several mapping lists (MODEL_FOR_PRETRAINING_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES). To maintain consistency within the file, please place them in their correct alphabetical positions, usually after the corresponding t5 entry.

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling_t5gemma import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Wildcard imports (*) are discouraged as they can pollute the namespace and make it unclear which names are being imported. Please use explicit imports and define __all__ to control what is exported from the module. It is also good practice to keep the imported names and __all__ list sorted alphabetically.

Suggested change
from .modeling_t5gemma import *
from .modeling_t5gemma import (
T5GemmaEncoderModel,
T5GemmaForConditionalGeneration,
T5GemmaForSequenceClassification,
T5GemmaForTokenClassification,
T5GemmaModel,
T5GemmaPreTrainedModel,
)
__all__ = [
"T5GemmaEncoderModel",
"T5GemmaForConditionalGeneration",
"T5GemmaForSequenceClassification",
"T5GemmaForTokenClassification",
"T5GemmaModel",
"T5GemmaPreTrainedModel",
]

Comment on lines +588 to +591
_can_record_outputs = {
"hidden_states": T5GemmaDecoderLayer,
"attentions": T5GemmaAttention,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The T5GemmaAttention class is used here for recording attention outputs, but it appears to be a redundant implementation of self-attention, very similar to T5GemmaSelfAttention. Other parts of the model like T5GemmaEncoder already use T5GemmaSelfAttention for this purpose. To reduce code duplication and improve maintainability, consider using T5GemmaSelfAttention here and removing the T5GemmaAttention class.

Suggested change
_can_record_outputs = {
"hidden_states": T5GemmaDecoderLayer,
"attentions": T5GemmaAttention,
}
_can_record_outputs = {
"hidden_states": T5GemmaDecoderLayer,
"attentions": T5GemmaSelfAttention,
}

@alien-0119 alien-0119 added the new model add new model to mindone label Nov 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

new model add new model to mindone

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant