diff --git a/src/mpid/ch4/netmod/ofi/ofi_impl.h b/src/mpid/ch4/netmod/ofi/ofi_impl.h index 625021a284e..09dc58266db 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_impl.h +++ b/src/mpid/ch4/netmod/ofi/ofi_impl.h @@ -298,6 +298,15 @@ int MPIDI_OFI_am_rdma_read_ack_handler(void *am_hdr, void *data, MPI_Aint in_data_sz, uint32_t attr, MPIR_Request ** req); int MPIDI_OFI_rndv_info_handler(void *am_hdr, void *data, MPI_Aint data_sz, uint32_t attr, MPIR_Request ** req); + +int MPIDI_OFI_mirror_get(void *origin_addr, MPI_Aint origin_count, MPI_Datatype origin_datatype, + int target_rank, MPI_Aint target_disp, MPI_Aint target_count, + MPI_Datatype target_datatype, MPIR_Win * win); +int MPIDI_OFI_get_handler(void *am_hdr, void *data, MPI_Aint data_sz, + uint32_t attr, MPIR_Request ** req); +int MPIDI_OFI_getack_handler(void *am_hdr, void *data, MPI_Aint data_sz, + uint32_t attr, MPIR_Request ** req); + int MPIDI_OFI_control_dispatch(void *buf); void MPIDI_OFI_index_datatypes(struct fid_ep *ep); int MPIDI_OFI_mr_key_allocator_init(void); @@ -307,6 +316,8 @@ void MPIDI_OFI_mr_key_allocator_destroy(void); int MPIDI_OFI_datatype_to_ofi(MPI_Datatype dt, enum fi_datatype *fi_dt); int MPIDI_OFI_op_to_ofi(MPI_Op op, enum fi_op *fi_op); +int MPIDI_OFI_rdmaread_poll(MPIX_Async_thing thing); + /* RMA */ #define MPIDI_OFI_INIT_CHUNK_CONTEXT(win,sigreq) \ do { \ diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.c b/src/mpid/ch4/netmod/ofi/ofi_init.c index 386d7c74735..41feb232fdc 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.c +++ b/src/mpid/ch4/netmod/ofi/ofi_init.c @@ -1573,6 +1573,8 @@ static int am_init(int vci) if (vci == 0) { MPIDIG_am_reg_cb(MPIDI_OFI_AM_RDMA_READ_ACK, NULL, &MPIDI_OFI_am_rdma_read_ack_handler); MPIDIG_am_reg_cb(MPIDI_OFI_RNDV_INFO, NULL, &MPIDI_OFI_rndv_info_handler); + MPIDIG_am_reg_cb(MPIDI_OFI_GET_REQ, NULL, &MPIDI_OFI_get_handler); + MPIDIG_am_reg_cb(MPIDI_OFI_GET_ACK, NULL, &MPIDI_OFI_getack_handler); } } diff --git a/src/mpid/ch4/netmod/ofi/ofi_pre.h b/src/mpid/ch4/netmod/ofi/ofi_pre.h index c1f09ca7745..8dd404ac433 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_pre.h +++ b/src/mpid/ch4/netmod/ofi/ofi_pre.h @@ -234,10 +234,12 @@ typedef struct { typedef struct { MPIDI_OFI_RNDV_COMMON_FIELDS; + int num_nics; MPI_Aint sz_per_nic; union { struct { const void *data; + struct fid_mr *mr0; struct fid_mr **mrs; } send; struct { @@ -246,11 +248,14 @@ typedef struct { int copy_infly; /* need_pack */ } u; uint64_t remote_base; + uint64_t rkey0; /* avoid malloc when num_nics == 1 */ uint64_t *rkeys; MPI_Aint chunks_per_nic; MPI_Aint cur_chunk_index; int num_infly; bool all_issued; + int (*cmpl_cb) (void *context); /* context will be cast to (MPIR_Request *) */ + void *context; } recv; } u; } MPIDI_OFI_rndvread_t; @@ -381,6 +386,9 @@ typedef struct { struct MPIDI_OFI_win_request *syncQ; struct MPIDI_OFI_win_request *deferredQ; MPIDI_OFI_win_targetinfo_t *winfo; + void *mirror_buf; /* used in gpu fallback paths to avoid repeated host registration */ + MPL_pointer_attr_t base_attr; + MPL_pointer_attr_t mirror_attr; MPL_gavl_tree_t *dwin_target_mrs; /* MR key and address pairs registered to remote processes. * One AVL tree per process. */ diff --git a/src/mpid/ch4/netmod/ofi/ofi_rma.c b/src/mpid/ch4/netmod/ofi/ofi_rma.c index 9c1392c06b9..2106acb36b9 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_rma.c +++ b/src/mpid/ch4/netmod/ofi/ofi_rma.c @@ -488,3 +488,255 @@ int MPIDI_OFI_issue_deferred_rma(MPIR_Win * win) fn_fail: goto fn_exit; } + +/* -- active message fallback using mirror buffers -- */ + +/* assumptions: + * 1. both origin and target datatypes are contig + * 2. data_sz <= MPIDI_OFI_global.max_msg_size + */ + +/* Get using AM mirror buffer - + * 1. Origin send am MPIDI_OFI_GET_REQ + * 2. Target async localcopy to mirror buffer + * 3. Target send am MPIDI_OFI_GET_ACK + * 4. Origin RDMA read + * 5. Origin complete + */ + +struct get_context { + MPIR_Win *win; + int target_rank; + MPI_Aint data_sz; + void *origin_addr; + MPI_Aint target_offset; + MPIR_Request *req; +}; + +struct get_hdr { + uint64_t win_id; + int origin_rank; + void *origin_context; + MPI_Aint target_offset; + MPI_Aint data_sz; +}; + +/* origin side - issue AM req */ +int MPIDI_OFI_mirror_get(void *origin_addr, MPI_Aint origin_count, MPI_Datatype origin_datatype, + int target_rank, MPI_Aint target_disp, MPI_Aint target_count, + MPI_Datatype target_datatype, MPIR_Win * win) +{ + int mpi_errno = MPI_SUCCESS; + MPIR_FUNC_ENTER; + + /* query target datatype */ + int is_contig; + MPIR_Datatype_is_contig(target_datatype, &is_contig); + + MPI_Aint data_sz; + MPIR_Datatype_get_size_macro(origin_datatype, data_sz); + data_sz *= origin_count; + + MPI_Aint origin_true_lb, target_true_lb; + MPIR_Datatype_get_true_lb(target_datatype, &target_true_lb); + MPIR_Datatype_get_true_lb(origin_datatype, &origin_true_lb); + + int vci = MPIDI_WIN(win, am_vci); + int vci_target = MPIDI_WIN_TARGET_VCI(win, target_rank); + + /* fill origin context */ + struct get_context *origin_context; + origin_context = MPL_malloc(sizeof(struct get_context), MPL_MEM_OTHER); + MPIR_ERR_CHKANDJUMP((origin_context == NULL), mpi_errno, MPI_ERR_OTHER, "**nomem"); + + origin_context->win = win; + origin_context->target_rank = target_rank; + origin_context->data_sz = data_sz; + origin_context->origin_addr = (char *) origin_addr + origin_true_lb; + origin_context->target_offset = target_disp * win->disp_unit + target_true_lb; + + /* allocate a request, used for reuse the code from ofi_rndv_read. */ + MPIR_Request *req; + MPIDI_OFI_REQUEST_CREATE(req, MPIR_REQUEST_KIND__RMA, vci); + if (1) { + MPIDI_CH4_REQUEST_FREE(req); + } + origin_context->req = req; + + /* fill am_hdr */ + struct get_hdr am_hdr; + am_hdr.win_id = MPIDIG_WIN(win, win_id); + am_hdr.origin_rank = win->comm_ptr->rank; + am_hdr.origin_context = origin_context; + am_hdr.data_sz = origin_context->data_sz; + am_hdr.target_offset = origin_context->target_offset; + + MPIDIG_win_cmpl_cnts_incr(win, target_rank, NULL); + + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); + mpi_errno = MPIDI_NM_am_send_hdr(target_rank, win->comm_ptr, MPIDI_OFI_GET_REQ, + &am_hdr, sizeof(am_hdr), vci, vci_target); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); + MPIR_ERR_CHECK(mpi_errno); + + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; + fn_fail: + goto fn_exit; +} + +struct target_mirror_copy { + MPIR_Win *win; + int origin_rank; + void *origin_context; + int vci_origin; + int vci_target; + MPIR_gpu_req async_req; +}; + +static int target_mirror_copy_poll(MPIX_Async_thing thing); + +/* target side - AM callback */ +int MPIDI_OFI_get_handler(void *am_hdr, void *data, MPI_Aint in_data_sz, + uint32_t attr, MPIR_Request ** req) +{ + int mpi_errno = MPI_SUCCESS; + MPIR_FUNC_ENTER; + + struct get_hdr *msg_hdr = am_hdr; + + MPIR_Win *win; + win = (MPIR_Win *) MPIDIU_map_lookup(MPIDI_global.win_map, msg_hdr->win_id); + MPIR_Assert(win); + + void *mirror_buf = MPIDI_OFI_WIN(win).mirror_buf; + void *mirror_attr = &(MPIDI_OFI_WIN(win).mirror_attr); + void *base_buf = win->base; + void *base_attr = &(MPIDI_OFI_WIN(win).base_attr); + + /* async localcopy */ + MPIR_gpu_req async_req; + int engine_type = MPIDI_OFI_gpu_get_send_engine_type(); + mpi_errno = MPIR_Ilocalcopy_gpu(base_buf, msg_hdr->data_sz, MPIR_BYTE_INTERNAL, + msg_hdr->target_offset, base_attr, + mirror_buf, msg_hdr->data_sz, MPIR_BYTE_INTERNAL, + msg_hdr->target_offset, mirror_attr, + engine_type, 1, &async_req); + MPIR_ERR_CHECK(mpi_errno); + + /* add async things */ + struct target_mirror_copy *p = MPL_malloc(sizeof(struct target_mirror_copy), MPL_MEM_OTHER); + p->win = win; + p->origin_rank = msg_hdr->origin_rank; + p->origin_context = msg_hdr->origin_context; + p->vci_origin = MPIDIG_AM_ATTR_SRC_VCI(attr); + p->vci_target = MPIDIG_AM_ATTR_DST_VCI(attr); + p->async_req = async_req; + + mpi_errno = MPIR_Async_things_add(target_mirror_copy_poll, p, NULL); + MPIR_ERR_CHECK(mpi_errno); + + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; + fn_fail: + goto fn_exit; +} + +struct getack_hdr { + void *origin_context; + uint64_t rkey; + uint64_t remote_base; +}; + +/* target side - async callback */ +static int target_mirror_copy_poll(MPIX_Async_thing thing) +{ + struct target_mirror_copy *p = MPIR_Async_thing_get_state(thing); + int is_done; + MPIR_async_test(&p->async_req, &is_done); + + if (is_done) { + /* send get_ack */ + struct getack_hdr am_hdr; + am_hdr.origin_context = p->origin_context; + am_hdr.rkey = fi_mr_key(MPIDI_OFI_WIN(p->win).mr); + am_hdr.remote_base = (uintptr_t) MPIDI_OFI_WIN(p->win).mirror_buf; + + int rc = MPIDI_NM_am_send_hdr(p->origin_rank, p->win->comm_ptr, MPIDI_OFI_GET_ACK, + &am_hdr, sizeof(am_hdr), p->vci_target, p->vci_origin); + MPIR_Assertp(rc == MPI_SUCCESS); + + MPL_free(p); + + return MPIX_ASYNC_DONE; + } + + return MPIX_ASYNC_NOPROGRESS; +} + +struct read_req { + char pad[MPIDI_REQUEST_HDR_SIZE]; + struct fi_context context[MPIDI_OFI_CONTEXT_STRUCTS]; + int event_id; + struct get_context *origin_context; +}; + +static int rdmaread_completion(void *context); + +/* origin side - AM callback */ +int MPIDI_OFI_getack_handler(void *am_hdr, void *data, MPI_Aint in_data_sz, + uint32_t attr, MPIR_Request ** req) +{ + int mpi_errno = MPI_SUCCESS; + + struct getack_hdr *msg_hdr = am_hdr; + struct get_context *origin_context = msg_hdr->origin_context; + MPIR_Win *win = origin_context->win; + int target_rank = origin_context->target_rank; + MPI_Aint target_offset = origin_context->target_offset; + + MPIDI_OFI_rndvread_t *p = &MPIDI_OFI_AMREQ_READ(origin_context->req); + p->buf = origin_context->origin_addr; + p->count = origin_context->data_sz; + p->datatype = MPIR_BYTE_INTERNAL; + + MPIR_GPU_query_pointer_attr(p->buf, &p->attr); + p->need_pack = MPL_gpu_attr_is_dev(&p->attr); + + p->data_sz = p->remote_data_sz = origin_context->data_sz; + p->vci_local = MPIDI_WIN(win, am_vci); + p->vci_remote = MPIDI_WIN_TARGET_VCI(win, target_rank); + p->av = MPIDIU_win_rank_to_av(win, target_rank, MPIDI_WIN(win, winattr)); + + p->num_nics = 1; + if (MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) { + p->u.recv.remote_base = msg_hdr->remote_base + target_offset; + } else { + p->u.recv.remote_base = target_offset; + } + p->u.recv.rkeys = &p->u.recv.rkey0; + p->u.recv.rkey0 = msg_hdr->rkey; + p->u.recv.cmpl_cb = rdmaread_completion; + p->u.recv.context = origin_context; + + mpi_errno = MPIR_Async_things_add(MPIDI_OFI_rdmaread_poll, origin_context->req, NULL); + + return mpi_errno; +} + +static int rdmaread_completion(void *context) +{ + struct get_context *origin_context = context; + + MPIR_Win *win = origin_context->win; + int target_rank = origin_context->target_rank; + + MPIDIG_win_cmpl_cnts_decr(win, target_rank); + + MPIDI_Request_complete_fast(origin_context->req); + MPL_free(origin_context); + + return MPI_SUCCESS; +} diff --git a/src/mpid/ch4/netmod/ofi/ofi_rma.h b/src/mpid/ch4/netmod/ofi/ofi_rma.h index 1f8f38e2a71..c56dceea095 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_rma.h +++ b/src/mpid/ch4/netmod/ofi/ofi_rma.h @@ -525,14 +525,43 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_get(void *origin_addr, MPIDI_winattr_t winattr) { int mpi_errno = MPI_SUCCESS; - MPIR_FUNC_ENTER; + /* TODO: move pre-checks to ch4_rma */ + MPIDIG_RMA_OP_CHECK_SYNC(target_rank, win); + + /* check early exit */ + if (target_count == 0) + goto fn_exit; + + if (target_rank == win->comm_ptr->rank) { + MPI_Aint offset; + offset = win->disp_unit * target_disp; + mpi_errno = MPIR_Localcopy((char *) win->base + offset, target_count, target_datatype, + origin_addr, origin_count, origin_datatype); + MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + } + if (!MPIDI_OFI_ENABLE_RMA || !(winattr & MPIDI_WINATTR_NM_REACHABLE) || !MPIDI_OFI_gpu_rma_enabled(origin_addr)) { - MPIDI_OFI_register_am_bufs(); - mpi_errno = MPIDIG_mpi_get(origin_addr, origin_count, origin_datatype, target_rank, - target_disp, target_count, target_datatype, win); + MPI_Aint data_sz; + MPIDI_Datatype_check_size(target_datatype, target_count, data_sz); + bool good_size = (data_sz >= MPIDI_NM_am_eager_limit()); + int origin_is_contig, target_is_contig; + MPIR_Datatype_is_contig(origin_datatype, &origin_is_contig); + MPIR_Datatype_is_contig(target_datatype, &target_is_contig); + /* for now, only optimize for large contig data */ + if (origin_is_contig && target_is_contig && good_size && MPIR_CVAR_OFI_ENABLE_WIN_MIRROR) { + /* use mirror_buf optimization */ + mpi_errno = MPIDI_OFI_mirror_get(origin_addr, origin_count, origin_datatype, + target_rank, + target_disp, target_count, target_datatype, win); + } else { + MPIDI_OFI_register_am_bufs(); + mpi_errno = MPIDIG_mpi_get(origin_addr, origin_count, origin_datatype, target_rank, + target_disp, target_count, target_datatype, win); + } goto fn_exit; } @@ -546,6 +575,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_get(void *origin_addr, fn_exit: MPIR_FUNC_EXIT; return mpi_errno; + fn_fail: + goto fn_exit; } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_rput(const void *origin_addr, diff --git a/src/mpid/ch4/netmod/ofi/ofi_rndv_read.c b/src/mpid/ch4/netmod/ofi/ofi_rndv_read.c index 8b4fa5f8b1d..575483f19cf 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_rndv_read.c +++ b/src/mpid/ch4/netmod/ofi/ofi_rndv_read.c @@ -12,7 +12,7 @@ #define MPIDI_OFI_RNDVREAD_INFLY_CHUNKS 10 -static int rndvread_read_poll(MPIX_Async_thing thing); +static int rndvread_read_completion(void *); static int recv_issue_read(MPIR_Request * parent_request, int event_id, void *buf, MPI_Aint data_sz, MPI_Aint offset, MPIDI_av_entry_t * av, int vci_local, int vci_remote, int nic, @@ -38,10 +38,14 @@ int MPIDI_OFI_rndvread_send(MPIR_Request * sreq, int tag) MPIR_Type_get_true_extent_impl(p->datatype, &true_lb, &true_extent); p->u.send.data = MPIR_get_contig_ptr(p->buf, true_lb); - int num_nics = MPIDI_OFI_global.num_nics; - p->u.send.mrs = MPL_malloc((num_nics * sizeof(struct fid_mr *)), MPL_MEM_OTHER); + p->num_nics = MPIDI_OFI_global.num_nics; + if (p->num_nics == 1) { + p->u.send.mrs = &p->u.send.mr0; + } else { + p->u.send.mrs = MPL_malloc((p->num_nics * sizeof(struct fid_mr *)), MPL_MEM_OTHER); + } - int hdr_sz = sizeof(struct rdma_info) + num_nics * sizeof(uint64_t); + int hdr_sz = sizeof(struct rdma_info) + p->num_nics * sizeof(uint64_t); struct rdma_info *hdr = MPL_malloc(hdr_sz, MPL_MEM_OTHER); MPIR_Assertp(hdr); @@ -74,15 +78,16 @@ int MPIDI_OFI_rndvread_ack_event(struct fi_cq_tagged_entry *wc, MPIR_Request * r MPIR_Request *sreq = MPIDI_OFI_RNDV_GET_CONTROL_REQ(r); MPIDI_OFI_rndvread_t *p = &MPIDI_OFI_AMREQ_READ(sreq); - int num_nics = MPIDI_OFI_global.num_nics; - for (int i = 0; i < num_nics; i++) { + for (int i = 0; i < p->num_nics; i++) { uint64_t key = fi_mr_key(p->u.send.mrs[i]); MPIDI_OFI_CALL(fi_close(&p->u.send.mrs[i]->fid), mr_unreg); if (!MPIDI_OFI_ENABLE_MR_PROV_KEY) { MPIDI_OFI_mr_key_free(MPIDI_OFI_LOCAL_MR_KEY, key); } } - MPL_free(p->u.send.mrs); + if (p->num_nics > 1) { + MPL_free(p->u.send.mrs); + } MPL_free(r); /* complete sreq */ @@ -112,8 +117,8 @@ int MPIDI_OFI_rndvread_recv(MPIR_Request * rreq, int tag, int vci_src, int vci_d } /* recv the mrs */ - int num_nics = MPIDI_OFI_global.num_nics; - MPI_Aint hdr_sz = sizeof(struct rdma_info) + num_nics * sizeof(uint64_t); + p->num_nics = MPIDI_OFI_global.num_nics; + MPI_Aint hdr_sz = sizeof(struct rdma_info) + p->num_nics * sizeof(uint64_t); mpi_errno = MPIDI_OFI_RNDV_recv_hdr(rreq, MPIDI_OFI_EVENT_RNDVREAD_RECV_MRS, hdr_sz, p->av, p->vci_local, p->vci_remote, p->match_bits); MPIR_ERR_CHECK(mpi_errno); @@ -134,46 +139,54 @@ int MPIDI_OFI_rndvread_recv_mrs_event(struct fi_cq_tagged_entry *wc, MPIR_Reques MPIDI_OFI_RNDV_update_count(rreq, hdr->data_sz); - int num_nics = MPIDI_OFI_global.num_nics; p->remote_data_sz = MPL_MIN(hdr->data_sz, p->data_sz); - p->u.recv.remote_base = hdr->base; - p->u.recv.rkeys = MPL_malloc(num_nics * sizeof(uint64_t), MPL_MEM_OTHER); - for (int i = 0; i < num_nics; i++) { + if (MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) { + p->u.recv.remote_base = hdr->base; + } else { + p->u.recv.remote_base = 0; + } + if (p->num_nics == 1) { + p->u.recv.rkeys = &p->u.recv.rkey0; + } else { + p->u.recv.rkeys = MPL_malloc(p->num_nics * sizeof(uint64_t), MPL_MEM_OTHER); + } + for (int i = 0; i < p->num_nics; i++) { p->u.recv.rkeys[i] = hdr->rkeys[i]; } MPL_free(r); - /* setup chunks */ - p->u.recv.chunks_per_nic = get_chunks_per_nic(p->remote_data_sz, num_nics); - - p->u.recv.cur_chunk_index = 0; - p->u.recv.num_infly = 0; - /* issue fi_read */ - mpi_errno = MPIR_Async_things_add(rndvread_read_poll, rreq, NULL); + p->u.recv.cmpl_cb = rndvread_read_completion; + p->u.recv.context = rreq; + mpi_errno = MPIR_Async_things_add(MPIDI_OFI_rdmaread_poll, rreq, NULL); return mpi_errno; } -static int rndvread_read_poll(MPIX_Async_thing thing) +int MPIDI_OFI_rdmaread_poll(MPIX_Async_thing thing) { int ret = MPIX_ASYNC_NOPROGRESS; int mpi_errno = MPI_SUCCESS; MPIR_Request *rreq = MPIR_Async_thing_get_state(thing); MPIDI_OFI_rndvread_t *p = &MPIDI_OFI_AMREQ_READ(rreq); + /* setup chunks */ + p->u.recv.chunks_per_nic = get_chunks_per_nic(p->remote_data_sz, p->num_nics); + + p->u.recv.cur_chunk_index = 0; + p->u.recv.num_infly = 0; + /* CS required for genq pool and gpu imemcpy */ MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(p->vci_local)); - int num_nics = MPIDI_OFI_global.num_nics; - while (p->u.recv.cur_chunk_index < p->u.recv.chunks_per_nic * num_nics) { + while (p->u.recv.cur_chunk_index < p->u.recv.chunks_per_nic * p->num_nics) { if (p->u.recv.num_infly >= MPIDI_OFI_RNDVREAD_INFLY_CHUNKS) { goto fn_exit; } int nic; MPI_Aint total_offset, nic_offset, chunk_sz; - get_chunk_offsets(p->u.recv.cur_chunk_index, num_nics, + get_chunk_offsets(p->u.recv.cur_chunk_index, p->num_nics, p->u.recv.chunks_per_nic, p->remote_data_sz, &total_offset, &nic, &nic_offset, &chunk_sz); @@ -190,7 +203,10 @@ static int rndvread_read_poll(MPIX_Async_thing thing) read_buf = (char *) p->u.recv.u.data + total_offset; } uint64_t disp; - if (MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) { + if (p->u.recv.remote_base) { + /* remote_base is nonzero either when MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS is ON, + * or single NIC (thus single mr) is used and it contains base offset. + */ disp = p->u.recv.remote_base + total_offset; } else { disp = nic_offset; @@ -347,13 +363,27 @@ static int check_recv_complete(MPIR_Request * rreq) MPIDI_OFI_rndvread_t *p = &MPIDI_OFI_AMREQ_READ(rreq); if (p->u.recv.all_issued && p->u.recv.num_infly == 0 && (!p->need_pack || p->u.recv.u.copy_infly == 0)) { - /* done. send ack */ - mpi_errno = MPIDI_OFI_RNDV_send_hdr(NULL, 0, p->av, p->vci_local, p->vci_remote, - p->match_bits); - /* complete request */ + /* done */ + mpi_errno = p->u.recv.cmpl_cb(rreq); + } + return mpi_errno; +} + +static int rndvread_read_completion(void *context) +{ + int mpi_errno = MPI_SUCCESS; + + MPIR_Request *rreq = context; + MPIDI_OFI_rndvread_t *p = &MPIDI_OFI_AMREQ_READ(rreq); + + /* send ack */ + mpi_errno = MPIDI_OFI_RNDV_send_hdr(NULL, 0, p->av, p->vci_local, p->vci_remote, p->match_bits); + /* complete request */ + if (p->num_nics > 1) { MPL_free(p->u.recv.rkeys); - MPIR_Datatype_release_if_not_builtin(p->datatype); - MPIDI_Request_complete_fast(rreq); } + MPIR_Datatype_release_if_not_builtin(p->datatype); + MPIDI_Request_complete_fast(rreq); + return mpi_errno; } diff --git a/src/mpid/ch4/netmod/ofi/ofi_win.c b/src/mpid/ch4/netmod/ofi/ofi_win.c index 9ac46837561..6d0429b9c9c 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_win.c +++ b/src/mpid/ch4/netmod/ofi/ofi_win.c @@ -7,6 +7,23 @@ #include "ofi_impl.h" #include "ofi_noinline.h" +/* +=== BEGIN_MPI_T_CVAR_INFO_BLOCK === + +cvars: + - name : MPIR_CVAR_OFI_ENABLE_WIN_MIRROR + category : CH4_OFI + type : boolean + default : false + class : none + verbosity : MPI_T_VERBOSITY_USER_BASIC + scope : MPI_T_SCOPE_LOCAL + description : >- + If enabled, allocate a host mirror buffer to avoid repeated host pack buffer registrations. + +=== END_MPI_T_CVAR_INFO_BLOCK === +*/ + static void load_acc_hint(MPIR_Win * win); static void set_rma_fi_info(MPIR_Win * win, struct fi_info *finfo); static int win_allgather(MPIR_Win * win, void *base, int disp_unit); @@ -124,6 +141,15 @@ static int win_allgather(MPIR_Win * win, void *base, int disp_unit) MPIR_FUNC_ENTER; + /* dynamic window need fallback for most cases */ + if (win->create_flavor == MPI_WIN_FLAVOR_DYNAMIC && + (MPIDI_OFI_ENABLE_MR_ALLOCATED || !MPIDI_OFI_ENABLE_MR_REGISTER_NULL)) { + /* We may still do native atomics with collective attach, let's load acc_hint */ + load_acc_hint(win); + goto fn_exit; + } + + /* init mr_key */ if (!MPIDI_OFI_ENABLE_MR_PROV_KEY) { if (MPIDIG_WIN(win, info_args).optimized_mr && MPIDIG_WIN(win, info_args).accumulate_ordering == 0) { @@ -167,95 +193,132 @@ static int win_allgather(MPIR_Win * win, void *base, int disp_unit) * providers. For providers like CXI, FI_MR_ALLOCATED is not set but registration with * non-zero address is not supported. For such providers, registration is skipped by * using the MPIDI_OFI_ENABLE_MR_REGISTER_NULL capability set variable. */ - int rc = 0, allrc = 0; MPIDI_OFI_WIN(win).mr = NULL; - if (base || (win->create_flavor == MPI_WIN_FLAVOR_DYNAMIC && - !MPIDI_OFI_ENABLE_MR_ALLOCATED && MPIDI_OFI_ENABLE_MR_REGISTER_NULL)) { - size_t len; - if (win->create_flavor == MPI_WIN_FLAVOR_DYNAMIC) { - len = UINTPTR_MAX - (uintptr_t) base; + bool need_mr, is_gpu, do_mirror; + need_mr = MPIDI_OFI_ENABLE_RMA; + + MPL_pointer_attr_t *attr = &MPIDI_OFI_WIN(win).base_attr; + if (base) { + MPIR_GPU_query_pointer_attr(base, attr); + is_gpu = MPL_gpu_attr_is_dev(attr); + if (is_gpu) { + if (MPIDI_OFI_ENABLE_HMEM && MPL_gpu_attr_is_strict_dev(attr)) { + do_mirror = false; + } else if (MPIR_CVAR_OFI_ENABLE_WIN_MIRROR && need_mr) { + do_mirror = true; + } else { + need_mr = false; + do_mirror = false; + } } else { - len = (size_t) win->size; + do_mirror = false; } + } else { + MPIR_Assert(win->create_flavor == MPI_WIN_FLAVOR_DYNAMIC || win->size == 0); + MPIR_Assert(MPIDI_OFI_ENABLE_RMA); + /* provider allows registering the whole address space */ + is_gpu = false; + do_mirror = false; + } - MPL_pointer_attr_t attr; - MPIR_GPU_query_pointer_attr(base, &attr); - if (MPL_gpu_attr_is_dev(&attr)) { - if (MPIDI_OFI_ENABLE_HMEM && MPL_gpu_attr_is_strict_dev(&attr)) { - mpi_errno = - MPIDI_OFI_register_memory(base, len, &attr, ctx_idx, MPIDI_OFI_WIN(win).mr_key, - &MPIDI_OFI_WIN(win).mr); - if (mpi_errno != MPI_SUCCESS) - rc = -1; - } else { - rc = -1; - } + void *addr; + uintptr_t len; + if (base) { + len = (uintptr_t) win->size; + } else { + len = UINTPTR_MAX - (uintptr_t) base; + } + if (do_mirror) { + int ret = MPL_gpu_malloc_host(&addr, len); + MPIR_ERR_CHKANDJUMP(ret || !addr, mpi_errno, MPI_ERR_OTHER, "**nomem"); + MPIDI_OFI_WIN(win).mirror_buf = addr; + MPIR_GPU_query_pointer_attr(base, &MPIDI_OFI_WIN(win).mirror_attr); + } else { + addr = base; + MPIDI_OFI_WIN(win).mirror_buf = NULL; + } + + if (need_mr) { + int rc; + if (is_gpu) { + rc = MPIDI_OFI_register_memory(base, len, attr, ctx_idx, + MPIDI_OFI_WIN(win).mr_key, &MPIDI_OFI_WIN(win).mr); } else { - MPIDI_OFI_CALL_RETURN(fi_mr_reg(MPIDI_OFI_global.ctx[ctx_idx].domain, /* In: Domain Object */ - base, /* In: Lower memory address */ - len, /* In: Length */ - FI_REMOTE_READ | FI_REMOTE_WRITE, /* In: Expose MR for read */ - 0ULL, /* In: offset(not used) */ + MPIDI_OFI_CALL_RETURN(fi_mr_reg(MPIDI_OFI_global.ctx[ctx_idx].domain, base, len, FI_REMOTE_READ | FI_REMOTE_WRITE, 0ULL, /* In: offset(not used) */ MPIDI_OFI_WIN(win).mr_key, /* In: requested key */ 0ULL, /* In: flags */ &MPIDI_OFI_WIN(win).mr, /* Out: memregion object */ NULL), rc); /* In: context */ } - if (rc == 0) { + + if (rc == MPI_SUCCESS) { mpi_errno = MPIDI_OFI_mr_bind(MPIDI_OFI_global.prov_use[0], MPIDI_OFI_WIN(win).mr, MPIDI_OFI_WIN(win).ep, MPIDI_OFI_WIN(win).cmpl_cntr); MPIR_ERR_CHECK(mpi_errno); + } else { + MPIDI_OFI_WIN(win).mr = NULL; } - } else if (win->create_flavor == MPI_WIN_FLAVOR_DYNAMIC) { - /* We may still do native atomics with collective attach, let's load acc_hint */ - load_acc_hint(win); - goto fn_exit; - } else { - /* Do nothing */ } - /* Check if any process fails to register. If so, release local MR and force AM path. */ - MPIR_Allreduce(&rc, &allrc, 1, MPIR_INT_INTERNAL, MPI_MIN, comm_ptr, MPIR_COLL_ATTR_SYNC); - if (allrc < 0) { - if (rc >= 0 && MPIDI_OFI_WIN(win).mr) - MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_WIN(win).mr->fid), fi_close); - MPIDI_OFI_WIN(win).mr = NULL; - goto fn_exit; - } else { + /* collectively check whether MPIDI_WINATTR_NM_REACHABLE */ + /* NOTE: mirror_buf is not exposed to remote processes directly. + * They are just for optimizing the fallback path */ + int got_mr = (!do_mirror && MPIDI_OFI_WIN(win).mr != NULL); + int min_got_mr; + MPIR_Allreduce(&got_mr, &min_got_mr, 1, MPIR_INT_INTERNAL, MPI_MIN, comm_ptr, + MPIR_COLL_ATTR_SYNC); + if (min_got_mr) { MPIDI_WIN(win, winattr) |= MPIDI_WINATTR_NM_REACHABLE; /* enable NM native RMA */ + } else { + /* every one need aggree on using mirror buffer */ + /* FIXME: do we need collectively check? */ + if (MPIR_CVAR_OFI_ENABLE_WIN_MIRROR) { + MPIR_Assert(MPIDI_OFI_WIN(win).mirror_buf && MPIDI_OFI_WIN(win).mr); + } } - MPIDI_OFI_WIN(win).winfo = MPL_malloc(sizeof(*winfo) * comm_ptr->local_size, MPL_MEM_RMA); + if (!min_got_mr && !do_mirror && MPIDI_OFI_WIN(win).mr) { + /* mr not needed */ + MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_WIN(win).mr->fid), fi_close); + MPIDI_OFI_WIN(win).mr = NULL; + } - winfo = MPIDI_OFI_WIN(win).winfo; - winfo[comm_ptr->rank].disp_unit = disp_unit; + if (min_got_mr) { + /* allgather winfo */ + MPIDI_OFI_WIN(win).winfo = MPL_malloc(sizeof(*winfo) * comm_ptr->local_size, MPL_MEM_RMA); - if ((MPIDI_OFI_ENABLE_MR_PROV_KEY || MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) && MPIDI_OFI_WIN(win).mr) { - /* MR_BASIC */ - MPIDI_OFI_WIN(win).mr_key = fi_mr_key(MPIDI_OFI_WIN(win).mr); - winfo[comm_ptr->rank].mr_key = MPIDI_OFI_WIN(win).mr_key; - winfo[comm_ptr->rank].base = (uintptr_t) base; - } + if (MPIDI_OFI_ENABLE_MR_PROV_KEY && MPIDI_OFI_WIN(win).mr) { + MPIDI_OFI_WIN(win).mr_key = fi_mr_key(MPIDI_OFI_WIN(win).mr); + } + winfo = MPIDI_OFI_WIN(win).winfo; + winfo[comm_ptr->rank].disp_unit = disp_unit; + if (MPIDI_OFI_WIN(win).mr_key) { + winfo[comm_ptr->rank].mr_key = MPIDI_OFI_WIN(win).mr_key; + winfo[comm_ptr->rank].base = (uintptr_t) addr; + } else { + winfo[comm_ptr->rank].mr_key = 0; + winfo[comm_ptr->rank].base = 0; + } - mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, - MPI_DATATYPE_NULL, - winfo, sizeof(*winfo), MPIR_BYTE_INTERNAL, comm_ptr, - MPIR_COLL_ATTR_SYNC); - MPIR_ERR_CHECK(mpi_errno); + mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, + winfo, sizeof(*winfo), MPIR_BYTE_INTERNAL, comm_ptr, + MPIR_COLL_ATTR_SYNC); + MPIR_ERR_CHECK(mpi_errno); - if (!MPIDI_OFI_ENABLE_MR_PROV_KEY && !MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) { - first = winfo[0].disp_unit; - same_disp = 1; - for (i = 1; i < comm_ptr->local_size; i++) { - if (winfo[i].disp_unit != first) { - same_disp = 0; - break; + if (!MPIDI_OFI_ENABLE_MR_PROV_KEY && !MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) { + first = winfo[0].disp_unit; + same_disp = 1; + for (i = 1; i < comm_ptr->local_size; i++) { + if (winfo[i].disp_unit != first) { + same_disp = 0; + break; + } + } + if (same_disp) { + MPL_free(MPIDI_OFI_WIN(win).winfo); + MPIDI_OFI_WIN(win).winfo = NULL; } - } - if (same_disp) { - MPL_free(MPIDI_OFI_WIN(win).winfo); - MPIDI_OFI_WIN(win).winfo = NULL; } } @@ -1089,6 +1152,10 @@ int MPIDI_OFI_mpi_win_free_hook(MPIR_Win * win) MPL_free(MPIDI_OFI_WIN(win).acc_hint); MPIDI_OFI_WIN(win).acc_hint = NULL; + if (MPIDI_OFI_WIN(win).mirror_buf) { + MPL_gpu_free_host(MPIDI_OFI_WIN(win).mirror_buf); + } + /* Free storage of per-attach memory regions for dynamic window */ if (win->create_flavor == MPI_WIN_FLAVOR_DYNAMIC && !MPIDI_OFI_WIN(win).mr && MPIDIG_WIN(win, info_args).coll_attach) { diff --git a/src/mpid/ch4/src/ch4_impl.h b/src/mpid/ch4/src/ch4_impl.h index bf68132c93c..bb9c8e65b13 100644 --- a/src/mpid/ch4/src/ch4_impl.h +++ b/src/mpid/ch4/src/ch4_impl.h @@ -656,15 +656,38 @@ MPL_STATIC_INLINE_PREFIX void MPIDIG_win_cmpl_cnts_incr(MPIR_Win * win, int targ MPIR_cc_inc(&target_ptr->local_cmpl_cnts); MPIR_cc_inc(&target_ptr->remote_cmpl_cnts); - - *local_cmpl_cnts_ptr = &target_ptr->local_cmpl_cnts; + if (local_cmpl_cnts_ptr) { + *local_cmpl_cnts_ptr = &target_ptr->local_cmpl_cnts; + } break; } default: MPIR_cc_inc(&MPIDIG_WIN(win, local_cmpl_cnts)); MPIR_cc_inc(&MPIDIG_WIN(win, remote_cmpl_cnts)); + if (local_cmpl_cnts_ptr) { + *local_cmpl_cnts_ptr = &MPIDIG_WIN(win, local_cmpl_cnts); + } + break; + } +} + +/* Use this if local_cmpl_cnts_ptr is NULL in MPIDIG_win_cmpl_cnts_incr */ +MPL_STATIC_INLINE_PREFIX void MPIDIG_win_cmpl_cnts_decr(MPIR_Win * win, int target_rank) +{ + switch (MPIDIG_WIN(win, sync).access_epoch_type) { + case MPIDIG_EPOTYPE_LOCK: + case MPIDIG_EPOTYPE_LOCK_ALL: + case MPIDIG_EPOTYPE_START: + { + MPIDIG_win_target_t *target_ptr = MPIDIG_win_target_get(win, target_rank); - *local_cmpl_cnts_ptr = &MPIDIG_WIN(win, local_cmpl_cnts); + MPIR_cc_dec(&target_ptr->local_cmpl_cnts); + MPIR_cc_dec(&target_ptr->remote_cmpl_cnts); + break; + } + default: + MPIR_cc_dec(&MPIDIG_WIN(win, local_cmpl_cnts)); + MPIR_cc_dec(&MPIDIG_WIN(win, remote_cmpl_cnts)); break; } } diff --git a/src/mpid/ch4/src/mpidig.h b/src/mpid/ch4/src/mpidig.h index db850ea931e..e38d87220c3 100644 --- a/src/mpid/ch4/src/mpidig.h +++ b/src/mpid/ch4/src/mpidig.h @@ -64,6 +64,8 @@ enum { MPIDI_OFI_AM_RDMA_READ_ACK, MPIDI_OFI_RNDV_INFO, + MPIDI_OFI_GET_REQ, + MPIDI_OFI_GET_ACK, MPIDIG_HANDLER_STATIC_MAX };