Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Fix pred type for vote functions #134166

Merged
merged 1 commit into from
Apr 3, 2025

Conversation

clementval
Copy link
Contributor

No description provided.

@clementval clementval requested a review from wangzpgi April 2, 2025 22:43
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Apr 2, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 2, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/134166.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+9-5)
  • (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+5-4)
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 8aed288d128b6..4988b6bfb3d3f 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -6508,12 +6508,13 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
 }
 
 static mlir::Value genVoteSync(fir::FirOpBuilder &builder, mlir::Location loc,
-                               llvm::StringRef funcName,
+                               llvm::StringRef funcName, mlir::Type resTy,
                                llvm::ArrayRef<mlir::Value> args) {
   mlir::MLIRContext *context = builder.getContext();
   mlir::Type i32Ty = builder.getI32Type();
+  mlir::Type i1Ty = builder.getI1Type();
   mlir::FunctionType ftype =
-      mlir::FunctionType::get(context, {i32Ty, i32Ty}, {i32Ty});
+      mlir::FunctionType::get(context, {i32Ty, i1Ty}, {resTy});
   auto funcOp = builder.createFunction(loc, funcName, ftype);
   llvm::SmallVector<mlir::Value> filteredArgs;
   return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
@@ -6523,14 +6524,16 @@ static mlir::Value genVoteSync(fir::FirOpBuilder &builder, mlir::Location loc,
 mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
                                              llvm::ArrayRef<mlir::Value> args) {
   assert(args.size() == 2);
-  return genVoteSync(builder, loc, "llvm.nvvm.vote.all.sync", args);
+  return genVoteSync(builder, loc, "llvm.nvvm.vote.all.sync",
+                     builder.getI1Type(), args);
 }
 
 // ANY_SYNC
 mlir::Value IntrinsicLibrary::genVoteAnySync(mlir::Type resultType,
                                              llvm::ArrayRef<mlir::Value> args) {
   assert(args.size() == 2);
-  return genVoteSync(builder, loc, "llvm.nvvm.vote.any.sync", args);
+  return genVoteSync(builder, loc, "llvm.nvvm.vote.any.sync",
+                     builder.getI1Type(), args);
 }
 
 // BALLOT_SYNC
@@ -6538,7 +6541,8 @@ mlir::Value
 IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
                                     llvm::ArrayRef<mlir::Value> args) {
   assert(args.size() == 2);
-  return genVoteSync(builder, loc, "llvm.nvvm.vote.ballot.sync", args);
+  return genVoteSync(builder, loc, "llvm.nvvm.vote.ballot.sync",
+                     builder.getI32Type(), args);
 }
 
 // MATCH_ANY_SYNC
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 6a7fee73f338a..a4a4750dd61e6 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -297,10 +297,11 @@ end
 ! CHECK: fir.call @__ldcv_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
 
 attributes(device) subroutine testVote()
-  integer :: a, ipred, mask, v32
-  a = all_sync(mask, v32)
-  a = any_sync(mask, v32)
-  a = ballot_sync(mask, v32)
+  integer :: a, ipred, mask
+  logical(4) :: pred
+  a = all_sync(mask, pred)
+  a = any_sync(mask, pred)
+  a = ballot_sync(mask, pred)
 end subroutine
 
 ! CHECK-LABEL: func.func @_QPtestvote()

@clementval clementval merged commit 3e59ff2 into llvm:main Apr 3, 2025
14 checks passed
@clementval clementval deleted the cuf_fix_vote branch April 3, 2025 17:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants