Skip to content

[TRITON] Add Attention support to the bench_models benchmarking script#2274

Open
lucas-santos-amd wants to merge 2 commits intomainfrom
lusantos/add_attention_to_bench_models
Open

[TRITON] Add Attention support to the bench_models benchmarking script#2274
lucas-santos-amd wants to merge 2 commits intomainfrom
lusantos/add_attention_to_bench_models

Conversation

@lucas-santos-amd
Copy link
Contributor

@lucas-santos-amd lucas-santos-amd commented Mar 13, 2026

Motivation

  • Add standalone Unified Attention Benchmark.
  • Add support for MHA, MLA and Unified Attention to the bench_models benchmarking script.

Technical Details

New help text of bench_models.py:

usage: bench_models.py [-h] [--batch_size BATCH_SIZE [BATCH_SIZE ...]]
                       [--seq_len SEQ_LEN [SEQ_LEN ...]] [--TP {1,2,4,8}]
                       [--metric {throughput,bandwidth,time}] [--model MODEL]
                       [--layout {TN,TT,NN,NT}] [--output_file OUTPUT_FILE]
 
Model benchmarking tool
 
options:
  -h, --help            show this help message and exit
  --batch_size BATCH_SIZE [BATCH_SIZE ...]
                        Batch size(s) to sweep. Accepts:
                          Single value:            --batch_size 1
                          Multiple values:         --batch_size 1 2 4
                          Range start:stop:step:   --batch_size 1:8:2
                          Combinations of values and ranges are also accepted.
                        Default: 1.
  --seq_len SEQ_LEN [SEQ_LEN ...]
                        Sequence length(s) to sweep. Accepts:
                          Single value:            --seq_len 512
                          Multiple values:         --seq_len 256 512 1024
                          Range start:stop:step:   --seq_len 128:1024:128
                          Combinations of values and ranges are also accepted.
                        For non-attention kernels, M = batch_size x seq_len is passed as M.
                        Default: 4096.
  --TP {1,2,4,8}        Tensor parallel size. Default: 8.
  --metric {throughput,bandwidth,time}
                        Metric to report (throughput=TFLOPS, bandwidth=GB/s, time=ms). Default: throughput. RoPE reports total flops (total floating-point operations) in a separate column (see note in output).MLA benchmark only reports time (ms).
  --model MODEL         model name filter: case-insensitive regex matched against model name (default: all models). e.g. 'llama3' to include only Llama3 family, 'llama|qwen' to include both Llama and Qwen families, '^(?!.*deepseek)' to exclude DeepSeek family.
                        Available models: Llama3 405B, Llama3 70B, Llama3 8B, GPT-OSS 120B, DeepSeek-R1, Llama4 Maverick, Qwen3-235B-A22B.
  --layout {TN,TT,NN,NT}
                        GEMM layout. Default: TN.
  --output_file OUTPUT_FILE
                        Name for the CSV output file. Default: bench_results.

Description of output CSV file:

CSV column Description Type Relevant kernels
Model LLM model name String All
Kernel Kernel name String All
batch_size Input batch size Integer MHA, MLA, Unified Attention
seq_len Input sequence length Integer RoPE, MHA, MLA, Unified Attention
B Batched GEMM batch size Integer Batched GEMM
M M dimension Integer GEMM, batched GEMM, RMSNorm, MoE
N N dimension Integer GEMM, batched GEMM, RMSNorm
K K dimension Integer GEMM, batched GEMM
gemm_layout GEMM layout String GEMM, batched GEMM
hq Number of Q heads Integer RoPE, MHA, MLA, Unified Attention
hkv Number of KV heads Integer RoPE, MHA, MLA, Unified Attention
dqk Head size of Q and K Integer MHA, MLA, Unified Attention
dv Head size of V Integer MHA, MLA, Unified Attention
rotary_dim RoPE rotary dimension Integer RoPE
rotate_style RoPE rotate style String RoPE
experts Number of experts Integer MoE
moe_dim1 First MoE dimension (i.e. hidden_size) Integer MoE
moe_dim2 Second MoE dimension (i.e. moe_intermediate_size*2) Integer MoE
topk Number of experts per token Integer MoE
Performance metric Can be time in ms, throughput in tflops or bandwidth in gpbs Decimal GEMM, batched GEMM, RMSNorm, MoE, MHA, Unified Attention
time(ms) *Time in ms Decimal MLA
total_flops(tflops) **Total floating-point operations Decimal RoPE

*MLA benchmark only reports Time(ms)
**RoPE reports only total floating-point operations, not throughput (TFLOPS).

Test Plan

Test Result

Submission Checklist

@lucas-santos-amd lucas-santos-amd self-assigned this Mar 13, 2026
@lucas-santos-amd lucas-santos-amd added enhancement New feature or request triton labels Mar 13, 2026
@github-actions
Copy link
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2274 --add-label <label>

@lucas-santos-amd lucas-santos-amd changed the title [TRITON] Add Attention support to bench_models.py [TRITON] Add Attention support to the bench_models benchmarking script Mar 13, 2026
@lucas-santos-amd lucas-santos-amd marked this pull request as ready for review March 13, 2026 19:14
@lucas-santos-amd lucas-santos-amd force-pushed the lusantos/add_attention_to_bench_models branch from 766780b to f94f542 Compare March 13, 2026 19:37
@lucas-santos-amd
Copy link
Contributor Author

@cagrikymk Please take a look at the unified attention benchmark when you can. I converted it to pure python and added BW/TFLOPS calculations for the reduction kernel.

