[triton] re-write gemm/grouped_gemm triton backend for gfx942. #235
[triton] re-write gemm/grouped_gemm triton backend for gfx942. #235
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces Triton persistent kernel implementations for GEMM and grouped GEMM operations on AMD gfx942 hardware, supporting both BF16/FP16 and FP8 data types with multiple scaling granularities (tensorwise, rowwise, blockwise). The implementation aims to eliminate CPU-GPU synchronization overhead through persistent kernels while maintaining compatibility with existing CK and HIPBLASLT backends.
Changes:
- Added comprehensive Triton kernel implementations for GEMM/grouped GEMM operations supporting BF16/FP16 and FP8 dtypes
- Extended test coverage to include the new Triton backend across all test cases
- Updated dispatch logic to support backend selection (csrc/triton) via environment variables
- Enhanced benchmark workflow to measure Triton backend performance alongside existing backends
Reviewed changes
Copilot reviewed 15 out of 16 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| tests/pytorch/ops/test_grouped_gemm_fp8.py | Added BackendType.TRITON to test parametrization for FP8 grouped GEMM tests |
| tests/pytorch/ops/test_grouped_gemm.py | Added BackendType.TRITON to test parametrization for BF16 grouped GEMM tests |
| primus_turbo/triton/grouped_gemm/grouped_gemm_kernel.py | New persistent kernel implementations for BF16/FP16 grouped GEMM (forward and backward) |
| primus_turbo/triton/grouped_gemm/grouped_gemm_fp8_kernel.py | New persistent kernel implementations for FP8 grouped GEMM with tensorwise/rowwise/blockwise scaling |
| primus_turbo/triton/gemm/gemm_kernel.py | Rewrote BF16/FP16 GEMM kernel to use persistent design with StreamK grid computation |
| primus_turbo/triton/gemm/gemm_fp8_kernel.py | New persistent kernel implementations for FP8 GEMM with three scaling granularities |
| primus_turbo/pytorch/ops/grouped_gemm.py | Renamed gemm_impl to gemm_csrc_impl and updated import structure |
| primus_turbo/pytorch/ops/gemm_fp8.py | Added contiguity checks in backward passes to prevent kernel failures |
| primus_turbo/pytorch/ops/gemm.py | Added dispatch logic to route between csrc and Triton backends based on configuration |
| primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_impl.py | Added Triton backend registration for BF16/FP16 grouped GEMM operations |
| primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py | Added Triton backend registration for FP8 grouped GEMM operations |
| primus_turbo/pytorch/kernels/gemm/gemm_triton_impl.py | Updated triton_op implementation to call new gemm_triton_kernel directly |
| primus_turbo/pytorch/kernels/gemm/gemm_fp8_impl.py | Added Triton backend registration for FP8 GEMM operations |
| primus_turbo/pytorch/kernels/gemm/gemm_csrc_impl.py | Renamed gemm_impl function to gemm_csrc_impl for clarity |
| .github/workflows/benchmark.yaml | Extended benchmark workflow with Triton backend steps for all GEMM variants |
Comments suppressed due to low confidence (1)
primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_impl.py:331
- Empty line removed at end of function. While this doesn't affect functionality, it reduces consistency with surrounding code style which typically maintains blank lines between function definitions and module-level statements.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 16 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 16 out of 17 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
MI300 bench result for this PR: Grouped GEMM BF16: |
e9264b8 to
e8243a6
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 17 out of 18 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tests/pytorch/ops/test_gemm_fp8.py
Outdated
| @pytest.mark.parametrize("format", [Format.E4M3, Format.E5M2]) | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| @pytest.mark.parametrize("backend", [None, BackendType.CK, BackendType.HIPBLASLT]) | ||
| @pytest.mark.parametrize("backend", [BackendType.TRITON, BackendType.CK, BackendType.HIPBLASLT]) |
There was a problem hiding this comment.
The removal of None from the backend parametrization means that tests no longer cover the default backend selection behavior. Consider whether None should be retained to test the default backend path, or if this is intentionally removed because default backend selection is tested elsewhere.
94c6b9f to
7986f8f
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 17 out of 18 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| pytest.skip( | ||
| "Triton persistent kernel uses BLOCK_K=64 / BLOCK_M=256 / BLOCK_N=256; " | ||
| "small dimensions cause illegal memory access in pytest environment" | ||
| ) |
There was a problem hiding this comment.
triton kernel will have illegal memory access if min(m, n, k) < 64. So I bypass these test cases.
No description provided.