@@ -183,4 +183,93 @@ struct CudaStreams {
183183 }
184184};
185185
186+ struct CudaStreamsBarrier {
187+ private:
188+ std::vector<cudaEvent_t> _events;
189+ CudaStreams _streams;
190+
191+ CudaStreamsBarrier (const CudaStreamsBarrier &) {} // Prevent copy-construction
192+ CudaStreamsBarrier &operator =(const CudaStreamsBarrier &) {
193+ return *this ;
194+ } // Prevent assignment
195+ public:
196+ void create_on (const CudaStreams &streams) {
197+ _streams = streams;
198+
199+ GPU_ASSERT (streams.count () > 1 , " CudaStreamsFirstWaitsWorkersBarrier: "
200+ " Attempted to create on single GPU" );
201+ _events.resize (streams.count ());
202+ for (int i = 0 ; i < streams.count (); i++) {
203+ _events[i] = cuda_create_event (streams.gpu_index (i));
204+ }
205+ }
206+
207+ CudaStreamsBarrier (){};
208+
209+ void local_streams_wait_for_stream_0 (const CudaStreams &user_streams) {
210+ GPU_ASSERT (!_events.empty (),
211+ " CudaStreamsBarrier: must call create_on before use" );
212+ GPU_ASSERT (user_streams.gpu_index (0 ) == _streams.gpu_index (0 ),
213+ " CudaStreamsBarrier: synchronization can only be performed on "
214+ " the GPUs the barrier was initially created on." );
215+
216+ cuda_event_record (_events[0 ], user_streams.stream (0 ),
217+ user_streams.gpu_index (0 ));
218+ for (int j = 1 ; j < user_streams.count (); j++) {
219+ GPU_ASSERT (user_streams.gpu_index (j) == _streams.gpu_index (j),
220+ " CudaStreamsBarrier: synchronization can only be performed on "
221+ " the GPUs the barrier was initially created on." );
222+ cuda_stream_wait_event (user_streams.stream (j), _events[0 ],
223+ user_streams.gpu_index (j));
224+ }
225+ }
226+
227+ void stream_0_wait_for_local_streams (const CudaStreams &user_streams) {
228+ GPU_ASSERT (
229+ !_events.empty (),
230+ " CudaStreamsFirstWaitsWorkersBarrier: must call create_on before use" );
231+ GPU_ASSERT (
232+ user_streams.count () <= _events.size (),
233+ " CudaStreamsFirstWaitsWorkersBarrier: trying to synchronize too many "
234+ " streams. "
235+ " The barrier was created on a LUT that had %lu active streams, while "
236+ " the user stream set has %u streams" ,
237+ _events.size (), user_streams.count ());
238+
239+ if (user_streams.count () > 1 ) {
240+ // Worker GPUs record their events
241+ for (int j = 1 ; j < user_streams.count (); j++) {
242+ GPU_ASSERT (_streams.gpu_index (j) == user_streams.gpu_index (j),
243+ " CudaStreamsBarrier: The user stream "
244+ " set GPU[%d]=%u while the LUT stream set GPU[%d]=%u" ,
245+ j, user_streams.gpu_index (j), j, _streams.gpu_index (j));
246+
247+ cuda_event_record (_events[j], user_streams.stream (j),
248+ user_streams.gpu_index (j));
249+ }
250+
251+ // GPU 0 waits for all workers
252+ for (int j = 1 ; j < user_streams.count (); j++) {
253+ cuda_stream_wait_event (user_streams.stream (0 ), _events[j],
254+ user_streams.gpu_index (0 ));
255+ }
256+ }
257+ }
258+
259+ void release () {
260+ for (int j = 0 ; j < _streams.count (); j++) {
261+ cuda_event_destroy (_events[j], _streams.gpu_index (j));
262+ }
263+
264+ _events.clear ();
265+ }
266+
267+ ~CudaStreamsBarrier () {
268+ GPU_ASSERT (_events.empty (),
269+ " CudaStreamsBarrier: must "
270+ " call release before destruction: events size = %lu" ,
271+ _events.size ());
272+ }
273+ };
274+
186275#endif
0 commit comments