Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Requirements to pass WGMMA LHS operand in registers #4785

Open
chsigg opened this issue Sep 23, 2024 · 9 comments
Open

Requirements to pass WGMMA LHS operand in registers #4785

chsigg opened this issue Sep 23, 2024 · 9 comments

Comments

@chsigg
Copy link
Collaborator

chsigg commented Sep 23, 2024

NVIDIA is implementing an optimization to pass the LHS operand of WGMMA ops in register. This allows element-wise prologues to pass the intermediate result directly to WGMMA without writing it to shared memory like it currently does.

The OpenXLA team is currently reviewing NVIDIA's changes with the intent of eventually writing a PR against this repository. We heard through @ThomasRaoux and @gflegar that the Triton team is planning a similar feature, so it would be great to align the requirements. @ThomasRaoux, how can we best achieve this? Are you far along enough in the planning phase that you could provide some feedback to @ggengnv? Or would you prefer for us to do a round of reviews first?

@Jokeren
Copy link
Contributor

Jokeren commented Sep 23, 2024

We heard through @ThomasRaoux and @gflegar that the Triton team is planning a similar feature

What feature are you referring to? I don't think anyone on our side is working on this.

@ThomasRaoux
Copy link
Collaborator

@lezcano is currently working on a mixed mode kernel that will require this support but at this point there hasn't been much design work on the mmav3 specific part yet and there are a few steps before getting to that.

Looking at the changes in the link, it seems that there is some more work needed to productize it so maybe we should join efforts indeed. What lond of timeline did you have in mind for this work?

@ggengnv
Copy link
Contributor

ggengnv commented Sep 26, 2024

As an update, I've addressed existing PR comments and split the original PR into two for ease of review.

I'm currently on leave till 10/9 - will be happy to address feedback and resume effort on this once I return :)

@Moerafaat
Copy link
Contributor

@ThomasRaoux @Jokeren we believe the first PR is ready to be looked at by your side. It is currently residing in OpenXLA's repo here https://github.com/openxla/triton/pull/18/files. It would be great if you can check it from your side before we ask NVIDIA to open the PR against OAI Triton.
Part 2 of the PR is here openxla#19 (obviously on top of PR 18 so the changes are there as well).
Let us know what you think.

@Jokeren
Copy link
Contributor

Jokeren commented Oct 14, 2024

We're currently refactoring a bunch of attributes related with the DotOperand layout.
You can find PRs like #4891 (comment), #4895 (review), #4856 (comment)
I believe once all problems have been fixed, maybe we could have a cleaner and general solution for mixed precision problems.

@Jokeren
Copy link
Contributor

Jokeren commented Oct 14, 2024

We will also refer to your PRs in openxla/triton. In case we still miss something to solve your problems later, please feel free to suggest changes. What do you think? @ThomasRaoux @Moerafaat

@ggengnv
Copy link
Contributor

ggengnv commented Oct 14, 2024

@Jokeren Thanks for providing context. For my reference, do you happen to have a tentative timeline on merging the PRs you linked above?

@Jokeren
Copy link
Contributor

Jokeren commented Oct 14, 2024

Probably by the end of the next week

@Moerafaat
Copy link
Contributor

Adding comment here to point to #5003 which is the first PR (potentially will also include the second PR after comments are addressed). Just echoing here for visibility.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants