Skip to content

Commit bf7327a

Browse files
xw285cornellfacebook-github-bot
authored andcommitted
fix volatile synchronization with acquire/relax (#3728)
Summary: Pull Request resolved: #3728 X-link: facebookresearch/FBGEMM#811 Following https://fb.workplace.com/groups/pytorch.dev/permalink/1731892800722526/ - fixing volatile synchronization with explicit acuire/relax syntax. Reviewed By: ngimel Differential Revision: D70080262 fbshipit-source-id: 3ca4e11ca94a8d0541294dbeceee9bcd5ab06690
1 parent 6ac5a7f commit bf7327a

File tree

1 file changed

+43
-69
lines changed
  • fbgemm_gpu/experimental/gen_ai/src/comm

1 file changed

+43
-69
lines changed

fbgemm_gpu/experimental/gen_ai/src/comm/car.cu

+43-69
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,32 @@ DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
6767
return c;
6868
}
6969

70+
static DEVICE_INLINE void st_flag_release(int32_t& flag, int32_t* flag_addr) {
71+
#if defined(USE_ROCM)
72+
__atomic_store_n(flag_addr, flag, __ATOMIC_RELEASE);
73+
#elif __CUDA_ARCH__ >= 700
74+
asm volatile(
75+
"st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
76+
#else
77+
__threadfence_system();
78+
asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
79+
#endif
80+
}
81+
82+
static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) {
83+
#if defined(USE_ROCM)
84+
flag = __atomic_load_n(flag_addr, __ATOMIC_ACQUIRE);
85+
#elif __CUDA_ARCH__ >= 700
86+
asm volatile("ld.global.acquire.sys.b32 %0, [%1];"
87+
: "=r"(flag)
88+
: "l"(flag_addr));
89+
#else
90+
asm volatile("ld.global.volatile.b32 %0, [%1];"
91+
: "=r"(flag)
92+
: "l"(flag_addr));
93+
#endif
94+
}
95+
7096
template <int32_t kWorldSize, bool has_acc>
7197
#if defined(USE_ROCM)
7298
__launch_bounds__(512)
@@ -106,25 +132,17 @@ __launch_bounds__(512)
106132
#endif
107133
}
108134
// Synchronize the ranks.
109-
volatile int32_t* barrier_d = barriers[rank];
110135
if (threadIdx.x < kWorldSize) {
111136
// The 1st block notifies the other ranks.
112137
if (blockIdx.x == 0) {
113-
#if defined(USE_ROCM)
114-
__atomic_store_n(barriers[threadIdx.x] + rank, flag, __ATOMIC_RELEASE);
115-
#else
116-
barriers[threadIdx.x][rank] = flag;
117-
#endif
138+
st_flag_release(flag, barriers[threadIdx.x] + rank);
118139
}
119140

120141
// Busy-wait until all ranks are ready.
121-
#if defined(USE_ROCM)
122-
while (__atomic_load_n(barrier_d + threadIdx.x, __ATOMIC_ACQUIRE) != flag) {
123-
}
124-
#else
125-
while (barrier_d[threadIdx.x] != flag) {
126-
}
127-
#endif
142+
int32_t rank_barrier = 0;
143+
do {
144+
ld_flag_acquire(rank_barrier, barriers[rank] + threadIdx.x);
145+
} while (rank_barrier != flag);
128146
}
129147

130148
// Make sure we can move on...
@@ -179,25 +197,15 @@ __launch_bounds__(512)
179197
// notify all other blocks this blockIdx is ready
180198
const int32_t flag_block_offset = kWorldSize + blockIdx.x * kWorldSize;
181199

182-
#if defined(USE_ROCM)
183-
__atomic_store_n(
184-
barriers[threadIdx.x] + flag_block_offset + rank,
185-
flag,
186-
__ATOMIC_RELEASE);
187-
#else
188-
barriers[threadIdx.x][flag_block_offset + rank] = flag;
189-
#endif
200+
st_flag_release(flag, barriers[threadIdx.x] + flag_block_offset + rank);
201+
202+
int32_t rank_barrier = 0;
190203

191204
// busy-wait until all ranks are ready
192-
#if defined(USE_ROCM)
193-
while (__atomic_load_n(
194-
barrier_d + flag_block_offset + threadIdx.x, __ATOMIC_ACQUIRE) !=
195-
flag) {
196-
}
197-
#else
198-
while (barrier_d[flag_block_offset + threadIdx.x] != flag) {
199-
}
200-
#endif
205+
do {
206+
ld_flag_acquire(
207+
rank_barrier, barriers[rank] + flag_block_offset + threadIdx.x);
208+
} while (rank_barrier != flag);
201209
}
202210
}
203211

@@ -319,32 +327,6 @@ at::Tensor car_tensor() {
319327
at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA));
320328
}
321329

322-
static DEVICE_INLINE void st_flag_release(int32_t& flag, int32_t* flag_addr) {
323-
#if defined(USE_ROCM)
324-
__atomic_store_n(flag_addr, flag, __ATOMIC_RELEASE);
325-
#elif __CUDA_ARCH__ >= 700
326-
asm volatile(
327-
"st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
328-
#else
329-
__threadfence_system();
330-
asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
331-
#endif
332-
}
333-
334-
static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) {
335-
#if defined(USE_ROCM)
336-
flag = __atomic_load_n(flag_addr, __ATOMIC_ACQUIRE);
337-
#elif __CUDA_ARCH__ >= 700
338-
asm volatile("ld.global.acquire.sys.b32 %0, [%1];"
339-
: "=r"(flag)
340-
: "l"(flag_addr));
341-
#else
342-
asm volatile("ld.global.volatile.b32 %0, [%1];"
343-
: "=r"(flag)
344-
: "l"(flag_addr));
345-
#endif
346-
}
347-
348330
template <int32_t kWorldSize, bool split_last_dim>
349331
#if defined(USE_ROCM)
350332
__launch_bounds__(512) __global__ void reduce_scatter(
@@ -472,25 +454,17 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
472454
int32_t N_start = N_per_rank * rank;
473455

474456
// Synchronize the ranks.
475-
volatile int32_t* barrier_d = barriers[rank];
476457
if (threadIdx.x < kWorldSize) {
477458
// The 1st block notifies the other ranks.
478459
if (blockIdx.x == 0) {
479-
#if defined(USE_ROCM)
480-
__atomic_store_n(barriers[threadIdx.x] + rank, flag, __ATOMIC_RELEASE);
481-
#else
482-
barriers[threadIdx.x][rank] = flag;
483-
#endif
460+
st_flag_release(flag, barriers[threadIdx.x] + rank);
484461
}
485462

486463
// Busy-wait until all ranks are ready.
487-
#if defined(USE_ROCM)
488-
while (__atomic_load_n(barrier_d + threadIdx.x, __ATOMIC_ACQUIRE) != flag) {
489-
}
490-
#else
491-
while (barrier_d[threadIdx.x] != flag) {
492-
}
493-
#endif
464+
int32_t rank_flag = 0;
465+
do {
466+
ld_flag_acquire(rank_flag, barriers[rank] + threadIdx.x);
467+
} while (rank_flag != flag);
494468
}
495469

496470
__syncthreads();

0 commit comments

Comments
 (0)