Skip to content

Commit 5f5ca02

Browse files
authored
[Model] Add vision encoder and input embeddings merger warmup for Qwen2.5 VL model (vllm-project#972)
Signed-off-by: Kewei Wang <[email protected]>
1 parent c187ed1 commit 5f5ca02

File tree

9 files changed

+137
-45
lines changed

9 files changed

+137
-45
lines changed

tests/models/jax/test_qwen2_5_vl.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616
# Import the module itself to allow patching
1717
# Corrected imports for the code under test
1818
from tpu_inference.models.jax.qwen2_5_vl import (
19-
AttentionMetadata, MultiModalEmbeddings, Qwen2_5_VisionAttention,
20-
Qwen2_5_VisionBlock, Qwen2_5_VisionMLP, Qwen2_5_VisionPatchEmbed,
21-
Qwen2_5_VisionPatchMerger, Qwen2_5_VisionRotaryEmbedding,
22-
Qwen2_5_VisionTransformer, Qwen2_5_VLForConditionalGeneration,
23-
Qwen2_5_VLImagePixelInputs, SegmentIds, apply_rotary_pos_emb_vision,
24-
generate_window_segment_ids)
19+
AttentionMetadata, Qwen2_5_VisionAttention, Qwen2_5_VisionBlock,
20+
Qwen2_5_VisionMLP, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionPatchMerger,
21+
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer,
22+
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLImagePixelInputs, SegmentIds,
23+
apply_rotary_pos_emb_vision, generate_window_segment_ids)
2524

2625

