[TRITON] Add Attention support to the bench_models benchmarking script#2274
[TRITON] Add Attention support to the bench_models benchmarking script#2274lucas-santos-amd wants to merge 2 commits intomainfrom
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
766780b to
f94f542
Compare
|
@cagrikymk Please take a look at the unified attention benchmark when you can. I converted it to pure python and added BW/TFLOPS calculations for the reduction kernel. |
| def build_args(self) -> str: | ||
| shape = self._shape | ||
| return ( | ||
| f"-mode fwd -causal true --layout bshd --dtype bf16 -b {self._batch_size} " |
There was a problem hiding this comment.
Please double check if bshd layout is the proper one for this benchmark. You can also choose thd layout.
| return RmsnormKernelHandler() | ||
| if kernel == "rope": | ||
| return RopeKernelHandler() | ||
| if kernel == "mha": |
There was a problem hiding this comment.
Suggestion:
I think _get_handler function is becoming hard to maintain. I don't know if we can do anything better...
Something like this comes to my mind:
_HANDLER_RULES: list[tuple[Callable[[str], bool], type[KernelHandler]]] = [
(lambda k: "moe" in k, MoeKernelHandler),
(lambda k: "gemm" in k, GemmKernelHandler),
(lambda k: k == "rmsnorm", RmsnormKernelHandler),
(lambda k: k == "rope", RopeKernelHandler),
(lambda k: k == "mha", MhaKernelHandler),
(lambda k: k == "mla", MlaKernelHandler),
(lambda k: k == "unified_attention", UnifiedAttnKernelHandler),
]
def _get_handler(kernel: str) -> KernelHandler:
for predicate, handler_cls in _HANDLER_RULES:
if predicate(kernel):
return handler_cls()
raise ValueError(f"Kernel {kernel} not supported")Maybe it's more complex than the if statements... Please use your own judgement.
| for shape in shapes: | ||
| s = shape.copy() | ||
| s["hq"] = max(shape["hq"] // self._tp, 1) | ||
| s["hkv"] = max(shape["hkv"] // self._tp, 1) |
There was a problem hiding this comment.
Suggestion:
This computation is repeated over and over:
s[key] = max(s[key] // self._tp, 1)
return [{**s, "Dim2": max(s["Dim2"] // self._tp, 1)} for s in shapes]
s["num_heads"] = max(s["num_heads"] // self._tp, 1)
s["num_kv_heads"] = max(s["num_kv_heads"] // self._tp, 1)
s["hq"] = max(shape["hq"] // self._tp, 1)
s["hkv"] = max(shape["hkv"] // self._tp, 1)
s["hq"] = max(shape["hq"] // self._tp, 1)
s["hkv"] = max(shape["hkv"] // self._tp, 1)
s["hq"] = max(shape["hq"] // self._tp, 1)
s["hkv"] = max(shape["hkv"] // self._tp, 1)I'd add some utility methods to KernelHandler class, to avoid repetition:
def shard(self, value: int) -> int:
return max(value // self._tp, 1)
def shard_keys(self, s: dict, keys: list[str]) -> None:
for key in keys:
s[key] = self._shard(s[key])Usage examples:
return [{**s, "Dim2": self.shard(s["Dim2"])} for s in shapes]
self._shard_keys(s, ["num_heads", "num_kv_heads"])
self._shard_keys(s, ["hq", "hkv"])| dtype=torch.int32, | ||
| device="cuda", | ||
| ) | ||
| values = torch.arange(0, num_blocks, dtype=torch.int32) |
There was a problem hiding this comment.
Suggestion:
Please try to avoid moving data from host to GPU with .to("cuda"). I'd try to generate values tensor in GPU:
- values = torch.arange(0, num_blocks, dtype=torch.int32)
- values = values[torch.randperm(num_blocks)]
+ values = torch.arange(0, num_blocks, dtype=torch.int32, device="cuda")
+ values = values[torch.randperm(num_blocks, device="cuda")]
block_tables = (
values[: num_seqs * max_num_blocks_per_seq]
.view(num_seqs, max_num_blocks_per_seq)
.contiguous()
- .to("cuda")
)| kv_lens = torch.tensor(kv_lens, dtype=torch.int32, device="cuda") | ||
|
|
||
| max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size | ||
| block_tables = torch.randint( |
There was a problem hiding this comment.
I think this code can be removed:
block_tables = torch.randint(
0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32,
device="cuda",
)block_tables is immediately overwritten, and the data generated by this statement never used as far as I can see.
| @@ -0,0 +1,578 @@ | |||
| import math | |||
There was a problem hiding this comment.
I did a quick review of op_tests/op_benchmarks/triton/bench_unified_attention.py and couldn't see anything seriously wrong with it. However, a thumbs up from @cagrikymk is mandatory IMHO (he doesn't need to check the entire PR, just bench_unified_attention.py).
brunomazzottiamd
left a comment
There was a problem hiding this comment.
@lucas-santos-amd, I have some suggestions, but no blockers. It would be nice to have a thumbs up from Cagri regarding op_tests/op_benchmarks/triton/bench_unified_attention.py.
|
There is a related PR: #2190 @juuso-oskari It also contains some benchmarking related updates. |
What's the best course of action in your opinion? Wait for Juuso's PR to get merged and then integrate his benchmark into the uber-benchmark? |
We would probably prefer this way if possible. We have some other features implemented on top of our branch (#2300) while waiting for #2190 to be merged |
We will do it like this then |
Motivation
Technical Details
New help text of
bench_models.py:Description of output CSV file:
*MLA benchmark only reports Time(ms)
**RoPE reports only total floating-point operations, not throughput (TFLOPS).
Test Plan
Test Result
Submission Checklist