Skip to content

Commit 885a3f2

Browse files
committed
Code implementation of Async Scheduler
Signed-off-by: cychiuak <[email protected]>
1 parent 0aa5183 commit 885a3f2

File tree

2 files changed

+291
-5
lines changed

2 files changed

+291
-5
lines changed

tpu_inference/runner/compilation_manager.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def capture_model(self) -> None:
7676
self._precompile_backbone_text_only()
7777
if self.runner.is_multimodal_model:
7878
self._precompile_backbone_with_inputs_embeds()
79+
if self.runner.scheduler_config.async_scheduling:
80+
self._precompile_substitute_placeholder_token()
7981
self._precompile_select_from_array()
8082
self._precompile_compute_logits()
8183
self._precompile_disagg_utils()
@@ -149,6 +151,41 @@ def model_fn_wrapper(
149151
num_tokens=num_tokens,
150152
)
151153

154+
def _precompile_substitute_placeholder_token(self) -> None:
155+
"""Precompiles the token substitution function for all expected input shapes.
156+
157+
It iterates through all potential padded token lengths
158+
(`num_tokens_paddings`) and request batch sizes (`num_reqs_paddings`)
159+
that the scheduler is expected to handle, ensuring a compiled version
160+
is ready for each combination.
161+
"""
162+
163+
for num_tokens in self.runner.num_tokens_paddings:
164+
padded_token_in_tpu_cur_input_indices = np.zeros((num_tokens, ),
165+
dtype=np.int32)
166+
padded_token_in_tpu_pre_next_tokens_indices = np.zeros(
167+
(num_tokens, ), dtype=jnp.int32)
168+
for num_reqs in self.runner.num_reqs_paddings:
169+
input_ids = self._create_dummy_tensor((num_tokens, ),
170+
jnp.int32)
171+
# Need align to the sampling output
172+
next_tokens = self._create_dummy_tensor(
173+
(num_reqs, ),
174+
jnp.int32,
175+
sharding=NamedSharding(self.runner.mesh, PartitionSpec()))
176+
placeholder_num = 1
177+
self._run_compilation(
178+
"_substitute_placeholder_token_fn",
179+
self.runner._substitute_placeholder_token_fn,
180+
input_ids,
181+
padded_token_in_tpu_cur_input_indices,
182+
padded_token_in_tpu_pre_next_tokens_indices,
183+
next_tokens,
184+
placeholder_num,
185+
num_tokens=num_tokens,
186+
num_reqs=num_reqs,
187+
)
188+
152189
def _precompile_backbone_text_only(self) -> None:
153190
for num_tokens in self.runner.num_tokens_paddings:
154191
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)

tpu_inference/runner/tpu_jax_runner.py

Lines changed: 254 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import copy
12
import functools
23
import os
34
import random
45
from contextlib import nullcontext
6+
from dataclasses import dataclass
57
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
68

79
import jax
@@ -21,8 +23,8 @@
2123
from vllm.utils import cdiv
2224
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
2325
from vllm.v1.kv_cache_interface import KVCacheConfig
24-
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
25-
ModelRunnerOutput)
26+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
27+
DraftTokenIds, ModelRunnerOutput)
2628
from vllm.v1.request import Request
2729
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
2830
from vllm.v1.worker.kv_connector_model_runner_mixin import \
@@ -80,6 +82,91 @@
8082
}
8183

8284

