Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions include/Finch/FinchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,93 @@ def Finch_AssignOp : Finch_Op<"assign", [MemRefsNormalizable, AllTypesMatch<["in
}


// HIGH-LEVEL LAZY API


class LazyFinch_Op<string mnemonic, list<Trait> traits = []> :
Op<Finch_Dialect, mnemonic, traits>;


def LazyFinch_Immediate : LazyFinch_Op<"immediate", [Pure]> {
let summary = "Lazy Finch immediate op";
let description = [{
Logical AST expression for the literal value `val`.

```mlir
%0 = finch.immediate val : !finch.logicnode
```
}];
let arguments = (ins AnyType:$val);
let results = (outs LazyFinch_LogicNodeType);
}

def LazyFinch_Deferred : LazyFinch_Op<"deferred", [Pure]> {
let summary = "Lazy Finch deferred op";
// TODO
}

def LazyFinch_Field : LazyFinch_Op<"field", [Pure]> {
let summary = "Lazy Finch field op";
// TODO
}

def LazyFinch_Alias : LazyFinch_Op<"alias", [Pure]> {
let summary = "Lazy Finch alias op";
// TODO
}

def LazyFinch_Table : LazyFinch_Op<"table", [Pure]> {
let summary = "Lazy Finch table op";
let description = [{
Logical AST expression for a tensor object `val`, indexed by fields `idxs...`.

```mlir
%0 = finch.table val, idxs : !finch.logicnode
```
}];
let arguments = (ins AnyType:$val, Variadic<AnyType>:$idxs);
let results = (outs LazyFinch_LogicNodeType);
}

def LazyFinch_MapJoin : LazyFinch_Op<"mapjoin", [Pure]> {
let summary = "Lazy Finch mapjoin op";
let description = [{
Logical AST expression for mapping the function `op` across `args...`.
The order of fields in the mapjoin is `unique(vcat(map(getfields, args)...))`

```mlir
%0 = finch.mapjoin op, args : !finch.logicnode
```
}];
let arguments = (ins AnyType:$op, Variadic<AnyType>:$args);
let results = (outs LazyFinch_LogicNodeType);
}

def LazyFinch_Aggregate : LazyFinch_Op<"aggregate", [Pure]> {
let summary = "Lazy Finch aggregate op";
let description = [{
Logical AST statement that reduces `arg` using `op`, starting with `init`.
`idxs` are the dimensions to reduce. May happen in any order.

```mlir
%0 = finch.aggregate op, init, arg, idxs : !finch.logicnode
```
}];
let arguments = (ins AnyType:$op, AnyType:$init, AnyType:$arg, Variadic<AnyType>:$args);
let results = (outs LazyFinch_LogicNodeType);
}

def LazyFinch_Reformat : LazyFinch_Op<"reformat", [Pure]> {
let summary = "Lazy Finch reformat op";
let description = [{
Logical AST statement that reformats `arg` into the tensor `tns`.

```mlir
%0 = finch.reformat tns, arg : !finch.logicnode
```
}];
let arguments = (ins AnyType:$tns, AnyType:$arg);
let results = (outs LazyFinch_LogicNodeType);
}

#endif // FINCH_OPS
9 changes: 9 additions & 0 deletions include/Finch/FinchPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ def FinchLoopletPass: Pass<"finch-looplet-pass"> {
}


// HIGH-LEVEL LAZY API


def LazyFinchIsolateReformatsPass: Pass<"lazy-finch-isolate-reformats-pass"> {
let summary = "Isolate reformats";
let description = [{
Optimization pass for isolating reformats.
}];
}


#endif // FINCH_PASS
11 changes: 11 additions & 0 deletions include/Finch/FinchTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,15 @@ def AnyNumberOrIndex : AnyTypeOf<[AnyInteger, AnyFloat, Index]>;
def Looplet : AnyTypeOf<[Finch_LoopletType]>;
def LoopletOrNumber : AnyTypeOf<[Finch_LoopletType, AnyNumber]>;


// HIGH-LEVEL LAZY API


def LazyFinch_LogicNodeType : TypeDef<Finch_Dialect, "LogicNode", []> {
let summary = "Finch LogicNode type";
let description = "LogicNode type in Finch dialect";
let mnemonic = "logicnode";
}


#endif // FINCH_TYPES
44 changes: 44 additions & 0 deletions lib/Finch/FinchPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,51 @@ class FinchLoopletPass
};


class LazyFinchIsolateReformatsRewriter : public OpRewritePattern<finch::ReformatOp> {
public:
using OpRewritePattern<finch::ReformatOp>::OpRewritePattern;

LogicalResult matchAndRewrite(finch::ReformatOp op, PatternRewriter &rewriter) const {
Value tns = op.getOperand(0);
Value arg = op.getOperand(1);

Alias a = "#A##123"; // generate unique symbol somehow
finch::AliasOp alias = finch::AliasOp(a);

// rather a pseudocode
rewriter.replaceOpWithNewOp<memref::SubqueryOp>(op, alias, finch::ReformatOp(tns, arg));

return success();
}
}

// mirrors optimize function from optimize.jl
class LazyFinchLogicPass
: public impl::LazyFinchLogicPassBase<LazyFinchLogicPass> {
public:
using impl::LazyFinchLogicPassBase<
LazyFinchLogicPass>::LazyFinchLogicPassBase;
void runOnOperation() final {
RewritePatternSet patterns(&getContext());
//patterns.add<LazyFinchLiftSubqueriesRewriter>(&getContext());
patterns.add<LazyFinchIsolateReformatsRewriter>(&getContext());
//patterns.add<LazyFinchIsolateAggregatesRewriter>(&getContext());
//patterns.add<LazyFinchIsolateTablesRewriter>(&getContext());
//patterns.add<LazyFinchPrettyLabelsRewriter>(&getContext());
//patterns.add<LazyFinchPropagateCopyQueriesRewriter>(&getContext());
//patterns.add<LazyFinchPropagateTransposeQueriesRewriter>(&getContext());
//patterns.add<LazyFinchPropagateMapQueriesRewriter>(&getContext());
//patterns.add<LazyFinchPropagateFieldsRewriter>(&getContext());
//patterns.add<LazyFinchPropagateTransposeQueriesRewriter>(&getContext());
//patterns.add<LazyFinchSetLoopOrderRewriter>(&getContext());
//patterns.add<LazyFinchPushFieldsRewriter>(&getContext());
//patterns.add<LazyFinchConcordizeRewriter>(&getContext());
//...
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
signalPassFailure();
}
};


} // namespace
Expand Down