def build_args(self) -> str:
shape = self._shape
return (
f"-mode fwd -causal true --layout bshd --dtype bf16 -b {self._batch_size} "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double check if bshd layout is the proper one for this benchmark. You can also choose thd layout.

return RmsnormKernelHandler()
if kernel == "rope":
return RopeKernelHandler()
if kernel == "mha":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

I think _get_handler function is becoming hard to maintain. I don't know if we can do anything better...

Something like this comes to my mind:

_HANDLER_RULES: list[tuple[Callable[[str], bool], type[KernelHandler]]] = [
    (lambda k: "moe" in k, MoeKernelHandler),
    (lambda k: "gemm" in k, GemmKernelHandler),
    (lambda k: k == "rmsnorm", RmsnormKernelHandler),
    (lambda k: k == "rope", RopeKernelHandler),
    (lambda k: k == "mha", MhaKernelHandler),
    (lambda k: k == "mla", MlaKernelHandler),
    (lambda k: k == "unified_attention", UnifiedAttnKernelHandler),
]

def _get_handler(kernel: str) -> KernelHandler:
    for predicate, handler_cls in _HANDLER_RULES:
        if predicate(kernel):
            return handler_cls()
    raise ValueError(f"Kernel {kernel} not supported")

Maybe it's more complex than the if statements... Please use your own judgement.

for shape in shapes:
s = shape.copy()
s["hq"] = max(shape["hq"] // self._tp, 1)
s["hkv"] = max(shape["hkv"] // self._tp, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

This computation is repeated over and over:

s[key] = max(s[key] // self._tp, 1)

return [{**s, "Dim2": max(s["Dim2"] // self._tp, 1)} for s in shapes]

s["num_heads"] = max(s["num_heads"] // self._tp, 1)
s["num_kv_heads"] = max(s["num_kv_heads"] // self._tp, 1)

s["hq"] = max(shape["hq"] // self._tp, 1)
s["hkv"] = max(shape["hkv"] // self._tp, 1)

s["hq"] = max(shape["hq"] // self._tp, 1)
s["hkv"] = max(shape["hkv"] // self._tp, 1)

s["hq"] = max(shape["hq"] // self._tp, 1)
s["hkv"] = max(shape["hkv"] // self._tp, 1)

I'd add some utility methods to KernelHandler class, to avoid repetition:

def shard(self, value: int) -> int:
    return max(value // self._tp, 1)

def shard_keys(self, s: dict, keys: list[str]) -> None:
    for key in keys:
        s[key] = self._shard(s[key])

Usage examples:

return [{**s, "Dim2": self.shard(s["Dim2"])} for s in shapes]

self._shard_keys(s, ["num_heads", "num_kv_heads"])

self._shard_keys(s, ["hq", "hkv"])

dtype=torch.int32,
device="cuda",
)
values = torch.arange(0, num_blocks, dtype=torch.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

Please try to avoid moving data from host to GPU with .to("cuda"). I'd try to generate values tensor in GPU:

-    values = torch.arange(0, num_blocks, dtype=torch.int32)
-    values = values[torch.randperm(num_blocks)]
+    values = torch.arange(0, num_blocks, dtype=torch.int32, device="cuda")
+    values = values[torch.randperm(num_blocks, device="cuda")]
     block_tables = (
         values[: num_seqs * max_num_blocks_per_seq]
         .view(num_seqs, max_num_blocks_per_seq)
         .contiguous()
-        .to("cuda")
     )

kv_lens = torch.tensor(kv_lens, dtype=torch.int32, device="cuda")

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code can be removed:

    block_tables = torch.randint(
        0,
        num_blocks,
        (num_seqs, max_num_blocks_per_seq),
        dtype=torch.int32,
        device="cuda",
    )

block_tables is immediately overwritten, and the data generated by this statement never used as far as I can see.

@@ -0,0 +1,578 @@
import math
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a quick review of op_tests/op_benchmarks/triton/bench_unified_attention.py and couldn't see anything seriously wrong with it. However, a thumbs up from @cagrikymk is mandatory IMHO (he doesn't need to check the entire PR, just bench_unified_attention.py).

Copy link
Contributor

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucas-santos-amd, I have some suggestions, but no blockers. It would be nice to have a thumbs up from Cagri regarding op_tests/op_benchmarks/triton/bench_unified_attention.py.

@cagrikymk
Copy link
Contributor

There is a related PR: #2190 @juuso-oskari

It also contains some benchmarking related updates.

@brunomazzottiamd
Copy link
Contributor

There is a related PR: #2190 @juuso-oskari

It also contains some benchmarking related updates.

What's the best course of action in your opinion? Wait for Juuso's PR to get merged and then integrate his benchmark into the uber-benchmark?

@Chi-Chu319
Copy link
Contributor

There is a related PR: #2190 @juuso-oskari
It also contains some benchmarking related updates.

What's the best course of action in your opinion? Wait for Juuso's PR to get merged and then integrate his benchmark into the uber-benchmark?

We would probably prefer this way if possible. We have some other features implemented on top of our branch (#2300) while waiting for #2190 to be merged

@lucas-santos-amd
Copy link
Contributor Author

There is a related PR: #2190 @juuso-oskari
It also contains some benchmarking related updates.

What's the best course of action in your opinion? Wait for Juuso's PR to get merged and then integrate his benchmark into the uber-benchmark?

We would probably prefer this way if possible. We have some other features implemented on top of our branch (#2300) while waiting for #2190 to be merged

We will do it like this then

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci:all enhancement New feature or request triton

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants