Skip to content

Commit 161f9ae

Browse files
committed
[CIR][ThroughMLIR] Fix ForOp handling
1 parent 5df5009 commit 161f9ae

File tree

3 files changed

+122
-43
lines changed

3 files changed

+122
-43
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 106 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,24 @@ class SCFLoop {
4242
int64_t getStep() { return step; }
4343
mlir::Value getLowerBound() { return lowerBound; }
4444
mlir::Value getUpperBound() { return upperBound; }
45+
bool isCanonical() { return canonical; }
4546

46-
int64_t findStepAndIV(mlir::Value &addr);
47+
// Returns true if successfully finds both step and induction variable.
48+
bool findStepAndIV();
4749
cir::CmpOp findCmpOp();
4850
mlir::Value findIVInitValue();
4951
void analysis();
5052

51-
mlir::Value plusConstant(mlir::Value V, mlir::Location loc, int addend);
53+
mlir::Value plusConstant(mlir::Value v, mlir::Location loc, int addend);
5254
void transferToSCFForOp();
5355

5456
private:
5557
cir::ForOp forOp;
5658
cir::CmpOp cmpOp;
57-
mlir::Value IVAddr, lowerBound = nullptr, upperBound = nullptr;
59+
mlir::Value ivAddr, lowerBound = nullptr, upperBound = nullptr;
5860
mlir::ConversionPatternRewriter *rewriter;
5961
int64_t step = 0;
62+
bool canonical = true;
6063
};
6164

6265
class SCFWhileLoop {
@@ -86,47 +89,97 @@ class SCFDoLoop {
8689
};
8790

8891
static int64_t getConstant(cir::ConstantOp op) {
89-
auto attr = op->getAttrs().front().getValue();
90-
const auto IntAttr = mlir::dyn_cast<cir::IntAttr>(attr);
91-
return IntAttr.getValue().getSExtValue();
92+
auto attr = op.getValue();
93+
const auto intAttr = mlir::cast<cir::IntAttr>(attr);
94+
return intAttr.getValue().getSExtValue();
9295
}
9396

94-
int64_t SCFLoop::findStepAndIV(mlir::Value &addr) {
97+
bool SCFLoop::findStepAndIV() {
9598
auto *stepBlock =
9699
(forOp.maybeGetStep() ? &forOp.maybeGetStep()->front() : nullptr);
97100
assert(stepBlock && "Can not find step block");
98101

99-
int64_t step = 0;
100-
mlir::Value IV = nullptr;
101-
// Try to match "IV load addr; ++IV; store IV, addr" to find step.
102-
for (mlir::Operation &op : *stepBlock)
103-
if (auto loadOp = dyn_cast<cir::LoadOp>(op)) {
104-
addr = loadOp.getAddr();
105-
IV = loadOp.getResult();
106-
} else if (auto cop = dyn_cast<cir::ConstantOp>(op)) {
107-
if (step)
108-
llvm_unreachable(
109-
"Not support multiple constant in step calculation yet");
110-
step = getConstant(cop);
111-
} else if (auto bop = dyn_cast<cir::BinOp>(op)) {
112-
if (bop.getLhs() != IV)
113-
llvm_unreachable("Find BinOp not operate on IV");
114-
if (bop.getKind() != cir::BinOpKind::Add)
115-
llvm_unreachable(
116-
"Not support BinOp other than Add in step calculation yet");
117-
} else if (auto uop = dyn_cast<cir::UnaryOp>(op)) {
118-
if (uop.getInput() != IV)
119-
llvm_unreachable("Find UnaryOp not operate on IV");
120-
if (uop.getKind() == cir::UnaryOpKind::Inc)
121-
step = 1;
122-
else if (uop.getKind() == cir::UnaryOpKind::Dec)
123-
llvm_unreachable("Not support decrement step yet");
124-
} else if (auto storeOp = dyn_cast<cir::StoreOp>(op)) {
125-
assert(storeOp.getAddr() == addr && "Can't find IV when lowering ForOp");
126-
}
127-
assert(step && "Can't find step when lowering ForOp");
102+
// Try to match "iv = load addr; ++iv; store iv, addr; yield" to find step.
103+
// We should match the exact pattern, in case there's something unexpected:
104+
// we must rule out cases like `for (int i = 0; i < n; i++, printf("\n"))`.
105+
auto &oplist = stepBlock->getOperations();
106+
107+
auto iterator = oplist.begin();
108+
109+
// We might find constants at beginning. Skip them.
110+
// We could have hoisted them outside the for loop in previous passes, but
111+
// it hasn't been done yet.
112+
while (iterator != oplist.end() && isa<ConstantOp>(*iterator))
113+
++iterator;
114+
115+
if (iterator == oplist.end())
116+
return false;
117+
118+
auto load = dyn_cast<LoadOp>(*iterator);
119+
if (!load)
120+
return false;
121+
122+
// We assume this is the address of induction variable (IV). The operations
123+
// that come next will check if that's true.
124+
mlir::Value addr = load.getAddr();
125+
mlir::Value iv = load.getResult();
126+
127+
// Then we try to match either "++IV" or "IV += n". Same for reversed loops.
128+
if (++iterator == oplist.end())
129+
return false;
130+
131+
mlir::Operation &arith = *iterator;
132+
133+
if (auto unary = dyn_cast<UnaryOp>(arith)) {
134+
// Not operating on induction variable. Fail.
135+
if (unary.getInput() != iv)
136+
return false;
137+
138+
if (unary.getKind() == UnaryOpKind::Inc)
139+
step = 1;
140+
else if (unary.getKind() == UnaryOpKind::Dec)
141+
step = -1;
142+
else
143+
return false;
144+
}
128145

129-
return step;
146+
if (auto binary = dyn_cast<BinOp>(arith)) {
147+
if (binary.getLhs() != iv)
148+
return false;
149+
150+
mlir::Value value = binary.getRhs();
151+
if (auto constValue = dyn_cast<ConstantOp>(value.getDefiningOp());
152+
isa<IntAttr>(constValue.getValue()))
153+
step = getConstant(constValue);
154+
155+
if (binary.getKind() == BinOpKind::Add)
156+
; // Nothing to do. Step has been calculated above.
157+
else if (binary.getKind() == BinOpKind::Sub)
158+
step = -step;
159+
else
160+
return false;
161+
}
162+
163+
// Check whether we immediately store this value into the appropriate place.
164+
if (++iterator == oplist.end())
165+
return false;
166+
167+
auto store = dyn_cast<StoreOp>(*iterator);
168+
if (!store || store.getAddr() != addr ||
169+
store.getValue() != arith.getResult(0))
170+
return false;
171+
172+
if (++iterator == oplist.end())
173+
return false;
174+
175+
// Finally, this should precede a yield with nothing in between.
176+
bool success = isa<YieldOp>(*iterator);
177+
178+
// Remember to update analysis information.
179+
if (success)
180+
ivAddr = addr;
181+
182+
return success;
130183
}
131184

132185
static bool isIVLoad(mlir::Operation *op, mlir::Value IVAddr) {
@@ -143,7 +196,7 @@ static bool isIVLoad(mlir::Operation *op, mlir::Value IVAddr) {
143196

144197
cir::CmpOp SCFLoop::findCmpOp() {
145198
cmpOp = nullptr;
146-
for (auto *user : IVAddr.getUsers()) {
199+
for (auto *user : ivAddr.getUsers()) {
147200
if (user->getParentRegion() != &forOp.getCond())
148201
continue;
149202
if (auto loadOp = dyn_cast<cir::LoadOp>(*user)) {
@@ -162,10 +215,10 @@ cir::CmpOp SCFLoop::findCmpOp() {
162215
if (!mlir::isa<cir::IntType>(type))
163216
llvm_unreachable("Non-integer type IV is not supported");
164217

165-
auto lhsDefOp = cmpOp.getLhs().getDefiningOp();
218+
auto *lhsDefOp = cmpOp.getLhs().getDefiningOp();
166219
if (!lhsDefOp)
167220
llvm_unreachable("Can't find IV load");
168-
if (!isIVLoad(lhsDefOp, IVAddr))
221+
if (!isIVLoad(lhsDefOp, ivAddr))
169222
llvm_unreachable("cmpOp LHS is not IV");
170223

171224
if (cmpOp.getKind() != cir::CmpOpKind::le &&
@@ -187,7 +240,7 @@ mlir::Value SCFLoop::plusConstant(mlir::Value V, mlir::Location loc,
187240
// The operations before the loop have been transferred to MLIR.
188241
// So we need to go through getRemappedValue to find the value.
189242
mlir::Value SCFLoop::findIVInitValue() {
190-
auto remapAddr = rewriter->getRemappedValue(IVAddr);
243+
auto remapAddr = rewriter->getRemappedValue(ivAddr);
191244
if (!remapAddr)
192245
return nullptr;
193246
if (!remapAddr.hasOneUse())
@@ -199,7 +252,13 @@ mlir::Value SCFLoop::findIVInitValue() {
199252
}
200253

201254
void SCFLoop::analysis() {
202-
step = findStepAndIV(IVAddr);
255+
canonical = findStepAndIV();
256+
if (!canonical) {
257+
mlir::emitError(forOp.getLoc(),
258+
"cannot handle non-constant step for induction variable");
259+
return;
260+
}
261+
203262
cmpOp = findCmpOp();
204263
auto IVInit = findIVInitValue();
205264
// The loop end value should be hoisted out of loop by -cir-mlir-scf-prepare.
@@ -237,7 +296,7 @@ void SCFLoop::transferToSCFForOp() {
237296
llvm_unreachable(
238297
"Not support lowering loop with break, continue or if yet");
239298
// Replace the IV usage to scf loop induction variable.
240-
if (isIVLoad(op, IVAddr)) {
299+
if (isIVLoad(op, ivAddr)) {
241300
// Replace CIR IV load with arith.addi scf.IV, 0.
242301
// The replacement makes the SCF IV can be automatically propogated
243302
// by OpAdaptor for individual IV user lowering.
@@ -293,6 +352,10 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
293352
mlir::ConversionPatternRewriter &rewriter) const override {
294353
SCFLoop loop(op, &rewriter);
295354
loop.analysis();
355+
if (!loop.isCanonical()) {
356+
mlir::emitError(op.getLoc(), "cannot handle non-canonicalized loop");
357+
return mlir::failure();
358+
}
296359
loop.transferToSCFForOp();
297360
rewriter.eraseOp(op);
298361
return mlir::success();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: not %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o - 2>&1 | FileCheck %s
2+
3+
void f();
4+
5+
void reject() {
6+
for (int i = 0; i < 100; i++, f());
7+
// CHECK: cannot handle non-constant step for induction variable
8+
// CHECK: cannot handle non-canonicalized loop
9+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// RUN: not %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o - 2>&1 | FileCheck %s
2+
3+
void reject() {
4+
for (int i = 0; i < 100; i++, i++);
5+
// CHECK: cannot handle non-constant step for induction variable
6+
// CHECK: cannot handle non-canonicalized loop
7+
}

0 commit comments

Comments
 (0)