@@ -67,6 +67,32 @@ DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
67
67
return c;
68
68
}
69
69
70
+ static DEVICE_INLINE void st_flag_release (int32_t & flag, int32_t * flag_addr) {
71
+ #if defined(USE_ROCM)
72
+ __atomic_store_n (flag_addr, flag, __ATOMIC_RELEASE);
73
+ #elif __CUDA_ARCH__ >= 700
74
+ asm volatile (
75
+ " st.global.release.sys.b32 [%1], %0;" ::" r" (flag), " l" (flag_addr));
76
+ #else
77
+ __threadfence_system ();
78
+ asm volatile (" st.global.volatile.b32 [%1], %0;" ::" r" (flag), " l" (flag_addr));
79
+ #endif
80
+ }
81
+
82
+ static DEVICE_INLINE void ld_flag_acquire (int32_t & flag, int32_t * flag_addr) {
83
+ #if defined(USE_ROCM)
84
+ flag = __atomic_load_n (flag_addr, __ATOMIC_ACQUIRE);
85
+ #elif __CUDA_ARCH__ >= 700
86
+ asm volatile (" ld.global.acquire.sys.b32 %0, [%1];"
87
+ : " =r" (flag)
88
+ : " l" (flag_addr));
89
+ #else
90
+ asm volatile (" ld.global.volatile.b32 %0, [%1];"
91
+ : " =r" (flag)
92
+ : " l" (flag_addr));
93
+ #endif
94
+ }
95
+
70
96
template <int32_t kWorldSize , bool has_acc>
71
97
#if defined(USE_ROCM)
72
98
__launch_bounds__ (512 )
@@ -106,25 +132,17 @@ __launch_bounds__(512)
106
132
#endif
107
133
}
108
134
// Synchronize the ranks.
109
- volatile int32_t * barrier_d = barriers[rank];
110
135
if (threadIdx .x < kWorldSize ) {
111
136
// The 1st block notifies the other ranks.
112
137
if (blockIdx .x == 0 ) {
113
- #if defined(USE_ROCM)
114
- __atomic_store_n (barriers[threadIdx .x ] + rank, flag, __ATOMIC_RELEASE);
115
- #else
116
- barriers[threadIdx .x ][rank] = flag;
117
- #endif
138
+ st_flag_release (flag, barriers[threadIdx .x ] + rank);
118
139
}
119
140
120
141
// Busy-wait until all ranks are ready.
121
- #if defined(USE_ROCM)
122
- while (__atomic_load_n (barrier_d + threadIdx .x , __ATOMIC_ACQUIRE) != flag) {
123
- }
124
- #else
125
- while (barrier_d[threadIdx .x ] != flag) {
126
- }
127
- #endif
142
+ int32_t rank_barrier = 0 ;
143
+ do {
144
+ ld_flag_acquire (rank_barrier, barriers[rank] + threadIdx .x );
145
+ } while (rank_barrier != flag);
128
146
}
129
147
130
148
// Make sure we can move on...
@@ -179,25 +197,15 @@ __launch_bounds__(512)
179
197
// notify all other blocks this blockIdx is ready
180
198
const int32_t flag_block_offset = kWorldSize + blockIdx .x * kWorldSize ;
181
199
182
- #if defined(USE_ROCM)
183
- __atomic_store_n (
184
- barriers[threadIdx .x ] + flag_block_offset + rank,
185
- flag,
186
- __ATOMIC_RELEASE);
187
- #else
188
- barriers[threadIdx .x ][flag_block_offset + rank] = flag;
189
- #endif
200
+ st_flag_release (flag, barriers[threadIdx .x ] + flag_block_offset + rank);
201
+
202
+ int32_t rank_barrier = 0 ;
190
203
191
204
// busy-wait until all ranks are ready
192
- #if defined(USE_ROCM)
193
- while (__atomic_load_n (
194
- barrier_d + flag_block_offset + threadIdx .x , __ATOMIC_ACQUIRE) !=
195
- flag) {
196
- }
197
- #else
198
- while (barrier_d[flag_block_offset + threadIdx .x ] != flag) {
199
- }
200
- #endif
205
+ do {
206
+ ld_flag_acquire (
207
+ rank_barrier, barriers[rank] + flag_block_offset + threadIdx .x );
208
+ } while (rank_barrier != flag);
201
209
}
202
210
}
203
211
@@ -319,32 +327,6 @@ at::Tensor car_tensor() {
319
327
at::TensorOptions ().dtype (at::kBFloat16 ).device (at::kCUDA ));
320
328
}
321
329
322
- static DEVICE_INLINE void st_flag_release (int32_t & flag, int32_t * flag_addr) {
323
- #if defined(USE_ROCM)
324
- __atomic_store_n (flag_addr, flag, __ATOMIC_RELEASE);
325
- #elif __CUDA_ARCH__ >= 700
326
- asm volatile (
327
- " st.global.release.sys.b32 [%1], %0;" ::" r" (flag), " l" (flag_addr));
328
- #else
329
- __threadfence_system ();
330
- asm volatile (" st.global.volatile.b32 [%1], %0;" ::" r" (flag), " l" (flag_addr));
331
- #endif
332
- }
333
-
334
- static DEVICE_INLINE void ld_flag_acquire (int32_t & flag, int32_t * flag_addr) {
335
- #if defined(USE_ROCM)
336
- flag = __atomic_load_n (flag_addr, __ATOMIC_ACQUIRE);
337
- #elif __CUDA_ARCH__ >= 700
338
- asm volatile (" ld.global.acquire.sys.b32 %0, [%1];"
339
- : " =r" (flag)
340
- : " l" (flag_addr));
341
- #else
342
- asm volatile (" ld.global.volatile.b32 %0, [%1];"
343
- : " =r" (flag)
344
- : " l" (flag_addr));
345
- #endif
346
- }
347
-
348
330
template <int32_t kWorldSize , bool split_last_dim>
349
331
#if defined(USE_ROCM)
350
332
__launch_bounds__ (512 ) __global__ void reduce_scatter(
@@ -472,25 +454,17 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
472
454
int32_t N_start = N_per_rank * rank;
473
455
474
456
// Synchronize the ranks.
475
- volatile int32_t * barrier_d = barriers[rank];
476
457
if (threadIdx .x < kWorldSize ) {
477
458
// The 1st block notifies the other ranks.
478
459
if (blockIdx .x == 0 ) {
479
- #if defined(USE_ROCM)
480
- __atomic_store_n (barriers[threadIdx .x ] + rank, flag, __ATOMIC_RELEASE);
481
- #else
482
- barriers[threadIdx .x ][rank] = flag;
483
- #endif
460
+ st_flag_release (flag, barriers[threadIdx .x ] + rank);
484
461
}
485
462
486
463
// Busy-wait until all ranks are ready.
487
- #if defined(USE_ROCM)
488
- while (__atomic_load_n (barrier_d + threadIdx .x , __ATOMIC_ACQUIRE) != flag) {
489
- }
490
- #else
491
- while (barrier_d[threadIdx .x ] != flag) {
492
- }
493
- #endif
464
+ int32_t rank_flag = 0 ;
465
+ do {
466
+ ld_flag_acquire (rank_flag, barriers[rank] + threadIdx .x );
467
+ } while (rank_flag != flag);
494
468
}
495
469
496
470
__syncthreads ();
0 commit comments