@@ -42,21 +42,24 @@ class SCFLoop {
42
42
int64_t getStep () { return step; }
43
43
mlir::Value getLowerBound () { return lowerBound; }
44
44
mlir::Value getUpperBound () { return upperBound; }
45
+ bool isCanonical () { return canonical; }
45
46
46
- int64_t findStepAndIV (mlir::Value &addr);
47
+ // Returns true if successfully finds both step and induction variable.
48
+ bool findStepAndIV ();
47
49
cir::CmpOp findCmpOp ();
48
50
mlir::Value findIVInitValue ();
49
51
void analysis ();
50
52
51
- mlir::Value plusConstant (mlir::Value V , mlir::Location loc, int addend);
53
+ mlir::Value plusConstant (mlir::Value v , mlir::Location loc, int addend);
52
54
void transferToSCFForOp ();
53
55
54
56
private:
55
57
cir::ForOp forOp;
56
58
cir::CmpOp cmpOp;
57
- mlir::Value IVAddr , lowerBound = nullptr , upperBound = nullptr ;
59
+ mlir::Value ivAddr , lowerBound = nullptr , upperBound = nullptr ;
58
60
mlir::ConversionPatternRewriter *rewriter;
59
61
int64_t step = 0 ;
62
+ bool canonical = true ;
60
63
};
61
64
62
65
class SCFWhileLoop {
@@ -86,47 +89,97 @@ class SCFDoLoop {
86
89
};
87
90
88
91
static int64_t getConstant (cir::ConstantOp op) {
89
- auto attr = op-> getAttrs (). front () .getValue ();
90
- const auto IntAttr = mlir::dyn_cast <cir::IntAttr>(attr);
91
- return IntAttr .getValue ().getSExtValue ();
92
+ auto attr = op.getValue ();
93
+ const auto intAttr = mlir::cast <cir::IntAttr>(attr);
94
+ return intAttr .getValue ().getSExtValue ();
92
95
}
93
96
94
- int64_t SCFLoop::findStepAndIV (mlir::Value &addr ) {
97
+ bool SCFLoop::findStepAndIV () {
95
98
auto *stepBlock =
96
99
(forOp.maybeGetStep () ? &forOp.maybeGetStep ()->front () : nullptr );
97
100
assert (stepBlock && " Can not find step block" );
98
101
99
- int64_t step = 0 ;
100
- mlir::Value IV = nullptr ;
101
- // Try to match "IV load addr; ++IV; store IV, addr" to find step.
102
- for (mlir::Operation &op : *stepBlock)
103
- if (auto loadOp = dyn_cast<cir::LoadOp>(op)) {
104
- addr = loadOp.getAddr ();
105
- IV = loadOp.getResult ();
106
- } else if (auto cop = dyn_cast<cir::ConstantOp>(op)) {
107
- if (step)
108
- llvm_unreachable (
109
- " Not support multiple constant in step calculation yet" );
110
- step = getConstant (cop);
111
- } else if (auto bop = dyn_cast<cir::BinOp>(op)) {
112
- if (bop.getLhs () != IV)
113
- llvm_unreachable (" Find BinOp not operate on IV" );
114
- if (bop.getKind () != cir::BinOpKind::Add)
115
- llvm_unreachable (
116
- " Not support BinOp other than Add in step calculation yet" );
117
- } else if (auto uop = dyn_cast<cir::UnaryOp>(op)) {
118
- if (uop.getInput () != IV)
119
- llvm_unreachable (" Find UnaryOp not operate on IV" );
120
- if (uop.getKind () == cir::UnaryOpKind::Inc)
121
- step = 1 ;
122
- else if (uop.getKind () == cir::UnaryOpKind::Dec)
123
- llvm_unreachable (" Not support decrement step yet" );
124
- } else if (auto storeOp = dyn_cast<cir::StoreOp>(op)) {
125
- assert (storeOp.getAddr () == addr && " Can't find IV when lowering ForOp" );
126
- }
127
- assert (step && " Can't find step when lowering ForOp" );
102
+ // Try to match "iv = load addr; ++iv; store iv, addr; yield" to find step.
103
+ // We should match the exact pattern, in case there's something unexpected:
104
+ // we must rule out cases like `for (int i = 0; i < n; i++, printf("\n"))`.
105
+ auto &oplist = stepBlock->getOperations ();
106
+
107
+ auto iterator = oplist.begin ();
108
+
109
+ // We might find constants at beginning. Skip them.
110
+ // We could have hoisted them outside the for loop in previous passes, but
111
+ // it hasn't been done yet.
112
+ while (iterator != oplist.end () && isa<ConstantOp>(*iterator))
113
+ ++iterator;
114
+
115
+ if (iterator == oplist.end ())
116
+ return false ;
117
+
118
+ auto load = dyn_cast<LoadOp>(*iterator);
119
+ if (!load)
120
+ return false ;
121
+
122
+ // We assume this is the address of induction variable (IV). The operations
123
+ // that come next will check if that's true.
124
+ mlir::Value addr = load.getAddr ();
125
+ mlir::Value iv = load.getResult ();
126
+
127
+ // Then we try to match either "++IV" or "IV += n". Same for reversed loops.
128
+ if (++iterator == oplist.end ())
129
+ return false ;
130
+
131
+ mlir::Operation &arith = *iterator;
132
+
133
+ if (auto unary = dyn_cast<UnaryOp>(arith)) {
134
+ // Not operating on induction variable. Fail.
135
+ if (unary.getInput () != iv)
136
+ return false ;
137
+
138
+ if (unary.getKind () == UnaryOpKind::Inc)
139
+ step = 1 ;
140
+ else if (unary.getKind () == UnaryOpKind::Dec)
141
+ step = -1 ;
142
+ else
143
+ return false ;
144
+ }
128
145
129
- return step;
146
+ if (auto binary = dyn_cast<BinOp>(arith)) {
147
+ if (binary.getLhs () != iv)
148
+ return false ;
149
+
150
+ mlir::Value value = binary.getRhs ();
151
+ if (auto constValue = dyn_cast<ConstantOp>(value.getDefiningOp ());
152
+ isa<IntAttr>(constValue.getValue ()))
153
+ step = getConstant (constValue);
154
+
155
+ if (binary.getKind () == BinOpKind::Add)
156
+ ; // Nothing to do. Step has been calculated above.
157
+ else if (binary.getKind () == BinOpKind::Sub)
158
+ step = -step;
159
+ else
160
+ return false ;
161
+ }
162
+
163
+ // Check whether we immediately store this value into the appropriate place.
164
+ if (++iterator == oplist.end ())
165
+ return false ;
166
+
167
+ auto store = dyn_cast<StoreOp>(*iterator);
168
+ if (!store || store.getAddr () != addr ||
169
+ store.getValue () != arith.getResult (0 ))
170
+ return false ;
171
+
172
+ if (++iterator == oplist.end ())
173
+ return false ;
174
+
175
+ // Finally, this should precede a yield with nothing in between.
176
+ bool success = isa<YieldOp>(*iterator);
177
+
178
+ // Remember to update analysis information.
179
+ if (success)
180
+ ivAddr = addr;
181
+
182
+ return success;
130
183
}
131
184
132
185
static bool isIVLoad (mlir::Operation *op, mlir::Value IVAddr) {
@@ -143,7 +196,7 @@ static bool isIVLoad(mlir::Operation *op, mlir::Value IVAddr) {
143
196
144
197
cir::CmpOp SCFLoop::findCmpOp () {
145
198
cmpOp = nullptr ;
146
- for (auto *user : IVAddr .getUsers ()) {
199
+ for (auto *user : ivAddr .getUsers ()) {
147
200
if (user->getParentRegion () != &forOp.getCond ())
148
201
continue ;
149
202
if (auto loadOp = dyn_cast<cir::LoadOp>(*user)) {
@@ -162,10 +215,10 @@ cir::CmpOp SCFLoop::findCmpOp() {
162
215
if (!mlir::isa<cir::IntType>(type))
163
216
llvm_unreachable (" Non-integer type IV is not supported" );
164
217
165
- auto lhsDefOp = cmpOp.getLhs ().getDefiningOp ();
218
+ auto * lhsDefOp = cmpOp.getLhs ().getDefiningOp ();
166
219
if (!lhsDefOp)
167
220
llvm_unreachable (" Can't find IV load" );
168
- if (!isIVLoad (lhsDefOp, IVAddr ))
221
+ if (!isIVLoad (lhsDefOp, ivAddr ))
169
222
llvm_unreachable (" cmpOp LHS is not IV" );
170
223
171
224
if (cmpOp.getKind () != cir::CmpOpKind::le &&
@@ -187,7 +240,7 @@ mlir::Value SCFLoop::plusConstant(mlir::Value V, mlir::Location loc,
187
240
// The operations before the loop have been transferred to MLIR.
188
241
// So we need to go through getRemappedValue to find the value.
189
242
mlir::Value SCFLoop::findIVInitValue () {
190
- auto remapAddr = rewriter->getRemappedValue (IVAddr );
243
+ auto remapAddr = rewriter->getRemappedValue (ivAddr );
191
244
if (!remapAddr)
192
245
return nullptr ;
193
246
if (!remapAddr.hasOneUse ())
@@ -199,7 +252,13 @@ mlir::Value SCFLoop::findIVInitValue() {
199
252
}
200
253
201
254
void SCFLoop::analysis () {
202
- step = findStepAndIV (IVAddr);
255
+ canonical = findStepAndIV ();
256
+ if (!canonical) {
257
+ mlir::emitError (forOp.getLoc (),
258
+ " cannot handle non-constant step for induction variable" );
259
+ return ;
260
+ }
261
+
203
262
cmpOp = findCmpOp ();
204
263
auto IVInit = findIVInitValue ();
205
264
// The loop end value should be hoisted out of loop by -cir-mlir-scf-prepare.
@@ -237,7 +296,7 @@ void SCFLoop::transferToSCFForOp() {
237
296
llvm_unreachable (
238
297
" Not support lowering loop with break, continue or if yet" );
239
298
// Replace the IV usage to scf loop induction variable.
240
- if (isIVLoad (op, IVAddr )) {
299
+ if (isIVLoad (op, ivAddr )) {
241
300
// Replace CIR IV load with arith.addi scf.IV, 0.
242
301
// The replacement makes the SCF IV can be automatically propogated
243
302
// by OpAdaptor for individual IV user lowering.
@@ -293,6 +352,10 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
293
352
mlir::ConversionPatternRewriter &rewriter) const override {
294
353
SCFLoop loop (op, &rewriter);
295
354
loop.analysis ();
355
+ if (!loop.isCanonical ()) {
356
+ mlir::emitError (op.getLoc (), " cannot handle non-canonicalized loop" );
357
+ return mlir::failure ();
358
+ }
296
359
loop.transferToSCFForOp ();
297
360
rewriter.eraseOp (op);
298
361
return mlir::success ();
0 commit comments