85+
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
86+
"""Holds asynchronous model output specifically from a TPU runner.
87+
88+
This class acts as a wrapper around the standard ModelRunnerOutput. Its
89+
primary purpose is to hold references to data still on the TPU device
90+
(like the `next_tokens` JAX array) without blocking the main thread.
91+
92+
The `get_output()` method is called to resolve these async results,
93+
triggering the JAX device-to-host (CPU) data transfer and populating
94+
the final `ModelRunnerOutput` object.
95+
"""
96+
97+
def __init__(
98+
self,
99+
model_runner_output: ModelRunnerOutput,
100+
next_tokens: jax.Array,
101+
num_reqs: int,
102+
discard_sampled_tokens_req_indices: list[int],
103+
):
104+
self._model_runner_output = model_runner_output
105+
self._next_tokens = next_tokens
106+
self._num_reqs = num_reqs
107+
self._discard_sampled_tokens_req_indices = discard_sampled_tokens_req_indices
108+
109+
def get_output(self) -> ModelRunnerOutput:
110+
next_tokens_cpu = np.asarray(jax.device_get(self._next_tokens))
111+
selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs],
112+
1)
113+
valid_sampled_token_ids = selected_token_ids.tolist()
114+
for i in self._discard_sampled_tokens_req_indices:
115+
valid_sampled_token_ids[i].clear()
116+
self._model_runner_output.sampled_token_ids = valid_sampled_token_ids
117+
return self._model_runner_output
118+
119+
120+
@dataclass
121+
class AsyncPreResults:
122+
req_ids: list[str]
123+
next_tokens: jax.Array
124+
request_seq_lens: list[tuple[int, CachedRequestState, int]]
125+
discard_sampled_tokens_req_indices: list[int]
126+
placeholder_req_id_to_index: dict[str, int]
127+
128+
129+
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
130+
def _substitute_placeholder_token(
131+
input_ids: jax.Array, token_in_tpu_cur_input_indices: jax.Array,
132+
token_in_tpu_pre_next_tokens_indices: jax.Array,
133+
next_tokens: jax.Array, placeholder_num: int):
134+
"""Substitute placeholder tokens from TPU for async scheduler
135+
136+
Args:
137+
input_ids: possible input_ids size
138+
token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids.
139+
token_in_tpu_pre_next_tokens_indices: value idx in next_tokens. Length the same to input_ids.
140+
next_tokens: next tokens on the TPU from previous step.
141+
placeholder_num: number of placeholders. placeholder_num <= len(token_in_tpu_cur_input_indices)
142+
Return:
143+
input_ids after replace placeholder tokens
144+
"""
145+
assert input_ids.shape == token_in_tpu_cur_input_indices.shape == token_in_tpu_pre_next_tokens_indices.shape, \
146+
f"Shape mismatch: input_ids and index arrays must have identical shapes due to precompilation assumptions. " \
147+
f"Got: {input_ids.shape=}, {token_in_tpu_cur_input_indices.shape=}, {token_in_tpu_pre_next_tokens_indices.shape=}"
148+
149+
def updated_input_ids_array(i: int,
150+
current_input_ids: jax.Array) -> jax.Array:
151+
"""
152+
Iteratively updates the input_ids for all placeholders.
153+
å
154+
Args:
155+
i: The current loop index.
156+
current_input_ids: The loop carry state (the input_ids being modified).
157+
158+
Returns:
159+
The updated input_ids array.
160+
"""
161+
update_idx = token_in_tpu_cur_input_indices[i]
162+
value_idx = token_in_tpu_pre_next_tokens_indices[i]
163+
new_token_value = next_tokens[value_idx]
164+
return current_input_ids.at[update_idx].set(new_token_value)
165+
166+
return jax.lax.fori_loop(0, placeholder_num, updated_input_ids_array,
167+
input_ids)
168+
169+
83170
class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
84171

85172
def __init__(
@@ -139,6 +226,9 @@ def __init__(
139226
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
140227
cache_config.cache_dtype]
141228

229+
self._pre_async_results: AsyncPreResults | None = None
230+
self._substitute_placeholder_token_fn = _substitute_placeholder_token
231+
142232
def _init_random(self):
143233
if self.model_config.seed is None:
144234
self.model_config.seed = 0
@@ -343,13 +433,88 @@ def execute_model(
343433
self,
344434
scheduler_output: "VllmSchedulerOutput",
345435
intermediate_tensors: Optional[IntermediateTensors] = None,
346-
) -> ModelRunnerOutput:
436+
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
437+
347438
return self._execute_model(scheduler_output)[1]
348439

440+
def _modify_prev_results(self):
441+
# If copy to host has not been done, we just wait.
442+
# device_get should return immediately as we have scheduled it in previous function call.
443+
assert self._pre_async_results is not None, "When we call _modify_prev_results(), self._pre_async_results should already exist"
444+
pre_req_ids = self._pre_async_results.req_ids
445+
pre_next_tokens = self._pre_async_results.next_tokens
446+
pre_request_seq_lens = self._pre_async_results.request_seq_lens
447+
pre_discard_sampled_tokens_req_indices = self._pre_async_results.discard_sampled_tokens_req_indices
448+
449+
next_tokens_cpu = np.asarray(jax.device_get(pre_next_tokens))
450+
selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)],
451+
1)
452+
valid_sampled_token_ids = selected_token_ids.tolist()
453+
454+
# Mask out the sampled tokens that should not be sampled.
455+
for i in pre_discard_sampled_tokens_req_indices:
456+
valid_sampled_token_ids[i].clear()
457+
# Append sampled tokens
458+
for pre_req_idx, req_state, _ in pre_request_seq_lens:
459+
sampled_ids = valid_sampled_token_ids[pre_req_idx]
460+
if not sampled_ids:
461+
continue
462+
463+
# If request not active in the *current* batch (e.g. finished or evicted), skip it.
464+
req_id = pre_req_ids[pre_req_idx]
465+
if req_id not in self.input_batch.req_id_to_index:
466+
continue
467+
468+
req_idx = self.input_batch.req_id_to_index[req_id]
469+
assert req_state is self.requests[
470+
req_id], "The req_state should be valid and identical"
471+
472+
# Updated on previous execute
473+
end_idx = self.input_batch.num_tokens_no_spec[req_idx]
474+
assert len(sampled_ids) == 1, "do not support spec decode yet"
475+
start_idx = end_idx - 1
476+
assert end_idx <= self.max_model_len, (
477+
"Sampled token IDs exceed the max model length. "
478+
f"Total number of tokens: {end_idx} > max_model_len: "
479+
f"{self.max_model_len}")
480+
481+
self.input_batch.token_ids_cpu[req_idx,
482+
start_idx:end_idx] = sampled_ids
483+
# Replace previous placeholder
484+
req_state.output_token_ids[-1] = sampled_ids[-1]
485+
486+
def _update_placeholder(self, discard_sampled_tokens_req_indices,
487+
request_seq_lens):
488+
placeholder_req_id_to_index: dict[str, int] = {}
489+
discard_sampled_tokens_req_indices_set = set(
490+
discard_sampled_tokens_req_indices)
491+
for req_idx, req_state, _ in request_seq_lens:
492+
if req_idx in discard_sampled_tokens_req_indices_set:
493+
continue
494+
495+
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
496+
# Not supporting spec decode yet, assume only 1 new token
497+
end_idx = start_idx + 1
498+
assert end_idx <= self.max_model_len, (
499+
"Sampled token IDs exceed the max model length. "
500+
f"Total number of tokens: {end_idx} > max_model_len: "
501+
f"{self.max_model_len}")
502+
503+
# Update cpu tokens at next execute and prepare input from tpu
504+
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
505+
self.input_batch.num_tokens[req_idx] = end_idx
506+
507+
# For placeholder, should be update on next execute.
508+
req_state.output_token_ids.extend([0])
509+
510+
placeholder_req_id_to_index[req_state.req_id] = req_idx
511+
return placeholder_req_id_to_index
512+
349513
def _execute_model(
350514
self,
351515
scheduler_output: "VllmSchedulerOutput",
352-
) -> tuple[AttentionMetadata, ModelRunnerOutput]:
516+
) -> tuple[AttentionMetadata, ModelRunnerOutput
517+
| AsyncTPUModelRunnerOutput]:
353518
self.persistent_batch_manager.update_states(
354519
scheduler_output, self.get_mrope_input_positions_fn)
355520
if not scheduler_output.total_num_scheduled_tokens:
@@ -470,7 +635,7 @@ def _execute_model(
470635
num_reqs = self.input_batch.num_reqs
471636

472637
# Update the cache state concurrently. Code above will not block until
473-
# we use `selected_token_ids`. Add mark_step if post-processing changes
638+
# We use `selected_token_ids`. Add mark_step if post-processing changes
474639
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
475640
discard_sampled_tokens_req_indices = []
476641
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
@@ -501,6 +666,51 @@ def _execute_model(
501666
for req_id in self.input_batch.req_ids[:num_reqs]:
502667
prompt_logprobs_dict[req_id] = None
503668

669+
# If async scheduler enabled
670+
if self.scheduler_config.async_scheduling:
671+
# Get previous results from TPU and replace the placeholder.
672+
if self._pre_async_results is not None:
673+
assert not self.speculative_config and spec_decode_metadata is None, "Async scheduler does not support speculative decoding yet."
674+
self._modify_prev_results()
675+
676+
# Set placeholder for next tokens that is not yet generated
677+
placeholder_req_id_to_index: dict[
678+
str, int] = self._update_placeholder(
679+
discard_sampled_tokens_req_indices, request_seq_lens)
680+
681+
if logprobs is not None:
682+
logprobs_lists = logprobs.tolists()
683+
else:
684+
logprobs_lists = None
685+
686+
# Save the previous results
687+
next_tokens = jax.copy_to_host_async(next_tokens)
688+
self._pre_async_results = AsyncPreResults(
689+
req_ids=req_ids,
690+
next_tokens=next_tokens,
691+
request_seq_lens=request_seq_lens,
692+
discard_sampled_tokens_req_indices=
693+
discard_sampled_tokens_req_indices,
694+
placeholder_req_id_to_index=placeholder_req_id_to_index,
695+
)
696+
697+
# Return Model output to executor
698+
model_runner_output = ModelRunnerOutput(
699+
req_ids=req_ids,
700+
req_id_to_index=copy.deepcopy(
701+
self.input_batch.req_id_to_index),
702+
sampled_token_ids=[], # Fill in async get
703+
logprobs=logprobs_lists,
704+
prompt_logprobs_dict=prompt_logprobs_dict,
705+
pooler_output=[],
706+
kv_connector_output=kv_connector_output,
707+
)
708+
# Return attn_metadata, model_runner_output
709+
async_model_runner_output = AsyncTPUModelRunnerOutput(
710+
model_runner_output, next_tokens, num_reqs,
711+
discard_sampled_tokens_req_indices)
712+
return attn_metadata, async_model_runner_output
713+
504714
if spec_decode_metadata is None:
505715
next_tokens = np.asarray(jax.device_get(next_tokens))
506716
selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1)
@@ -596,6 +806,30 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
596806
# For each scheduled token, what are the corresponding req index.
597807
req_indices = np.repeat(self.arange_cpu[:num_reqs],
598808
num_scheduled_tokens_per_req)
809+
token_in_tpu_cur_input_indices = np.array([])
810+
token_in_tpu_pre_next_tokens_indices = np.array([])
811+
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
812+
# If async previous results exists, we will prepare for the token substitution here
813+
# The actual substitution will be performed in tpu during later parts of this function.
814+
token_in_tpu_cur_input_indices_list = []
815+
token_in_tpu_pre_next_tokens_indices_list = []
816+
acc_cur_len = 0
817+
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
818+
acc_cur_len += num_scheduled_tokens_per_req[i]
819+
assert req_id is not None
820+
if req_id not in self._pre_async_results.placeholder_req_id_to_index:
821+
continue
822+
823+
token_in_tpu_cur_input_indices_list.append(acc_cur_len - 1)
824+
token_in_tpu_pre_next_tokens_indices_list.append(
825+
self._pre_async_results.placeholder_req_id_to_index[req_id]
826+
)
827+
828+
if len(token_in_tpu_cur_input_indices_list) > 0:
829+
token_in_tpu_cur_input_indices = np.array(
830+
token_in_tpu_cur_input_indices_list)
831+
token_in_tpu_pre_next_tokens_indices = np.array(
832+
token_in_tpu_pre_next_tokens_indices_list)
599833

600834
# Get batched arange.
601835
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
@@ -701,6 +935,21 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
701935
self.mesh, (input_ids, positions, block_tables, query_start_loc,
702936
seq_lens, logits_indices, request_distribution))
703937

938+
if self.scheduler_config.async_scheduling and len(
939+
token_in_tpu_cur_input_indices) > 0:
940+
assert self._pre_async_results is not None
941+
idx_pad_len = len(input_ids) - len(token_in_tpu_cur_input_indices)
942+
padded_token_in_tpu_cur_input_indices = np.pad(
943+
token_in_tpu_cur_input_indices, (0, idx_pad_len))
944+
padded_token_in_tpu_pre_next_tokens_indices = np.pad(
945+
token_in_tpu_pre_next_tokens_indices, (0, idx_pad_len))
946+
with self.maybe_forbid_compile:
947+
input_ids = self._substitute_placeholder_token_fn(
948+
input_ids, padded_token_in_tpu_cur_input_indices,
949+
padded_token_in_tpu_pre_next_tokens_indices,
950+
self._pre_async_results.next_tokens,
951+
len(token_in_tpu_cur_input_indices))
952+
704953
if self.lora_config is not None:
705954
self.lora_utils.set_active_loras(
706955
num_scheduled_tokens_per_req, total_num_scheduled_tokens,

0 commit comments

Comments
 (0)