Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LHS Registers Part 1 - DotOp Hoisting and SMEM-RF Copy Lowering #18

Open
wants to merge 12 commits into
base: llvm-head
Choose a base branch
from

Conversation

ggengnv
Copy link

@ggengnv ggengnv commented Sep 23, 2024

(Part 2: #19)
Part 1 of "WGMMA with LHS operand in registers" feature.

Hopper has two kinds of WGMMAs, "SS" (both operands in shmem) and "RS" (LHS operand A in registers).
In cases where we apply elementwise operations on A before WGMMA, Triton previously will copy A from global memory (GMEM) into registers (RF), perform the elementwise ops, and then copy to shared memory (SMEM) to perform SS WGMMA.

This PR adds an optimization for the case above to use RS GEMM. This requires the following changes:

  • In TritonGPU OptimizeDotOperands pass, add optimization to change SS GEMM into RS GEMM, where doing so is possible and beneficial
  • Add TritonGPU -> LLVM lowering for copying from SMEM to RF in MMA v3 dotOperand layout

Being without pipelining, this PR is not expected to see perf gains. Pipelining for MMAv3 operand in registers is added in Part 2.

@ggengnv
Copy link
Author

ggengnv commented Sep 23, 2024

@gflegar @chsigg @vwbaker @Moerafaat This is part 1 of the two PRs split from #17 as suggested by @vwbaker. Please assign reviewers/review, thanks
And for part 2 as well, please: #19

@ggengnv
Copy link
Author

ggengnv commented Sep 23, 2024

Addressed all comments in the original PR that are relevant to part 1 in this PR instead.

@@ -11,6 +11,8 @@
import pytest
import torch
import os
os.environ['TRITON_ALWAYS_COMPILE'] = '1'
os.environ['MLIR_ENABLE_DUMP'] = '1'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like these were leftover from debugging

@Moerafaat
Copy link
Member

Ops, seems that updates we have to maintain the Triton integration will cause PR diffs to break because of force-updates. We might need to figure out a better way to handle this as we didn't intend for this repo to accept incoming PRs. Apologies for this, but you will need to rebase again for the diff to include the proper changes.

@ggengnv ggengnv force-pushed the lhs-reg-hoist branch 3 times, most recently from 1b95c9a to 942dad4 Compare September 25, 2024 22:08
@ggengnv
Copy link
Author

ggengnv commented Sep 25, 2024

@Moerafaat np - I've reapplied my changes on the new main

Copy link

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nb. I have not reviewed OptimizeDotOperands.cpp, @ThomasRaoux should probably review that part once you put up a PR against the triton repo.

First of all, hats off if you've managed to fix SharedToDotOperand for kWidth != 2! It was an absolute madness of indices.

Now, @Jokeren is working on completely removing this devilish path in favour of the cleaner and more correct conversion via linear layouts. I think a more long-term solution here would be to implement linear layout support for DotOperand for Hopper, but we can do this in a follow-up PR.

Comment on lines +557 to +558
// To unify the ordering conventions would potentially require touching
// `ConvertLayoutOpToLLVM.cpp`, `ElementwiseOpToLLVM.cpp`, `MMAv2.cpp`,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be fixed (at last!) in triton-lang#4951. Hopefully it'll get merged today, tomorrow at latest.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just landed: triton-lang#4979

Comment on lines +468 to +472
if (isHopperWidthChange) {
vecWidth = 4 / mmaElemBytes;
} else {
vecWidth = 4 / elemBytes;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean vecWidth = kWidth here? This way, it would work for kWidth > 4.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just left this comment below (#18 (comment)) which I think also applies to this question.

In what case would we have kWidth > 4 though? For Ampere and Hopper we should have 4 as the maximum right (with int8/fp8 dtype)

// width-changing casting is done later in DotOp Layout... then, in the case of
// Hopper, the number of bytes held by each thread after loading will no longer
// be 32B. Hence this flag is required to stipulate different logic.
bool isHopperWidthChange = isHopper && (mmaElemBytes != elemBytes);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also applies to Ampere, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, and it makes me wonder now why this was working for Ampere before without this flag.
I'll inspect this code a bit more and come up with an explanation.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The answer is that it just didn't work before. That's one of the reasons why we started migrating everything to Linear Layouts, and it's that code written manually has very subtle bugs like the ones you fixed in this PR :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update - I was testing int8 -> f16 for Ampere and was surprised at first to find that it was working correctly.
I've now come up with an explanation. You probably already know most of the following, but I'd just like to confirm we're on the same page:

If I have AxB where A is int8 and B is f16, with A cast to f16 before MMA, AccelerateMatmul.cpp will compute kWidth for both operands as 4. This is due to the logic in computeOrigBitWidth, which calculates the kWidth using the bitwidth of the smallest element along the dot-op chain. In my case this smallest element is int8, so kWidth = 32 / 8 = 4.

But the eventual layout right before MMA is f16, so the kWidth should be 2. So, it seems at first glance that with kWidth = 4, the results should be wrong. For example, for operand A, f16 layout expects thread 0 to hold elements

(0, {0, 1}), (8, {0, 1}), (0, {8, 9}), (8, {8, 9})  # (m_index, {k_indices...})

...but in reality, we load it with kWidth = 4, meaning

(0, {0, 1, 2, 3}), (8, {0, 1, 2, 3}), (0, {8, 9, 10, 11}), (8, {8, 9, 10, 11})

(I've attached Ampere dotOp layouts for int8 and f16 at the end)

So each "rep" of A here doesn't actually mean a "rep" of the actual f16 MMA instruction, but instead corresponds to 2 reps.

OTOH, B is also loaded with kWidth = 4, but with vecWidth = 2, so the first rep of B is loaded in this order:
(n_offset=0, k_offset={0, 1}), (n_offset=0, k_offset={8, 9})
and the second rep...
(n_offset=0, k_offset={2, 3}), (n_offset=0, k_offset={10, 11})

To match the ordering of elements in A and B, the lowering of int8 -> f16 will reorder the values of A, so that every 16 values of A can be split into 2x 8 values, i.e. 2 reps:

(m_offset=0, k_offset={0, 1}), (8, {0, 1}), (0, {8, 9}), (8, {8, 9})
(m_offset=0, k_offset={2, 3}), (8, {2, 3}), (0, {10, 11}), (8, {10, 11})

I think the key observation here was that the element ordering along K doesn't have to match what the PTX doc prescribes, as long as it's consistent between A and B.

To conclude, the logic here works for Ampere thanks to reorderValues (which has logic for 8b <> 16b and 16b <> 32b). OTOH, for WGMMA, I'd have to previously set kWidth = 2 and have special logic here to load the correct number of values; this is because operand B is always in shmem and doesn't have the kWidth/vecWidth logic above for Ampere's B.

My thought is that the reordering trick might allow for more vectorization and so might actually be worth implementing for Hopper. For this PR though, things look to be functionally correct for both Hopper and Ampere, so should I just leave the logic here as-is for now? (or unless there are other things, which I'm not aware of, that didn't work for Ampere that my Hopper fix here could extend to?)


Operand A layout for mma.m16n8k16 with int8:
image

Operand A layout for mma.m16n8k16 with f16/bf16:
image

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a pretty neat analysis!

The issue here is that I pretty much hit these bugs when doing fp4 x bf16, so, following the logic for fp8 x bf16, I chose kWidth=8, but then everything breaks as these computations break with kWidth > 4. Using kWidth=1 and kWidth=4 I was hitting a similar bug as well.

In general, the issue with most of these computations is that they are done in terms of element sizes (which is a tensor-level concept) instead of kWidth, which is a layout-level concept. For example, the same vectorisation optimisations could be done by looking at the kWidth themselves. cc @Jokeren who is trying to clean-up all this.

That being said, we have what we have, so our current attack route to clean all this mess is using LinearLayouts. The idea moving forward is to implement LinearLayout conversions for all the layouts we have. LinearLayouts are easy to prove to be correct, and all these optimisations can be implemented rather cleanly at a layout level.

With all this I want to say that it's probably fine for this code not to support Hopper with kWidth > 4, as we are aiming to delete it anyway in favour of our LinearLayout path.

// matK (k-index of the "warp matrix")
// quadK (k-index of the "quad" in the core matrix)
// quadM (m-index of the "quad" in the core matrix)
// vecIdx (index of the element in the quad; this is always along the k-dim)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to support kWidth in this file? Or are you expecting that kWidth = 4 / elemSize at this stage? If so, could you add an assert?

Note that this needn't be the case, and we could have a larger kWidth, and have this code emit kWidth / (4 / elemSize) wgmma ops.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, I guess that it's just the lhs that supports funny kWidths, so at this stage the kWidth should agree. It would be good to assert that though.

Comment on lines +671 to +672
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / mmaBitwidth;
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably write 2 * kWidth rather than 64 / mmaBitwidth for both Hopper and Ampere, as this is in both cases the number of elements along K.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point I believe kWidth may not equal 32 / mmaBitwidth for Ampere.

In the example in my comment above, operand B will have dtype = f16 but kWidth = 4. I think that in this case, matShapeK should be calculated based on the dtype bitwidth, so that the element ordering is correct.

int bitwidth,
int opIdx) const {
assert(isAmpere() || isHopper());
auto rank = shape.size();
auto warpsPerCTA = getWarpsPerCTA();
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the n value here in Hopper?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, after having a thought about this, I guess that in all these places that depend on instrN, it's fine to leave n = 8, as the n pattern repeats with period 8 for the different tile sizes.

It would be good to leave a note somewhere in the code and refer to that note in all the places where we use this "trick" (here, in the SharedEncodingAttr, in the shared to dot operand mma, etc).

Copy link
Author

@ggengnv ggengnv Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Hopper actually opIdx should always be 0 since WGMMA doesn't support operand B in registers, and so we shouldn't ever use the n value.

I can add an assert here for opIdx == 0

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right! Adding an assert would be great.

@lezcano
Copy link

lezcano commented Oct 24, 2024

Also, could you add some tests and open a PR against the triton repo?
Edit. I just saw you do have an mlir test

@lezcano
Copy link

lezcano commented Oct 28, 2024

@ggengnv do you know roughly when will you be able to rebase this on top of main and address the review? I think we would like to incorporate this work upstream sooner than later.

If you are busy, I can help with the rebasing and addressing the review.

@ggengnv
Copy link
Author

ggengnv commented Oct 28, 2024

@lezcano Hey, sorry for taking a bit on this. I was busy last week but I'm rebasing this now. Aiming to get this done soon today; will keep you updated.

@ggengnv
Copy link
Author

ggengnv commented Oct 28, 2024

PR transferred to triton-lang: triton-lang#5003

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants