Skip to content

Commit 97d0131

Browse files
authored
Merge pull request #10564 from rakhmets/topic/cuda-cpy-relax-ctx
UCT/CUDA/CUDA_COPY: Relaxed CUDA context dependency in cuda_copy transport.
2 parents 4fcffef + 2938b57 commit 97d0131

File tree

6 files changed

+189
-27
lines changed

6 files changed

+189
-27
lines changed

src/uct/cuda/base/cuda_iface.h

+2-9
Original file line numberDiff line numberDiff line change
@@ -250,18 +250,11 @@ ucs_status_t uct_cuda_primary_ctx_retain(CUdevice cuda_device, int force,
250250

251251

252252
static UCS_F_ALWAYS_INLINE ucs_status_t
253-
uct_cuda_base_ctx_rsc_get(uct_cuda_iface_t *iface, uct_cuda_ctx_rsc_t **ctx_rsc_p)
253+
uct_cuda_base_ctx_rsc_get(uct_cuda_iface_t *iface, unsigned long long ctx_id,
254+
uct_cuda_ctx_rsc_t **ctx_rsc_p)
254255
{
255-
unsigned long long ctx_id;
256-
CUresult result;
257256
khiter_t iter;
258257

259-
result = uct_cuda_base_ctx_get_id(NULL, &ctx_id);
260-
if (ucs_unlikely(result != CUDA_SUCCESS)) {
261-
UCT_CUDADRV_LOG(cuCtxGetId, UCS_LOG_LEVEL_ERROR, result);
262-
return UCS_ERR_IO_ERROR;
263-
}
264-
265258
iter = kh_get(cuda_ctx_rscs, &iface->ctx_rscs, ctx_id);
266259
if (ucs_likely(iter != kh_end(&iface->ctx_rscs))) {
267260
*ctx_rsc_p = kh_value(&iface->ctx_rscs, iter);

src/uct/cuda/cuda_copy/cuda_copy_ep.c

+113-16
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,102 @@ uct_cuda_copy_get_mem_type(uct_md_h md, void *address, size_t length)
8787
return mem_info.type;
8888
}
8989

90-
static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ctx_rsc_get(
91-
uct_cuda_copy_iface_t *iface, uct_cuda_copy_ctx_rsc_t **ctx_rsc_p)
90+
static ucs_status_t
91+
uct_cuda_primary_ctx_push_first_active(CUdevice *cuda_device_p)
9292
{
93+
int num_devices, device_index;
94+
ucs_status_t status;
95+
CUdevice cuda_device;
96+
CUcontext cuda_ctx;
97+
98+
status = UCT_CUDADRV_FUNC_LOG_ERR(cuDeviceGetCount(&num_devices));
99+
if (status != UCS_OK) {
100+
return status;
101+
}
102+
103+
for (device_index = 0; device_index < num_devices; ++device_index) {
104+
status = UCT_CUDADRV_FUNC_LOG_ERR(
105+
cuDeviceGet(&cuda_device, device_index));
106+
if (status != UCS_OK) {
107+
return status;
108+
}
109+
110+
status = uct_cuda_primary_ctx_retain(cuda_device, 0, &cuda_ctx);
111+
if (status == UCS_OK) {
112+
/* Found active primary context */
113+
status = UCT_CUDADRV_FUNC_LOG_ERR(cuCtxPushCurrent(cuda_ctx));
114+
if (status != UCS_OK) {
115+
UCT_CUDADRV_FUNC_LOG_WARN(
116+
cuDevicePrimaryCtxRelease(cuda_device));
117+
return status;
118+
}
119+
120+
*cuda_device_p = cuda_device;
121+
return UCS_OK;
122+
} else if (status != UCS_ERR_NO_DEVICE) {
123+
return status;
124+
}
125+
}
126+
127+
return UCS_ERR_NO_DEVICE;
128+
}
129+
130+
static UCS_F_ALWAYS_INLINE void
131+
uct_cuda_primary_ctx_pop_and_release(CUdevice cuda_device)
132+
{
133+
if (ucs_likely(cuda_device == CU_DEVICE_INVALID)) {
134+
return;
135+
}
136+
137+
UCT_CUDADRV_FUNC_LOG_WARN(cuCtxPopCurrent(NULL));
138+
UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_device));
139+
}
140+
141+
static UCS_F_ALWAYS_INLINE ucs_status_t
142+
uct_cuda_copy_ctx_rsc_get(uct_cuda_copy_iface_t *iface, CUdevice *cuda_device_p,
143+
uct_cuda_copy_ctx_rsc_t **ctx_rsc_p)
144+
{
145+
unsigned long long ctx_id;
146+
CUresult result;
147+
CUdevice cuda_device;
93148
ucs_status_t status;
94149
uct_cuda_ctx_rsc_t *ctx_rsc;
95150

96-
status = uct_cuda_base_ctx_rsc_get(&iface->super, &ctx_rsc);
151+
result = uct_cuda_base_ctx_get_id(NULL, &ctx_id);
152+
if (ucs_likely(result == CUDA_SUCCESS)) {
153+
/* If there is a current context, the CU_DEVICE_INVALID is returned in
154+
cuda_device_p */
155+
cuda_device = CU_DEVICE_INVALID;
156+
} else {
157+
/* Otherwise, the first active primary context found is pushed as a
158+
current context. The caller must pop, and release the primary context
159+
on the device returned in cuda_device_p. */
160+
status = uct_cuda_primary_ctx_push_first_active(&cuda_device);
161+
if (status != UCS_OK) {
162+
goto err;
163+
}
164+
165+
result = uct_cuda_base_ctx_get_id(NULL, &ctx_id);
166+
if (result != CUDA_SUCCESS) {
167+
UCT_CUDADRV_LOG(cuCtxGetId, UCS_LOG_LEVEL_ERROR, result);
168+
status = UCS_ERR_IO_ERROR;
169+
goto err_pop_and_release;
170+
}
171+
}
172+
173+
status = uct_cuda_base_ctx_rsc_get(&iface->super, ctx_id, &ctx_rsc);
97174
if (ucs_unlikely(status != UCS_OK)) {
98-
return status;
175+
goto err_pop_and_release;
99176
}
100177

178+
*cuda_device_p = cuda_device;
101179
*ctx_rsc_p = ucs_derived_of(ctx_rsc, uct_cuda_copy_ctx_rsc_t);
102180
return UCS_OK;
181+
182+
err_pop_and_release:
183+
uct_cuda_primary_ctx_pop_and_release(cuda_device);
184+
err:
185+
return status;
103186
}
104187

