-
Notifications
You must be signed in to change notification settings - Fork 13.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][mpi] Lowering MPI_Allreduce #133133
Conversation
@llvm/pr-subscribers-mlir Author: Frank Schlimbach (fschlimb) ChangesAdding lowering of MPI_Allreduce. FYI: @tkarna @mofeing Patch is 21.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133133.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 7c84443e5520d..f2837e71df060 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -246,12 +246,7 @@ def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
MPI_OpMaxloc,
MPI_OpReplace
]> {
- let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mpi";
}
-def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
#endif // MLIR_DIALECT_MPI_IR_MPI_TD
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index db28bd09678f8..a8267b115b9e6 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -244,7 +244,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
- MPI_OpClassAttr : $op
+ MPI_OpClassEnum : $op
);
let results = (outs Optional<MPI_Retval>:$retval);
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index d91f9512ccb8f..4e0f59305a647 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -47,6 +47,22 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
}
+std::pair<Value, Value> getRawPtrAndSize(const Location loc,
+ ConversionPatternRewriter &rewriter,
+ Value memRef, Type elType) {
+ Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+ Value dataPtr =
+ rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
+ Value offset = rewriter.create<LLVM::ExtractValueOp>(
+ loc, rewriter.getI64Type(), memRef, 2);
+ Value resPtr =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
+ Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
+ ArrayRef<int64_t>{3, 0});
+ size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+ return {resPtr, size};
+}
+
/// When lowering the mpi dialect to functions calls certain details
/// differ between various MPI implementations. This class will provide
/// these in a generic way, depending on the MPI implementation that got
@@ -77,6 +93,12 @@ class MPIImplTraits {
/// type.
virtual Value getDataType(const Location loc,
ConversionPatternRewriter &rewriter, Type type) = 0;
+
+ /// Gets or creates an MPI_Op value which corresponds to the given
+ /// enum value.
+ virtual Value getMPIOp(const Location loc,
+ ConversionPatternRewriter &rewriter,
+ mpi::MPI_OpClassEnum opAttr) = 0;
};
//===----------------------------------------------------------------------===//
@@ -94,6 +116,20 @@ class MPICHImplTraits : public MPIImplTraits {
static constexpr int MPI_UINT16_T = 0x4c00023c;
static constexpr int MPI_UINT32_T = 0x4c00043d;
static constexpr int MPI_UINT64_T = 0x4c00083e;
+ static constexpr int MPI_MAX = 0x58000001;
+ static constexpr int MPI_MIN = 0x58000002;
+ static constexpr int MPI_SUM = 0x58000003;
+ static constexpr int MPI_PROD = 0x58000004;
+ static constexpr int MPI_LAND = 0x58000005;
+ static constexpr int MPI_BAND = 0x58000006;
+ static constexpr int MPI_LOR = 0x58000007;
+ static constexpr int MPI_BOR = 0x58000008;
+ static constexpr int MPI_LXOR = 0x58000009;
+ static constexpr int MPI_BXOR = 0x5800000a;
+ static constexpr int MPI_MINLOC = 0x5800000b;
+ static constexpr int MPI_MAXLOC = 0x5800000c;
+ static constexpr int MPI_REPLACE = 0x5800000d;
+ static constexpr int MPI_NO_OP = 0x5800000e;
public:
using MPIImplTraits::MPIImplTraits;
@@ -136,6 +172,56 @@ class MPICHImplTraits : public MPIImplTraits {
assert(false && "unsupported type");
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
}
+
+ Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
+ mpi::MPI_OpClassEnum opAttr) override {
+ int32_t op = MPI_NO_OP;
+ switch (opAttr) {
+ case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+ op = MPI_NO_OP;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAX:
+ op = MPI_MAX;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MIN:
+ op = MPI_MIN;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_SUM:
+ op = MPI_SUM;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_PROD:
+ op = MPI_PROD;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LAND:
+ op = MPI_LAND;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BAND:
+ op = MPI_BAND;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LOR:
+ op = MPI_LOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BOR:
+ op = MPI_BOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LXOR:
+ op = MPI_LXOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BXOR:
+ op = MPI_BXOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MINLOC:
+ op = MPI_MINLOC;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+ op = MPI_MAXLOC;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_REPLACE:
+ op = MPI_REPLACE;
+ break;
+ }
+ return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
+ }
};
//===----------------------------------------------------------------------===//
@@ -205,15 +291,74 @@ class OMPIImplTraits : public MPIImplTraits {
auto context = rewriter.getContext();
// get external opaque struct pointer type
- auto commStructT =
+ auto typeStructT =
LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
// make sure global op definition exists
- getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
+ getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
// get address of symbol
return rewriter.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(context),
SymbolRefAttr::get(context, mtype));
}
+
+ Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
+ mpi::MPI_OpClassEnum opAttr) override {
+ StringRef op;
+ switch (opAttr) {
+ case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+ op = "ompi_mpi_no_op";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAX:
+ op = "ompi_mpi_max";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MIN:
+ op = "ompi_mpi_min";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_SUM:
+ op = "ompi_mpi_sum";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_PROD:
+ op = "ompi_mpi_prod";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LAND:
+ op = "ompi_mpi_land";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BAND:
+ op = "ompi_mpi_band";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LOR:
+ op = "ompi_mpi_lor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BOR:
+ op = "ompi_mpi_bor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LXOR:
+ op = "ompi_mpi_lxor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BXOR:
+ op = "ompi_mpi_bxor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MINLOC:
+ op = "ompi_mpi_minloc";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+ op = "ompi_mpi_maxloc";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_REPLACE:
+ op = "ompi_mpi_replace";
+ break;
+ }
+ auto context = rewriter.getContext();
+ // get external opaque struct pointer type
+ auto opStructT =
+ LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
+ // make sure global op definition exists
+ getOrDefineExternalStruct(loc, rewriter, op, opStructT);
+ // get address of symbol
+ return rewriter.create<LLVM::AddressOfOp>(
+ loc, LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, op));
+ }
};
std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
@@ -365,8 +510,6 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
Location loc = op.getLoc();
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
- Type i64 = rewriter.getI64Type();
- Value memRef = adaptor.getRef();
Type elemType = op.getRef().getType().getElementType();
// ptrType `!llvm.ptr`
@@ -376,14 +519,8 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
auto moduleOp = op->getParentOfType<ModuleOp>();
// get MPI_COMM_WORLD, dataType and pointer
- Value dataPtr =
- rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
- Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
- dataPtr =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
- Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+ auto [dataPtr, size] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -425,7 +562,6 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
- Value memRef = adaptor.getRef();
Type elemType = op.getRef().getType().getElementType();
// ptrType `!llvm.ptr`
@@ -435,14 +571,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
auto moduleOp = op->getParentOfType<ModuleOp>();
// get MPI_COMM_WORLD, dataType, status_ignore and pointer
- Value dataPtr =
- rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
- Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
- dataPtr =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
- Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+ auto [dataPtr, size] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -474,6 +604,55 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
}
};
+//===----------------------------------------------------------------------===//
+// AllReduceOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type i32 = rewriter.getI32Type();
+ Type elemType = op.getSendbuf().getType().getElementType();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ auto [sendPtr, sendSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
+ auto [recvPtr, recvSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+ Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
+ Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
+ Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+ // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
+ // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
+ auto funcType = LLVM::LLVMFunctionType::get(
+ i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
+ commWorld.getType()});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
+
+ // replace op with function call
+ auto funcCall = rewriter.create<LLVM::CallOp>(
+ loc, funcDecl,
+ ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
+
+ if (op.getRetval())
+ rewriter.replaceOp(op, funcCall.getResult());
+ else
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//
@@ -498,7 +677,7 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
- SendOpLowering, RecvOpLowering>(converter);
+ SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
similarity index 78%
rename from mlir/test/Conversion/MPIToLLVM/ops.mlir
rename to mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 3c1b344efd50b..249ef195e8f5c 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -1,13 +1,13 @@
// RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s
// COM: Test MPICH ABI
-// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH">} {
+// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
-module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
func.func @mpi_test_mpich(%arg0: memref<100xf32>) {
@@ -73,7 +73,22 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
- // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
+ // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+ // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+ // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
+ // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+
+ // CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
return
@@ -83,7 +98,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
// -----
// COM: Test OpenMPI ABI
-// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
+// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
@@ -91,7 +106,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
-module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
+module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
func.func @mpi_test_openmpi(%arg0: memref<100xf32>) {
@@ -157,6 +172,22 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+ // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+ // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+ // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
+ // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK-NEXT: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+ mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i...
[truncated]
|
// CHECK-NEXT: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 | ||
// CHECK-NEXT: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32 | ||
// CHECK-NEXT: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32 | ||
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check for llvm.call here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Will add tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
static constexpr int MPI_MINLOC = 0x5800000b; | ||
static constexpr int MPI_MAXLOC = 0x5800000c; | ||
static constexpr int MPI_REPLACE = 0x5800000d; | ||
static constexpr int MPI_NO_OP = 0x5800000e; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you think about reusing the same values as in the incoming MPI 5.0 ABI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are current MPI implementations supporting this yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, each implementation uses their own values. you can check some of them in here https://github.com/JuliaParallel/MPI.jl/tree/master/src/api
i guess that in the future they will change to the ones in MPI 5, that's why i suggested but it's not mandatory for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there is agreement that we'll go for explicit specialization for now until the ABI is broadly available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mofeing As I said already, MPICH supports the MPI-5 ABI already (you must enable it with configure, as it is not the default). I will sync Mukautuva to the final MPI-5 ABI in a few days.
#133280 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's great to know!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great.
Once we there are at least 2 popular implementations supporting it we should switch to the ABI.
ping @AntonLydike @Dinistro |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/198/builds/3283 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/41/builds/5853 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/199/builds/2489 Here is the relevant piece of the build log for the reference
|
Lowering of mpi.all_reduce to LLVM function call
Adding lowering of MPI_Allreduce.
FYI: @tkarna @mofeing