Skip to content

Add FastImageProcessor for InstructBLIPVideo #37611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/instructblipvideo.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ The attributes can be obtained from model config, as `model.config.num_query_tok
[[autodoc]] InstructBlipVideoImageProcessor
- preprocess

## InstructBlipVideoImageProcessorFast

[[autodoc]] InstructBlipVideoImageProcessorFast
- preprocess

## InstructBlipVideoVisionModel

[[autodoc]] InstructBlipVideoVisionModel
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
("imagegpt", ("ImageGPTImageProcessor",)),
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),
("instructblipvideo", ("InstructBlipVideoImageProcessor", "InstructBlipVideoImageProcessorFast")),
("janus", ("JanusImageProcessor")),
("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/instructblipvideo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if TYPE_CHECKING:
from .configuration_instructblipvideo import *
from .image_processing_instructblipvideo import *
from .image_processing_instructblipvideo_fast import *
from .modeling_instructblipvideo import *
from .processing_instructblipvideo import *
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
from ...utils import add_start_docstrings, is_torch_available


if is_torch_available():
import torch


@add_start_docstrings(
"Constructs a fast InstructBLIPVideo image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
)
class InstructBlipVideoImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
do_resize = True
size = {"height": 384, "width": 384}
do_rescale = True
rescale_factor = 1 / 255
do_normalize = True
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
do_convert_rgb = True

def __call__(self, images=None, return_tensors="pt", input_data_format=None, **kwargs):
# 1) 4D array-like video (not a torch.Tensor): (frames, H, W, C) or (frames, C, H, W)
if hasattr(images, "ndim") and images.ndim == 4 and not isinstance(images, torch.Tensor):
frames = images.shape[0]
# (frames, H, W, C) → (frames, C, H, W)
if images.shape[-1] not in (1, 3, 4):
images = images.transpose(0, 3, 1, 2)
flat_frames = [images[i] for i in range(frames)]
bf = super().__call__(
flat_frames,
return_tensors=return_tensors,
input_data_format=input_data_format,
**kwargs,
)
pv = bf["pixel_values"] # (frames, C, H, W)
pv = pv.unsqueeze(0) # (1, frames, C, H, W)
return BatchFeature(data={"pixel_values": pv}, tensor_type=return_tensors)

# 2) Batched videos: list of 4D array-like or torch.Tensor each (frames, C, H, W)
if isinstance(images, list) and len(images) > 0 and hasattr(images[0], "ndim") and images[0].ndim == 4:
batch_size = len(images)
frames = images[0].shape[0]
# Flatten all frames
flat_frames = []
for video in images:
for frame in video:
flat_frames.append(frame)
bf = super().__call__(
flat_frames,
return_tensors=return_tensors,
input_data_format=input_data_format,
**kwargs,
)
pv = bf["pixel_values"] # (batch_size*frames, C, H, W)
pv = pv.view(batch_size, frames, pv.size(1), pv.size(2), pv.size(3)) # (batch_size, frames, C, H, W)
return BatchFeature(data={"pixel_values": pv}, tensor_type=return_tensors)

# 3) Fallback: default fast processor behavior
return super().__call__(
images,
return_tensors=return_tensors,
input_data_format=input_data_format,
**kwargs,
)


__all__ = ["InstructBlipVideoImageProcessorFast"]
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
if is_vision_available():
from PIL import Image

from transformers import InstructBlipVideoImageProcessor
from transformers import InstructBlipVideoImageProcessor, InstructBlipVideoImageProcessorFast


class InstructBlipVideoProcessingTester:
Expand Down Expand Up @@ -109,6 +109,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F
@require_vision
class InstructBlipVideoProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = InstructBlipVideoImageProcessor if is_vision_available() else None
fast_image_processing_class = InstructBlipVideoImageProcessorFast

def setUp(self):
super().setUp()
Expand All @@ -120,13 +121,14 @@ def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()

def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))

def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
Expand Down