9
9
#include < set>
10
10
#include < map>
11
11
#include < utility>
12
+ #include < functional>
12
13
14
+ #include " taco/util/name_generator.h"
13
15
#include " taco/format.h"
14
16
#include " taco/error.h"
15
17
#include " taco/util/intrusive_ptr.h"
22
24
#include " taco/ir_tags.h"
23
25
#include " taco/lower/iterator.h"
24
26
#include " taco/index_notation/provenance_graph.h"
27
+ #include " taco/index_notation/properties.h"
25
28
26
29
namespace taco {
27
30
@@ -38,6 +41,8 @@ class IndexExpr;
38
41
class Assignment ;
39
42
class Access ;
40
43
44
+ class IterationAlgebra ;
45
+
41
46
struct AccessNode ;
42
47
struct AccessWindow ;
43
48
struct LiteralNode ;
@@ -48,8 +53,10 @@ struct SubNode;
48
53
struct MulNode ;
49
54
struct DivNode ;
50
55
struct CastNode ;
56
+ struct CallNode ;
51
57
struct CallIntrinsicNode ;
52
58
struct ReductionNode ;
59
+ struct IndexVarNode ;
53
60
54
61
struct AssignmentNode ;
55
62
struct YieldNode ;
@@ -262,6 +269,11 @@ class Access : public IndexExpr {
262
269
Assignment operator +=(const IndexExpr&);
263
270
264
271
typedef AccessNode Node;
272
+
273
+ // Equality and comparison are overridden on Access to perform a deep
274
+ // comparison of the access rather than a pointer check.
275
+ friend bool operator ==(const Access& a, const Access& b);
276
+ friend bool operator <(const Access& a, const Access &b);
265
277
};
266
278
267
279
@@ -289,11 +301,14 @@ class Literal : public IndexExpr {
289
301
Literal (std::complex<float >);
290
302
Literal (std::complex<double >);
291
303
292
- static IndexExpr zero (Datatype);
304
+ static Literal zero (Datatype);
293
305
294
306
// / Returns the literal value.
295
307
template <typename T> T getVal () const ;
296
308
309
+ // / Returns an untyped pointer to the literal value
310
+ void * getValPtr ();
311
+
297
312
typedef LiteralNode Node;
298
313
};
299
314
@@ -413,6 +428,26 @@ class Cast : public IndexExpr {
413
428
typedef CastNode Node;
414
429
};
415
430
431
+ // / A call to an operator
432
+ class Call : public IndexExpr {
433
+ public:
434
+ Call () = default ;
435
+ Call (const CallNode*);
436
+ Call (const CallNode*, std::string name);
437
+
438
+ const std::vector<IndexExpr>& getArgs () const ;
439
+ const std::function<ir::Expr(const std::vector<ir::Expr>&)> getFunc () const ;
440
+ const IterationAlgebra& getAlgebra () const ;
441
+ const std::vector<Property>& getProperties () const ;
442
+ const std::string getName () const ;
443
+ const std::map<std::vector<int >, std::function<ir::Expr(const std::vector<ir::Expr>&)>> getDefs () const ;
444
+ const std::vector<int >& getDefinedArgs () const ;
445
+
446
+ typedef CallNode Node;
447
+
448
+ private:
449
+ std::string name;
450
+ };
416
451
417
452
// / A call to an intrinsic.
418
453
// / ```
@@ -433,6 +468,8 @@ class CallIntrinsic : public IndexExpr {
433
468
typedef CallIntrinsicNode Node;
434
469
};
435
470
471
+ std::ostream& operator <<(std::ostream&, const IndexVar&);
472
+
436
473
// / Create calls to various intrinsics.
437
474
IndexExpr mod (IndexExpr, IndexExpr);
438
475
IndexExpr abs (IndexExpr);
@@ -871,17 +908,27 @@ class WindowedIndexVar : public util::Comparable<WindowedIndexVar>, public Index
871
908
872
909
// / Index variables are used to index into tensors in index expressions, and
873
910
// / they represent iteration over the tensor modes they index into.
874
- class IndexVar : public util ::Comparable<IndexVar>, public IndexVarInterface {
911
+ class IndexVar : public IndexExpr , public IndexVarInterface {
912
+
875
913
public:
876
914
IndexVar ();
877
915
~IndexVar () = default ;
878
916
IndexVar (const std::string& name);
917
+ IndexVar (const std::string& name, const Datatype& type);
918
+ IndexVar (const IndexVarNode *);
879
919
880
920
// / Returns the name of the index variable.
881
921
std::string getName () const ;
882
922
923
+ // Need these to overshadow the comparisons in for the IndexExpr instrusive pointer
883
924
friend bool operator ==(const IndexVar&, const IndexVar&);
884
925
friend bool operator <(const IndexVar&, const IndexVar&);
926
+ friend bool operator !=(const IndexVar&, const IndexVar&);
927
+ friend bool operator >=(const IndexVar&, const IndexVar&);
928
+ friend bool operator <=(const IndexVar&, const IndexVar&);
929
+ friend bool operator >(const IndexVar&, const IndexVar&);
930
+
931
+ typedef IndexVarNode Node;
885
932
886
933
// / Indexing into an IndexVar returns a window into it.
887
934
WindowedIndexVar operator ()(int lo, int hi);
@@ -927,11 +974,12 @@ SuchThat suchthat(IndexStmt stmt, std::vector<IndexVarRel> predicate);
927
974
class TensorVar : public util ::Comparable<TensorVar> {
928
975
public:
929
976
TensorVar ();
930
- TensorVar (const Type& type);
931
- TensorVar (const std::string& name, const Type& type);
932
- TensorVar (const Type& type, const Format& format);
933
- TensorVar (const std::string& name, const Type& type, const Format& format);
934
- TensorVar (const int &id, const std::string& name, const Type& type, const Format& format);
977
+ TensorVar (const Type& type, const Literal& fill = Literal());
978
+ TensorVar (const std::string& name, const Type& type, const Literal& fill = Literal());
979
+ TensorVar (const Type& type, const Format& format, const Literal& fill = Literal());
980
+ TensorVar (const std::string& name, const Type& type, const Format& format, const Literal& fill = Literal());
981
+ TensorVar (const int &id, const std::string& name, const Type& type, const Format& format,
982
+ const Literal& fill = Literal());
935
983
936
984
// / Returns the ID of the tensor variable.
937
985
int getId () const ;
@@ -952,6 +1000,12 @@ class TensorVar : public util::Comparable<TensorVar> {
952
1000
// / and execute it's expression.
953
1001
const Schedule& getSchedule () const ;
954
1002
1003
+ // / Gets the fill value of the tensor variable. May be left undefined.
1004
+ const Literal& getFill () const ;
1005
+
1006
+ // / Set the fill value of the tensor variable
1007
+ void setFill (const Literal& fill);
1008
+
955
1009
// / Set the name of the tensor variable.
956
1010
void setName (std::string name);
957
1011
@@ -1008,7 +1062,8 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr);
1008
1062
bool isReductionNotation (IndexStmt, std::string* reason=nullptr );
1009
1063
1010
1064
// / Check whether the statement is in the concrete index notation dialect.
1011
- // / This means every index variable has a forall node, there are no reduction
1065
+ // / This means every index variable has a forall node, each index variable used
1066
+ // / for computation is under a forall node for that variable, there are no reduction
1012
1067
// / nodes, and that every reduction variable use is nested inside a compound
1013
1068
// / assignment statement. You can optionally pass in a pointer to a string
1014
1069
// / that the reason why it is not concrete notation is printed to.
@@ -1030,7 +1085,12 @@ std::vector<TensorVar> getResults(IndexStmt stmt);
1030
1085
// / Returns the input tensors to the index statement, in the order they appear.
1031
1086
std::vector<TensorVar> getArguments (IndexStmt stmt);
1032
1087
1033
- // / Returns the temporaries in the index statement, in the order they appear.
1088
+ // / Returns true iff all of the loops over free variables come before all of the loops over
1089
+ // / reduction variables. Therefore, this returns true if the reduction controlled by the loops
1090
+ // / does not a scatter.
1091
+ bool allForFreeLoopsBeforeAllReductionLoops (IndexStmt stmt);
1092
+
1093
+ // / Returns the temporaries in the index statement, in the order they appear.
1034
1094
std::vector<TensorVar> getTemporaries (IndexStmt stmt);
1035
1095
1036
1096
// [Olivia]
@@ -1070,7 +1130,15 @@ IndexExpr zero(IndexExpr, const std::set<Access>& zeroed);
1070
1130
// / zero and then propagating and removing zeroes.
1071
1131
IndexStmt zero (IndexStmt, const std::set<Access>& zeroed);
1072
1132
1073
- // / Create an `other` tensor with the given name and format,
1133
+ // / Infers the fill value of the input expression by applying properties if possible. If unable
1134
+ // / to successfully infer the fill value of the result, returns the empty IndexExpr
1135
+ IndexExpr inferFill (IndexExpr);
1136
+
1137
+ // / Returns true if there are no forall nodes in the indexStmt. Used to check
1138
+ // / if the last loop is being lowered.
1139
+ bool hasNoForAlls (IndexStmt);
1140
+
1141
+ // / Create an `other` tensor with the given name and format,
1074
1142
// / and return tensor(indexVars) = other(indexVars) if otherIsOnRight,
1075
1143
// / and otherwise returns other(indexVars) = tensor(indexVars).
1076
1144
IndexStmt generatePackStmt (TensorVar tensor,
0 commit comments