Skip to content

Commit 6cdabdf

Browse files
kwang3939sierraisland
authored andcommitted
[CI] Fix Qwen2.5 VL get_mrope_input_positions after vLLM change. (#934)
Signed-off-by: Kewei Wang <[email protected]>
1 parent 025eb49 commit 6cdabdf

File tree

2 files changed

+111
-18
lines changed

2 files changed

+111
-18
lines changed

tpu_inference/models/common/model_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ def combine_hidden_states(graphdef, state, hidden_states):
270270
graphdef)
271271

272272
get_mrope_input_positions_fn = None if not hasattr(
273-
model_class,
274-
"get_mrope_input_positions") else model_class.get_mrope_input_positions
273+
jit_model,
274+
"get_mrope_input_positions") else jit_model.get_mrope_input_positions
275275

276276
return model_fn, compute_logits_fn, combine_hidden_states_fn, get_multimodal_embeddings_fn, get_input_embeddings_fn, get_mrope_input_positions_fn, state, lora_manager, model
277277

tpu_inference/models/jax/qwen2_5_vl.py

Lines changed: 109 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
1313
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
1414
from vllm.config import VllmConfig
15-
from vllm.model_executor.models.qwen2_5_vl import \
16-
Qwen2_5_VLForConditionalGeneration as vllm_model_cls
1715

1816
from tpu_inference import utils as utils
1917
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
@@ -689,9 +687,8 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
689687
)
690688
self.language_model = Qwen2ForCausalLM(vllm_config, rng_key, mesh)
691689

692-
@classmethod
693690
def get_mrope_input_positions(
694-
cls,
691+
self,
695692
input_tokens: list[int],
696693
hf_config,
697694
image_grid_thw,
@@ -701,18 +698,114 @@ def get_mrope_input_positions(
701698
seq_len: int | None = None,
702699
audio_feature_lengths=None,
703700
use_audio_in_video: bool = False,
704-
):
705-
return vllm_model_cls.get_mrope_input_positions(
706-
input_tokens=input_tokens,
707-
hf_config=hf_config,
708-
image_grid_thw=image_grid_thw,
709-
video_grid_thw=video_grid_thw,
710-
second_per_grid_ts=second_per_grid_ts,
711-
context_len=context_len,
712-
seq_len=seq_len,
713-
audio_feature_lengths=audio_feature_lengths,
714-
use_audio_in_video=use_audio_in_video,
715-
)
701+
) -> tuple[jax.Array, int]:
702+
"""Get mrope input positions and delta value."""
703+
704+
image_token_id = hf_config.image_token_id
705+
video_token_id = hf_config.video_token_id
706+
vision_start_token_id = hf_config.vision_start_token_id
707+
spatial_merge_size = hf_config.vision_config.spatial_merge_size
708+
tokens_per_second = getattr(hf_config.vision_config,
709+
"tokens_per_second", 1.0)
710+
711+
input_tokens_tensor = np.array(input_tokens)
712+
vision_start_indices = np.argwhere(
713+
input_tokens_tensor == vision_start_token_id).squeeze(1)
714+
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
715+
image_nums = np.sum(vision_tokens == image_token_id)
716+
video_nums = np.sum(vision_tokens == video_token_id)
717+
llm_pos_ids_list: list = []
718+
719+
st = 0
720+
remain_images, remain_videos = image_nums, video_nums
721+
722+
image_index, video_index = 0, 0
723+
for _ in range(image_nums + video_nums):
724+
video_second_per_grid_t = 0.0
725+
if remain_images > 0:
726+
try:
727+
ed_image = input_tokens.index(image_token_id, st)
728+
except ValueError:
729+
ed_image = len(input_tokens) + 1
730+
else:
731+
ed_image = len(input_tokens) + 1
732+
if remain_videos > 0:
733+
try:
734+
ed_video = input_tokens.index(video_token_id, st)
735+
except ValueError:
736+
ed_video = len(input_tokens) + 1
737+
else:
738+
ed_video = len(input_tokens) + 1
739+
if ed_image < ed_video:
740+
t, h, w = (
741+
image_grid_thw[image_index][0],
742+
image_grid_thw[image_index][1],
743+
image_grid_thw[image_index][2],
744+
)
745+
image_index += 1
746+
remain_images -= 1
747+
ed = ed_image
748+
else:
749+
t, h, w = (
750+
video_grid_thw[video_index][0],
751+
video_grid_thw[video_index][1],
752+
video_grid_thw[video_index][2],
753+
)
754+
video_second_per_grid_t = 1.0
755+
if second_per_grid_ts:
756+
video_second_per_grid_t = second_per_grid_ts[video_index]
757+
video_index += 1
758+
remain_videos -= 1
759+
ed = ed_video
760+
761+
llm_grid_t, llm_grid_h, llm_grid_w = (
762+
t,
763+
h // spatial_merge_size,
764+
w // spatial_merge_size,
765+
)
766+
text_len = ed - st
767+
768+
st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(
769+
llm_pos_ids_list) > 0 else 0
770+
llm_pos_ids_list.append(
771+
jnp.broadcast_to(
772+
jnp.arange(text_len, dtype=jnp.int32).reshape(1, -1),
773+
(3, text_len)) + st_idx)
774+
775+
t_index = ((jnp.broadcast_to(
776+
jnp.arange(llm_grid_t, dtype=jnp.int32).reshape(-1, 1),
777+
(llm_grid_t, llm_grid_h * llm_grid_w)) *
778+
video_second_per_grid_t * tokens_per_second).astype(
779+
jnp.int32).flatten())
780+
781+
h_index = (jnp.broadcast_to(
782+
jnp.arange(llm_grid_h, dtype=jnp.int32).reshape(1, -1, 1),
783+
(llm_grid_t, llm_grid_h, llm_grid_w)).flatten())
784+
w_index = (jnp.broadcast_to(
785+
jnp.arange(llm_grid_w, dtype=jnp.int32).reshape(1, 1, -1),
786+
(llm_grid_t, llm_grid_h, llm_grid_w)).flatten())
787+
788+
llm_pos_ids_list.append(
789+
jnp.stack([t_index, h_index, w_index]) + text_len + st_idx)
790+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
791+
792+
if st < len(input_tokens):
793+
st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(
794+
llm_pos_ids_list) > 0 else 0
795+
text_len = len(input_tokens) - st
796+
797+
llm_pos_ids_list.append(
798+
jnp.broadcast_to(
799+
jnp.arange(text_len, dtype=jnp.int32).reshape(1, -1),
800+
(3, text_len)) + st_idx)
801+
802+
llm_positions = jnp.concatenate(llm_pos_ids_list,
803+
axis=1).reshape(3, -1)
804+
mrope_position_delta = (llm_positions.max() + 1 -
805+
len(input_tokens)).item()
806+
llm_positions = llm_positions[:, context_len:seq_len]
807+
808+
return llm_positions, mrope_position_delta
716809

717810
def _validate_and_reshape_mm_tensor(self, mm_input: object,
718811
name: str) -> jax.Array:

0 commit comments

Comments
 (0)