diff --git a/include/Finch/FinchOps.td b/include/Finch/FinchOps.td index e6cdb29..7618859 100644 --- a/include/Finch/FinchOps.td +++ b/include/Finch/FinchOps.td @@ -225,4 +225,93 @@ def Finch_AssignOp : Finch_Op<"assign", [MemRefsNormalizable, AllTypesMatch<["in } +// HIGH-LEVEL LAZY API + + +class LazyFinch_Op traits = []> : + Op; + + +def LazyFinch_Immediate : LazyFinch_Op<"immediate", [Pure]> { + let summary = "Lazy Finch immediate op"; + let description = [{ + Logical AST expression for the literal value `val`. + + ```mlir + %0 = finch.immediate val : !finch.logicnode + ``` + }]; + let arguments = (ins AnyType:$val); + let results = (outs LazyFinch_LogicNodeType); +} + +def LazyFinch_Deferred : LazyFinch_Op<"deferred", [Pure]> { + let summary = "Lazy Finch deferred op"; + // TODO +} + +def LazyFinch_Field : LazyFinch_Op<"field", [Pure]> { + let summary = "Lazy Finch field op"; + // TODO +} + +def LazyFinch_Alias : LazyFinch_Op<"alias", [Pure]> { + let summary = "Lazy Finch alias op"; + // TODO +} + +def LazyFinch_Table : LazyFinch_Op<"table", [Pure]> { + let summary = "Lazy Finch table op"; + let description = [{ + Logical AST expression for a tensor object `val`, indexed by fields `idxs...`. + + ```mlir + %0 = finch.table val, idxs : !finch.logicnode + ``` + }]; + let arguments = (ins AnyType:$val, Variadic:$idxs); + let results = (outs LazyFinch_LogicNodeType); +} + +def LazyFinch_MapJoin : LazyFinch_Op<"mapjoin", [Pure]> { + let summary = "Lazy Finch mapjoin op"; + let description = [{ + Logical AST expression for mapping the function `op` across `args...`. + The order of fields in the mapjoin is `unique(vcat(map(getfields, args)...))` + + ```mlir + %0 = finch.mapjoin op, args : !finch.logicnode + ``` + }]; + let arguments = (ins AnyType:$op, Variadic:$args); + let results = (outs LazyFinch_LogicNodeType); +} + +def LazyFinch_Aggregate : LazyFinch_Op<"aggregate", [Pure]> { + let summary = "Lazy Finch aggregate op"; + let description = [{ + Logical AST statement that reduces `arg` using `op`, starting with `init`. + `idxs` are the dimensions to reduce. May happen in any order. + + ```mlir + %0 = finch.aggregate op, init, arg, idxs : !finch.logicnode + ``` + }]; + let arguments = (ins AnyType:$op, AnyType:$init, AnyType:$arg, Variadic:$args); + let results = (outs LazyFinch_LogicNodeType); +} + +def LazyFinch_Reformat : LazyFinch_Op<"reformat", [Pure]> { + let summary = "Lazy Finch reformat op"; + let description = [{ + Logical AST statement that reformats `arg` into the tensor `tns`. + + ```mlir + %0 = finch.reformat tns, arg : !finch.logicnode + ``` + }]; + let arguments = (ins AnyType:$tns, AnyType:$arg); + let results = (outs LazyFinch_LogicNodeType); +} + #endif // FINCH_OPS diff --git a/include/Finch/FinchPasses.td b/include/Finch/FinchPasses.td index 4e0407e..1d58f22 100644 --- a/include/Finch/FinchPasses.td +++ b/include/Finch/FinchPasses.td @@ -121,6 +121,15 @@ def FinchLoopletPass: Pass<"finch-looplet-pass"> { } +// HIGH-LEVEL LAZY API + + +def LazyFinchIsolateReformatsPass: Pass<"lazy-finch-isolate-reformats-pass"> { + let summary = "Isolate reformats"; + let description = [{ + Optimization pass for isolating reformats. + }]; +} #endif // FINCH_PASS diff --git a/include/Finch/FinchTypes.td b/include/Finch/FinchTypes.td index aef8689..6c42f3d 100644 --- a/include/Finch/FinchTypes.td +++ b/include/Finch/FinchTypes.td @@ -45,4 +45,15 @@ def AnyNumberOrIndex : AnyTypeOf<[AnyInteger, AnyFloat, Index]>; def Looplet : AnyTypeOf<[Finch_LoopletType]>; def LoopletOrNumber : AnyTypeOf<[Finch_LoopletType, AnyNumber]>; + +// HIGH-LEVEL LAZY API + + +def LazyFinch_LogicNodeType : TypeDef { + let summary = "Finch LogicNode type"; + let description = "LogicNode type in Finch dialect"; + let mnemonic = "logicnode"; +} + + #endif // FINCH_TYPES diff --git a/lib/Finch/FinchPasses.cpp b/lib/Finch/FinchPasses.cpp index 9bfabba..550dc84 100644 --- a/lib/Finch/FinchPasses.cpp +++ b/lib/Finch/FinchPasses.cpp @@ -753,7 +753,51 @@ class FinchLoopletPass }; +class LazyFinchIsolateReformatsRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(finch::ReformatOp op, PatternRewriter &rewriter) const { + Value tns = op.getOperand(0); + Value arg = op.getOperand(1); + + Alias a = "#A##123"; // generate unique symbol somehow + finch::AliasOp alias = finch::AliasOp(a); + + // rather a pseudocode + rewriter.replaceOpWithNewOp(op, alias, finch::ReformatOp(tns, arg)); + return success(); + } +} + +// mirrors optimize function from optimize.jl +class LazyFinchLogicPass + : public impl::LazyFinchLogicPassBase { +public: + using impl::LazyFinchLogicPassBase< + LazyFinchLogicPass>::LazyFinchLogicPassBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + //patterns.add(&getContext()); + patterns.add(&getContext()); + //patterns.add(&getContext()); + //patterns.add(&getContext()); + //patterns.add(&getContext()); + //patterns.add(&getContext()); + //patterns.add(&getContext()); + //patterns.add(&getContext()); + //patterns.add(&getContext()); + //patterns.add(&getContext()); + //patterns.add(&getContext()); + //patterns.add(&getContext()); + //patterns.add(&getContext()); + //... + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + signalPassFailure(); + } +}; } // namespace