Skip to content

Allow UCC to be used with sessions #13311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ompi/communicator/comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -2528,7 +2528,7 @@ int ompi_comm_determine_first_auto ( ompi_communicator_t* intercomm )
/********************************************************************************/
int ompi_comm_dump ( ompi_communicator_t *comm )
{
opal_output(0, "Dumping information for comm_cid %s\n", ompi_comm_print_cid (comm));
opal_output(0, "Dumping information for comm_cid %s : %d\n", ompi_comm_print_cid (comm), ompi_comm_get_local_cid(comm));
opal_output(0," f2c index:%d cube_dim: %d\n", comm->c_f_to_c_index,
comm->c_cube_dim);
opal_output(0," Local group: size = %d my_rank = %d\n",
Expand All @@ -2539,13 +2539,17 @@ int ompi_comm_dump ( ompi_communicator_t *comm )
/* Display flags */
if ( OMPI_COMM_IS_INTER(comm) )
opal_output(0," inter-comm,");
else
opal_output(0," intra-comm,");
if ( OMPI_COMM_IS_CART(comm))
opal_output(0," topo-cart");
else if ( OMPI_COMM_IS_GRAPH(comm))
opal_output(0," topo-graph");
else if ( OMPI_COMM_IS_DIST_GRAPH(comm))
opal_output(0," topo-dist-graph");
opal_output(0,"\n");
else
opal_output(0, " no topo");
opal_output(0,"\n");

if (OMPI_COMM_IS_INTER(comm)) {
opal_output(0," Remote group size:%d\n", comm->c_remote_group->grp_proc_count);
Expand Down
33 changes: 20 additions & 13 deletions ompi/mca/coll/ucc/coll_ucc_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "ompi_config.h"
#include "coll_ucc.h"
#include "coll_ucc_common.h"
#include "coll_ucc_dtypes.h"
#include "ompi/mca/coll/base/coll_tags.h"
#include "ompi/mca/pml/pml.h"
Expand Down Expand Up @@ -219,7 +220,8 @@ static ucc_status_t oob_allgather(void *sbuf, void *rbuf, size_t msglen,
}


static int mca_coll_ucc_init_ctx() {
static int mca_coll_ucc_init_ctx(ompi_communicator_t* comm)
{
mca_coll_ucc_component_t *cm = &mca_coll_ucc_component;
char str_buf[256];
ompi_attribute_fn_ptr_union_t del_fn;
Expand Down Expand Up @@ -270,9 +272,9 @@ static int mca_coll_ucc_init_ctx() {
ctx_params.oob.allgather = oob_allgather;
ctx_params.oob.req_test = oob_allgather_test;
ctx_params.oob.req_free = oob_allgather_free;
ctx_params.oob.coll_info = (void*)MPI_COMM_WORLD;
ctx_params.oob.n_oob_eps = ompi_comm_size(&ompi_mpi_comm_world.comm);
ctx_params.oob.oob_ep = ompi_comm_rank(&ompi_mpi_comm_world.comm);
ctx_params.oob.coll_info = (void*)comm;
ctx_params.oob.n_oob_eps = ompi_comm_size(comm);
ctx_params.oob.oob_ep = ompi_comm_rank(comm);
if (UCC_OK != ucc_context_config_read(cm->ucc_lib, NULL, &ctx_config)) {
UCC_ERROR("UCC context config read failed");
goto cleanup_lib;
Expand Down Expand Up @@ -329,7 +331,7 @@ static int mca_coll_ucc_init_ctx() {
return OMPI_ERROR;
}

uint64_t rank_map_cb(uint64_t ep, void *cb_ctx)
static uint64_t rank_map_cb(uint64_t ep, void *cb_ctx)
{
struct ompi_communicator_t *comm = cb_ctx;

Expand Down Expand Up @@ -433,8 +435,7 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
ucc_team_params_t team_params = {
.mask = UCC_TEAM_PARAM_FIELD_EP_MAP |
UCC_TEAM_PARAM_FIELD_EP |
UCC_TEAM_PARAM_FIELD_EP_RANGE |
UCC_TEAM_PARAM_FIELD_ID,
UCC_TEAM_PARAM_FIELD_EP_RANGE,
.ep_map = {
.type = (comm == &ompi_mpi_comm_world.comm) ?
UCC_EP_MAP_FULL : UCC_EP_MAP_CB,
Expand All @@ -443,12 +444,18 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
.cb.cb_ctx = (void*)comm
},
.ep = ompi_comm_rank(comm),
.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG,
.id = ompi_comm_get_local_cid(comm)
.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG
};
UCC_VERBOSE(2, "creating ucc_team for comm %p, comm_id %llu, comm_size %d",
(void*)comm, (long long unsigned)team_params.id,
ompi_comm_size(comm));
if (OMPI_COMM_IS_GLOBAL_INDEX(comm)) {
team_params.mask |= UCC_TEAM_PARAM_FIELD_ID;
team_params.id = ompi_comm_get_local_cid(comm);
UCC_VERBOSE(2, "creating ucc_team for comm %p, comm_id %llu, comm_size %d",
(void*)comm, (long long unsigned)team_params.id,
ompi_comm_size(comm));
} else {
UCC_VERBOSE(2, "creating ucc_team for comm %p, comm_id not provided, comm_size %d",
(void*)comm, ompi_comm_size(comm));
}

if (UCC_OK != ucc_team_create_post(&cm->ucc_context, 1,
&team_params, &ucc_module->ucc_team)) {
Expand Down Expand Up @@ -555,7 +562,7 @@ mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority)
}

if (!cm->libucc_initialized) {
if (OMPI_SUCCESS != mca_coll_ucc_init_ctx()) {
if (OMPI_SUCCESS != mca_coll_ucc_init_ctx(comm)) {
cm->ucc_enable = 0;
return NULL;
}
Expand Down