@@ -87,19 +87,102 @@ uct_cuda_copy_get_mem_type(uct_md_h md, void *address, size_t length)
87
87
return mem_info .type ;
88
88
}
89
89
90
- static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ctx_rsc_get (
91
- uct_cuda_copy_iface_t * iface , uct_cuda_copy_ctx_rsc_t * * ctx_rsc_p )
90
+ static ucs_status_t
91
+ uct_cuda_primary_ctx_push_first_active ( CUdevice * cuda_device_p )
92
92
{
93
+ int num_devices , device_index ;
94
+ ucs_status_t status ;
95
+ CUdevice cuda_device ;
96
+ CUcontext cuda_ctx ;
97
+
98
+ status = UCT_CUDADRV_FUNC_LOG_ERR (cuDeviceGetCount (& num_devices ));
99
+ if (status != UCS_OK ) {
100
+ return status ;
101
+ }
102
+
103
+ for (device_index = 0 ; device_index < num_devices ; ++ device_index ) {
104
+ status = UCT_CUDADRV_FUNC_LOG_ERR (
105
+ cuDeviceGet (& cuda_device , device_index ));
106
+ if (status != UCS_OK ) {
107
+ return status ;
108
+ }
109
+
110
+ status = uct_cuda_primary_ctx_retain (cuda_device , 0 , & cuda_ctx );
111
+ if (status == UCS_OK ) {
112
+ /* Found active primary context */
113
+ status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxPushCurrent (cuda_ctx ));
114
+ if (status != UCS_OK ) {
115
+ UCT_CUDADRV_FUNC_LOG_WARN (
116
+ cuDevicePrimaryCtxRelease (cuda_device ));
117
+ return status ;
118
+ }
119
+
120
+ * cuda_device_p = cuda_device ;
121
+ return UCS_OK ;
122
+ } else if (status != UCS_ERR_NO_DEVICE ) {
123
+ return status ;
124
+ }
125
+ }
126
+
127
+ return UCS_ERR_NO_DEVICE ;
128
+ }
129
+
130
+ static UCS_F_ALWAYS_INLINE void
131
+ uct_cuda_primary_ctx_pop_and_release (CUdevice cuda_device )
132
+ {
133
+ if (ucs_likely (cuda_device == CU_DEVICE_INVALID )) {
134
+ return ;
135
+ }
136
+
137
+ UCT_CUDADRV_FUNC_LOG_WARN (cuCtxPopCurrent (NULL ));
138
+ UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (cuda_device ));
139
+ }
140
+
141
+ static UCS_F_ALWAYS_INLINE ucs_status_t
142
+ uct_cuda_copy_ctx_rsc_get (uct_cuda_copy_iface_t * iface , CUdevice * cuda_device_p ,
143
+ uct_cuda_copy_ctx_rsc_t * * ctx_rsc_p )
144
+ {
145
+ unsigned long long ctx_id ;
146
+ CUresult result ;
147
+ CUdevice cuda_device ;
93
148
ucs_status_t status ;
94
149
uct_cuda_ctx_rsc_t * ctx_rsc ;
95
150
96
- status = uct_cuda_base_ctx_rsc_get (& iface -> super , & ctx_rsc );
151
+ result = uct_cuda_base_ctx_get_id (NULL , & ctx_id );
152
+ if (ucs_likely (result == CUDA_SUCCESS )) {
153
+ /* If there is a current context, the CU_DEVICE_INVALID is returned in
154
+ cuda_device_p */
155
+ cuda_device = CU_DEVICE_INVALID ;
156
+ } else {
157
+ /* Otherwise, the first active primary context found is pushed as a
158
+ current context. The caller must pop, and release the primary context
159
+ on the device returned in cuda_device_p. */
160
+ status = uct_cuda_primary_ctx_push_first_active (& cuda_device );
161
+ if (status != UCS_OK ) {
162
+ goto err ;
163
+ }
164
+
165
+ result = uct_cuda_base_ctx_get_id (NULL , & ctx_id );
166
+ if (result != CUDA_SUCCESS ) {
167
+ UCT_CUDADRV_LOG (cuCtxGetId , UCS_LOG_LEVEL_ERROR , result );
168
+ status = UCS_ERR_IO_ERROR ;
169
+ goto err_pop_and_release ;
170
+ }
171
+ }
172
+
173
+ status = uct_cuda_base_ctx_rsc_get (& iface -> super , ctx_id , & ctx_rsc );
97
174
if (ucs_unlikely (status != UCS_OK )) {
98
- return status ;
175
+ goto err_pop_and_release ;
99
176
}
100
177
178
+ * cuda_device_p = cuda_device ;
101
179
* ctx_rsc_p = ucs_derived_of (ctx_rsc , uct_cuda_copy_ctx_rsc_t );
102
180
return UCS_OK ;
181
+
182
+ err_pop_and_release :
183
+ uct_cuda_primary_ctx_pop_and_release (cuda_device );
184
+ err :
185
+ return status ;
103
186
}
104
187
105
188
static UCS_F_ALWAYS_INLINE ucs_status_t
@@ -108,6 +191,7 @@ uct_cuda_copy_post_cuda_async_copy(uct_ep_h tl_ep, void *dst, void *src,
108
191
{
109
192
uct_cuda_copy_iface_t * iface = ucs_derived_of (tl_ep -> iface , uct_cuda_copy_iface_t );
110
193
uct_base_iface_t * base_iface = ucs_derived_of (tl_ep -> iface , uct_base_iface_t );
194
+ CUdevice cuda_device ;
111
195
uct_cuda_event_desc_t * cuda_event ;
112
196
uct_cuda_queue_desc_t * q_desc ;
113
197
ucs_status_t status ;
@@ -121,9 +205,9 @@ uct_cuda_copy_post_cuda_async_copy(uct_ep_h tl_ep, void *dst, void *src,
121
205
return UCS_OK ;
122
206
}
123
207
124
- status = uct_cuda_copy_ctx_rsc_get (iface , & ctx_rsc );
208
+ status = uct_cuda_copy_ctx_rsc_get (iface , & cuda_device , & ctx_rsc );
125
209
if (ucs_unlikely (status != UCS_OK )) {
126
- return status ;
210
+ goto out ;
127
211
}
128
212
129
213
src_type = uct_cuda_copy_get_mem_type (base_iface -> md , src , length );
@@ -135,25 +219,27 @@ uct_cuda_copy_post_cuda_async_copy(uct_ep_h tl_ep, void *dst, void *src,
135
219
ucs_error ("stream for src %s dst %s not available" ,
136
220
ucs_memory_type_names [src_type ],
137
221
ucs_memory_type_names [dst_type ]);
138
- return UCS_ERR_IO_ERROR ;
222
+ status = UCS_ERR_IO_ERROR ;
223
+ goto out_pop_and_release ;
139
224
}
140
225
141
226
cuda_event = ucs_mpool_get (& ctx_rsc -> super .event_mp );
142
227
if (ucs_unlikely (cuda_event == NULL )) {
143
228
ucs_error ("failed to allocate cuda event object" );
144
- return UCS_ERR_NO_MEMORY ;
229
+ status = UCS_ERR_NO_MEMORY ;
230
+ goto out_pop_and_release ;
145
231
}
146
232
147
233
status = UCT_CUDADRV_FUNC_LOG_ERR (
148
234
cuMemcpyAsync ((CUdeviceptr )dst , (CUdeviceptr )src , length , * stream ));
149
235
if (ucs_unlikely (UCS_OK != status )) {
150
- return status ;
236
+ goto out_pop_and_release ;
151
237
}
152
238
153
239
status = UCT_CUDADRV_FUNC_LOG_ERR (
154
240
cuEventRecord (cuda_event -> event , * stream ));
155
241
if (ucs_unlikely (UCS_OK != status )) {
156
- return status ;
242
+ goto out_pop_and_release ;
157
243
}
158
244
159
245
if (ucs_queue_is_empty (event_q )) {
@@ -169,7 +255,12 @@ uct_cuda_copy_post_cuda_async_copy(uct_ep_h tl_ep, void *dst, void *src,
169
255
ucs_trace ("cuda async issued: %p dst:%p[%s], src:%p[%s] len:%ld" ,
170
256
cuda_event , dst , ucs_memory_type_names [dst_type ], src ,
171
257
ucs_memory_type_names [src_type ], length );
172
- return UCS_INPROGRESS ;
258
+ status = UCS_INPROGRESS ;
259
+
260
+ out_pop_and_release :
261
+ uct_cuda_primary_ctx_pop_and_release (cuda_device );
262
+ out :
263
+ return status ;
173
264
}
174
265
175
266
UCS_PROFILE_FUNC (ucs_status_t , uct_cuda_copy_ep_get_zcopy ,
@@ -219,27 +310,33 @@ static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ep_rma_short(
219
310
{
220
311
uct_cuda_copy_iface_t * iface = ucs_derived_of (tl_ep -> iface ,
221
312
uct_cuda_copy_iface_t );
313
+ CUdevice cuda_device ;
222
314
uct_cuda_copy_ctx_rsc_t * ctx_rsc ;
223
315
ucs_status_t status ;
224
316
CUstream * stream ;
225
317
226
- status = uct_cuda_copy_ctx_rsc_get (iface , & ctx_rsc );
318
+ status = uct_cuda_copy_ctx_rsc_get (iface , & cuda_device , & ctx_rsc );
227
319
if (ucs_unlikely (status != UCS_OK )) {
228
- return status ;
320
+ goto out ;
229
321
}
230
322
231
323
stream = & ctx_rsc -> short_stream ;
232
324
status = uct_cuda_base_init_stream (stream );
233
325
if (ucs_unlikely (status != UCS_OK )) {
234
- return status ;
326
+ goto out_pop_and_release ;
235
327
}
236
328
237
329
status = UCT_CUDADRV_FUNC_LOG_ERR (cuMemcpyAsync (dst , src , length , * stream ));
238
330
if (ucs_unlikely (status != UCS_OK )) {
239
- return status ;
331
+ goto out_pop_and_release ;
240
332
}
241
333
242
- return UCT_CUDADRV_FUNC_LOG_ERR (cuStreamSynchronize (* stream ));
334
+ status = UCT_CUDADRV_FUNC_LOG_ERR (cuStreamSynchronize (* stream ));
335
+
336
+ out_pop_and_release :
337
+ uct_cuda_primary_ctx_pop_and_release (cuda_device );
338
+ out :
339
+ return status ;
243
340
}
244
341
245
342
UCS_PROFILE_FUNC (ucs_status_t , uct_cuda_copy_ep_put_short ,
0 commit comments