Skip to content

Commit 7d84d5d

Browse files
Merge pull request #412 from RawnH/array_algebra
Array algebra
2 parents bfdaa71 + b1f7f88 commit 7d84d5d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+5124
-309
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ if(NOT EXISTS "${TACO_PROJECT_DIR}/python_bindings/pybind11/CMakeLists.txt")
129129
endif()
130130

131131
if(PYTHON)
132+
add_subdirectory(python_bindings)
132133
message("-- Will build Python extension")
133134
add_definitions(-DPYTHON)
134135
endif(PYTHON)

include/taco.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "taco/tensor.h"
55
#include "taco/format.h"
6+
#include "taco/index_notation/tensor_operator.h"
67
#include "taco/index_notation/index_notation.h"
78

89
#endif

include/taco/index_notation/index_notation.h

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#include <set>
1010
#include <map>
1111
#include <utility>
12+
#include <functional>
1213

14+
#include "taco/util/name_generator.h"
1315
#include "taco/format.h"
1416
#include "taco/error.h"
1517
#include "taco/util/intrusive_ptr.h"
@@ -22,6 +24,7 @@
2224
#include "taco/ir_tags.h"
2325
#include "taco/lower/iterator.h"
2426
#include "taco/index_notation/provenance_graph.h"
27+
#include "taco/index_notation/properties.h"
2528

2629
namespace taco {
2730

@@ -38,6 +41,8 @@ class IndexExpr;
3841
class Assignment;
3942
class Access;
4043

44+
class IterationAlgebra;
45+
4146
struct AccessNode;
4247
struct AccessWindow;
4348
struct LiteralNode;
@@ -48,8 +53,10 @@ struct SubNode;
4853
struct MulNode;
4954
struct DivNode;
5055
struct CastNode;
56+
struct CallNode;
5157
struct CallIntrinsicNode;
5258
struct ReductionNode;
59+
struct IndexVarNode;
5360

5461
struct AssignmentNode;
5562
struct YieldNode;
@@ -262,6 +269,11 @@ class Access : public IndexExpr {
262269
Assignment operator+=(const IndexExpr&);
263270

264271
typedef AccessNode Node;
272+
273+
// Equality and comparison are overridden on Access to perform a deep
274+
// comparison of the access rather than a pointer check.
275+
friend bool operator==(const Access& a, const Access& b);
276+
friend bool operator<(const Access& a, const Access &b);
265277
};
266278

267279

@@ -289,11 +301,14 @@ class Literal : public IndexExpr {
289301
Literal(std::complex<float>);
290302
Literal(std::complex<double>);
291303

292-
static IndexExpr zero(Datatype);
304+
static Literal zero(Datatype);
293305

294306
/// Returns the literal value.
295307
template <typename T> T getVal() const;
296308

309+
/// Returns an untyped pointer to the literal value
310+
void* getValPtr();
311+
297312
typedef LiteralNode Node;
298313
};
299314

@@ -413,6 +428,26 @@ class Cast : public IndexExpr {
413428
typedef CastNode Node;
414429
};
415430

431+
/// A call to an operator
432+
class Call: public IndexExpr {
433+
public:
434+
Call() = default;
435+
Call(const CallNode*);
436+
Call(const CallNode*, std::string name);
437+
438+
const std::vector<IndexExpr>& getArgs() const;
439+
const std::function<ir::Expr(const std::vector<ir::Expr>&)> getFunc() const;
440+
const IterationAlgebra& getAlgebra() const;
441+
const std::vector<Property>& getProperties() const;
442+
const std::string getName() const;
443+
const std::map<std::vector<int>, std::function<ir::Expr(const std::vector<ir::Expr>&)>> getDefs() const;
444+
const std::vector<int>& getDefinedArgs() const;
445+
446+
typedef CallNode Node;
447+
448+
private:
449+
std::string name;
450+
};
416451

417452
/// A call to an intrinsic.
418453
/// ```
@@ -433,6 +468,8 @@ class CallIntrinsic : public IndexExpr {
433468
typedef CallIntrinsicNode Node;
434469
};
435470

471+
std::ostream& operator<<(std::ostream&, const IndexVar&);
472+
436473
/// Create calls to various intrinsics.
437474
IndexExpr mod(IndexExpr, IndexExpr);
438475
IndexExpr abs(IndexExpr);
@@ -871,17 +908,27 @@ class WindowedIndexVar : public util::Comparable<WindowedIndexVar>, public Index
871908

872909
/// Index variables are used to index into tensors in index expressions, and
873910
/// they represent iteration over the tensor modes they index into.
874-
class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
911+
class IndexVar : public IndexExpr, public IndexVarInterface {
912+
875913
public:
876914
IndexVar();
877915
~IndexVar() = default;
878916
IndexVar(const std::string& name);
917+
IndexVar(const std::string& name, const Datatype& type);
918+
IndexVar(const IndexVarNode *);
879919

880920
/// Returns the name of the index variable.
881921
std::string getName() const;
882922

923+
// Need these to overshadow the comparisons in for the IndexExpr instrusive pointer
883924
friend bool operator==(const IndexVar&, const IndexVar&);
884925
friend bool operator<(const IndexVar&, const IndexVar&);
926+
friend bool operator!=(const IndexVar&, const IndexVar&);
927+
friend bool operator>=(const IndexVar&, const IndexVar&);
928+
friend bool operator<=(const IndexVar&, const IndexVar&);
929+
friend bool operator>(const IndexVar&, const IndexVar&);
930+
931+
typedef IndexVarNode Node;
885932

886933
/// Indexing into an IndexVar returns a window into it.
887934
WindowedIndexVar operator()(int lo, int hi);
@@ -927,11 +974,12 @@ SuchThat suchthat(IndexStmt stmt, std::vector<IndexVarRel> predicate);
927974
class TensorVar : public util::Comparable<TensorVar> {
928975
public:
929976
TensorVar();
930-
TensorVar(const Type& type);
931-
TensorVar(const std::string& name, const Type& type);
932-
TensorVar(const Type& type, const Format& format);
933-
TensorVar(const std::string& name, const Type& type, const Format& format);
934-
TensorVar(const int &id, const std::string& name, const Type& type, const Format& format);
977+
TensorVar(const Type& type, const Literal& fill = Literal());
978+
TensorVar(const std::string& name, const Type& type, const Literal& fill = Literal());
979+
TensorVar(const Type& type, const Format& format, const Literal& fill = Literal());
980+
TensorVar(const std::string& name, const Type& type, const Format& format, const Literal& fill = Literal());
981+
TensorVar(const int &id, const std::string& name, const Type& type, const Format& format,
982+
const Literal& fill = Literal());
935983

936984
/// Returns the ID of the tensor variable.
937985
int getId() const;
@@ -952,6 +1000,12 @@ class TensorVar : public util::Comparable<TensorVar> {
9521000
/// and execute it's expression.
9531001
const Schedule& getSchedule() const;
9541002

1003+
/// Gets the fill value of the tensor variable. May be left undefined.
1004+
const Literal& getFill() const;
1005+
1006+
/// Set the fill value of the tensor variable
1007+
void setFill(const Literal& fill);
1008+
9551009
/// Set the name of the tensor variable.
9561010
void setName(std::string name);
9571011

@@ -1008,7 +1062,8 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr);
10081062
bool isReductionNotation(IndexStmt, std::string* reason=nullptr);
10091063

10101064
/// Check whether the statement is in the concrete index notation dialect.
1011-
/// This means every index variable has a forall node, there are no reduction
1065+
/// This means every index variable has a forall node, each index variable used
1066+
/// for computation is under a forall node for that variable, there are no reduction
10121067
/// nodes, and that every reduction variable use is nested inside a compound
10131068
/// assignment statement. You can optionally pass in a pointer to a string
10141069
/// that the reason why it is not concrete notation is printed to.
@@ -1030,7 +1085,12 @@ std::vector<TensorVar> getResults(IndexStmt stmt);
10301085
/// Returns the input tensors to the index statement, in the order they appear.
10311086
std::vector<TensorVar> getArguments(IndexStmt stmt);
10321087

1033-
/// Returns the temporaries in the index statement, in the order they appear.
1088+
/// Returns true iff all of the loops over free variables come before all of the loops over
1089+
/// reduction variables. Therefore, this returns true if the reduction controlled by the loops
1090+
/// does not a scatter.
1091+
bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt);
1092+
1093+
/// Returns the temporaries in the index statement, in the order they appear.
10341094
std::vector<TensorVar> getTemporaries(IndexStmt stmt);
10351095

10361096
// [Olivia]
@@ -1070,7 +1130,15 @@ IndexExpr zero(IndexExpr, const std::set<Access>& zeroed);
10701130
/// zero and then propagating and removing zeroes.
10711131
IndexStmt zero(IndexStmt, const std::set<Access>& zeroed);
10721132

1073-
/// Create an `other` tensor with the given name and format,
1133+
/// Infers the fill value of the input expression by applying properties if possible. If unable
1134+
/// to successfully infer the fill value of the result, returns the empty IndexExpr
1135+
IndexExpr inferFill(IndexExpr);
1136+
1137+
/// Returns true if there are no forall nodes in the indexStmt. Used to check
1138+
/// if the last loop is being lowered.
1139+
bool hasNoForAlls(IndexStmt);
1140+
1141+
/// Create an `other` tensor with the given name and format,
10741142
/// and return tensor(indexVars) = other(indexVars) if otherIsOnRight,
10751143
/// and otherwise returns other(indexVars) = tensor(indexVars).
10761144
IndexStmt generatePackStmt(TensorVar tensor,

include/taco/index_notation/index_notation_nodes.h

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33

44
#include <vector>
55
#include <memory>
6+
#include <numeric>
67

78
#include "taco/type.h"
9+
#include "taco/util/collections.h"
10+
#include "taco/util/comparable.h"
811
#include "taco/index_notation/index_notation.h"
912
#include "taco/index_notation/index_notation_nodes_abstract.h"
1013
#include "taco/index_notation/index_notation_visitor.h"
1114
#include "taco/index_notation/intrinsic.h"
1215
#include "taco/util/strings.h"
16+
#include "iteration_algebra.h"
17+
#include "properties.h"
1318

1419
namespace taco {
1520

@@ -23,6 +28,12 @@ struct AccessWindow {
2328
friend bool operator==(const AccessWindow& a, const AccessWindow& b) {
2429
return a.lo == b.lo && a.hi == b.hi;
2530
}
31+
friend bool operator<(const AccessWindow& a, const AccessWindow& b) {
32+
if (a.lo != b.lo) {
33+
return a.lo < b.lo;
34+
}
35+
return a.hi < b.hi;
36+
}
2637
};
2738

2839
struct AccessNode : public IndexExprNode {
@@ -68,7 +79,6 @@ struct LiteralNode : public IndexExprNode {
6879
void* val;
6980
};
7081

71-
7282
struct UnaryExprNode : public IndexExprNode {
7383
IndexExpr a;
7484

@@ -188,6 +198,57 @@ struct CallIntrinsicNode : public IndexExprNode {
188198
std::vector<IndexExpr> args;
189199
};
190200

201+
struct CallNode : public IndexExprNode {
202+
typedef std::function<ir::Expr(const std::vector<ir::Expr>&)> OpImpl;
203+
typedef std::function<IterationAlgebra(const std::vector<IndexExpr>&)> AlgebraImpl;
204+
205+
CallNode(std::string name, const std::vector<IndexExpr>& args, OpImpl lowerFunc,
206+
const IterationAlgebra& iterAlg,
207+
const std::vector<Property>& properties,
208+
const std::map<std::vector<int>, OpImpl>& regionDefinitions,
209+
const std::vector<int>& definedRegions);
210+
211+
CallNode(std::string name, const std::vector<IndexExpr>& args, OpImpl lowerFunc,
212+
const IterationAlgebra& iterAlg,
213+
const std::vector<Property>& properties,
214+
const std::map<std::vector<int>, OpImpl>& regionDefinitions);
215+
216+
void accept(IndexExprVisitorStrict* v) const {
217+
v->visit(this);
218+
}
219+
220+
std::string name;
221+
std::vector<IndexExpr> args;
222+
OpImpl defaultLowerFunc;
223+
IterationAlgebra iterAlg;
224+
std::vector<Property> properties;
225+
std::map<std::vector<int>, OpImpl> regionDefinitions;
226+
227+
// Needed to track which inputs have been exhausted so the lowerer can know which lower func to use
228+
std::vector<int> definedRegions;
229+
230+
private:
231+
static Datatype inferReturnType(OpImpl f, const std::vector<IndexExpr>& inputs) {
232+
std::function<ir::Expr(IndexExpr)> getExprs = [](IndexExpr arg) { return ir::Var::make("t", arg.getDataType()); };
233+
std::vector<ir::Expr> exprs = util::map(inputs, getExprs);
234+
235+
if(exprs.empty()) {
236+
return taco::Datatype();
237+
}
238+
239+
return f(exprs).type();
240+
}
241+
242+
static std::vector<int> definedIndices(std::vector<IndexExpr> args) {
243+
std::vector<int> v;
244+
for(int i = 0; i < (int) args.size(); ++i) {
245+
if(args[i].defined()) {
246+
v.push_back(i);
247+
}
248+
}
249+
return v;
250+
}
251+
};
191252

192253
struct ReductionNode : public IndexExprNode {
193254
ReductionNode(IndexExpr op, IndexVar var, IndexExpr a);
@@ -202,6 +263,27 @@ struct ReductionNode : public IndexExprNode {
202263
IndexExpr a;
203264
};
204265

266+
struct IndexVarNode : public IndexExprNode, public util::Comparable<IndexVarNode> {
267+
IndexVarNode() = delete;
268+
IndexVarNode(const std::string& name, const Datatype& type);
269+
270+
void accept(IndexExprVisitorStrict* v) const {
271+
v->visit(this);
272+
}
273+
274+
std::string getName() const;
275+
276+
friend bool operator==(const IndexVarNode& a, const IndexVarNode& b);
277+
friend bool operator<(const IndexVarNode& a, const IndexVarNode& b);
278+
279+
private:
280+
struct Content;
281+
std::shared_ptr<Content> content;
282+
};
283+
284+
struct IndexVarNode::Content {
285+
std::string name;
286+
};
205287

206288
// Index Statements
207289
struct AssignmentNode : public IndexStmtNode {

include/taco/index_notation/index_notation_printer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ class IndexNotationPrinter : public IndexNotationVisitorStrict {
2525
void visit(const MulNode*);
2626
void visit(const DivNode*);
2727
void visit(const CastNode*);
28+
void visit(const CallNode*);
2829
void visit(const CallIntrinsicNode*);
2930
void visit(const ReductionNode*);
31+
void visit(const IndexVarNode*);
3032

3133
// Tensor Expressions
3234
void visit(const AssignmentNode*);

include/taco/index_notation/index_notation_rewriter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ class IndexExprRewriterStrict : public IndexExprVisitorStrict {
3232
virtual void visit(const MulNode* op) = 0;
3333
virtual void visit(const DivNode* op) = 0;
3434
virtual void visit(const CastNode* op) = 0;
35+
virtual void visit(const CallNode* op) = 0;
3536
virtual void visit(const CallIntrinsicNode* op) = 0;
3637
virtual void visit(const ReductionNode* op) = 0;
38+
virtual void visit(const IndexVarNode* op) = 0;
3739
};
3840

3941

@@ -93,8 +95,10 @@ class IndexNotationRewriter : public IndexNotationRewriterStrict {
9395
virtual void visit(const MulNode* op);
9496
virtual void visit(const DivNode* op);
9597
virtual void visit(const CastNode* op);
98+
virtual void visit(const CallNode* op);
9699
virtual void visit(const CallIntrinsicNode* op);
97100
virtual void visit(const ReductionNode* op);
101+
virtual void visit(const IndexVarNode* op);
98102

99103
virtual void visit(const AssignmentNode* op);
100104
virtual void visit(const YieldNode* op);

0 commit comments

Comments
 (0)