-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Workaround for matmul kernel crash with i8xf32 operands. #12
base: main
Are you sure you want to change the base?
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
I signed CLA. |
This change, as mentioned in the title, would only work-around the issue but not fix it. Effectively what this is doing is it removes mixed-precision behavior for any matmuls with s8. Also the current change would regress s8 x . The ideal way we would hope to handle the issue is to fix the limitations of Triton during its lowering to LLVM, and still allow proper mixed-precision mma to happen. |
Could you elaborate on what you mean by "removes mixed-precision behavior for any matmuls with s8"? I ask because the lowered code contains a cast from i8 to f32 before feeding data to the tf32 mma op, which is necessary since the other operand is already f32. Could you also clarify what you mean by "the current change would regress s8 x "? Perhaps you could provide an example to illustrate this point? Thank you. |
My apologies for replying late. Regarding s8 x : We can consider the example of s8 x f16:
I haven't looked deeply into the performance impact, but it is clear that the change is not local. As you can see the change will impact other use-cases. I'm not sure whether what the performance impact is (would be nice if you profile it). The constraints could be tighter to only match on s8 x f32 combinations, but that would still be working around the issue. |
The BlockedToMMA pass creates a layout with kWidth=4 when one operand is i8. However, the TritonGPU to LLVM lowering pass does not support lowering f32 with kWidth=4, which is the other operand, causing a segmentation fault. To work around this, if the operands' minBitWidth is 8 and maxBitWidth is 32, we use a minBitWidth of 16 instead of 8, creating a layout with kWidth=2 for both i8 and f32 operands.
3ff8088
to
b313a8b
Compare
Thank you for the details. I think I understand the issue with the proposed workaround. I have updated this MR with changes that should not affect other mixed-precision matrix multiplications. I verified that the i8xf16 kWidth remains 4 with this workaround. The issue stems from the LLVM lowering pass not supporting f32 with kWidth=4 when lowering for Ampere tensor cores. I am not familiar with Ampere tensor cores and cannot estimate the effort required to fix the issue in the lowering pass. |
Thank you for the modifications. Currently there are discussions whether we would proceed with a work-around or not. I will get back to you once there is a decision. |
Unfortunately, a workaround is not something we can accept for this issue, and would need a proper fix here. We already have a different workaround internally, and the performance benefits we would gain from this do not outweigh the cost of maintaining a patch on top of upstream. |
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/python/test
for end-to-end testsSelect one of the following.
lit
tests.