Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][mpi] Lowering MPI_Allreduce #133133

Merged
merged 5 commits into from
Mar 31, 2025
Merged

[mlir][mpi] Lowering MPI_Allreduce #133133

merged 5 commits into from
Mar 31, 2025

Conversation

fschlimb
Copy link
Contributor

Adding lowering of MPI_Allreduce.

FYI: @tkarna @mofeing

@llvmbot llvmbot added the mlir label Mar 26, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes

Adding lowering of MPI_Allreduce.

FYI: @tkarna @mofeing


Patch is 21.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133133.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MPI/IR/MPI.td (-5)
  • (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+1-1)
  • (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+201-22)
  • (renamed) mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir (+36-5)
  • (renamed) mlir/test/Dialect/MPI/mpiops.mlir (+4-4)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 7c84443e5520d..f2837e71df060 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -246,12 +246,7 @@ def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
       MPI_OpMaxloc,
       MPI_OpReplace
     ]> {
-  let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::mpi";
 }
 
-def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
 #endif // MLIR_DIALECT_MPI_IR_MPI_TD
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index db28bd09678f8..a8267b115b9e6 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -244,7 +244,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
   let arguments = (
     ins AnyMemRef : $sendbuf,
     AnyMemRef : $recvbuf,
-    MPI_OpClassAttr : $op
+    MPI_OpClassEnum : $op
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index d91f9512ccb8f..4e0f59305a647 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -47,6 +47,22 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
       moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
 }
 
+std::pair<Value, Value> getRawPtrAndSize(const Location loc,
+                                         ConversionPatternRewriter &rewriter,
+                                         Value memRef, Type elType) {
+  Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+  Value dataPtr =
+      rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
+  Value offset = rewriter.create<LLVM::ExtractValueOp>(
+      loc, rewriter.getI64Type(), memRef, 2);
+  Value resPtr =
+      rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
+  Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
+                                                     ArrayRef<int64_t>{3, 0});
+  size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+  return {resPtr, size};
+}
+
 /// When lowering the mpi dialect to functions calls certain details
 /// differ between various MPI implementations. This class will provide
 /// these in a generic way, depending on the MPI implementation that got
@@ -77,6 +93,12 @@ class MPIImplTraits {
   /// type.
   virtual Value getDataType(const Location loc,
                             ConversionPatternRewriter &rewriter, Type type) = 0;
+
+  /// Gets or creates an MPI_Op value which corresponds to the given
+  /// enum value.
+  virtual Value getMPIOp(const Location loc,
+                         ConversionPatternRewriter &rewriter,
+                         mpi::MPI_OpClassEnum opAttr) = 0;
 };
 
 //===----------------------------------------------------------------------===//
@@ -94,6 +116,20 @@ class MPICHImplTraits : public MPIImplTraits {
   static constexpr int MPI_UINT16_T = 0x4c00023c;
   static constexpr int MPI_UINT32_T = 0x4c00043d;
   static constexpr int MPI_UINT64_T = 0x4c00083e;
+  static constexpr int MPI_MAX = 0x58000001;
+  static constexpr int MPI_MIN = 0x58000002;
+  static constexpr int MPI_SUM = 0x58000003;
+  static constexpr int MPI_PROD = 0x58000004;
+  static constexpr int MPI_LAND = 0x58000005;
+  static constexpr int MPI_BAND = 0x58000006;
+  static constexpr int MPI_LOR = 0x58000007;
+  static constexpr int MPI_BOR = 0x58000008;
+  static constexpr int MPI_LXOR = 0x58000009;
+  static constexpr int MPI_BXOR = 0x5800000a;
+  static constexpr int MPI_MINLOC = 0x5800000b;
+  static constexpr int MPI_MAXLOC = 0x5800000c;
+  static constexpr int MPI_REPLACE = 0x5800000d;
+  static constexpr int MPI_NO_OP = 0x5800000e;
 
 public:
   using MPIImplTraits::MPIImplTraits;
@@ -136,6 +172,56 @@ class MPICHImplTraits : public MPIImplTraits {
       assert(false && "unsupported type");
     return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
   }
+
+  Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
+                 mpi::MPI_OpClassEnum opAttr) override {
+    int32_t op = MPI_NO_OP;
+    switch (opAttr) {
+    case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+      op = MPI_NO_OP;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MAX:
+      op = MPI_MAX;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MIN:
+      op = MPI_MIN;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_SUM:
+      op = MPI_SUM;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_PROD:
+      op = MPI_PROD;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LAND:
+      op = MPI_LAND;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BAND:
+      op = MPI_BAND;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LOR:
+      op = MPI_LOR;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BOR:
+      op = MPI_BOR;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LXOR:
+      op = MPI_LXOR;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BXOR:
+      op = MPI_BXOR;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MINLOC:
+      op = MPI_MINLOC;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+      op = MPI_MAXLOC;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_REPLACE:
+      op = MPI_REPLACE;
+      break;
+    }
+    return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
+  }
 };
 
 //===----------------------------------------------------------------------===//
