|
3 | 3 | import random |
4 | 4 | from contextlib import nullcontext |
5 | 5 | from typing import Any, Callable, Dict, List, Optional, Tuple, cast |
| 6 | +from dataclasses import dataclass |
| 7 | +import copy |
6 | 8 |
|
7 | 9 | import jax |
8 | 10 | import jax.numpy as jnp |
|
22 | 24 | from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput |
23 | 25 | from vllm.v1.kv_cache_interface import KVCacheConfig |
24 | 26 | from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, |
25 | | - ModelRunnerOutput) |
| 27 | + ModelRunnerOutput, AsyncModelRunnerOutput) |
26 | 28 | from vllm.v1.request import Request |
27 | 29 | from vllm.v1.spec_decode.ngram_proposer import NgramProposer |
28 | 30 | from vllm.v1.worker.kv_connector_model_runner_mixin import \ |
|
79 | 81 | "uint8": torch.uint8, |
80 | 82 | } |
81 | 83 |
|
| 84 | +class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput): |
| 85 | + """Holds asynchronous model output specifically from a TPU runner. |
| 86 | +
|
| 87 | + This class acts as a wrapper around the standard ModelRunnerOutput. Its |
| 88 | + primary purpose is to hold references to data still on the TPU device |
| 89 | + (like the `next_tokens` JAX array) without blocking the main thread. |
| 90 | +
|
| 91 | + The `get_output()` method is called to resolve these async results, |
| 92 | + triggering the JAX device-to-host (CPU) data transfer and populating |
| 93 | + the final `ModelRunnerOutput` object. |
| 94 | + """ |
| 95 | + |
| 96 | + def __init__(self, |
| 97 | + model_runner_output: ModelRunnerOutput, |
| 98 | + next_tokens: jax.Array, |
| 99 | + num_reqs: int, |
| 100 | + discard_sampled_tokens_req_indices: list[int], |
| 101 | + ): |
| 102 | + self._model_runner_output = model_runner_output |
| 103 | + self._next_tokens = next_tokens |
| 104 | + self._num_reqs = num_reqs |
| 105 | + self._discard_sampled_tokens_req_indices = discard_sampled_tokens_req_indices |
| 106 | + |
| 107 | + |
| 108 | + def get_output(self) -> ModelRunnerOutput: |
| 109 | + next_tokens_cpu = np.asarray(jax.device_get(self._next_tokens)) |
| 110 | + selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs], 1) |
| 111 | + valid_sampled_token_ids = selected_token_ids.tolist() |
| 112 | + for i in self._discard_sampled_tokens_req_indices: |
| 113 | + valid_sampled_token_ids[i].clear() |
| 114 | + self._model_runner_output.sampled_token_ids = valid_sampled_token_ids |
| 115 | + return self._model_runner_output |
| 116 | + |
| 117 | +@dataclass |
| 118 | +class AsyncPreResults: |
| 119 | + req_ids: list[str] |
| 120 | + next_tokens: jax.Array |
| 121 | + request_seq_lens: list[tuple[int, CachedRequestState, int]] |
| 122 | + discard_sampled_tokens_req_indices: list[int] |
| 123 | + placeholder_req_id_to_index: dict[str, int] |
| 124 | + |
| 125 | + |
| 126 | +@functools.partial(jax.jit, donate_argnums=(0, 1, 2)) |
| 127 | +def _substitute_placeholder_token(input_ids: jax.Array, |
| 128 | + token_in_tpu_cur_input_indices: jax.Array, |
| 129 | + token_in_tpu_pre_next_tokens_indices: jax.Array, |
| 130 | + next_tokens: jax.Array, |
| 131 | + placeholder_num: int): |
| 132 | + """Substitute placeholder tokens from TPU for async scheduler |
| 133 | +
|
| 134 | + Args: |
| 135 | + input_ids: possible input_ids size |
| 136 | + token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids. |
| 137 | + token_in_tpu_pre_next_tokens_indices: value idx in next_tokens. Length the same to input_ids. |
| 138 | + next_tokens: next tokens on the TPU from previous step. |
| 139 | + placeholder_num: number of placeholders. placeholder_num <= len(token_in_tpu_cur_input_indices) |
| 140 | + Return: |
| 141 | + input_ids after replace placeholder tokens |
| 142 | + """ |
| 143 | + assert input_ids.shape == token_in_tpu_cur_input_indices.shape == token_in_tpu_pre_next_tokens_indices.shape, \ |
| 144 | + f"Shape mismatch: input_ids and index arrays must have identical shapes due to precompilation assumptions. " \ |
| 145 | + f"Got: {input_ids.shape=}, {token_in_tpu_cur_input_indices.shape=}, {token_in_tpu_pre_next_tokens_indices.shape=}" |
| 146 | + |
| 147 | + def updated_input_ids_array(i: int, current_input_ids: jax.Array) -> jax.Array: |
| 148 | + """ |
| 149 | + Iteratively updates the input_ids for all placeholders. |
| 150 | +å |
| 151 | + Args: |
| 152 | + i: The current loop index. |
| 153 | + current_input_ids: The loop carry state (the input_ids being modified). |
| 154 | +
|
| 155 | + Returns: |
| 156 | + The updated input_ids array. |
| 157 | + """ |
| 158 | + update_idx = token_in_tpu_cur_input_indices[i] |
| 159 | + value_idx = token_in_tpu_pre_next_tokens_indices[i] |
| 160 | + new_token_value = next_tokens[value_idx] |
| 161 | + return current_input_ids.at[update_idx].set(new_token_value) |
| 162 | + |
| 163 | + return jax.lax.fori_loop(0, placeholder_num, |
| 164 | + updated_input_ids_array, |
| 165 | + input_ids) |
82 | 166 |
|
83 | 167 | class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin): |
84 | 168 |
|
@@ -139,6 +223,9 @@ def __init__( |
139 | 223 | self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[ |
140 | 224 | cache_config.cache_dtype] |
141 | 225 |
|
| 226 | + self._pre_async_results: AsyncPreResults | None = None |
| 227 | + self._substitute_placeholder_token_fn = _substitute_placeholder_token |
| 228 | + |
142 | 229 | def _init_random(self): |
143 | 230 | if self.model_config.seed is None: |
144 | 231 | self.model_config.seed = 0 |
@@ -343,13 +430,86 @@ def execute_model( |
343 | 430 | self, |
344 | 431 | scheduler_output: "VllmSchedulerOutput", |
345 | 432 | intermediate_tensors: Optional[IntermediateTensors] = None, |
346 | | - ) -> ModelRunnerOutput: |
| 433 | + ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput: |
| 434 | + |
347 | 435 | return self._execute_model(scheduler_output)[1] |
348 | 436 |
|
| 437 | + def _modify_prev_results(self): |
| 438 | + # If copy to host has not been done, we just wait. |
| 439 | + # device_get should return immediately as we have scheduled it in previous function call. |
| 440 | + assert self._pre_async_results is not None, "When we call _modify_prev_results(), self._pre_async_results should already exist" |
| 441 | + pre_req_ids = self._pre_async_results.req_ids |
| 442 | + pre_next_tokens = self._pre_async_results.next_tokens |
| 443 | + pre_request_seq_lens = self._pre_async_results.request_seq_lens |
| 444 | + pre_discard_sampled_tokens_req_indices = self._pre_async_results.discard_sampled_tokens_req_indices |
| 445 | + |
| 446 | + next_tokens_cpu = np.asarray(jax.device_get(pre_next_tokens)) |
| 447 | + selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)], 1) |
| 448 | + valid_sampled_token_ids = selected_token_ids.tolist() |
| 449 | + |
| 450 | + |
| 451 | + # Mask out the sampled tokens that should not be sampled. |
| 452 | + for i in pre_discard_sampled_tokens_req_indices: |
| 453 | + valid_sampled_token_ids[i].clear() |
| 454 | + # Append sampled tokens |
| 455 | + for pre_req_idx, req_state, _ in pre_request_seq_lens: |
| 456 | + sampled_ids = valid_sampled_token_ids[pre_req_idx] |
| 457 | + if not sampled_ids: |
| 458 | + continue |
| 459 | + |
| 460 | + # If request not active in the *current* batch (e.g. finished or evicted), skip it. |
| 461 | + req_id = pre_req_ids[pre_req_idx] |
| 462 | + if req_id not in self.input_batch.req_id_to_index: |
| 463 | + continue |
| 464 | + |
| 465 | + |
| 466 | + req_idx = self.input_batch.req_id_to_index[req_id] |
| 467 | + assert req_state is self.requests[req_id], "The req_state should be valid and identical" |
| 468 | + |
| 469 | + # Updated on previous execute |
| 470 | + end_idx = self.input_batch.num_tokens_no_spec[req_idx] |
| 471 | + assert len(sampled_ids) == 1, "do not support spec decode yet" |
| 472 | + start_idx = end_idx - 1 |
| 473 | + assert end_idx <= self.max_model_len, ( |
| 474 | + "Sampled token IDs exceed the max model length. " |
| 475 | + f"Total number of tokens: {end_idx} > max_model_len: " |
| 476 | + f"{self.max_model_len}") |
| 477 | + |
| 478 | + self.input_batch.token_ids_cpu[req_idx, |
| 479 | + start_idx:end_idx] = sampled_ids |
| 480 | + # Replace previous placeholder |
| 481 | + req_state.output_token_ids[-1] = sampled_ids[-1] |
| 482 | + |
| 483 | + def _update_placeholder(self, discard_sampled_tokens_req_indices, request_seq_lens): |
| 484 | + placeholder_req_id_to_index:dict[str, int] = {} |
| 485 | + discard_sampled_tokens_req_indices_set = set(discard_sampled_tokens_req_indices) |
| 486 | + for req_idx, req_state, _ in request_seq_lens: |
| 487 | + if req_idx in discard_sampled_tokens_req_indices_set: |
| 488 | + continue |
| 489 | + |
| 490 | + start_idx = self.input_batch.num_tokens_no_spec[req_idx] |
| 491 | + # Not supporting spec decode yet, assume only 1 new token |
| 492 | + end_idx = start_idx + 1 |
| 493 | + assert end_idx <= self.max_model_len, ( |
| 494 | + "Sampled token IDs exceed the max model length. " |
| 495 | + f"Total number of tokens: {end_idx} > max_model_len: " |
| 496 | + f"{self.max_model_len}") |
| 497 | + |
| 498 | + # Update cpu tokens at next execute and prepare input from tpu |
| 499 | + self.input_batch.num_tokens_no_spec[req_idx] = end_idx |
| 500 | + self.input_batch.num_tokens[req_idx] = end_idx |
| 501 | + |
| 502 | + # For placeholder, should be update on next execute. |
| 503 | + req_state.output_token_ids.extend([0]) |
| 504 | + |
| 505 | + placeholder_req_id_to_index[req_state.req_id] = req_idx |
| 506 | + return placeholder_req_id_to_index |
| 507 | + |
| 508 | + |
349 | 509 | def _execute_model( |
350 | 510 | self, |
351 | 511 | scheduler_output: "VllmSchedulerOutput", |
352 | | - ) -> tuple[AttentionMetadata, ModelRunnerOutput]: |
| 512 | + ) -> tuple[AttentionMetadata, ModelRunnerOutput | AsyncTPUModelRunnerOutput]: |
353 | 513 | self.persistent_batch_manager.update_states( |
354 | 514 | scheduler_output, self.get_mrope_input_positions_fn) |
355 | 515 | if not scheduler_output.total_num_scheduled_tokens: |
@@ -470,7 +630,7 @@ def _execute_model( |
470 | 630 | num_reqs = self.input_batch.num_reqs |
471 | 631 |
|
472 | 632 | # Update the cache state concurrently. Code above will not block until |
473 | | - # we use `selected_token_ids`. Add mark_step if post-processing changes |
| 633 | + # We use `selected_token_ids`. Add mark_step if post-processing changes |
474 | 634 | request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] |
475 | 635 | discard_sampled_tokens_req_indices = [] |
476 | 636 | for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): |
@@ -501,6 +661,48 @@ def _execute_model( |
501 | 661 | for req_id in self.input_batch.req_ids[:num_reqs]: |
502 | 662 | prompt_logprobs_dict[req_id] = None |
503 | 663 |
|
| 664 | + # If async scheduler enabled |
| 665 | + if self.scheduler_config.async_scheduling: |
| 666 | + # Get previous results from TPU and replace the placeholder. |
| 667 | + if self._pre_async_results is not None: |
| 668 | + assert not self.speculative_config and spec_decode_metadata is None, "Async scheduler does not support speculative decoding yet." |
| 669 | + self._modify_prev_results() |
| 670 | + |
| 671 | + # Set placeholder for next tokens that is not yet generated |
| 672 | + placeholder_req_id_to_index:dict[str, int] = self._update_placeholder(discard_sampled_tokens_req_indices, request_seq_lens) |
| 673 | + |
| 674 | + if logprobs is not None: |
| 675 | + logprobs_lists = logprobs.tolists() |
| 676 | + else: |
| 677 | + logprobs_lists = None |
| 678 | + |
| 679 | + # Save the previous results |
| 680 | + next_tokens = jax.copy_to_host_async(next_tokens) |
| 681 | + self._pre_async_results = AsyncPreResults( |
| 682 | + req_ids=req_ids, |
| 683 | + next_tokens=next_tokens, |
| 684 | + request_seq_lens=request_seq_lens, |
| 685 | + discard_sampled_tokens_req_indices=discard_sampled_tokens_req_indices, |
| 686 | + placeholder_req_id_to_index=placeholder_req_id_to_index, |
| 687 | + ) |
| 688 | + |
| 689 | + # Return Model output to executor |
| 690 | + model_runner_output = ModelRunnerOutput( |
| 691 | + req_ids=req_ids, |
| 692 | + req_id_to_index=copy.deepcopy(self.input_batch.req_id_to_index), |
| 693 | + sampled_token_ids=[], # Fill in async get |
| 694 | + logprobs=logprobs_lists, |
| 695 | + prompt_logprobs_dict=prompt_logprobs_dict, |
| 696 | + pooler_output=[], |
| 697 | + kv_connector_output=kv_connector_output, |
| 698 | + ) |
| 699 | + # Return attn_metadata, model_runner_output |
| 700 | + async_model_runner_output = AsyncTPUModelRunnerOutput(model_runner_output, |
| 701 | + next_tokens, |
| 702 | + num_reqs, |
| 703 | + discard_sampled_tokens_req_indices) |
| 704 | + return attn_metadata, async_model_runner_output |
| 705 | + |
504 | 706 | if spec_decode_metadata is None: |
505 | 707 | next_tokens = np.asarray(jax.device_get(next_tokens)) |
506 | 708 | selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1) |
@@ -596,6 +798,26 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"): |
596 | 798 | # For each scheduled token, what are the corresponding req index. |
597 | 799 | req_indices = np.repeat(self.arange_cpu[:num_reqs], |
598 | 800 | num_scheduled_tokens_per_req) |
| 801 | + token_in_tpu_cur_input_indices = np.array([]) |
| 802 | + token_in_tpu_pre_next_tokens_indices = np.array([]) |
| 803 | + if self.scheduler_config.async_scheduling and self._pre_async_results is not None: |
| 804 | + # If async previous results exists, we will prepare for the token substitution here |
| 805 | + # The actual substitution will be performed in tpu during later parts of this function. |
| 806 | + token_in_tpu_cur_input_indices_list = [] |
| 807 | + token_in_tpu_pre_next_tokens_indices_list = [] |
| 808 | + acc_cur_len = 0 |
| 809 | + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): |
| 810 | + acc_cur_len += num_scheduled_tokens_per_req[i] |
| 811 | + assert req_id is not None |
| 812 | + if req_id not in self._pre_async_results.placeholder_req_id_to_index: |
| 813 | + continue |
| 814 | + |
| 815 | + token_in_tpu_cur_input_indices_list.append(acc_cur_len-1) |
| 816 | + token_in_tpu_pre_next_tokens_indices_list.append(self._pre_async_results.placeholder_req_id_to_index[req_id]) |
| 817 | + |
| 818 | + if len(token_in_tpu_cur_input_indices_list) > 0: |
| 819 | + token_in_tpu_cur_input_indices = np.array(token_in_tpu_cur_input_indices_list) |
| 820 | + token_in_tpu_pre_next_tokens_indices = np.array(token_in_tpu_pre_next_tokens_indices_list) |
599 | 821 |
|
600 | 822 | # Get batched arange. |
601 | 823 | # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] |
@@ -701,6 +923,18 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"): |
701 | 923 | self.mesh, (input_ids, positions, block_tables, query_start_loc, |
702 | 924 | seq_lens, logits_indices, request_distribution)) |
703 | 925 |
|
| 926 | + if self.scheduler_config.async_scheduling and len(token_in_tpu_cur_input_indices) > 0: |
| 927 | + assert self._pre_async_results is not None |
| 928 | + idx_pad_len = len(input_ids) - len(token_in_tpu_cur_input_indices) |
| 929 | + padded_token_in_tpu_cur_input_indices = np.pad(token_in_tpu_cur_input_indices, (0, idx_pad_len)) |
| 930 | + padded_token_in_tpu_pre_next_tokens_indices = np.pad(token_in_tpu_pre_next_tokens_indices, (0, idx_pad_len)) |
| 931 | + with self.maybe_forbid_compile: |
| 932 | + input_ids = self._substitute_placeholder_token_fn(input_ids, |
| 933 | + padded_token_in_tpu_cur_input_indices, |
| 934 | + padded_token_in_tpu_pre_next_tokens_indices, |
| 935 | + self._pre_async_results.next_tokens, |
| 936 | + len(token_in_tpu_cur_input_indices)) |
| 937 | + |
704 | 938 | if self.lora_config is not None: |
705 | 939 | self.lora_utils.set_active_loras( |
706 | 940 | num_scheduled_tokens_per_req, total_num_scheduled_tokens, |
|
0 commit comments