2726
# --- Configuration Mocking ---
@@ -508,12 +507,12 @@ def test_get_input_embeddings(self, mock_merge_embeddings: MagicMock,
508507
np.testing.assert_array_equal(embeds, mock_text_embeds)
509508
mock_merge_embeddings.assert_not_called()
510509

511-
embeds_empty_mm = model.get_input_embeddings(input_ids, tuple())
510+
empty_mm = jnp.ones((0, model.config.hidden_size), )
511+
embeds_empty_mm = model.get_input_embeddings(input_ids, empty_mm)
512512
np.testing.assert_array_equal(embeds_empty_mm, mock_text_embeds)
513513
mock_merge_embeddings.assert_not_called()
514514

515-
mm_embeds: MultiModalEmbeddings = (jnp.ones(
516-
(5, model.config.hidden_size)), )
515+
mm_embeds = jnp.ones((5, model.config.hidden_size))
517516
mock_merged = jnp.ones((1, 15, model.config.hidden_size))
518517
mock_merge_embeddings.return_value = mock_merged
519518

tests/models/jax/utils/test_multi_modal_utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from tpu_inference.models.jax.utils.multi_modal_utils import (
7-
MultiModalEmbeddings, NestedTensors, _flatten_embeddings,
7+
MultiModalEmbeddings, NestedTensors, flatten_embeddings,
88
merge_multimodal_embeddings, sanity_check_mm_encoder_outputs)
99

1010
# --- Tests for sanity_check_mm_encoder_outputs ---
@@ -65,45 +65,45 @@ def test_sanity_check_wrong_dimensions_in_list():
6565
sanity_check_mm_encoder_outputs(embeddings, 1)
6666

6767

68-
# --- Tests for _flatten_embeddings ---
68+
# --- Tests for flatten_embeddings ---
6969

7070

7171
def test_flatten_single_array():
72-
"""Tests _flatten_embeddings with a single 2D array."""
72+
"""Tests flatten_embeddings with a single 2D array."""
7373
emb: NestedTensors = jnp.arange(12).reshape((3, 4))
74-
result = _flatten_embeddings(emb)
74+
result = flatten_embeddings(emb)
7575
np.testing.assert_array_equal(result, emb)
7676

7777

7878
def test_flatten_single_3d_array():
79-
"""Tests _flatten_embeddings with a single 3D array."""
79+
"""Tests flatten_embeddings with a single 3D array."""
8080
emb: NestedTensors = jnp.arange(24).reshape((2, 3, 4))
81-
result = _flatten_embeddings(emb)
81+
result = flatten_embeddings(emb)
8282
expected = jnp.arange(24).reshape((6, 4))
8383
np.testing.assert_array_equal(result, expected)
8484

8585

8686
def test_flatten_list_of_arrays():
87-
"""Tests _flatten_embeddings with a list of 2D arrays."""
87+
"""Tests flatten_embeddings with a list of 2D arrays."""
8888
emb: NestedTensors = [
8989
jnp.arange(12).reshape((3, 4)),
9090
jnp.arange(12, 20).reshape((2, 4))
9191
]
92-
result = _flatten_embeddings(emb)
92+
result = flatten_embeddings(emb)
9393
expected = jnp.arange(20).reshape((5, 4))
9494
np.testing.assert_array_equal(result, expected)
9595

9696

9797
def test_flatten_nested_list():
98-
"""Tests _flatten_embeddings with a nested list of arrays."""
98+
"""Tests flatten_embeddings with a nested list of arrays."""
9999
emb: NestedTensors = [
100100
jnp.arange(6).reshape((2, 3)),
101101
[
102102
jnp.arange(6, 12).reshape((2, 3)),
103103
jnp.arange(12, 15).reshape((1, 3))
104104
]
105105
]
106-
result = _flatten_embeddings(emb)
106+
result = flatten_embeddings(emb)
107107
expected = jnp.arange(15).reshape((5, 3))
108108
np.testing.assert_array_equal(result, expected)
109109

@@ -191,7 +191,7 @@ def test_merge_mm_embeds_count_too_many_no_raise(placeholder_id, base_embeds):
191191
# Check that the first 2 embeddings from mm_embeds_too_many were used.
192192
expected = np.array(inputs_embeds)
193193
is_mm = np.isin(input_ids, np.array(placeholder_id))
194-
expected[is_mm] = _flatten_embeddings(mm_embeds_too_many)[:2]
194+
expected[is_mm] = flatten_embeddings(mm_embeds_too_many)[:2]
195195
np.testing.assert_array_equal(result, expected)
196196
except Exception as e:
197197
pytest.fail(

tpu_inference/models/common/model_loader.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,9 @@ def combine_hidden_states(graphdef, state, hidden_states):
260260
model = nnx.merge(graphdef, state)
261261
return model.combine_hidden_states(hidden_states)
262262

263+
model = nnx.merge(graphdef, state)
264+
precompile_vision_encoder_fn = getattr(model, "precompile_vision_encoder",
265+
None)
263266
model_fn = functools.partial(run_model, graphdef)
264267
compute_logits_fn = functools.partial(run_compute_logits, graphdef)
265268
get_multimodal_embeddings_fn = functools.partial(
@@ -274,7 +277,14 @@ def combine_hidden_states(graphdef, state, hidden_states):
274277
jit_model,
275278
"get_mrope_input_positions") else jit_model.get_mrope_input_positions
276279

277-
return model_fn, compute_logits_fn, combine_hidden_states_fn, get_multimodal_embeddings_fn, get_input_embeddings_fn, get_mrope_input_positions_fn, state, lora_manager, model
280+
multimodal_fns = {
281+
"precompile_vision_encoder_fn": precompile_vision_encoder_fn,
282+
"get_multimodal_embeddings_fn": get_multimodal_embeddings_fn,
283+
"get_input_embeddings_fn": get_input_embeddings_fn,
284+
"get_mrope_input_positions_fn": get_mrope_input_positions_fn,
285+
}
286+
287+
return model_fn, compute_logits_fn, combine_hidden_states_fn, multimodal_fns, state, lora_manager, model
278288

279289

280290
def get_vllm_model(
@@ -295,7 +305,7 @@ def get_vllm_model(
295305
compute_logits_fn = model.jit_compute_logits_func()
296306
# the model needs to be returned because lora weights are neither torch.nn.parameter nor torch.nn.buffer. After we load the lora weights and set it to the torch.nn.Module, we can shard it and move it to TPU.
297307
combine_hidden_states_fn = None
298-
return jit_model, compute_logits_fn, combine_hidden_states_fn, None, None, None, params, lora_manager, model
308+
return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model
299309

300310

301311
def get_model(

tpu_inference/models/jax/qwen2_5_vl.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -959,14 +959,13 @@ def get_multimodal_embeddings(self, image_grid_thw: tuple[tuple[int, int,
959959

960960
def get_input_embeddings(
961961
self, input_ids: jax.Array,
962-
multimodal_embeddings: Optional[MultiModalEmbeddings]
963-
) -> jax.Array:
962+
multimodal_embeddings: Optional[jax.Array]) -> jax.Array:
964963

965964
inputs_embeds = self.language_model.model.embed(input_ids)
966965

967966

968967
if multimodal_embeddings is not None \
969-
and len(multimodal_embeddings) != 0:
968+
and multimodal_embeddings.shape[0] != 0:
970969
inputs_embeds = merge_multimodal_embeddings(
971970
input_ids, inputs_embeds, multimodal_embeddings,
972971
[self.config.image_token_id, self.config.video_token_id])
@@ -1067,3 +1066,35 @@ def load_weights(self, rng_key: jax.Array) -> None:
10671066
model=self,
10681067
metadata_map=metadata_map,
10691068
mesh=self.mesh)
1069+
1070+
def precompile_vision_encoder(
1071+
self,
1072+
run_compilation_fn: Callable,
1073+
) -> None:
1074+
image_shapes = []
1075+
if (warmup_config := self.vllm_config.additional_config.get(
1076+
"vision_warmup_config")):
1077+
image_shapes = warmup_config.get("image_shapes")
1078+
1079+
vc = self.vllm_config.model_config.hf_config.vision_config
1080+
for input_hw in image_shapes:
1081+
if not isinstance(input_hw, list) or len(input_hw) != 2:
1082+
logger.warning(f"Skipping invalid shape {input_hw}.")
1083+
continue
1084+
h_input, w_input = input_hw
1085+
t, h, w = 1, h_input // vc.patch_size, w_input // vc.patch_size
1086+
grid_thw = (t, h, w)
1087+
num_patches = t * h * w
1088+
patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1089+
1090+
dummy_pixel_values = jnp.ones(
1091+
(num_patches, patch_input_dim),
1092+
self.vllm_config.model_config.dtype,
1093+
)
1094+
dummy_grid_thw = grid_thw
1095+
1096+
run_compilation_fn("single_image_encoder",
1097+
self.get_single_image_embedding,
1098+
dummy_pixel_values,
1099+
dummy_grid_thw,
1100+
image_shape=input_hw)

tpu_inference/models/jax/utils/multi_modal_utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def sanity_check_mm_encoder_outputs(
5050
"of the model's `get_multimodal_embeddings` method.")
5151

5252

53-
def _flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
53+
def flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
5454
"""
5555
Recursively flattens and concatenates NestedTensors on all but the last
5656
dimension.
@@ -59,8 +59,7 @@ def _flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
5959
if isinstance(embeddings, jax.Array):
6060
return embeddings.reshape(-1, embeddings.shape[-1])
6161

62-
return jnp.concatenate([_flatten_embeddings(t) for t in embeddings],
63-
axis=0)
62+
return jnp.concatenate([flatten_embeddings(t) for t in embeddings], axis=0)
6463

6564

6665
def _embedding_count_expression(embeddings: NestedTensors) -> str:
@@ -79,7 +78,7 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
7978
def _merge_multimodal_embeddings(
8079
inputs_embeds: jax.Array,
8180
is_multimodal: jax.Array,
82-
multimodal_embeddings: NestedTensors,
81+
multimodal_embeddings: jax.Array,
8382
) -> jax.Array:
8483
"""
8584
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
@@ -89,7 +88,6 @@ def _merge_multimodal_embeddings(
8988
Note:
9089
This returns a new array with the updated values.
9190
"""
92-
flattened = _flatten_embeddings(multimodal_embeddings)
9391
# The check for matching number of tokens is removed as it is not
9492
# JIT-compatible. If the shapes mismatch, JAX will raise an error
9593
# during execution anyway. The user-friendly error message is
@@ -99,10 +97,11 @@ def _merge_multimodal_embeddings(
9997
# NonConcreteBooleanIndexError.
10098
# Create a dummy row to handle indices for non-multimodal tokens.
10199
# The content of the dummy row does not matter as it will be masked out.
102-
dummy_row = jnp.zeros_like(flattened[0:1])
100+
dummy_row = jnp.zeros_like(multimodal_embeddings[0:1])
103101

104102
# Prepend the dummy row to the flattened embeddings.
105-
flattened_padded = jnp.concatenate([dummy_row, flattened], axis=0)
103+
flattened_padded = jnp.concatenate([dummy_row, multimodal_embeddings],
104+
axis=0)
106105

107106
# Create gather indices. For each token in the input sequence, this gives
108107
# the index into `flattened_padded`.
@@ -121,7 +120,7 @@ def _merge_multimodal_embeddings(
121120
def merge_multimodal_embeddings(
122121
input_ids: jax.Array,
123122
inputs_embeds: jax.Array,
124-
multimodal_embeddings: NestedTensors,
123+
multimodal_embeddings: jax.Array,
125124
placeholder_token_id: Union[int, list[int]],
126125
) -> jax.Array:
127126
"""

tpu_inference/runner/compilation_manager.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def capture_model(self) -> None:
7474
with self.runner.maybe_setup_dummy_loras(self.runner.lora_config):
7575
self._precompile_backbone_text_only()
7676
if self.runner.is_multimodal_model:
77+
self.runner.precompile_vision_encoder_fn(
78+
self._run_compilation, )
79+
self._precompile_input_embeddings_merger()
7780
self._precompile_backbone_with_inputs_embeds()
7881
if self.runner.scheduler_config.async_scheduling:
7982
self._precompile_substitute_placeholder_token()
@@ -86,6 +89,36 @@ def capture_model(self) -> None:
8689
if self.runner.speculative_config:
8790
self._precompile_speculative_decoding()
8891

92+
def _precompile_input_embeddings_merger(self) -> None:
93+
for num_tokens in self.runner.num_tokens_paddings:
94+
hidden_size = self.runner.vllm_config.model_config.get_hidden_size(
95+
)
96+
sharding = NamedSharding(self.runner.mesh, PartitionSpec())
97+
dummy_multimodal_embeddings = self._create_dummy_tensor(
98+
(num_tokens, hidden_size),
99+
self.runner.vllm_config.model_config.dtype,
100+
sharding=sharding)
101+
dummy_input_ids = self._create_dummy_tensor((num_tokens, ),
102+
jnp.int32)
103+
104+
self._run_compilation(
105+
"input_embeddings_merger",
106+
self.runner.get_input_embeddings_fn,
107+
self.runner.state,
108+
dummy_input_ids,
109+
dummy_multimodal_embeddings,
110+
num_tokens=num_tokens,
111+
)
112+
113+
self._run_compilation(
114+
"input_embeddings_merger_text_only",
115+
self.runner.get_input_embeddings_fn,
116+
self.runner.state,
117+
dummy_input_ids,
118+
None,
119+
num_tokens=num_tokens,
120+
)
121+
89122
def _precompile_backbone_helper(self, name, *, input_ids, positions,
90123
inputs_embeds) -> None:
91124
num_tokens = None

tpu_inference/runner/multimodal_manager.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from vllm.v1.worker.utils import (gather_mm_placeholders,
1010
scatter_mm_placeholders)
1111

12-
from tpu_inference.models.jax.utils.multi_modal_utils import \
13-
sanity_check_mm_encoder_outputs
12+
from tpu_inference.models.jax.utils.multi_modal_utils import (
13+
flatten_embeddings, sanity_check_mm_encoder_outputs)
1414

1515
if TYPE_CHECKING:
1616
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
@@ -158,10 +158,8 @@ def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"):
158158
is_embed=pos_info.is_embed,
159159
)
160160

161-
def gather_mm_embeddings(
162-
self,
163-
scheduler_output: "VllmSchedulerOutput",
164-
) -> list[jax.Array]:
161+
def gather_mm_embeddings(self, scheduler_output: "VllmSchedulerOutput",
162+
target_pad_len: int) -> list[jax.Array]:
165163
mm_embeds: list[jax.Array] = []
166164
for req_id in self.runner.input_batch.req_ids:
167165
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
@@ -205,4 +203,15 @@ def gather_mm_embeddings(
205203
is_embed=is_embed,
206204
)
207205
mm_embeds.append(mm_embeds_item)
208-
return mm_embeds
206+
if not mm_embeds:
207+
return None
208+
flattened_embeds = flatten_embeddings(mm_embeds)
209+
if flattened_embeds.shape[0] == 0:
210+
return None
211+
212+
padding = jnp.zeros((target_pad_len - flattened_embeds.shape[0],
213+
flattened_embeds.shape[1]),
214+
dtype=flattened_embeds.dtype)
215+
flattened_embeds = jnp.concatenate([flattened_embeds, padding], axis=0)
216+
217+
return flattened_embeds

tpu_inference/runner/tpu_jax_runner.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,22 @@ def _init_inputs(self) -> None:
381381
dtype=np.int64)
382382

383383
def load_model(self):
384-
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, self.get_multimodal_embeddings_fn, self.get_input_embeddings_fn, self.get_mrope_input_positions_fn, self.state, self.lora_manager, self.model = get_model(
384+
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model(
385385
self.vllm_config,
386386
self.rng_key,
387387
self.mesh,
388388
)
389389

390+
multimodal_fns = multimodal_fns or {}
391+
self.precompile_vision_encoder_fn = multimodal_fns.get(
392+
"precompile_vision_encoder_fn", None)
393+
self.get_multimodal_embeddings_fn = multimodal_fns.get(
394+
"get_multimodal_embeddings_fn", None)
395+
self.get_input_embeddings_fn = multimodal_fns.get(
396+
"get_input_embeddings_fn", None)
397+
self.get_mrope_input_positions_fn = multimodal_fns.get(
398+
"get_mrope_input_positions_fn", None)
399+
390400
if self.drafter is not None:
391401
logger.info("Loading drafter model...")
392402
self.drafter.load_model(self.state)
@@ -529,7 +539,8 @@ def _execute_model(
529539
# Run the multimodal encoder if any.
530540
# We have the modality embeds at this time.
531541
self.mm_manager.execute_mm_encoder(scheduler_output)
532-
mm_embeds = self.mm_manager.gather_mm_embeddings(scheduler_output)
542+
mm_embeds = self.mm_manager.gather_mm_embeddings(
543+
scheduler_output, input_ids.shape[0])
533544
else:
534545
mm_embeds = []
535546

@@ -970,8 +981,8 @@ def _get_input_ids_embeds(self, input_ids: jax.Array,
970981
if self.is_multimodal_model:
971982
inputs_embeds = self.get_input_embeddings_fn(
972983
self.state,
973-
input_ids=input_ids,
974-
multimodal_embeddings=mm_embeds,
984+
input_ids,
985+
mm_embeds,
975986
)
976987
return None, inputs_embeds
977988
else:

tpu_inference/spec_decode/jax/eagle3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949

5050
def load_model(self, target_model: Any) -> None:
5151
"""Loads the draft model."""
52-
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, _, _, self.state, _, _ = get_model(
52+
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
5353
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
5454
del self.state.model['embed_tokens']
5555
self.state.model.embed_tokens = target_model.model.embed

0 commit comments

Comments
 (0)