Skip to content

Commit eb3529e

Browse files
committed
All major combination functionality
1 parent 74dc9d3 commit eb3529e

File tree

5 files changed

+171
-25
lines changed

5 files changed

+171
-25
lines changed

examples/sycl/pvc/pvc_gemm_fp8.cpp

+124-21
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@
105105
};
106106

107107
///////////////////////////////////////////////////////////////////////////////////////////////////
108-
108+
#define A_ROW
109+
#define B_COL
109110
template <
110111
class Gemm
111112
>
@@ -159,19 +160,33 @@
159160
//
160161

161162
bool verify(const Options &options) {
162-
163-
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
164-
using GmemTiledCopyB = XE_2D_U16x32x32_LD_N;
165-
166-
// Workgroup-level tile
167-
using TileShape = Shape<_256, _256, _32>;
168-
169-
using TiledMma =
170-
TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
171-
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>,
172-
Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
173-
Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, _32>>;
174-
163+
#if defined(A_ROW) && defined(B_COL)
164+
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
165+
using GmemTiledCopyB = XE_2D_U16x16x16_LD_T;
166+
#endif
167+
#if defined(A_ROW) && defined(B_ROW)
168+
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
169+
using GmemTiledCopyB = XE_2D_U16x32x32_LD_N;
170+
#endif
171+
172+
#if defined(A_COL) && defined(B_ROW)
173+
using GmemTiledCopyA = XE_2D_U16x16x16_LD_T;
174+
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
175+
#endif
176+
177+
#if defined(A_COL) && defined(B_COL)
178+
using GmemTiledCopyA = XE_2D_U16x16x16_LD_T;
179+
using GmemTiledCopyB = XE_2D_U16x16x16_LD_T;
180+
#endif
181+
// Workgroup-level tile
182+
using TileShape = Shape<_256, _256, _32>;
183+
184+
using MMAAtom = MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>;
185+
using TiledMma = TiledMMA<MMAAtom,
186+
Layout<Shape<_8,_4,_1>, Stride<_4,_1,_0>>,
187+
Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
188+
Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>,
189+
_32>>;
175190
constexpr int PipelineStages = 3;
176191
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
177192
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
@@ -317,29 +332,47 @@
317332

318333
};
319334

320-
321335
struct TransformA {
322336
template <class RTensor, class Trait, class TransTensor>
323337
CUTE_HOST_DEVICE auto operator()(RTensor const& in, Trait trait, TransTensor& out) {
338+
#if defined(A_ROW)
324339
// auto mma_A = make_fragment_like<typename TiledMma::ValTypeA>(in);
325340
Layout A_selector = make_layout(make_shape(_8{}, _4{}, _2{}), make_stride(_2{},_16{},_1{}));
341+
// Layout A_selector = make_layout(make_shape(_8{}, _1{}, _2{}), make_stride(_2{},_16{}, _1{}));
342+
// Layout A_selector = make_layout(make_shape(_8{}, _2{}, _2{}), make_stride(_2{}, _16{}, _1{}));
326343
CUTLASS_PRAGMA_UNROLL
327344
for(int i = 0; i < size<1>(out); i++) {
328345
CUTLASS_PRAGMA_UNROLL
329346
for(int j =0; j < size<2>(out); j++) {
330347
CUTLASS_PRAGMA_UNROLL
331348
for(int v = 0; v < size<0>(out); v++) {
332349
out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeA*/>(in.data()[A_selector(v, i, j)]);
333-
// out(v, i, j) = 1;
350+
// out(v, i, j) = (bfloat16_t)(1.0f);
334351
}
335352
}
336353
}
354+
#endif
355+
#if defined(A_COL)
356+
Layout A_selector = make_layout(make_shape(_8{},_4{},_2{}), make_stride(_1{},_8{},_32{}));
357+
CUTLASS_PRAGMA_UNROLL
358+
for(int i = 0; i < size<1>(out); i++) {
359+
CUTLASS_PRAGMA_UNROLL
360+
for(int j =0; j < size<2>(out); j++) {
361+
CUTLASS_PRAGMA_UNROLL
362+
for(int v = 0; v < size<0>(out); v++) {
363+
out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeA*/>(in.data()[A_selector(v, i, j)]);
364+
// out(v, i, j) = (bfloat16_t)(1.0f);
365+
}
366+
}
367+
}
368+
#endif
337369
}
338370
};
339371

