diff --git a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_blockscaled_rcgrouped.cu b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_blockscaled_rcgrouped.cu new file mode 100644 index 0000000000..5fb925bc28 --- /dev/null +++ b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_blockscaled_rcgrouped.cu @@ -0,0 +1,885 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Ragged Contiguous Blockscaled Grouped GEMM example using CUTLASS 3 APIs for the NVIDIA Blackwell SM100 architecture. + + This example demonstrates an implementation of Ragged Contiguous Grouped GEMM using a TMA + Blackwell SM100 TensorOp-based warp-specialized kernel for narrow precisions (FP4) with Scale Factors (In and Out). + For this example all scheduling work is performed on the device. + + To run this example: + + $ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_blockscaled_rcgrouped --m=128 --k=128 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Note that m and k remain consistent across groups and only n is randomized if it's not provided through the args. + Alpha and beta values are randomized across the different groups. + + To run this example for a set of problems using the benchmark option: + + $ ./examples/92_blackwell_grouped_gemm/92_blackwell_moe_gemm_blockscaled_rcgrouped --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x256 + 1 256x128x256 + 2 256x256x256 and so on + Note that one must keep m and k consistent across groups in the benchmark file. +*/ + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "helper.h" +using namespace cute; + +using ProblemShape = cutlass::gemm::MoEProblemShape>; // per group +using ElementInput = cutlass::float_e4m3_t; // Element type for Input matrix operands +using ElementSF = cutlass::float_ue8m0_t; // Element type for SF matrix operands +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operands + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::mx_float8_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 16; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::mx_float8_t; // Element type for A matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 16; // Alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = ElementC; // Element type for D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Alignment of D matrix in units of elements (up to 16 bytes) +using ElementAccumulator = float; // Element type for internal accumulation + +using ElementSFD = cutlass::float_ue4m3_t; // Element type for SF Output operands +constexpr int OutputSFVectorSize = 16; +using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor< + cutlass::epilogue::thread::SiLu, + OutputSFVectorSize, + ElementD, + ElementAccumulator, + ElementSFD, + LayoutC, + ElementC>; + +// Core kernel configurations +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + +// Runtime Cluster Shape +using ClusterShape = Shape; + +// Different configs for 1SM and 2SM MMA kernel +struct MMA1SMConfig { + using MmaTileShape = Shape<_128,_256,_128>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch +}; + +struct MMA2SMConfig { + using MmaTileShape = Shape<_256,_256,_128>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +}; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + typename MMA1SMConfig::MmaTileShape, ClusterShape, + Shape<_128,_64>, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, // Set ElementC as void here to run kernel as void-C case + ElementD, LayoutC *, AlignmentD, + typename MMA1SMConfig::EpilogueSchedule + // , FusionOperation // Enable for SF Output +>::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + typename MMA1SMConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename MMA1SMConfig::KernelSchedule +>::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; +using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter; +using Gemm = Gemm1SM; + +using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + typename MMA2SMConfig::MmaTileShape, ClusterShape, + Shape<_128,_64>, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, // Set ElementC as void here to run kernel as void-C case + ElementD, LayoutC *, AlignmentD, + typename MMA2SMConfig::EpilogueSchedule + // , FusionOperation // Enable for SF Output +>::CollectiveOp; +using CollectiveMainloop2SM = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + typename MMA2SMConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue2SM::SharedStorage))>, + typename MMA2SMConfig::KernelSchedule +>::CollectiveOp; +using GemmKernel2SM = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop2SM, + CollectiveEpilogue2SM +>; +using Gemm2SM = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; +using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + OutputSFVectorSize, + cute::is_same_v ? cute::UMMA::Major::K : cute::UMMA::Major::MN + >; +using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; +using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF; + +// Host-side allocations +std::vector alpha_host; +std::vector beta_host; + +using HostTensorA = cutlass::HostTensor; +using HostTensorB = cutlass::HostTensor; +using HostTensorSF = cutlass::HostTensor; +using HostTensorC = cutlass::HostTensor; +using HostTensorD = cutlass::HostTensor; + +HostTensorA block_A; +HostTensorSF block_SFA; +std::vector block_B; +std::vector block_SFB; +std::vector block_C; +std::vector block_D; +std::vector block_SFD; +std::vector block_ref_D; + +// Device-side allocations +cutlass::DeviceAllocation tokens_per_expert; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_SFA; +cutlass::DeviceAllocation ptr_SFB; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_SFD; +cutlass::DeviceAllocation ptr_ref_D; + +StrideA stride_A; +LayoutSFA layout_SFA; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; +// A matrix wide constant value to scale the output matrix +// Avoids generating small FP4 values. +// NormConst is a single device-side constant value, its not per-batch or per-group +cutlass::DeviceAllocation norm_constant_device; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; +// Command line options parsing +struct Options { + + bool help = false; + bool verification = true; + bool use_pdl = false; + + float alpha = FLT_MAX; + float beta = FLT_MAX; + float norm_constant = 1.0; + int warmup = 1000; + int iterations = 1000; + int m = 1024, n = 2048, k = 512, groups = 10; + dim3 cluster_shape = dim3(2,1,1); + dim3 cluster_shape_fallback = dim3(2,1,1); + RasterOrderOptions raster_order = RasterOrderOptions::AlongN; + int max_sm_count = INT_MAX; + std::string benchmark_path; + std::vector tokens_per_expert_host; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + if (cmd.check_cmd_line_flag("no_verif")) { + verification = false; + } + if (cmd.check_cmd_line_flag("use_pdl")) { + use_pdl = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX); + cmd.get_cmd_line_argument("beta", beta, FLT_MAX); + cmd.get_cmd_line_argument("norm_constant", norm_constant, float(1.0)); + cmd.get_cmd_line_argument("warmup", warmup); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + cmd.get_cmd_line_argument("cluster_m", cluster_shape.x); + cmd.get_cmd_line_argument("cluster_n", cluster_shape.y); + cmd.get_cmd_line_argument("cluster_fallback_m", cluster_shape_fallback.x); + cmd.get_cmd_line_argument("cluster_fallback_n", cluster_shape_fallback.y); + cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_host.clear(); + tokens_per_expert_host.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster_order = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster_order = RasterOrderOptions::AlongM; + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_host.reserve(groups); + + m = cmd_line_m; + k = cmd_line_k; + if (m < 1) { + m = alignment * ((rand() % 64) + 1); + } + if (k < 1) { + k = alignment * ((rand() % 64) + 1); + } + + for (int i = groups; i > 0; i--) { + int n = cmd_line_n; + if (n < 0) { + n = alignment * ((rand() % 64) + 1); + } + problem_sizes_host.push_back({m, n, k}); + tokens_per_expert_host.push_back(n); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + extent.at(i) = std::atoi(tokens.at(i).c_str()); + } + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + tokens_per_expert_host.push_back(extent.n()); + } + groups = static_cast(problem_sizes_host.size()); + m = get<0>(problem_sizes_host.at(0)); + k = get<2>(problem_sizes_host.at(0)); + + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "92_blackwell_moe_gemm_blockscaled_rcgrouped\n\n" + << " Blackwell Block Scaled Narrow Precision Ragged Contiguous Grouped GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --norm_constant= Epilogue scalar normalization constant for the output matrix\n\n" + << " --cluster_m= and --cluster_n= Sets the X,Y dims of the preferred cluster shape\n" + << " --cluster_fallback_m= and --cluster_fallback_n= Sets the X,Y dims of the fallback cluster shape\n\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M)\n\n" + << " --iterations= Number of profiling iterations to perform\n\n" + << " --benchmark= Executes a benchmark problem size\n" + << " --max_sm_count= Run kernels using only these number of SMs\n" + << " --no_verif Do not run (host-side) verification kernels\n" + << " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "92_blackwell_moe_gemm_blockscaled_rcgrouped" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Allocates device-side data +void allocate(const Options &options) { + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + + auto layout_B = make_layout(make_shape(N, K, 1), stride_B); + auto layout_C = make_layout(make_shape(M, N, 1), stride_C); + auto layout_D = make_layout(make_shape(M, N, 1), stride_D); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + + block_B.push_back(HostTensorB(cutlass::make_Coord(size(layout_B)))); + block_SFB.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFB))))); + block_C.push_back(HostTensorC(cutlass::make_Coord(size(layout_C)))); + block_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + block_SFD.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFD))))); + block_ref_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + } + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + uint64_t seed = 2020; + + // Setting up tokens_per_expert array + tokens_per_expert.reset(options.tokens_per_expert_host.size()); + tokens_per_expert.copy_from_host(options.tokens_per_expert_host.data()); + + // + // Assign pointers + // + + std::vector ptr_B_host(options.groups); + std::vector ptr_SFB_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_SFD_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, options.groups)); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, options.groups}); + auto layout_A = make_layout(make_shape(options.m, options.k, options.groups), stride_A); + block_A.reset(cutlass::make_Coord(size(layout_A))); + + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + initialize_block(block_A.host_view(), seed + 2022); + initialize_block(block_SFA.host_view(), seed + 2024); + + block_A.sync_device(); + block_SFA.sync_device(); + + for (int32_t i = 0; i < options.groups; ++i) { + + initialize_block(block_B.at(i).host_view(), seed + 2022); + initialize_block(block_C.at(i).host_view(), seed + 2023); + initialize_block(block_SFB.at(i).host_view(), seed + 2025); + + block_B.at(i).sync_device(); + block_C.at(i).sync_device(); + block_SFB.at(i).sync_device(); + + ptr_B_host.at(i) = block_B.at(i).device_data(); + ptr_SFB_host.at(i) = block_SFB.at(i).device_data(); + ptr_C_host.at(i) = block_C.at(i).device_data(); + ptr_D_host.at(i) = block_D.at(i).device_data(); + ptr_SFD_host.at(i) = block_SFD.at(i).device_data(); + + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_SFB.reset(options.groups); + ptr_SFB.copy_from_host(ptr_SFB_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + ptr_SFD.reset(options.groups); + ptr_SFD.copy_from_host(ptr_SFD_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + + norm_constant_device.reset(1); + norm_constant_device.copy_from_host(&options.norm_constant); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename Gemm::Arguments args_from_options(Options &options) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count); + + if (!is_static_v) { + if (size<0>(typename Gemm::GemmKernel::CollectiveMainloop::AtomThrShapeMNK{}) == 2 && + (options.cluster_shape.x < 2 || options.cluster_shape_fallback.x < 2)) { + std::cout << "Error: MMA2SMConfig kernel config needs cluster_dim.x >= 2" << std::endl; + exit(-1); + } + hw_info.cluster_shape = options.cluster_shape; + hw_info.cluster_shape_fallback = options.cluster_shape_fallback; + } + + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + // If alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + if (options.alpha != FLT_MAX){ + // Single alpha for all groups + fusion_args.alpha = options.alpha; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.dAlpha = {_0{}, _0{}, 0}; + } + else { + fusion_args.alpha = 0; + fusion_args.alpha_ptr_array = alpha_device.get(); + // Only one alpha per each group + fusion_args.dAlpha = {_0{}, _0{}, 1}; + } + if (options.beta != FLT_MAX) { + // Single beta for all groups + fusion_args.beta = options.beta; + fusion_args.beta_ptr_array = nullptr; + fusion_args.dBeta = {_0{}, _0{}, 0}; + } + else { + fusion_args.beta = 0; + fusion_args.beta_ptr_array = beta_device.get(); + // Only one beta per each group + fusion_args.dBeta = {_0{}, _0{}, 1}; + } + + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = options.raster_order; + + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.m, options.n, options.k, options.groups, tokens_per_expert.get()}, + {block_A.device_data(), ptr_B.get(), + block_SFA.device_data(), ptr_SFB.get()}, + {fusion_args, ptr_C.get(), nullptr, ptr_D.get(), nullptr}, + hw_info, scheduler + }; + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + bool passed = true; + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + auto layout_A = make_layout(make_shape(M, K, 1), stride_A); + auto layout_B = make_layout(make_shape(N, K, 1), stride_B); + auto layout_C = make_layout(make_shape(M, N, 1), stride_C); + auto layout_D = make_layout(make_shape(M, N, 1), stride_D); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()) + size_t(1) * i * size(layout_A), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data() + size_t(1) * i * size(filter_zeros(layout_SFA)), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.at(i).host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.at(i).host_data(), layout_SFB); + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C); + auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D); + + cutlass::reference::host::GettEpilogueParams< + float, float, + ElementAccumulator, ElementAccumulator, + decltype(tensor_C), decltype(tensor_ref_D) + > epilogue_params{}; + + epilogue_params.C = tensor_C; + epilogue_params.D = tensor_ref_D; + epilogue_params.alpha = alpha_host.at(i); + epilogue_params.beta = beta_host.at(i); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + block_D.at(i).sync_host(); + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::host::TensorEquals(block_ref_D.at(i).host_view(), block_D.at(i).host_view()); + + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options, bool host_problem_shapes_available = true) +{ + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + if (options.verification) { + std::cout << " Host-side verification is now running - may be very slow for large cases." << std::endl; + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { + exit(-1); + } + } + else { + std::cout << " Verification is turned off for this run." << std::endl; + } + + // Run profiling loop + if (options.iterations > 0) { + for (int iter = 0; iter < options.warmup; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); + } + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); + } + timer.stop(); + + // Compute average setup and runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host); + + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || + ((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8))) { + std::cerr << "This example requires CUDA 12.8 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) { + std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + allocate(options); + initialize(options); + + // + // Evaluate CUTLASS kernels + // + + std::cout << "Running kernel with 1SM MMA config:" << std::endl; + run(options); + std::cout << "Running kernel with 2SM MMA config:" << std::endl; + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/92_blackwell_moe_gemm/CMakeLists.txt b/examples/92_blackwell_moe_gemm/CMakeLists.txt index dbdd531319..02359a2aa6 100644 --- a/examples/92_blackwell_moe_gemm/CMakeLists.txt +++ b/examples/92_blackwell_moe_gemm/CMakeLists.txt @@ -37,6 +37,7 @@ set(TEST_DEEPSEEK_B_FP4 --m=7168 --n=1 --k=512 --groups=256 --iterations=0) set(TEST_IRREGULAR_MNK_FP4 --m=4080 --n=9 --k=4160 --groups=8 --iterations=0) set(TEST_FIXED --m=2048 --n=5120 --k=8192 --iterations=0) # Fixed problem sizes +set(TEST_FIXED_SMALL --m=2048 --n=512 --k=8192 --groups=2 --iterations=0) # Fixed problem sizes if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( @@ -62,6 +63,13 @@ cutlass_example_add_executable( TEST_FIXED ) +cutlass_example_add_executable( + 92_blackwell_moe_gemm_blockscaled_rcgrouped + 92_blackwell_moe_gemm_blockscaled_rcgrouped.cu + TEST_COMMAND_OPTIONS + TEST_FIXED_SMALL +) + cutlass_example_add_executable( 92_blackwell_moe_gemm_fp4_regular 92_blackwell_moe_gemm_fp4_regular.cu diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl index f5e0ed70cd..6be3f518ee 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl @@ -241,8 +241,9 @@ struct CollectiveBuilder< static constexpr uint32_t AccumulatorPipelineStageCount = (MMA_N == 256) ? 1 : 2; static constexpr bool IsArrayOfPointersGemm = cute::is_base_of_v; // Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler. - static constexpr bool IsGroupGemm = !cute::is_same_v; - static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 2); + static constexpr bool IsGroupGemm = !(cute::is_same_v) && !(cute::is_same_v); + static constexpr bool IsRCGroupGemm = (cute::is_same_v) && !(cute::is_same_v); + static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 2); static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< ClusterShape_MNK, @@ -265,13 +266,21 @@ struct CollectiveBuilder< using DispatchPolicy = cute::conditional_t, - cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockScaled< + cute::conditional_t, + cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK + > + >, + cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockScaled< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 15a89fa945..bfd6731d18 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -56,6 +56,7 @@ #include "cutlass/gemm/collective/sm100_mma_warpspecialized.hpp" #include "cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp" #include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp" #include "cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp" #include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp" #include "cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp" diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp new file mode 100644 index 0000000000..2b8d71791d --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp @@ -0,0 +1,1293 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/detail/collective/moe_stride_utils.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100RCGroupGemmTmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100RCGroupGemmTmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or + shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 64/128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static constexpr int IsCtaN64 = shape<1>(CtaShape_MNK{}) == 64; + static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}; + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + using InternalStrideA = cute::remove_pointer_t; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + using InternalStrideB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using InternalLayoutSFA = cute::remove_pointer_t; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + using InternalLayoutSFB = cute::remove_pointer_t; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + using TmaInternalElementA = cute::conditional_t; + using TmaInternalElementB = cute::conditional_t; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_SFB; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes; + + template + struct TmemStorage { + AccTensor accumulators; + SfaTensor tCtSFA; + SfbTensor tCtSFB; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + ArrayElementB const** ptr_B{nullptr}; + ElementSF const* ptr_SFA{nullptr}; + ElementSF const** ptr_SFB{nullptr}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFA = decltype(make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), InternalLayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFB = decltype(make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), InternalLayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + ClusterLayoutSfb_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const* ptr_A; + ArrayElementB const** ptr_B; + ElementSF const* ptr_SFA; + ElementSF const** ptr_SFB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_M = int32_t(size<0>(TileShape{})); + auto init_N = int32_t(size<1>(TileShape{})); + auto init_K_A = int32_t(size<2>(TileShape{})); + auto init_K_B = int32_t(size<2>(TileShape{})); + auto init_L = 1; + + // Tensor pointers will be fixed before the first access + auto ptr_A_first_batch = recast_ptr(args.ptr_A); + TmaInternalElementB const* ptr_B_first_batch = nullptr; + + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_K_A = get<2>(problem_shape_MNK); + + auto shape_a = make_shape(init_M, init_K_A, problem_shapes.groups()); + InternalStrideA stride_a = cutlass::make_internal_packed_stride(InternalStrideA{}, shape_a); + InternalStrideB stride_b = InternalStrideB{}; + + InternalLayoutSFA layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K_A, problem_shapes.groups())); + InternalLayoutSFB layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(init_M, init_N, init_K_B, 1)); + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M, init_K_A, problem_shapes.groups()), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N, init_K_B, init_L), stride_b)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + // Tensor pointers will be fixed before the first access + ElementSF const* ptr_SFA_first_batch = recast_ptr(args.ptr_SFA); + ElementSF const* ptr_SFB_first_batch = nullptr; + + Tensor tensor_sfa = make_tensor(ptr_SFA_first_batch, layout_SFA); + Tensor tensor_sfb = make_tensor(ptr_SFB_first_batch, layout_SFB); + + // Cluster layout for TMA construction of SFB + auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk); + + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + args.ptr_A, + reinterpret_cast(args.ptr_B), + args.ptr_SFA, + reinterpret_cast(args.ptr_SFB), + }; + } + + struct TensorMaps : cute::aligned_struct<256, _0> { + cute::TmaDescriptor tma_desc_b; + cute::TmaDescriptor tma_desc_sfb; + }; + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + // Allocate gmem space for input tensormaps per each SM. + return (sm_count * sizeof(TensorMaps) * NumTmaDescriptorsPerSm); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return tmem_storage.accumulators(_,_,_,stage); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtSFA = tCtSFA; + tmem_storage.tCtSFB = tCtSFB; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// tAgSFA_mkl - partitioned gmem tensor for SFA + /// tBgSFB_nkl - partitioned gmem tensor for SFB + /// tAsSFA - partitioned tmem tensor for SFA + /// tAsSFB - partitioned tmem tensor for SFB + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + int32_t num_groups, + int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,num_groups)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Represent the full tensor of Scale factors + InternalLayoutSFA layout_SFA{}; + InternalLayoutSFB layout_SFB{}; + + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, num_groups)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else if constexpr (IsCtaN64) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + } + }(); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + ThrMMA cta_mma_sfb = TiledMMA_SF{}.get_slice(blockIdx.x % size(typename TiledMMA_SF::AtomThrID{})); + Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Define the CTA-in-Cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + auto ret = cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb); // multicast masks + + if constexpr (IsTensorMapUpdateAsync) { + return ret; + } else { + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + return cute::tuple_cat(ret, cute::make_tuple(input_tensormaps)); + } + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = tmem_storage.tCtSFA; + Tensor tCtSFB = tmem_storage.tCtSFB; + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class TensorMapB, class TensorMapSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change, + int curr_batch) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb, + input_tensormaps] = load_inputs; + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(input_tensormaps); + } + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, curr_batch); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, curr_batch); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + // Note: We don't synchronize the sf_pipeline for "Buffer_Empty". We use mainloop pipeline + // to do the synchronization at once. + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + copy(observed_tma_load_sfa_->with(*tma_barrier, mcast_mask_sfa), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_sfb), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class CtaTileCoord, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + auto [tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_s2t, + thr_tCtSFA_s2t, tiled_copy_s2t_SFB, + thr_tCsSFB_s2t, thr_tCtSFB_s2t] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB; + if (size<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else if constexpr (IsCtaN64) { + // Move in increments of 64 columns of SFB + auto tCtSFB_tmp = tCtSFB; + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + (size<1>(cta_tile_coord) % 2) * 2; + return tCtSFB_tmp; + } + else { + return tCtSFB; + } + }(); + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + if constexpr (IsOverlappingAccum) { + // first iteration manual unroll for tmem overlap kernel + if (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + } + else { + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + template + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + + TensorMaps* gmem_tensormap = &(reinterpret_cast(mainloop_params.tensormaps)[sm_idx * NumTmaDescriptorsPerSm]); + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + Tensor pSFB_tensormap = make_tensor(observed_tma_load_sfb_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFB), Int<1>{}, Int<1>{}); + + copy(recast(pB_tensormap), recast(sB_tensormap)); + + copy(recast(pSFB_tensormap), recast(sSFB_tensormap)); + } + + __syncwarp(); + + struct TensorMapArray { + + TensorMaps *tensor_maps; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(void* tensormaps) : tensor_maps(reinterpret_cast(tensormaps)) {} + + CUTLASS_DEVICE + cute::tuple + operator[](int32_t idx) { + idx = idx % NumTmaDescriptorsPerSm; + return cute::make_tuple(&tensor_maps[idx].tma_desc_b, &tensor_maps[idx].tma_desc_sfb); + } + }; + if constexpr (IsTensorMapUpdateAsync) { + return TensorMapArray(gmem_tensormap); + } else { + return cute::make_tuple(&gmem_tensormap->tma_desc_b, &gmem_tensormap->tma_desc_sfb); + } + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + mainloop_params.ptr_SFB[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + cute::array prob_shape_SFB = {1,1,1,1,1}; + cute::array prob_stride_SFB = {0,0,0,0,0}; + + TmaInternalElementB const* ptr_B = nullptr; + auto internal_shape_b = make_shape(static_cast(N), static_cast(K), 1); + InternalStrideB stride_b = cutlass::make_internal_packed_stride(InternalStrideB{}, internal_shape_b); + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), stride_b); + + ElementSF const* ptr_SF = nullptr; + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + Tensor tensor_sfb = make_tensor(ptr_SF, layout_SFB); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfb_, tensor_sfb, + prob_shape_SFB, prob_stride_SFB); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_SFB) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + prob_shape_SFB, + prob_stride_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape problem_shape, + int32_t next_batch + ) { + if (cute::elect_one_sync()) { + + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release( + shared_tensormaps, + input_tensormaps + ); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps + ) { + + if constexpr (WaitForInflightTmaRequests) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + } + + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + + LayoutSFA layout_SFA_{}; + LayoutSFB layout_SFB_{}; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 4076e52e13..9c42b4647c 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -1192,6 +1192,21 @@ struct MainloopSm100RCGroupGemmTmaUmmaWarpSpecialized { using Schedule = KernelPtrArrayTmaWarpSpecializedSm100; }; +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100RCGroupGemmTmaUmmaWarpSpecializedBlockScaled { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; + using Schedule = KernelPtrArrayTmaWarpSpecializedBlockScaledSm100; +}; + // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< int Stages_,