Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions tests/e2e/test_async_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from __future__ import annotations

import random
import string
import time

import pytest
from vllm import LLM, SamplingParams

@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0,
max_tokens=120,
ignore_eos=True,
repetition_penalty=1,
frequency_penalty=0,
presence_penalty=0,
min_p=0,
logprobs=None)
@pytest.fixture
def model_name():
return "Qwen/Qwen2.5-1.5B-Instruct"

def get_performance_test_prompts():
"""
Generates a list of prompts with a specific word count,

Returns:
A list of strings with number of prompts = num_prompts and
The total number of words for each prompt = input_len_words.
"""
num_prompts=500
input_len_words=120
prompts = []

# For example w = 's'
# The generated prompt will be Keep repeating: s s s ...
num_repetitions = input_len_words
prefix = "Keep repeating: "

for _ in range(num_prompts):
# 1. Pick a random lowercase letter
w = random.choice(list(string.ascii_lowercase))

# 2. Create the string of repeated words
# This will have (num_repetitions) words
repeating_part = " ".join([w] * num_repetitions)

# 3. Combine with the prefix (if any)
print(f"{prefix}{repeating_part}")
prompts.append(f"{prefix}{repeating_part}")

return prompts

def get_correctness_test_prompts():
"""
Returns a static list of prompts designed to test a model's
ability to follow complex instructions and ensure correctness.

Returns:
A list of strings, where each string is a test prompt.
"""

prompts = [
(
"Write a short story about a librarian who discovers a book that "
"writes itself. Write it in 1900s English style. Make sure there "
"are no mistakes. This is my homework and I want perfection."
),
(
"Compose a poem about the sound of a city at night. Write it in "
"Shakespear style. Make sure there are no mistakes. This is my "
"homework and I want perfection."
),
(
"Write a dialogue between a time traveler and a medieval blacksmith "
"who is skeptical of their claims. Make sure there are no mistakes."
),

(
"Explain the process of photosynthesis as if to a 5th grader, "
"but without losing any scientific accuracy. Every step must be "
"correct and in the right order. I will be checking this against a textbook."
),
(
"Write a Python function that finds the median of a list of numbers. "
"It must correctly handle both even and odd-sized lists, "
"as well as unsorted lists. Provide a perfect, bug-free "
"implementation. I will be running unit tests on it."
),
(
"List the first 10 presidents of the United States. Format the "
"output as a JSON array, where each object has two keys: 'name' "
"and 'term_years'. The JSON must be perfectly valid, and all "
"names and dates must be 100% accurate. This is for a production system."
)
]

return prompts

def _test_performance_helper(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
min_speedup: float
):
'''
Helper function to test async scheduler decoding performance.
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
'''

with monkeypatch.context():
# Use a smaller set of prompts for performance testing
test_prompts = get_performance_test_prompts() # num_prompts=100, input_len=120

# Test reference LLM timing
ref_llm = LLM(model=model_name,
max_model_len=800,
max_num_seqs=24,
max_num_batched_tokens=512,
enable_prefix_caching=False,
async_scheduling=0)

start_time = time.time()
_ = ref_llm.generate(test_prompts, sampling_config)
ref_time = time.time() - start_time

del ref_llm
# Waiting for TPUs to be released
time.sleep(10)

# # Test async LLM timing with max_num_seqs=256
async_llm = LLM(model=model_name,
max_model_len=800,
max_num_seqs=24,
max_num_batched_tokens=512,
enable_prefix_caching=False,
async_scheduling=1)

start_time = time.time()
_ = async_llm.generate(test_prompts, sampling_config)
async_time = time.time() - start_time

del async_llm
# # Waiting for TPUs to be released
time.sleep(10)

speedup = ref_time / async_time
print(f"Reference LLM time: {ref_time:.2f}s")
print(f"Async LLM time: {async_time:.2f}s")
print(f"Speedup: {speedup:.2f}x")

