1212from transformers .models .qwen2_5_vl .configuration_qwen2_5_vl import (
1313 Qwen2_5_VLConfig , Qwen2_5_VLVisionConfig )
1414from vllm .config import VllmConfig
15- from vllm .model_executor .models .qwen2_5_vl import \
16- Qwen2_5_VLForConditionalGeneration as vllm_model_cls
1715
1816from tpu_inference import utils as utils
1917from tpu_inference .layers .common .attention_metadata import AttentionMetadata
@@ -689,9 +687,8 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
689687 )
690688 self .language_model = Qwen2ForCausalLM (vllm_config , rng_key , mesh )
691689
692- @classmethod
693690 def get_mrope_input_positions (
694- cls ,
691+ self ,
695692 input_tokens : list [int ],
696693 hf_config ,
697694 image_grid_thw ,
@@ -701,18 +698,114 @@ def get_mrope_input_positions(
701698 seq_len : int | None = None ,
702699 audio_feature_lengths = None ,
703700 use_audio_in_video : bool = False ,
704- ):
705- return vllm_model_cls .get_mrope_input_positions (
706- input_tokens = input_tokens ,
707- hf_config = hf_config ,
708- image_grid_thw = image_grid_thw ,
709- video_grid_thw = video_grid_thw ,
710- second_per_grid_ts = second_per_grid_ts ,
711- context_len = context_len ,
712- seq_len = seq_len ,
713- audio_feature_lengths = audio_feature_lengths ,
714- use_audio_in_video = use_audio_in_video ,
715- )
701+ ) -> tuple [jax .Array , int ]:
702+ """Get mrope input positions and delta value."""
703+
704+ image_token_id = hf_config .image_token_id
705+ video_token_id = hf_config .video_token_id
706+ vision_start_token_id = hf_config .vision_start_token_id
707+ spatial_merge_size = hf_config .vision_config .spatial_merge_size
708+ tokens_per_second = getattr (hf_config .vision_config ,
709+ "tokens_per_second" , 1.0 )
710+
711+ input_tokens_tensor = np .array (input_tokens )
712+ vision_start_indices = np .argwhere (
713+ input_tokens_tensor == vision_start_token_id ).squeeze (1 )
714+ vision_tokens = input_tokens_tensor [vision_start_indices + 1 ]
715+ image_nums = np .sum (vision_tokens == image_token_id )
716+ video_nums = np .sum (vision_tokens == video_token_id )
717+ llm_pos_ids_list : list = []
718+
719+ st = 0
720+ remain_images , remain_videos = image_nums , video_nums
721+
722+ image_index , video_index = 0 , 0
723+ for _ in range (image_nums + video_nums ):
724+ video_second_per_grid_t = 0.0
725+ if remain_images > 0 :
726+ try :
727+ ed_image = input_tokens .index (image_token_id , st )
728+ except ValueError :
729+ ed_image = len (input_tokens ) + 1
730+ else :
731+ ed_image = len (input_tokens ) + 1
732+ if remain_videos > 0 :
733+ try :
734+ ed_video = input_tokens .index (video_token_id , st )
735+ except ValueError :
736+ ed_video = len (input_tokens ) + 1
737+ else :
738+ ed_video = len (input_tokens ) + 1
739+ if ed_image < ed_video :
740+ t , h , w = (
741+ image_grid_thw [image_index ][0 ],
742+ image_grid_thw [image_index ][1 ],
743+ image_grid_thw [image_index ][2 ],
744+ )
745+ image_index += 1
746+ remain_images -= 1
747+ ed = ed_image
748+ else :
749+ t , h , w = (
750+ video_grid_thw [video_index ][0 ],
751+ video_grid_thw [video_index ][1 ],
752+ video_grid_thw [video_index ][2 ],
753+ )
754+ video_second_per_grid_t = 1.0
755+ if second_per_grid_ts :
756+ video_second_per_grid_t = second_per_grid_ts [video_index ]
757+ video_index += 1
758+ remain_videos -= 1
759+ ed = ed_video
760+
761+ llm_grid_t , llm_grid_h , llm_grid_w = (
762+ t ,
763+ h // spatial_merge_size ,
764+ w // spatial_merge_size ,
765+ )
766+ text_len = ed - st
767+
768+ st_idx = llm_pos_ids_list [- 1 ].max ().item () + 1 if len (
769+ llm_pos_ids_list ) > 0 else 0
770+ llm_pos_ids_list .append (
771+ jnp .broadcast_to (
772+ jnp .arange (text_len , dtype = jnp .int32 ).reshape (1 , - 1 ),
773+ (3 , text_len )) + st_idx )
774+
775+ t_index = ((jnp .broadcast_to (
776+ jnp .arange (llm_grid_t , dtype = jnp .int32 ).reshape (- 1 , 1 ),
777+ (llm_grid_t , llm_grid_h * llm_grid_w )) *
778+ video_second_per_grid_t * tokens_per_second ).astype (
779+ jnp .int32 ).flatten ())
780+
781+ h_index = (jnp .broadcast_to (
782+ jnp .arange (llm_grid_h , dtype = jnp .int32 ).reshape (1 , - 1 , 1 ),
783+ (llm_grid_t , llm_grid_h , llm_grid_w )).flatten ())
784+ w_index = (jnp .broadcast_to (
785+ jnp .arange (llm_grid_w , dtype = jnp .int32 ).reshape (1 , 1 , - 1 ),
786+ (llm_grid_t , llm_grid_h , llm_grid_w )).flatten ())
787+
788+ llm_pos_ids_list .append (
789+ jnp .stack ([t_index , h_index , w_index ]) + text_len + st_idx )
790+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
791+
792+ if st < len (input_tokens ):
793+ st_idx = llm_pos_ids_list [- 1 ].max ().item () + 1 if len (
794+ llm_pos_ids_list ) > 0 else 0
795+ text_len = len (input_tokens ) - st
796+
797+ llm_pos_ids_list .append (
798+ jnp .broadcast_to (
799+ jnp .arange (text_len , dtype = jnp .int32 ).reshape (1 , - 1 ),
800+ (3 , text_len )) + st_idx )
801+
802+ llm_positions = jnp .concatenate (llm_pos_ids_list ,
803+ axis = 1 ).reshape (3 , - 1 )
804+ mrope_position_delta = (llm_positions .max () + 1 -
805+ len (input_tokens )).item ()
806+ llm_positions = llm_positions [:, context_len :seq_len ]
807+
808+ return llm_positions , mrope_position_delta
716809
717810 def _validate_and_reshape_mm_tensor (self , mm_input : object ,
718811 name : str ) -> jax .Array :
0 commit comments