340372
struct TransformB {
341373
template <class RTensor, class Trait, class TransTensor>
342374
CUTE_HOST_DEVICE auto operator()(RTensor const& in, Trait trait, TransTensor& out) {
375+
#if defined(B_ROW) && defined(A_ROW)
343376
// auto mma_B = make_fragment_like<typename TiledMma::ValTypeB>(in);
344377
Layout B_selector = make_layout(make_shape(_16{}, make_shape(_2{}, _2{}), _2{}), make_stride(_4{}, make_stride(_1{}, _64{}) ,_2{}));
345378
CUTLASS_PRAGMA_UNROLL
@@ -349,10 +382,53 @@
349382
CUTLASS_PRAGMA_UNROLL
350383
for(int v = 0; v < size<0>(out); v++) {
351384
out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeB*/>(in.data()[B_selector(v, i, j)]);
352-
// out(v, i, j) = 1;
385+
// out(v, i, j) = (bfloat16_t)(1.0f);
386+
}
387+
}
388+
}
389+
#endif
390+
#if defined(B_ROW) && defined(A_COL)
391+
Layout B_selector = make_layout(make_shape(_16{}, make_shape(_2{}, _2{}), _2{}), make_stride(_2{}, make_stride(_1{}, _64{}) ,_32{}));
392+
CUTLASS_PRAGMA_UNROLL
393+
for(int i = 0; i < size<1>(out); i++) {
394+
CUTLASS_PRAGMA_UNROLL
395+
for(int j =0; j < size<2>(out); j++) {
396+
CUTLASS_PRAGMA_UNROLL
397+
for(int v = 0; v < size<0>(out); v++) {
398+
out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeB*/>(in.data()[B_selector(v, i, j)]);
399+
// out(v, i, j) = (bfloat16_t)(1.0f);
400+
}
401+
}
402+
}
403+
#endif
404+
#if defined(B_COL) && defined(A_COL)
405+
Layout B_selector = make_layout(make_shape(_16{}, _4{},_2{}), make_stride(_1{}, _32{},_16{}));
406+
CUTLASS_PRAGMA_UNROLL
407+
for(int i = 0; i < size<1>(out); i++) {
408+
CUTLASS_PRAGMA_UNROLL
409+
for(int j =0; j < size<2>(out); j++) {
410+
CUTLASS_PRAGMA_UNROLL
411+
for(int v = 0; v < size<0>(out); v++) {
412+
out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeB*/>(in.data()[B_selector(v, i, j)]);
413+
// out(v, i, j) = (bfloat16_t)(1.0f);
353414
}
354415
}
355416
}
417+
#endif
418+
#if defined(B_COL) && defined(A_ROW)
419+
Layout B_selector = make_layout(make_shape(_16{}, _4{}, _2{}), make_stride(_2{}, _32{},_1{}));
420+
CUTLASS_PRAGMA_UNROLL
421+
for(int i = 0; i < size<1>(out); i++) {
422+
CUTLASS_PRAGMA_UNROLL
423+
for(int j =0; j < size<2>(out); j++) {
424+
CUTLASS_PRAGMA_UNROLL
425+
for(int v = 0; v < size<0>(out); v++) {
426+
out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeB*/>(in.data()[B_selector(v, i, j)]);
427+
// out(v, i, j) = (bfloat16_t)(1.0f);
428+
}
429+
}
430+
}
431+
#endif
356432
}
357433
};
358434

@@ -398,16 +474,39 @@
398474
using ElementInputB = cutlass::float_e4m3_t; // <- data type of elements in input matrix B
399475
using ElementOutput = float; // <- data type of elements in output matrix D
400476

401-
using LayoutA = cutlass::layout::RowMajor;
402-
using LayoutB = cutlass::layout::RowMajor;
403477
using LayoutC = cutlass::layout::RowMajor;
404478
using LayoutD = cutlass::layout::RowMajor;
405479