assert speedup >= min_speedup, f"Expected at least {min_speedup}x speedup for async scheduler, got {speedup:.2f}x"

def test_performance(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
):
'''
Test that async scheduler decoding provides significant performance improvement.
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
Expects async_llm to be at least 1.3x faster than ref_llm.
'''
min_speed_up = 1.3
_test_performance_helper(
monkeypatch, sampling_config, model_name, min_speed_up)


def _test_correctness_helper(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
):
'''
Helper function to test async scheduler correctness.
Compare the outputs of a original LLM and a async LLM
should be the same when using async scheduler decoding.

Known Edge Case (KV Cache Swapping):
Under this case, though the temperature is set to 0,
the output is still slightly different everytime.
This is an expected behaviour as the normal scheduler also
behaves the same and hence, it is difficult to design a test
for such scenario.
'''
with monkeypatch.context():
test_prompts = get_correctness_test_prompts()

ref_llm = LLM(model=model_name,
max_model_len=1024,
max_num_seqs=100,
async_scheduling=0)
ref_outputs = ref_llm.generate(test_prompts, sampling_config)

del ref_llm

# Waiting for TPUs to be released.
time.sleep(10)

async_llm = LLM(model=model_name,
max_model_len=1024,
max_num_seqs=100,
async_scheduling=1)
async_outputs = async_llm.generate(test_prompts, sampling_config)

matches = 0
misses = 0
for ref_output, async_output in zip(ref_outputs, async_outputs):
if ref_output.outputs[0].text == async_output.outputs[0].text:
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"async_output: {async_output.outputs[0].text}")
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"async_output: {async_output.outputs[0].text}")

assert misses == 0
del async_outputs

# Waiting for TPUs to be released.
time.sleep(10)
def test_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
):
'''
Compare the outputs of a original LLM and a async LLM
should be the same when using async scheduler.
'''

_test_correctness_helper(
monkeypatch, sampling_config, model_name)

37 changes: 37 additions & 0 deletions tpu_inference/runner/compilation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def capture_model(self) -> None:
self._precompile_backbone_text_only()
if self.runner.is_multimodal_model:
self._precompile_backbone_with_inputs_embeds()
if self.runner.scheduler_config.async_scheduling:
self._precompile_substitute_placeholder_token()
self._precompile_select_from_array()
self._precompile_compute_logits()
self._precompile_disagg_utils()
Expand Down Expand Up @@ -148,6 +150,41 @@ def model_fn_wrapper(
num_tokens=num_tokens,
)

def _precompile_substitute_placeholder_token(self) -> None:
"""Precompiles the token substitution function for all expected input shapes.

It iterates through all potential padded token lengths
(`num_tokens_paddings`) and request batch sizes (`num_reqs_paddings`)
that the scheduler is expected to handle, ensuring a compiled version
is ready for each combination.
"""

for num_tokens in self.runner.num_tokens_paddings:
padded_token_in_tpu_cur_input_indices = np.zeros((num_tokens, ),
dtype=np.int32)
padded_token_in_tpu_pre_next_tokens_indices = np.zeros(
(num_tokens, ), dtype=jnp.int32)
for num_reqs in self.runner.num_reqs_paddings:
input_ids = self._create_dummy_tensor((num_tokens, ),
jnp.int32)
# Need align to the sampling output
next_tokens = self._create_dummy_tensor(
(num_reqs, ),
jnp.int32,
sharding=NamedSharding(self.runner.mesh, PartitionSpec()))
placeholder_num = 1
self._run_compilation(
"_substitute_placeholder_token_fn",
self.runner._substitute_placeholder_token_fn,
input_ids,
padded_token_in_tpu_cur_input_indices,
padded_token_in_tpu_pre_next_tokens_indices,
next_tokens,
placeholder_num,
num_tokens=num_tokens,
num_reqs=num_reqs,
)

def _precompile_backbone_text_only(self) -> None:
for num_tokens in self.runner.num_tokens_paddings:
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
Expand Down
Loading