Skip to content

Commit ae8dd63

Browse files
authored
[flang][cuda] Add interface and lowering for all_sync (#134001)
1 parent e25187b commit ae8dd63

File tree

4 files changed

+36
-0
lines changed

4 files changed

+36
-0
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

+1
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ struct IntrinsicLibrary {
441441
fir::ExtendedValue genUbound(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
442442
fir::ExtendedValue genUnpack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
443443
fir::ExtendedValue genVerify(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
444+
mlir::Value genVoteAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
444445

445446
/// Implement all conversion functions like DBLE, the first argument is
446447
/// the value to convert. There may be an additional KIND arguments that

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ static constexpr IntrinsicHandler handlers[]{
260260
&I::genAll,
261261
{{{"mask", asAddr}, {"dim", asValue}}},
262262
/*isElemental=*/false},
263+
{"all_sync",
264+
&I::genVoteAllSync,
265+
{{{"mask", asValue}, {"pred", asValue}}},
266+
/*isElemental=*/false},
263267
{"allocated",
264268
&I::genAllocated,
265269
{{{"array", asInquired}, {"scalar", asInquired}}},
@@ -6495,6 +6499,21 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
64956499
return value;
64966500
}
64976501

6502+
// ALL_SYNC
6503+
mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
6504+
llvm::ArrayRef<mlir::Value> args) {
6505+
assert(args.size() == 2);
6506+
6507+
llvm::StringRef funcName = "llvm.nvvm.vote.all.sync";
6508+
mlir::MLIRContext *context = builder.getContext();
6509+
mlir::Type i32Ty = builder.getI32Type();
6510+
mlir::FunctionType ftype =
6511+
mlir::FunctionType::get(context, {i32Ty, i32Ty}, {i32Ty});
6512+
auto funcOp = builder.createFunction(loc, funcName, ftype);
6513+
llvm::SmallVector<mlir::Value> filteredArgs;
6514+
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
6515+
}
6516+
64986517
// MATCH_ANY_SYNC
64996518
mlir::Value
65006519
IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,

flang/module/cudadevice.f90

+7
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,13 @@ attributes(device) integer function match_any_syncjd(mask, val)
10151015
end function
10161016
end interface
10171017

1018+
interface all_sync
1019+
attributes(device) integer function all_sync(mask, pred)
1020+
!dir$ ignore_tkr(d) mask, (td) pred
1021+
integer, value :: mask, pred
1022+
end function
1023+
end interface
1024+
10181025
! LDCG
10191026
interface __ldcg
10201027
attributes(device) pure integer(4) function __ldcg_i4(x) bind(c)

flang/test/Lower/CUDA/cuda-device-proc.cuf

+9
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,15 @@ end
296296
! CHECK: fir.call @__ldlu_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
297297
! CHECK: fir.call @__ldcv_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
298298

299+
attributes(device) subroutine testVote()
300+
integer :: a, ipred, mask, v32
301+
a = all_sync(mask, v32)
302+
303+
end subroutine
304+
305+
! CHECK-LABEL: func.func @_QPtestvote()
306+
! CHECK: fir.call @llvm.nvvm.vote.all.sync
307+
299308

300309
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
301310
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)

0 commit comments

Comments
 (0)