Skip to content

Flydsl pa#338

Open
fsx950223 wants to merge 13 commits intomainfrom
flydsl_pa
Open

Flydsl pa#338
fsx950223 wants to merge 13 commits intomainfrom
flydsl_pa

Conversation

@fsx950223
Copy link
Copy Markdown
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

sixifang added 2 commits April 2, 2026 08:13
Add the persistent-scheduling FP8 paged-attention kernel and a PS-focused regression harness so split-reduce behavior can be exercised against the Gluon reference.

Signed-off-by: sixifang <sixifang@amd.com>
Made-with: Cursor
Copilot AI review requested due to automatic review settings April 2, 2026 09:46
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the paged-attention (PA) decode workflow to focus on persistent-scheduling (PS) mode, adding a PS-only regression harness and significantly expanding the PS implementation in the FP8 PA decode kernel.

Changes:

  • Replaced the prior PA test with a PS-only regression harness that compares FlyDSL PS vs Torch and Gluon across a case matrix.
  • Implemented/extended persistent-scheduling PA decode in kernels/pa_decode_fp8.py, including metadata expansion for block-split partials and a sliding-window path.

Reviewed changes

Copilot reviewed 1 out of 2 changed files in this pull request and generated 8 comments.

File Description
tests/kernels/test_pa.py New PS-only correctness/perf harness for FlyDSL PS vs Torch/Gluon with multi-case execution and CSV output.
kernels/pa_decode_fp8.py New/expanded PS kernel compilation + launch APIs, metadata expansion for 1024→4×256 block splits, and a sliding-window PS path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +49 to +50
torch.set_default_device("cuda")
torch.set_printoptions(sci_mode=False)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.set_default_device("cuda") at import time changes global Torch state for the entire pytest process, which can unintentionally affect other tests (including CPU-only ones) and makes failures harder to diagnose. Prefer using an explicit device variable (as you already do later) and passing it to tensor creation, or set the default device inside the specific test function and restore it afterward.

Copilot uses AI. Check for mistakes.
Comment on lines +86 to +91
def compare_arrays(
arr1: np.ndarray,
arr2: np.ndarray,
k: int = 5,
thresholds: List[float] = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1],
) -> Dict[str, object]:
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare_arrays uses a mutable list (thresholds) as a default argument. Even if you don’t currently mutate it, this is a shared object across calls and is easy to accidentally mutate later. Prefer thresholds: Sequence[float] = (...) (tuple) or thresholds: Optional[Sequence[float]] = None with an internal default.

Copilot uses AI. Check for mistakes.
Comment on lines +1111 to +1115
output_file = (
f"run_pa_decode_ps_test.{output_tag}.block_size_{block_sizes[0]}.triton.{TRITON_VERSION}.csv"
)
results_df.to_csv(output_file, index=False)
print(f"\nResults saved to {output_file}")
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_arg_and_run_test always writes a CSV (results_df.to_csv(...)) even when executed under pytest. Writing files from tests can pollute the workspace and cause issues in hermetic CI environments. Consider only writing the CSV under __main__ execution, or writing to a pytest-provided temp directory (e.g., via a tmp_path fixture) when running under pytest.

Suggested change
output_file = (
f"run_pa_decode_ps_test.{output_tag}.block_size_{block_sizes[0]}.triton.{TRITON_VERSION}.csv"
)
results_df.to_csv(output_file, index=False)
print(f"\nResults saved to {output_file}")
if not running_via_pytest:
output_file = (
f"run_pa_decode_ps_test.{output_tag}.block_size_{block_sizes[0]}.triton.{TRITON_VERSION}.csv"
)
results_df.to_csv(output_file, index=False)
print(f"\nResults saved to {output_file}")

Copilot uses AI. Check for mistakes.
Comment on lines +1 to 10
"""FlyDSL Paged Attention Decode with Persistent Scheduling — FP8.

"""
Supports kv_block_size=16 (original) and kv_block_size=1024 (trans_v required).
Extends pa_decode_sw_fp8.py with persistent scheduling (PS) mode:
- Grid = (num_SM, 1, 4) so each CTA handles one 256-token sub-tile of a 1024-token KV page
- Outer work loop iterates over pre-computed worklist from get_pa_metadata_v1
- Inner KV loop iterates pages from kv_page_indices instead of block_tables
- Supports split-reduce for load balancing across CUs