@@ -205,15 +291,74 @@ class OMPIImplTraits : public MPIImplTraits {
 
     auto context = rewriter.getContext();
     // get external opaque struct pointer type
-    auto commStructT =
+    auto typeStructT =
         LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
     // make sure global op definition exists
-    getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
+    getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
     // get address of symbol
     return rewriter.create<LLVM::AddressOfOp>(
         loc, LLVM::LLVMPointerType::get(context),
         SymbolRefAttr::get(context, mtype));
   }
+
+  Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
+                 mpi::MPI_OpClassEnum opAttr) override {
+    StringRef op;
+    switch (opAttr) {
+    case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+      op = "ompi_mpi_no_op";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MAX:
+      op = "ompi_mpi_max";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MIN:
+      op = "ompi_mpi_min";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_SUM:
+      op = "ompi_mpi_sum";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_PROD:
+      op = "ompi_mpi_prod";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LAND:
+      op = "ompi_mpi_land";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BAND:
+      op = "ompi_mpi_band";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LOR:
+      op = "ompi_mpi_lor";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BOR:
+      op = "ompi_mpi_bor";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LXOR:
+      op = "ompi_mpi_lxor";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BXOR:
+      op = "ompi_mpi_bxor";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MINLOC:
+      op = "ompi_mpi_minloc";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+      op = "ompi_mpi_maxloc";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_REPLACE:
+      op = "ompi_mpi_replace";
+      break;
+    }
+    auto context = rewriter.getContext();
+    // get external opaque struct pointer type
+    auto opStructT =
+        LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
+    // make sure global op definition exists
+    getOrDefineExternalStruct(loc, rewriter, op, opStructT);
+    // get address of symbol
+    return rewriter.create<LLVM::AddressOfOp>(
+        loc, LLVM::LLVMPointerType::get(context),
+        SymbolRefAttr::get(context, op));
+  }
 };
 
 std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
@@ -365,8 +510,6 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
     Location loc = op.getLoc();
     MLIRContext *context = rewriter.getContext();
     Type i32 = rewriter.getI32Type();
-    Type i64 = rewriter.getI64Type();
-    Value memRef = adaptor.getRef();
     Type elemType = op.getRef().getType().getElementType();
 
     // ptrType `!llvm.ptr`
@@ -376,14 +519,8 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
     auto moduleOp = op->getParentOfType<ModuleOp>();
 
     // get MPI_COMM_WORLD, dataType and pointer
-    Value dataPtr =
-        rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
-    Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
-    dataPtr =
-        rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
-    Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
-                                                       ArrayRef<int64_t>{3, 0});
-    size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+    auto [dataPtr, size] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
     Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -425,7 +562,6 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
     MLIRContext *context = rewriter.getContext();
     Type i32 = rewriter.getI32Type();
     Type i64 = rewriter.getI64Type();
-    Value memRef = adaptor.getRef();
     Type elemType = op.getRef().getType().getElementType();
 
     // ptrType `!llvm.ptr`
@@ -435,14 +571,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
     auto moduleOp = op->getParentOfType<ModuleOp>();
 
     // get MPI_COMM_WORLD, dataType, status_ignore and pointer
-    Value dataPtr =
-        rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
-    Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
-    dataPtr =
-        rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
-    Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
-                                                       ArrayRef<int64_t>{3, 0});
-    size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+    auto [dataPtr, size] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
     Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -474,6 +604,55 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// AllReduceOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type i32 = rewriter.getI32Type();
+    Type elemType = op.getSendbuf().getType().getElementType();
+
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    auto [sendPtr, sendSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
+    auto [recvPtr, recvSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+    Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
+    Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
+    Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+    // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
+    //                    MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
+    auto funcType = LLVM::LLVMFunctionType::get(
+        i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
+              commWorld.getType()});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp funcDecl =
+        getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
+
+    // replace op with function call
+    auto funcCall = rewriter.create<LLVM::CallOp>(
+        loc, funcDecl,
+        ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
+
+    if (op.getRetval())
+      rewriter.replaceOp(op, funcCall.getResult());
+    else
+      rewriter.eraseOp(op);
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertToLLVMPatternInterface implementation
 //===----------------------------------------------------------------------===//
@@ -498,7 +677,7 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
 void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                               RewritePatternSet &patterns) {
   patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
-               SendOpLowering, RecvOpLowering>(converter);
+               SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter);
 }
 
 void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
similarity index 78%
rename from mlir/test/Conversion/MPIToLLVM/ops.mlir
rename to mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 3c1b344efd50b..249ef195e8f5c 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -1,13 +1,13 @@
 // RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s
 
 // COM: Test MPICH ABI
-// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH">} {
+// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
 // CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
-module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 
   // CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
   func.func @mpi_test_mpich(%arg0: memref<100xf32>) {
@@ -73,7 +73,22 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
     // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
     %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
-    // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
+    // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+    // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+    // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+    // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
+    // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+
+    // CHECK: llvm.call @MPI_Finalize() : () -> i32
     %3 = mpi.finalize : !mpi.retval
 
     return
@@ -83,7 +98,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
 // -----
 
 // COM: Test OpenMPI ABI
-// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
+// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
@@ -91,7 +106,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
 // CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
 // CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
-module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
+module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
 
   // CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
   func.func @mpi_test_openmpi(%arg0: memref<100xf32>) {
@@ -157,6 +172,22 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
     %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
+    // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+    // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+    // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+    // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
+    // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK-NEXT: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+    mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+
     // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i...
[truncated]

@fschlimb fschlimb requested a review from jeffhammond March 26, 2025 17:55
// CHECK-NEXT: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK-NEXT: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
// CHECK-NEXT: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for llvm.call here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Will add tomorrow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

static constexpr int MPI_MINLOC = 0x5800000b;
static constexpr int MPI_MAXLOC = 0x5800000c;
static constexpr int MPI_REPLACE = 0x5800000d;
static constexpr int MPI_NO_OP = 0x5800000e;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about reusing the same values as in the incoming MPI 5.0 ABI?

https://github.com/mpi-forum/mpi-abi-stubs/blob/e89a80017a3fe9a05d903ced2564c6342d678165/mpi.h#L47-L62

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are current MPI implementations supporting this yet?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, each implementation uses their own values. you can check some of them in here https://github.com/JuliaParallel/MPI.jl/tree/master/src/api

i guess that in the future they will change to the ones in MPI 5, that's why i suggested but it's not mandatory for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is agreement that we'll go for explicit specialization for now until the ABI is broadly available.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mofeing As I said already, MPICH supports the MPI-5 ABI already (you must enable it with configure, as it is not the default). I will sync Mukautuva to the final MPI-5 ABI in a few days.
#133280 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's great to know!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great.
Once we there are at least 2 popular implementations supporting it we should switch to the ABI.

@fschlimb
Copy link
Contributor Author

fschlimb commented Mar 28, 2025

ping @AntonLydike @Dinistro

@fschlimb fschlimb merged commit 1dee125 into llvm:main Mar 31, 2025
10 of 11 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Mar 31, 2025

LLVM Buildbot has detected a new failure on builder clang-aarch64-sve2-vla running on linaro-g4-01 while building mlir at step 7 "ninja check 1".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/198/builds/3283

Here is the relevant piece of the build log for the reference
Step 7 (ninja check 1) failure: stage 1 checked (failure)
...
llvm-lit: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/llvm/llvm/utils/lit/lit/llvm/config.py:520: note: using ld64.lld: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/stage1/bin/ld64.lld
llvm-lit: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/llvm/llvm/utils/lit/lit/llvm/config.py:520: note: using wasm-ld: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/stage1/bin/wasm-ld
llvm-lit: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/llvm/llvm/utils/lit/lit/discovery.py:276: warning: input '/home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/stage1/runtimes/runtimes-bins/compiler-rt/test/interception/Unit' contained no tests
llvm-lit: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/llvm/llvm/utils/lit/lit/discovery.py:276: warning: input '/home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/stage1/runtimes/runtimes-bins/compiler-rt/test/sanitizer_common/Unit' contained no tests
llvm-lit: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/llvm/llvm/utils/lit/lit/llvm/config.py:520: note: using ld.lld: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/stage1/bin/ld.lld
llvm-lit: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/llvm/llvm/utils/lit/lit/llvm/config.py:520: note: using lld-link: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/stage1/bin/lld-link
llvm-lit: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/llvm/llvm/utils/lit/lit/llvm/config.py:520: note: using ld64.lld: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/stage1/bin/ld64.lld
llvm-lit: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/llvm/llvm/utils/lit/lit/llvm/config.py:520: note: using wasm-ld: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla/stage1/bin/wasm-ld
-- Testing: 96509 tests, 48 workers --
UNRESOLVED: Flang :: Driver/slp-vectorize.ll (1 of 96509)
******************** TEST 'Flang :: Driver/slp-vectorize.ll' FAILED ********************
Test has no 'RUN:' line
********************
PASS: ThreadSanitizer-aarch64 :: compare_exchange_acquire_fence.cpp (2 of 96509)
PASS: Clang :: Driver/clang_f_opts.c (3 of 96509)
PASS: UBSan-AddressSanitizer-aarch64 :: TestCases/ImplicitConversion/signed-integer-truncation-ignorelist.c (4 of 96509)
PASS: UBSan-ThreadSanitizer-aarch64 :: TestCases/ImplicitConversion/unsigned-integer-truncation-ignorelist.c (5 of 96509)
PASS: libFuzzer-aarch64-default-Linux :: sigusr.test (6 of 96509)
PASS: UBSan-ThreadSanitizer-aarch64 :: TestCases/ImplicitConversion/signed-integer-truncation-ignorelist.c (7 of 96509)
PASS: LLVM :: CodeGen/ARM/build-attributes.ll (8 of 96509)
PASS: Clang :: Driver/linux-ld.c (9 of 96509)
PASS: Clang :: OpenMP/target_teams_distribute_parallel_for_simd_codegen_registration.cpp (10 of 96509)
PASS: libFuzzer-aarch64-default-Linux :: msan.test (11 of 96509)
PASS: LLVM :: CodeGen/AMDGPU/sched-group-barrier-pipeline-solver.mir (12 of 96509)
PASS: MemorySanitizer-AARCH64 :: release_origin.c (13 of 96509)
PASS: libFuzzer-aarch64-default-Linux :: sigint.test (14 of 96509)
PASS: libFuzzer-aarch64-default-Linux :: merge-sigusr.test (15 of 96509)
PASS: LLVM :: CodeGen/Hexagon/isel/isel-tfrrp.ll (16 of 96509)
PASS: ScudoStandalone-Unit :: ./ScudoUnitTest-aarch64-Test/53/71 (17 of 96509)
PASS: ScudoStandalone-Unit-GwpAsanTorture :: ./ScudoUnitTest-aarch64-Test/53/71 (18 of 96509)
PASS: Clang :: Preprocessor/predefined-arch-macros.c (19 of 96509)
PASS: Clang :: Analysis/runtime-regression.c (20 of 96509)
PASS: ThreadSanitizer-aarch64 :: deadlock_detector_stress_test.cpp (21 of 96509)
PASS: HWAddressSanitizer-aarch64 :: TestCases/Linux/create-thread-stress.cpp (22 of 96509)
PASS: Clang :: CodeGen/AArch64/sve-intrinsics/acle_sve_reinterpret-bfloat.c (23 of 96509)
PASS: Clang :: Preprocessor/aarch64-target-features.c (24 of 96509)
PASS: Clang :: Analysis/a_flaky_crash.cpp (25 of 96509)
PASS: Clang :: Preprocessor/arm-target-features.c (26 of 96509)
PASS: libFuzzer-aarch64-default-Linux :: fuzzer-timeout.test (27 of 96509)
PASS: Profile-aarch64 :: Posix/instrprof-value-prof-shared.test (28 of 96509)
PASS: LLVM :: CodeGen/AMDGPU/memintrinsic-unroll.ll (29 of 96509)
PASS: Clang :: Driver/arm-cortex-cpus-1.c (30 of 96509)
PASS: LLVM :: CodeGen/RISCV/attributes.ll (31 of 96509)
PASS: Clang :: CodeGen/AArch64/sve-intrinsics/acle_sve_reinterpret.c (32 of 96509)
PASS: Clang :: Driver/arm-cortex-cpus-2.c (33 of 96509)
PASS: ThreadSanitizer-aarch64 :: force_background_thread.cpp (34 of 96509)
PASS: LLVM-Unit :: Support/./SupportTests/26/48 (35 of 96509)
PASS: Clang :: OpenMP/target_defaultmap_codegen_01.cpp (36 of 96509)
PASS: Clang :: OpenMP/target_update_codegen.cpp (37 of 96509)

@llvm-ci
Copy link
Collaborator

llvm-ci commented Mar 31, 2025

LLVM Buildbot has detected a new failure on builder clang-aarch64-sve-vla-2stage running on linaro-g3-01 while building mlir at step 12 "ninja check 2".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/41/builds/5853

Here is the relevant piece of the build log for the reference
Step 12 (ninja check 2) failure: stage 2 checked (failure)
...
PASS: Flang :: Driver/print-resource-dir.F90 (25172 of 97491)
PASS: Flang :: Driver/parse-error.ll (25173 of 97491)
PASS: Clangd Unit Tests :: ./ClangdTests/80/81 (25174 of 97491)
PASS: Flang :: Driver/predefined-macros-x86.f90 (25175 of 97491)
PASS: Flang :: Driver/override-triple.ll (25176 of 97491)
PASS: Flang :: Driver/implicit-none.f90 (25177 of 97491)
PASS: Flang :: Driver/parse-fir-error.ll (25178 of 97491)
PASS: Flang :: Driver/macro-def-undef.F90 (25179 of 97491)
PASS: Flang :: Driver/phases.f90 (25180 of 97491)
UNRESOLVED: Flang :: Driver/slp-vectorize.ll (25181 of 97491)
******************** TEST 'Flang :: Driver/slp-vectorize.ll' FAILED ********************
Test has no 'RUN:' line
********************
PASS: Flang :: Driver/bbc-openmp-version-macro.f90 (25182 of 97491)
PASS: Flang :: Driver/parse-ir-error.f95 (25183 of 97491)
PASS: Flang :: Driver/predefined-macros-compiler-version.F90 (25184 of 97491)
PASS: Flang :: Driver/include-header.f90 (25185 of 97491)
PASS: Flang :: Driver/mlir-pass-pipeline.f90 (25186 of 97491)
PASS: Clangd Unit Tests :: ./ClangdTests/71/81 (25187 of 97491)
PASS: Flang :: Driver/fd-lines-as.f90 (25188 of 97491)
PASS: Flang :: Driver/linker-flags.f90 (25189 of 97491)
PASS: Flang :: Driver/pthread.f90 (25190 of 97491)
PASS: DataFlowSanitizer-aarch64 :: origin_ldst.c (25191 of 97491)
PASS: Flang :: Driver/print-pipeline-passes.f90 (25192 of 97491)
PASS: Flang :: Driver/mlink-builtin-bc.f90 (25193 of 97491)
PASS: Flang :: Driver/missing-arg.f90 (25194 of 97491)
PASS: Flang :: Driver/pass-plugin-not-found.f90 (25195 of 97491)
PASS: Flang :: Driver/scanning-error.f95 (25196 of 97491)
PASS: Flang :: Driver/print-target-triple.f90 (25197 of 97491)
PASS: Flang :: Driver/supported-suffices/f08-suffix.f08 (25198 of 97491)
PASS: Flang :: Driver/pp-fixed-form.f90 (25199 of 97491)
PASS: Flang :: Driver/std2018-wrong.f90 (25200 of 97491)
PASS: Flang :: Driver/lto-bc.f90 (25201 of 97491)
PASS: Flang :: Driver/supported-suffices/f03-suffix.f03 (25202 of 97491)
PASS: Flang :: Driver/target-gpu-features.f90 (25203 of 97491)
PASS: Flang :: Driver/tco-code-gen-llvm.fir (25204 of 97491)
PASS: Flang :: Driver/target.f90 (25205 of 97491)
PASS: Flang :: Driver/q-unused-arguments.f90 (25206 of 97491)
PASS: Clangd Unit Tests :: ./ClangdTests/68/81 (25207 of 97491)
PASS: Flang :: Driver/unsupported-vscale-max-min.f90 (25208 of 97491)
PASS: Flang :: Driver/mllvm.f90 (25209 of 97491)
PASS: Flang :: Driver/config-file.f90 (25210 of 97491)
PASS: Flang :: Driver/unparse-with-modules.f90 (25211 of 97491)
PASS: Flang :: Driver/multiple-input-files.f90 (25212 of 97491)
PASS: Flang :: Driver/no-duplicate-main.f90 (25213 of 97491)
PASS: Flang :: Driver/target-machine-error.f90 (25214 of 97491)
PASS: Flang :: Driver/input-from-stdin/input-from-stdin.f90 (25215 of 97491)
PASS: Flang :: Driver/unparse-use-analyzed.f95 (25216 of 97491)
PASS: Flang :: Driver/prescanner-diag.f90 (25217 of 97491)

@llvm-ci
Copy link
Collaborator

llvm-ci commented Mar 31, 2025

LLVM Buildbot has detected a new failure on builder clang-aarch64-sve2-vla-2stage running on linaro-g4-02 while building mlir at step 12 "ninja check 2".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/199/builds/2489

Here is the relevant piece of the build log for the reference
Step 12 (ninja check 2) failure: stage 2 checked (failure)
...
PASS: Flang :: Driver/macro-def-undef.F90 (24741 of 96446)
PASS: Flang :: Driver/include-header.f90 (24742 of 96446)
PASS: Flang :: Driver/print-effective-triple.f90 (24743 of 96446)
PASS: Flang :: Driver/predefined-macros-compiler-version.F90 (24744 of 96446)
PASS: Flang :: Driver/print-resource-dir.F90 (24745 of 96446)
PASS: Flang :: Driver/mlink-builtin-bc.f90 (24746 of 96446)
PASS: Flang :: Driver/parse-fir-error.ll (24747 of 96446)
PASS: Flang :: Driver/phases.f90 (24748 of 96446)
PASS: Flang :: Driver/missing-arg.f90 (24749 of 96446)
UNRESOLVED: Flang :: Driver/slp-vectorize.ll (24750 of 96446)
******************** TEST 'Flang :: Driver/slp-vectorize.ll' FAILED ********************
Test has no 'RUN:' line
********************
PASS: Flang :: Driver/parse-error.ll (24751 of 96446)
PASS: Flang :: Driver/override-triple.ll (24752 of 96446)
PASS: Flang :: Driver/print-pipeline-passes.f90 (24753 of 96446)
PASS: Flang :: Driver/mlir-pass-pipeline.f90 (24754 of 96446)
PASS: Flang :: Driver/parse-ir-error.f95 (24755 of 96446)
PASS: Flang :: Driver/print-target-triple.f90 (24756 of 96446)
PASS: Flang :: Driver/lto-bc.f90 (24757 of 96446)
PASS: Flang :: Driver/pthread.f90 (24758 of 96446)
PASS: Flang :: Driver/pp-fixed-form.f90 (24759 of 96446)
PASS: Flang :: Driver/scanning-error.f95 (24760 of 96446)
PASS: Flang :: Driver/std2018-wrong.f90 (24761 of 96446)
PASS: Flang :: Driver/supported-suffices/f08-suffix.f08 (24762 of 96446)
PASS: Flang :: Driver/supported-suffices/f03-suffix.f03 (24763 of 96446)
PASS: Flang :: Driver/pass-plugin-not-found.f90 (24764 of 96446)
PASS: Flang :: Driver/lto-flags.f90 (24765 of 96446)
PASS: Flang :: Driver/target-gpu-features.f90 (24766 of 96446)
PASS: Flang :: Driver/tco-code-gen-llvm.fir (24767 of 96446)
PASS: Flang :: Driver/fixed-line-length.f90 (24768 of 96446)
PASS: Flang :: Driver/q-unused-arguments.f90 (24769 of 96446)
PASS: Flang :: Driver/target.f90 (24770 of 96446)
PASS: Flang :: Driver/multiple-input-files.f90 (24771 of 96446)
PASS: Flang :: Driver/mllvm.f90 (24772 of 96446)
PASS: Flang :: Driver/unsupported-vscale-max-min.f90 (24773 of 96446)
PASS: Flang :: Driver/unparse-with-modules.f90 (24774 of 96446)
PASS: Flang :: Driver/fsave-optimization-record.f90 (24775 of 96446)
PASS: Flang :: Driver/no-duplicate-main.f90 (24776 of 96446)
PASS: Flang :: Driver/input-from-stdin/input-from-stdin.f90 (24777 of 96446)
PASS: Flang :: Driver/target-machine-error.f90 (24778 of 96446)
PASS: Flang :: Driver/std2018.f90 (24779 of 96446)
PASS: Flang :: Driver/unparse-use-analyzed.f95 (24780 of 96446)
PASS: Flang :: Driver/optimization-remark-invalid.f90 (24781 of 96446)
PASS: Flang :: Driver/save-temps-use-module.f90 (24782 of 96446)
PASS: Flang :: Driver/response-file.f90 (24783 of 96446)
PASS: Flang :: Driver/prescanner-diag.f90 (24784 of 96446)
PASS: Flang :: Driver/save-temps.f90 (24785 of 96446)
PASS: Flang :: Driver/target-cpu-features-invalid.f90 (24786 of 96446)

SchrodingerZhu pushed a commit to SchrodingerZhu/llvm-project that referenced this pull request Mar 31, 2025
Lowering of mpi.all_reduce to LLVM function call
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants