diff --git a/include/Finch/FinchOps.td b/include/Finch/FinchOps.td index 142ee49..ee30d68 100644 --- a/include/Finch/FinchOps.td +++ b/include/Finch/FinchOps.td @@ -150,8 +150,43 @@ def Finch_LookupOp : Finch_Op<"lookup",[Pure, NoTerminator]> { def Finch_StepperOp : Finch_Op<"stepper",[Pure, NoTerminator]> { let summary = "Finch Stepper Looplets"; let description = [{ + This operation defines the Stepper Looplet, which represents repetitive patterns in child looplets. + Stepper takes four function arguments that always return with finch.return: + + 1. **seek**: `coordinate:index -> position:index` + Returns the position of the child looplet whose range includes the given input coordinate. + + 2. **stop**: `position:index -> coordinate:index` + Returns the coordinate at which the child looplet at the given position ends. + + 3. **body**: `position:index -> child:looplet` + Returns the child looplet located at the specified position. + + 4. **next**: `position:index -> position:index` + Returns the subsequent position following the given position. + ```mlir - %3 = finch.stepper %lb %ub : + %3 = finch.stepper + seek = { + ^bb(%crd : index): + %pos = arith.addi %pos, %c1 : index + finch.return %pos : index + } + stop = { + ^bb(%pos : index): + %stopcrd = arith.muli %pos, %shape : index + finch.return %stopcrd : index + } + body = { + ^bb(%pos : index): + %run = finch.run %f1 : (f32) -> (!finch.looplet) + finch.return %run : !finch.looplet + } + next = { + ^bb(%pos : index): + %nextpos = arith.addi %pos, %c1 : index + finch.return %nextpos : index + } ``` }]; diff --git a/include/Finch/FinchPasses.td b/include/Finch/FinchPasses.td index 0c40152..fc373f0 100644 --- a/include/Finch/FinchPasses.td +++ b/include/Finch/FinchPasses.td @@ -68,11 +68,67 @@ def FinchLoopletSequence: Pass<"finch-looplet-sequence"> { def FinchLoopletStepper: Pass<"finch-looplet-stepper"> { let summary = "Compiler Pass for Stepper Looplets"; let description = [{ - Compiler Pass for handling stepper looplets + Lowers finch.access of finch.stepper inside scf.for + + `finch.stepper seek(crd):pos, stop(pos):crd, body(pos):looplet, next(pos):pos` can be thought as a function + p = initpos + f(x) = body(p) (if x < stop(p)) + body(next(p)) (if stop(p) <= x < stop(next(p))) + body(next(next(p))) (if stop(next(p)) <= x < stop(next(next(p)))) + ... + + if we iterate f(x) over [st, en), + p = seek(st) + f(x) = body(p) (if x in [st, min(en, stop(p))) ) + body(next(p)) (if x in [min(en, stop(p)), min(en, stop(next(p))))) + ... + + A pseudocode of lowering a single stepper looks like : + + ```mlir + + %0 = finch.stepper + seek = { + ^bb(%crd:index): + ... + } + stop = { + ^bb(%pos:index): + ... + } + body = { + ^bb(%pos:index): + ... + } + next = { + ^bb(%pos:index): + ... + } + scf.for %idx = %st to %en step %c1 { + %1 = finch.access %0, %idx : f32 + } ``` - to be filled. - + to + + ```mlir + %initpos = stepper.seek(%st) + %1:2 = scf.while (%pos=%initpos, %idx=%st) : (index,index) -> (index,index) { + %cmp = arith.cmpi ult, %idx, %en : index + scf.condition(%cmp) %pos, %idx : index, index + } do { + ^bb(%pos: index, %idx: index): + %stop = stepper.stop(%pos) + %body = stepper.body(%pos) + %end = arith.minui %2, %en : index + scf.for %idx2 = %idx to %end step %c1 { + %1 = finch.access %body, %idx2 : f32 + } + + %nextpos = stepper.next(%pos) + scf.yield %nextpos, %end + } ``` + }]; } diff --git a/lib/Finch/CMakeLists.txt b/lib/Finch/CMakeLists.txt index ce1028b..efdff18 100644 --- a/lib/Finch/CMakeLists.txt +++ b/lib/Finch/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRFinch FinchDialect.cpp FinchOps.cpp FinchRunPass.cpp + FinchStepperPass.cpp FinchPasses.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Finch/FinchPasses.cpp b/lib/Finch/FinchPasses.cpp index e1858d5..1782f16 100644 --- a/lib/Finch/FinchPasses.cpp +++ b/lib/Finch/FinchPasses.cpp @@ -26,7 +26,6 @@ namespace mlir::finch { #define GEN_PASS_DEF_FINCHSIMPLIFIER #define GEN_PASS_DEF_FINCHINSTANTIATE #define GEN_PASS_DEF_FINCHLOOPLETSEQUENCE -#define GEN_PASS_DEF_FINCHLOOPLETSTEPPER #define GEN_PASS_DEF_FINCHLOOPLETLOOKUP #include "Finch/FinchPasses.h.inc" @@ -382,192 +381,6 @@ class FinchLoopletLookupRewriter : public OpRewritePattern { -class FinchLoopletStepperRewriter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ForOp forOp, - PatternRewriter &rewriter) const final { - auto indVar = forOp.getInductionVar(); - - //llvm::outs() << "(0)\n"; - OpBuilder builder(forOp); - Location loc = forOp.getLoc(); - - // Collect all the steppers from accesses - IRMapping mapper; - SmallVector stepperLooplets; - SmallVector accessOps; - for (auto& accessOp : *forOp.getBody()) { - if (isa(accessOp)) { - Value accessVar = accessOp.getOperand(1); - if (accessVar == indVar) { - Operation* looplet = accessOp.getOperand(0).getDefiningOp(); - if (isa(looplet)) { - // There can be multiple uses of this Stepper. - // We don't want to erase original Stepper when lowering - // because of other use. - // So everytime we lower Stepper, clone it. - //llvm::outs() << *looplet << "\n"; - Operation* clonedStepper = rewriter.clone(*looplet); - stepperLooplets.push_back(cast(clonedStepper)); - accessOps.push_back(cast(accessOp)); - } - } - } - } - //llvm::outs() << "(0')\n"; - //llvm::outs() << *(forOp->getBlock()->getParentOp()->getBlock()->getParentOp()) << "\n"; - - if (stepperLooplets.empty()) { - return failure(); - } - - // Main Stepper Rewrite - Value loopLowerBound = forOp.getLowerBound(); - Value loopUpperBound = forOp.getUpperBound(); - - //llvm::outs() << "(1)\n"; - // Call Seek - SmallVector seekPositions; - for (auto& stepperLooplet : stepperLooplets) { - Block &seekBlock = stepperLooplet.getRegion(0).front(); - - Operation* seekReturn = seekBlock.getTerminator(); - Value seekPosition = seekReturn->getOperand(0); - rewriter.inlineBlockBefore(&seekBlock, forOp, ValueRange(loopLowerBound)); - seekPositions.push_back(seekPosition); - rewriter.eraseOp(seekReturn); - } - - // create while Op - seekPositions.push_back(loopLowerBound); - unsigned numIterArgs = seekPositions.size(); - ValueRange iterArgs(seekPositions); - scf::WhileOp whileOp = rewriter.create( - loc, iterArgs.getTypes(), iterArgs); - - - // fill condition - SmallVector locations(numIterArgs, loc); - Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, - iterArgs.getTypes(), locations); - rewriter.setInsertionPointToEnd(before); - Value cond = rewriter.create(loc, arith::CmpIPredicate::ult, - before->getArgument(numIterArgs-1), - loopUpperBound); - rewriter.create(loc, cond, before->getArguments()); - - - // after region of while op - Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, - iterArgs.getTypes(), locations); - - rewriter.setInsertionPointToEnd(after); - rewriter.moveOpBefore(forOp, after, after->end()); - - - //llvm::outs() << "(2)\n"; - // call stop then intersection - rewriter.setInsertionPoint(forOp); - SmallVector stopCoords; - Value intersectUpperBound = loopUpperBound; - for (unsigned i = 0; i < stepperLooplets.size(); i++) { - auto stepperLooplet = stepperLooplets[i]; - Block &stopBlock = stepperLooplet.getRegion(1).front(); - - // IDK why but this order is important. - // getTerminator -> inlineBlockBefore -> getOperand -> eraseOp - Operation* stopReturn = stopBlock.getTerminator(); - rewriter.inlineBlockBefore(&stopBlock, forOp, after->getArgument(i)); - Value stopCoord = stopReturn->getOperand(0); - rewriter.eraseOp(stopReturn); - - //llvm::outs() << "(2-2)\n"; - intersectUpperBound = rewriter.create( - loc, intersectUpperBound, stopCoord); - stopCoords.push_back(stopCoord); - } - //llvm::outs() << *(forOp->getBlock()->getParentOp()->getBlock()->getParentOp()) << "\n"; - //llvm::outs() << numIterArgs << "\n"; - forOp.setLowerBound(after->getArgument(numIterArgs-1)); - forOp.setUpperBound(intersectUpperBound); - - - - //llvm::outs() << "(3)\n"; - //llvm::outs() << *(forOp->getBlock()->getParentOp()->getBlock()->getParentOp()) << "\n"; - - // call body and replace access - for (unsigned i = 0; i < stepperLooplets.size(); i++) { - auto stepperLooplet = stepperLooplets[i]; - Block &bodyBlock = stepperLooplet.getRegion(2).front(); - Operation* bodyReturn = bodyBlock.getTerminator(); - Value bodyLooplet = bodyReturn->getOperand(0); - rewriter.inlineBlockBefore(&bodyBlock, forOp, after->getArgument(i)); - - //Operation* loopletOp = stepperLooplet; - //Operation* accessOp = mapper.lookupOrDefault(loopletOp); - //accessOp->setOperand(0, bodyLooplet); - accessOps[i].setOperand(0, bodyLooplet); - rewriter.eraseOp(bodyReturn); - } - - //// current Upper Bound become next iteration's lower bound - rewriter.setInsertionPointToEnd(after); - Value nextCoord = intersectUpperBound; - - //llvm::outs() << "(4)\n"; - //// call next - SmallVector nextPositions; - Type indexType = rewriter.getIndexType(); - for (unsigned i = 0; i < stepperLooplets.size(); i++) { - auto stepperLooplet = stepperLooplets[i]; - auto currPos = after->getArgument(i); - auto stopCoord = stopCoords[i]; - - Block &nextBlock = stepperLooplet.getRegion(3).front(); - Operation* nextReturn = nextBlock.getTerminator(); - Value nextPos = nextReturn->getOperand(0); - - rewriter.setInsertionPointToEnd(after); - Value eq = rewriter.create( - loc, arith::CmpIPredicate::eq, stopCoord, intersectUpperBound); - - scf::IfOp ifOp = rewriter.create(loc, indexType, eq, true); - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - scf::YieldOp thenYieldOp = rewriter.create(loc, nextPos); - rewriter.inlineBlockBefore(&nextBlock, thenYieldOp, currPos); - - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - scf::YieldOp elseYieldOp = rewriter.create(loc, currPos); - - nextPositions.push_back(ifOp.getResult(0)); - rewriter.eraseOp(nextReturn); - } - nextPositions.push_back(nextCoord); - rewriter.setInsertionPointToEnd(after); - rewriter.create(loc, ValueRange(nextPositions)); - - // Todo:Build a chain - // %0 = tensor.empty() - // %1 = scf.for $i = $b0 to %b1 step %c1 iter_args(%v = %0) //forOp - // return %1 - // - // vvv - // - // %0 = tensor.empty() - // %res:4 = scf.while iter_args(%pos1=%pos1_, %pos2=%pos2_, %idx=%idx_, %tensor=%0) { - // %2 = scf.for $i = $b0 to %b1 step %c1 iter_args(%v = %tensor) //newForOp2 - // } - // return %res#3 - - - - return success(); - } -}; - class FinchInstantiate : public impl::FinchInstantiateBase { public: @@ -631,20 +444,6 @@ class FinchLoopletLookup } }; -class FinchLoopletStepper - : public impl::FinchLoopletStepperBase { -public: - using impl::FinchLoopletStepperBase< - FinchLoopletStepper>::FinchLoopletStepperBase; - void runOnOperation() final { - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsGreedily(getOperation(), patternSet))) - signalPassFailure(); - } -}; - } // namespace } // namespace mlir::finch diff --git a/lib/Finch/FinchStepperPass.cpp b/lib/Finch/FinchStepperPass.cpp new file mode 100644 index 0000000..5d9145f --- /dev/null +++ b/lib/Finch/FinchStepperPass.cpp @@ -0,0 +1,230 @@ +//===- FinchStepperPasses.cpp - Finch Stepper passes -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" + +#include "Finch/FinchPasses.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::finch { +#define GEN_PASS_DEF_FINCHLOOPLETSTEPPER +#include "Finch/FinchPasses.h.inc" + +namespace { + +class FinchLoopletStepperRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const final { + auto indVar = forOp.getInductionVar(); + + //llvm::outs() << "(0)\n"; + OpBuilder builder(forOp); + Location loc = forOp.getLoc(); + + // Collect all the steppers from accesses + IRMapping mapper; + SmallVector stepperLooplets; + SmallVector accessOps; + for (auto& accessOp : *forOp.getBody()) { + if (isa(accessOp)) { + Value accessVar = accessOp.getOperand(1); + if (accessVar == indVar) { + Operation* looplet = accessOp.getOperand(0).getDefiningOp(); + if (isa(looplet)) { + // There can be multiple uses of this Stepper. + // We don't want to erase original Stepper when lowering + // because of other use. + // So everytime we lower Stepper, clone it. + //llvm::outs() << *looplet << "\n"; + Operation* clonedStepper = rewriter.clone(*looplet); + stepperLooplets.push_back(cast(clonedStepper)); + accessOps.push_back(cast(accessOp)); + } + } + } + } + //llvm::outs() << "(0')\n"; + //llvm::outs() << *(forOp->getBlock()->getParentOp()->getBlock()->getParentOp()) << "\n"; + + if (stepperLooplets.empty()) { + return failure(); + } + + // Main Stepper Rewrite + Value loopLowerBound = forOp.getLowerBound(); + Value loopUpperBound = forOp.getUpperBound(); + + //llvm::outs() << "(1)\n"; + // Call Seek + SmallVector seekPositions; + for (auto& stepperLooplet : stepperLooplets) { + Block &seekBlock = stepperLooplet.getRegion(0).front(); + + Operation* seekReturn = seekBlock.getTerminator(); + rewriter.inlineBlockBefore(&seekBlock, forOp, ValueRange(loopLowerBound)); + Value seekPosition = seekReturn->getOperand(0); + seekPositions.push_back(seekPosition); + rewriter.eraseOp(seekReturn); + } + + // create while Op + seekPositions.push_back(loopLowerBound); + unsigned numIterArgs = seekPositions.size(); + ValueRange iterArgs(seekPositions); + scf::WhileOp whileOp = rewriter.create( + loc, iterArgs.getTypes(), iterArgs); + + + // fill condition + SmallVector locations(numIterArgs, loc); + Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, + iterArgs.getTypes(), locations); + rewriter.setInsertionPointToEnd(before); + Value cond = rewriter.create(loc, arith::CmpIPredicate::ult, + before->getArgument(numIterArgs-1), + loopUpperBound); + rewriter.create(loc, cond, before->getArguments()); + + + // after region of while op + Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, + iterArgs.getTypes(), locations); + + rewriter.setInsertionPointToEnd(after); + rewriter.moveOpBefore(forOp, after, after->end()); + + + //llvm::outs() << "(2)\n"; + // call stop then intersection + rewriter.setInsertionPoint(forOp); + SmallVector stopCoords; + Value intersectUpperBound = loopUpperBound; + for (unsigned i = 0; i < stepperLooplets.size(); i++) { + auto stepperLooplet = stepperLooplets[i]; + Block &stopBlock = stepperLooplet.getRegion(1).front(); + + // IDK why but this order is important. + // getTerminator -> inlineBlockBefore -> getOperand -> eraseOp + Operation* stopReturn = stopBlock.getTerminator(); + rewriter.inlineBlockBefore(&stopBlock, forOp, after->getArgument(i)); + Value stopCoord = stopReturn->getOperand(0); + rewriter.eraseOp(stopReturn); + + //llvm::outs() << "(2-2)\n"; + intersectUpperBound = rewriter.create( + loc, intersectUpperBound, stopCoord); + stopCoords.push_back(stopCoord); + } + //llvm::outs() << *(forOp->getBlock()->getParentOp()->getBlock()->getParentOp()) << "\n"; + //llvm::outs() << numIterArgs << "\n"; + forOp.setLowerBound(after->getArgument(numIterArgs-1)); + forOp.setUpperBound(intersectUpperBound); + + + + //llvm::outs() << "(3)\n"; + //llvm::outs() << *(forOp->getBlock()->getParentOp()->getBlock()->getParentOp()) << "\n"; + + // call body and replace access + for (unsigned i = 0; i < stepperLooplets.size(); i++) { + auto stepperLooplet = stepperLooplets[i]; + Block &bodyBlock = stepperLooplet.getRegion(2).front(); + Operation* bodyReturn = bodyBlock.getTerminator(); + Value bodyLooplet = bodyReturn->getOperand(0); + rewriter.inlineBlockBefore(&bodyBlock, forOp, after->getArgument(i)); + + //Operation* loopletOp = stepperLooplet; + //Operation* accessOp = mapper.lookupOrDefault(loopletOp); + //accessOp->setOperand(0, bodyLooplet); + accessOps[i].setOperand(0, bodyLooplet); + rewriter.eraseOp(bodyReturn); + } + + //// current Upper Bound become next iteration's lower bound + rewriter.setInsertionPointToEnd(after); + Value nextCoord = intersectUpperBound; + + //llvm::outs() << "(4)\n"; + //// call next + SmallVector nextPositions; + Type indexType = rewriter.getIndexType(); + for (unsigned i = 0; i < stepperLooplets.size(); i++) { + auto stepperLooplet = stepperLooplets[i]; + auto currPos = after->getArgument(i); + auto stopCoord = stopCoords[i]; + + Block &nextBlock = stepperLooplet.getRegion(3).front(); + Operation* nextReturn = nextBlock.getTerminator(); + Value nextPos = nextReturn->getOperand(0); + + rewriter.setInsertionPointToEnd(after); + Value eq = rewriter.create( + loc, arith::CmpIPredicate::eq, stopCoord, intersectUpperBound); + + scf::IfOp ifOp = rewriter.create(loc, indexType, eq, true); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + scf::YieldOp thenYieldOp = rewriter.create(loc, nextPos); + rewriter.inlineBlockBefore(&nextBlock, thenYieldOp, currPos); + + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + scf::YieldOp elseYieldOp = rewriter.create(loc, currPos); + + nextPositions.push_back(ifOp.getResult(0)); + rewriter.eraseOp(nextReturn); + } + nextPositions.push_back(nextCoord); + rewriter.setInsertionPointToEnd(after); + rewriter.create(loc, ValueRange(nextPositions)); + + // Todo:Build a chain + // %0 = tensor.empty() + // %1 = scf.for $i = $b0 to %b1 step %c1 iter_args(%v = %0) //forOp + // return %1 + // + // vvv + // + // %0 = tensor.empty() + // %res:4 = scf.while iter_args(%pos1=%pos1_, %pos2=%pos2_, %idx=%idx_, %tensor=%0) { + // %2 = scf.for $i = $b0 to %b1 step %c1 iter_args(%v = %tensor) //newForOp2 + // } + // return %res#3 + + + + return success(); + } +}; + +class FinchLoopletStepper + : public impl::FinchLoopletStepperBase { +public: + using impl::FinchLoopletStepperBase< + FinchLoopletStepper>::FinchLoopletStepperBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsGreedily(getOperation(), patternSet))) + signalPassFailure(); + } +}; + + +} // namespace +} // namespace mlir::finch diff --git a/test/Finch/looplet-stepper.mlir b/test/Finch/looplet-stepper.mlir index c65593f..4251869 100644 --- a/test/Finch/looplet-stepper.mlir +++ b/test/Finch/looplet-stepper.mlir @@ -10,51 +10,53 @@ // CHECK-LABEL: func.func @test1( // CHECK-SAME: %[[VAL_0:.*]]: index, // CHECK-SAME: %[[VAL_1:.*]]: index) -> f32 { -// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_4:.*]] = arith.constant 4.000000e+00 : f32 -// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_3:.*]] = arith.constant 4.000000e+00 : f32 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index // CHECK: %[[VAL_6:.*]] = memref.alloc() : memref -// CHECK: memref.store %[[VAL_3]], %[[VAL_6]][] : memref -// CHECK: %[[VAL_7:.*]]:2 = scf.while (%[[VAL_8:.*]] = %[[VAL_2]], %[[VAL_9:.*]] = %[[VAL_0]]) : (index, index) -> (index, index) { +// CHECK: memref.store %[[VAL_2]], %[[VAL_6]][] : memref +// CHECK: %[[VAL_7:.*]]:2 = scf.while (%[[VAL_8:.*]] = %[[VAL_0]], %[[VAL_9:.*]] = %[[VAL_0]]) : (index, index) -> (index, index) { // CHECK: %[[VAL_10:.*]] = arith.cmpi ult, %[[VAL_9]], %[[VAL_1]] : index // CHECK: scf.condition(%[[VAL_10]]) %[[VAL_8]], %[[VAL_9]] : index, index // CHECK: } do { // CHECK: ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index): -// CHECK: %[[VAL_13:.*]] = arith.minui %[[VAL_1]], %[[VAL_11]] : index -// CHECK: %[[VAL_14:.*]] = finch.run %[[VAL_4]] : (f32) -> (!finch.looplet) -// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_5]] { -// CHECK: %[[VAL_16:.*]] = finch.access %[[VAL_14]], %[[VAL_15]] : f32 -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_6]][] : memref -// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : f32 -// CHECK: memref.store %[[VAL_18]], %[[VAL_6]][] : memref +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_5]] : index +// CHECK: %[[VAL_14:.*]] = arith.minui %[[VAL_1]], %[[VAL_13]] : index +// CHECK: %[[VAL_15:.*]] = finch.run %[[VAL_3]] : (f32) -> (!finch.looplet) +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_4]] { +// CHECK: %[[VAL_17:.*]] = finch.access %[[VAL_15]], %[[VAL_16]] : f32 +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_6]][] : memref +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: memref.store %[[VAL_19]], %[[VAL_6]][] : memref // CHECK: } -// CHECK: %[[VAL_19:.*]] = arith.cmpi eq, %[[VAL_11]], %[[VAL_13]] : index -// CHECK: %[[VAL_20:.*]] = scf.if %[[VAL_19]] -> (index) { -// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_11]], %[[VAL_5]] : index -// CHECK: scf.yield %[[VAL_21]] : index +// CHECK: %[[VAL_20:.*]] = arith.cmpi eq, %[[VAL_13]], %[[VAL_14]] : index +// CHECK: %[[VAL_21:.*]] = scf.if %[[VAL_20]] -> (index) { +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_11]], %[[VAL_4]] : index +// CHECK: scf.yield %[[VAL_22]] : index // CHECK: } else { // CHECK: scf.yield %[[VAL_11]] : index // CHECK: } -// CHECK: scf.yield %[[VAL_20]], %[[VAL_13]] : index, index +// CHECK: scf.yield %[[VAL_21]], %[[VAL_14]] : index, index // CHECK: } -// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_6]][] : memref -// CHECK: return %[[VAL_22]] : f32 +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_6]][] : memref +// CHECK: return %[[VAL_23]] : f32 // CHECK: } func.func @test1(%b0:index, %b1:index) -> f32{ %f0 = arith.constant 0.0 : f32 %f1 = arith.constant 4.0 : f32 %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index %step = finch.stepper seek={ ^bb0(%idx : index): - %firstpos = arith.constant 0 : index - finch.return %firstpos : index + finch.return %idx : index } stop={ ^bb(%pos : index): - finch.return %pos : index + %crd = arith.addi %pos, %c2 : index + finch.return %crd : index } body={ ^bb(%pos : index):