NVFP4 Random Hadamard Transform (butterfly permutation-based)#509
NVFP4 Random Hadamard Transform (butterfly permutation-based)#509matthiasdiener merged 49 commits intodevfrom
Conversation
Remove TODO regarding userbuffers
Userbuffer Enablement for ROCm
| namespace transformer_engine { | ||
| namespace { | ||
|
|
||
| constexpr int kThreadsPerWarp = 32; |
There was a problem hiding this comment.
It also seems unused on ROCm now so whole namespace could be guarded
| static constexpr int kHadamardDim = 16; | ||
| static constexpr int kWarpSize = 64; | ||
| static constexpr int kThreadsPerWHT = 4; | ||
| static constexpr int kElemsPerThread = 4; | ||
| static constexpr int kRowsPerWarp = kWarpSize / kThreadsPerWHT; // 16 | ||
| static constexpr int kWarpsPerBlock = 4; | ||
| static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64 | ||
| static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256 | ||
| static constexpr float kHadamardScale = 0.25f; |
There was a problem hiding this comment.
These do not seem like arbitrary tuning knobs. Some comment describing the layout scheme here could be helpful
There was a problem hiding this comment.
I think I can help answer partial of your questions.
kHadamardDim is the dimension of the hadamard transform matrix. In this specific case, the hadamard transform matrix is of size 16x16.
And kHadamardScale is 1/sqrt(hadamard matrix dim)
For the tiling constants:
kWarpSize is the number of threads per warp (or wavefront in our amd platform)
kThreadsPerWHT is how many threads are needed for one 16-point hadamard transform. Here it's set to be 4, which means that each thread will manage 4 inputs
kRowsPerWarp is defined as kWarpSize/kThreadsPerWHT probably because Matthias assign one warp (64 threads) to deal with a 2D data block of size 16x16 at the same time. So one row of 16 input data can be handle by 4 threads
But regarding why those tiling parameters are chosen like this, I'm not quite sure either
There was a problem hiding this comment.
Thanks @wangye805 for answering.
I've added a comment regarding these comments in cf2c8f6. These values aren't tuning knobs, they're determined by the problem structure. kThreadsPerWHT=4 follows from the Kronecker decomposition H16 = H4 x H4, where we reshape the 16-element vector into a 4×4 matrix with one column per thread. This means each thread hold 4 values and the cross-thread butterfly stages use ds_swizzle for the H4.
Given that and a 64-wide wavefront, the rest follows: kRowsPerWarp = 64/4 = 16 rows per wavefront, and kWarpsPerBlock = 4 gives 64 rows per block
| block_lam=fmaxf(block_lam,__shfl_xor(block_lam,off)); | ||
|
|
||
| if (lane_id == 0) | ||
| atomicMaxFloat(amax_out, block_lam); |
There was a problem hiding this comment.
This seems correct, but from a performance perspective, did you consider a hierarchical/two-pass reduction instead of atomically combining block-local amax values into global memory? Since it is only one atomic per block, I can see the simplicity argument, but I was curious about the tradeoff.
There was a problem hiding this comment.
Like you said, for this kernel, the atomic contention should be relatively small. The two-stage reduction requires a workspace allocation plus another kernel launch. We can revisit if profiling shows this as a bottleneck.
bafafea to
2772834
Compare
Description
Implements RHT via a butterfly permutation-based algorithm for NVFP4.
Has similar restrictions as upstream:
Fixes https://github.com/ROCm/frameworks-internal/issues/15732
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: