@@ -72,6 +72,54 @@ TEST(scheduling, splitIndexStmt) {
72
72
ASSERT_TRUE (equals (a (i) = b (i), i2Forall.getStmt ()));
73
73
}
74
74
75
+ TEST (scheduling, fuseDenseLoops) {
76
+ auto dim = 4 ;
77
+ Tensor<int > A (" A" , {dim, dim, dim}, {Dense, Dense, Dense});
78
+ Tensor<int > B (" B" , {dim, dim, dim}, {Dense, Dense, Dense});
79
+ Tensor<int > expected (" expected" , {dim, dim, dim}, {Dense, Dense, Dense});
80
+ IndexVar f (" f" ), g (" g" );
81
+ for (int i = 0 ; i < dim; i++) {
82
+ for (int j = 0 ; j < dim; j++) {
83
+ for (int k = 0 ; k < dim; k++) {
84
+ A.insert ({i, j, k}, i + j + k);
85
+ B.insert ({i, j, k}, i + j + k);
86
+ expected.insert ({i, j, k}, 2 * (i + j + k));
87
+ }
88
+ }
89
+ }
90
+ A.pack ();
91
+ B.pack ();
92
+ expected.pack ();
93
+
94
+ // Helper function to evaluate the target statement and verify the results.
95
+ // It takes in a function that applies some scheduling transforms to the
96
+ // input IndexStmt, and applies to the point-wise tensor addition below.
97
+ // The test is structured this way as TACO does its best to avoid re-compilation
98
+ // whenever possible. I.e. changing the stmt that a tensor is compiled with
99
+ // doesn't cause compilation to occur again.
100
+ auto testFn = [&](std::function<IndexStmt (IndexStmt)> modifier) {
101
+ Tensor<int > C (" C" , {dim, dim, dim}, {Dense, Dense, Dense});
102
+ C (i, j, k) = A (i, j, k) + B (i, j, k);
103
+ auto stmt = C.getAssignment ().concretize ();
104
+ C.compile (modifier (stmt));
105
+ C.evaluate ();
106
+ ASSERT_TRUE (equals (C, expected)) << endl << C << endl << expected << endl;
107
+ };
108
+
109
+ // First, a sanity check with no transformations.
110
+ testFn ([](IndexStmt stmt) { return stmt; });
111
+ // Next, fuse the outer two loops. This tests the original bug in #355.
112
+ testFn ([&](IndexStmt stmt) {
113
+ return stmt.fuse (i, j, f);
114
+ });
115
+ // Lastly, fuse all of the loops into a single loop. This ensures that
116
+ // locators with a chain of ancestors have all of their dependencies
117
+ // generated in a valid ordering.
118
+ testFn ([&](IndexStmt stmt) {
119
+ return stmt.fuse (i, j, f).fuse (f, k, g);
120
+ });
121
+ }
122
+
75
123
TEST (scheduling, lowerDenseMatrixMul) {
76
124
Tensor<double > A (" A" , {4 , 4 }, {Dense, Dense});
77
125
Tensor<double > B (" B" , {4 , 4 }, {Dense, Dense});
0 commit comments