Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 134 additions & 65 deletions skyrl-tx/benchmarks/benchmark_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Features:
- Test sampling, training, or both modes
- Sweep across multiple batch sizes and sequence lengths
- Separate JIT compilation time vs post-JIT runtime measurement
- Configurable number of post-JIT measurement iterations for averaging
- Early termination: skips remaining batch sizes if one fails (e.g., OOM)
- GPU memory monitoring via nvidia-smi polling
- Per-test server logs with JIT compilation time extraction
Expand Down Expand Up @@ -39,10 +41,14 @@
uv run --extra tinker python benchmarks/benchmark_memory.py \\
--backend-config '{"loss_chunk_size": 512, "enforce_eager": true}'

# Run with multiple measurement iterations for more accurate post-JIT timing
uv run --extra tinker python benchmarks/benchmark_memory.py \\
--num-measurement-iters 3 --batch-sizes 8 --seq-lens 4096

Output directory (default: /tmp/skyrl_tx_memory_benchmark/):
tx_memory_benchmark_{experiment_name}_{timestamp}/
config.json # Full benchmark configuration (JSON)
results.csv # Results table (mode, batch, seq, status, peak_mem, e2e_time)
results.csv # Results table (mode, batch, seq, status, peak_mem, jit_time, post_jit_time)
tinker.db # SQLite database used by tinker API
server_*.log # Server stdout/stderr for each test run
xla_dump_*/ # XLA HLO graphs per test (if --dump-xla enabled)
Expand All @@ -58,6 +64,7 @@
import csv
import json
import os
import random
import re
import signal
import subprocess
Expand All @@ -67,7 +74,7 @@
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Literal
from typing import Callable, Literal

import httpx
import tinker
Expand Down Expand Up @@ -96,6 +103,7 @@ class BenchmarkConfig:
test_mode: Literal["sample", "train", "both"] = "both"
batch_sizes: list[int] = field(default_factory=lambda: [4, 8, 16, 32])
seq_lens: list[int] = field(default_factory=lambda: [8192])
num_measurement_iters: int = 3 # Number of post-JIT measurement iterations
server_only: bool = False

# Runtime configuration
Expand Down Expand Up @@ -131,7 +139,8 @@ class TestResult:
status: Literal["PASS", "FAIL", "ERROR"]
peak_gpu_mem_mib: int
jit_logs: list[str]
client_e2e_sec: float | None
jit_e2e_sec: float | None # First request (includes JIT compilation)
post_jit_e2e_sec: float | None # Average of subsequent requests (post-JIT)
error_message: str | None = None


Expand Down Expand Up @@ -426,9 +435,10 @@ def __init__(self, config: BenchmarkConfig):
self.config = config
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model)

def _make_datum(self, seq_len: int) -> types.Datum:
"""Create a training datum with specified sequence length."""
all_tokens = list(range(1, seq_len + 1))
def _make_datum(self, seq_len: int, rng: random.Random) -> types.Datum:
"""Create a training datum with random tokens."""
vocab_size = self.tokenizer.vocab_size
all_tokens = [rng.randint(1, vocab_size - 1) for _ in range(seq_len)]
target_tokens = all_tokens[1:] + [self.tokenizer.eos_token_id]
weights = [1.0] * seq_len

Expand All @@ -440,57 +450,99 @@ def _make_datum(self, seq_len: int) -> types.Datum:
},
)

def _test_sample(self, service_client, server: ServerManager, batch_size: int, seq_len: int) -> tuple[bool, float]:
"""Execute sampling test."""
def _run_timed_requests(
self,
run_single_request: Callable[[], tuple[bool, float]],
num_measurement_iters: int,
) -> tuple[bool, float, float]:
"""Run timed requests with JIT warmup and measurement iterations.

Args:
run_single_request: Callable that executes a single request and returns (success, elapsed_time)
num_measurement_iters: Number of post-JIT measurement iterations

Returns:
Tuple of (success, jit_time, post_jit_avg_time)
"""
# First request triggers JIT compilation
print(" Running warmup request (JIT compilation)...")
success, jit_time = run_single_request()
if not success:
return False, jit_time, 0.0

# Subsequent requests measure post-JIT performance
post_jit_times = []
for i in range(num_measurement_iters):
print(f" Running measurement request {i + 1}/{num_measurement_iters}...")
success, elapsed = run_single_request()
if not success:
return False, jit_time, 0.0
post_jit_times.append(elapsed)

avg_post_jit_time = sum(post_jit_times) / len(post_jit_times) if post_jit_times else 0.0
return True, jit_time, avg_post_jit_time

def _test_sample(
self, service_client, server: ServerManager, batch_size: int, seq_len: int, num_measurement_iters: int
) -> tuple[bool, float, float]:
"""Execute sampling test with warmup iterations."""
sampling_client = service_client.create_sampling_client(base_model=self.config.base_model)
vocab_size = self.tokenizer.vocab_size
rng = random.Random(42)

