Skip to content

Commit 891d76c

Browse files
committed
pytests
1 parent 885a3f2 commit 891d76c

File tree

1 file changed

+234
-0
lines changed

1 file changed

+234
-0
lines changed

tests/e2e/test_async_scheduler.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
from __future__ import annotations
2+
3+
import random
4+
import string
5+
import time
6+
7+
import pytest
8+
from vllm import LLM, SamplingParams
9+
10+
@pytest.fixture
11+
def sampling_config():
12+
return SamplingParams(temperature=0,
13+
max_tokens=120,
14+
ignore_eos=True,
15+
repetition_penalty=1,
16+
frequency_penalty=0,
17+
presence_penalty=0,
18+
min_p=0,
19+
logprobs=None)
20+
@pytest.fixture
21+
def model_name():
22+
return "Qwen/Qwen2.5-1.5B-Instruct"
23+
24+
def get_performance_test_prompts():
25+
"""
26+
Generates a list of prompts with a specific word count,
27+
28+
Returns:
29+
A list of strings with number of prompts = num_prompts and
30+
The total number of words for each prompt = input_len_words.
31+
"""
32+
num_prompts=500
33+
input_len_words=120
34+
prompts = []
35+
36+
# For example w = 's'
37+
# The generated prompt will be Keep repeating: s s s ...
38+
num_repetitions = input_len_words
39+
prefix = "Keep repeating: "
40+
41+
for _ in range(num_prompts):
42+
# 1. Pick a random lowercase letter
43+
w = random.choice(list(string.ascii_lowercase))
44+
45+
# 2. Create the string of repeated words
46+
# This will have (num_repetitions) words
47+
repeating_part = " ".join([w] * num_repetitions)
48+
49+
# 3. Combine with the prefix (if any)
50+
print(f"{prefix}{repeating_part}")
51+
prompts.append(f"{prefix}{repeating_part}")
52+
53+
return prompts
54+
55+
def get_correctness_test_prompts():
56+
"""
57+
Returns a static list of prompts designed to test a model's
58+
ability to follow complex instructions and ensure correctness.
59+
60+
Returns:
61+
A list of strings, where each string is a test prompt.
62+
"""
63+
64+
prompts = [
65+
(
66+
"Write a short story about a librarian who discovers a book that "
67+
"writes itself. Write it in 1900s English style. Make sure there "
68+
"are no mistakes. This is my homework and I want perfection."
69+
),
70+
(
71+
"Compose a poem about the sound of a city at night. Write it in "
72+
"Shakespear style. Make sure there are no mistakes. This is my "
73+
"homework and I want perfection."
74+
),
75+
(
76+
"Write a dialogue between a time traveler and a medieval blacksmith "
77+
"who is skeptical of their claims. Make sure there are no mistakes."
78+
),
79+
80+
(
81+
"Explain the process of photosynthesis as if to a 5th grader, "
82+
"but without losing any scientific accuracy. Every step must be "
83+
"correct and in the right order. I will be checking this against a textbook."
84+
),
85+
(
86+
"Write a Python function that finds the median of a list of numbers. "
87+
"It must correctly handle both even and odd-sized lists, "
88+
"as well as unsorted lists. Provide a perfect, bug-free "
89+
"implementation. I will be running unit tests on it."
90+
),
91+
(
92+
"List the first 10 presidents of the United States. Format the "
93+
"output as a JSON array, where each object has two keys: 'name' "
94+
"and 'term_years'. The JSON must be perfectly valid, and all "
95+
"names and dates must be 100% accurate. This is for a production system."
96+
)
97+
]
98+
99+
return prompts
100+
101+
def _test_performance_helper(
102+
monkeypatch: pytest.MonkeyPatch,
103+
sampling_config: SamplingParams,
104+
model_name: str,
105+
min_speedup: float
106+
):
107+
'''
108+
Helper function to test async scheduler decoding performance.
109+
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
110+
'''
111+
112+
with monkeypatch.context():
113+
# Use a smaller set of prompts for performance testing
114+
test_prompts = get_performance_test_prompts() # num_prompts=100, input_len=120
115+
116+
# Test reference LLM timing
117+
ref_llm = LLM(model=model_name,
118+
max_model_len=800,
119+
max_num_seqs=24,
120+
max_num_batched_tokens=512,
121+
enable_prefix_caching=False)
122+
123+
start_time = time.time()
124+
_ = ref_llm.generate(test_prompts, sampling_config)
125+
ref_time = time.time() - start_time
126+
127+
del ref_llm
128+
# Waiting for TPUs to be released
129+
time.sleep(10)
130+
131+
# # Test async LLM timing with max_num_seqs=256
132+
async_llm = LLM(model=model_name,
133+
max_model_len=800,
134+
max_num_seqs=24,
135+
max_num_batched_tokens=512,
136+
enable_prefix_caching=False,
137+
async_scheduling=1)
138+
139+
start_time = time.time()
140+
_ = async_llm.generate(test_prompts, sampling_config)
141+
async_time = time.time() - start_time
142+
143+
del async_llm
144+
# # Waiting for TPUs to be released
145+
time.sleep(10)
146+
147+
speedup = ref_time / async_time
148+
print(f"Reference LLM time: {ref_time:.2f}s")
149+
print(f"Async LLM time: {async_time:.2f}s")
150+
print(f"Speedup: {speedup:.2f}x")
151+
152+
assert speedup >= min_speedup, f"Expected at least {min_speedup}x speedup for async scheduler, got {speedup:.2f}x"
153+
154+
def test_performance(
155+
monkeypatch: pytest.MonkeyPatch,
156+
sampling_config: SamplingParams,
157+
model_name: str,
158+
):
159+
'''
160+
Test that async scheduler decoding provides significant performance improvement.
161+
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
162+
Expects async_llm to be at least 1.3x faster than ref_llm.
163+
'''
164+
min_speed_up = 1.3
165+
_test_performance_helper(
166+
monkeypatch, sampling_config, model_name, min_speed_up)
167+
168+
169+
def _test_correctness_helper(
170+
monkeypatch: pytest.MonkeyPatch,
171+
sampling_config: SamplingParams,
172+
model_name: str,
173+
):
174+
'''
175+
Helper function to test async scheduler correctness.
176+
Compare the outputs of a original LLM and a async LLM
177+
should be the same when using async scheduler decoding.
178+
179+
Known Edge Case (KV Cache Swapping):
180+
Under this case, though the temperature is set to 0,
181+
the output is still slightly different everytime.
182+
This is an expected behaviour as the normal scheduler also
183+
behaves the same and hence, it is difficult to design a test
184+
for such scenario.
185+
'''
186+
with monkeypatch.context():
187+
test_prompts = get_correctness_test_prompts()
188+
189+
ref_llm = LLM(model=model_name,
190+
max_model_len=1024,
191+
max_num_seqs=100)
192+
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
193+
194+
del ref_llm
195+
196+
# Waiting for TPUs to be released.
197+
time.sleep(10)
198+
199+
async_llm = LLM(model=model_name,
200+
max_model_len=1024,
201+
max_num_seqs=100,
202+
async_scheduling=1)
203+
async_outputs = async_llm.generate(test_prompts, sampling_config)
204+
205+
matches = 0
206+
misses = 0
207+
for ref_output, async_output in zip(ref_outputs, async_outputs):
208+
if ref_output.outputs[0].text == async_output.outputs[0].text:
209+
print(f"ref_output: {ref_output.outputs[0].text}")
210+
print(f"async_output: {async_output.outputs[0].text}")
211+
matches += 1
212+
else:
213+
misses += 1
214+
print(f"ref_output: {ref_output.outputs[0].text}")
215+
print(f"async_output: {async_output.outputs[0].text}")
216+
217+
assert misses == 0
218+
del async_outputs
219+
220+
# Waiting for TPUs to be released.
221+
time.sleep(10)
222+
def test_correctness(
223+
monkeypatch: pytest.MonkeyPatch,
224+
sampling_config: SamplingParams,
225+
model_name: str,
226+
):
227+
'''
228+
Compare the outputs of a original LLM and a async LLM
229+
should be the same when using async scheduler.
230+
'''
231+
232+
_test_correctness_helper(
233+
monkeypatch, sampling_config, model_name)
234+

0 commit comments

Comments
 (0)