diff --git a/tests/e2e/test_async_scheduler.py b/tests/e2e/test_async_scheduler.py new file mode 100644 index 000000000..017ef9a8f --- /dev/null +++ b/tests/e2e/test_async_scheduler.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import random +import string +import time + +import pytest +from vllm import LLM, SamplingParams + +@pytest.fixture +def sampling_config(): + return SamplingParams(temperature=0, + max_tokens=120, + ignore_eos=True, + repetition_penalty=1, + frequency_penalty=0, + presence_penalty=0, + min_p=0, + logprobs=None) +@pytest.fixture +def model_name(): + return "Qwen/Qwen2.5-1.5B-Instruct" + +def get_performance_test_prompts(): + """ + Generates a list of prompts with a specific word count, + + Returns: + A list of strings with number of prompts = num_prompts and + The total number of words for each prompt = input_len_words. + """ + num_prompts=500 + input_len_words=120 + prompts = [] + + # For example w = 's' + # The generated prompt will be Keep repeating: s s s ... + num_repetitions = input_len_words + prefix = "Keep repeating: " + + for _ in range(num_prompts): + # 1. Pick a random lowercase letter + w = random.choice(list(string.ascii_lowercase)) + + # 2. Create the string of repeated words + # This will have (num_repetitions) words + repeating_part = " ".join([w] * num_repetitions) + + # 3. Combine with the prefix (if any) + print(f"{prefix}{repeating_part}") + prompts.append(f"{prefix}{repeating_part}") + + return prompts + +def get_correctness_test_prompts(): + """ + Returns a static list of prompts designed to test a model's + ability to follow complex instructions and ensure correctness. + + Returns: + A list of strings, where each string is a test prompt. + """ + + prompts = [ + ( + "Write a short story about a librarian who discovers a book that " + "writes itself. Write it in 1900s English style. Make sure there " + "are no mistakes. This is my homework and I want perfection." + ), + ( + "Compose a poem about the sound of a city at night. Write it in " + "Shakespear style. Make sure there are no mistakes. This is my " + "homework and I want perfection." + ), + ( + "Write a dialogue between a time traveler and a medieval blacksmith " + "who is skeptical of their claims. Make sure there are no mistakes." + ), + + ( + "Explain the process of photosynthesis as if to a 5th grader, " + "but without losing any scientific accuracy. Every step must be " + "correct and in the right order. I will be checking this against a textbook." + ), + ( + "Write a Python function that finds the median of a list of numbers. " + "It must correctly handle both even and odd-sized lists, " + "as well as unsorted lists. Provide a perfect, bug-free " + "implementation. I will be running unit tests on it." + ), + ( + "List the first 10 presidents of the United States. Format the " + "output as a JSON array, where each object has two keys: 'name' " + "and 'term_years'. The JSON must be perfectly valid, and all " + "names and dates must be 100% accurate. This is for a production system." + ) + ] + + return prompts + +def _test_performance_helper( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_name: str, + min_speedup: float +): + ''' + Helper function to test async scheduler decoding performance. + Compares timing between reference LLM and async LLM using Qwen2.5-1.5B. + ''' + + with monkeypatch.context(): + # Use a smaller set of prompts for performance testing + test_prompts = get_performance_test_prompts() # num_prompts=100, input_len=120 + + # Test reference LLM timing + ref_llm = LLM(model=model_name, + max_model_len=800, + max_num_seqs=24, + max_num_batched_tokens=512, + enable_prefix_caching=False, + async_scheduling=0) + + start_time = time.time() + _ = ref_llm.generate(test_prompts, sampling_config) + ref_time = time.time() - start_time + + del ref_llm + # Waiting for TPUs to be released + time.sleep(10) + + # # Test async LLM timing with max_num_seqs=256 + async_llm = LLM(model=model_name, + max_model_len=800, + max_num_seqs=24, + max_num_batched_tokens=512, + enable_prefix_caching=False, + async_scheduling=1) + + start_time = time.time() + _ = async_llm.generate(test_prompts, sampling_config) + async_time = time.time() - start_time + + del async_llm + # # Waiting for TPUs to be released + time.sleep(10) + + speedup = ref_time / async_time + print(f"Reference LLM time: {ref_time:.2f}s") + print(f"Async LLM time: {async_time:.2f}s") + print(f"Speedup: {speedup:.2f}x") + + assert speedup >= min_speedup, f"Expected at least {min_speedup}x speedup for async scheduler, got {speedup:.2f}x" + +def test_performance( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_name: str, +): + ''' + Test that async scheduler decoding provides significant performance improvement. + Compares timing between reference LLM and async LLM using Qwen2.5-1.5B. + Expects async_llm to be at least 1.3x faster than ref_llm. + ''' + min_speed_up = 1.3 + _test_performance_helper( + monkeypatch, sampling_config, model_name, min_speed_up) + + +def _test_correctness_helper( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_name: str, +): + ''' + Helper function to test async scheduler correctness. + Compare the outputs of a original LLM and a async LLM + should be the same when using async scheduler decoding. + + Known Edge Case (KV Cache Swapping): + Under this case, though the temperature is set to 0, + the output is still slightly different everytime. + This is an expected behaviour as the normal scheduler also + behaves the same and hence, it is difficult to design a test + for such scenario. + ''' + with monkeypatch.context(): + test_prompts = get_correctness_test_prompts() + + ref_llm = LLM(model=model_name, + max_model_len=1024, + max_num_seqs=100, + async_scheduling=0) + ref_outputs = ref_llm.generate(test_prompts, sampling_config) + + del ref_llm + + # Waiting for TPUs to be released. + time.sleep(10) + + async_llm = LLM(model=model_name, + max_model_len=1024, + max_num_seqs=100, + async_scheduling=1) + async_outputs = async_llm.generate(test_prompts, sampling_config) + + matches = 0 + misses = 0 + for ref_output, async_output in zip(ref_outputs, async_outputs): + if ref_output.outputs[0].text == async_output.outputs[0].text: + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"async_output: {async_output.outputs[0].text}") + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"async_output: {async_output.outputs[0].text}") + + assert misses == 0 + del async_outputs + + # Waiting for TPUs to be released. + time.sleep(10) +def test_correctness( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_name: str, +): + ''' + Compare the outputs of a original LLM and a async LLM + should be the same when using async scheduler. + ''' + + _test_correctness_helper( + monkeypatch, sampling_config, model_name) + diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 96d2a1988..6a34e7ed7 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -75,6 +75,8 @@ def capture_model(self) -> None: self._precompile_backbone_text_only() if self.runner.is_multimodal_model: self._precompile_backbone_with_inputs_embeds() + if self.runner.scheduler_config.async_scheduling: + self._precompile_substitute_placeholder_token() self._precompile_select_from_array() self._precompile_compute_logits() self._precompile_disagg_utils() @@ -148,6 +150,41 @@ def model_fn_wrapper( num_tokens=num_tokens, ) + def _precompile_substitute_placeholder_token(self) -> None: + """Precompiles the token substitution function for all expected input shapes. + + It iterates through all potential padded token lengths + (`num_tokens_paddings`) and request batch sizes (`num_reqs_paddings`) + that the scheduler is expected to handle, ensuring a compiled version + is ready for each combination. + """ + + for num_tokens in self.runner.num_tokens_paddings: + padded_token_in_tpu_cur_input_indices = np.zeros((num_tokens, ), + dtype=np.int32) + padded_token_in_tpu_pre_next_tokens_indices = np.zeros( + (num_tokens, ), dtype=jnp.int32) + for num_reqs in self.runner.num_reqs_paddings: + input_ids = self._create_dummy_tensor((num_tokens, ), + jnp.int32) + # Need align to the sampling output + next_tokens = self._create_dummy_tensor( + (num_reqs, ), + jnp.int32, + sharding=NamedSharding(self.runner.mesh, PartitionSpec())) + placeholder_num = 1 + self._run_compilation( + "_substitute_placeholder_token_fn", + self.runner._substitute_placeholder_token_fn, + input_ids, + padded_token_in_tpu_cur_input_indices, + padded_token_in_tpu_pre_next_tokens_indices, + next_tokens, + placeholder_num, + num_tokens=num_tokens, + num_reqs=num_reqs, + ) + def _precompile_backbone_text_only(self) -> None: for num_tokens in self.runner.num_tokens_paddings: input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32) diff --git a/tpu_inference/runner/tpu_jax_runner.py b/tpu_inference/runner/tpu_jax_runner.py index 988e61920..e02ff52d5 100644 --- a/tpu_inference/runner/tpu_jax_runner.py +++ b/tpu_inference/runner/tpu_jax_runner.py @@ -1,7 +1,9 @@ +import copy import functools import os import random from contextlib import nullcontext +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, cast import jax @@ -21,8 +23,8 @@ from vllm.utils.math_utils import cdiv from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, - ModelRunnerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, ModelRunnerOutput) from vllm.v1.request import Request from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.worker.kv_connector_model_runner_mixin import \ @@ -80,6 +82,77 @@ } +class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput): + """Holds asynchronous model output specifically from a TPU runner. + + This class acts as a wrapper around the standard ModelRunnerOutput. Its + primary purpose is to hold references to data still on the TPU device + (like the `next_tokens` JAX array) without blocking the main thread. + + The `get_output()` method is called to resolve these async results, + triggering the JAX device-to-host (CPU) data transfer and populating + the final `ModelRunnerOutput` object. + """ + + def __init__( + self, + model_runner_output: ModelRunnerOutput, + next_tokens: jax.Array, + num_reqs: int, + discard_sampled_tokens_req_indices: list[int], + ): + self._model_runner_output = model_runner_output + self._next_tokens = next_tokens + self._num_reqs = num_reqs + self._discard_sampled_tokens_req_indices = discard_sampled_tokens_req_indices + + def get_output(self) -> ModelRunnerOutput: + next_tokens_cpu = np.asarray(jax.device_get(self._next_tokens)) + selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs], + 1) + valid_sampled_token_ids = selected_token_ids.tolist() + for i in self._discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + self._model_runner_output.sampled_token_ids = valid_sampled_token_ids + return self._model_runner_output + + +@dataclass +class AsyncPreResults: + req_ids: list[str] + next_tokens: jax.Array + request_seq_lens: list[tuple[int, CachedRequestState, int]] + discard_sampled_tokens_req_indices: list[int] + placeholder_req_id_to_index: dict[str, int] + + +@functools.partial(jax.jit, donate_argnums=(0, 1, 2)) +def _substitute_placeholder_token( + input_ids: jax.Array, token_in_tpu_cur_input_indices: jax.Array, + token_in_tpu_pre_next_tokens_indices: jax.Array, + next_tokens: jax.Array, placeholder_num: int): + """Substitute placeholder tokens from TPU for async scheduler + + Args: + input_ids: possible input_ids size + token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids. + token_in_tpu_pre_next_tokens_indices: value idx in next_tokens. Length the same to input_ids. + next_tokens: next tokens on the TPU from previous step. + placeholder_num: number of placeholders. placeholder_num <= len(token_in_tpu_cur_input_indices) + Return: + input_ids after replace placeholder tokens + """ + assert input_ids.shape == token_in_tpu_cur_input_indices.shape == token_in_tpu_pre_next_tokens_indices.shape, \ + f"Shape mismatch: input_ids and index arrays must have identical shapes due to precompilation assumptions. " \ + f"Got: {input_ids.shape=}, {token_in_tpu_cur_input_indices.shape=}, {token_in_tpu_pre_next_tokens_indices.shape=}" + + # updates the input_ids for all placeholders. + mask = jnp.arange(input_ids.shape[0]) < placeholder_num + new_token_values = next_tokens[token_in_tpu_pre_next_tokens_indices] + original_values = input_ids[token_in_tpu_cur_input_indices] + update_values = jnp.where(mask, new_token_values, original_values) + return input_ids.at[token_in_tpu_cur_input_indices].set(update_values) + class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin): def __init__( @@ -139,6 +212,9 @@ def __init__( self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] + self._pre_async_results: AsyncPreResults | None = None + self._substitute_placeholder_token_fn = _substitute_placeholder_token + def _init_random(self): if self.model_config.seed is None: self.model_config.seed = 0 @@ -343,13 +419,88 @@ def execute_model( self, scheduler_output: "VllmSchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> ModelRunnerOutput: + ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput: + return self._execute_model(scheduler_output)[1] + def _modify_prev_results(self): + # If copy to host has not been done, we just wait. + # device_get should return immediately as we have scheduled it in previous function call. + assert self._pre_async_results is not None, "When we call _modify_prev_results(), self._pre_async_results should already exist" + pre_req_ids = self._pre_async_results.req_ids + pre_next_tokens = self._pre_async_results.next_tokens + pre_request_seq_lens = self._pre_async_results.request_seq_lens + pre_discard_sampled_tokens_req_indices = self._pre_async_results.discard_sampled_tokens_req_indices + + next_tokens_cpu = np.asarray(jax.device_get(pre_next_tokens)) + selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)], + 1) + valid_sampled_token_ids = selected_token_ids.tolist() + + # Mask out the sampled tokens that should not be sampled. + for i in pre_discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + # Append sampled tokens + for pre_req_idx, req_state, _ in pre_request_seq_lens: + sampled_ids = valid_sampled_token_ids[pre_req_idx] + if not sampled_ids: + continue + + # If request not active in the *current* batch (e.g. finished or evicted), skip it. + req_id = pre_req_ids[pre_req_idx] + if req_id not in self.input_batch.req_id_to_index: + continue + + req_idx = self.input_batch.req_id_to_index[req_id] + assert req_state is self.requests[ + req_id], "The req_state should be valid and identical" + + # Updated on previous execute + end_idx = self.input_batch.num_tokens_no_spec[req_idx] + assert len(sampled_ids) == 1, "do not support spec decode yet" + start_idx = end_idx - 1 + assert end_idx <= self.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.max_model_len}") + + self.input_batch.token_ids_cpu[req_idx, + start_idx:end_idx] = sampled_ids + # Replace previous placeholder + req_state.output_token_ids[-1] = sampled_ids[-1] + + def _update_placeholder(self, discard_sampled_tokens_req_indices, + request_seq_lens): + placeholder_req_id_to_index: dict[str, int] = {} + discard_sampled_tokens_req_indices_set = set( + discard_sampled_tokens_req_indices) + for req_idx, req_state, _ in request_seq_lens: + if req_idx in discard_sampled_tokens_req_indices_set: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + # Not supporting spec decode yet, assume only 1 new token + end_idx = start_idx + 1 + assert end_idx <= self.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.max_model_len}") + + # Update cpu tokens at next execute and prepare input from tpu + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + + # For placeholder, should be update on next execute. + req_state.output_token_ids.extend([0]) + + placeholder_req_id_to_index[req_state.req_id] = req_idx + return placeholder_req_id_to_index + def _execute_model( self, scheduler_output: "VllmSchedulerOutput", - ) -> tuple[AttentionMetadata, ModelRunnerOutput]: + ) -> tuple[AttentionMetadata, ModelRunnerOutput + | AsyncTPUModelRunnerOutput]: self.persistent_batch_manager.update_states( scheduler_output, self.get_mrope_input_positions_fn) if not scheduler_output.total_num_scheduled_tokens: @@ -470,7 +621,7 @@ def _execute_model( num_reqs = self.input_batch.num_reqs # Update the cache state concurrently. Code above will not block until - # we use `selected_token_ids`. Add mark_step if post-processing changes + # We use `selected_token_ids`. Add mark_step if post-processing changes request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] discard_sampled_tokens_req_indices = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): @@ -501,6 +652,51 @@ def _execute_model( for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None + # If async scheduler enabled + if self.scheduler_config.async_scheduling: + # Get previous results from TPU and replace the placeholder. + if self._pre_async_results is not None: + assert not self.speculative_config and spec_decode_metadata is None, "Async scheduler does not support speculative decoding yet." + self._modify_prev_results() + + # Set placeholder for next tokens that is not yet generated + placeholder_req_id_to_index: dict[ + str, int] = self._update_placeholder( + discard_sampled_tokens_req_indices, request_seq_lens) + + if logprobs is not None: + logprobs_lists = logprobs.tolists() + else: + logprobs_lists = None + + # Save the previous results + next_tokens = jax.copy_to_host_async(next_tokens) + self._pre_async_results = AsyncPreResults( + req_ids=req_ids, + next_tokens=next_tokens, + request_seq_lens=request_seq_lens, + discard_sampled_tokens_req_indices= + discard_sampled_tokens_req_indices, + placeholder_req_id_to_index=placeholder_req_id_to_index, + ) + + # Return Model output to executor + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=copy.deepcopy( + self.input_batch.req_id_to_index), + sampled_token_ids=[], # Fill in async get + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + ) + # Return attn_metadata, model_runner_output + async_model_runner_output = AsyncTPUModelRunnerOutput( + model_runner_output, next_tokens, num_reqs, + discard_sampled_tokens_req_indices) + return attn_metadata, async_model_runner_output + if spec_decode_metadata is None: next_tokens = np.asarray(jax.device_get(next_tokens)) selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1) @@ -596,6 +792,30 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"): # For each scheduled token, what are the corresponding req index. req_indices = np.repeat(self.arange_cpu[:num_reqs], num_scheduled_tokens_per_req) + token_in_tpu_cur_input_indices = np.array([]) + token_in_tpu_pre_next_tokens_indices = np.array([]) + if self.scheduler_config.async_scheduling and self._pre_async_results is not None: + # If async previous results exists, we will prepare for the token substitution here + # The actual substitution will be performed in tpu during later parts of this function. + token_in_tpu_cur_input_indices_list = [] + token_in_tpu_pre_next_tokens_indices_list = [] + acc_cur_len = 0 + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + acc_cur_len += num_scheduled_tokens_per_req[i] + assert req_id is not None + if req_id not in self._pre_async_results.placeholder_req_id_to_index: + continue + + token_in_tpu_cur_input_indices_list.append(acc_cur_len - 1) + token_in_tpu_pre_next_tokens_indices_list.append( + self._pre_async_results.placeholder_req_id_to_index[req_id] + ) + + if len(token_in_tpu_cur_input_indices_list) > 0: + token_in_tpu_cur_input_indices = np.array( + token_in_tpu_cur_input_indices_list) + token_in_tpu_pre_next_tokens_indices = np.array( + token_in_tpu_pre_next_tokens_indices_list) # Get batched arange. # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] @@ -701,6 +921,21 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"): self.mesh, (input_ids, positions, block_tables, query_start_loc, seq_lens, logits_indices, request_distribution)) + if self.scheduler_config.async_scheduling and len( + token_in_tpu_cur_input_indices) > 0: + assert self._pre_async_results is not None + idx_pad_len = len(input_ids) - len(token_in_tpu_cur_input_indices) + padded_token_in_tpu_cur_input_indices = np.pad( + token_in_tpu_cur_input_indices, (0, idx_pad_len), mode='constant', constant_values=-1) + padded_token_in_tpu_pre_next_tokens_indices = np.pad( + token_in_tpu_pre_next_tokens_indices, (0, idx_pad_len), mode='constant', constant_values=-1) + with self.maybe_forbid_compile: + input_ids = self._substitute_placeholder_token_fn( + input_ids, padded_token_in_tpu_cur_input_indices, + padded_token_in_tpu_pre_next_tokens_indices, + self._pre_async_results.next_tokens, + len(token_in_tpu_cur_input_indices)) + if self.lora_config is not None: self.lora_utils.set_active_loras( num_scheduled_tokens_per_req, total_num_scheduled_tokens,