|
105 | 105 | };
|
106 | 106 |
|
107 | 107 | ///////////////////////////////////////////////////////////////////////////////////////////////////
|
108 |
| - |
| 108 | +#define A_ROW |
| 109 | +#define B_COL |
109 | 110 | template <
|
110 | 111 | class Gemm
|
111 | 112 | >
|
|
159 | 160 | //
|
160 | 161 |
|
161 | 162 | bool verify(const Options &options) {
|
162 |
| - |
163 |
| - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; |
164 |
| - using GmemTiledCopyB = XE_2D_U16x32x32_LD_N; |
165 |
| - |
166 |
| - // Workgroup-level tile |
167 |
| - using TileShape = Shape<_256, _256, _32>; |
168 |
| - |
169 |
| - using TiledMma = |
170 |
| - TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, |
171 |
| - Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>, |
172 |
| - Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>, |
173 |
| - Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, _32>>; |
174 |
| - |
| 163 | + #if defined(A_ROW) && defined(B_COL) |
| 164 | + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; |
| 165 | + using GmemTiledCopyB = XE_2D_U16x16x16_LD_T; |
| 166 | + #endif |
| 167 | + #if defined(A_ROW) && defined(B_ROW) |
| 168 | + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; |
| 169 | + using GmemTiledCopyB = XE_2D_U16x32x32_LD_N; |
| 170 | + #endif |
| 171 | + |
| 172 | + #if defined(A_COL) && defined(B_ROW) |
| 173 | + using GmemTiledCopyA = XE_2D_U16x16x16_LD_T; |
| 174 | + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; |
| 175 | + #endif |
| 176 | + |
| 177 | + #if defined(A_COL) && defined(B_COL) |
| 178 | + using GmemTiledCopyA = XE_2D_U16x16x16_LD_T; |
| 179 | + using GmemTiledCopyB = XE_2D_U16x16x16_LD_T; |
| 180 | + #endif |
| 181 | + // Workgroup-level tile |
| 182 | + using TileShape = Shape<_256, _256, _32>; |
| 183 | + |
| 184 | + using MMAAtom = MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>; |
| 185 | + using TiledMma = TiledMMA<MMAAtom, |
| 186 | + Layout<Shape<_8,_4,_1>, Stride<_4,_1,_0>>, |
| 187 | + Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>, |
| 188 | + Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, |
| 189 | + _32>>; |
175 | 190 | constexpr int PipelineStages = 3;
|
176 | 191 | using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
|
177 | 192 | using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
|
|
317 | 332 |
|
318 | 333 | };
|
319 | 334 |
|
320 |
| - |
321 | 335 | struct TransformA {
|
322 | 336 | template <class RTensor, class Trait, class TransTensor>
|
323 | 337 | CUTE_HOST_DEVICE auto operator()(RTensor const& in, Trait trait, TransTensor& out) {
|
| 338 | + #if defined(A_ROW) |
324 | 339 | // auto mma_A = make_fragment_like<typename TiledMma::ValTypeA>(in);
|
325 | 340 | Layout A_selector = make_layout(make_shape(_8{}, _4{}, _2{}), make_stride(_2{},_16{},_1{}));
|
| 341 | + // Layout A_selector = make_layout(make_shape(_8{}, _1{}, _2{}), make_stride(_2{},_16{}, _1{})); |
| 342 | + // Layout A_selector = make_layout(make_shape(_8{}, _2{}, _2{}), make_stride(_2{}, _16{}, _1{})); |
326 | 343 | CUTLASS_PRAGMA_UNROLL
|
327 | 344 | for(int i = 0; i < size<1>(out); i++) {
|
328 | 345 | CUTLASS_PRAGMA_UNROLL
|
329 | 346 | for(int j =0; j < size<2>(out); j++) {
|
330 | 347 | CUTLASS_PRAGMA_UNROLL
|
331 | 348 | for(int v = 0; v < size<0>(out); v++) {
|
332 | 349 | out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeA*/>(in.data()[A_selector(v, i, j)]);
|
333 |
| - // out(v, i, j) = 1; |
| 350 | + // out(v, i, j) = (bfloat16_t)(1.0f); |
334 | 351 | }
|
335 | 352 | }
|
336 | 353 | }
|
| 354 | + #endif |
| 355 | + #if defined(A_COL) |
| 356 | + Layout A_selector = make_layout(make_shape(_8{},_4{},_2{}), make_stride(_1{},_8{},_32{})); |
| 357 | + CUTLASS_PRAGMA_UNROLL |
| 358 | + for(int i = 0; i < size<1>(out); i++) { |
| 359 | + CUTLASS_PRAGMA_UNROLL |
| 360 | + for(int j =0; j < size<2>(out); j++) { |
| 361 | + CUTLASS_PRAGMA_UNROLL |
| 362 | + for(int v = 0; v < size<0>(out); v++) { |
| 363 | + out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeA*/>(in.data()[A_selector(v, i, j)]); |
| 364 | + // out(v, i, j) = (bfloat16_t)(1.0f); |
| 365 | + } |
| 366 | + } |
| 367 | + } |
| 368 | + #endif |
337 | 369 | }
|
338 | 370 | };
|
339 | 371 |
|
340 | 372 | struct TransformB {
|
341 | 373 | template <class RTensor, class Trait, class TransTensor>
|
342 | 374 | CUTE_HOST_DEVICE auto operator()(RTensor const& in, Trait trait, TransTensor& out) {
|
| 375 | + #if defined(B_ROW) && defined(A_ROW) |
343 | 376 | // auto mma_B = make_fragment_like<typename TiledMma::ValTypeB>(in);
|
344 | 377 | Layout B_selector = make_layout(make_shape(_16{}, make_shape(_2{}, _2{}), _2{}), make_stride(_4{}, make_stride(_1{}, _64{}) ,_2{}));
|
345 | 378 | CUTLASS_PRAGMA_UNROLL
|
|
349 | 382 | CUTLASS_PRAGMA_UNROLL
|
350 | 383 | for(int v = 0; v < size<0>(out); v++) {
|
351 | 384 | out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeB*/>(in.data()[B_selector(v, i, j)]);
|
352 |
| - // out(v, i, j) = 1; |
| 385 | + // out(v, i, j) = (bfloat16_t)(1.0f); |
| 386 | + } |
| 387 | + } |
| 388 | + } |
| 389 | + #endif |
| 390 | + #if defined(B_ROW) && defined(A_COL) |
| 391 | + Layout B_selector = make_layout(make_shape(_16{}, make_shape(_2{}, _2{}), _2{}), make_stride(_2{}, make_stride(_1{}, _64{}) ,_32{})); |
| 392 | + CUTLASS_PRAGMA_UNROLL |
| 393 | + for(int i = 0; i < size<1>(out); i++) { |
| 394 | + CUTLASS_PRAGMA_UNROLL |
| 395 | + for(int j =0; j < size<2>(out); j++) { |
| 396 | + CUTLASS_PRAGMA_UNROLL |
| 397 | + for(int v = 0; v < size<0>(out); v++) { |
| 398 | + out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeB*/>(in.data()[B_selector(v, i, j)]); |
| 399 | + // out(v, i, j) = (bfloat16_t)(1.0f); |
| 400 | + } |
| 401 | + } |
| 402 | + } |
| 403 | + #endif |
| 404 | + #if defined(B_COL) && defined(A_COL) |
| 405 | + Layout B_selector = make_layout(make_shape(_16{}, _4{},_2{}), make_stride(_1{}, _32{},_16{})); |
| 406 | + CUTLASS_PRAGMA_UNROLL |
| 407 | + for(int i = 0; i < size<1>(out); i++) { |
| 408 | + CUTLASS_PRAGMA_UNROLL |
| 409 | + for(int j =0; j < size<2>(out); j++) { |
| 410 | + CUTLASS_PRAGMA_UNROLL |
| 411 | + for(int v = 0; v < size<0>(out); v++) { |
| 412 | + out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeB*/>(in.data()[B_selector(v, i, j)]); |
| 413 | + // out(v, i, j) = (bfloat16_t)(1.0f); |
353 | 414 | }
|
354 | 415 | }
|
355 | 416 | }
|
| 417 | + #endif |
| 418 | + #if defined(B_COL) && defined(A_ROW) |
| 419 | + Layout B_selector = make_layout(make_shape(_16{}, _4{}, _2{}), make_stride(_2{}, _32{},_1{})); |
| 420 | + CUTLASS_PRAGMA_UNROLL |
| 421 | + for(int i = 0; i < size<1>(out); i++) { |
| 422 | + CUTLASS_PRAGMA_UNROLL |
| 423 | + for(int j =0; j < size<2>(out); j++) { |
| 424 | + CUTLASS_PRAGMA_UNROLL |
| 425 | + for(int v = 0; v < size<0>(out); v++) { |
| 426 | + out(v, i, j) = static_cast<cutlass::bfloat16_t/*typename TiledMma::ValTypeB*/>(in.data()[B_selector(v, i, j)]); |
| 427 | + // out(v, i, j) = (bfloat16_t)(1.0f); |
| 428 | + } |
| 429 | + } |
| 430 | + } |
| 431 | + #endif |
356 | 432 | }
|
357 | 433 | };
|
358 | 434 |
|
|
398 | 474 | using ElementInputB = cutlass::float_e4m3_t; // <- data type of elements in input matrix B
|
399 | 475 | using ElementOutput = float; // <- data type of elements in output matrix D
|
400 | 476 |
|
401 |
| - using LayoutA = cutlass::layout::RowMajor; |
402 |
| - using LayoutB = cutlass::layout::RowMajor; |
403 | 477 | using LayoutC = cutlass::layout::RowMajor;
|
404 | 478 | using LayoutD = cutlass::layout::RowMajor;
|
405 | 479 |
|
406 | 480 | // Note: XE_2D_U8x32x32_LD_V is incompatible with our bf16 MMA atoms
|
| 481 | + // 2.8tflops U8x32x32NLD_N |
| 482 | + // 1.4tflops U8x16x32NLD_N |
| 483 | + // 0.7tflops U8x 8x32NLD_N |
| 484 | + #if defined(A_COL) && defined(B_ROW) |
| 485 | + using LayoutA = cutlass::layout::ColumnMajor; |
| 486 | + using LayoutB = cutlass::layout::RowMajor; |
| 487 | + using GmemTiledCopyA = XE_2D_U8x16x32_LD_T; |
| 488 | + using GmemTiledCopyB = XE_2D_U8x32x32_LD_N; |
| 489 | + #endif |
| 490 | + #if defined(A_ROW) && defined(B_ROW) |
| 491 | + using LayoutA = cutlass::layout::RowMajor; |
| 492 | + using LayoutB = cutlass::layout::RowMajor; |
407 | 493 | using GmemTiledCopyA = XE_2D_U8x32x32_LD_N;
|
408 | 494 | using GmemTiledCopyB = XE_2D_U8x32x32_LD_N;
|
409 |
| - static_assert(sizeof(ElementInputA) == 1, "ElementA width must match GmemTiledCopyA U8"); |
410 |
| - |
| 495 | + #endif |
| 496 | + |
| 497 | + #if defined(A_COL) & defined(B_COL) |
| 498 | + using LayoutA = cutlass::layout::ColumnMajor; |
| 499 | + using LayoutB = cutlass::layout::ColumnMajor; |
| 500 | + using GmemTiledCopyA = XE_2D_U8x16x32_LD_T; |
| 501 | + using GmemTiledCopyB = XE_2D_U8x16x32_LD_T; |
| 502 | + #endif |
| 503 | + |
| 504 | + #if defined(A_ROW) && defined(B_COL) |
| 505 | + using LayoutA = cutlass::layout::RowMajor; |
| 506 | + using LayoutB = cutlass::layout::ColumnMajor; |
| 507 | + using GmemTiledCopyA = XE_2D_U8x32x32_LD_N; |
| 508 | + using GmemTiledCopyB = XE_2D_U8x16x32_LD_T; |
| 509 | + #endif |
411 | 510 | // Workgroup-level tile
|
412 | 511 | using TileShape = Shape<_256, _256, _32>;
|
413 | 512 |
|
|
435 | 534 | FusionCallBacks,
|
436 | 535 | XE_2D_U32x8x16_LD_N,
|
437 | 536 | void, void,
|
| 537 | + #if defined(B_COL) |
| 538 | + XE_2D_U32x8x16_ST_N, |
| 539 | + #else |
438 | 540 | void,
|
| 541 | + #endif |
439 | 542 | void, void>;
|
440 | 543 |
|
441 | 544 | // Mainloop
|
|
0 commit comments