Skip to content

Commit 774c5fd

Browse files
authored
[V1] fix torch profiling for V1 offline scenarios (vllm-project#18445)
Signed-off-by: Divakar Verma <[email protected]>
1 parent 9a21e33 commit 774c5fd

File tree

4 files changed

+23
-51
lines changed

4 files changed

+23
-51
lines changed

benchmarks/benchmark_latency.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
import json
77
import os
88
import time
9-
from pathlib import Path
109
from typing import Any, Optional
1110

1211
import numpy as np
13-
import torch
1412
from tqdm import tqdm
1513

14+
import vllm.envs as envs
1615
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
1716
from vllm import LLM, SamplingParams
1817
from vllm.engine.arg_utils import EngineArgs
@@ -80,17 +79,9 @@ def llm_generate():
8079

8180
def run_to_completion(profile_dir: Optional[str] = None):
8281
if profile_dir:
83-
with torch.profiler.profile(
84-
activities=[
85-
torch.profiler.ProfilerActivity.CPU,
86-
torch.profiler.ProfilerActivity.CUDA,
87-
],
88-
on_trace_ready=torch.profiler.tensorboard_trace_handler(
89-
str(profile_dir)
90-
),
91-
) as p:
92-
llm_generate()
93-
print(p.key_averages().table(sort_by="self_cuda_time_total"))
82+
llm.start_profile()
83+
llm_generate()
84+
llm.stop_profile()
9485
else:
9586
start_time = time.perf_counter()
9687
llm_generate()
@@ -103,11 +94,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
10394
run_to_completion(profile_dir=None)
10495

10596
if args.profile:
106-
profile_dir = args.profile_result_dir
107-
if not profile_dir:
108-
profile_dir = (
109-
Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
110-
)
97+
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
11198
print(f"Profiling (results will be saved to '{profile_dir}')...")
11299
run_to_completion(profile_dir=profile_dir)
113100
return
@@ -164,15 +151,6 @@ def run_to_completion(profile_dir: Optional[str] = None):
164151
action="store_true",
165152
help="profile the generation process of a single batch",
166153
)
167-
parser.add_argument(
168-
"--profile-result-dir",
169-
type=str,
170-
default=None,
171-
help=(
172-
"path to save the pytorch profiler output. Can be visualized "
173-
"with ui.perfetto.dev or Tensorboard."
174-
),
175-
)
176154
parser.add_argument(
177155
"--output-json",
178156
type=str,
@@ -193,4 +171,9 @@ def run_to_completion(profile_dir: Optional[str] = None):
193171
# numbers. We need to disable prefix caching by default.
194172
parser.set_defaults(enable_prefix_caching=False)
195173
args = parser.parse_args()
174+
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
175+
raise OSError(
176+
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
177+
"Please set it to a valid path to use torch profiler."
178+
)
196179
main(args)

vllm/benchmarks/latency.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
import json
77
import os
88
import time
9-
from pathlib import Path
109
from typing import Any, Optional
1110

1211
import numpy as np
13-
import torch
1412
from tqdm import tqdm
1513

14+
import vllm.envs as envs
1615
from vllm import LLM, SamplingParams
1716
from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format,
1817
write_to_json)
@@ -59,13 +58,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
5958
action="store_true",
6059
help="profile the generation process of a single batch",
6160
)
62-
parser.add_argument(
63-
"--profile-result-dir",
64-
type=str,
65-
default=None,
66-
help=("path to save the pytorch profiler output. Can be visualized "
67-
"with ui.perfetto.dev or Tensorboard."),
68-
)
6961
parser.add_argument(
7062
"--output-json",
7163
type=str,
@@ -87,7 +79,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
8779

8880
def main(args: argparse.Namespace):
8981
print(args)
90-
82+
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
83+
raise OSError(
84+
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
85+
"Please set it to a valid path to use torch profiler.")
9186
engine_args = EngineArgs.from_cli_args(args)
9287

9388
# NOTE(woosuk): If the request cannot be processed in a single batch,
@@ -131,16 +126,9 @@ def llm_generate():
131126

132127
def run_to_completion(profile_dir: Optional[str] = None):
133128
if profile_dir:
134-
with torch.profiler.profile(
135-
activities=[
136-
torch.profiler.ProfilerActivity.CPU,
137-
torch.profiler.ProfilerActivity.CUDA,
138-
],
139-
on_trace_ready=torch.profiler.tensorboard_trace_handler(
140-
str(profile_dir)),
141-
) as p:
142-
llm_generate()
143-
print(p.key_averages().table(sort_by="self_cuda_time_total"))
129+
llm.start_profile()
130+
llm_generate()
131+
llm.stop_profile()
144132
else:
145133
start_time = time.perf_counter()
146134
llm_generate()
@@ -153,10 +141,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
153141
run_to_completion(profile_dir=None)
154142

155143
if args.profile:
156-
profile_dir = args.profile_result_dir
157-
if not profile_dir:
158-
profile_dir = (Path(".") / "vllm_benchmark_result" /
159-
f"latency_result_{time.time()}")
144+
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
160145
print(f"Profiling (results will be saved to '{profile_dir}')...")
161146
run_to_completion(profile_dir=profile_dir)
162147
return

vllm/v1/worker/gpu_worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,8 @@ def profile(self, is_start: bool = True):
292292
self.profiler.start()
293293
else:
294294
self.profiler.stop()
295+
print(self.profiler.key_averages().table(
296+
sort_by="self_cuda_time_total"))
295297

296298
def execute_dummy_batch(self) -> None:
297299
self.model_runner._dummy_run(1)

vllm/worker/worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def stop_profile(self):
128128
if self.profiler is None:
129129
raise RuntimeError("Profiler is not enabled.")
130130
self.profiler.stop()
131+
print(
132+
self.profiler.key_averages().table(sort_by="self_cuda_time_total"))
131133

132134
def sleep(self, level: int = 1) -> None:
133135
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]

0 commit comments

Comments
 (0)