406480
// Note: XE_2D_U8x32x32_LD_V is incompatible with our bf16 MMA atoms
481+
// 2.8tflops U8x32x32NLD_N
482+
// 1.4tflops U8x16x32NLD_N
483+
// 0.7tflops U8x 8x32NLD_N
484+
#if defined(A_COL) && defined(B_ROW)
485+
using LayoutA = cutlass::layout::ColumnMajor;
486+
using LayoutB = cutlass::layout::RowMajor;
487+
using GmemTiledCopyA = XE_2D_U8x16x32_LD_T;
488+
using GmemTiledCopyB = XE_2D_U8x32x32_LD_N;
489+
#endif
490+
#if defined(A_ROW) && defined(B_ROW)
491+
using LayoutA = cutlass::layout::RowMajor;
492+
using LayoutB = cutlass::layout::RowMajor;
407493
using GmemTiledCopyA = XE_2D_U8x32x32_LD_N;
408494
using GmemTiledCopyB = XE_2D_U8x32x32_LD_N;
409-
static_assert(sizeof(ElementInputA) == 1, "ElementA width must match GmemTiledCopyA U8");
410-
495+
#endif
496+
497+
#if defined(A_COL) & defined(B_COL)
498+
using LayoutA = cutlass::layout::ColumnMajor;
499+
using LayoutB = cutlass::layout::ColumnMajor;
500+
using GmemTiledCopyA = XE_2D_U8x16x32_LD_T;
501+
using GmemTiledCopyB = XE_2D_U8x16x32_LD_T;
502+
#endif
503+
504+
#if defined(A_ROW) && defined(B_COL)
505+
using LayoutA = cutlass::layout::RowMajor;
506+
using LayoutB = cutlass::layout::ColumnMajor;
507+
using GmemTiledCopyA = XE_2D_U8x32x32_LD_N;
508+
using GmemTiledCopyB = XE_2D_U8x16x32_LD_T;
509+
#endif
411510
// Workgroup-level tile
412511
using TileShape = Shape<_256, _256, _32>;
413512

@@ -435,7 +534,11 @@
435534
FusionCallBacks,
436535
XE_2D_U32x8x16_LD_N,
437536
void, void,
537+
#if defined(B_COL)
538+
XE_2D_U32x8x16_ST_N,
539+
#else
438540
void,
541+
#endif
439542
void, void>;
440543

441544
// Mainloop

include/cute/arch/xe_copy_2B.hpp

+21
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,27 @@ struct XE_2D_U16x16x16_LD_T {
778778
}
779779
};
780780

781+
struct XE_2D_U8x16x32_LD_T {
782+
using BlockShape = Shape<_32, _16>;
783+
using inst_dtype = uint32_t;
784+
785+
static constexpr bool is_transpose = true;
786+
787+
template <class T>
788+
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
789+
int height, int pitch, intel::coord_t coord,
790+
T *dst) {
791+
#if defined(SYCL_INTEL_TARGET)
792+
static_assert(sizeof(T) == 1, "Expected T to have size 1");
793+
*reinterpret_cast<intel::uint8 *>(dst) =
794+
__builtin_IB_subgroup_block_read_flat_transpose_u32_k8(
795+
(long)(baseoffset), width - 1, height - 1, pitch - 1, coord);
796+
#else
797+
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
798+
#endif
799+
}
800+
};
801+
781802
struct XE_2D_U16x1x16_ST_N {
782803
using BlockShape = Shape<_1, _16>;
783804

include/cute/atom/copy_traits_xe.hpp

+21
Original file line numberDiff line numberDiff line change
@@ -1667,6 +1667,26 @@ struct Copy_Traits_<XE_2D_U16x16x16_LD_T, args_t...>
16671667
: XE_2D_LD_Unpack<XE_2D_U16x16x16_LD_T, args_t...>(args...) {}
16681668
};
16691669

1670+
template <class... args_t>
1671+
struct Copy_Traits_<XE_2D_U8x16x32_LD_T, args_t...>
1672+
: XE_2D_LD_Unpack<XE_2D_U8x16x32_LD_T, args_t...> {
1673+
using ThrID = Layout<_16>;
1674+
// Map from (src-thr,src-val) to bit
1675+
// TODO(joe): Not convinced that changing from <_16, _256> should be required here
1676+
// but get_logical_layout assumes get<1,0>(layout.shape) is the type size
1677+
using SrcLayout = Layout<Shape <_16,Shape <_8,_32>>,
1678+
Stride< _0,Stride<_1,_64>>>;
1679+
// Map from (dst-thr,dst-val) to bit
1680+
using DstLayout = Layout<Shape < _16,Shape <_8,_32>>,
1681+
Stride<_256,Stride<_1, _8>>>;
1682+
// Reference map from (thr,val) to bit
1683+
using RefLayout = DstLayout;
1684+
1685+
template <class... ArgT>
1686+
Copy_Traits_(ArgT... args)
1687+
: XE_2D_LD_Unpack<XE_2D_U8x16x32_LD_T, args_t...>(args...) {}
1688+
};
1689+
16701690
// template<class... args_t>
16711691
// struct Copy_Traits<XE_2D_U32x16x1_LD_T, args_t...>
16721692
// : XE_2D_LD_Unpack<XE_2D_U32x16x1_LD_T, args_t...> {
@@ -2251,6 +2271,7 @@ COPY_TRAIT_LD_DEF(XE_2D_U16x32x16_LD_V)
22512271
COPY_TRAIT_LD_DEF(XE_2D_U16x32x32_LD_V)
22522272
COPY_TRAIT_LD_DEF(XE_2D_U16x16x32_LD_V)
22532273
COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_T)
2274+
COPY_TRAIT_LD_DEF(XE_2D_U8x16x32_LD_T)
22542275
COPY_TRAIT_LD_DEF(XE_2D_TF32x16x16_LD_N)
22552276
COPY_TRAIT_LD_DEF(XE_2D_TF32x32x16_LD_N)
22562277
COPY_TRAIT_LD_DEF(XE_2D_U4x32x64_LD_N)

include/cutlass/epilogue/collective/xe_epilogue.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,16 @@ class CollectiveEpilogue<
378378
auto synchronize = [&] () {};
379379

380380
// 32 x 64
381+
// if(cute::thread0()) {
382+
// print("accumulators: ");print(accumulators);print("\n");
383+
// }
381384
if constexpr(!is_same_v<CopyOpR2G_, XE_2D_U32x8x16_ST_N>) {
382385
auto D = make_tensor(make_gmem_ptr(params.ptr_D), make_layout(make_shape(4096, 4096), make_stride(4096, 1)));
383386
for(int i = 0; i < size<1>(accumulators); i++) {
384387
for(int j = 0; j < size<2>(accumulators); j++) {
385388
for(int v = 0; v < size<0>(accumulators); v++) {
386389
D(v + i * 8 + m_sg * 32 + BlockIdxY() * 256 , BlockIdxX() * 256 + n_sg * 64 + (thread_idx % 16) * 2 + (j % 2) + (j / 2) * 32) = accumulators(v, i, j);
390+
// D(v + i * 8 + m_sg * 16 + BlockIdxY() * 128 , BlockIdxX() * 256 + n_sg * 64 + (thread_idx % 16) * 2 + (j % 2) + (j / 2) * 32) = accumulators(v, i, j);
387391
}
388392
}
389393
}

include/cutlass/gemm/collective/xe_mma.hpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,6 @@ struct CollectiveMma<MainloopIntelPVC<Stages, Schedule>, TileShape_, ElementA_,
197197
auto pAgA = thr_prefetch_A.partition_S(gA);
198198
auto pBgB = thr_prefetch_B.partition_S(gB);
199199

200-
TransformA transformA{};
201-
TransformB transformB{};
202-
203200
#if CUTLASS_ENABLE_DEBUG_PRINTS
204201
#define PRINT(x) print(#x ": "); print(x); print("\n");
205202
if (cute::thread(LOG_THREAD, LOG_GROUP)) {
@@ -228,7 +225,7 @@ struct CollectiveMma<MainloopIntelPVC<Stages, Schedule>, TileShape_, ElementA_,
228225
const auto k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start));
229226
constexpr int barrier_scope = 2;
230227
int prefetch_k = 0;
231-
228+
232229
CUTLASS_PRAGMA_UNROLL
233230
for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) {
234231
prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k));

0 commit comments

Comments
 (0)