Skip to content

Commit 64dd174

Browse files
committed
ch4/ofi: use MPIDI_OFI_rdmaread_poll in MPIDI_OFI_mirror_get
The direct fi_read won't work if the origin buffer is a device buffer or uses non-contig datatypes. Async operation MPIDI_OFI_rdmaread_poll handles pipelined read supporting both device buffer and non-contig datatypes.
1 parent 17b563c commit 64dd174

File tree

4 files changed

+40
-41
lines changed

4 files changed

+40
-41
lines changed

src/mpid/ch4/netmod/ofi/ofi_events.c

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -479,10 +479,6 @@ int MPIDI_OFI_dispatch_function(int vci, struct fi_cq_tagged_entry *wc, MPIR_Req
479479
mpi_errno = MPIDI_OFI_rndvwrite_ack_event(wc, req);
480480
break;
481481

482-
case MPIDI_OFI_EVENT_MIRROR_READ:
483-
mpi_errno = MPIDI_OFI_mirror_read_event(wc, req);
484-
break;
485-
486482
case MPIDI_OFI_EVENT_CHUNK_DONE:
487483
mpi_errno = chunk_done_event(vci, wc, req);
488484
break;

src/mpid/ch4/netmod/ofi/ofi_events.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ int MPIDI_OFI_rndvread_ack_event(struct fi_cq_tagged_entry *wc, MPIR_Request * r
2525
int MPIDI_OFI_rndvwrite_recv_mrs_event(struct fi_cq_tagged_entry *wc, MPIR_Request * r);
2626
int MPIDI_OFI_rndvwrite_write_chunk_event(struct fi_cq_tagged_entry *wc, MPIR_Request * r);
2727
int MPIDI_OFI_rndvwrite_ack_event(struct fi_cq_tagged_entry *wc, MPIR_Request * r);
28-
int MPIDI_OFI_mirror_read_event(struct fi_cq_tagged_entry *wc, MPIR_Request * r);
2928

3029
MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_cqe_get_source(struct fi_cq_tagged_entry *wc, bool has_err)
3130
{

src/mpid/ch4/netmod/ofi/ofi_rma.c

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ struct get_context {
510510
MPI_Aint data_sz;
511511
void *origin_addr;
512512
MPI_Aint target_offset;
513+
MPIR_Request *req;
513514
};
514515

515516
struct get_hdr {
@@ -554,6 +555,14 @@ int MPIDI_OFI_mirror_get(void *origin_addr, MPI_Aint origin_count, MPI_Datatype
554555
origin_context->origin_addr = (char *) origin_addr + origin_true_lb;
555556
origin_context->target_offset = target_disp * win->disp_unit + target_true_lb;
556557

558+
/* allocate a request, used for reuse the code from ofi_rndv_read. */
559+
MPIR_Request *req;
560+
MPIDI_OFI_REQUEST_CREATE(req, MPIR_REQUEST_KIND__RMA, vci);
561+
if (1) {
562+
MPIDI_CH4_REQUEST_FREE(req);
563+
}
564+
origin_context->req = req;
565+
557566
/* fill am_hdr */
558567
struct get_hdr am_hdr;
559568
am_hdr.win_id = MPIDIG_WIN(win, win_id);
@@ -638,6 +647,7 @@ int MPIDI_OFI_get_handler(void *am_hdr, void *data, MPI_Aint in_data_sz,
638647
struct getack_hdr {
639648
void *origin_context;
640649
uint64_t rkey;
650+
uint64_t remote_base;
641651
};
642652

643653
/* target side - async callback */
@@ -652,6 +662,7 @@ static int target_mirror_copy_poll(MPIX_Async_thing thing)
652662
struct getack_hdr am_hdr;
653663
am_hdr.origin_context = p->origin_context;
654664
am_hdr.rkey = fi_mr_key(MPIDI_OFI_WIN(p->win).mr);
665+
am_hdr.remote_base = (uintptr_t) MPIDI_OFI_WIN(p->win).mirror_buf;
655666

656667
int rc = MPIDI_NM_am_send_hdr(p->origin_rank, p->win->comm_ptr, MPIDI_OFI_GET_ACK,
657668
&am_hdr, sizeof(am_hdr), p->vci_target, p->vci_origin);
@@ -672,6 +683,8 @@ struct read_req {
672683
struct get_context *origin_context;
673684
};
674685

686+
static int rdmaread_completion(void *context);
687+
675688
/* origin side - AM callback */
676689
int MPIDI_OFI_getack_handler(void *am_hdr, void *data, MPI_Aint in_data_sz,
677690
uint32_t attr, MPIR_Request ** req)
@@ -682,56 +695,48 @@ int MPIDI_OFI_getack_handler(void *am_hdr, void *data, MPI_Aint in_data_sz,
682695
struct get_context *origin_context = msg_hdr->origin_context;
683696
MPIR_Win *win = origin_context->win;
684697
int target_rank = origin_context->target_rank;
685-
void *buf = origin_context->origin_addr;
686-
MPI_Aint len = origin_context->data_sz;
687698
MPI_Aint target_offset = origin_context->target_offset;
688699

689-
/* rdma read */
690-
/* FIXME: potentially we need allocate a staging buffer */
691-
void *desc = NULL;
692-
int nic_target = MPIDI_OFI_get_pref_nic(win->comm_ptr, target_rank);
693-
694-
MPL_pointer_attr_t bufattr;
695-
MPIR_GPU_query_pointer_attr(buf, &bufattr);
696-
if (MPL_gpu_attr_is_strict_dev(&bufattr)) {
697-
MPIDI_OFI_gpu_rma_register(buf, len, &bufattr, win, nic_target, &desc);
698-
}
699-
700-
int vci = MPIDI_WIN(win, am_vci);
701-
int vci_target = MPIDI_WIN_TARGET_VCI(win, target_rank);
700+
MPIDI_OFI_rndvread_t *p = &MPIDI_OFI_AMREQ_READ(origin_context->req);
701+
p->buf = origin_context->origin_addr;
702+
p->count = origin_context->data_sz;
703+
p->datatype = MPIR_BYTE_INTERNAL;
702704

703-
MPIDI_av_entry_t *av = MPIDIU_win_rank_to_av(win, target_rank, MPIDI_WIN(win, winattr));
704-
fi_addr_t addr = MPIDI_OFI_av_to_phys(av, vci, 0, vci_target, nic_target);
705+
MPIR_GPU_query_pointer_attr(p->buf, &p->attr);
706+
p->need_pack = MPL_gpu_attr_is_dev(&p->attr);
705707

706-
struct read_req *r = MPL_malloc(sizeof(struct read_req), MPL_MEM_OTHER);
707-
MPIR_ERR_CHKANDJUMP(!r, mpi_errno, MPI_ERR_OTHER, "**nomem");
708+
p->data_sz = p->remote_data_sz = origin_context->data_sz;
709+
p->vci_local = MPIDI_WIN(win, am_vci);
710+
p->vci_remote = MPIDI_WIN_TARGET_VCI(win, target_rank);
711+
p->av = MPIDIU_win_rank_to_av(win, target_rank, MPIDI_WIN(win, winattr));
708712

709-
r->event_id = MPIDI_OFI_EVENT_MIRROR_READ;
710-
r->origin_context = origin_context;
713+
p->num_nics = 1;
714+
if (MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) {
715+
p->u.recv.remote_base = msg_hdr->remote_base + target_offset;
716+
} else {
717+
p->u.recv.remote_base = target_offset;
718+
}
719+
p->u.recv.rkeys = &p->u.recv.rkey0;
720+
p->u.recv.rkey0 = msg_hdr->rkey;
721+
p->u.recv.cmpl_cb = rdmaread_completion;
722+
p->u.recv.context = origin_context;
711723

712-
MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci));
713-
MPIDI_OFI_CALL_RETRY(fi_read(MPIDI_OFI_WIN(win).ep, buf, len, desc,
714-
addr, target_offset, msg_hdr->rkey, (void *) &r->context),
715-
vci, rdma_readfrom);
716-
/* Complete signal request to inform completion to user. */
717-
MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci));
724+
mpi_errno = MPIR_Async_things_add(MPIDI_OFI_rdmaread_poll, origin_context->req, NULL);
718725

719-
fn_exit:
720726
return mpi_errno;
721-
fn_fail:
722-
goto fn_exit;
723727
}
724728

725-
/* origin side - rdma event callback */
726-
int MPIDI_OFI_mirror_read_event(struct fi_cq_tagged_entry *wc, MPIR_Request * req)
729+
static int rdmaread_completion(void *context)
727730
{
728-
struct read_req *r = (void *) req;
729-
struct get_context *origin_context = r->origin_context;
731+
struct get_context *origin_context = context;
732+
730733
MPIR_Win *win = origin_context->win;
731734
int target_rank = origin_context->target_rank;
732735

733736
MPIDIG_win_cmpl_cnts_decr(win, target_rank);
734737

738+
MPIDI_Request_complete_fast(origin_context->req);
735739
MPL_free(origin_context);
736-
MPL_free(r);
740+
741+
return MPI_SUCCESS;
737742
}

src/mpid/ch4/netmod/ofi/ofi_types.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ enum {
208208
MPIDI_OFI_EVENT_RNDVWRITE_RECV_MRS,
209209
MPIDI_OFI_EVENT_RNDVWRITE_WRITE_CHUNK,
210210
MPIDI_OFI_EVENT_RNDVWRITE_ACK,
211-
MPIDI_OFI_EVENT_MIRROR_READ,
212211
MPIDI_OFI_EVENT_CHUNK_DONE,
213212
MPIDI_OFI_EVENT_INJECT_EMU,
214213
MPIDI_OFI_EVENT_DYNPROC_DONE,

0 commit comments

Comments
 (0)