-
Notifications
You must be signed in to change notification settings - Fork 109
Open
Labels
Description
Summary
A boolean mask derived from text positions is incorrectly applied to the flattened image features during stitching, causing a shape mismatch and a runtime crash during generation.
Impact: blocks batched multimodal inference with images/videos.
Repro
import torch
from perception_models.apps.plm.transformer import LMTransformer
# Token embeddings: B=1, L_tok=8, D=16
h_tok = torch.zeros(1, 8, 16)
# Simulate 16 image feature rows (e.g., many patches)
# Shape matches what stitch_images_into_text expects before flatten: [num_chunks, tokens_per_chunk, D]
h_img = torch.zeros(16, 1, 16)
# image_pos_index marks text positions that should receive image features.
# Here we only have 8 valid text positions (L_tok=8), mapping to indices 0..7
image_pos_index = torch.arange(8, dtype=torch.long).unsqueeze(0) # shape [1, 8]
# One sample with 16 chunks; media contains non-text
num_chunks = [16]
media_type = ["multi_image"]
# This reproduces the IndexError in the current (buggy) code:
# applying a boolean mask of length 8 to a 16-row feature tensor.
LMTransformer.stitch_images_into_text(
None, h_tok, h_img, image_pos_index, num_chunks, media_type
)File "/perception_models/apps/plm/transformer.py", line 229, in stitch_images_into_text
h_tok[img_indices_B, img_indices_L] = h_img[non_text_indices].flatten(0, 1)[
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: The shape of the mask [8] at index 0 does not match the shape of the indexed tensor [16, 16] at index 0
Root cause
A boolean mask for text positions (valid_index_filter) is incorrectly applied to the flattened image features, mixing index spaces (text vs. image).
Problematic code: 225:231:perception_models/apps/plm/transformer.py
img_indices_B, img_indices_L = torch.where(image_pos_index >= 0)
valid_index_filter = img_indices_L < h_tok.shape[1]
img_indices_L = img_indices_L[valid_index_filter]
img_indices_B = img_indices_B[valid_index_filter]
h_tok[img_indices_B, img_indices_L] = h_img[non_text_indices].flatten(0, 1)[valid_index_filter]Proposed fix
Use image_pos_index to map each valid text position to its corresponding image feature row; do not reapply the boolean mask to the image features:
img_indices_B, img_indices_L = torch.where(image_pos_index >= 0)
valid_index_filter = img_indices_L < h_tok.shape[1]
img_indices_L = img_indices_L[valid_index_filter]
img_indices_B = img_indices_B[valid_index_filter]
# Map each text position to its corresponding image feature row
img_feat_indices = image_pos_index[img_indices_B, img_indices_L]
h_tok[img_indices_B, img_indices_L] = h_img[non_text_indices].flatten(0, 1)[img_feat_indices]