|
| 1 | +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. |
| 2 | + |
| 3 | +#include "cutlass/cutlass.h" |
| 4 | +#include "cutlass/gemm/device/gemm.h" |
| 5 | +#include "cutlass/gemm/device/gemm_universal.h" |
| 6 | + |
| 7 | +/** |
| 8 | + * Panic wrapper for unwinding CUTLASS errors |
| 9 | + */ |
| 10 | +#define CUTLASS_CHECK(status) \ |
| 11 | + { \ |
| 12 | + cutlass::Status error = status; \ |
| 13 | + if (error != cutlass::Status::kSuccess) { \ |
| 14 | + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ |
| 15 | + << " at: " << __LINE__ << std::endl; \ |
| 16 | + exit(EXIT_FAILURE); \ |
| 17 | + } \ |
| 18 | + } |
| 19 | + |
| 20 | +/////////////////////////////////////////////////////////////////////////////////////////////////// |
| 21 | + |
| 22 | +// The code section below describes datatype for input, output matrices and |
| 23 | +// computation between elements in input matrices. |
| 24 | +using ElementAccumulator = float; // <- data type of accumulator |
| 25 | +using ElementComputeEpilogue = |
| 26 | + ElementAccumulator; // <- data type of epilogue operations |
| 27 | +using ElementInputA = float; // <- data type of elements in input matrix A |
| 28 | +using ElementInputB = float; // <- data type of elements in input matrix B |
| 29 | +using ElementOutput = float; // <- data type of elements in output matrix D |
| 30 | + |
| 31 | +// The code section below describes matrix layout of input and output matrices. |
| 32 | +// Column Major for Matrix A, Row Major for Matrix B and Row Major for Matrix C |
| 33 | +using LayoutInputA = cutlass::layout::RowMajor; |
| 34 | +using LayoutInputB = cutlass::layout::RowMajor; |
| 35 | +using LayoutOutput = cutlass::layout::RowMajor; |
| 36 | + |
| 37 | +// This code section describes whether you want to use tensor cores or regular |
| 38 | +// SIMT cores on GPU SM |
| 39 | +using MMAOp = cutlass::arch::OpClassTensorOp; |
| 40 | + |
| 41 | +// This code section describes CUDA SM architecture number |
| 42 | +using SmArch = cutlass::arch::Sm80; |
| 43 | + |
| 44 | +// This code section describes the tile size a thread block will compute |
| 45 | +using ShapeMMAThreadBlock = |
| 46 | + cutlass::gemm::GemmShape<128, 256, 16>; // <- threadblock tile M = 128, N = |
| 47 | + // 128, K = 16 |
| 48 | +// This code section describes tile size a warp will compute |
| 49 | +using ShapeMMAWarp = |
| 50 | + cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 |
| 51 | +// This code section describes the size of MMA op |
| 52 | +using ShapeMMAOp = |
| 53 | + cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 |
| 54 | + |
| 55 | +// This code section describes how threadblocks are scheduled on GPU |
| 56 | +using SwizzleThreadBlock = |
| 57 | + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? |
| 58 | + |
| 59 | +// This code section describes the epilogue part of the kernel |
| 60 | +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< |
| 61 | + ElementOutput, // <- data type of output matrix |
| 62 | + 128 / |
| 63 | + cutlass::sizeof_bits< |
| 64 | + ElementOutput>::value, // <- the number of elements per vectorized |
| 65 | + // memory access. For a byte, it's 16 |
| 66 | + // elements. This becomes the vector width of |
| 67 | + // math instructions in the epilogue too |
| 68 | + ElementAccumulator, // <- data type of accumulator |
| 69 | + ElementComputeEpilogue>; // <- data type for alpha/beta in linear |
| 70 | + // combination function |
| 71 | + |
| 72 | +// Number of pipelines you want to use |
| 73 | +constexpr int NumStages = 3; |
| 74 | + |
| 75 | +using Gemm = cutlass::gemm::device::Gemm< |
| 76 | + ElementInputA, |
| 77 | + LayoutInputA, |
| 78 | + ElementInputB, |
| 79 | + LayoutInputB, |
| 80 | + ElementOutput, |
| 81 | + LayoutOutput, |
| 82 | + ElementAccumulator, |
| 83 | + MMAOp, |
| 84 | + SmArch, |
| 85 | + ShapeMMAThreadBlock, |
| 86 | + ShapeMMAWarp, |
| 87 | + ShapeMMAOp, |
| 88 | + EpilogueOp, |
| 89 | + SwizzleThreadBlock, |
| 90 | + NumStages>; |
| 91 | + |
| 92 | +void gemm_kernel(float* a, float* b, float* c, int m, int n, int k) { |
| 93 | + cutlass::gemm::GemmCoord problem_size{m, n, k}; |
| 94 | + cutlass::TensorRef tensor_a{a, LayoutInputA{k}}; |
| 95 | + cutlass::TensorRef tensor_b{b, LayoutInputB{n}}; |
| 96 | + cutlass::TensorRef tensor_c{c, LayoutOutput{n}}; |
| 97 | + cutlass::TensorRef tensor_d{c, LayoutOutput{n}}; |
| 98 | + |
| 99 | + // Initialize alpha and beta for dot product computation |
| 100 | + ElementComputeEpilogue alpha = ElementComputeEpilogue(1.0f); |
| 101 | + ElementComputeEpilogue beta = ElementComputeEpilogue(0.0f); |
| 102 | + |
| 103 | + // Split K dimension into 1 partitions |
| 104 | + int split_k_slices = 1; |
| 105 | + |
| 106 | + // Create a tuple of gemm kernel arguments. This is later passed as arguments |
| 107 | + // to launch instantiated CUTLASS kernel |
| 108 | + typename Gemm::Arguments arguments{ |
| 109 | + problem_size, // <- problem size of matrix multiplication |
| 110 | + tensor_a, // <- reference to matrix A on device |
| 111 | + tensor_b, // <- reference to matrix B on device |
| 112 | + tensor_c, // <- reference to matrix C on device |
| 113 | + tensor_d, // <- reference to matrix D on device |
| 114 | + {alpha, beta}, // <- tuple of alpha and beta |
| 115 | + split_k_slices}; // <- k-dimension split factor |
| 116 | + |
| 117 | + // Using the arguments, query for extra workspace required for matrix |
| 118 | + // multiplication computation |
| 119 | + size_t workspace_size = Gemm::get_workspace_size(arguments); |
| 120 | + |
| 121 | + // printf("workspace size: %d\n", workspace_size); |
| 122 | + if (workspace_size != 0) { |
| 123 | + exit(EXIT_FAILURE); |
| 124 | + } |
| 125 | + // Allocate workspace memory |
| 126 | + // cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); |
| 127 | + |
| 128 | + Gemm gemm_op; |
| 129 | + |
| 130 | + // Instantiate CUTLASS kernel depending on templates |
| 131 | + cutlass::Status status = gemm_op.can_implement(arguments); |
| 132 | + CUTLASS_CHECK(status); |
| 133 | + |
| 134 | + status = gemm_op.initialize(arguments, nullptr); // workspace.get()); |
| 135 | + CUTLASS_CHECK(status); |
| 136 | + |
| 137 | + status = gemm_op(); |
| 138 | + CUTLASS_CHECK(status); |
| 139 | +} |
0 commit comments