Conversation
Add the persistent-scheduling FP8 paged-attention kernel and a PS-focused regression harness so split-reduce behavior can be exercised against the Gluon reference. Signed-off-by: sixifang <sixifang@amd.com> Made-with: Cursor
There was a problem hiding this comment.
Pull request overview
This PR updates the paged-attention (PA) decode workflow to focus on persistent-scheduling (PS) mode, adding a PS-only regression harness and significantly expanding the PS implementation in the FP8 PA decode kernel.
Changes:
- Replaced the prior PA test with a PS-only regression harness that compares FlyDSL PS vs Torch and Gluon across a case matrix.
- Implemented/extended persistent-scheduling PA decode in
kernels/pa_decode_fp8.py, including metadata expansion for block-split partials and a sliding-window path.
Reviewed changes
Copilot reviewed 1 out of 2 changed files in this pull request and generated 8 comments.
| File | Description |
|---|---|
tests/kernels/test_pa.py |
New PS-only correctness/perf harness for FlyDSL PS vs Torch/Gluon with multi-case execution and CSV output. |
kernels/pa_decode_fp8.py |
New/expanded PS kernel compilation + launch APIs, metadata expansion for 1024→4×256 block splits, and a sliding-window PS path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| torch.set_default_device("cuda") | ||
| torch.set_printoptions(sci_mode=False) |
There was a problem hiding this comment.
torch.set_default_device("cuda") at import time changes global Torch state for the entire pytest process, which can unintentionally affect other tests (including CPU-only ones) and makes failures harder to diagnose. Prefer using an explicit device variable (as you already do later) and passing it to tensor creation, or set the default device inside the specific test function and restore it afterward.
| def compare_arrays( | ||
| arr1: np.ndarray, | ||
| arr2: np.ndarray, | ||
| k: int = 5, | ||
| thresholds: List[float] = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1], | ||
| ) -> Dict[str, object]: |
There was a problem hiding this comment.
compare_arrays uses a mutable list (thresholds) as a default argument. Even if you don’t currently mutate it, this is a shared object across calls and is easy to accidentally mutate later. Prefer thresholds: Sequence[float] = (...) (tuple) or thresholds: Optional[Sequence[float]] = None with an internal default.
| output_file = ( | ||
| f"run_pa_decode_ps_test.{output_tag}.block_size_{block_sizes[0]}.triton.{TRITON_VERSION}.csv" | ||
| ) | ||
| results_df.to_csv(output_file, index=False) | ||
| print(f"\nResults saved to {output_file}") |
There was a problem hiding this comment.
parse_arg_and_run_test always writes a CSV (results_df.to_csv(...)) even when executed under pytest. Writing files from tests can pollute the workspace and cause issues in hermetic CI environments. Consider only writing the CSV under __main__ execution, or writing to a pytest-provided temp directory (e.g., via a tmp_path fixture) when running under pytest.
| output_file = ( | |
| f"run_pa_decode_ps_test.{output_tag}.block_size_{block_sizes[0]}.triton.{TRITON_VERSION}.csv" | |
| ) | |
| results_df.to_csv(output_file, index=False) | |
| print(f"\nResults saved to {output_file}") | |
| if not running_via_pytest: | |
| output_file = ( | |
| f"run_pa_decode_ps_test.{output_tag}.block_size_{block_sizes[0]}.triton.{TRITON_VERSION}.csv" | |
| ) | |
| results_df.to_csv(output_file, index=False) | |
| print(f"\nResults saved to {output_file}") |
| """FlyDSL Paged Attention Decode with Persistent Scheduling — FP8. | ||
|
|
||
| """ | ||
| Supports kv_block_size=16 (original) and kv_block_size=1024 (trans_v required). | ||
| Extends pa_decode_sw_fp8.py with persistent scheduling (PS) mode: | ||
| - Grid = (num_SM, 1, 4) so each CTA handles one 256-token sub-tile of a 1024-token KV page | ||
| - Outer work loop iterates over pre-computed worklist from get_pa_metadata_v1 | ||
| - Inner KV loop iterates pages from kv_page_indices instead of block_tables | ||
| - Supports split-reduce for load balancing across CUs | ||
|
|
||
| Contains: | ||
| - build_pa_decode_module(): main decode dot-product kernel | ||
| - build_ps_reduce_kernel(): fixed-partition-count reduce kernel | ||
| - build_v2_reduce_kernel(): dynamic-partition-count reduce kernel | ||
| Requires: aiter's get_pa_metadata_v1 (module_pa_metadata.so) | ||
| """ |
There was a problem hiding this comment.
This file no longer includes the repository’s standard SPDX header comment. Most files under kernels/ start with an Apache-2.0 SPDX identifier (e.g., kernels/preshuffle_gemm.py:1). To stay consistent with repo licensing and tooling, add the appropriate # SPDX-License-Identifier: ... header at the top of this file.
kernels/pa_decode_fp8.py
Outdated
|
|
||
| # Note: waves_per_eu=4 causes agpr=0 regression on current build (0c1805f). | ||
| # Leave empty to let LLVM decide — gets agpr=128, vgpr=96, ~203us. | ||
| CompilationContext._compile_hints.data = {} |
There was a problem hiding this comment.
Directly assigning CompilationContext._compile_hints.data = {} mutates global compiler state and can clobber hints set by other callers (and is not thread-safe). The codebase provides a context manager for this (python/flydsl/compiler/kernel_function.py:228-240). Prefer removing this global assignment, or using with CompilationContext.compile_hints({}): ... around the compilation/launch where the hint scope is well-defined.
| CompilationContext._compile_hints.data = {} |
kernels/pa_decode_fp8.py
Outdated
| CompilationContext._compile_hints.data = {} | ||
|
|
There was a problem hiding this comment.
Same issue as earlier: CompilationContext._compile_hints.data = {} mutates global compile hints and can interfere with other compilations. Please scope compile-hint changes using CompilationContext.compile_hints(...) (see python/flydsl/compiler/kernel_function.py:228-240) or avoid resetting hints globally here.
| CompilationContext._compile_hints.data = {} |
| if sliding_window > 0: | ||
| # Launch one CTA per 256-token tile inside each 1024-token physical block: | ||
| # grid = (batch, kv_heads, max_context_partition_num * 4). | ||
| batch_size = context_lengths.shape[0] | ||
| head_size = query.shape[-1] | ||
| eqgs = query_length * query_group_size | ||
| physical_partition_size = KV_BLOCK_SIZE | ||
| context_partition_size = KV_COMPUTE_BLOCK | ||
|
|
||
| if max_context_partition_num == 0: | ||
| max_context_partition_num = ( | ||
| (sliding_window + physical_partition_size - 1) // physical_partition_size | ||
| ) + 1 | ||
| total_context_partition_num = max_context_partition_num * TILES_PER_BLOCK | ||
|
|
||
| if exp_sums is None: | ||
| exp_sums = torch.empty(batch_size, num_kv_heads, total_context_partition_num, eqgs, | ||
| device=dev, dtype=torch.float32) | ||
| if max_logits is None: | ||
| max_logits = torch.full((batch_size, num_kv_heads, total_context_partition_num, eqgs), | ||
| float('-inf'), device=dev, dtype=torch.float32) | ||
| if temporary_output is None: | ||
| temporary_output = torch.zeros(batch_size, num_kv_heads, total_context_partition_num, | ||
| eqgs, head_size, device=dev, dtype=torch.bfloat16) | ||
|
|
||
| compiled_sw = compile_pa_decode_ps_sw( | ||
| sliding_window=sliding_window, | ||
| softmax_scale=softmax_scale, trans_v=trans_v, query_group_size=query_group_size, | ||
| per_token_kv=per_token_kv, query_length=query_length, | ||
| query_input_dtype=query_input_dtype) | ||
|
|
||
| compiled_sw['launch']( | ||
| exp_sums, max_logits, temporary_output, | ||
| query, key_cache, value_cache, | ||
| block_tables, context_lengths, | ||
| query_scale, key_scale, value_scale, | ||
| query.stride(0), query.stride(1), | ||
| key_cache.stride(0), key_cache.stride(1), | ||
| value_cache.stride(0), value_cache.stride(1), | ||
| exp_sums.stride(0), exp_sums.stride(1), exp_sums.stride(2), | ||
| temporary_output.stride(0), temporary_output.stride(1), | ||
| temporary_output.stride(2), temporary_output.stride(3), | ||
| block_tables.stride(0), | ||
| stride_ks_block, stride_ks_head, |
There was a problem hiding this comment.
In the sliding_window > 0 path, block_tables is used unconditionally (block_tables.stride(0) and passed into the kernel), but the function signature allows block_tables: torch.Tensor = None. Calling this API with sliding_window > 0 and leaving block_tables as None will raise an AttributeError. Add an explicit validation early in this branch (e.g., raise a clear ValueError when block_tables is None, or compute/derive the needed mapping internally).
| import numpy as np | ||
| import pandas as pd | ||
| import pytest | ||
| import torch | ||
| import triton | ||
|
|
||
| import aiter | ||
| from aiter import dtypes | ||
| from aiter import per_tensor_quant, pertoken_quant | ||
| from aiter.ops.attention import pa_decode_gluon as _pa_decode_gluon # noqa: F401 | ||
| from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits | ||
| from aiter.test_common import checkAllclose | ||
|
|
||
| REPO_ROOT = Path(__file__).resolve().parents[2] | ||
| if str(REPO_ROOT) not in sys.path: | ||
| sys.path.insert(0, str(REPO_ROOT)) | ||
|
|
||
| try: | ||
| from triton.experimental import gluon # noqa: F401 | ||
| from triton.experimental.gluon import language as gl # noqa: F401 | ||
| HAS_GLUON = True | ||
| except ImportError: | ||
| HAS_GLUON = False | ||
| print("Warning: Triton Gluon is unavailable; Gluon reference checks will fail.") | ||
|
|
||
| try: | ||
| from kernels.pa_decode_fp8 import ( | ||
| get_pa_metadata as flydsl_get_pa_metadata, | ||
| pa_decode_ps_launch as flydsl_ps_launch, | ||
| ) | ||
| HAS_FLYDSL_PS = True | ||
| except ImportError as exc: | ||
| HAS_FLYDSL_PS = False | ||
| print(f"Warning: FlyDSL PA decode PS not available: {exc}") | ||
|
|
||
| pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] | ||
|
|
||
| sys.path.insert(0, 'build-fly/python_packages'); sys.path.insert(1, '.') | ||
| os.environ['FLYDSL_RUNTIME_ENABLE_CACHE'] = '1' | ||
| logging.basicConfig(level=logging.INFO, format='%(message)s') | ||
|
|
||
| aiter = pytest.importorskip("aiter", reason="aiter is not installed, skipping PA tests") | ||
|
|
||
| from tests.test_common import run_perftest, verify_output, checkAllclose | ||
| from kernels.pa_decode_fp8 import build_pa_decode_module, BLOCK_THREADS, QUERY_GROUP_SIZE, HEAD_SIZE, KV_COMPUTE_BLOCK | ||
| import kernels.pa_decode_fp8 as _pa | ||
| from flydsl.compiler.kernel_function import CompilationContext | ||
| from flydsl._mlir import ir as _ir | ||
| import flydsl.compiler as flyc, flydsl.expr as fx | ||
| from flydsl.expr import arith | ||
| from flydsl.expr.typing import T | ||
| from aiter.ops.triton.gluon.pa_decode_gluon import ( | ||
| pa_decode_gluon, get_recommended_splits, | ||
| _paged_attention_decode_v2_reduce_kernel_wrapper, | ||
| ) | ||
| from aiter import per_tensor_quant, dtypes as aiter_dtypes | ||
|
|
||
| CPSZ = 256; QG = QUERY_GROUP_SIZE | ||
| fp8 = torch.float8_e4m3fnuz; bf16 = torch.bfloat16; dev = 'cuda' | ||
| torch.set_default_device("cuda") | ||
| torch.set_printoptions(sci_mode=False) |
There was a problem hiding this comment.
This test module has unguarded GPU- and dependency-specific setup at import/collection time (e.g., import aiter, import triton, and calling torch.set_default_device("cuda")) but no module-level skip/import gates. In environments without optional deps (or without a CUDA/ROCm device), pytest collection will error instead of skipping. Other GPU tests in this repo follow a consistent pattern (e.g., tests/kernels/test_blockscale_preshuffle_gemm.py:17,37-38) using pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] and if not torch.cuda.is_available(): pytest.skip(..., allow_module_level=True), plus pytest.importorskip for optional deps. Aligning with that pattern will keep this long-running harness out of default runs and make collection robust.
|
CI failed @fsx950223 |
Signed-off-by: sixifang <sixifang@amd.com> Made-with: Cursor
Signed-off-by: sixifang <sixifang@amd.com> Made-with: Cursor
Signed-off-by: sixifang <sixifang@amd.com> Made-with: Cursor
Signed-off-by: sixifang <sixifang@amd.com> Made-with: Cursor
Signed-off-by: sixifang <sixifang@amd.com> Made-with: Cursor
Signed-off-by: sixifang <sixifang@amd.com> Made-with: Cursor
…ction error on navi runner Agent-Logs-Url: https://github.com/ROCm/FlyDSL/sessions/cea81057-b6be-4480-a0ca-dcad1210c224 Co-authored-by: fsx950223 <17592563+fsx950223@users.noreply.github.com>
Fix CI: conditional aiter import in test_pa.py to unblock navi runner collection
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist