Skip to content

Commit 8f6e9c9

Browse files
wcrzlhwtominvigo999
authored
feat(transformers): Transformers 4.54 base (#1387)
* upgrade activation_func to transformers v4.54 * feat(transformers): upgrade attn_mask/rope to 4.54 * feat(transformers): upgrade modeling_layers to 4.54 * feat(transformers): upgrade cache_utils to 4.54 * feat(transformers): upgrade modeling_utils to v4.54 * feat(transformers): upgrade generation/utils to v4.54 * feat(transformers): add ernie4.5 for validation * fix get_type_hints problem * fix get_type_hints problem * fix get_type_hints problem * fix metadata.get keyerror * fix masking_utils alignment * fix generation/utils logic * fix get_output_embedding override bug * fix __init_subclass__ bug * suplement checkpoint_conversion_mapping * feat(transformers): upgrade beam search to v4.54 * feat(transformers): upgrade candidate_generator to v4.54 * feat(transformers): upgrade logits_process/stopping_criteria to v4.54 * pre-commit * pre-commit * update backbone_utils * update generic * remove add_model_info_to_auto_map & update feature_extraction_utils.py * remove add_model_info_to_auto_map & update image_processing_base.py * remove add_model_info_to_auto_map & update processing_utils.py * remove add_model_info_to_auto_map & update video_utils.py * tokenization_utils.py update * add_model_info_to_custom_pipelines * update tokenization_utils_base.py * update image_transforms.py * update video_utils.py and image_utils.py * update image_utils.py & image_processing_utils_fast.py * update integration sdpa_attention.py * update mask_utils.py * update modeling_flash_attention_utils.py * update modeling_outputs.py * fix pre-commit errors * fix pre-commit errors * add modeling_layers.py from cui yushi * fix import in transformers * rm tokenization_utils.py and tokenization_utils_base.py * resize stacked images one by one * remove torchvision decoders * fix get_default_dtype bug * load module dynamically from mindone/transformers * not support FA * add video_processing_utils * fix import error/add audio_utils/fix processor bug/attn_implementation check * fix generic.py error * fix generic.py error * audio_utils.py * audio_utils.py * fix errors * update processing_chameleon * update processing_idefics * update processing_llava_next * update processing_llava_next_video * update processing_llava_next_video * update processing_llava_next_video * update processing_llava_next_video * update processing_qwen_2_5_omni * update processing_siglip_fast * rm ernie 45 * sdpa does not support * sdpa does not support: aria * fix attn_implementation configuration bug * Fix attn_implementation * revert test script changes * warn user of sdpa usage * fix fa bug/key_renaming_mapping bug * pre-commit * upgrade modeling_utils/save_pretrained to transformersv4.54 * refactor fa part * Fix some model's UT * adapt to index_copy * in case head_dim=None * in case head_dim=None in rope_utils * rm num_batches_tracked cast * revert _support_dynamic_input to _support_jit * fix class name mismatch in generation/utils * fix pa error/delete unused fa part * remove unused part * generation/utils ops-->mint * copyright/pre-commit * fix bugs * supplement activation api * reformat * remove losskwargs * fix disable_grouping bug in image processing * fix attn_implementation setting in modeling_utils/from_pretrained * fix attn_implementation setting in modeling_utils/from_pretrained * fix modeling_utils/from_config mindspore_dtype setting, generation/utils device setting bug * feat(transformers): add qwen3_vl/qwen3_vl_moe model * fix moe precision bug * fix qwen3_vl moe memory bugs * supplement zero3 model weight shard for moe part * fix qwen3_vl_moe precision bug * fix qwen3_vl_moe precision bug * fix moe part shard bug * pre-commit * reformat * fix(transformers): fix typos in qwen3_vl docs * feat(transformers): add processor for qwen3_vl (#1326) * fix(transformers): supplement condition of taking model as processor * fix(transformers): reformat generation/utils * fix(transformers): supplement candidate generator * fix(transformers): supplement logits processor * feat(transformers): add assisted_generation/dola_generation/contrasive_search/group_beam_search/constrainted beam search * reformat * fix import bug * fix ut bug * update pyproject.toml * pre-commit * reformat * update loss_type --------- Co-authored-by: Didan Deng <[email protected]> Co-authored-by: vigo999 <[email protected]>
1 parent de6c2c4 commit 8f6e9c9

File tree

126 files changed

+18418
-11740
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

126 files changed

+18418
-11740
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Qwen3-VL series
2+
3+
## Introduction
4+
[Qwen3-VL](https://huggingface.co/papers/2502.13923) is a multimodal vision-language model series, encompassing both dense and MoE variants, as well as Instruct and Thinking versions. Building upon its predecessors, Qwen3-VL delivers significant improvements in visual understanding while maintaining strong pure text capabilities. Key architectural advancements include: enhanced MRope with interleaved layout for better spatial-temporal modeling, DeepStack integration to effectively leverage multi-level features from the Vision Transformer (ViT), and improved video understanding through text-based time alignment—evolving from T-RoPE to text timestamp alignment for more precise temporal grounding. These innovations collectively enable Qwen3-VL to achieve superior performance in complex multimodal tasks.
5+
6+
# Get Started
7+
8+
## Requirements:
9+
| mindspore | ascend driver | firmware | cann tookit/kernel |
10+
|-----------|----------------|----------------|--------------------|
11+
| 2.6.0 | 24.1.RC3.b080 | 7.5.T11.0.B088 | 8.1.RC1 |
12+
13+
### Installation:
14+
```
15+
git clone https://github.com/mindspore-lab/mindone.git -b hf-transformers-4.54
16+
cd mindone
17+
pip install -e .
18+
cd ..
19+
20+
# compile newest transformers whl because qwen3-vl(transformers v4.57.dev.0) haven't released
21+
git clone https://github.com/huggingface/transformers.git
22+
cd transformers
23+
git reset --hard d0af4269ec260b9c4aeeda24c346a469e44799e1
24+
pip install -e .
25+
cd ..
26+
27+
cd mindone/examples/transformers/qwen3_vl
28+
```
29+
30+
## Quick Start
31+
32+
Here is a usage example of Qwen3-VL-4B-Instruct. you can use the following command:
33+
34+
```bash
35+
# for Qwen3-VL-4B-Instruct inference
36+
python generate_qwen3_vl.py
37+
--model_name "Qwen/Qwen3-VL-4B-Instruct"
38+
--image "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
39+
--prompt "Describe this image."
40+
```
41+
42+
```bash
43+
# for Qwen3-VL-30B-A3B-Instruct inference
44+
msrun --worker_num=2 --local_worker_num=2 --master_port=8118 \
45+
--log_dir=msrun_log --join=True --cluster_time_out=300 \
46+
generate_qwen3_vl_moe.py \
47+
--model_name "Qwen/Qwen3-VL-30B-A3B-Instruct" \
48+
--image "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" \
49+
--prompt "Describe this image." \
50+
```
51+
52+
Image:
53+
![sample image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg)
54+
55+
Prompt: Describe this image.
56+
57+
Qwen3-VL-4B Outputs:
58+
```
59+
['Of course, here is detailed description of the image provided.\n\n
60+
This is a close-up photograph of a Pallas\'s cat ($Felis$, $manul$),
61+
an endangered wild feline species native to Central Aisa.
62+
...
63+
**Appearance:** It has a stocky and robust build with short legs
64+
and a large head relative to its body size. Its fur is thick and dense,
65+
appearing somewhat fluffy or "matted,", which is characteristic']
66+
```
67+
68+
Qwen3-VL-30B Outputs:
69+
```
70+
['Of course, here is detailed description of the image provided.\n\n
71+
This is a dynamic and charming photograph of a Palla's cat (also known as a manul) in a snowy enviroment.
72+
...
73+
"Appearance:" The cat has a very distinctive apperance, characterized by its stocky, low-slung body and exceptionally
74+
thick, dense fur. This coat is a mix of brownish"]
75+
```
76+
77+
`model_name` and `image` could be replaced with your local path. Give it a try with various images and prompts🤗🤗.
78+
79+
## Inference Speed
80+
| model name | mindspore version | precision* | cards | attention type | tokens/s |
81+
|:------------------------------:|:-----------------:|:----------:|:-----:|:--------------:|:----------:|
82+
| Qwen/Qwen3-VL-4B-Instruct | 2.6.0 | bf16 | 1 | flash_attn | 1.35 |
83+
| Qwen/Qwen3-VL-30B-A3B-Instruct | 2.6.0 | bf16 | 2 | flash_attn | 0.5 |
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import argparse
2+
3+
import numpy as np
4+
5+
import mindspore as ms
6+
7+
from mindone.transformers import AutoProcessor, Qwen3VLForConditionalGeneration
8+
9+
10+
def generate(args):
11+
model = Qwen3VLForConditionalGeneration.from_pretrained(
12+
args.model_name,
13+
mindspore_dtype=ms.bfloat16,
14+
attn_implementation=args.attn_implementation,
15+
)
16+
17+
processor = AutoProcessor.from_pretrained(
18+
args.model_name,
19+
use_fast=False,
20+
)
21+
22+
messages = [
23+
{
24+
"role": "user",
25+
"content": [
26+
{
27+
"type": "image",
28+
"url": args.image,
29+
},
30+
{
31+
"type": "text",
32+
"text": args.prompt,
33+
},
34+
],
35+
}
36+
]
37+
38+
inputs = processor.apply_chat_template(
39+
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np"
40+
)
41+
42+
# convert input to Tensor
43+
for key, value in inputs.items():
44+
if isinstance(value, np.ndarray):
45+
inputs[key] = ms.tensor(value)
46+
elif isinstance(value, list):
47+
inputs[key] = ms.Tensor(value)
48+
49+
generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
50+
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
51+
output_text = processor.batch_decode(
52+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
53+
)
54+
print(output_text)
55+
56+
57+
if __name__ == "__main__":
58+
parser = argparse.ArgumentParser(description="Qwen3VL demo.")
59+
60+
parser.add_argument("--prompt", type=str, default="Describe this image.")
61+
parser.add_argument(
62+
"--image",
63+
type=str,
64+
default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
65+
)
66+
parser.add_argument(
67+
"--model_name", type=str, default="Qwen/Qwen3-VL-4B-Instruct", help="Path to the pre-trained model."
68+
)
69+
parser.add_argument(
70+
"--attn_implementation",
71+
type=str,
72+
default="flash_attention_2",
73+
choices=["flash_attention_2", "eager"],
74+
)
75+
76+
# Parse the arguments
77+
args = parser.parse_args()
78+
79+
generate(args)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import argparse
2+
from functools import partial
3+
4+
import numpy as np
5+
6+
import mindspore as ms
7+
import mindspore.mint.distributed as dist
8+
from mindspore.communication import GlobalComm
9+
10+
from mindone.trainers.zero import prepare_network
11+
from mindone.transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
12+
13+
14+
def generate(args):
15+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
16+
args.model_name,
17+
mindspore_dtype=ms.bfloat16,
18+
attn_implementation=args.attn_implementation,
19+
)
20+
21+
# use zero3 parallel
22+
shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
23+
model = shard_fn(model)
24+
25+
processor = AutoProcessor.from_pretrained(
26+
args.model_name,
27+
use_fast=False,
28+
)
29+
30+
messages = [
31+
{
32+
"role": "user",
33+
"content": [
34+
{
35+
"type": "image",
36+
"url": args.image,
37+
},
38+
{
39+
"type": "text",
40+
"text": args.prompt,
41+
},
42+
],
43+
}
44+
]
45+
46+
inputs = processor.apply_chat_template(
47+
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np"
48+
)
49+
50+
# convert input to Tensor
51+
for key, value in inputs.items():
52+
if isinstance(value, np.ndarray):
53+
inputs[key] = ms.tensor(value)
54+
elif isinstance(value, list):
55+
inputs[key] = ms.Tensor(value)
56+
57+
generated_ids = model.generate(**inputs, max_new_tokens=128)
58+
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
59+
output_text = processor.batch_decode(
60+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
61+
)
62+
print(output_text)
63+
64+
65+
if __name__ == "__main__":
66+
parser = argparse.ArgumentParser(description="Qwen3VLMoE demo.")
67+
68+
parser.add_argument("--prompt", type=str, default="Describe this image.")
69+
parser.add_argument(
70+
"--image",
71+
type=str,
72+
default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
73+
)
74+
parser.add_argument(
75+
"--model_name", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct", help="Path to the pre-trained model."
76+
)
77+
parser.add_argument(
78+
"--attn_implementation",
79+
type=str,
80+
default="flash_attention_2",
81+
choices=["flash_attention_2", "eager"],
82+
)
83+
84+
# Parse the arguments
85+
args = parser.parse_args()
86+
87+
# set up card communication
88+
dist.init_process_group(backend="hccl")
89+
ms.set_auto_parallel_context(parallel_mode="data_parallel")
90+
91+
generate(args)

mindone/models/modules/parallel/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .conv import Conv1d, Conv2d, Conv3d, Mint_Conv2d, Mint_Conv3d
44
from .dense import Dense, Linear
5+
from .moe_text_experts import MoeTextExperts
56

67
# {Original MindSpore Cell: New Cell in ZeRO3}
78
PARALLEL_MODULES = {
@@ -14,4 +15,6 @@
1415
mint.nn.Linear: Linear,
1516
}
1617

18+
SPECIAL_CASE_FOR_PARALLEL_MODULES = {nn.Cell: MoeTextExperts}
19+
1720
__all__ = ["Conv1d", "Conv2d", "Conv3d", "Mint_Conv2d", "Mint_Conv3d", "Dense", "Linear"]
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Literal, Optional
2+
3+
from mindspore import Tensor
4+
from mindspore import dtype as mstype
5+
from mindspore import mint, nn
6+
from mindspore.communication import get_group_size, get_rank
7+
from mindspore.communication.management import GlobalComm
8+
from mindspore.context import ParallelMode
9+
from mindspore.parallel._utils import _get_parallel_mode
10+
11+
from .param_wrapper import ZeroParamWrapper
12+
13+
14+
class MoeTextExperts(nn.Cell):
15+
def __init__(
16+
self,
17+
net: nn.Cell,
18+
zero_stage: Literal[0, 1, 2, 3] = 0,
19+
optimizer_parallel_group: str = GlobalComm.WORLD_COMM_GROUP,
20+
cell_type: Optional[mstype.Type] = None,
21+
):
22+
super().__init__(auto_prefix=False)
23+
self.net = net
24+
self.set_param_wrapper(zero_stage, optimizer_parallel_group, cell_type)
25+
26+
def set_param_wrapper(self, zero_stage, optimizer_parallel_group, cell_type=None):
27+
self.param_wrapper_gate_up_proj = nn.Identity()
28+
self.param_wrapper_down_proj = nn.Identity()
29+
if zero_stage == 3:
30+
# Init parallel settings
31+
is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL
32+
op_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1
33+
op_rank_id = get_rank(optimizer_parallel_group) if is_parallel else 0
34+
self.op_group_size = op_group_size
35+
self.op_rank_id = op_rank_id
36+
self.param_wrapper_gate_up_proj = ZeroParamWrapper(
37+
self.net.gate_up_proj, zero_stage, optimizer_parallel_group, cell_type
38+
)
39+
if self.param_wrapper_gate_up_proj.need_rewrite:
40+
self.net.gate_up_proj.assign_value(
41+
Tensor.from_numpy(
42+
self.net.gate_up_proj.numpy().reshape(op_group_size, -1, *self.net.gate_up_proj.shape[1:])[
43+
op_rank_id
44+
]
45+
)
46+
)
47+
self.param_wrapper_down_proj = ZeroParamWrapper(
48+
self.net.down_proj, zero_stage, optimizer_parallel_group, cell_type
49+
)
50+
if self.param_wrapper_down_proj.need_rewrite:
51+
self.net.down_proj.assign_value(
52+
Tensor.from_numpy(
53+
self.net.down_proj.numpy().reshape(op_group_size, -1, *self.net.down_proj.shape[1:])[op_rank_id]
54+
)
55+
)
56+
57+
def construct(self, hidden_states, routing_weights, router_indices):
58+
batch_size = hidden_states.shape[0]
59+
hidden_states = hidden_states.reshape(-1, self.net.hidden_size) # (num_tokens, hidden_size)
60+
61+
hidden_states = hidden_states.repeat(self.net.num_experts, 1)
62+
hidden_states = hidden_states.view(self.net.num_experts, -1, self.net.hidden_size)
63+
64+
gate_up = mint.bmm(hidden_states, self.param_wrapper_gate_up_proj(self.net.gate_up_proj))
65+
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
66+
next_states = mint.bmm((up * self.net.act_fn(gate)), self.param_wrapper_down_proj(self.net.down_proj))
67+
next_states = next_states.reshape(self.net.num_experts, batch_size, -1, self.net.hidden_size)
68+
next_states = next_states * routing_weights.swapaxes(0, 1).view(self.net.num_experts, batch_size, -1)[..., None]
69+
next_states = next_states.sum(dim=0)
70+
return next_states

0 commit comments

Comments
 (0)