Skip to content

Commit e27df8d

Browse files
authored
feat(transformers): support Qwen2.5VLImageProcessorFast/ Qwen2.5VLVideoProcessor (#1429)
1 parent eb6a005 commit e27df8d

File tree

6 files changed

+460
-42
lines changed

6 files changed

+460
-42
lines changed

mindone/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,7 @@
11241124
Qwen2VLImageProcessorFast,
11251125
Qwen2VLModel,
11261126
Qwen2VLPreTrainedModel,
1127+
Qwen2VLVideoProcessor,
11271128
)
11281129
from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
11291130
from .models.recurrent_gemma import RecurrentGemmaForCausalLM, RecurrentGemmaModel, RecurrentGemmaPreTrainedModel

mindone/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
("oneformer", ("OneFormerImageProcessor",)),
7373
("owlv2", ("Owlv2ImageProcessor",)),
7474
("owlvit", ("OwlViTImageProcessor",)),
75-
("qwen2_5_vl", ("Qwen2VLImageProcessor",)),
75+
("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
7676
("sam", ("SamImageProcessor",)),
7777
("segformer", ("SegformerImageProcessor",)),
7878
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),

mindone/transformers/models/auto/video_processing_auto.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@
4242
# the transformers package is used with Microsoft's Pylance language server.
4343
VIDEO_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict()
4444
else:
45-
VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict()
45+
VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
46+
[
47+
("qwen2_5_vl", "Qwen2VLVideoProcessor"),
48+
]
49+
)
4650

