diff --git a/skyrl-tx/pyproject.toml b/skyrl-tx/pyproject.toml index 8ec58d243..a07ef27f8 100644 --- a/skyrl-tx/pyproject.toml +++ b/skyrl-tx/pyproject.toml @@ -28,6 +28,8 @@ dependencies = [ [project.optional-dependencies] gpu = [ "jax[cuda12]>=0.7.2", + "torch>=2.0.0", + "cuda-tile", # cuTile Python from PyPI (requires CUDA Toolkit 13.1+) ] tpu = [ diff --git a/skyrl-tx/tests/README.md b/skyrl-tx/tests/README.md index 747e748b0..972ce4a98 100644 --- a/skyrl-tx/tests/README.md +++ b/skyrl-tx/tests/README.md @@ -5,3 +5,6 @@ uv run --extra dev --extra tinker pytest -v ## Run Specific tests uv run --extra dev --extra tinker pytest -v -s tests/models/test_qwen3_generate.py::test_qwen3_generate_speed + +## Run Cutile tests +TX_USE_CUTILE_LORA=1 uv run pytest tests/cutile/test_cutile_lora_equivalence.py -v \ No newline at end of file diff --git a/skyrl-tx/tests/cutile/test_cutile_lora_equivalence.py b/skyrl-tx/tests/cutile/test_cutile_lora_equivalence.py new file mode 100644 index 000000000..bbcf8b984 --- /dev/null +++ b/skyrl-tx/tests/cutile/test_cutile_lora_equivalence.py @@ -0,0 +1,504 @@ +""" +Test equivalence between cutile-based LoRA computation and ragged_dot. + +This test suite verifies that the cutile implementation produces numerically +equivalent results to the existing JAX ragged_dot implementation for LoRA +expert parallelism computation. + +Phase 1 Scope: Single-GPU forward pass only (no group_offset, no gradients). +""" + +import pytest +import numpy as np + +try: + import jax + import jax.numpy as jnp + from jax import random +except ImportError: + pytest.skip("JAX not available", allow_module_level=True) + +try: + import torch +except ImportError: + pytest.skip("PyTorch not available", allow_module_level=True) + +try: + from tx.kernels.cutile_lora import cutile_ragged_dot + from tx.kernels import CUTILE_AVAILABLE +except ImportError: + CUTILE_AVAILABLE = False + cutile_ragged_dot = None + pytestmark = pytest.mark.skip("Cutile implementation not available") + +from tx.layers.util import ragged_dot + + +# ============================================================================ +# Test Utilities +# ============================================================================ + + +def generate_ragged_test_case( + m: int, + d: int, + out_features: int, + num_experts: int, + seed: int = 42, + distribution: str = "balanced", + dtype=jnp.float32, +): + """Generate synthetic test case for ragged_dot equivalence testing. + + Args: + m: Total number of tokens + d: Hidden dimension (input features) + out_features: Output dimension + num_experts: Number of expert groups + seed: Random seed for reproducibility + distribution: Token distribution strategy: + - "balanced": Equal tokens per expert + - "imbalanced": Random uneven distribution + - "sparse": Some experts have zero tokens + dtype: Data type for arrays (jnp.float32, jnp.float16, jnp.bfloat16) + + Returns: + Tuple of (lhs, rhs, group_sizes) where: + lhs: [m, d] input tokens + rhs: [num_experts, d, out_features] expert weights + group_sizes: [num_experts] number of tokens per expert + """ + key = random.PRNGKey(seed) + key_lhs, key_rhs, key_sizes = random.split(key, 3) + + # Generate input tokens + lhs = random.normal(key_lhs, (m, d), dtype=dtype) + + # Generate expert weights + rhs = random.normal(key_rhs, (num_experts, d, out_features), dtype=dtype) * 0.02 + + # Generate group sizes based on distribution strategy + if distribution == "balanced": + base_size = m // num_experts + remainder = m % num_experts + group_sizes = jnp.array([base_size + (1 if i < remainder else 0) for i in range(num_experts)], dtype=jnp.int32) + elif distribution == "imbalanced": + # Random distribution ensuring sum equals m + random_weights = random.uniform(key_sizes, (num_experts,)) + random_weights = random_weights / random_weights.sum() + group_sizes = (random_weights * m).astype(jnp.int32) + # Adjust to ensure exact sum + diff = m - group_sizes.sum() + group_sizes = group_sizes.at[0].add(diff) + elif distribution == "sparse": + # Some experts get zero tokens + active_experts = num_experts // 2 + active_sizes = jnp.array([m // active_experts] * active_experts, dtype=jnp.int32) + # Adjust for remainder + active_sizes = active_sizes.at[0].add(m - active_sizes.sum()) + group_sizes = jnp.concatenate([active_sizes, jnp.zeros(num_experts - active_experts, dtype=jnp.int32)]) + else: + raise ValueError(f"Unknown distribution: {distribution}") + + return lhs, rhs, group_sizes + + +def compare_outputs( + cutile_out: jax.Array, + ragged_out: jax.Array, + rtol: float = 1e-3, + atol: float = 1e-5, + verbose: bool = True, +): + """Compare cutile and ragged_dot outputs with detailed error reporting. + + Args: + cutile_out: Output from cutile implementation + ragged_out: Output from ragged_dot + rtol: Relative tolerance + atol: Absolute tolerance + verbose: Print detailed error statistics + + Raises: + AssertionError: If outputs don't match within tolerance + """ + # Ensure both are JAX arrays + if isinstance(cutile_out, torch.Tensor): + # This shouldn't happen, but handle gracefully + cutile_out = jnp.array(cutile_out.detach().cpu().numpy()) + + # Check shapes match + assert ( + cutile_out.shape == ragged_out.shape + ), f"Shape mismatch: cutile {cutile_out.shape} vs ragged {ragged_out.shape}" + + # Check dtypes match + assert ( + cutile_out.dtype == ragged_out.dtype + ), f"Dtype mismatch: cutile {cutile_out.dtype} vs ragged {ragged_out.dtype}" + + # Compute element-wise differences + abs_diff = jnp.abs(cutile_out - ragged_out) + rel_diff = abs_diff / (jnp.abs(ragged_out) + 1e-8) + + # Statistics + max_abs_diff = jnp.max(abs_diff) + max_rel_diff = jnp.max(rel_diff) + mean_abs_diff = jnp.mean(abs_diff) + mean_rel_diff = jnp.mean(rel_diff) + + if verbose: + print("\nNumerical Comparison:") + print(f" Max absolute diff: {max_abs_diff:.6e}") + print(f" Max relative diff: {max_rel_diff:.6e}") + print(f" Mean absolute diff: {mean_abs_diff:.6e}") + print(f" Mean relative diff: {mean_rel_diff:.6e}") + print(f" Tolerance: rtol={rtol}, atol={atol}") + + # Check if within tolerance + try: + np.testing.assert_allclose( + cutile_out, + ragged_out, + rtol=rtol, + atol=atol, + err_msg=f"Outputs differ beyond tolerance (max rel diff: {max_rel_diff:.6e})", + ) + if verbose: + print(" ✓ PASS: Within tolerance") + except AssertionError: + # Find indices of largest errors + error_mask = rel_diff > rtol + num_errors = jnp.sum(error_mask) + print(f" ✗ FAIL: {num_errors}/{cutile_out.size} elements exceed tolerance") + + # Show a few worst offenders + flat_rel_diff = rel_diff.flatten() + worst_indices = jnp.argsort(flat_rel_diff)[-5:] + print("\n Worst 5 errors:") + for idx in worst_indices: + i = int(idx) + print( + f" Index {i}: cutile={float(cutile_out.flatten()[i]):.6e}, " + f"ragged={float(ragged_out.flatten()[i]):.6e}, " + f"rel_diff={float(flat_rel_diff[i]):.6e}" + ) + raise + + +def benchmark_both( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + num_runs: int = 100, + warmup: int = 10, +): + """Benchmark latency of ragged_dot vs cutile. + + Args: + lhs: Input tokens + rhs: Expert weights + group_sizes: Group sizes + num_runs: Number of benchmark iterations + warmup: Number of warmup iterations + + Returns: + Tuple of (ragged_time_ms, cutile_time_ms, speedup) + """ + import time + + # Warmup ragged_dot + for _ in range(warmup): + _ = ragged_dot(lhs, rhs, group_sizes) + jax.block_until_ready(_) + + # Benchmark ragged_dot + start = time.perf_counter() + for _ in range(num_runs): + out = ragged_dot(lhs, rhs, group_sizes) + jax.block_until_ready(out) + ragged_time = (time.perf_counter() - start) * 1000 / num_runs + + # Warmup cutile + for _ in range(warmup): + _ = cutile_ragged_dot(lhs, rhs, group_sizes) + jax.block_until_ready(_) + + # Benchmark cutile + start = time.perf_counter() + for _ in range(num_runs): + out = cutile_ragged_dot(lhs, rhs, group_sizes) + jax.block_until_ready(out) + cutile_time = (time.perf_counter() - start) * 1000 / num_runs + + speedup = ragged_time / cutile_time + + print("\nBenchmark Results:") + print(f" ragged_dot: {ragged_time:.3f} ms") + print(f" cutile: {cutile_time:.3f} ms") + print(f" Speedup: {speedup:.2f}x") + + return ragged_time, cutile_time, speedup + + +# ============================================================================ +# Test Cases +# ============================================================================ + + +@pytest.mark.skipif(not CUTILE_AVAILABLE, reason="Cutile not available") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestCutileRaggedDotEquivalence: + """Test cutile_ragged_dot matches ragged_dot output.""" + + def test_basic_single_expert(self): + """Simple case: all tokens to one expert.""" + m, d, out_features, num_experts = 128, 64, 32, 1 + + lhs, rhs, group_sizes = generate_ragged_test_case( + m, d, out_features, num_experts, seed=42, distribution="balanced" + ) + + # Run both implementations + ragged_out = ragged_dot(lhs, rhs, group_sizes) + cutile_out = cutile_ragged_dot(lhs, rhs, group_sizes) + + # Compare + compare_outputs(cutile_out, ragged_out, rtol=1e-3, atol=1e-5) + + def test_multiple_experts_balanced(self): + """Multiple experts with equal token distribution.""" + m, d, out_features, num_experts = 256, 128, 128, 4 + + lhs, rhs, group_sizes = generate_ragged_test_case( + m, d, out_features, num_experts, seed=42, distribution="balanced" + ) + + ragged_out = ragged_dot(lhs, rhs, group_sizes) + cutile_out = cutile_ragged_dot(lhs, rhs, group_sizes) + + compare_outputs(cutile_out, ragged_out, rtol=1e-3, atol=1e-5) + + def test_multiple_experts_imbalanced(self): + """Realistic case: uneven token distribution.""" + m, d, out_features, num_experts = 512, 256, 256, 8 + + lhs, rhs, group_sizes = generate_ragged_test_case( + m, d, out_features, num_experts, seed=123, distribution="imbalanced" + ) + + ragged_out = ragged_dot(lhs, rhs, group_sizes) + cutile_out = cutile_ragged_dot(lhs, rhs, group_sizes) + + compare_outputs(cutile_out, ragged_out, rtol=1e-3, atol=1e-5) + + def test_empty_expert_groups(self): + """Edge case: some experts have zero tokens.""" + m, d, out_features, num_experts = 256, 128, 128, 8 + + lhs, rhs, group_sizes = generate_ragged_test_case( + m, d, out_features, num_experts, seed=456, distribution="sparse" + ) + + # Verify we actually have empty groups + assert jnp.any(group_sizes == 0), "Test case should have empty groups" + + ragged_out = ragged_dot(lhs, rhs, group_sizes) + cutile_out = cutile_ragged_dot(lhs, rhs, group_sizes) + + compare_outputs(cutile_out, ragged_out, rtol=1e-3, atol=1e-5) + + def test_many_small_groups(self): + """Many experts with few tokens each.""" + m, d, out_features, num_experts = 512, 128, 128, 64 + + lhs, rhs, group_sizes = generate_ragged_test_case( + m, d, out_features, num_experts, seed=789, distribution="balanced" + ) + + ragged_out = ragged_dot(lhs, rhs, group_sizes) + cutile_out = cutile_ragged_dot(lhs, rhs, group_sizes) + + compare_outputs(cutile_out, ragged_out, rtol=1e-3, atol=1e-5) + + def test_few_large_groups(self): + """Few experts with many tokens each.""" + m, d, out_features, num_experts = 2048, 256, 256, 4 + + lhs, rhs, group_sizes = generate_ragged_test_case( + m, d, out_features, num_experts, seed=101, distribution="balanced" + ) + + ragged_out = ragged_dot(lhs, rhs, group_sizes) + cutile_out = cutile_ragged_dot(lhs, rhs, group_sizes) + + compare_outputs(cutile_out, ragged_out, rtol=1e-3, atol=1e-5) + + @pytest.mark.parametrize( + "d,out_features", + [ + (256, 256), + (512, 512), + (1024, 1024), + (2048, 2048), + ], + ) + def test_various_dimensions(self, d, out_features): + """Test different hidden dimensions.""" + m, num_experts = 256, 4 + + lhs, rhs, group_sizes = generate_ragged_test_case( + m, d, out_features, num_experts, seed=202, distribution="balanced" + ) + + ragged_out = ragged_dot(lhs, rhs, group_sizes) + cutile_out = cutile_ragged_dot(lhs, rhs, group_sizes) + + compare_outputs(cutile_out, ragged_out, rtol=1e-3, atol=1e-5) + + @pytest.mark.parametrize( + "dtype", + [ + jnp.float32, + pytest.param(jnp.float16, marks=pytest.mark.xfail(reason="fp16 may have numerical issues")), + pytest.param(jnp.bfloat16, marks=pytest.mark.xfail(reason="bf16 may have numerical issues")), + ], + ) + def test_dtype_preservation(self, dtype): + """Verify dtype (fp32, fp16, bf16) preserved.""" + m, d, out_features, num_experts = 256, 128, 128, 4 + + lhs, rhs, group_sizes = generate_ragged_test_case( + m, d, out_features, num_experts, seed=303, distribution="balanced", dtype=dtype + ) + + ragged_out = ragged_dot(lhs, rhs, group_sizes) + cutile_out = cutile_ragged_dot(lhs, rhs, group_sizes) + + # Check dtype preserved + assert ( + cutile_out.dtype == ragged_out.dtype == dtype + ), f"Dtype not preserved: input={dtype}, ragged={ragged_out.dtype}, cutile={cutile_out.dtype}" + + # Looser tolerance for fp16/bf16 + rtol = 1e-2 if dtype in [jnp.float16, jnp.bfloat16] else 1e-3 + compare_outputs(cutile_out, ragged_out, rtol=rtol, atol=1e-4) + + def test_device_placement(self): + """Ensure output stays on same GPU.""" + m, d, out_features, num_experts = 256, 128, 128, 4 + + lhs, rhs, group_sizes = generate_ragged_test_case( + m, d, out_features, num_experts, seed=404, distribution="balanced" + ) + + # Ensure inputs are on GPU + device = jax.devices("gpu")[0] + lhs = jax.device_put(lhs, device) + rhs = jax.device_put(rhs, device) + group_sizes = jax.device_put(group_sizes, device) + + ragged_out = ragged_dot(lhs, rhs, group_sizes) + cutile_out = cutile_ragged_dot(lhs, rhs, group_sizes) + + # Check device placement + assert ( + cutile_out.device == ragged_out.device == device + ), f"Device mismatch: input={device}, ragged={ragged_out.device}, cutile={cutile_out.device}" + + compare_outputs(cutile_out, ragged_out, rtol=1e-3, atol=1e-5) + + def test_benchmark_performance(self, capsys): + """Benchmark cutile vs ragged_dot performance across multiple realistic configurations.""" + # Realistic configurations based on common LLM architectures + # Format: (m, d, out_features, num_experts, description) + configs = [ + (1024, 512, 512, 16, "Small (original)"), + (2048, 1024, 1024, 16, "Medium (Qwen-0.6B scale)"), + (4096, 1536, 1536, 32, "Large (Qwen2.5-1.5B scale)"), + (4096, 2048, 2048, 32, "Large+ (2B scale)"), + (8192, 4096, 4096, 64, "XLarge (Llama 3 8B scale)"), + ] + + results = [] + + # Temporarily disable output capturing to show benchmark results + with capsys.disabled(): + print(f"\n{'='*80}") + print(f"{'CUTILE vs RAGGED_DOT BENCHMARK SUITE':^80}") + print(f"{'='*80}") + + for m, d, out_features, num_experts, desc in configs: + lhs, rhs, group_sizes = generate_ragged_test_case( + m, d, out_features, num_experts, seed=505, distribution="imbalanced" + ) + + print(f"\n{desc}") + print(f" Config: {m} tokens × {d} hidden → {out_features} out, {num_experts} experts") + print(f" {'-'*76}") + + ragged_time, cutile_time, speedup = benchmark_both(lhs, rhs, group_sizes, num_runs=50, warmup=5) + + # Store results + results.append( + { + "config": desc, + "m": m, + "d": d, + "num_experts": num_experts, + "ragged_time": ragged_time, + "cutile_time": cutile_time, + "speedup": speedup, + } + ) + + # Basic sanity check + assert cutile_time > 0, f"Cutile execution failed for {desc}" + assert ragged_time > 0, f"Ragged_dot execution failed for {desc}" + + # Print summary table + print(f"\n{'='*80}") + print(f"{'SUMMARY':^80}") + print(f"{'='*80}") + print(f"{'Config':<20} {'Tokens':>8} {'Hidden':>8} {'Experts':>8} {'Speedup':>10}") + print(f"{'-'*80}") + for r in results: + status = "✓" if r["speedup"] > 1.0 else "⚠" + print( + f"{r['config']:<20} {r['m']:>8} {r['d']:>8} {r['num_experts']:>8} " + f"{status} {r['speedup']:>7.2f}x" + ) + print(f"{'='*80}\n") + + +# ============================================================================ +# CLI for Manual Testing +# ============================================================================ + +if __name__ == "__main__": + """Run basic equivalence test from command line.""" + print("Running basic cutile vs ragged_dot equivalence test...") + + # Simple test case + m, d, out_features, num_experts = 256, 128, 128, 4 + lhs, rhs, group_sizes = generate_ragged_test_case(m, d, out_features, num_experts, seed=42, distribution="balanced") + + print("\nTest configuration:") + print(f" Tokens (m): {m}") + print(f" Hidden dim (d): {d}") + print(f" Output dim: {out_features}") + print(f" Num experts: {num_experts}") + print(f" Group sizes: {group_sizes}") + + # Run both + print("\nRunning ragged_dot...") + ragged_out = ragged_dot(lhs, rhs, group_sizes) + + print("Running cutile...") + cutile_out = cutile_ragged_dot(lhs, rhs, group_sizes) + + # Compare + print("\nComparing outputs...") + compare_outputs(cutile_out, ragged_out, rtol=1e-3, atol=1e-5, verbose=True) + + print("\n✓ Basic test PASSED!") diff --git a/skyrl-tx/tests/cutile/time_cutile_parts.py b/skyrl-tx/tests/cutile/time_cutile_parts.py new file mode 100644 index 000000000..4cb938c01 --- /dev/null +++ b/skyrl-tx/tests/cutile/time_cutile_parts.py @@ -0,0 +1,174 @@ +""" +Time breakdown for cutile LoRA path: +- pad groups (_pad_groups_to_tile_m) +- cutile kernel launch (launch_cutile_lora_gemm) +- combined (pad + launch) + +Uses CUDA events for accurate GPU timing. +""" + +import sys + +try: + import jax + import jax.numpy as jnp + from jax import random +except ImportError: + print("JAX not available") + sys.exit(0) + +try: + import torch +except ImportError: + print("PyTorch not available") + sys.exit(0) + +try: + # IMPORTANT: match your test import style + from tx.kernels import CUTILE_AVAILABLE + from tx.kernels.cutile_lora import _pad_groups_to_tile_m + from tx.kernels.cutile_lora_kernels import launch_cutile_lora_gemm + from tx.kernels.cutile_config import config as default_config +except ImportError as e: + print(f"Cutile implementation not available: {e}") + sys.exit(0) + + +def cuda_time_ms(fn, iters=200, warmup=20) -> float: + """Time fn() with CUDA events; fn must enqueue CUDA work.""" + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + + return start.elapsed_time(end) / iters + + +def make_case(m, d, out_features, num_experts, seed=505, distribution="imbalanced", dtype=jnp.float16): + key = random.PRNGKey(seed) + key_lhs, key_rhs, key_sizes = random.split(key, 3) + + lhs = random.normal(key_lhs, (m, d), dtype=dtype) + rhs = random.normal(key_rhs, (num_experts, d, out_features), dtype=dtype) * 0.02 + + if distribution == "balanced": + base = m // num_experts + rem = m % num_experts + gs = jnp.array([base + (1 if i < rem else 0) for i in range(num_experts)], dtype=jnp.int32) + elif distribution == "imbalanced": + w = random.uniform(key_sizes, (num_experts,)) + w = w / w.sum() + gs = (w * m).astype(jnp.int32) + diff = m - gs.sum() + gs = gs.at[0].add(diff) + elif distribution == "sparse": + active = max(1, num_experts // 2) + active_sizes = jnp.array([m // active] * active, dtype=jnp.int32) + active_sizes = active_sizes.at[0].add(m - active_sizes.sum()) + gs = jnp.concatenate([active_sizes, jnp.zeros(num_experts - active, dtype=jnp.int32)]) + else: + raise ValueError(distribution) + + return lhs, rhs, gs + + +def main(): + if not CUTILE_AVAILABLE: + print("CUTILE_AVAILABLE is False") + return + if not torch.cuda.is_available(): + print("CUDA not available") + return + + # Pick one config first + m, d, out_features, num_experts = 2048, 1024, 1024, 16 + dtype = jnp.float16 # try jnp.bfloat16 too if you support it + + lhs_j, rhs_j, gs_j = make_case(m, d, out_features, num_experts, dtype=dtype) + + # Put on GPU (JAX) + dev = jax.devices("gpu")[0] + lhs_j = jax.device_put(lhs_j, dev) + rhs_j = jax.device_put(rhs_j, dev) + gs_j = jax.device_put(gs_j, dev) + + # Convert to torch via DLPack (same as your wrapper) + lhs_t = torch.from_dlpack(lhs_j) + rhs_t = torch.from_dlpack(rhs_j) + gs_t = torch.from_dlpack(gs_j) + + # Make sure dtypes are what you expect + if gs_t.dtype not in (torch.int32, torch.int64): + gs_t = gs_t.to(torch.int32) + + TILE_M = int(getattr(default_config, "tile_m")) + TILE_N = int(getattr(default_config, "tile_n")) + TILE_K = int(getattr(default_config, "tile_k")) + + print(f"Config: m={m}, d={d}, out={out_features}, E={num_experts}, dtype={lhs_t.dtype}") + print(f"TILE_M/N/K = {TILE_M}/{TILE_N}/{TILE_K}") + print(f"rhs contiguous={rhs_t.is_contiguous()} stride={rhs_t.stride()}") + + # --------- 1) pad-only ---------- + def do_pad(): + _pad_groups_to_tile_m(lhs_t, gs_t, tile_m=TILE_M) + + pad_ms = cuda_time_ms(do_pad) + + # Prepare once for launch-only timing + lhs_padded, expert_ids_per_tile = _pad_groups_to_tile_m(lhs_t, gs_t, tile_m=TILE_M) + m_padded = lhs_padded.shape[0] + out_t = torch.empty((m_padded, out_features), device=lhs_t.device, dtype=lhs_t.dtype) + + # --------- 2) launch-only ---------- + def do_launch(): + launch_cutile_lora_gemm( + lhs_padded, + rhs_t, + out_t, + expert_ids_per_tile, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + ) + + launch_ms = cuda_time_ms(do_launch) + + # --------- 3) combined ---------- + def do_combined(): + lp, ept = _pad_groups_to_tile_m(lhs_t, gs_t, tile_m=TILE_M) + ob = torch.empty((lp.shape[0], out_features), device=lhs_t.device, dtype=lhs_t.dtype) + launch_cutile_lora_gemm( + lp, + rhs_t, + ob, + ept, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + ) + + combined_ms = cuda_time_ms(do_combined) + + print("\n=== CUDA-event timing breakdown ===") + print(f"pad_groups: {pad_ms:.3f} ms") + print(f"cutile_launch: {launch_ms:.3f} ms") + print(f"combined: {combined_ms:.3f} ms") + print(f"pad fraction: {100.0 * pad_ms / max(combined_ms, 1e-9):.1f}%") + print(f"launch frac: {100.0 * launch_ms / max(combined_ms, 1e-9):.1f}%") + print(f"(pad+launch): {pad_ms + launch_ms:.3f} ms (rough expected)") + + # Optional sanity: ensure nothing is silently syncing on CPU + # If pad_ms is surprisingly large, it’s likely CPU sync from .cpu()/.tolist()/.item(). + + +if __name__ == "__main__": + main() diff --git a/skyrl-tx/tx/kernels/__init__.py b/skyrl-tx/tx/kernels/__init__.py new file mode 100644 index 000000000..1dfb9c6d9 --- /dev/null +++ b/skyrl-tx/tx/kernels/__init__.py @@ -0,0 +1,34 @@ +""" +Cutile-based kernel implementations for SkyRL-tx. + +This package provides optimized CUDA kernels using NVIDIA's cuTile (cuda-tile on PyPI) +for LoRA expert parallelism computation. +""" + +import os + +# Feature flag for cutile LoRA +USE_CUTILE_LORA = os.environ.get("TX_USE_CUTILE_LORA", "0") == "1" + +# Try to import cutile implementation +if USE_CUTILE_LORA: + try: + from .cutile_lora import cutile_ragged_dot + + CUTILE_AVAILABLE = True + except ImportError as e: + import logging + + logger = logging.getLogger(__name__) + logger.warning(f"Cutile not available, falling back to ragged_dot: {e}") + CUTILE_AVAILABLE = False + cutile_ragged_dot = None +else: + CUTILE_AVAILABLE = False + cutile_ragged_dot = None + +__all__ = [ + "USE_CUTILE_LORA", + "CUTILE_AVAILABLE", + "cutile_ragged_dot", +] diff --git a/skyrl-tx/tx/kernels/cutile_config.py b/skyrl-tx/tx/kernels/cutile_config.py new file mode 100644 index 000000000..fe04f5cc1 --- /dev/null +++ b/skyrl-tx/tx/kernels/cutile_config.py @@ -0,0 +1,35 @@ +""" +Configuration for cutile kernel parameters. + +Tile sizes and other kernel configuration options. +""" + +from dataclasses import dataclass +import os + + +@dataclass +class CutileConfig: + """Configuration for cutile LoRA kernels.""" + + # Tile sizes (must be powers of 2 for efficient computation) + tile_m: int = 128 # M dimension (rows/tokens) + tile_n: int = 128 # N dimension (columns/output features) + tile_k: int = 64 # K dimension (inner/reduction) + + # Block swizzling for cache locality + group_size_m: int = 8 + + @classmethod + def from_env(cls): + """Create config from environment variables.""" + return cls( + tile_m=int(os.environ.get("TX_CUTILE_TILE_M", 128)), + tile_n=int(os.environ.get("TX_CUTILE_TILE_N", 128)), + tile_k=int(os.environ.get("TX_CUTILE_TILE_K", 64)), + group_size_m=int(os.environ.get("TX_CUTILE_GROUP_SIZE_M", 8)), + ) + + +# Global config instance +config = CutileConfig.from_env() diff --git a/skyrl-tx/tx/kernels/cutile_lora.py b/skyrl-tx/tx/kernels/cutile_lora.py new file mode 100644 index 000000000..c0dfe262c --- /dev/null +++ b/skyrl-tx/tx/kernels/cutile_lora.py @@ -0,0 +1,229 @@ +""" +JAX-PyTorch interop for cutile LoRA kernels (optimized wrapper). + +Key changes vs prior version: +- Assumes ragged_dot-style contiguous grouping implied by `group_sizes` + (tokens are already laid out group-by-group in order). +- Removes sort/unsort and avoids building per-token expert_ids. +- Pads each group to TILE_M boundary (linear-time, no O(m log m) sort). +- Uses fewer allocations and zero-fills only padded tails. + +Note: +- This still crosses JAX<->Torch each call (DLPack). For benchmarking kernel-only, + prefer a pure-torch harness using torch.cuda.Event timing. +""" + +from __future__ import annotations + +try: + import jax +except ImportError as e: + raise ImportError("JAX is required for cutile LoRA") from e + +try: + import torch +except ImportError as e: + raise ImportError("PyTorch is required for cutile LoRA") from e + +from .cutile_lora_kernels import ( + launch_cutile_lora_gemm, + CUTILE_AVAILABLE, +) +from .cutile_config import config as default_config + + +# ----------------------------------------------------------------------------- +# DLPack helpers +# ----------------------------------------------------------------------------- + + +def jax_to_torch(jax_arr: "jax.Array") -> "torch.Tensor": + device_str = str(jax_arr.device).lower() + if "gpu" not in device_str and "cuda" not in device_str: + raise ValueError(f"Expected GPU array, got device: {jax_arr.device}") + try: + return torch.from_dlpack(jax_arr) # zero-copy when devices match + except Exception as e: + raise RuntimeError(f"DLPack conversion JAX->Torch failed: {e}") from e + + +def torch_to_jax(torch_tensor: "torch.Tensor") -> "jax.Array": + if not torch_tensor.is_cuda: + raise ValueError(f"Expected CUDA tensor, got device: {torch_tensor.device}") + try: + return jax.dlpack.from_dlpack(torch_tensor) + except Exception as e: + raise RuntimeError(f"DLPack conversion Torch->JAX failed: {e}") from e + + +# ----------------------------------------------------------------------------- +# Group padding (no sort/unsort) +# ----------------------------------------------------------------------------- + +# tx/kernels/cutile_lora.py (near top-level) +_PAD_PLAN_CACHE = {} + + +def _make_pad_plan(group_sizes_cpu, tile_m): + ps_cpu = [((g + tile_m - 1) // tile_m) * tile_m for g in group_sizes_cpu] + m_padded = sum(ps_cpu) + + expert_ids_list = [] + for e, p in enumerate(ps_cpu): + expert_ids_list.extend([e] * (p // tile_m)) + + return ps_cpu, m_padded, expert_ids_list + + +def _pad_groups_to_tile_m( + lhs: torch.Tensor, # [m, d], groups contiguous in order + group_sizes: torch.Tensor, # [E], int32/int64 on CUDA + tile_m: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + lhs_padded: [m_padded, d] + expert_ids_per_tile: [num_tiles_total] int32 + """ + if group_sizes.numel() == 0: + # Degenerate: no experts + m, d = lhs.shape + lhs_padded = lhs.new_empty((0, d)) + expert_ids_per_tile = torch.empty((0,), device=lhs.device, dtype=torch.int32) + return lhs_padded, expert_ids_per_tile + + # Make sure group_sizes is on same device + if group_sizes.device != lhs.device: + group_sizes = group_sizes.to(device=lhs.device) + + # Compute padded sizes: ceil(gs / tile_m) * tile_m + # Keep this on GPU to avoid sync; we only pull small scalars when looping over E. + gs = group_sizes.to(dtype=torch.int64) + padded_sizes = ((gs + tile_m - 1) // tile_m) * tile_m # [E] + # E is small; pulling scalar is fine + m_padded = int(padded_sizes.sum().item()) + + m, d = lhs.shape + lhs_padded = torch.empty((m_padded, d), device=lhs.device, dtype=lhs.dtype) + + # Build expert_ids_per_tile on CPU list then upload (E is small, num_tiles ~ m/tile_m) + # If this becomes a bottleneck, you can build it on GPU with a small kernel later. + expert_ids_list: list[int] = [] + + in_off = 0 + out_off = 0 + + # Loop over experts (small) + # E = int(gs.numel()) + # Pull to CPU once to avoid many device->host syncs + gs_cpu = tuple(group_sizes.detach().cpu().tolist()) + key = (tile_m, gs_cpu, lhs.dtype) + + plan = _PAD_PLAN_CACHE.get(key) + if plan is None: + ps_cpu, m_padded, expert_ids_list = _make_pad_plan(gs_cpu, tile_m) + expert_ids_per_tile = torch.tensor(expert_ids_list, device=lhs.device, dtype=torch.int32) + plan = (ps_cpu, m_padded, expert_ids_per_tile) + _PAD_PLAN_CACHE[key] = plan + + ps_cpu, m_padded, expert_ids_per_tile = plan + + m, d = lhs.shape + lhs_padded = torch.empty((m_padded, d), device=lhs.device, dtype=lhs.dtype) + + in_off = 0 + out_off = 0 + for e, (g, p) in enumerate(zip(gs_cpu, ps_cpu)): + if g: + lhs_padded[out_off : out_off + g].copy_(lhs[in_off : in_off + g]) + tail = p - g + if tail: + lhs_padded[out_off + g : out_off + p].zero_() + in_off += g + out_off += p + + return lhs_padded, expert_ids_per_tile + + +# ----------------------------------------------------------------------------- +# Main API +# ----------------------------------------------------------------------------- + + +def cutile_ragged_dot( + lhs: "jax.Array", # [m, d] + rhs: "jax.Array", # [E, d, out] + group_sizes: "jax.Array", # [E] + precision=None, # ignored + preferred_element_type=None, # ignored + group_offset: "jax.Array | None" = None, # not supported in this wrapper +) -> "jax.Array": + """ + Optimized drop-in replacement for ragged_dot when group_sizes implies + contiguous groups in lhs order. + + Output order matches input order (no sort/unsort). + """ + if not CUTILE_AVAILABLE: + raise RuntimeError( + "Cutile not available. Install with:\n" + " pip install cuda-tile\n" + "Note: CUDA Toolkit 13.1+ is required (install separately)\n" + "Or set TX_USE_CUTILE_LORA=0 to use ragged_dot" + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA not available. Cutile requires NVIDIA GPU.") + if group_offset is not None: + raise NotImplementedError("group_offset not supported in this Phase-1 cutile implementation.") + + # Convert JAX arrays to Torch (zero-copy) + lhs_t = jax_to_torch(lhs) + rhs_t = jax_to_torch(rhs) + gs_t = jax_to_torch(group_sizes) + + # Basic validation + if lhs_t.ndim != 2: + raise ValueError(f"lhs must be [m,d], got shape {tuple(lhs_t.shape)}") + if rhs_t.ndim != 3: + raise ValueError(f"rhs must be [E,d,out], got shape {tuple(rhs_t.shape)}") + if gs_t.ndim != 1: + raise ValueError(f"group_sizes must be [E], got shape {tuple(gs_t.shape)}") + + m, d = lhs_t.shape + E, d2, out_features = rhs_t.shape + if d2 != d: + raise ValueError(f"rhs d ({d2}) must match lhs d ({d})") + if gs_t.numel() != E: + raise ValueError(f"group_sizes len ({gs_t.numel()}) must equal num experts ({E})") + + # Optional: ensure integer type + if gs_t.dtype not in (torch.int32, torch.int64): + gs_t = gs_t.to(torch.int32) + + # Pad groups to TILE_M boundary (no sort/unsort) + TILE_M = int(getattr(default_config, "tile_m")) + TILE_N = int(getattr(default_config, "tile_n")) + TILE_K = int(getattr(default_config, "tile_k")) + + lhs_padded, expert_ids_per_tile = _pad_groups_to_tile_m(lhs_t, gs_t, tile_m=TILE_M) + m_padded = lhs_padded.shape[0] + + # Allocate output (use empty; kernel will write all valid tiles) + out_t = torch.empty((m_padded, out_features), device=lhs_t.device, dtype=lhs_t.dtype) + + # Launch kernel + launch_cutile_lora_gemm( + lhs_padded, + rhs_t, + out_t, + expert_ids_per_tile, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + ) + + # Slice back to original m (order preserved) + out_trim = out_t[:m] + + # Convert back to JAX (zero-copy) + return torch_to_jax(out_trim) diff --git a/skyrl-tx/tx/kernels/cutile_lora_kernels.py b/skyrl-tx/tx/kernels/cutile_lora_kernels.py new file mode 100644 index 000000000..2b9c33414 --- /dev/null +++ b/skyrl-tx/tx/kernels/cutile_lora_kernels.py @@ -0,0 +1,252 @@ +""" +Cutile CUDA kernel implementations for LoRA expert parallelism. + +This module contains the actual CUDA kernels using NVIDIA's cuTile (cuda-tile on PyPI). +""" + +from typing import Tuple + +try: + import torch +except ImportError: + raise ImportError("PyTorch is required for cutile kernels") + +try: + from cuda import tile as ct + from cuda.tile import Constant as ConstInt + + CUTILE_AVAILABLE = True +except ImportError: + CUTILE_AVAILABLE = False + + # Define dummy types for syntax checking + class DummyModule: + def __getattr__(self, name): + return lambda *args, **kwargs: None + + ct = DummyModule() + ConstInt = int + + +from .cutile_config import config as default_config + + +# ============================================================================ +# Token Sorting Utilities +# ============================================================================ + + +def lora_align_tile_size( + hidden_states: torch.Tensor, + expert_ids: torch.Tensor, + tile_m: int = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Sort tokens by expert assignment and pad to tile boundaries. + + Adapted from moe_align_tile_size_torch in cutile-moe.py. + + Args: + hidden_states: Input tokens [m, d] + expert_ids: Expert assignment per token [m] + tile_m: Tile size for M dimension (default from config) + + Returns: + Tuple of: + - sorted_hidden_states: Tokens sorted by expert [m_padded, d] + - sorted_token_ids: Original indices of sorted tokens [m] + - sorted_expert_ids: Expert ID per tile [num_tiles] + """ + if tile_m is None: + tile_m = default_config.tile_m + + m, d = hidden_states.shape + device = hidden_states.device + dtype = hidden_states.dtype + + # Sort tokens by expert assignment + _, sorted_token_ids = torch.sort(expert_ids) + sorted_hidden_states = hidden_states[sorted_token_ids] + + # Count tokens per expert + num_experts = int(expert_ids.max().item()) + 1 + expert_counts = torch.bincount(expert_ids, minlength=num_experts) + + # Compute padding needed per expert to align to tile_m + expert_counts_padded = torch.ceil(expert_counts.float() / tile_m).long() * tile_m + total_padded = expert_counts_padded.sum().item() + + # Create padded tensor + sorted_hidden_states_padded = torch.zeros(total_padded, d, dtype=dtype, device=device) + sorted_hidden_states_padded[:m] = sorted_hidden_states + + # Create expert ID per tile + sorted_expert_ids_per_tile = [] + for expert_id in range(num_experts): + expert_tokens_padded = expert_counts_padded[expert_id].item() + num_tiles = expert_tokens_padded // tile_m + sorted_expert_ids_per_tile.extend([expert_id] * num_tiles) + + sorted_expert_ids_per_tile = torch.tensor(sorted_expert_ids_per_tile, dtype=torch.int32, device=device) + + return sorted_hidden_states_padded, sorted_token_ids, sorted_expert_ids_per_tile + + +# ============================================================================ +# 2D Swizzling Utility +# ============================================================================ + + +def swizzle_2d( + M: int, + N: int, + TILE_M: int, + TILE_N: int, + GROUP_SIZE_M: int, +) -> Tuple[int, int]: + """Compute 2D block swizzling for better cache locality. + + This function must be called from within a cutile kernel context. + Matches the reference implementation from cutile-moe.py. + + Args: + M: Total rows + N: Total columns + TILE_M: Tile size in M dimension + TILE_N: Tile size in N dimension + GROUP_SIZE_M: Number of M blocks to group together + + Returns: + Tuple of (bid_m, bid_n) - block indices in M and N dimensions + """ + bid = ct.bid(axis=0) + num_bid_m = ct.cdiv(M, TILE_M) + num_bid_n = ct.cdiv(N, TILE_N) + num_bid_in_group = GROUP_SIZE_M * num_bid_n + group_id = bid // num_bid_in_group + first_bid_m = group_id * GROUP_SIZE_M + # Handle edge case when remaining blocks < GROUP_SIZE_M + # Match reference implementation: use min() - cutile can handle Python min in this context + remaining = num_bid_m - first_bid_m + group_size_m = min(remaining, GROUP_SIZE_M) # Python min works here + bid_m = first_bid_m + (bid % group_size_m) + bid_n = (bid % num_bid_in_group) // group_size_m + return bid_m, bid_n + + +# ============================================================================ +# Cutile Kernel +# ============================================================================ + +if CUTILE_AVAILABLE: + + @ct.kernel + def cutile_lora_gemm_kernel( + hidden_states: torch.Tensor, # [M, K] + # [E, K, N] (if you can store it this way) + weights: torch.Tensor, + output: torch.Tensor, # [M, N] + expert_ids_per_tile: torch.Tensor, + TILE_M: ConstInt, + TILE_N: ConstInt, + TILE_K: ConstInt, + ): + M = hidden_states.shape[0] + K = hidden_states.shape[1] + N = output.shape[1] + + bid_m, bid_n = swizzle_2d(M, N, TILE_M, TILE_N, GROUP_SIZE_M=8) + + start_m = bid_m * TILE_M + start_n = bid_n * TILE_N + + expert_id = ct.load(expert_ids_per_tile, index=bid_m, shape=()) + zero = ct.PaddingMode.ZERO + + acc = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32) + + # Hoist aranges if you still need them (not needed with ct.load offset+shape) + for start_k in range(0, K, TILE_K): + a = ct.load( + hidden_states, + (start_m, start_k), + shape=(TILE_M, TILE_K), + order=(1, 0), + padding_mode=zero, + ) + b = ct.load( + weights, + (expert_id, start_k, start_n), + shape=(1, TILE_K, TILE_N), + order=(0, 2, 1), + padding_mode=zero, + ).reshape((TILE_K, TILE_N)) + + acc = ct.mma(a, b, acc) + + out_tile = ct.astype(acc, output.dtype) + output_m_indices = start_m + ct.arange(TILE_M, dtype=ct.int32) + output_n_indices = start_n + ct.arange(TILE_N, dtype=ct.int32) + ct.scatter( + output, + (output_m_indices[:, None], output_n_indices[None, :]), + out_tile, + ) + + +# ============================================================================ +# Kernel Launch +# ============================================================================ + + +def launch_cutile_lora_gemm( + sorted_hidden_states: torch.Tensor, + weights: torch.Tensor, + output: torch.Tensor, + sorted_expert_ids_per_tile: torch.Tensor, + TILE_M: int = None, + TILE_N: int = None, + TILE_K: int = None, +): + """Launch cutile kernel for LoRA expert computation. + + Args: + sorted_hidden_states: Sorted and padded tokens [m_padded, d] + weights: Expert weights [num_experts, d, out_features] + output: Output buffer [m_padded, out_features] + sorted_expert_ids_per_tile: Expert ID per tile [num_tiles] + TILE_M, TILE_N, TILE_K: Tile sizes (default from config) + """ + if not CUTILE_AVAILABLE: + raise RuntimeError("Cutile not available. Cannot run CUDA kernels.") + + # Use config defaults if not specified + if TILE_M is None: + TILE_M = default_config.tile_m + if TILE_N is None: + TILE_N = default_config.tile_n + if TILE_K is None: + TILE_K = default_config.tile_k + + m_padded, d = sorted_hidden_states.shape + out_features = weights.shape[2] + + # Compute grid dimensions + grid_m = ct.cdiv(m_padded, TILE_M) + grid_n = ct.cdiv(out_features, TILE_N) + grid = (grid_m * grid_n,) + + # Launch kernel + ct.launch( + torch.cuda.current_stream(), + grid, + cutile_lora_gemm_kernel, + ( + sorted_hidden_states, + weights, + output, + sorted_expert_ids_per_tile, + TILE_M, + TILE_N, + TILE_K, + ), + ) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 9bb1ac808..5824198e2 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -91,8 +91,8 @@ def apply_lora( intermediate = self.lora_A[...][adapter_indices_sorted, x_sorted, :] else: # Linear path: x @ A - intermediate = jax.lax.ragged_dot(x_sorted, self.lora_A[...], group_sizes) - lora_output_sorted = jax.lax.ragged_dot(intermediate, self.lora_B[...], group_sizes) + intermediate = ragged_dot(x_sorted, self.lora_A[...], group_sizes) + lora_output_sorted = ragged_dot(intermediate, self.lora_B[...], group_sizes) # Unsort, reshape, scale lora_output = lora_output_sorted[unsort_indices].reshape(batch_size, seq_len, -1) diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 0030c604d..04f33d8b7 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -3,6 +3,23 @@ from jax import lax from jax import numpy as jnp from jax.sharding import get_abstract_mesh, PartitionSpec +import os +import logging + +# Cutile integration (Phase 1) +USE_CUTILE_LORA = os.environ.get("TX_USE_CUTILE_LORA", "0") == "1" +_cutile_ragged_dot = None + +if USE_CUTILE_LORA: + try: + from tx.kernels.cutile_lora import cutile_ragged_dot as _cutile_ragged_dot + + logger = logging.getLogger(__name__) + logger.info("Cutile LoRA enabled (TX_USE_CUTILE_LORA=1)") + except ImportError as e: + logger = logging.getLogger(__name__) + logger.warning(f"Cutile LoRA requested but not available: {e}") + USE_CUTILE_LORA = False def ragged_dot( @@ -17,7 +34,20 @@ def ragged_dot( When group_offset is specified, rhs contains groups [offset, offset + g_local). Tokens outside this range are routed to boundary groups and masked to zero. + + Phase 1 Cutile Integration: + - If TX_USE_CUTILE_LORA=1 and group_offset is None, uses cutile kernels + - Falls back to JAX ragged_dot otherwise """ + # Phase 1: Cutile only for single-GPU (no group_offset) + if USE_CUTILE_LORA and _cutile_ragged_dot is not None and group_offset is None: + try: + return _cutile_ragged_dot(lhs, rhs, group_sizes, precision, preferred_element_type) + except Exception as e: + logger = logging.getLogger(__name__) + logger.error(f"Cutile failed, falling back to ragged_dot: {e}") + # Fall through to JAX implementation + if group_offset is None: return lax.ragged_dot( lhs,