Skip to content

Commit 0159a26

Browse files
authored
[InlineCost]: Add a new heuristic to branch folding for better inlining decisions.
Recursive functions are generally not inlined to avoid issues like infinite inlining or excessive code expansion. However, this conservative approach misses opportunities for optimization in cases where a recursive call is guaranteed to execute only once. This patch detects a scenario where a guarding branch condition of a recursive call will become false after the first iteration of the recursive function. If such a condition is met, and the recursion depth is confirmed to be one, the Inliner will now consider this recursive function for inlining. A new test case (`test/Transforms/Inline/inline-recursive-fn.ll`) has been added to verify this behaviour.
1 parent 8ea5eac commit 0159a26

File tree

2 files changed

+274
-0
lines changed

2 files changed

+274
-0
lines changed

llvm/lib/Analysis/InlineCost.cpp

+81
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/Analysis/BlockFrequencyInfo.h"
2121
#include "llvm/Analysis/CodeMetrics.h"
2222
#include "llvm/Analysis/ConstantFolding.h"
23+
#include "llvm/Analysis/DomConditionCache.h"
2324
#include "llvm/Analysis/EphemeralValuesCache.h"
2425
#include "llvm/Analysis/InstructionSimplify.h"
2526
#include "llvm/Analysis/LoopInfo.h"
@@ -262,6 +263,8 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
262263
// Cache the DataLayout since we use it a lot.
263264
const DataLayout &DL;
264265

266+
DominatorTree DT;
267+
265268
/// The OptimizationRemarkEmitter available for this compilation.
266269
OptimizationRemarkEmitter *ORE;
267270

@@ -444,6 +447,7 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
444447
bool canFoldInboundsGEP(GetElementPtrInst &I);
445448
bool accumulateGEPOffset(GEPOperator &GEP, APInt &Offset);
446449
bool simplifyCallSite(Function *F, CallBase &Call);
450+
bool simplifyCmpInstForRecCall(CmpInst &Cmp);
447451
bool simplifyInstruction(Instruction &I);
448452
bool simplifyIntrinsicCallIsConstant(CallBase &CB);
449453
bool simplifyIntrinsicCallObjectSize(CallBase &CB);
@@ -1676,6 +1680,79 @@ bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) {
16761680
return isGEPFree(I);
16771681
}
16781682

