Skip to content

Commit 9cd03e7

Browse files
committed
[CIR][ThroughMLIR] Fix ForOp handling
Currently the ForOp handling ignores everything except load, store and arithmetic. It does not detect whether the step and induction variable has been assigned, either. That might result to wrong behaviour: ```cpp // Example 1: ignores printf for (int i = 0; i < n; i++, printf("\n")); // Example 2: only increment once for (int i = 0; i < n; i++, i++); ``` I choose to rewrite the detection and do an exact match of the instruction sequence. Though this will allow fewer recognitions, this will preserve soundness.
1 parent 5df5009 commit 9cd03e7

File tree

2 files changed

+107
-43
lines changed

2 files changed

+107
-43
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 97 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,22 @@ class SCFLoop {
4343
mlir::Value getLowerBound() { return lowerBound; }
4444
mlir::Value getUpperBound() { return upperBound; }
4545

46-
int64_t findStepAndIV(mlir::Value &addr);
46+
// Returns true if successfully finds both step and induction variable.
47+
bool findStepAndIV();
4748
cir::CmpOp findCmpOp();
4849
mlir::Value findIVInitValue();
4950
void analysis();
5051

51-
mlir::Value plusConstant(mlir::Value V, mlir::Location loc, int addend);
52+
mlir::Value plusConstant(mlir::Value v, mlir::Location loc, int addend);
5253
void transferToSCFForOp();
5354

5455
private:
5556
cir::ForOp forOp;
5657
cir::CmpOp cmpOp;
57-
mlir::Value IVAddr, lowerBound = nullptr, upperBound = nullptr;
58+
mlir::Value ivAddr, lowerBound = nullptr, upperBound = nullptr;
5859
mlir::ConversionPatternRewriter *rewriter;
5960
int64_t step = 0;
61+
bool canonical = true;
6062
};
6163

6264
class SCFWhileLoop {
@@ -86,47 +88,96 @@ class SCFDoLoop {
8688
};
8789

8890
static int64_t getConstant(cir::ConstantOp op) {
89-
auto attr = op->getAttrs().front().getValue();
90-
const auto IntAttr = mlir::dyn_cast<cir::IntAttr>(attr);
91-
return IntAttr.getValue().getSExtValue();
91+
auto attr = op.getValue();
92+
const auto intAttr = mlir::cast<cir::IntAttr>(attr);
93+
return intAttr.getValue().getSExtValue();
9294
}
9395

94-
int64_t SCFLoop::findStepAndIV(mlir::Value &addr) {
96+
bool SCFLoop::findStepAndIV() {
9597
auto *stepBlock =
9698
(forOp.maybeGetStep() ? &forOp.maybeGetStep()->front() : nullptr);
9799
assert(stepBlock && "Can not find step block");
98100

99-
int64_t step = 0;
100-
mlir::Value IV = nullptr;
101-
// Try to match "IV load addr; ++IV; store IV, addr" to find step.
102-
for (mlir::Operation &op : *stepBlock)
103-
if (auto loadOp = dyn_cast<cir::LoadOp>(op)) {
104-
addr = loadOp.getAddr();
105-
IV = loadOp.getResult();
106-
} else if (auto cop = dyn_cast<cir::ConstantOp>(op)) {
107-
if (step)
108-
llvm_unreachable(
109-
"Not support multiple constant in step calculation yet");
110-
step = getConstant(cop);
111-
} else if (auto bop = dyn_cast<cir::BinOp>(op)) {
112-
if (bop.getLhs() != IV)
113-
llvm_unreachable("Find BinOp not operate on IV");
114-
if (bop.getKind() != cir::BinOpKind::Add)
115-
llvm_unreachable(
116-
"Not support BinOp other than Add in step calculation yet");
117-
} else if (auto uop = dyn_cast<cir::UnaryOp>(op)) {
118-
if (uop.getInput() != IV)
119-
llvm_unreachable("Find UnaryOp not operate on IV");
120-
if (uop.getKind() == cir::UnaryOpKind::Inc)
121-
step = 1;
122-
else if (uop.getKind() == cir::UnaryOpKind::Dec)
123-
llvm_unreachable("Not support decrement step yet");
124-
} else if (auto storeOp = dyn_cast<cir::StoreOp>(op)) {
125-
assert(storeOp.getAddr() == addr && "Can't find IV when lowering ForOp");
126-
}
127-
assert(step && "Can't find step when lowering ForOp");
128101

129-
return step;
102+
// Try to match "iv = load addr; ++iv; store iv, addr; yield" to find step.
103+
// We should match the exact pattern, in case there's something unexpected:
104+
// we must rule out cases like `for (int i = 0; i < n; i++, printf("\n"))`.
105+
auto &oplist = stepBlock->getOperations();
106+
107+
auto iterator = oplist.begin();
108+
109+
// We might find constants at beginning. Skip them.
110+
// We could have hoisted them outside the for loop in previous passes, but
111+
// it hasn't been done yet.
112+
while (iterator != oplist.end() && isa<ConstantOp>(*iterator))
113+
++iterator;
114+
115+
if (iterator == oplist.end())
116+
return false;
117+
118+
auto load = dyn_cast<LoadOp>(*iterator);
119+
if (!load)
120+
return false;
121+
122+
// We assume this is the address of induction variable (IV). The operations that come
123+
// next will check if that's true.
124+
mlir::Value addr = load.getAddr();
125+
mlir::Value iv = load.getResult();
126+
127+
// Then we try to match either "++IV" or "IV += n". Same for reversed loops.
128+
if (++iterator == oplist.end())
129+
return false;
130+
131+
mlir::Operation &arith = *iterator;
132+
133+
if (auto unary = dyn_cast<UnaryOp>(arith)) {
134+
// Not operating on induction variable. Fail.
135+
if (unary.getInput() != iv)
136+
return false;
137+
138+
if (unary.getKind() == UnaryOpKind::Inc)
139+
step = 1;
140+
else if (unary.getKind() == UnaryOpKind::Dec)
141+
step = -1;
142+
else
143+
return false;
144+
}
145+
146+
if (auto binary = dyn_cast<BinOp>(arith)) {
147+
if (binary.getLhs() != iv)
148+
return false;
149+
150+
mlir::Value value = binary.getRhs();
151+
if (auto constValue = dyn_cast<ConstantOp>(value.getDefiningOp()); isa<IntAttr>(constValue.getValue()))
152+
step = getConstant(constValue);
153+
154+
if (binary.getKind() == BinOpKind::Add)
155+
; // Nothing to do. Step has been calculated above.
156+
else if (binary.getKind() == BinOpKind::Sub)
157+
step = -step;
158+
else
159+
return false;
160+
}
161+
162+
// Check whether we immediately store this value into the appropriate place.
163+
if (++iterator == oplist.end())
164+
return false;
165+
166+
auto store = dyn_cast<StoreOp>(*iterator);
167+
if (!store || store.getAddr() != addr || store.getValue() != arith.getResult(0))
168+
return false;
169+
170+
if (++iterator == oplist.end())
171+
return false;
172+
173+
// Finally, this should precede a yield with nothing in between.
174+
bool success = isa<YieldOp>(*iterator);
175+
176+
// Remember to update analysis information.
177+
if (success)
178+
ivAddr = addr;
179+
180+
return success;
130181
}
131182

132183
static bool isIVLoad(mlir::Operation *op, mlir::Value IVAddr) {
@@ -143,7 +194,7 @@ static bool isIVLoad(mlir::Operation *op, mlir::Value IVAddr) {
143194

144195
cir::CmpOp SCFLoop::findCmpOp() {
145196
cmpOp = nullptr;
146-
for (auto *user : IVAddr.getUsers()) {
197+
for (auto *user : ivAddr.getUsers()) {
147198
if (user->getParentRegion() != &forOp.getCond())
148199
continue;
149200
if (auto loadOp = dyn_cast<cir::LoadOp>(*user)) {
@@ -162,10 +213,10 @@ cir::CmpOp SCFLoop::findCmpOp() {
162213
if (!mlir::isa<cir::IntType>(type))
163214
llvm_unreachable("Non-integer type IV is not supported");
164215

165-
auto lhsDefOp = cmpOp.getLhs().getDefiningOp();
216+
auto *lhsDefOp = cmpOp.getLhs().getDefiningOp();
166217
if (!lhsDefOp)
167218
llvm_unreachable("Can't find IV load");
168-
if (!isIVLoad(lhsDefOp, IVAddr))
219+
if (!isIVLoad(lhsDefOp, ivAddr))
169220
llvm_unreachable("cmpOp LHS is not IV");
170221

171222
if (cmpOp.getKind() != cir::CmpOpKind::le &&
@@ -187,7 +238,7 @@ mlir::Value SCFLoop::plusConstant(mlir::Value V, mlir::Location loc,
187238
// The operations before the loop have been transferred to MLIR.
188239
// So we need to go through getRemappedValue to find the value.
189240
mlir::Value SCFLoop::findIVInitValue() {
190-
auto remapAddr = rewriter->getRemappedValue(IVAddr);
241+
auto remapAddr = rewriter->getRemappedValue(ivAddr);
191242
if (!remapAddr)
192243
return nullptr;
193244
if (!remapAddr.hasOneUse())
@@ -199,7 +250,10 @@ mlir::Value SCFLoop::findIVInitValue() {
199250
}
200251

201252
void SCFLoop::analysis() {
202-
step = findStepAndIV(IVAddr);
253+
canonical = findStepAndIV();
254+
if (!canonical)
255+
llvm_unreachable("Non-canonical for loops are not yet handled");
256+
203257
cmpOp = findCmpOp();
204258
auto IVInit = findIVInitValue();
205259
// The loop end value should be hoisted out of loop by -cir-mlir-scf-prepare.
@@ -237,7 +291,7 @@ void SCFLoop::transferToSCFForOp() {
237291
llvm_unreachable(
238292
"Not support lowering loop with break, continue or if yet");
239293
// Replace the IV usage to scf loop induction variable.
240-
if (isIVLoad(op, IVAddr)) {
294+
if (isIVLoad(op, ivAddr)) {
241295
// Replace CIR IV load with arith.addi scf.IV, 0.
242296
// The replacement makes the SCF IV can be automatically propogated
243297
// by OpAdaptor for individual IV user lowering.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
// XFAIL: *
4+
5+
void f();
6+
7+
void reject() {
8+
for (int i = 0; i < 100; i++, f());
9+
for (int i = 0; i < 100; i++, i++);
10+
}

0 commit comments

Comments
 (0)