# Build prompt - half prompt, half generation
# Half prompt, half generation
prompt_len = seq_len // 2
gen_len = seq_len - prompt_len
base_tokens = self.tokenizer.encode("Hello, how are you doing today? ", add_special_tokens=True)
prompt_tokens = (base_tokens * ((prompt_len // len(base_tokens)) + 1))[:prompt_len]
prompt = types.ModelInput.from_ints(prompt_tokens)

start_time = time.time()
request = sampling_client.sample(
prompt=prompt,
sampling_params=types.SamplingParams(temperature=0.7, max_tokens=gen_len, seed=42),
num_samples=batch_size,
)
# Poll with small timeout to allow server aliveness checks
while True:
try:
result = request.result(timeout=5)
break
except TimeoutError:
if not server.is_alive():
raise RuntimeError("Server crashed during test")
elapsed = time.time() - start_time

return len(result.sequences) == batch_size, elapsed
def run_sample() -> tuple[bool, float]:
"""Run a single sample request and return (success, elapsed_time)."""
prompt_tokens = [rng.randint(1, vocab_size - 1) for _ in range(prompt_len)]
prompt = types.ModelInput.from_ints(prompt_tokens)

start_time = time.time()
request = sampling_client.sample(
prompt=prompt,
sampling_params=types.SamplingParams(temperature=0.7, max_tokens=gen_len, seed=42),
num_samples=batch_size,
)
# Poll with small timeout to allow server aliveness checks
while True:
try:
result = request.result(timeout=5)
break
except TimeoutError:
if not server.is_alive():
raise RuntimeError("Server crashed during test")
elapsed = time.time() - start_time
return len(result.sequences) == batch_size, elapsed

return self._run_timed_requests(run_sample, num_measurement_iters)

def _test_forward_backward(
self, service_client, server: ServerManager, batch_size: int, seq_len: int
) -> tuple[bool, float]:
"""Execute forward-backward test."""
self, service_client, server: ServerManager, batch_size: int, seq_len: int, num_measurement_iters: int
) -> tuple[bool, float, float]:
"""Execute forward-backward test with warmup iterations."""
training_client = service_client.create_lora_training_client(base_model=self.config.base_model)

# Create training data
data = [self._make_datum(seq_len) for _ in range(batch_size)]

start_time = time.time()
fwdbwd_future = training_client.forward_backward(data, "cross_entropy")
# Poll with small timeout to allow server aliveness checks
while True:
try:
result = fwdbwd_future.result(timeout=5)
break
except TimeoutError:
if not server.is_alive():
raise RuntimeError("Server crashed during test")
elapsed = time.time() - start_time

return len(result.loss_fn_outputs) == batch_size, elapsed
rng = random.Random(42)

def run_forward_backward() -> tuple[bool, float]:
"""Run a single forward-backward request and return (success, elapsed_time)."""
data = [self._make_datum(seq_len, rng) for _ in range(batch_size)]

start_time = time.time()
fwdbwd_future = training_client.forward_backward(data, "cross_entropy")
# Poll with small timeout to allow server aliveness checks
while True:
try:
result = fwdbwd_future.result(timeout=5)
break
except TimeoutError:
if not server.is_alive():
raise RuntimeError("Server crashed during test")
elapsed = time.time() - start_time
return len(result.loss_fn_outputs) == batch_size, elapsed

return self._run_timed_requests(run_forward_backward, num_measurement_iters)

def run_single_test(self, batch_size: int, seq_len: int, mode: str) -> TestResult:
"""Run a single benchmark test with given parameters."""
Expand All @@ -505,7 +557,8 @@ def run_single_test(self, batch_size: int, seq_len: int, mode: str) -> TestResul
status="ERROR",
peak_gpu_mem_mib=0,
jit_logs=[],
client_e2e_sec=None,
jit_e2e_sec=None,
post_jit_e2e_sec=None,
error_message=None,
)

Expand Down Expand Up @@ -534,14 +587,19 @@ def run_single_test(self, batch_size: int, seq_len: int, mode: str) -> TestResul
try:
print(f" Running {mode} test...")
if mode == "sample":
success, elapsed = self._test_sample(service_client, server, batch_size, seq_len)
success, jit_time, post_jit_time = self._test_sample(
service_client, server, batch_size, seq_len, self.config.num_measurement_iters
)
else:
success, elapsed = self._test_forward_backward(service_client, server, batch_size, seq_len)
success, jit_time, post_jit_time = self._test_forward_backward(
service_client, server, batch_size, seq_len, self.config.num_measurement_iters
)

# Collect results
result.peak_gpu_mem_mib = gpu_monitor.stop()
result.jit_logs = server.get_jit_logs()
result.client_e2e_sec = elapsed
result.jit_e2e_sec = jit_time
result.post_jit_e2e_sec = post_jit_time
result.status = "PASS" if success else "FAIL"
finally:
# Close client to stop heartbeat thread before server shutdown
Expand Down Expand Up @@ -589,7 +647,10 @@ def run_all_tests(self, results_writer: ResultsWriter | None = None) -> list[Tes
status_color = "\033[32m" if result.status == "PASS" else "\033[31m"
print(f"Result: {status_color}{result.status}\033[0m")
print(f"Peak GPU Memory: {result.peak_gpu_mem_mib} MiB")
print(f"Client E2E Time: {result.client_e2e_sec or 'N/A'}s")
jit_str = f"{result.jit_e2e_sec:.2f}s" if result.jit_e2e_sec else "N/A"
post_jit_str = f"{result.post_jit_e2e_sec:.2f}s" if result.post_jit_e2e_sec else "N/A"
print(f"JIT Time (1st request): {jit_str}")
print(f"Post-JIT Time (avg): {post_jit_str}")
if result.error_message:
print(f"Error: {result.error_message}")

Expand All @@ -607,7 +668,7 @@ def run_all_tests(self, results_writer: ResultsWriter | None = None) -> list[Tes
class ResultsWriter:
"""Write benchmark results to CSV incrementally."""

CSV_HEADER = ["mode", "batch_size", "seq_len", "status", "peak_gpu_mem_mib", "client_e2e_sec"]
CSV_HEADER = ["mode", "batch_size", "seq_len", "status", "peak_gpu_mem_mib", "jit_e2e_sec", "post_jit_e2e_sec"]

def __init__(self, output_path: Path):
self.output_path = output_path
Expand All @@ -627,7 +688,8 @@ def append(self, result: TestResult) -> None:
result.seq_len,
result.status,
result.peak_gpu_mem_mib,
f"{result.client_e2e_sec:.2f}" if result.client_e2e_sec else "",
f"{result.jit_e2e_sec:.2f}" if result.jit_e2e_sec else "",
f"{result.post_jit_e2e_sec:.2f}" if result.post_jit_e2e_sec else "",
]
)

Expand All @@ -641,9 +703,9 @@ def __init__(self, results: list[TestResult], output_path: str):

def print_summary(self) -> None:
"""Print human-readable summary to terminal."""
print("\n" + "=" * 70)
print("\n" + "=" * 85)
print("BENCHMARK SUMMARY")
print("=" * 70)
print("=" * 85)

# Group by mode
by_mode: dict[str, list[TestResult]] = {}
Expand All @@ -652,28 +714,29 @@ def print_summary(self) -> None:

for mode, mode_results in by_mode.items():
print(f"\n{mode.upper()} Results:")
print("-" * 58)
print(f"{'Batch':>8} {'SeqLen':>8} {'Status':>8} {'PeakMem':>12} {'E2E(s)':>10}")
print("-" * 58)
print("-" * 75)
print(f"{'Batch':>8} {'SeqLen':>8} {'Status':>8} {'PeakMem':>12} {'JIT(s)':>10} {'PostJIT(s)':>12}")
print("-" * 75)

for r in mode_results:
e2e = f"{r.client_e2e_sec:.2f}" if r.client_e2e_sec else "N/A"
jit = f"{r.jit_e2e_sec:.2f}" if r.jit_e2e_sec else "N/A"
post_jit = f"{r.post_jit_e2e_sec:.2f}" if r.post_jit_e2e_sec else "N/A"
status_color = "\033[32m" if r.status == "PASS" else "\033[31m"
print(
f"{r.batch_size:>8} {r.seq_len:>8} {status_color}{r.status:>8}\033[0m "
f"{r.peak_gpu_mem_mib:>10} MiB {e2e:>10}"
f"{r.peak_gpu_mem_mib:>10} MiB {jit:>10} {post_jit:>12}"
)

# Summary statistics
passed = sum(1 for r in self.results if r.status == "PASS")
failed = sum(1 for r in self.results if r.status in ("FAIL", "ERROR"))
max_mem = max((r.peak_gpu_mem_mib for r in self.results if r.status == "PASS"), default=0)

print("\n" + "=" * 70)
print("\n" + "=" * 85)
print(f"Total: {len(self.results)} tests | Passed: {passed} | Failed: {failed}")
print(f"Peak Memory (successful tests): {max_mem} MiB")
print(f"Results saved to: {Path(self.output_path).resolve()}")
print("=" * 70)
print("=" * 85)

# Print JIT compilation logs
print("\n" + "=" * 70)
Expand Down Expand Up @@ -744,6 +807,12 @@ def parse_args() -> argparse.Namespace:
type=lambda s: [int(x) for x in s.split(",")],
help="Comma-separated sequence lengths to test",
)
test_group.add_argument(
"--num-measurement-iters",
type=int,
default=3,
help="Number of post-JIT measurement iterations (default: 3)",
)
test_group.add_argument(
"--server-only",
action="store_true",
Expand Down
Loading