1683+
// Simplify \p Cmp if RHS is const and we can ValueTrack LHS.
1684+
// This handles the case only when the Cmp instruction is guarding a recursive
1685+
// call that will cause the Cmp to fail/succeed for the recursive call.
1686+
bool CallAnalyzer::simplifyCmpInstForRecCall(CmpInst &Cmp) {
1687+
// Bail out if LHS is not a function argument or RHS is NOT const:
1688+
if (!isa<Argument>(Cmp.getOperand(0)) || !isa<Constant>(Cmp.getOperand(1)))
1689+
return false;
1690+
auto *CmpOp = Cmp.getOperand(0);
1691+
Function *F = Cmp.getFunction();
1692+
// Iterate over the users of the function to check if it's a recursive
1693+
// function:
1694+
for (auto *U : F->users()) {
1695+
CallInst *Call = dyn_cast<CallInst>(U);
1696+
if (!Call || Call->getFunction() != F || Call->getCalledFunction() != F)
1697+
continue;
1698+
auto *CallBB = Call->getParent();
1699+
auto *Predecessor = CallBB->getSinglePredecessor();
1700+
// Only handle the case when the callsite has a single predecessor:
1701+
if (!Predecessor)
1702+
continue;
1703+
1704+
auto *Br = dyn_cast<BranchInst>(Predecessor->getTerminator());
1705+
if (!Br || Br->isUnconditional())
1706+
continue;
1707+
// Check if the Br condition is the same Cmp instr we are investigating:
1708+
if (Br->getCondition() != &Cmp)
1709+
continue;
1710+
// Check if there are any arg of the recursive callsite is affecting the cmp
1711+
// instr:
1712+
bool ArgFound = false;
1713+
Value *FuncArg = nullptr, *CallArg = nullptr;
1714+
for (unsigned ArgNum = 0;
1715+
ArgNum < F->arg_size() && ArgNum < Call->arg_size(); ArgNum++) {
1716+
FuncArg = F->getArg(ArgNum);
1717+
CallArg = Call->getArgOperand(ArgNum);
1718+
if (FuncArg == CmpOp && CallArg != CmpOp) {
1719+
ArgFound = true;
1720+
break;
1721+
}
1722+
}
1723+
if (!ArgFound)
1724+
continue;
1725+
// Now we have a recursive call that is guarded by a cmp instruction.
1726+
// Check if this cmp can be simplified:
1727+
SimplifyQuery SQ(DL, dyn_cast<Instruction>(CallArg));
1728+
DomConditionCache DC;
1729+
DC.registerBranch(Br);
1730+
SQ.DC = &DC;
1731+
if (DT.root_size() == 0) {
1732+
// Dominator tree was never constructed for any function yet.
1733+
DT.recalculate(*F);
1734+
} else if (DT.getRoot()->getParent() != F) {
1735+
// Dominator tree was constructed for a different function, recalculate
1736+
// it for the current function.
1737+
DT.recalculate(*F);
1738+
}
1739+
SQ.DT = &DT;
1740+
Value *SimplifiedInstruction = llvm::simplifyInstructionWithOperands(
1741+
cast<CmpInst>(&Cmp), {CallArg, Cmp.getOperand(1)}, SQ);
1742+
if (auto *ConstVal = dyn_cast_or_null<ConstantInt>(SimplifiedInstruction)) {
1743+
bool IsTrueSuccessor = CallBB == Br->getSuccessor(0);
1744+
// Make sure that the BB of the recursive call is NOT the next successor
1745+
// of the icmp. In other words, make sure that the recursion depth is 1.
1746+
if ((ConstVal->isOne() && !IsTrueSuccessor) ||
1747+
(ConstVal->isZero() && IsTrueSuccessor)) {
1748+
SimplifiedValues[&Cmp] = ConstVal;
1749+
return true;
1750+
}
1751+
}
1752+
}
1753+
return false;
1754+
}
1755+
16791756
/// Simplify \p I if its operands are constants and update SimplifiedValues.
16801757
bool CallAnalyzer::simplifyInstruction(Instruction &I) {
16811758
SmallVector<Constant *> COps;
@@ -2060,6 +2137,10 @@ bool CallAnalyzer::visitCmpInst(CmpInst &I) {
20602137
if (simplifyInstruction(I))
20612138
return true;
20622139

2140+
// Try to handle comparison that can be simplified using ValueTracking.
2141+
if (simplifyCmpInstForRecCall(I))
2142+
return true;
2143+
20632144
if (I.getOpcode() == Instruction::FCmp)
20642145
return false;
20652146

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -passes='inline,instcombine' < %s | FileCheck %s
3+
4+
define float @inline_rec_true_successor(float %x, float %scale) {
5+
; CHECK-LABEL: define float @inline_rec_true_successor(
6+
; CHECK-SAME: float [[X:%.*]], float [[SCALE:%.*]]) {
7+
; CHECK-NEXT: [[ENTRY:.*:]]
8+
; CHECK-NEXT: [[CMP:%.*]] = fcmp olt float [[X]], 0.000000e+00
9+
; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
10+
; CHECK: [[COMMON_RET18:.*]]:
11+
; CHECK-NEXT: [[COMMON_RET18_OP:%.*]] = phi float [ [[COMMON_RET18_OP_I:%.*]], %[[INLINE_REC_TRUE_SUCCESSOR_EXIT:.*]] ], [ [[MUL:%.*]], %[[IF_END]] ]
12+
; CHECK-NEXT: ret float [[COMMON_RET18_OP]]
13+
; CHECK: [[IF_THEN]]:
14+
; CHECK-NEXT: br i1 false, label %[[IF_THEN_I:.*]], label %[[IF_END_I:.*]]
15+
; CHECK: [[IF_THEN_I]]:
16+
; CHECK-NEXT: br label %[[INLINE_REC_TRUE_SUCCESSOR_EXIT]]
17+
; CHECK: [[IF_END_I]]:
18+
; CHECK-NEXT: [[FNEG:%.*]] = fneg float [[X]]
19+
; CHECK-NEXT: [[MUL_I:%.*]] = fmul float [[SCALE]], [[FNEG]]
20+
; CHECK-NEXT: br label %[[INLINE_REC_TRUE_SUCCESSOR_EXIT]]
21+
; CHECK: [[INLINE_REC_TRUE_SUCCESSOR_EXIT]]:
22+
; CHECK-NEXT: [[COMMON_RET18_OP_I]] = phi float [ poison, %[[IF_THEN_I]] ], [ [[MUL_I]], %[[IF_END_I]] ]
23+
; CHECK-NEXT: br label %[[COMMON_RET18]]
24+
; CHECK: [[IF_END]]:
25+
; CHECK-NEXT: [[MUL]] = fmul float [[X]], [[SCALE]]
26+
; CHECK-NEXT: br label %[[COMMON_RET18]]
27+
;
28+
entry:
29+
%cmp = fcmp olt float %x, 0.000000e+00
30+
br i1 %cmp, label %if.then, label %if.end
31+
32+
common.ret18: ; preds = %if.then, %if.end
33+
%common.ret18.op = phi float [ %call, %if.then ], [ %mul, %if.end ]
34+
ret float %common.ret18.op
35+
36+
if.then: ; preds = %entry
37+
%fneg = fneg float %x
38+
%call = tail call float @inline_rec_true_successor(float %fneg, float %scale)
39+
br label %common.ret18
40+
41+
if.end: ; preds = %entry
42+
%mul = fmul float %x, %scale
43+
br label %common.ret18
44+
}
45+
46+
; Same as previous test except that the recursive callsite is in the false successor
47+
define float @inline_rec_false_successor(float %x, float %scale) {
48+
; CHECK-LABEL: define float @inline_rec_false_successor(
49+
; CHECK-SAME: float [[Y:%.*]], float [[SCALE:%.*]]) {
50+
; CHECK-NEXT: [[ENTRY:.*:]]
51+
; CHECK-NEXT: [[CMP:%.*]] = fcmp uge float [[Y]], 0.000000e+00
52+
; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
53+
; CHECK: [[COMMON_RET18:.*]]:
54+
; CHECK-NEXT: [[COMMON_RET18_OP:%.*]] = phi float [ [[MUL:%.*]], %[[IF_THEN]] ], [ [[COMMON_RET18_OP_I:%.*]], %[[INLINE_REC_FALSE_SUCCESSOR_EXIT:.*]] ]
55+
; CHECK-NEXT: ret float [[COMMON_RET18_OP]]
56+
; CHECK: [[IF_THEN]]:
57+
; CHECK-NEXT: [[MUL]] = fmul float [[Y]], [[SCALE]]
58+
; CHECK-NEXT: br label %[[COMMON_RET18]]
59+
; CHECK: [[IF_END]]:
60+
; CHECK-NEXT: br i1 true, label %[[IF_THEN_I:.*]], label %[[IF_END_I:.*]]
61+
; CHECK: [[IF_THEN_I]]:
62+
; CHECK-NEXT: [[FNEG:%.*]] = fneg float [[Y]]
63+
; CHECK-NEXT: [[MUL_I:%.*]] = fmul float [[SCALE]], [[FNEG]]
64+
; CHECK-NEXT: br label %[[INLINE_REC_FALSE_SUCCESSOR_EXIT]]
65+
; CHECK: [[IF_END_I]]:
66+
; CHECK-NEXT: br label %[[INLINE_REC_FALSE_SUCCESSOR_EXIT]]
67+
; CHECK: [[INLINE_REC_FALSE_SUCCESSOR_EXIT]]:
68+
; CHECK-NEXT: [[COMMON_RET18_OP_I]] = phi float [ [[MUL_I]], %[[IF_THEN_I]] ], [ poison, %[[IF_END_I]] ]
69+
; CHECK-NEXT: br label %[[COMMON_RET18]]
70+
;
71+
entry:
72+
%cmp = fcmp uge float %x, 0.000000e+00
73+
br i1 %cmp, label %if.then, label %if.end
74+
75+
common.ret18: ; preds = %if.then, %if.end
76+
%common.ret18.op = phi float [ %mul, %if.then ], [ %call, %if.end ]
77+
ret float %common.ret18.op
78+
79+
if.then: ; preds = %entry
80+
%mul = fmul float %x, %scale
81+
br label %common.ret18
82+
83+
if.end: ; preds = %entry
84+
%fneg = fneg float %x
85+
%call = tail call float @inline_rec_false_successor(float %fneg, float %scale)
86+
br label %common.ret18
87+
}
88+
89+
; Test when the BR has Value not cmp instruction
90+
define float @inline_rec_no_cmp(i1 %flag, float %scale) {
91+
; CHECK-LABEL: define float @inline_rec_no_cmp(
92+
; CHECK-SAME: i1 [[FLAG:%.*]], float [[SCALE:%.*]]) {
93+
; CHECK-NEXT: [[ENTRY:.*:]]
94+
; CHECK-NEXT: br i1 [[FLAG]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
95+
; CHECK: [[IF_THEN]]:
96+
; CHECK-NEXT: [[SUM:%.*]] = fadd float [[SCALE]], 5.000000e+00
97+
; CHECK-NEXT: [[SUM1:%.*]] = fadd float [[SUM]], [[SCALE]]
98+
; CHECK-NEXT: br label %[[COMMON_RET:.*]]
99+
; CHECK: [[IF_END]]:
100+
; CHECK-NEXT: [[SUM2:%.*]] = fadd float [[SCALE]], 5.000000e+00
101+
; CHECK-NEXT: br label %[[COMMON_RET]]
102+
; CHECK: [[COMMON_RET]]:
103+
; CHECK-NEXT: [[COMMON_RET_RES:%.*]] = phi float [ [[SUM1]], %[[IF_THEN]] ], [ [[SUM2]], %[[IF_END]] ]
104+
; CHECK-NEXT: ret float [[COMMON_RET_RES]]
105+
;
106+
entry:
107+
br i1 %flag, label %if.then, label %if.end
108+
if.then:
109+
%res = tail call float @inline_rec_no_cmp(i1 false, float %scale)
110+
%sum1 = fadd float %res, %scale
111+
br label %common.ret
112+
if.end:
113+
%sum2 = fadd float %scale, 5.000000e+00
114+
br label %common.ret
115+
common.ret:
116+
%common.ret.res = phi float [ %sum1, %if.then ], [ %sum2, %if.end ]
117+
ret float %common.ret.res
118+
}
119+
120+
define float @no_inline_rec(float %x, float %scale) {
121+
; CHECK-LABEL: define float @no_inline_rec(
122+
; CHECK-SAME: float [[Z:%.*]], float [[SCALE:%.*]]) {
123+
; CHECK-NEXT: [[ENTRY:.*:]]
124+
; CHECK-NEXT: [[CMP:%.*]] = fcmp olt float [[Z]], 5.000000e+00
125+
; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
126+
; CHECK: [[COMMON_RET18:.*]]:
127+
; CHECK-NEXT: [[COMMON_RET18_OP:%.*]] = phi float [ [[FNEG1:%.*]], %[[IF_THEN]] ], [ [[MUL:%.*]], %[[IF_END]] ]
128+
; CHECK-NEXT: ret float [[COMMON_RET18_OP]]
129+
; CHECK: [[IF_THEN]]:
130+
; CHECK-NEXT: [[FADD:%.*]] = fadd float [[Z]], 5.000000e+00
131+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @no_inline_rec(float [[FADD]], float [[SCALE]])
132+
; CHECK-NEXT: [[FNEG1]] = fneg float [[CALL]]
133+
; CHECK-NEXT: br label %[[COMMON_RET18]]
134+
; CHECK: [[IF_END]]:
135+
; CHECK-NEXT: [[MUL]] = fmul float [[Z]], [[SCALE]]
136+
; CHECK-NEXT: br label %[[COMMON_RET18]]
137+
;
138+
entry:
139+
%cmp = fcmp olt float %x, 5.000000e+00
140+
br i1 %cmp, label %if.then, label %if.end
141+
142+
common.ret18: ; preds = %if.then, %if.end
143+
%common.ret18.op = phi float [ %fneg1, %if.then ], [ %mul, %if.end ]
144+
ret float %common.ret18.op
145+
146+
if.then: ; preds = %entry
147+
%fadd = fadd float %x, 5.000000e+00
148+
%call = tail call float @no_inline_rec(float %fadd, float %scale)
149+
%fneg1 = fneg float %call
150+
br label %common.ret18
151+
152+
if.end: ; preds = %entry
153+
%mul = fmul float %x, %scale
154+
br label %common.ret18
155+
}
156+
157+
; Test when the icmp can be simplified but the recurison depth is NOT 1,
158+
; so the recursive call will not be inlined.
159+
define float @no_inline_rec_depth_not_1(float %x, float %scale) {
160+
; CHECK-LABEL: define float @no_inline_rec_depth_not_1(
161+
; CHECK-SAME: float [[X:%.*]], float [[SCALE:%.*]]) {
162+
; CHECK-NEXT: [[ENTRY:.*:]]
163+
; CHECK-NEXT: [[CMP:%.*]] = fcmp olt float [[X]], 0.000000e+00
164+
; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
165+
; CHECK: [[COMMON_RET18:.*]]:
166+
; CHECK-NEXT: [[COMMON_RET18_OP:%.*]] = phi float [ [[CALL:%.*]], %[[IF_THEN]] ], [ [[MUL:%.*]], %[[IF_END]] ]
167+
; CHECK-NEXT: ret float [[COMMON_RET18_OP]]
168+
; CHECK: [[IF_THEN]]:
169+
; CHECK-NEXT: [[CALL]] = tail call float @no_inline_rec_depth_not_1(float [[X]], float [[SCALE]])
170+
; CHECK-NEXT: br label %[[COMMON_RET18]]
171+
; CHECK: [[IF_END]]:
172+
; CHECK-NEXT: [[MUL]] = fmul float [[X]], [[SCALE]]
173+
; CHECK-NEXT: br label %[[COMMON_RET18]]
174+
;
175+
entry:
176+
%cmp = fcmp olt float %x, 0.000000e+00
177+
br i1 %cmp, label %if.then, label %if.end
178+
179+
common.ret18: ; preds = %if.then, %if.end
180+
%common.ret18.op = phi float [ %call, %if.then ], [ %mul, %if.end ]
181+
ret float %common.ret18.op
182+
183+
if.then: ; preds = %entry
184+
%fneg1 = fneg float %x
185+
%fneg = fneg float %fneg1
186+
%call = tail call float @no_inline_rec_depth_not_1(float %fneg, float %scale)
187+
br label %common.ret18
188+
189+
if.end: ; preds = %entry
190+
%mul = fmul float %x, %scale
191+
br label %common.ret18
192+
}
193+

0 commit comments

Comments
 (0)