Skip to content

Commit 9b3772d

Browse files
ANIKET-SHIVAMhwu36
andauthored
Hopper Grouped GEMM support for FP8 Accum (#2123)
* Add support for fp8accum, with profiler extension * Update .gitignore * contri --------- Co-authored-by: Haicheng Wu <[email protected]>
1 parent b84e980 commit 9b3772d

9 files changed

+914
-71
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# PyCache files
22
__pycache__/
3-
cutlass_library.egg-info/
3+
cutlass_library.egg-info/
4+
/build*

CONTRIBUTORS.md

+65-8
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS")
22

3-
[README](./README.md#documentation) > **Active Developers**
3+
[README](./README.md#documentation) > **Contributors**
44

55
# CUTLASS Developers **
66

7-
Andrew Kerr (CUTLASS founding member)<br />
7+
Andrew Kerr<br />
8+
Paul Springer<br />
89
Dustyn Blasig<br />
910
Albert Xu<br />
1011
Junkai Wu<br />
1112
Xiuxia Zhang<br />
12-
Haicheng Wu (CUTLASS founding member)<br />
13+
Haicheng Wu<br />
1314
Jack Yang<br />
14-
Pradeep Ramani (CUTLASS 3.x founding member)<br />
15+
Pradeep Ramani<br />
1516
Aditya Atluri<br />
1617
Han Li<br />
1718
Nick Zhao<br />
@@ -20,15 +21,15 @@ Yu-Jung Chen<br />
2021
Markus Hoehnerbach<br />
2122
Honghao Lu<br />
2223
Mihir Awatramani<br />
23-
Hao Sheng<br />
24+
Hao Sheng<br />
2425
Zekun Fan<br />
25-
Aniket Shivam<br />
26+
Aniket Shivam<br />
2627
Siyu Liu<br />
2728
Richard Cai<br />
2829
Vikas Gupta<br />
2930
Ethan Yan<br />
30-
Vijay Thakkar (CUTLASS 3.x and CuTe founding member)<br />
31-
Cris Cecka (CuTe and CUTLASS 3.x founding member)<br />
31+
Vijay Thakkar<br />
32+
Cris Cecka<br />
3233
Lawrence Ryan<br />
3334
Qun Song<br />
3435
Daniel Ricketts<br />
@@ -69,5 +70,61 @@ Shreya Gaur<br />
6970

7071
** _The list is sorted in order of the author's first contribution to the CUTLASS project._
7172

73+
74+
# CUTE Developers
75+
76+
Cris Cecka<br />
77+
Vijay Thakkar<br />
78+
79+
7280
# CUTLASS Product Manager
81+
7382
Matthew Nicely<br />
83+
84+
85+
# Former CUTLASS Developers
86+
87+
Manish Gupta<br />
88+
Duane Merrill<br />
89+
Piotr Majcher<br />
90+
Naila Farooqui<br />
91+
Mark Hoemmen<br />
92+
Rawn Henry<br />
93+
Jin Wang<br />
94+
Timmy Liu<br />
95+
Manikandan Ananth<br />
96+
David Tanner<br />
97+
98+
99+
# Acknowledgements
100+
101+
Tri Dao<br />
102+
Jay Shah<br />
103+
Timothy Costa<br />
104+
Julien Demouth<br />
105+
Brian Fahs<br />
106+
Michael Garland<br />
107+
Michael Goldfarb<br />
108+
Mostafa Hagog<br />
109+
Fei Hu<br />
110+
Alan Kaatz<br />
111+
Tina Li<br />
112+
Wei Liu<br />
113+
Tim Martin<br />
114+
Kevin Siu<br />
115+
Markus Tavenrath<br />
116+
John Tran<br />
117+
Vicki Wang<br />
118+
Fung Xie<br />
119+
Yang Xu<br />
120+
Scott Yokim<br />
121+
Girish Bharambe<br />
122+
Luke Durant<br />
123+
Carter Edwards<br />
124+
Olivier Giroux<br />
125+
Stephen Jones<br />
126+
Rishkul Kulkarni<br />
127+
Bryce Lelbach<br />
128+
Joel McCormack<br />
129+
Kyrylo Perelygin<br />
130+
Sean Treichler<br />

include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl

+8-5
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,6 @@ struct CollectiveBuilder<
234234
KernelPtrArrayTmaWarpSpecializedCooperative,
235235
KernelPtrArrayTmaWarpSpecializedPingpong>);
236236
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
237-
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
238-
"KernelPtrArrayTmaWarpSpecialized[Cooperative|Pingpong] is only compatible with FP8 FastAccum version right now.");
239237

240238
// For fp32 types, map to tf32 MMA value type
241239
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
@@ -267,12 +265,17 @@ struct CollectiveBuilder<
267265

268266
static constexpr int PipelineStages = detail::compute_stage_count_or_override<Sm90ReducedSmemCapacityBytes,
269267
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
268+
/* For FP8 use a separate mainloop compared to other datatypes */
270269
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
271-
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
272-
/* For FP8 use a separate mainloop compared to other datatypes */
270+
cute::conditional_t<IsFP8Input,
271+
MainloopSm90ArrayTmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
272+
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>
273+
>,
273274
cute::conditional_t<IsFP8Input,
274275
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
275-
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>>;
276+
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>
277+
>
278+
>;
276279

277280
using SmemCopyAtomA = void;
278281
using SmemCopyAtomB = void;

include/cutlass/gemm/collective/collective_mma.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@
4646
#include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp"
4747
#include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
4848
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp"
49-
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp"
49+
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp"
5050
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
51+
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp"
5152
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
5253

5354
#if !defined(__CUDACC_RTC__)

0 commit comments

Comments
 (0)