105188
static UCS_F_ALWAYS_INLINE ucs_status_t
@@ -108,6 +191,7 @@ uct_cuda_copy_post_cuda_async_copy(uct_ep_h tl_ep, void *dst, void *src,
108191
{
109192
uct_cuda_copy_iface_t *iface = ucs_derived_of(tl_ep->iface, uct_cuda_copy_iface_t);
110193
uct_base_iface_t *base_iface = ucs_derived_of(tl_ep->iface, uct_base_iface_t);
194+
CUdevice cuda_device;
111195
uct_cuda_event_desc_t *cuda_event;
112196
uct_cuda_queue_desc_t *q_desc;
113197
ucs_status_t status;
@@ -121,9 +205,9 @@ uct_cuda_copy_post_cuda_async_copy(uct_ep_h tl_ep, void *dst, void *src,
121205
return UCS_OK;
122206
}
123207

124-
status = uct_cuda_copy_ctx_rsc_get(iface, &ctx_rsc);
208+
status = uct_cuda_copy_ctx_rsc_get(iface, &cuda_device, &ctx_rsc);
125209
if (ucs_unlikely(status != UCS_OK)) {
126-
return status;
210+
goto out;
127211
}
128212

129213
src_type = uct_cuda_copy_get_mem_type(base_iface->md, src, length);
@@ -135,25 +219,27 @@ uct_cuda_copy_post_cuda_async_copy(uct_ep_h tl_ep, void *dst, void *src,
135219
ucs_error("stream for src %s dst %s not available",
136220
ucs_memory_type_names[src_type],
137221
ucs_memory_type_names[dst_type]);
138-
return UCS_ERR_IO_ERROR;
222+
status = UCS_ERR_IO_ERROR;
223+
goto out_pop_and_release;
139224
}
140225

141226
cuda_event = ucs_mpool_get(&ctx_rsc->super.event_mp);
142227
if (ucs_unlikely(cuda_event == NULL)) {
143228
ucs_error("failed to allocate cuda event object");
144-
return UCS_ERR_NO_MEMORY;
229+
status = UCS_ERR_NO_MEMORY;
230+
goto out_pop_and_release;
145231
}
146232

147233
status = UCT_CUDADRV_FUNC_LOG_ERR(
148234
cuMemcpyAsync((CUdeviceptr)dst, (CUdeviceptr)src, length, *stream));
149235
if (ucs_unlikely(UCS_OK != status)) {
150-
return status;
236+
goto out_pop_and_release;
151237
}
152238

153239
status = UCT_CUDADRV_FUNC_LOG_ERR(
154240
cuEventRecord(cuda_event->event, *stream));
155241
if (ucs_unlikely(UCS_OK != status)) {
156-
return status;
242+
goto out_pop_and_release;
157243
}
158244

159245
if (ucs_queue_is_empty(event_q)) {
@@ -169,7 +255,12 @@ uct_cuda_copy_post_cuda_async_copy(uct_ep_h tl_ep, void *dst, void *src,
169255
ucs_trace("cuda async issued: %p dst:%p[%s], src:%p[%s] len:%ld",
170256
cuda_event, dst, ucs_memory_type_names[dst_type], src,
171257
ucs_memory_type_names[src_type], length);
172-
return UCS_INPROGRESS;
258+
status = UCS_INPROGRESS;
259+
260+
out_pop_and_release:
261+
uct_cuda_primary_ctx_pop_and_release(cuda_device);
262+
out:
263+
return status;
173264
}
174265

175266
UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_copy_ep_get_zcopy,
@@ -219,27 +310,33 @@ static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ep_rma_short(
219310
{
220311
uct_cuda_copy_iface_t *iface = ucs_derived_of(tl_ep->iface,
221312
uct_cuda_copy_iface_t);
313+
CUdevice cuda_device;
222314
uct_cuda_copy_ctx_rsc_t *ctx_rsc;
223315
ucs_status_t status;
224316
CUstream *stream;
225317

226-
status = uct_cuda_copy_ctx_rsc_get(iface, &ctx_rsc);
318+
status = uct_cuda_copy_ctx_rsc_get(iface, &cuda_device, &ctx_rsc);
227319
if (ucs_unlikely(status != UCS_OK)) {
228-
return status;
320+
goto out;
229321
}
230322

231323
stream = &ctx_rsc->short_stream;
232324
status = uct_cuda_base_init_stream(stream);
233325
if (ucs_unlikely(status != UCS_OK)) {
234-
return status;
326+
goto out_pop_and_release;
235327
}
236328

237329
status = UCT_CUDADRV_FUNC_LOG_ERR(cuMemcpyAsync(dst, src, length, *stream));
238330
if (ucs_unlikely(status != UCS_OK)) {
239-
return status;
331+
goto out_pop_and_release;
240332
}
241333

242-
return UCT_CUDADRV_FUNC_LOG_ERR(cuStreamSynchronize(*stream));
334+
status = UCT_CUDADRV_FUNC_LOG_ERR(cuStreamSynchronize(*stream));
335+
336+
out_pop_and_release:
337+
uct_cuda_primary_ctx_pop_and_release(cuda_device);
338+
out:
339+
return status;
243340
}
244341

245342
UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_copy_ep_put_short,

src/uct/cuda/cuda_ipc/cuda_ipc_ep.c

+9-1
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,18 @@ int uct_cuda_ipc_ep_is_connected(const uct_ep_h tl_ep,
6060
static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_ipc_ctx_rsc_get(
6161
uct_cuda_ipc_iface_t *iface, uct_cuda_ipc_ctx_rsc_t **ctx_rsc_p)
6262
{
63+
unsigned long long ctx_id;
64+
CUresult result;
6365
ucs_status_t status;
6466
uct_cuda_ctx_rsc_t *ctx_rsc;
6567

66-
status = uct_cuda_base_ctx_rsc_get(&iface->super, &ctx_rsc);
68+
result = uct_cuda_base_ctx_get_id(NULL, &ctx_id);
69+
if (ucs_unlikely(result != CUDA_SUCCESS)) {
70+
UCT_CUDADRV_LOG(cuCtxGetId, UCS_LOG_LEVEL_ERROR, result);
71+
return UCS_ERR_IO_ERROR;
72+
}
73+
74+
status = uct_cuda_base_ctx_rsc_get(&iface->super, ctx_id, &ctx_rsc);
6775
if (ucs_unlikely(status != UCS_OK)) {
6876
return status;
6977
}

test/gtest/uct/cuda/test_switch_cuda_device.cc

+56
Original file line numberDiff line numberDiff line change
@@ -428,3 +428,59 @@ UCS_TEST_P(test_mem_alloc_device, same_device_cuda_fabric_implicit,
428428
}
429429

430430
_UCT_MD_INSTANTIATE_TEST_CASE(test_mem_alloc_device, cuda_cpy);
431+
432+
class test_p2p_no_current_cuda_ctx : public uct_p2p_rma_test {
433+
public:
434+
void test_xfer_on_thread(send_func_t send, size_t length, unsigned flags);
435+
};
436+
437+
void test_p2p_no_current_cuda_ctx::test_xfer_on_thread(send_func_t send,
438+
size_t length,
439+
unsigned flags)
440+
{
441+
mapped_buffer sendbuf(length, SEED1, sender());
442+
mapped_buffer recvbuf(length, SEED2, receiver(), 0, UCS_MEMORY_TYPE_CUDA);
443+
444+
std::exception_ptr thread_exception;
445+
std::thread([&]() {
446+
try {
447+
blocking_send(send, sender_ep(), sendbuf, recvbuf, true);
448+
} catch (...) {
449+
thread_exception = std::current_exception();
450+
}
451+
}).join();
452+
453+
if (thread_exception) {
454+
std::rethrow_exception(thread_exception);
455+
}
456+
457+
check_buf(sendbuf, recvbuf, flags);
458+
}
459+
460+
UCS_TEST_P(test_p2p_no_current_cuda_ctx, put_short)
461+
{
462+
test_xfer_on_thread(static_cast<send_func_t>(&uct_p2p_rma_test::put_short),
463+
1, TEST_UCT_FLAG_SEND_ZCOPY);
464+
}
465+
466+
UCS_TEST_P(test_p2p_no_current_cuda_ctx, get_short)
467+
{
468+
test_xfer_on_thread(static_cast<send_func_t>(&uct_p2p_rma_test::get_short),
469+
1, TEST_UCT_FLAG_RECV_ZCOPY);
470+
}
471+
472+
UCS_TEST_P(test_p2p_no_current_cuda_ctx, put_zcopy)
473+
{
474+
test_xfer_on_thread(static_cast<send_func_t>(&uct_p2p_rma_test::put_zcopy),
475+
sender().iface_attr().cap.put.min_zcopy + 1,
476+
TEST_UCT_FLAG_SEND_ZCOPY);
477+
}
478+
479+
UCS_TEST_P(test_p2p_no_current_cuda_ctx, get_zcopy)
480+
{
481+
test_xfer_on_thread(static_cast<send_func_t>(&uct_p2p_rma_test::get_zcopy),
482+
sender().iface_attr().cap.get.min_zcopy + 1,
483+
TEST_UCT_FLAG_RECV_ZCOPY);
484+
}
485+
486+
_UCT_INSTANTIATE_TEST_CASE(test_p2p_no_current_cuda_ctx, cuda_copy)

test/gtest/uct/test_p2p_rma.cc

+6-1
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,13 @@ void uct_p2p_rma_test::test_xfer(send_func_t send, size_t length,
7878

7979
mapped_buffer sendbuf(length, SEED1, sender(), 1, src_mem_type);
8080
mapped_buffer recvbuf(length, SEED2, receiver(), 3, mem_type);
81-
8281
blocking_send(send, sender_ep(), sendbuf, recvbuf, true);
82+
check_buf(sendbuf, recvbuf, flags);
83+
}
84+
85+
void uct_p2p_rma_test::check_buf(mapped_buffer &sendbuf, mapped_buffer &recvbuf,
86+
unsigned flags)
87+
{
8388
if (flags & TEST_UCT_FLAG_SEND_ZCOPY) {
8489
sendbuf.memset(0);
8590
wait_for_remote();

test/gtest/uct/test_p2p_rma.h

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class uct_p2p_rma_test : public uct_p2p_test {
3737

3838
virtual void test_xfer(send_func_t send, size_t length,
3939
unsigned flags, ucs_memory_type_t mem_type);
40+
41+
void
42+
check_buf(mapped_buffer &sendbuf, mapped_buffer &recvbuf, unsigned flags);
4043
};
4144

4245
#endif

0 commit comments

Comments
 (0)