Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion include/Finch/FinchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,43 @@ def Finch_LookupOp : Finch_Op<"lookup",[Pure, NoTerminator]> {
def Finch_StepperOp : Finch_Op<"stepper",[Pure, NoTerminator]> {
let summary = "Finch Stepper Looplets";
let description = [{
This operation defines the Stepper Looplet, which represents repetitive patterns in child looplets.
Stepper takes four function arguments that always return with finch.return:

1. **seek**: `coordinate:index -> position:index`
Returns the position of the child looplet whose range includes the given input coordinate.

2. **stop**: `position:index -> coordinate:index`
Returns the coordinate at which the child looplet at the given position ends.

3. **body**: `position:index -> child:looplet`
Returns the child looplet located at the specified position.

4. **next**: `position:index -> position:index`
Returns the subsequent position following the given position.

```mlir
%3 = finch.stepper %lb %ub :
%3 = finch.stepper
seek = {
^bb(%crd : index):
%pos = arith.addi %pos, %c1 : index
finch.return %pos : index
}
stop = {
^bb(%pos : index):
%stopcrd = arith.muli %pos, %shape : index
finch.return %stopcrd : index
}
body = {
^bb(%pos : index):
%run = finch.run %f1 : (f32) -> (!finch.looplet)
finch.return %run : !finch.looplet
}
next = {
^bb(%pos : index):
%nextpos = arith.addi %pos, %c1 : index
finch.return %nextpos : index
}
```
}];

Expand Down
62 changes: 59 additions & 3 deletions include/Finch/FinchPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,67 @@ def FinchLoopletSequence: Pass<"finch-looplet-sequence"> {
def FinchLoopletStepper: Pass<"finch-looplet-stepper"> {
let summary = "Compiler Pass for Stepper Looplets";
let description = [{
Compiler Pass for handling stepper looplets
Lowers finch.access of finch.stepper inside scf.for

`finch.stepper seek(crd):pos, stop(pos):crd, body(pos):looplet, next(pos):pos` can be thought as a function
p = initpos
f(x) = body(p) (if x < stop(p))
body(next(p)) (if stop(p) <= x < stop(next(p)))
body(next(next(p))) (if stop(next(p)) <= x < stop(next(next(p))))
...

if we iterate f(x) over [st, en),
p = seek(st)
f(x) = body(p) (if x in [st, min(en, stop(p))) )
body(next(p)) (if x in [min(en, stop(p)), min(en, stop(next(p)))))
...

A pseudocode of lowering a single stepper looks like :

```mlir

%0 = finch.stepper
seek = {
^bb(%crd:index):
...
}
stop = {
^bb(%pos:index):
...
}
body = {
^bb(%pos:index):
...
}
next = {
^bb(%pos:index):
...
}
scf.for %idx = %st to %en step %c1 {
%1 = finch.access %0, %idx : f32
}
```
to be filled.

to

```mlir
%initpos = stepper.seek(%st)
%1:2 = scf.while (%pos=%initpos, %idx=%st) : (index,index) -> (index,index) {
%cmp = arith.cmpi ult, %idx, %en : index
scf.condition(%cmp) %pos, %idx : index, index
} do {
^bb(%pos: index, %idx: index):
%stop = stepper.stop(%pos)
%body = stepper.body(%pos)
%end = arith.minui %2, %en : index
scf.for %idx2 = %idx to %end step %c1 {
%1 = finch.access %body, %idx2 : f32
}

%nextpos = stepper.next(%pos)
scf.yield %nextpos, %end
}
```

}];
}

Expand Down
1 change: 1 addition & 0 deletions lib/Finch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRFinch
FinchDialect.cpp
FinchOps.cpp
FinchRunPass.cpp
FinchStepperPass.cpp
FinchPasses.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
201 changes: 0 additions & 201 deletions lib/Finch/FinchPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ namespace mlir::finch {
#define GEN_PASS_DEF_FINCHSIMPLIFIER
#define GEN_PASS_DEF_FINCHINSTANTIATE
#define GEN_PASS_DEF_FINCHLOOPLETSEQUENCE
#define GEN_PASS_DEF_FINCHLOOPLETSTEPPER
#define GEN_PASS_DEF_FINCHLOOPLETLOOKUP
#include "Finch/FinchPasses.h.inc"

Expand Down Expand Up @@ -382,192 +381,6 @@ class FinchLoopletLookupRewriter : public OpRewritePattern<scf::ForOp> {



class FinchLoopletStepperRewriter : public OpRewritePattern<scf::ForOp> {
public:
using OpRewritePattern<scf::ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const final {
auto indVar = forOp.getInductionVar();

//llvm::outs() << "(0)\n";
OpBuilder builder(forOp);
Location loc = forOp.getLoc();

// Collect all the steppers from accesses
IRMapping mapper;
SmallVector<finch::StepperOp, 4> stepperLooplets;
SmallVector<finch::AccessOp, 4> accessOps;
for (auto& accessOp : *forOp.getBody()) {
if (isa<mlir::finch::AccessOp>(accessOp)) {
Value accessVar = accessOp.getOperand(1);
if (accessVar == indVar) {
Operation* looplet = accessOp.getOperand(0).getDefiningOp();
if (isa<finch::StepperOp>(looplet)) {
// There can be multiple uses of this Stepper.
// We don't want to erase original Stepper when lowering
// because of other use.
// So everytime we lower Stepper, clone it.
//llvm::outs() << *looplet << "\n";
Operation* clonedStepper = rewriter.clone(*looplet);
stepperLooplets.push_back(cast<finch::StepperOp>(clonedStepper));
accessOps.push_back(cast<finch::AccessOp>(accessOp));
}
}
}
}
//llvm::outs() << "(0')\n";
//llvm::outs() << *(forOp->getBlock()->getParentOp()->getBlock()->getParentOp()) << "\n";

if (stepperLooplets.empty()) {
return failure();
}

// Main Stepper Rewrite
Value loopLowerBound = forOp.getLowerBound();
Value loopUpperBound = forOp.getUpperBound();

//llvm::outs() << "(1)\n";
// Call Seek
SmallVector<Value, 4> seekPositions;
for (auto& stepperLooplet : stepperLooplets) {
Block &seekBlock = stepperLooplet.getRegion(0).front();

Operation* seekReturn = seekBlock.getTerminator();
Value seekPosition = seekReturn->getOperand(0);
rewriter.inlineBlockBefore(&seekBlock, forOp, ValueRange(loopLowerBound));
seekPositions.push_back(seekPosition);
rewriter.eraseOp(seekReturn);
}

// create while Op
seekPositions.push_back(loopLowerBound);
unsigned numIterArgs = seekPositions.size();
ValueRange iterArgs(seekPositions);
scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(
loc, iterArgs.getTypes(), iterArgs);


// fill condition
SmallVector<Location, 4> locations(numIterArgs, loc);
Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
iterArgs.getTypes(), locations);
rewriter.setInsertionPointToEnd(before);
Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
before->getArgument(numIterArgs-1),
loopUpperBound);
rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());


// after region of while op
Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
iterArgs.getTypes(), locations);

rewriter.setInsertionPointToEnd(after);
rewriter.moveOpBefore(forOp, after, after->end());


//llvm::outs() << "(2)\n";
// call stop then intersection
rewriter.setInsertionPoint(forOp);
SmallVector<Value, 4> stopCoords;
Value intersectUpperBound = loopUpperBound;
for (unsigned i = 0; i < stepperLooplets.size(); i++) {
auto stepperLooplet = stepperLooplets[i];
Block &stopBlock = stepperLooplet.getRegion(1).front();

// IDK why but this order is important.
// getTerminator -> inlineBlockBefore -> getOperand -> eraseOp
Operation* stopReturn = stopBlock.getTerminator();
rewriter.inlineBlockBefore(&stopBlock, forOp, after->getArgument(i));
Value stopCoord = stopReturn->getOperand(0);
rewriter.eraseOp(stopReturn);

//llvm::outs() << "(2-2)\n";
intersectUpperBound = rewriter.create<arith::MinUIOp>(
loc, intersectUpperBound, stopCoord);
stopCoords.push_back(stopCoord);
}
//llvm::outs() << *(forOp->getBlock()->getParentOp()->getBlock()->getParentOp()) << "\n";
//llvm::outs() << numIterArgs << "\n";
forOp.setLowerBound(after->getArgument(numIterArgs-1));
forOp.setUpperBound(intersectUpperBound);



//llvm::outs() << "(3)\n";
//llvm::outs() << *(forOp->getBlock()->getParentOp()->getBlock()->getParentOp()) << "\n";

// call body and replace access
for (unsigned i = 0; i < stepperLooplets.size(); i++) {
auto stepperLooplet = stepperLooplets[i];
Block &bodyBlock = stepperLooplet.getRegion(2).front();
Operation* bodyReturn = bodyBlock.getTerminator();
Value bodyLooplet = bodyReturn->getOperand(0);
rewriter.inlineBlockBefore(&bodyBlock, forOp, after->getArgument(i));

//Operation* loopletOp = stepperLooplet;
//Operation* accessOp = mapper.lookupOrDefault(loopletOp);
//accessOp->setOperand(0, bodyLooplet);
accessOps[i].setOperand(0, bodyLooplet);
rewriter.eraseOp(bodyReturn);
}

//// current Upper Bound become next iteration's lower bound
rewriter.setInsertionPointToEnd(after);
Value nextCoord = intersectUpperBound;

//llvm::outs() << "(4)\n";
//// call next
SmallVector<Value,4> nextPositions;
Type indexType = rewriter.getIndexType();
for (unsigned i = 0; i < stepperLooplets.size(); i++) {
auto stepperLooplet = stepperLooplets[i];
auto currPos = after->getArgument(i);
auto stopCoord = stopCoords[i];

Block &nextBlock = stepperLooplet.getRegion(3).front();
Operation* nextReturn = nextBlock.getTerminator();
Value nextPos = nextReturn->getOperand(0);

rewriter.setInsertionPointToEnd(after);
Value eq = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, stopCoord, intersectUpperBound);

scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, indexType, eq, true);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
scf::YieldOp thenYieldOp = rewriter.create<scf::YieldOp>(loc, nextPos);
rewriter.inlineBlockBefore(&nextBlock, thenYieldOp, currPos);

rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
scf::YieldOp elseYieldOp = rewriter.create<scf::YieldOp>(loc, currPos);

nextPositions.push_back(ifOp.getResult(0));
rewriter.eraseOp(nextReturn);
}
nextPositions.push_back(nextCoord);
rewriter.setInsertionPointToEnd(after);
rewriter.create<scf::YieldOp>(loc, ValueRange(nextPositions));

// Todo:Build a chain
// %0 = tensor.empty()
// %1 = scf.for $i = $b0 to %b1 step %c1 iter_args(%v = %0) //forOp
// return %1
//
// vvv
//
// %0 = tensor.empty()
// %res:4 = scf.while iter_args(%pos1=%pos1_, %pos2=%pos2_, %idx=%idx_, %tensor=%0) {
// %2 = scf.for $i = $b0 to %b1 step %c1 iter_args(%v = %tensor) //newForOp2
// }
// return %res#3



return success();
}
};

class FinchInstantiate
: public impl::FinchInstantiateBase<FinchInstantiate> {
public:
Expand Down Expand Up @@ -631,20 +444,6 @@ class FinchLoopletLookup
}
};

class FinchLoopletStepper
: public impl::FinchLoopletStepperBase<FinchLoopletStepper> {
public:
using impl::FinchLoopletStepperBase<
FinchLoopletStepper>::FinchLoopletStepperBase;
void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.add<FinchLoopletStepperRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsGreedily(getOperation(), patternSet)))
signalPassFailure();
}
};


} // namespace
} // namespace mlir::finch
Loading