Skip to content

Commit 04609e6

Browse files
committed
Code implementation of Async Scheduler
1 parent a523939 commit 04609e6

File tree

2 files changed

+272
-4
lines changed

2 files changed

+272
-4
lines changed

tpu_inference/runner/compilation_manager.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def capture_model(self) -> None:
7676
self._precompile_backbone_text_only()
7777
if self.runner.is_multimodal_model:
7878
self._precompile_backbone_with_inputs_embeds()
79+
if self.runner.scheduler_config.async_scheduling:
80+
self._precompile_substitute_placeholder_token()
7981
self._precompile_select_from_array()
8082
self._precompile_compute_logits()
8183
self._precompile_disagg_utils()
@@ -148,6 +150,38 @@ def model_fn_wrapper(
148150
lora_metadata,
149151
num_tokens=num_tokens,
150152
)
153+
def _precompile_substitute_placeholder_token(self) -> None:
154+
"""Precompiles the token substitution function for all expected input shapes.
155+
156+
It iterates through all potential padded token lengths
157+
(`num_tokens_paddings`) and request batch sizes (`num_reqs_paddings`)
158+
that the scheduler is expected to handle, ensuring a compiled version
159+
is ready for each combination.
160+
"""
161+
162+
for num_tokens in self.runner.num_tokens_paddings:
163+
padded_token_in_tpu_cur_input_indices = np.zeros((num_tokens, ), dtype=np.int32)
164+
padded_token_in_tpu_pre_next_tokens_indices = np.zeros((num_tokens, ), dtype=jnp.int32)
165+
for num_reqs in self.runner.num_reqs_paddings:
166+
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
167+
# Need align to the sampling output
168+
next_tokens = self._create_dummy_tensor((num_reqs, ),
169+
jnp.int32,
170+
sharding=NamedSharding(
171+
self.runner.mesh, PartitionSpec())
172+
)
173+
placeholder_num = 1
174+
self._run_compilation(
175+
"_substitute_placeholder_token_fn",
176+
self.runner._substitute_placeholder_token_fn,
177+
input_ids,
178+
padded_token_in_tpu_cur_input_indices,
179+
padded_token_in_tpu_pre_next_tokens_indices,
180+
next_tokens,
181+
placeholder_num,
182+
num_tokens=num_tokens,
183+
num_reqs=num_reqs,
184+
)
151185

152186
def _precompile_backbone_text_only(self) -> None:
153187
for num_tokens in self.runner.num_tokens_paddings:

tpu_inference/runner/tpu_jax_runner.py

Lines changed: 238 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import random
44
from contextlib import nullcontext
55
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
6+
from dataclasses import dataclass
7+
import copy
68

79
import jax
810
import jax.numpy as jnp
@@ -22,7 +24,7 @@
2224
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
2325
from vllm.v1.kv_cache_interface import KVCacheConfig
2426
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
25-
ModelRunnerOutput)
27+
ModelRunnerOutput, AsyncModelRunnerOutput)
2628
from vllm.v1.request import Request
2729
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
2830
from vllm.v1.worker.kv_connector_model_runner_mixin import \
@@ -79,6 +81,88 @@
7981
"uint8": torch.uint8,
8082
}
8183

84+
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
85+
"""Holds asynchronous model output specifically from a TPU runner.
86+
87+
This class acts as a wrapper around the standard ModelRunnerOutput. Its
88+
primary purpose is to hold references to data still on the TPU device
89+
(like the `next_tokens` JAX array) without blocking the main thread.
90+
91+
The `get_output()` method is called to resolve these async results,
92+
triggering the JAX device-to-host (CPU) data transfer and populating
93+
the final `ModelRunnerOutput` object.
94+
"""
95+
96+
def __init__(self,
97+
model_runner_output: ModelRunnerOutput,
98+
next_tokens: jax.Array,
99+
num_reqs: int,
100+
discard_sampled_tokens_req_indices: list[int],
101+
):
102+
self._model_runner_output = model_runner_output
103+
self._next_tokens = next_tokens
104+
self._num_reqs = num_reqs
105+
self._discard_sampled_tokens_req_indices = discard_sampled_tokens_req_indices
106+
107+
108+
def get_output(self) -> ModelRunnerOutput:
109+
next_tokens_cpu = np.asarray(jax.device_get(self._next_tokens))
110+
selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs], 1)
111+
valid_sampled_token_ids = selected_token_ids.tolist()
112+
for i in self._discard_sampled_tokens_req_indices:
113+
valid_sampled_token_ids[i].clear()
114+
self._model_runner_output.sampled_token_ids = valid_sampled_token_ids
115+
return self._model_runner_output
116+
117+
@dataclass
118+
class AsyncPreResults:
119+
req_ids: list[str]
120+
next_tokens: jax.Array
121+
request_seq_lens: list[tuple[int, CachedRequestState, int]]
122+
discard_sampled_tokens_req_indices: list[int]
123+
placeholder_req_id_to_index: dict[str, int]
124+
125+
126+
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
127+
def _substitute_placeholder_token(input_ids: jax.Array,
128+
token_in_tpu_cur_input_indices: jax.Array,
129+
token_in_tpu_pre_next_tokens_indices: jax.Array,
130+
next_tokens: jax.Array,
131+
placeholder_num: int):
132+
"""Substitute placeholder tokens from TPU for async scheduler
133+
134+
Args:
135+
input_ids: possible input_ids size
136+
token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids.
137+
token_in_tpu_pre_next_tokens_indices: value idx in next_tokens. Length the same to input_ids.
138+
next_tokens: next tokens on the TPU from previous step.
139+
placeholder_num: number of placeholders. placeholder_num <= len(token_in_tpu_cur_input_indices)
140+
Return:
141+
input_ids after replace placeholder tokens
142+
"""
143+
assert input_ids.shape == token_in_tpu_cur_input_indices.shape == token_in_tpu_pre_next_tokens_indices.shape, \
144+
f"Shape mismatch: input_ids and index arrays must have identical shapes due to precompilation assumptions. " \
145+
f"Got: {input_ids.shape=}, {token_in_tpu_cur_input_indices.shape=}, {token_in_tpu_pre_next_tokens_indices.shape=}"
146+
147+
def updated_input_ids_array(i: int, current_input_ids: jax.Array) -> jax.Array:
148+
"""
149+
Iteratively updates the input_ids for all placeholders.
150+
å
151+
Args:
152+
i: The current loop index.
153+
current_input_ids: The loop carry state (the input_ids being modified).
154+
155+
Returns:
156+
The updated input_ids array.
157+
"""
158+
update_idx = token_in_tpu_cur_input_indices[i]
159+
value_idx = token_in_tpu_pre_next_tokens_indices[i]
160+
new_token_value = next_tokens[value_idx]
161+
return current_input_ids.at[update_idx].set(new_token_value)
162+
163+
return jax.lax.fori_loop(0, placeholder_num,
164+
updated_input_ids_array,
165+
input_ids)
82166

83167
class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
84168

@@ -139,6 +223,9 @@ def __init__(
139223
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
140224
cache_config.cache_dtype]
141225

226+
self._pre_async_results: AsyncPreResults | None = None
227+
self._substitute_placeholder_token_fn = _substitute_placeholder_token
228+
142229
def _init_random(self):
143230
if self.model_config.seed is None:
144231
self.model_config.seed = 0
@@ -343,13 +430,86 @@ def execute_model(
343430
self,
344431
scheduler_output: "VllmSchedulerOutput",
345432
intermediate_tensors: Optional[IntermediateTensors] = None,
346-
) -> ModelRunnerOutput:
433+
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
434+
347435
return self._execute_model(scheduler_output)[1]
348436

437+
def _modify_prev_results(self):
438+
# If copy to host has not been done, we just wait.
439+
# device_get should return immediately as we have scheduled it in previous function call.
440+
assert self._pre_async_results is not None, "When we call _modify_prev_results(), self._pre_async_results should already exist"
441+
pre_req_ids = self._pre_async_results.req_ids
442+
pre_next_tokens = self._pre_async_results.next_tokens
443+
pre_request_seq_lens = self._pre_async_results.request_seq_lens
444+
pre_discard_sampled_tokens_req_indices = self._pre_async_results.discard_sampled_tokens_req_indices
445+
446+
next_tokens_cpu = np.asarray(jax.device_get(pre_next_tokens))
447+
selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)], 1)
448+
valid_sampled_token_ids = selected_token_ids.tolist()
449+
450+
451+
# Mask out the sampled tokens that should not be sampled.
452+
for i in pre_discard_sampled_tokens_req_indices:
453+
valid_sampled_token_ids[i].clear()
454+
# Append sampled tokens
455+
for pre_req_idx, req_state, _ in pre_request_seq_lens:
456+
sampled_ids = valid_sampled_token_ids[pre_req_idx]
457+
if not sampled_ids:
458+
continue
459+
460+
# If request not active in the *current* batch (e.g. finished or evicted), skip it.
461+
req_id = pre_req_ids[pre_req_idx]
462+
if req_id not in self.input_batch.req_id_to_index:
463+
continue
464+
465+
466+
req_idx = self.input_batch.req_id_to_index[req_id]
467+
assert req_state is self.requests[req_id], "The req_state should be valid and identical"
468+
469+
# Updated on previous execute
470+
end_idx = self.input_batch.num_tokens_no_spec[req_idx]
471+
assert len(sampled_ids) == 1, "do not support spec decode yet"
472+
start_idx = end_idx - 1
473+
assert end_idx <= self.max_model_len, (
474+
"Sampled token IDs exceed the max model length. "
475+
f"Total number of tokens: {end_idx} > max_model_len: "
476+
f"{self.max_model_len}")
477+
478+
self.input_batch.token_ids_cpu[req_idx,
479+
start_idx:end_idx] = sampled_ids
480+
# Replace previous placeholder
481+
req_state.output_token_ids[-1] = sampled_ids[-1]
482+
483+
def _update_placeholder(self, discard_sampled_tokens_req_indices, request_seq_lens):
484+
placeholder_req_id_to_index:dict[str, int] = {}
485+
discard_sampled_tokens_req_indices_set = set(discard_sampled_tokens_req_indices)
486+
for req_idx, req_state, _ in request_seq_lens:
487+
if req_idx in discard_sampled_tokens_req_indices_set:
488+
continue
489+
490+
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
491+
# Not supporting spec decode yet, assume only 1 new token
492+
end_idx = start_idx + 1
493+
assert end_idx <= self.max_model_len, (
494+
"Sampled token IDs exceed the max model length. "
495+
f"Total number of tokens: {end_idx} > max_model_len: "
496+
f"{self.max_model_len}")
497+
498+
# Update cpu tokens at next execute and prepare input from tpu
499+
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
500+
self.input_batch.num_tokens[req_idx] = end_idx
501+
502+
# For placeholder, should be update on next execute.
503+
req_state.output_token_ids.extend([0])
504+
505+
placeholder_req_id_to_index[req_state.req_id] = req_idx
506+
return placeholder_req_id_to_index
507+
508+
349509
def _execute_model(
350510
self,
351511
scheduler_output: "VllmSchedulerOutput",
352-
) -> tuple[AttentionMetadata, ModelRunnerOutput]:
512+
) -> tuple[AttentionMetadata, ModelRunnerOutput | AsyncTPUModelRunnerOutput]:
353513
self.persistent_batch_manager.update_states(
354514
scheduler_output, self.get_mrope_input_positions_fn)
355515
if not scheduler_output.total_num_scheduled_tokens:
@@ -470,7 +630,7 @@ def _execute_model(
470630
num_reqs = self.input_batch.num_reqs
471631

472632
# Update the cache state concurrently. Code above will not block until
473-
# we use `selected_token_ids`. Add mark_step if post-processing changes
633+
# We use `selected_token_ids`. Add mark_step if post-processing changes
474634
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
475635
discard_sampled_tokens_req_indices = []
476636
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
@@ -501,6 +661,48 @@ def _execute_model(
501661
for req_id in self.input_batch.req_ids[:num_reqs]:
502662
prompt_logprobs_dict[req_id] = None
503663

664+
# If async scheduler enabled
665+
if self.scheduler_config.async_scheduling:
666+
# Get previous results from TPU and replace the placeholder.
667+
if self._pre_async_results is not None:
668+
assert not self.speculative_config and spec_decode_metadata is None, "Async scheduler does not support speculative decoding yet."
669+
self._modify_prev_results()
670+
671+
# Set placeholder for next tokens that is not yet generated
672+
placeholder_req_id_to_index:dict[str, int] = self._update_placeholder(discard_sampled_tokens_req_indices, request_seq_lens)
673+
674+
if logprobs is not None:
675+
logprobs_lists = logprobs.tolists()
676+
else:
677+
logprobs_lists = None
678+
679+
# Save the previous results
680+
next_tokens = jax.copy_to_host_async(next_tokens)
681+
self._pre_async_results = AsyncPreResults(
682+
req_ids=req_ids,
683+
next_tokens=next_tokens,
684+
request_seq_lens=request_seq_lens,
685+
discard_sampled_tokens_req_indices=discard_sampled_tokens_req_indices,
686+
placeholder_req_id_to_index=placeholder_req_id_to_index,
687+
)
688+
689+
# Return Model output to executor
690+
model_runner_output = ModelRunnerOutput(
691+
req_ids=req_ids,
692+
req_id_to_index=copy.deepcopy(self.input_batch.req_id_to_index),
693+
sampled_token_ids=[], # Fill in async get
694+
logprobs=logprobs_lists,
695+
prompt_logprobs_dict=prompt_logprobs_dict,
696+
pooler_output=[],
697+
kv_connector_output=kv_connector_output,
698+
)
699+
# Return attn_metadata, model_runner_output
700+
async_model_runner_output = AsyncTPUModelRunnerOutput(model_runner_output,
701+
next_tokens,
702+
num_reqs,
703+
discard_sampled_tokens_req_indices)
704+
return attn_metadata, async_model_runner_output
705+
504706
if spec_decode_metadata is None:
505707
next_tokens = np.asarray(jax.device_get(next_tokens))
506708
selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1)
@@ -596,6 +798,26 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
596798
# For each scheduled token, what are the corresponding req index.
597799
req_indices = np.repeat(self.arange_cpu[:num_reqs],
598800
num_scheduled_tokens_per_req)
801+
token_in_tpu_cur_input_indices = np.array([])
802+
token_in_tpu_pre_next_tokens_indices = np.array([])
803+
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
804+
# If async previous results exists, we will prepare for the token substitution here
805+
# The actual substitution will be performed in tpu during later parts of this function.
806+
token_in_tpu_cur_input_indices_list = []
807+
token_in_tpu_pre_next_tokens_indices_list = []
808+
acc_cur_len = 0
809+
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
810+
acc_cur_len += num_scheduled_tokens_per_req[i]
811+
assert req_id is not None
812+
if req_id not in self._pre_async_results.placeholder_req_id_to_index:
813+
continue
814+
815+
token_in_tpu_cur_input_indices_list.append(acc_cur_len-1)
816+
token_in_tpu_pre_next_tokens_indices_list.append(self._pre_async_results.placeholder_req_id_to_index[req_id])
817+
818+
if len(token_in_tpu_cur_input_indices_list) > 0:
819+
token_in_tpu_cur_input_indices = np.array(token_in_tpu_cur_input_indices_list)
820+
token_in_tpu_pre_next_tokens_indices = np.array(token_in_tpu_pre_next_tokens_indices_list)
599821

600822
# Get batched arange.
601823
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
@@ -701,6 +923,18 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
701923
self.mesh, (input_ids, positions, block_tables, query_start_loc,
702924
seq_lens, logits_indices, request_distribution))
703925

926+
if self.scheduler_config.async_scheduling and len(token_in_tpu_cur_input_indices) > 0:
927+
assert self._pre_async_results is not None
928+
idx_pad_len = len(input_ids) - len(token_in_tpu_cur_input_indices)
929+
padded_token_in_tpu_cur_input_indices = np.pad(token_in_tpu_cur_input_indices, (0, idx_pad_len))
930+
padded_token_in_tpu_pre_next_tokens_indices = np.pad(token_in_tpu_pre_next_tokens_indices, (0, idx_pad_len))
931+
with self.maybe_forbid_compile:
932+
input_ids = self._substitute_placeholder_token_fn(input_ids,
933+
padded_token_in_tpu_cur_input_indices,
934+
padded_token_in_tpu_pre_next_tokens_indices,
935+
self._pre_async_results.next_tokens,
936+
len(token_in_tpu_cur_input_indices))
937+
704938
if self.lora_config is not None:
705939
self.lora_utils.set_active_loras(
706940
num_scheduled_tokens_per_req, total_num_scheduled_tokens,

0 commit comments

Comments
 (0)