Skip to content

Commit 8d856cf

Browse files
authored
[Profiling] Pull Over the TPU Profiler from vLLM + add profiling docs (#882)
Signed-off-by: Jacob Platin <[email protected]>
1 parent 42f7aac commit 8d856cf

File tree

3 files changed

+220
-1
lines changed

3 files changed

+220
-1
lines changed

docs/profiling.md

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Profiling
2+
3+
There are currently three ways to profile your workload:
4+
5+
## Using `examples/tpu_profiling.py`
6+
7+
### vLLM TPU Profiling Script
8+
9+
This script is a utility for profiling the performance of the vLLM engine on TPU VMs. It uses the JAX profiler to capture detailed performance traces.
10+
11+
The profiling results can be visualized using tools like TensorBoard (with the `tensorboard-plugin-profile` package) or Perfetto UI.
12+
13+
### How to Use
14+
15+
#### Prerequisites
16+
You must install the TensorBoard profile plugin to visualize the results:
17+
18+
```bash
19+
pip install tensorboard-plugin-profile
20+
```
21+
22+
#### Basic Command
23+
The script is run from the command line, specifying the workload parameters and any necessary vLLM engine arguments.
24+
25+
```bash
26+
python3 examples/tpu_profiling.py --model <your-model-name> [OPTIONS]
27+
```
28+
29+
#### Key Arguments
30+
* `--model`: (Required) The name or path of the model to profile.
31+
* `--input-len`: The length of the input prompt tokens per request
32+
* `--output-len`: The number of tokens to generate per request.
33+
* `--batch-size`: The number of requests.
34+
* `--profile-result-dir`: The directory where the JAX profiler output will be saved.
35+
* The script also accepts all standard vLLM `EngineArgs` (e.g., `--tensor-parallel-size`, `--dtype`).
36+
37+
#### Examples
38+
39+
**1. Profile a Prefill Operation:**
40+
To profile a single request with a long input prompt (e.g., 1024 tokens), set `--input-len` high and `--batch-size` to 1.
41+
42+
```bash
43+
python3 examples/tpu_profiling.py \
44+
--model google/gemma-2b \
45+
--input-len 1024 \
46+
--output-len 1 \
47+
--batch-size 1
48+
```
49+
50+
**2. Profile a Decoding Operation:**
51+
To profile a large batch of single-token decoding steps, set `--input-len` and `--output-len` to 1 and use a large `--batch-size`.
52+
53+
```bash
54+
python3 examples/tpu_profiling.py \
55+
--model google/gemma-2b \
56+
--input-len 1 \
57+
--output-len 1 \
58+
--batch-size 256
59+
```
60+
61+
## Using `PHASED_PROFILING_DIR`
62+
If you set the following environment variable:
63+
64+
```
65+
66+
PHASED_PROFILING_DIR=<DESIRED PROFILING OUTPUT DIR>
67+
68+
```
69+
70+
we will automatically capture profiles during three phases of your workload (assuming they are encountered):
71+
1. Prefill-heavy (the quotient of prefill / total scheduled tokens for the given batch is => 0.9)
72+
2. Decode-heavy (the quotient of prefill / total scheduled tokens for the given batch is <= 0.2)
73+
3. Mixed (the quotient of prefill / total scheduled tokens for the given batch is between 0.4 and 0.6)
74+
75+
To aid in your analysis, we will also log the batch composition for the profiled batches.
76+
77+
## Using `USE_JAX_PROFILER_SERVER`
78+
If you set the following environment variable:
79+
80+
```
81+
82+
USE_JAX_PROFILER_SERVER=True
83+
84+
```
85+
86+
you can instead manually decide when to capture a profile and for how long, which can helpful if your workload (e.g. E2E benchmarking) is
87+
large and taking a profile of the entire workload (i.e. using the above method) will generate a massive tracing file.
88+
89+
You can additionally set the desired profiling port (default is `9999`):
90+
91+
```
92+
93+
JAX_PROFILER_SERVER_PORT=XXXX
94+
95+
```
96+
97+
In order to use this approach, you can do the following:
98+
99+
1. Run your typical `vllm serve` or `offline_inference` command (making sure to set `USE_JAX_PROFILER_SERVER=True`)
100+
2. Run your benchmarking command (`python benchmark_serving.py...`)
101+
3. Once the warmup has completed and your benchmark is running, start a new tensorboard instance with your `logdir` set to the desired output location of your profiles (e.g. `tensorboard --logdir=profiles/llama3-mmlu/`)
102+
4. Open the tensorboard instance and navigate to the `profile` page (e.g. `http://localhost:6006/#profile`)
103+
5. Click `Capture Profile` and, in the `Profile Service URL(s) or TPU name` box, enter `localhost:XXXX` where `XXXX` is your `JAX_PROFILER_SERVER_PORT` (default is `9999`)
104+
105+
6. Enter the desired amount of time (in ms)

examples/tpu_profiling.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Implements profiling for vLLM on TPU VMs using the JAX profiler.
4+
# NOTE: you will need the tensorboard-plugin-profile python package to
5+
# visualize the results in TensorBoard.
6+
# Please see docs/profiler.md for more details.
7+
# Usage example for prefilling 1 request of 1024 tokens:
8+
# python3 examples/tpu_profiling.py --input-len 1024 --output-len 1 --batch-size 1
9+
# Usage example for decoding 256 requests of 1 token each:
10+
# python3 examples/tpu_profiling.py --input-len 1 --output-len 1 --batch-size=256
11+
12+
import argparse
13+
import dataclasses
14+
import os
15+
import time
16+
17+
import numpy as np
18+
from tqdm import tqdm
19+
from vllm import LLM, SamplingParams
20+
from vllm.engine.arg_utils import EngineArgs
21+
from vllm.inputs import PromptType
22+
from vllm.utils import FlexibleArgumentParser
23+
24+
DURATION_MS = int(os.getenv("VLLM_TPU_PROFILE_DURATION_MS", 3000))
25+
DELAY_MS = int(os.getenv("VLLM_TPU_PROFILE_DELAY_MS", 0))
26+
27+
28+
def main(args: argparse.Namespace):
29+
print(args)
30+
31+
# Profile
32+
profile_dir = args.profile_result_dir
33+
print(f"Profiling (results will be saved to '{profile_dir}')...")
34+
os.environ["VLLM_TORCH_PROFILER_DIR"] = profile_dir
35+
36+
engine_args = EngineArgs.from_cli_args(args)
37+
llm = LLM(**dataclasses.asdict(engine_args))
38+
39+
sampling_params = SamplingParams(
40+
temperature=0.0,
41+
ignore_eos=True,
42+
max_tokens=args.output_len,
43+
)
44+
print(sampling_params)
45+
dummy_prompt_token_ids = np.random.randint(10000,
46+
size=(args.batch_size,
47+
args.input_len))
48+
dummy_prompts: list[PromptType] = [{
49+
"prompt_token_ids": batch
50+
} for batch in dummy_prompt_token_ids.tolist()]
51+
52+
def run_to_completion():
53+
start_time = time.perf_counter()
54+
llm.generate(dummy_prompts,
55+
sampling_params=sampling_params,
56+
use_tqdm=False)
57+
end_time = time.perf_counter()
58+
latency = end_time - start_time
59+
return latency
60+
61+
# Warmup
62+
print("Warming up...")
63+
warmup_latencies = []
64+
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
65+
warmup_latencies.append(run_to_completion())
66+
print(f"Average warmup latency: {np.mean(warmup_latencies):.4f}s")
67+
68+
# Enable tracing on server
69+
llm.start_profile()
70+
if DELAY_MS == 0:
71+
time.sleep(1.0)
72+
profile_latencies = []
73+
for _ in tqdm(range(args.num_iters), desc="Profile iterations"):
74+
profile_latencies.append(run_to_completion())
75+
llm.stop_profile()
76+
print(f"Average profile latency: {np.mean(profile_latencies):.4f}s")
77+
78+
return
79+
80+
81+
def parse_args():
82+
parser = FlexibleArgumentParser(
83+
description="Benchmark the latency of processing a single batch of "
84+
"requests till completion.")
85+
parser.add_argument("--input-len", type=int, default=32)
86+
parser.add_argument("--output-len", type=int, default=128)
87+
parser.add_argument("--batch-size", type=int, default=8)
88+
parser.add_argument(
89+
"--num-iters-warmup",
90+
type=int,
91+
default=5,
92+
help="Number of iterations to run for warmup.",
93+
)
94+
parser.add_argument(
95+
"--num-iters",
96+
type=int,
97+
default=1,
98+
help="Number of iterations to run for profiling.",
99+
)
100+
parser.add_argument(
101+
"--profile-result-dir",
102+
type=str,
103+
default="profiles",
104+
help=("path to save the JAX profiler output. Can be visualized "
105+
"with ui.perfetto.dev, Tensorboard, or XProf"),
106+
)
107+
108+
parser = EngineArgs.add_cli_args(parser)
109+
return parser.parse_args()
110+
111+
112+
if __name__ == "__main__":
113+
args = parse_args()
114+
main(args)

tpu_inference/runner/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def step(self, batch_composition_stats: dict) -> None:
414414
have_seen_all_phases = all(self.inference_phase_seen.values())
415415
# We want to start profiling only after the first trial request
416416
is_past_initial_request = batch_composition_stats[
417-
"num_reqs"] >= 1 and batch_composition_stats[
417+
"num_reqs"] > 1 and batch_composition_stats[
418418
"total_num_scheduled_tokens"] > 1
419419
if is_past_initial_request and (not have_seen_all_phases
420420
or self.current_phase != ""):

0 commit comments

Comments
 (0)