4751
if version.parse(transformers.__version__) >= version.parse("4.57.0"):
4852
VIDEO_PROCESSOR_MAPPING_NAMES.update(

mindone/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py

Lines changed: 98 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,39 @@
2626
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2727
# See the License for the specific language governing permissions and
2828
# limitations under the License.
29-
from typing import List, Union
29+
from typing import Optional, Union
3030

31+
import numpy as np
3132
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
3233

3334
import mindspore as ms
3435

3536
from ...feature_extraction_utils import BatchFeature
3637
from ...image_utils import ImageInput
37-
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
38+
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
3839
from ...video_utils import VideoInput
3940

4041

4142
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
42-
fps: Union[List[float], float]
43+
fps: Union[list[float], float]
44+
45+
46+
class Qwen2_5_VLImagesKwargs(ImagesKwargs):
47+
min_pixels: Optional[int]
48+
max_pixels: Optional[int]
49+
patch_size: Optional[int]
50+
temporal_patch_size: Optional[int]
51+
merge_size: Optional[int]
4352

4453

4554
class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
55+
images_kwargs: Qwen2_5_VLImagesKwargs
4656
videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
4757
_defaults = {
4858
"text_kwargs": {
4959
"padding": False,
60+
"return_mm_token_type_ids": False,
5061
},
51-
"videos_kwargs": {"fps": 2.0},
5262
}
5363

5464

@@ -62,25 +72,37 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
6272
The image processor is a required input.
6373
tokenizer ([`Qwen2TokenizerFast`], *optional*):
6474
The tokenizer is a required input.
75+
video_processor ([`Qwen2_5_VLVideoProcessor`], *optional*):
76+
The video processor is a required input.
6577
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
6678
in a chat into a tokenizable string.
6779
"""
6880

69-
attributes = ["image_processor", "tokenizer"]
70-
valid_kwargs = ["chat_template"]
81+
attributes = ["image_processor", "tokenizer", "video_processor"]
7182

7283
image_processor_class = "AutoImageProcessor"
84+
video_processor_class = "AutoVideoProcessor"
7385
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
7486

75-
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
87+
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
7688
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
7789
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
78-
super().__init__(image_processor, tokenizer, chat_template=chat_template)
90+
self.image_token_id = (
91+
tokenizer.image_token_id
92+
if getattr(tokenizer, "image_token_id", None)
93+
else tokenizer.convert_tokens_to_ids(self.image_token)
94+
)
95+
self.video_token_id = (
96+
tokenizer.video_token_id
97+
if getattr(tokenizer, "video_token_id", None)
98+
else tokenizer.convert_tokens_to_ids(self.video_token)
99+
)
100+
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
79101

80102
def __call__(
81103
self,
82104
images: ImageInput = None,
83-
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
105+
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
84106
videos: VideoInput = None,
85107
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
86108
) -> BatchFeature:
@@ -91,14 +113,14 @@ def __call__(
91113
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
92114
93115
Args:
94-
images (`PIL.Image.Image`, `np.ndarray`, `ms.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[ms.Tensor]`):
116+
images (`PIL.Image.Image`, `np.ndarray`, `ms.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[ms.Tensor]`):
95117
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
96118
tensor. Both channels-first and channels-last formats are supported.
97-
text (`str`, `List[str]`, `List[List[str]]`):
119+
text (`str`, `list[str]`, `list[list[str]]`):
98120
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
99121
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
100122
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
101-
videos (`np.ndarray`, `ms.Tensor`, `List[np.ndarray]`, `List[ms.Tensor]`):
123+
videos (`np.ndarray`, `ms.Tensor`, `list[np.ndarray]`, `list[ms.Tensor]`):
102124
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
103125
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
104126
return_tensors (`str` or [`~utils.TensorType`], *optional*):
@@ -124,69 +146,105 @@ def __call__(
124146
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
125147
**kwargs,
126148
)
149+
150+
image_inputs = videos_inputs = {}
127151
if images is not None:
128-
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
152+
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
129153
image_grid_thw = image_inputs["image_grid_thw"]
130-
else:
131-
image_inputs = {}
132-
image_grid_thw = None
133154

134155
if videos is not None:
135-
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"])
156+
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
157+
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
136158
video_grid_thw = videos_inputs["video_grid_thw"]
137159

138-
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
139160
if isinstance(fps, (int, float)):
140-
second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw)
161+
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
141162
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
142-
second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps]
163+
second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
143164
else:
144165
raise ValueError(
145166
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the "
146167
f"length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
147168
)
148169
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
149170

150-
else:
151-
videos_inputs = {}
152-
video_grid_thw = None
153-
154171
if not isinstance(text, list):
155172
text = [text]
156173

157-
if image_grid_thw is not None:
174+
text = text.copy() # below lines change text in-place
175+
if images is not None:
158176
merge_length = self.image_processor.merge_size**2
159177
index = 0
160178
for i in range(len(text)):
161179
while self.image_token in text[i]:
162-
text[i] = text[i].replace(
163-
self.image_token,
164-
"<|placeholder|>" * (image_grid_thw[index].prod().item() // merge_length),
165-
1,
166-
)
180+
num_image_tokens = image_grid_thw[index].prod().item() // merge_length
181+
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
167182
index += 1
168183
text[i] = text[i].replace("<|placeholder|>", self.image_token)
169184

170-
if video_grid_thw is not None:
171-
merge_length = self.image_processor.merge_size**2
185+
if videos is not None:
186+
merge_length = self.video_processor.merge_size**2
172187
index = 0
173188
for i in range(len(text)):
174189
while self.video_token in text[i]:
175-
text[i] = text[i].replace(
176-
self.video_token,
177-
"<|placeholder|>" * (video_grid_thw[index].prod().item() // merge_length),
178-
1,
179-
)
190+
num_video_tokens = video_grid_thw[index].prod().item() // merge_length
191+
text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1)
180192
index += 1
181193
text[i] = text[i].replace("<|placeholder|>", self.video_token)
182194

183195
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
196+
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
184197
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors="np")
185198
if return_tensors == "ms":
186199
for k, v in text_inputs.items():
187200
text_inputs[k] = ms.tensor(v)
201+
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
202+
203+
if return_mm_token_type_ids:
204+
array_ids = np.array(text_inputs["input_ids"])
205+
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
206+
mm_token_type_ids[array_ids == self.image_token_id] = 1
207+
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
208+
209+
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
210+
211+
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
212+
"""
213+
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
214+
Args:
215+
image_sizes (`list[list[int]]`, *optional*):
216+
The input sizes formatted as (height, width) per each image.
217+
video_sizes (`list[list[int]]`, *optional*):
218+
The input sizes formatted as (num_frames, height, width) per each video.
219+
Returns:
220+
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
221+
input modalities, along with other useful data.
222+
"""
188223

189-
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
224+
vision_data = {}
225+
if image_sizes is not None:
226+
images_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("images_kwargs", {})
227+
images_kwargs.update(kwargs)
228+
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
229+
230+
num_image_patches = [
231+
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
232+
for image_size in image_sizes
233+
]
234+
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
235+
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
236+
237+
if video_sizes is not None:
238+
videos_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("videos_kwargs", {})
239+
videos_kwargs.update(kwargs)
240+
num_video_patches = [
241+
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
242+
for video_size in video_sizes
243+
]
244+
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
245+
vision_data["num_video_tokens"] = num_video_tokens
246+
247+
return MultiModalData(**vision_data)
190248

191249
def batch_decode(self, *args, **kwargs):
192250
"""
@@ -214,13 +272,13 @@ def post_process_image_text_to_text(
214272
or `(sequence_length,)`.
215273
skip_special_tokens (`bool`, *optional*, defaults to `True`):
216274
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
217-
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
275+
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
218276
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
219277
**kwargs:
220278
Additional arguments to be passed to the tokenizer's `batch_decode method`.
221279
222280
Returns:
223-
`List[str]`: The decoded text.
281+
`list[str]`: The decoded text.
224282
"""
225283
return self.tokenizer.batch_decode(
226284
generated_outputs,

mindone/transformers/models/qwen2_vl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
from .image_processing_qwen2_vl import *
1818
from .image_processing_qwen2_vl_fast import *
1919
from .modeling_qwen2_vl import Qwen2VLForConditionalGeneration, Qwen2VLModel, Qwen2VLPreTrainedModel
20+
from .video_processing_qwen2_vl import *

0 commit comments

Comments
 (0)