From 59d3341bfa3bc3a0b0e1109016f8175e540ad326 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 4 Feb 2026 12:06:18 -0800 Subject: [PATCH 1/4] Add separate JIT vs post-JIT timing measurement to memory benchmark The benchmark now sends multiple requests per test: the first request triggers JIT compilation and subsequent requests measure actual post-JIT runtime. This provides clearer insight into compilation overhead vs steady-state performance. - Add jit_e2e_sec field to capture first request (JIT) time - Rename client_e2e_sec to post_jit_e2e_sec for clarity - Add --num-measurement-iters flag (default: 3) for post-JIT iterations - Update CSV output and summary tables to show both metrics Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/benchmarks/benchmark_memory.py | 179 +++++++++++++++++------- 1 file changed, 126 insertions(+), 53 deletions(-) diff --git a/skyrl-tx/benchmarks/benchmark_memory.py b/skyrl-tx/benchmarks/benchmark_memory.py index 4b1fc5756c..13ba47a70e 100644 --- a/skyrl-tx/benchmarks/benchmark_memory.py +++ b/skyrl-tx/benchmarks/benchmark_memory.py @@ -10,6 +10,8 @@ Features: - Test sampling, training, or both modes - Sweep across multiple batch sizes and sequence lengths + - Separate JIT compilation time vs post-JIT runtime measurement + - Configurable number of post-JIT measurement iterations for averaging - Early termination: skips remaining batch sizes if one fails (e.g., OOM) - GPU memory monitoring via nvidia-smi polling - Per-test server logs with JIT compilation time extraction @@ -39,10 +41,14 @@ uv run --extra tinker python benchmarks/benchmark_memory.py \\ --backend-config '{"loss_chunk_size": 512, "enforce_eager": true}' + # Run with multiple measurement iterations for more accurate post-JIT timing + uv run --extra tinker python benchmarks/benchmark_memory.py \\ + --num-measurement-iters 3 --batch-sizes 8 --seq-lens 4096 + Output directory (default: /tmp/skyrl_tx_memory_benchmark/): tx_memory_benchmark_{experiment_name}_{timestamp}/ config.json # Full benchmark configuration (JSON) - results.csv # Results table (mode, batch, seq, status, peak_mem, e2e_time) + results.csv # Results table (mode, batch, seq, status, peak_mem, jit_time, post_jit_time) tinker.db # SQLite database used by tinker API server_*.log # Server stdout/stderr for each test run xla_dump_*/ # XLA HLO graphs per test (if --dump-xla enabled) @@ -96,6 +102,7 @@ class BenchmarkConfig: test_mode: Literal["sample", "train", "both"] = "both" batch_sizes: list[int] = field(default_factory=lambda: [4, 8, 16, 32]) seq_lens: list[int] = field(default_factory=lambda: [8192]) + num_measurement_iters: int = 3 # Number of post-JIT measurement iterations server_only: bool = False # Runtime configuration @@ -131,7 +138,8 @@ class TestResult: status: Literal["PASS", "FAIL", "ERROR"] peak_gpu_mem_mib: int jit_logs: list[str] - client_e2e_sec: float | None + jit_e2e_sec: float | None # First request (includes JIT compilation) + post_jit_e2e_sec: float | None # Average of subsequent requests (post-JIT) error_message: str | None = None @@ -440,8 +448,14 @@ def _make_datum(self, seq_len: int) -> types.Datum: }, ) - def _test_sample(self, service_client, server: ServerManager, batch_size: int, seq_len: int) -> tuple[bool, float]: - """Execute sampling test.""" + def _test_sample( + self, service_client, server: ServerManager, batch_size: int, seq_len: int, num_warmup_iters: int = 1 + ) -> tuple[bool, float, float]: + """Execute sampling test with warmup iterations. + + Returns: + Tuple of (success, jit_time, post_jit_avg_time) + """ sampling_client = service_client.create_sampling_client(base_model=self.config.base_model) # Build prompt - half prompt, half generation @@ -451,46 +465,88 @@ def _test_sample(self, service_client, server: ServerManager, batch_size: int, s prompt_tokens = (base_tokens * ((prompt_len // len(base_tokens)) + 1))[:prompt_len] prompt = types.ModelInput.from_ints(prompt_tokens) - start_time = time.time() - request = sampling_client.sample( - prompt=prompt, - sampling_params=types.SamplingParams(temperature=0.7, max_tokens=gen_len, seed=42), - num_samples=batch_size, - ) - # Poll with small timeout to allow server aliveness checks - while True: - try: - result = request.result(timeout=5) - break - except TimeoutError: - if not server.is_alive(): - raise RuntimeError("Server crashed during test") - elapsed = time.time() - start_time - - return len(result.sequences) == batch_size, elapsed + def run_sample() -> tuple[bool, float]: + """Run a single sample request and return (success, elapsed_time).""" + start_time = time.time() + request = sampling_client.sample( + prompt=prompt, + sampling_params=types.SamplingParams(temperature=0.7, max_tokens=gen_len, seed=42), + num_samples=batch_size, + ) + # Poll with small timeout to allow server aliveness checks + while True: + try: + result = request.result(timeout=5) + break + except TimeoutError: + if not server.is_alive(): + raise RuntimeError("Server crashed during test") + elapsed = time.time() - start_time + return len(result.sequences) == batch_size, elapsed + + # First request triggers JIT compilation + print(" Running warmup request (JIT compilation)...") + success, jit_time = run_sample() + if not success: + return False, jit_time, 0.0 + + # Subsequent requests measure post-JIT performance + post_jit_times = [] + for i in range(num_warmup_iters): + print(f" Running measurement request {i + 1}/{num_warmup_iters}...") + success, elapsed = run_sample() + if not success: + return False, jit_time, 0.0 + post_jit_times.append(elapsed) + + avg_post_jit_time = sum(post_jit_times) / len(post_jit_times) if post_jit_times else 0.0 + return True, jit_time, avg_post_jit_time def _test_forward_backward( - self, service_client, server: ServerManager, batch_size: int, seq_len: int - ) -> tuple[bool, float]: - """Execute forward-backward test.""" + self, service_client, server: ServerManager, batch_size: int, seq_len: int, num_warmup_iters: int = 1 + ) -> tuple[bool, float, float]: + """Execute forward-backward test with warmup iterations. + + Returns: + Tuple of (success, jit_time, post_jit_avg_time) + """ training_client = service_client.create_lora_training_client(base_model=self.config.base_model) # Create training data data = [self._make_datum(seq_len) for _ in range(batch_size)] - start_time = time.time() - fwdbwd_future = training_client.forward_backward(data, "cross_entropy") - # Poll with small timeout to allow server aliveness checks - while True: - try: - result = fwdbwd_future.result(timeout=5) - break - except TimeoutError: - if not server.is_alive(): - raise RuntimeError("Server crashed during test") - elapsed = time.time() - start_time - - return len(result.loss_fn_outputs) == batch_size, elapsed + def run_forward_backward() -> tuple[bool, float]: + """Run a single forward-backward request and return (success, elapsed_time).""" + start_time = time.time() + fwdbwd_future = training_client.forward_backward(data, "cross_entropy") + # Poll with small timeout to allow server aliveness checks + while True: + try: + result = fwdbwd_future.result(timeout=5) + break + except TimeoutError: + if not server.is_alive(): + raise RuntimeError("Server crashed during test") + elapsed = time.time() - start_time + return len(result.loss_fn_outputs) == batch_size, elapsed + + # First request triggers JIT compilation + print(" Running warmup request (JIT compilation)...") + success, jit_time = run_forward_backward() + if not success: + return False, jit_time, 0.0 + + # Subsequent requests measure post-JIT performance + post_jit_times = [] + for i in range(num_warmup_iters): + print(f" Running measurement request {i + 1}/{num_warmup_iters}...") + success, elapsed = run_forward_backward() + if not success: + return False, jit_time, 0.0 + post_jit_times.append(elapsed) + + avg_post_jit_time = sum(post_jit_times) / len(post_jit_times) if post_jit_times else 0.0 + return True, jit_time, avg_post_jit_time def run_single_test(self, batch_size: int, seq_len: int, mode: str) -> TestResult: """Run a single benchmark test with given parameters.""" @@ -505,7 +561,8 @@ def run_single_test(self, batch_size: int, seq_len: int, mode: str) -> TestResul status="ERROR", peak_gpu_mem_mib=0, jit_logs=[], - client_e2e_sec=None, + jit_e2e_sec=None, + post_jit_e2e_sec=None, error_message=None, ) @@ -534,14 +591,19 @@ def run_single_test(self, batch_size: int, seq_len: int, mode: str) -> TestResul try: print(f" Running {mode} test...") if mode == "sample": - success, elapsed = self._test_sample(service_client, server, batch_size, seq_len) + success, jit_time, post_jit_time = self._test_sample( + service_client, server, batch_size, seq_len, self.config.num_measurement_iters + ) else: - success, elapsed = self._test_forward_backward(service_client, server, batch_size, seq_len) + success, jit_time, post_jit_time = self._test_forward_backward( + service_client, server, batch_size, seq_len, self.config.num_measurement_iters + ) # Collect results result.peak_gpu_mem_mib = gpu_monitor.stop() result.jit_logs = server.get_jit_logs() - result.client_e2e_sec = elapsed + result.jit_e2e_sec = jit_time + result.post_jit_e2e_sec = post_jit_time result.status = "PASS" if success else "FAIL" finally: # Close client to stop heartbeat thread before server shutdown @@ -589,7 +651,10 @@ def run_all_tests(self, results_writer: ResultsWriter | None = None) -> list[Tes status_color = "\033[32m" if result.status == "PASS" else "\033[31m" print(f"Result: {status_color}{result.status}\033[0m") print(f"Peak GPU Memory: {result.peak_gpu_mem_mib} MiB") - print(f"Client E2E Time: {result.client_e2e_sec or 'N/A'}s") + jit_str = f"{result.jit_e2e_sec:.2f}s" if result.jit_e2e_sec else "N/A" + post_jit_str = f"{result.post_jit_e2e_sec:.2f}s" if result.post_jit_e2e_sec else "N/A" + print(f"JIT Time (1st request): {jit_str}") + print(f"Post-JIT Time (avg): {post_jit_str}") if result.error_message: print(f"Error: {result.error_message}") @@ -607,7 +672,7 @@ def run_all_tests(self, results_writer: ResultsWriter | None = None) -> list[Tes class ResultsWriter: """Write benchmark results to CSV incrementally.""" - CSV_HEADER = ["mode", "batch_size", "seq_len", "status", "peak_gpu_mem_mib", "client_e2e_sec"] + CSV_HEADER = ["mode", "batch_size", "seq_len", "status", "peak_gpu_mem_mib", "jit_e2e_sec", "post_jit_e2e_sec"] def __init__(self, output_path: Path): self.output_path = output_path @@ -627,7 +692,8 @@ def append(self, result: TestResult) -> None: result.seq_len, result.status, result.peak_gpu_mem_mib, - f"{result.client_e2e_sec:.2f}" if result.client_e2e_sec else "", + f"{result.jit_e2e_sec:.2f}" if result.jit_e2e_sec else "", + f"{result.post_jit_e2e_sec:.2f}" if result.post_jit_e2e_sec else "", ] ) @@ -641,9 +707,9 @@ def __init__(self, results: list[TestResult], output_path: str): def print_summary(self) -> None: """Print human-readable summary to terminal.""" - print("\n" + "=" * 70) + print("\n" + "=" * 85) print("BENCHMARK SUMMARY") - print("=" * 70) + print("=" * 85) # Group by mode by_mode: dict[str, list[TestResult]] = {} @@ -652,16 +718,17 @@ def print_summary(self) -> None: for mode, mode_results in by_mode.items(): print(f"\n{mode.upper()} Results:") - print("-" * 58) - print(f"{'Batch':>8} {'SeqLen':>8} {'Status':>8} {'PeakMem':>12} {'E2E(s)':>10}") - print("-" * 58) + print("-" * 75) + print(f"{'Batch':>8} {'SeqLen':>8} {'Status':>8} {'PeakMem':>12} {'JIT(s)':>10} {'PostJIT(s)':>12}") + print("-" * 75) for r in mode_results: - e2e = f"{r.client_e2e_sec:.2f}" if r.client_e2e_sec else "N/A" + jit = f"{r.jit_e2e_sec:.2f}" if r.jit_e2e_sec else "N/A" + post_jit = f"{r.post_jit_e2e_sec:.2f}" if r.post_jit_e2e_sec else "N/A" status_color = "\033[32m" if r.status == "PASS" else "\033[31m" print( f"{r.batch_size:>8} {r.seq_len:>8} {status_color}{r.status:>8}\033[0m " - f"{r.peak_gpu_mem_mib:>10} MiB {e2e:>10}" + f"{r.peak_gpu_mem_mib:>10} MiB {jit:>10} {post_jit:>12}" ) # Summary statistics @@ -669,11 +736,11 @@ def print_summary(self) -> None: failed = sum(1 for r in self.results if r.status in ("FAIL", "ERROR")) max_mem = max((r.peak_gpu_mem_mib for r in self.results if r.status == "PASS"), default=0) - print("\n" + "=" * 70) + print("\n" + "=" * 85) print(f"Total: {len(self.results)} tests | Passed: {passed} | Failed: {failed}") print(f"Peak Memory (successful tests): {max_mem} MiB") print(f"Results saved to: {Path(self.output_path).resolve()}") - print("=" * 70) + print("=" * 85) # Print JIT compilation logs print("\n" + "=" * 70) @@ -744,6 +811,12 @@ def parse_args() -> argparse.Namespace: type=lambda s: [int(x) for x in s.split(",")], help="Comma-separated sequence lengths to test", ) + test_group.add_argument( + "--num-measurement-iters", + type=int, + default=3, + help="Number of post-JIT measurement iterations (default: 3)", + ) test_group.add_argument( "--server-only", action="store_true", From cf65e6031f70ad18dbd2a67500aa9ffbe919ce30 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 4 Feb 2026 12:36:49 -0800 Subject: [PATCH 2/4] Rename num_warmup_iters to num_measurement_iters and refactor Address PR feedback: - Rename parameter to match CLI flag name for consistency - Remove default value since it's always passed from caller - Extract common timing logic into _run_timed_requests helper method - Use random tokens for each request Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/benchmarks/benchmark_memory.py | 106 ++++++++++++------------ 1 file changed, 51 insertions(+), 55 deletions(-) diff --git a/skyrl-tx/benchmarks/benchmark_memory.py b/skyrl-tx/benchmarks/benchmark_memory.py index 13ba47a70e..1b5501320e 100644 --- a/skyrl-tx/benchmarks/benchmark_memory.py +++ b/skyrl-tx/benchmarks/benchmark_memory.py @@ -64,6 +64,7 @@ import csv import json import os +import random import re import signal import subprocess @@ -73,7 +74,7 @@ from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import Literal +from typing import Callable, Literal import httpx import tinker @@ -434,9 +435,10 @@ def __init__(self, config: BenchmarkConfig): self.config = config self.tokenizer = AutoTokenizer.from_pretrained(config.base_model) - def _make_datum(self, seq_len: int) -> types.Datum: - """Create a training datum with specified sequence length.""" - all_tokens = list(range(1, seq_len + 1)) + def _make_datum(self, seq_len: int, rng: random.Random) -> types.Datum: + """Create a training datum with random tokens.""" + vocab_size = self.tokenizer.vocab_size + all_tokens = [rng.randint(1, vocab_size - 1) for _ in range(seq_len)] target_tokens = all_tokens[1:] + [self.tokenizer.eos_token_id] weights = [1.0] * seq_len @@ -448,29 +450,59 @@ def _make_datum(self, seq_len: int) -> types.Datum: }, ) - def _test_sample( - self, service_client, server: ServerManager, batch_size: int, seq_len: int, num_warmup_iters: int = 1 + def _run_timed_requests( + self, + run_single_request: Callable[[], tuple[bool, float]], + num_measurement_iters: int, ) -> tuple[bool, float, float]: - """Execute sampling test with warmup iterations. + """Run timed requests with JIT warmup and measurement iterations. + + Args: + run_single_request: Callable that executes a single request and returns (success, elapsed_time) + num_measurement_iters: Number of post-JIT measurement iterations Returns: Tuple of (success, jit_time, post_jit_avg_time) """ + # First request triggers JIT compilation + print(" Running warmup request (JIT compilation)...") + success, jit_time = run_single_request() + if not success: + return False, jit_time, 0.0 + + # Subsequent requests measure post-JIT performance + post_jit_times = [] + for i in range(num_measurement_iters): + print(f" Running measurement request {i + 1}/{num_measurement_iters}...") + success, elapsed = run_single_request() + if not success: + return False, jit_time, 0.0 + post_jit_times.append(elapsed) + + avg_post_jit_time = sum(post_jit_times) / len(post_jit_times) if post_jit_times else 0.0 + return True, jit_time, avg_post_jit_time + + def _test_sample( + self, service_client, server: ServerManager, batch_size: int, seq_len: int, num_measurement_iters: int + ) -> tuple[bool, float, float]: + """Execute sampling test with warmup iterations.""" sampling_client = service_client.create_sampling_client(base_model=self.config.base_model) + vocab_size = self.tokenizer.vocab_size + rng = random.Random(42) - # Build prompt - half prompt, half generation + # Half prompt, half generation prompt_len = seq_len // 2 gen_len = seq_len - prompt_len - base_tokens = self.tokenizer.encode("Hello, how are you doing today? ", add_special_tokens=True) - prompt_tokens = (base_tokens * ((prompt_len // len(base_tokens)) + 1))[:prompt_len] - prompt = types.ModelInput.from_ints(prompt_tokens) def run_sample() -> tuple[bool, float]: """Run a single sample request and return (success, elapsed_time).""" + prompt_tokens = [rng.randint(1, vocab_size - 1) for _ in range(prompt_len)] + prompt = types.ModelInput.from_ints(prompt_tokens) + start_time = time.time() request = sampling_client.sample( prompt=prompt, - sampling_params=types.SamplingParams(temperature=0.7, max_tokens=gen_len, seed=42), + sampling_params=types.SamplingParams(temperature=0.7, max_tokens=gen_len), num_samples=batch_size, ) # Poll with small timeout to allow server aliveness checks @@ -484,39 +516,19 @@ def run_sample() -> tuple[bool, float]: elapsed = time.time() - start_time return len(result.sequences) == batch_size, elapsed - # First request triggers JIT compilation - print(" Running warmup request (JIT compilation)...") - success, jit_time = run_sample() - if not success: - return False, jit_time, 0.0 - - # Subsequent requests measure post-JIT performance - post_jit_times = [] - for i in range(num_warmup_iters): - print(f" Running measurement request {i + 1}/{num_warmup_iters}...") - success, elapsed = run_sample() - if not success: - return False, jit_time, 0.0 - post_jit_times.append(elapsed) - - avg_post_jit_time = sum(post_jit_times) / len(post_jit_times) if post_jit_times else 0.0 - return True, jit_time, avg_post_jit_time + return self._run_timed_requests(run_sample, num_measurement_iters) def _test_forward_backward( - self, service_client, server: ServerManager, batch_size: int, seq_len: int, num_warmup_iters: int = 1 + self, service_client, server: ServerManager, batch_size: int, seq_len: int, num_measurement_iters: int ) -> tuple[bool, float, float]: - """Execute forward-backward test with warmup iterations. - - Returns: - Tuple of (success, jit_time, post_jit_avg_time) - """ + """Execute forward-backward test with warmup iterations.""" training_client = service_client.create_lora_training_client(base_model=self.config.base_model) - - # Create training data - data = [self._make_datum(seq_len) for _ in range(batch_size)] + rng = random.Random(42) def run_forward_backward() -> tuple[bool, float]: """Run a single forward-backward request and return (success, elapsed_time).""" + data = [self._make_datum(seq_len, rng) for _ in range(batch_size)] + start_time = time.time() fwdbwd_future = training_client.forward_backward(data, "cross_entropy") # Poll with small timeout to allow server aliveness checks @@ -530,23 +542,7 @@ def run_forward_backward() -> tuple[bool, float]: elapsed = time.time() - start_time return len(result.loss_fn_outputs) == batch_size, elapsed - # First request triggers JIT compilation - print(" Running warmup request (JIT compilation)...") - success, jit_time = run_forward_backward() - if not success: - return False, jit_time, 0.0 - - # Subsequent requests measure post-JIT performance - post_jit_times = [] - for i in range(num_warmup_iters): - print(f" Running measurement request {i + 1}/{num_warmup_iters}...") - success, elapsed = run_forward_backward() - if not success: - return False, jit_time, 0.0 - post_jit_times.append(elapsed) - - avg_post_jit_time = sum(post_jit_times) / len(post_jit_times) if post_jit_times else 0.0 - return True, jit_time, avg_post_jit_time + return self._run_timed_requests(run_forward_backward, num_measurement_iters) def run_single_test(self, batch_size: int, seq_len: int, mode: str) -> TestResult: """Run a single benchmark test with given parameters.""" From d44cfbc2f0334804cc03d42c59fd60b3f62f8ae2 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 4 Feb 2026 13:26:56 -0800 Subject: [PATCH 3/4] Update skyrl-tx/benchmarks/benchmark_memory.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- skyrl-tx/benchmarks/benchmark_memory.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/skyrl-tx/benchmarks/benchmark_memory.py b/skyrl-tx/benchmarks/benchmark_memory.py index 1b5501320e..45ab298b7f 100644 --- a/skyrl-tx/benchmarks/benchmark_memory.py +++ b/skyrl-tx/benchmarks/benchmark_memory.py @@ -502,8 +502,7 @@ def run_sample() -> tuple[bool, float]: start_time = time.time() request = sampling_client.sample( prompt=prompt, - sampling_params=types.SamplingParams(temperature=0.7, max_tokens=gen_len), - num_samples=batch_size, + sampling_params=types.SamplingParams(temperature=0.7, max_tokens=gen_len, seed=42), ) # Poll with small timeout to allow server aliveness checks while True: From f1486f0600f8b3c7bfdc08f4d39a254f633dbc09 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 4 Feb 2026 14:16:07 -0800 Subject: [PATCH 4/4] fix --- skyrl-tx/benchmarks/benchmark_memory.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl-tx/benchmarks/benchmark_memory.py b/skyrl-tx/benchmarks/benchmark_memory.py index 45ab298b7f..f36c2616f9 100644 --- a/skyrl-tx/benchmarks/benchmark_memory.py +++ b/skyrl-tx/benchmarks/benchmark_memory.py @@ -503,6 +503,7 @@ def run_sample() -> tuple[bool, float]: request = sampling_client.sample( prompt=prompt, sampling_params=types.SamplingParams(temperature=0.7, max_tokens=gen_len, seed=42), + num_samples=batch_size, ) # Poll with small timeout to allow server aliveness checks while True: