-
Notifications
You must be signed in to change notification settings - Fork 6
[Triton] DS FP4/FP8 Triton fusion and GEMM optimization #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Making the BMM use fp4 weights
…ABLE_RMSNORM_QUANT_FUSION
Enable DSR1 FP8 Optimizations
ChuanLi1101
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, approved for benchmark testing.
| # For Triton FP8 Blockscale GEMM is mostly slower then AITER GEMM, we turn off Triton FP8 GEMM | ||
| # from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle as gemm_a8w8_blockscale_bpreshuffle_triton | ||
| except: | ||
| gemm_afp4wfp4_preshuffle = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Use specific exceptions and add logging:
if use_triton_gemm():
try:
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle
except ImportError as e:
logger.warning(f"Triton FP4 GEMM not available: {e}")
gemm_afp4wfp4_preshuffle = None
| try: | ||
| from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_preshuffle_split_cat | ||
| from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_preshuffle_split_cat | ||
| except: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if use_triton_gemm():
try:
from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_preshuffle_split_cat
from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_preshuffle_split_cat
except ImportError as e:
logger.debug(f"Triton fused GEMM split_cat not available: {e}")
fused_gemm_afp4wfp4_preshuffle_split_cat = None
fused_gemm_a8w8_blockscale_preshuffle_split_cat = None
| from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle | ||
| from aiter.ops.triton.gemm_a16w8_blockscale import gemm_a16w8_blockscale_preshuffle | ||
| except: | ||
| gemm_afp4wfp4_preshuffle = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add logger?
logger.warning(f"Triton GEMM kernels not available: {e}. Ensure AITER is up-to-date.")
| shuffle=(m >= 32), | ||
| ) | ||
|
|
||
| if m >= 32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use module constant?
In both files, import or define:
from atom.models.deepseek_v2 import MXFP4_QUANT_BLOCK_SIZE
Then use:
if m >= MXFP4_QUANT_BLOCK_SIZE:
x_scale = x_scale.view(torch.uint8).view(x_scale.shape[0] // MXFP4_QUANT_BLOCK_SIZE, -1)
| # return x.view(h, b, d // 2), x_scales.view(h, b, d // 32) | ||
|
|
||
|
|
||
| def mxfp4_to_f32(x, is_threed): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicated with the one in aiter
This PR is co-authored by @k50112113, @omuhamma (#61) and @farlukas (#116)
This PR provides Triton fusion/GEMM optimizations for DS FP4 and FP8,
please use the following AITER branch for testing for now as some of the PRs are yet to be merged to AITER main:
https://github.com/ROCm/aiter/tree/shaoclee/atom_triton_tmp_0106
The required AITER PRs include:
To activate the optimizations on ATOM, the following env variables are required:
The following command along with the above env var are used to derive e2e performance results:
For client command:
DS FP8 performance comparisons and uplift

DS FP4 performance comparisons and uplift
