Skip to content

Commit d45c3bb

Browse files
committed
add t5gemma
1 parent 8f6e9c9 commit d45c3bb

File tree

8 files changed

+1755
-0
lines changed

8 files changed

+1755
-0
lines changed

mindone/transformers/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,14 @@
13351335
from .models.trocr import TrOCRForCausalLM, TrOCRPreTrainedModel
13361336
from .models.tvp import TvpForVideoGrounding, TvpModel, TvpPreTrainedModel
13371337
from .models.udop import UdopEncoderModel, UdopForConditionalGeneration, UdopModel, UdopPreTrainedModel
1338+
from .models.t5gemma import (
1339+
T5GemmaEncoderModel,
1340+
T5GemmaForConditionalGeneration,
1341+
T5GemmaPreTrainedModel,
1342+
T5GemmaForSequenceClassification,
1343+
T5GemmaForTokenClassification,
1344+
T5GemmaModel,
1345+
)
13381346
from .models.umt5 import (
13391347
UMT5EncoderModel,
13401348
UMT5ForQuestionAnswering,

mindone/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@
228228
trocr,
229229
tvp,
230230
udop,
231+
t5gemma,
231232
umt5,
232233
unispeech,
233234
unispeech_sat,

mindone/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@
257257
("trocr", "TrOCRConfig"),
258258
("tvp", "TvpConfig"),
259259
("udop", "UdopConfig"),
260+
("t5gemma", "T5GemmaConfig"),
260261
("umt5", "UMT5Config"),
261262
("unispeech", "UniSpeechConfig"),
262263
("unispeech-sat", "UniSpeechSatConfig"),
@@ -521,6 +522,7 @@
521522
("swinv2", "Swin Transformer V2"),
522523
("swin2sr", "Swin2SR"),
523524
("t5", "T5"),
525+
("t5gemma", "T5Gemma"),
524526
("t5v1.1", "T5v1.1"),
525527
("table-transformer", "Table Transformer"),
526528
("tapas", "TAPAS"),

mindone/transformers/models/auto/modeling_auto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@
233233
("timesformer", "TimesformerModel"),
234234
("tvp", "TvpModel"),
235235
("udop", "UdopModel"),
236+
("t5gemma", "T5GemmaModel"),
236237
("umt5", "UMT5Model"),
237238
("unispeech", "UniSpeechModel"),
238239
("unispeech-sat", "UniSpeechSatModel"),
@@ -328,6 +329,7 @@
328329
("vipllava", "VipLlavaForConditionalGeneration"),
329330
("visual_bert", "VisualBertForPreTraining"),
330331
("vit_mae", "ViTMAEForPreTraining"),
332+
("t5gemma", "T5GemmaForConditionalGeneration"),
331333
("wav2vec2", "Wav2Vec2ForPreTraining"),
332334
("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
333335
("xlm", "XLMWithLMHeadModel"),
@@ -397,6 +399,7 @@
397399
("squeezebert", "SqueezeBertForMaskedLM"),
398400
("t5", "T5ForConditionalGeneration"),
399401
("tapas", "TapasForMaskedLM"),
402+
("t5gemma", "T5GemmaForConditionalGeneration"),
400403
("wav2vec2", "Wav2Vec2ForMaskedLM"),
401404
("whisper", "WhisperForConditionalGeneration"),
402405
("xlm", "XLMWithLMHeadModel"),
@@ -831,6 +834,7 @@
831834
("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
832835
("squeezebert", "SqueezeBertForSequenceClassification"),
833836
("t5", "T5ForConditionalGeneration"),
837+
("t5gemma", "T5GemmaForConditionalGeneration"),
834838
("umt5", "UMT5ForConditionalGeneration"),
835839
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
836840
]
@@ -919,6 +923,7 @@
919923
("starcoder2", "Starcoder2ForSequenceClassification"),
920924
("t5", "T5ForSequenceClassification"),
921925
("tapas", "TapasForSequenceClassification"),
926+
("t5gemma", "T5GemmaForSequenceClassification"),
922927
("umt5", "UMT5ForSequenceClassification"),
923928
("xlm", "XLMForSequenceClassification"),
924929
("xlm-roberta", "XLMRobertaForSequenceClassification"),
@@ -1070,6 +1075,7 @@
10701075
("squeezebert", "SqueezeBertForTokenClassification"),
10711076
("stablelm", "StableLmForTokenClassification"),
10721077
("t5", "T5ForTokenClassification"),
1078+
("t5gemma", "T5GemmaForTokenClassification"),
10731079
("umt5", "UMT5ForTokenClassification"),
10741080
("xlm", "XLMForTokenClassification"),
10751081
("xlm-roberta", "XLMRobertaForTokenClassification"),
@@ -1257,6 +1263,7 @@
12571263
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
12581264
("squeezebert", "SqueezeBertModel"),
12591265
("t5", "T5EncoderModel"),
1266+
("t5gemma", "T5GemmaEncoderModel"),
12601267
("umt5", "UMT5EncoderModel"),
12611268
("xlm", "XLMModel"),
12621269
("xlm-roberta", "XLMRobertaModel"),
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# This code is adapted from https://github.com/huggingface/transformers
4+
# with modifications to run transformers on mindspore.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
from .modeling_t5gemma import *

0 commit comments

Comments
 (0)