Skip to content

Commit

Permalink
Update cuda runtime files for TMA Multicast and cluster (#3793)
Browse files Browse the repository at this point in the history
1. Add distributed arrive for mbarriers. It allows synchronization of
CTAs in the same CGA.
2. Fix clusterSync.
3. Add multicast variants of tma tensor loads. These functions do not
have the L2 cache hint argument.

This PR adds the functions necessary for implementing
#3689.
  • Loading branch information
rdspring1 authored Jan 29, 2025
1 parent 075f97f commit 9eb6121
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 2 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ list(APPEND NVFUSER_RUNTIME_FILES
${NVFUSER_ROOT}/runtime/block_sync_default.cu
${NVFUSER_ROOT}/runtime/block_welford_outer.cu
${NVFUSER_ROOT}/runtime/broadcast.cu
${NVFUSER_ROOT}/runtime/cluster.cu
${NVFUSER_ROOT}/runtime/complex_number.cu
${NVFUSER_ROOT}/runtime/fp16_support.cu
${NVFUSER_ROOT}/runtime/fp8_support.cu
Expand Down
1 change: 1 addition & 0 deletions csrc/runtime/compiled_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#include <nvfuser_resources/block_sync_default.h>
#include <nvfuser_resources/block_welford_outer.h>
#include <nvfuser_resources/broadcast.h>
#include <nvfuser_resources/cluster.h>
#include <nvfuser_resources/complex_number.h>
#include <nvfuser_resources/fp16_support.h>
#include <nvfuser_resources/fp8_support.h>
Expand Down
4 changes: 2 additions & 2 deletions runtime/cluster.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ void clusterWait() {

// Synchronize threads in cluster
void clusterSync() {
cluster_arrive();
cluster_wait();
clusterArrive();
clusterWait();
}

// Returns the dim3 grid size in terms of number of clusters.
Expand Down
10 changes: 10 additions & 0 deletions runtime/mbarrier.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ __device__ inline uint64_t arriveExpectTX(
: "r"(smem_barrier_ptr), "r"(tx_count));
return state;
}

__device__ inline void arrive(uint32_t smem_barrier_ptr, uint32_t cta_id) {
asm volatile(
"{.reg .b32 remaddr32;\n"
"mapa.shared::cluster.u32 remaddr32, %0, %1;\n"
"mbarrier.arrive.shared::cluster.b64 _, [remaddr32];\n"
"}"
:
: "r"(smem_barrier_ptr), "r"(cta_id));
}
#endif

__device__ inline void wait(uint32_t smem_barrier_ptr, uint64_t state) {
Expand Down
95 changes: 95 additions & 0 deletions runtime/memory.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,23 @@ __device__ inline void cpAsyncBulkTensorTileG2S(
: "memory");
}

__device__ inline void cpAsyncBulkTensorTileG2SMulticast(
const CpAsyncBulkTensorTileG2SIndex<1>& src,
uint32_t smem_addr,
uint16_t cta_mask) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(src.descriptor);
asm volatile(
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
" [%0], [%1, {%3}], [%2], %4;"
:
: "r"(smem_addr),
"l"(gmem_int_desc),
"r"(src.mbarrier),
"r"(src.crds[0]),
"h"(cta_mask)
: "memory");
}

__device__ inline void cpAsyncBulkTensorTileG2S(
const CpAsyncBulkTensorTileG2SIndex<2>& src,
uint32_t smem_addr) {
Expand All @@ -148,6 +165,24 @@ __device__ inline void cpAsyncBulkTensorTileG2S(
: "memory");
}

__device__ inline void cpAsyncBulkTensorTileG2SMulticast(
const CpAsyncBulkTensorTileG2SIndex<2>& src,
uint32_t smem_addr,
uint16_t cta_mask) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(src.descriptor);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
" [%0], [%1, {%3, %4}], [%2], %5;"
:
: "r"(smem_addr),
"l"(gmem_int_desc),
"r"(src.mbarrier),
"r"(src.crds[0]),
"r"(src.crds[1]),
"h"(cta_mask)
: "memory");
}

__device__ inline void cpAsyncBulkTensorTileG2S(
const CpAsyncBulkTensorTileG2SIndex<3>& src,
uint32_t smem_addr) {
Expand All @@ -165,6 +200,25 @@ __device__ inline void cpAsyncBulkTensorTileG2S(
: "memory");
}

__device__ inline void cpAsyncBulkTensorTileG2SMulticast(
const CpAsyncBulkTensorTileG2SIndex<3>& src,
uint32_t smem_addr,
uint16_t cta_mask) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(src.descriptor);
asm volatile(
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast_cluster"
" [%0], [%1, {%3, %4, %5}], [%2], %6;"
:
: "r"(smem_addr),
"l"(gmem_int_desc),
"r"(src.mbarrier),
"r"(src.crds[0]),
"r"(src.crds[1]),
"r"(src.crds[2]),
"h"(cta_mask)
: "memory");
}

__device__ inline void cpAsyncBulkTensorTileG2S(
const CpAsyncBulkTensorTileG2SIndex<4>& src,
uint32_t smem_addr) {
Expand All @@ -183,6 +237,26 @@ __device__ inline void cpAsyncBulkTensorTileG2S(
: "memory");
}

__device__ inline void cpAsyncBulkTensorTileG2SMulticast(
const CpAsyncBulkTensorTileG2SIndex<4>& src,
uint32_t smem_addr,
uint16_t cta_mask) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(src.descriptor);
asm volatile(
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast_cluster"
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
:
: "r"(smem_addr),
"l"(gmem_int_desc),
"r"(src.mbarrier),
"r"(src.crds[0]),
"r"(src.crds[1]),
"r"(src.crds[2]),
"r"(src.crds[3]),
"h"(cta_mask)
: "memory");
}

__device__ inline void cpAsyncBulkTensorTileG2S(
const CpAsyncBulkTensorTileG2SIndex<5>& src,
uint32_t smem_addr) {
Expand All @@ -202,6 +276,27 @@ __device__ inline void cpAsyncBulkTensorTileG2S(
: "memory");
}

__device__ inline void cpAsyncBulkTensorTileG2SMulticast(
const CpAsyncBulkTensorTileG2SIndex<5>& src,
uint32_t smem_addr,
uint16_t cta_mask) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(src.descriptor);
asm volatile(
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast_cluster"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
:
: "r"(smem_addr),
"l"(gmem_int_desc),
"r"(src.mbarrier),
"r"(src.crds[0]),
"r"(src.crds[1]),
"r"(src.crds[2]),
"r"(src.crds[3]),
"r"(src.crds[4]),
"h"(cta_mask)
: "memory");
}

// TMA Stores:

template <int dim>
Expand Down

0 comments on commit 9eb6121

Please sign in to comment.