@@ -510,6 +510,7 @@ struct get_context {
510510 MPI_Aint data_sz ;
511511 void * origin_addr ;
512512 MPI_Aint target_offset ;
513+ MPIR_Request * req ;
513514};
514515
515516struct get_hdr {
@@ -554,6 +555,14 @@ int MPIDI_OFI_mirror_get(void *origin_addr, MPI_Aint origin_count, MPI_Datatype
554555 origin_context -> origin_addr = (char * ) origin_addr + origin_true_lb ;
555556 origin_context -> target_offset = target_disp * win -> disp_unit + target_true_lb ;
556557
558+ /* allocate a request, used for reuse the code from ofi_rndv_read. */
559+ MPIR_Request * req ;
560+ MPIDI_OFI_REQUEST_CREATE (req , MPIR_REQUEST_KIND__RMA , vci );
561+ if (1 ) {
562+ MPIDI_CH4_REQUEST_FREE (req );
563+ }
564+ origin_context -> req = req ;
565+
557566 /* fill am_hdr */
558567 struct get_hdr am_hdr ;
559568 am_hdr .win_id = MPIDIG_WIN (win , win_id );
@@ -638,6 +647,7 @@ int MPIDI_OFI_get_handler(void *am_hdr, void *data, MPI_Aint in_data_sz,
638647struct getack_hdr {
639648 void * origin_context ;
640649 uint64_t rkey ;
650+ uint64_t remote_base ;
641651};
642652
643653/* target side - async callback */
@@ -652,6 +662,7 @@ static int target_mirror_copy_poll(MPIX_Async_thing thing)
652662 struct getack_hdr am_hdr ;
653663 am_hdr .origin_context = p -> origin_context ;
654664 am_hdr .rkey = fi_mr_key (MPIDI_OFI_WIN (p -> win ).mr );
665+ am_hdr .remote_base = (uintptr_t ) MPIDI_OFI_WIN (p -> win ).mirror_buf ;
655666
656667 int rc = MPIDI_NM_am_send_hdr (p -> origin_rank , p -> win -> comm_ptr , MPIDI_OFI_GET_ACK ,
657668 & am_hdr , sizeof (am_hdr ), p -> vci_target , p -> vci_origin );
@@ -672,6 +683,8 @@ struct read_req {
672683 struct get_context * origin_context ;
673684};
674685
686+ static int rdmaread_completion (void * context );
687+
675688/* origin side - AM callback */
676689int MPIDI_OFI_getack_handler (void * am_hdr , void * data , MPI_Aint in_data_sz ,
677690 uint32_t attr , MPIR_Request * * req )
@@ -682,56 +695,48 @@ int MPIDI_OFI_getack_handler(void *am_hdr, void *data, MPI_Aint in_data_sz,
682695 struct get_context * origin_context = msg_hdr -> origin_context ;
683696 MPIR_Win * win = origin_context -> win ;
684697 int target_rank = origin_context -> target_rank ;
685- void * buf = origin_context -> origin_addr ;
686- MPI_Aint len = origin_context -> data_sz ;
687698 MPI_Aint target_offset = origin_context -> target_offset ;
688699
689- /* rdma read */
690- /* FIXME: potentially we need allocate a staging buffer */
691- void * desc = NULL ;
692- int nic_target = MPIDI_OFI_get_pref_nic (win -> comm_ptr , target_rank );
693-
694- MPL_pointer_attr_t bufattr ;
695- MPIR_GPU_query_pointer_attr (buf , & bufattr );
696- if (MPL_gpu_attr_is_strict_dev (& bufattr )) {
697- MPIDI_OFI_gpu_rma_register (buf , len , & bufattr , win , nic_target , & desc );
698- }
699-
700- int vci = MPIDI_WIN (win , am_vci );
701- int vci_target = MPIDI_WIN_TARGET_VCI (win , target_rank );
700+ MPIDI_OFI_rndvread_t * p = & MPIDI_OFI_AMREQ_READ (origin_context -> req );
701+ p -> buf = origin_context -> origin_addr ;
702+ p -> count = origin_context -> data_sz ;
703+ p -> datatype = MPIR_BYTE_INTERNAL ;
702704
703- MPIDI_av_entry_t * av = MPIDIU_win_rank_to_av ( win , target_rank , MPIDI_WIN ( win , winattr ) );
704- fi_addr_t addr = MPIDI_OFI_av_to_phys ( av , vci , 0 , vci_target , nic_target );
705+ MPIR_GPU_query_pointer_attr ( p -> buf , & p -> attr );
706+ p -> need_pack = MPL_gpu_attr_is_dev ( & p -> attr );
705707
706- struct read_req * r = MPL_malloc (sizeof (struct read_req ), MPL_MEM_OTHER );
707- MPIR_ERR_CHKANDJUMP (!r , mpi_errno , MPI_ERR_OTHER , "**nomem" );
708+ p -> data_sz = p -> remote_data_sz = origin_context -> data_sz ;
709+ p -> vci_local = MPIDI_WIN (win , am_vci );
710+ p -> vci_remote = MPIDI_WIN_TARGET_VCI (win , target_rank );
711+ p -> av = MPIDIU_win_rank_to_av (win , target_rank , MPIDI_WIN (win , winattr ));
708712
709- r -> event_id = MPIDI_OFI_EVENT_MIRROR_READ ;
710- r -> origin_context = origin_context ;
713+ p -> num_nics = 1 ;
714+ if (MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS ) {
715+ p -> u .recv .remote_base = msg_hdr -> remote_base + target_offset ;
716+ } else {
717+ p -> u .recv .remote_base = target_offset ;
718+ }
719+ p -> u .recv .rkeys = & p -> u .recv .rkey0 ;
720+ p -> u .recv .rkey0 = msg_hdr -> rkey ;
721+ p -> u .recv .cmpl_cb = rdmaread_completion ;
722+ p -> u .recv .context = origin_context ;
711723
712- MPID_THREAD_CS_ENTER (VCI , MPIDI_VCI_LOCK (vci ));
713- MPIDI_OFI_CALL_RETRY (fi_read (MPIDI_OFI_WIN (win ).ep , buf , len , desc ,
714- addr , target_offset , msg_hdr -> rkey , (void * ) & r -> context ),
715- vci , rdma_readfrom );
716- /* Complete signal request to inform completion to user. */
717- MPID_THREAD_CS_EXIT (VCI , MPIDI_VCI_LOCK (vci ));
724+ mpi_errno = MPIR_Async_things_add (MPIDI_OFI_rdmaread_poll , origin_context -> req , NULL );
718725
719- fn_exit :
720726 return mpi_errno ;
721- fn_fail :
722- goto fn_exit ;
723727}
724728
725- /* origin side - rdma event callback */
726- int MPIDI_OFI_mirror_read_event (struct fi_cq_tagged_entry * wc , MPIR_Request * req )
729+ static int rdmaread_completion (void * context )
727730{
728- struct read_req * r = ( void * ) req ;
729- struct get_context * origin_context = r -> origin_context ;
731+ struct get_context * origin_context = context ;
732+
730733 MPIR_Win * win = origin_context -> win ;
731734 int target_rank = origin_context -> target_rank ;
732735
733736 MPIDIG_win_cmpl_cnts_decr (win , target_rank );
734737
738+ MPIDI_Request_complete_fast (origin_context -> req );
735739 MPL_free (origin_context );
736- MPL_free (r );
740+
741+ return MPI_SUCCESS ;
737742}
0 commit comments