@@ -304,11 +304,11 @@ def test_gather_mm_embeddings_chunked_prefill(self):
304304 mock_scheduler_output_1 .num_scheduled_tokens = {req_id : 20 }
305305
306306 gathered_embeds_1 = self .runner .mm_manager .gather_mm_embeddings (
307- mock_scheduler_output_1 )
307+ mock_scheduler_output_1 , target_pad_len = 10 )
308308
309- assert len (gathered_embeds_1 ) == 1
310309 expected_embeds_1 = encoder_embedding [0 :10 ]
311- np .testing .assert_array_equal (np .asarray (gathered_embeds_1 [0 ]),
310+ assert gathered_embeds_1 .shape == expected_embeds_1 .shape
311+ np .testing .assert_array_equal (np .asarray (gathered_embeds_1 ),
312312 np .asarray (expected_embeds_1 ))
313313
314314 # ----- Step 2: Middle chunk of prefill -----
@@ -317,11 +317,11 @@ def test_gather_mm_embeddings_chunked_prefill(self):
317317 mock_scheduler_output_2 .num_scheduled_tokens = {req_id : 30 }
318318
319319 gathered_embeds_2 = self .runner .mm_manager .gather_mm_embeddings (
320- mock_scheduler_output_2 )
320+ mock_scheduler_output_2 , target_pad_len = 30 )
321321
322- assert len (gathered_embeds_2 ) == 1
323322 expected_embeds_2 = encoder_embedding [10 :40 ]
324- np .testing .assert_array_equal (np .asarray (gathered_embeds_2 [0 ]),
323+ assert gathered_embeds_2 .shape == expected_embeds_2 .shape
324+ np .testing .assert_array_equal (np .asarray (gathered_embeds_2 ),
325325 np .asarray (expected_embeds_2 ))
326326
327327 # ----- Step 3: Last chunk of prefill -----
@@ -330,11 +330,11 @@ def test_gather_mm_embeddings_chunked_prefill(self):
330330 mock_scheduler_output_3 .num_scheduled_tokens = {req_id : 30 }
331331
332332 gathered_embeds_3 = self .runner .mm_manager .gather_mm_embeddings (
333- mock_scheduler_output_3 )
333+ mock_scheduler_output_3 , target_pad_len = 16 )
334334
335- assert len (gathered_embeds_3 ) == 1
336335 expected_embeds_3 = encoder_embedding [40 :56 ]
337- np .testing .assert_array_equal (np .asarray (gathered_embeds_3 [0 ]),
336+ assert gathered_embeds_3 .shape == expected_embeds_3 .shape
337+ np .testing .assert_array_equal (np .asarray (gathered_embeds_3 ),
338338 np .asarray (expected_embeds_3 ))
339339
340340 def test_calc_mrope_positions (self ):
0 commit comments