Skip to content

Bug: IndexError blocks batched multimodal inference with images/videos #86

@EladSharony

Description

@EladSharony

Summary

A boolean mask derived from text positions is incorrectly applied to the flattened image features during stitching, causing a shape mismatch and a runtime crash during generation.

Impact: blocks batched multimodal inference with images/videos.

Repro

import torch
from perception_models.apps.plm.transformer import LMTransformer

# Token embeddings: B=1, L_tok=8, D=16
h_tok = torch.zeros(1, 8, 16)

# Simulate 16 image feature rows (e.g., many patches)
# Shape matches what stitch_images_into_text expects before flatten: [num_chunks, tokens_per_chunk, D]
h_img = torch.zeros(16, 1, 16)

# image_pos_index marks text positions that should receive image features.
# Here we only have 8 valid text positions (L_tok=8), mapping to indices 0..7
image_pos_index = torch.arange(8, dtype=torch.long).unsqueeze(0)  # shape [1, 8]

# One sample with 16 chunks; media contains non-text
num_chunks = [16]
media_type = ["multi_image"]

# This reproduces the IndexError in the current (buggy) code:
# applying a boolean mask of length 8 to a 16-row feature tensor.
LMTransformer.stitch_images_into_text(
    None, h_tok, h_img, image_pos_index, num_chunks, media_type
)
File "/perception_models/apps/plm/transformer.py", line 229, in stitch_images_into_text
    h_tok[img_indices_B, img_indices_L] = h_img[non_text_indices].flatten(0, 1)[
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: The shape of the mask [8] at index 0 does not match the shape of the indexed tensor [16, 16] at index 0

Root cause

A boolean mask for text positions (valid_index_filter) is incorrectly applied to the flattened image features, mixing index spaces (text vs. image).

Problematic code: 225:231:perception_models/apps/plm/transformer.py

img_indices_B, img_indices_L = torch.where(image_pos_index >= 0)
valid_index_filter = img_indices_L < h_tok.shape[1]
img_indices_L = img_indices_L[valid_index_filter]
img_indices_B = img_indices_B[valid_index_filter]
h_tok[img_indices_B, img_indices_L] = h_img[non_text_indices].flatten(0, 1)[valid_index_filter]

Proposed fix

Use image_pos_index to map each valid text position to its corresponding image feature row; do not reapply the boolean mask to the image features:

img_indices_B, img_indices_L = torch.where(image_pos_index >= 0)
valid_index_filter = img_indices_L < h_tok.shape[1]
img_indices_L = img_indices_L[valid_index_filter]
img_indices_B = img_indices_B[valid_index_filter]

# Map each text position to its corresponding image feature row
img_feat_indices = image_pos_index[img_indices_B, img_indices_L]
h_tok[img_indices_B, img_indices_L] = h_img[non_text_indices].flatten(0, 1)[img_feat_indices]

Metadata

Metadata

Assignees

No one assigned

    Labels

    PLMbugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions