|
| 1 | +import copy |
1 | 2 | import functools |
2 | 3 | import os |
3 | 4 | import random |
4 | 5 | from contextlib import nullcontext |
| 6 | +from dataclasses import dataclass |
5 | 7 | from typing import Any, Callable, Dict, List, Optional, Tuple, cast |
6 | 8 |
|
7 | 9 | import jax |
|
21 | 23 | from vllm.utils import cdiv |
22 | 24 | from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput |
23 | 25 | from vllm.v1.kv_cache_interface import KVCacheConfig |
24 | | -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, |
25 | | - ModelRunnerOutput) |
| 26 | +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, |
| 27 | + DraftTokenIds, ModelRunnerOutput) |
26 | 28 | from vllm.v1.request import Request |
27 | 29 | from vllm.v1.spec_decode.ngram_proposer import NgramProposer |
28 | 30 | from vllm.v1.worker.kv_connector_model_runner_mixin import \ |
|
80 | 82 | } |
81 | 83 |
|
82 | 84 |
|
| 85 | +class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput): |
| 86 | + """Holds asynchronous model output specifically from a TPU runner. |
| 87 | +
|
| 88 | + This class acts as a wrapper around the standard ModelRunnerOutput. Its |
| 89 | + primary purpose is to hold references to data still on the TPU device |
| 90 | + (like the `next_tokens` JAX array) without blocking the main thread. |
| 91 | +
|
| 92 | + The `get_output()` method is called to resolve these async results, |
| 93 | + triggering the JAX device-to-host (CPU) data transfer and populating |
| 94 | + the final `ModelRunnerOutput` object. |
| 95 | + """ |
| 96 | + |
| 97 | + def __init__( |
| 98 | + self, |
| 99 | + model_runner_output: ModelRunnerOutput, |
| 100 | + next_tokens: jax.Array, |
| 101 | + num_reqs: int, |
| 102 | + discard_sampled_tokens_req_indices: list[int], |
| 103 | + ): |
| 104 | + self._model_runner_output = model_runner_output |
| 105 | + self._next_tokens = next_tokens |
| 106 | + self._num_reqs = num_reqs |
| 107 | + self._discard_sampled_tokens_req_indices = discard_sampled_tokens_req_indices |
| 108 | + |
| 109 | + def get_output(self) -> ModelRunnerOutput: |
| 110 | + next_tokens_cpu = np.asarray(jax.device_get(self._next_tokens)) |
| 111 | + selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs], |
| 112 | + 1) |
| 113 | + valid_sampled_token_ids = selected_token_ids.tolist() |
| 114 | + for i in self._discard_sampled_tokens_req_indices: |
| 115 | + valid_sampled_token_ids[i].clear() |
| 116 | + self._model_runner_output.sampled_token_ids = valid_sampled_token_ids |
| 117 | + return self._model_runner_output |
| 118 | + |
| 119 | + |
| 120 | +@dataclass |
| 121 | +class AsyncPreResults: |
| 122 | + req_ids: list[str] |
| 123 | + next_tokens: jax.Array |
| 124 | + request_seq_lens: list[tuple[int, CachedRequestState, int]] |
| 125 | + discard_sampled_tokens_req_indices: list[int] |
| 126 | + placeholder_req_id_to_index: dict[str, int] |
| 127 | + |
| 128 | + |
| 129 | +@functools.partial(jax.jit, donate_argnums=(0, 1, 2)) |
| 130 | +def _substitute_placeholder_token( |
| 131 | + input_ids: jax.Array, token_in_tpu_cur_input_indices: jax.Array, |
| 132 | + token_in_tpu_pre_next_tokens_indices: jax.Array, |
| 133 | + next_tokens: jax.Array, placeholder_num: int): |
| 134 | + """Substitute placeholder tokens from TPU for async scheduler |
| 135 | +
|
| 136 | + Args: |
| 137 | + input_ids: possible input_ids size |
| 138 | + token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids. |
| 139 | + token_in_tpu_pre_next_tokens_indices: value idx in next_tokens. Length the same to input_ids. |
| 140 | + next_tokens: next tokens on the TPU from previous step. |
| 141 | + placeholder_num: number of placeholders. placeholder_num <= len(token_in_tpu_cur_input_indices) |
| 142 | + Return: |
| 143 | + input_ids after replace placeholder tokens |
| 144 | + """ |
| 145 | + assert input_ids.shape == token_in_tpu_cur_input_indices.shape == token_in_tpu_pre_next_tokens_indices.shape, \ |
| 146 | + f"Shape mismatch: input_ids and index arrays must have identical shapes due to precompilation assumptions. " \ |
| 147 | + f"Got: {input_ids.shape=}, {token_in_tpu_cur_input_indices.shape=}, {token_in_tpu_pre_next_tokens_indices.shape=}" |
| 148 | + |
| 149 | + def updated_input_ids_array(i: int, |
| 150 | + current_input_ids: jax.Array) -> jax.Array: |
| 151 | + """ |
| 152 | + Iteratively updates the input_ids for all placeholders. |
| 153 | +å |
| 154 | + Args: |
| 155 | + i: The current loop index. |
| 156 | + current_input_ids: The loop carry state (the input_ids being modified). |
| 157 | +
|
| 158 | + Returns: |
| 159 | + The updated input_ids array. |
| 160 | + """ |
| 161 | + update_idx = token_in_tpu_cur_input_indices[i] |
| 162 | + value_idx = token_in_tpu_pre_next_tokens_indices[i] |
| 163 | + new_token_value = next_tokens[value_idx] |
| 164 | + return current_input_ids.at[update_idx].set(new_token_value) |
| 165 | + |
| 166 | + return jax.lax.fori_loop(0, placeholder_num, updated_input_ids_array, |
| 167 | + input_ids) |
| 168 | + |
| 169 | + |
83 | 170 | class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin): |
84 | 171 |
|
85 | 172 | def __init__( |
@@ -139,6 +226,9 @@ def __init__( |
139 | 226 | self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[ |
140 | 227 | cache_config.cache_dtype] |
141 | 228 |
|
| 229 | + self._pre_async_results: AsyncPreResults | None = None |
| 230 | + self._substitute_placeholder_token_fn = _substitute_placeholder_token |
| 231 | + |
142 | 232 | def _init_random(self): |
143 | 233 | if self.model_config.seed is None: |
144 | 234 | self.model_config.seed = 0 |
@@ -343,13 +433,88 @@ def execute_model( |
343 | 433 | self, |
344 | 434 | scheduler_output: "VllmSchedulerOutput", |
345 | 435 | intermediate_tensors: Optional[IntermediateTensors] = None, |
346 | | - ) -> ModelRunnerOutput: |
| 436 | + ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput: |
| 437 | + |
347 | 438 | return self._execute_model(scheduler_output)[1] |
348 | 439 |
|
| 440 | + def _modify_prev_results(self): |
| 441 | + # If copy to host has not been done, we just wait. |
| 442 | + # device_get should return immediately as we have scheduled it in previous function call. |
| 443 | + assert self._pre_async_results is not None, "When we call _modify_prev_results(), self._pre_async_results should already exist" |
| 444 | + pre_req_ids = self._pre_async_results.req_ids |
| 445 | + pre_next_tokens = self._pre_async_results.next_tokens |
| 446 | + pre_request_seq_lens = self._pre_async_results.request_seq_lens |
| 447 | + pre_discard_sampled_tokens_req_indices = self._pre_async_results.discard_sampled_tokens_req_indices |
| 448 | + |
| 449 | + next_tokens_cpu = np.asarray(jax.device_get(pre_next_tokens)) |
| 450 | + selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)], |
| 451 | + 1) |
| 452 | + valid_sampled_token_ids = selected_token_ids.tolist() |
| 453 | + |
| 454 | + # Mask out the sampled tokens that should not be sampled. |
| 455 | + for i in pre_discard_sampled_tokens_req_indices: |
| 456 | + valid_sampled_token_ids[i].clear() |
| 457 | + # Append sampled tokens |
| 458 | + for pre_req_idx, req_state, _ in pre_request_seq_lens: |
| 459 | + sampled_ids = valid_sampled_token_ids[pre_req_idx] |
| 460 | + if not sampled_ids: |
| 461 | + continue |
| 462 | + |
| 463 | + # If request not active in the *current* batch (e.g. finished or evicted), skip it. |
| 464 | + req_id = pre_req_ids[pre_req_idx] |
| 465 | + if req_id not in self.input_batch.req_id_to_index: |
| 466 | + continue |
| 467 | + |
| 468 | + req_idx = self.input_batch.req_id_to_index[req_id] |
| 469 | + assert req_state is self.requests[ |
| 470 | + req_id], "The req_state should be valid and identical" |
| 471 | + |
| 472 | + # Updated on previous execute |
| 473 | + end_idx = self.input_batch.num_tokens_no_spec[req_idx] |
| 474 | + assert len(sampled_ids) == 1, "do not support spec decode yet" |
| 475 | + start_idx = end_idx - 1 |
| 476 | + assert end_idx <= self.max_model_len, ( |
| 477 | + "Sampled token IDs exceed the max model length. " |
| 478 | + f"Total number of tokens: {end_idx} > max_model_len: " |
| 479 | + f"{self.max_model_len}") |
| 480 | + |
| 481 | + self.input_batch.token_ids_cpu[req_idx, |
| 482 | + start_idx:end_idx] = sampled_ids |
| 483 | + # Replace previous placeholder |
| 484 | + req_state.output_token_ids[-1] = sampled_ids[-1] |
| 485 | + |
| 486 | + def _update_placeholder(self, discard_sampled_tokens_req_indices, |
| 487 | + request_seq_lens): |
| 488 | + placeholder_req_id_to_index: dict[str, int] = {} |
| 489 | + discard_sampled_tokens_req_indices_set = set( |
| 490 | + discard_sampled_tokens_req_indices) |
| 491 | + for req_idx, req_state, _ in request_seq_lens: |
| 492 | + if req_idx in discard_sampled_tokens_req_indices_set: |
| 493 | + continue |
| 494 | + |
| 495 | + start_idx = self.input_batch.num_tokens_no_spec[req_idx] |
| 496 | + # Not supporting spec decode yet, assume only 1 new token |
| 497 | + end_idx = start_idx + 1 |
| 498 | + assert end_idx <= self.max_model_len, ( |
| 499 | + "Sampled token IDs exceed the max model length. " |
| 500 | + f"Total number of tokens: {end_idx} > max_model_len: " |
| 501 | + f"{self.max_model_len}") |
| 502 | + |
| 503 | + # Update cpu tokens at next execute and prepare input from tpu |
| 504 | + self.input_batch.num_tokens_no_spec[req_idx] = end_idx |
| 505 | + self.input_batch.num_tokens[req_idx] = end_idx |
| 506 | + |
| 507 | + # For placeholder, should be update on next execute. |
| 508 | + req_state.output_token_ids.extend([0]) |
| 509 | + |
| 510 | + placeholder_req_id_to_index[req_state.req_id] = req_idx |
| 511 | + return placeholder_req_id_to_index |
| 512 | + |
349 | 513 | def _execute_model( |
350 | 514 | self, |
351 | 515 | scheduler_output: "VllmSchedulerOutput", |
352 | | - ) -> tuple[AttentionMetadata, ModelRunnerOutput]: |
| 516 | + ) -> tuple[AttentionMetadata, ModelRunnerOutput |
| 517 | + | AsyncTPUModelRunnerOutput]: |
353 | 518 | self.persistent_batch_manager.update_states( |
354 | 519 | scheduler_output, self.get_mrope_input_positions_fn) |
355 | 520 | if not scheduler_output.total_num_scheduled_tokens: |
@@ -470,7 +635,7 @@ def _execute_model( |
470 | 635 | num_reqs = self.input_batch.num_reqs |
471 | 636 |
|
472 | 637 | # Update the cache state concurrently. Code above will not block until |
473 | | - # we use `selected_token_ids`. Add mark_step if post-processing changes |
| 638 | + # We use `selected_token_ids`. Add mark_step if post-processing changes |
474 | 639 | request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] |
475 | 640 | discard_sampled_tokens_req_indices = [] |
476 | 641 | for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): |
@@ -501,6 +666,51 @@ def _execute_model( |
501 | 666 | for req_id in self.input_batch.req_ids[:num_reqs]: |
502 | 667 | prompt_logprobs_dict[req_id] = None |
503 | 668 |
|
| 669 | + # If async scheduler enabled |
| 670 | + if self.scheduler_config.async_scheduling: |
| 671 | + # Get previous results from TPU and replace the placeholder. |
| 672 | + if self._pre_async_results is not None: |
| 673 | + assert not self.speculative_config and spec_decode_metadata is None, "Async scheduler does not support speculative decoding yet." |
| 674 | + self._modify_prev_results() |
| 675 | + |
| 676 | + # Set placeholder for next tokens that is not yet generated |
| 677 | + placeholder_req_id_to_index: dict[ |
| 678 | + str, int] = self._update_placeholder( |
| 679 | + discard_sampled_tokens_req_indices, request_seq_lens) |
| 680 | + |
| 681 | + if logprobs is not None: |
| 682 | + logprobs_lists = logprobs.tolists() |
| 683 | + else: |
| 684 | + logprobs_lists = None |
| 685 | + |
| 686 | + # Save the previous results |
| 687 | + next_tokens = jax.copy_to_host_async(next_tokens) |
| 688 | + self._pre_async_results = AsyncPreResults( |
| 689 | + req_ids=req_ids, |
| 690 | + next_tokens=next_tokens, |
| 691 | + request_seq_lens=request_seq_lens, |
| 692 | + discard_sampled_tokens_req_indices= |
| 693 | + discard_sampled_tokens_req_indices, |
| 694 | + placeholder_req_id_to_index=placeholder_req_id_to_index, |
| 695 | + ) |
| 696 | + |
| 697 | + # Return Model output to executor |
| 698 | + model_runner_output = ModelRunnerOutput( |
| 699 | + req_ids=req_ids, |
| 700 | + req_id_to_index=copy.deepcopy( |
| 701 | + self.input_batch.req_id_to_index), |
| 702 | + sampled_token_ids=[], # Fill in async get |
| 703 | + logprobs=logprobs_lists, |
| 704 | + prompt_logprobs_dict=prompt_logprobs_dict, |
| 705 | + pooler_output=[], |
| 706 | + kv_connector_output=kv_connector_output, |
| 707 | + ) |
| 708 | + # Return attn_metadata, model_runner_output |
| 709 | + async_model_runner_output = AsyncTPUModelRunnerOutput( |
| 710 | + model_runner_output, next_tokens, num_reqs, |
| 711 | + discard_sampled_tokens_req_indices) |
| 712 | + return attn_metadata, async_model_runner_output |
| 713 | + |
504 | 714 | if spec_decode_metadata is None: |
505 | 715 | next_tokens = np.asarray(jax.device_get(next_tokens)) |
506 | 716 | selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1) |
@@ -596,6 +806,30 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"): |
596 | 806 | # For each scheduled token, what are the corresponding req index. |
597 | 807 | req_indices = np.repeat(self.arange_cpu[:num_reqs], |
598 | 808 | num_scheduled_tokens_per_req) |
| 809 | + token_in_tpu_cur_input_indices = np.array([]) |
| 810 | + token_in_tpu_pre_next_tokens_indices = np.array([]) |
| 811 | + if self.scheduler_config.async_scheduling and self._pre_async_results is not None: |
| 812 | + # If async previous results exists, we will prepare for the token substitution here |
| 813 | + # The actual substitution will be performed in tpu during later parts of this function. |
| 814 | + token_in_tpu_cur_input_indices_list = [] |
| 815 | + token_in_tpu_pre_next_tokens_indices_list = [] |
| 816 | + acc_cur_len = 0 |
| 817 | + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): |
| 818 | + acc_cur_len += num_scheduled_tokens_per_req[i] |
| 819 | + assert req_id is not None |
| 820 | + if req_id not in self._pre_async_results.placeholder_req_id_to_index: |
| 821 | + continue |
| 822 | + |
| 823 | + token_in_tpu_cur_input_indices_list.append(acc_cur_len - 1) |
| 824 | + token_in_tpu_pre_next_tokens_indices_list.append( |
| 825 | + self._pre_async_results.placeholder_req_id_to_index[req_id] |
| 826 | + ) |
| 827 | + |
| 828 | + if len(token_in_tpu_cur_input_indices_list) > 0: |
| 829 | + token_in_tpu_cur_input_indices = np.array( |
| 830 | + token_in_tpu_cur_input_indices_list) |
| 831 | + token_in_tpu_pre_next_tokens_indices = np.array( |
| 832 | + token_in_tpu_pre_next_tokens_indices_list) |
599 | 833 |
|
600 | 834 | # Get batched arange. |
601 | 835 | # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] |
@@ -701,6 +935,21 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"): |
701 | 935 | self.mesh, (input_ids, positions, block_tables, query_start_loc, |
702 | 936 | seq_lens, logits_indices, request_distribution)) |
703 | 937 |
|
| 938 | + if self.scheduler_config.async_scheduling and len( |
| 939 | + token_in_tpu_cur_input_indices) > 0: |
| 940 | + assert self._pre_async_results is not None |
| 941 | + idx_pad_len = len(input_ids) - len(token_in_tpu_cur_input_indices) |
| 942 | + padded_token_in_tpu_cur_input_indices = np.pad( |
| 943 | + token_in_tpu_cur_input_indices, (0, idx_pad_len)) |
| 944 | + padded_token_in_tpu_pre_next_tokens_indices = np.pad( |
| 945 | + token_in_tpu_pre_next_tokens_indices, (0, idx_pad_len)) |
| 946 | + with self.maybe_forbid_compile: |
| 947 | + input_ids = self._substitute_placeholder_token_fn( |
| 948 | + input_ids, padded_token_in_tpu_cur_input_indices, |
| 949 | + padded_token_in_tpu_pre_next_tokens_indices, |
| 950 | + self._pre_async_results.next_tokens, |
| 951 | + len(token_in_tpu_cur_input_indices)) |
| 952 | + |
704 | 953 | if self.lora_config is not None: |
705 | 954 | self.lora_utils.set_active_loras( |
706 | 955 | num_scheduled_tokens_per_req, total_num_scheduled_tokens, |
|
0 commit comments