Contains:
- build_pa_decode_module(): main decode dot-product kernel
- build_ps_reduce_kernel(): fixed-partition-count reduce kernel
- build_v2_reduce_kernel(): dynamic-partition-count reduce kernel
Requires: aiter's get_pa_metadata_v1 (module_pa_metadata.so)
"""
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file no longer includes the repository’s standard SPDX header comment. Most files under kernels/ start with an Apache-2.0 SPDX identifier (e.g., kernels/preshuffle_gemm.py:1). To stay consistent with repo licensing and tooling, add the appropriate # SPDX-License-Identifier: ... header at the top of this file.

Copilot uses AI. Check for mistakes.

# Note: waves_per_eu=4 causes agpr=0 regression on current build (0c1805f).
# Leave empty to let LLVM decide — gets agpr=128, vgpr=96, ~203us.
CompilationContext._compile_hints.data = {}
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly assigning CompilationContext._compile_hints.data = {} mutates global compiler state and can clobber hints set by other callers (and is not thread-safe). The codebase provides a context manager for this (python/flydsl/compiler/kernel_function.py:228-240). Prefer removing this global assignment, or using with CompilationContext.compile_hints({}): ... around the compilation/launch where the hint scope is well-defined.

Suggested change
CompilationContext._compile_hints.data = {}

Copilot uses AI. Check for mistakes.
Comment on lines 1630 to 1631
CompilationContext._compile_hints.data = {}

Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as earlier: CompilationContext._compile_hints.data = {} mutates global compile hints and can interfere with other compilations. Please scope compile-hint changes using CompilationContext.compile_hints(...) (see python/flydsl/compiler/kernel_function.py:228-240) or avoid resetting hints globally here.

Suggested change
CompilationContext._compile_hints.data = {}

Copilot uses AI. Check for mistakes.
Comment on lines +1465 to +1508
if sliding_window > 0:
# Launch one CTA per 256-token tile inside each 1024-token physical block:
# grid = (batch, kv_heads, max_context_partition_num * 4).
batch_size = context_lengths.shape[0]
head_size = query.shape[-1]
eqgs = query_length * query_group_size
physical_partition_size = KV_BLOCK_SIZE
context_partition_size = KV_COMPUTE_BLOCK

if max_context_partition_num == 0:
max_context_partition_num = (
(sliding_window + physical_partition_size - 1) // physical_partition_size
) + 1
total_context_partition_num = max_context_partition_num * TILES_PER_BLOCK

if exp_sums is None:
exp_sums = torch.empty(batch_size, num_kv_heads, total_context_partition_num, eqgs,
device=dev, dtype=torch.float32)
if max_logits is None:
max_logits = torch.full((batch_size, num_kv_heads, total_context_partition_num, eqgs),
float('-inf'), device=dev, dtype=torch.float32)
if temporary_output is None:
temporary_output = torch.zeros(batch_size, num_kv_heads, total_context_partition_num,
eqgs, head_size, device=dev, dtype=torch.bfloat16)

compiled_sw = compile_pa_decode_ps_sw(
sliding_window=sliding_window,
softmax_scale=softmax_scale, trans_v=trans_v, query_group_size=query_group_size,
per_token_kv=per_token_kv, query_length=query_length,
query_input_dtype=query_input_dtype)

compiled_sw['launch'](
exp_sums, max_logits, temporary_output,
query, key_cache, value_cache,
block_tables, context_lengths,
query_scale, key_scale, value_scale,
query.stride(0), query.stride(1),
key_cache.stride(0), key_cache.stride(1),
value_cache.stride(0), value_cache.stride(1),
exp_sums.stride(0), exp_sums.stride(1), exp_sums.stride(2),
temporary_output.stride(0), temporary_output.stride(1),
temporary_output.stride(2), temporary_output.stride(3),
block_tables.stride(0),
stride_ks_block, stride_ks_head,
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the sliding_window > 0 path, block_tables is used unconditionally (block_tables.stride(0) and passed into the kernel), but the function signature allows block_tables: torch.Tensor = None. Calling this API with sliding_window > 0 and leaving block_tables as None will raise an AttributeError. Add an explicit validation early in this branch (e.g., raise a clear ValueError when block_tables is None, or compute/derive the needed mapping internally).

Copilot uses AI. Check for mistakes.
Comment on lines +14 to +50
import numpy as np
import pandas as pd
import pytest
import torch
import triton

import aiter
from aiter import dtypes
from aiter import per_tensor_quant, pertoken_quant
from aiter.ops.attention import pa_decode_gluon as _pa_decode_gluon # noqa: F401
from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits
from aiter.test_common import checkAllclose

REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))

try:
from triton.experimental import gluon # noqa: F401
from triton.experimental.gluon import language as gl # noqa: F401
HAS_GLUON = True
except ImportError:
HAS_GLUON = False
print("Warning: Triton Gluon is unavailable; Gluon reference checks will fail.")

try:
from kernels.pa_decode_fp8 import (
get_pa_metadata as flydsl_get_pa_metadata,
pa_decode_ps_launch as flydsl_ps_launch,
)
HAS_FLYDSL_PS = True
except ImportError as exc:
HAS_FLYDSL_PS = False
print(f"Warning: FlyDSL PA decode PS not available: {exc}")

pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower]

sys.path.insert(0, 'build-fly/python_packages'); sys.path.insert(1, '.')
os.environ['FLYDSL_RUNTIME_ENABLE_CACHE'] = '1'
logging.basicConfig(level=logging.INFO, format='%(message)s')

aiter = pytest.importorskip("aiter", reason="aiter is not installed, skipping PA tests")

from tests.test_common import run_perftest, verify_output, checkAllclose
from kernels.pa_decode_fp8 import build_pa_decode_module, BLOCK_THREADS, QUERY_GROUP_SIZE, HEAD_SIZE, KV_COMPUTE_BLOCK
import kernels.pa_decode_fp8 as _pa
from flydsl.compiler.kernel_function import CompilationContext
from flydsl._mlir import ir as _ir
import flydsl.compiler as flyc, flydsl.expr as fx
from flydsl.expr import arith
from flydsl.expr.typing import T
from aiter.ops.triton.gluon.pa_decode_gluon import (
pa_decode_gluon, get_recommended_splits,
_paged_attention_decode_v2_reduce_kernel_wrapper,
)
from aiter import per_tensor_quant, dtypes as aiter_dtypes

CPSZ = 256; QG = QUERY_GROUP_SIZE
fp8 = torch.float8_e4m3fnuz; bf16 = torch.bfloat16; dev = 'cuda'
torch.set_default_device("cuda")
torch.set_printoptions(sci_mode=False)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test module has unguarded GPU- and dependency-specific setup at import/collection time (e.g., import aiter, import triton, and calling torch.set_default_device("cuda")) but no module-level skip/import gates. In environments without optional deps (or without a CUDA/ROCm device), pytest collection will error instead of skipping. Other GPU tests in this repo follow a consistent pattern (e.g., tests/kernels/test_blockscale_preshuffle_gemm.py:17,37-38) using pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] and if not torch.cuda.is_available(): pytest.skip(..., allow_module_level=True), plus pytest.importorskip for optional deps. Aligning with that pattern will keep this long-running harness out of default runs and make collection robust.

Copilot uses AI. Check for mistakes.
@coderfeli
Copy link
Copy Markdown
Collaborator

CI failed @fsx950223

sixifang and others added 11 commits April 3, 2026 04:00
Signed-off-by: sixifang <sixifang@amd.com>
Made-with: Cursor
Signed-off-by: sixifang <sixifang@amd.com>
Made-with: Cursor
Signed-off-by: sixifang <sixifang@amd.com>
Made-with: Cursor
Signed-off-by: sixifang <sixifang@amd.com>
Made-with: Cursor
Signed-off-by: sixifang <sixifang@amd.com>
Made-with: Cursor
Signed-off-by: sixifang <sixifang@amd.com>
Made-with: Cursor
…ction error on navi runner

Agent-Logs-Url: https://github.com/ROCm/FlyDSL/sessions/cea81057-b6be-4480-a0ca-dcad1210c224

Co-authored-by: fsx950223 <17592563+fsx950223@users.noreply.github.com>
Fix CI: conditional aiter import in test_pa.py to unblock navi runner collection
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants