Skip to content

Commit 468ad7f

Browse files
Merge pull request #362 from rohany/fuse-bug-2
lower: fix a bug causing undefined variables when applying fuse
2 parents 864b65d + 6e57653 commit 468ad7f

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

src/lower/lowerer_impl.cpp

+18-1
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,24 @@ Stmt LowererImpl::lowerForall(Forall forall)
417417
Expr recoveredValue = provGraph.recoverVariable(varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
418418
taco_iassert(indexVarToExprMap.count(varToRecover));
419419
recoverySteps.push_back(VarDecl::make(indexVarToExprMap[varToRecover], recoveredValue));
420+
421+
// After we've recovered this index variable, some iterators are now
422+
// accessible for use when declaring locator access variables. So, generate
423+
// the accessors for those locator variables as part of the recovery process.
424+
// This is necessary after a fuse transformation, for example: If we fuse
425+
// two index variables (i, j) into f, then after we've generated the loop for
426+
// f, all locate accessors for i and j are now available for use.
427+
std::vector<Iterator> itersForVar;
428+
for (auto& iters : iterators.levelIterators()) {
429+
// Collect all level iterators that have locate and iterate over
430+
// the recovered index variable.
431+
if (iters.second.getIndexVar() == varToRecover && iters.second.hasLocate()) {
432+
itersForVar.push_back(iters.second);
433+
}
434+
}
435+
// Finally, declare all of the collected iterators' position access variables.
436+
recoverySteps.push_back(this->declLocatePosVars(itersForVar));
437+
420438
// place underived guard
421439
std::vector<ir::Expr> iterBounds = provGraph.deriveIterBounds(varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
422440
if (forallNeedsUnderivedGuards && underivedBounds.count(varToRecover) &&
@@ -2275,7 +2293,6 @@ Stmt LowererImpl::declLocatePosVars(vector<Iterator> locators) {
22752293
if (locateIterator.isLeaf()) {
22762294
break;
22772295
}
2278-
22792296
locateIterator = locateIterator.getChild();
22802297
} while (accessibleIterators.contains(locateIterator));
22812298
}

test/tests-scheduling.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,54 @@ TEST(scheduling, splitIndexStmt) {
7272
ASSERT_TRUE(equals(a(i) = b(i), i2Forall.getStmt()));
7373
}
7474

75+
TEST(scheduling, fuseDenseLoops) {
76+
auto dim = 4;
77+
Tensor<int> A("A", {dim, dim, dim}, {Dense, Dense, Dense});
78+
Tensor<int> B("B", {dim, dim, dim}, {Dense, Dense, Dense});
79+
Tensor<int> expected("expected", {dim, dim, dim}, {Dense, Dense, Dense});
80+
IndexVar f("f"), g("g");
81+
for (int i = 0; i < dim; i++) {
82+
for (int j = 0; j < dim; j++) {
83+
for (int k = 0; k < dim; k++) {
84+
A.insert({i, j, k}, i + j + k);
85+
B.insert({i, j, k}, i + j + k);
86+
expected.insert({i, j, k}, 2 * (i + j + k));
87+
}
88+
}
89+
}
90+
A.pack();
91+
B.pack();
92+
expected.pack();
93+
94+
// Helper function to evaluate the target statement and verify the results.
95+
// It takes in a function that applies some scheduling transforms to the
96+
// input IndexStmt, and applies to the point-wise tensor addition below.
97+
// The test is structured this way as TACO does its best to avoid re-compilation
98+
// whenever possible. I.e. changing the stmt that a tensor is compiled with
99+
// doesn't cause compilation to occur again.
100+
auto testFn = [&](std::function<IndexStmt(IndexStmt)> modifier) {
101+
Tensor<int> C("C", {dim, dim, dim}, {Dense, Dense, Dense});
102+
C(i, j, k) = A(i, j, k) + B(i, j, k);
103+
auto stmt = C.getAssignment().concretize();
104+
C.compile(modifier(stmt));
105+
C.evaluate();
106+
ASSERT_TRUE(equals(C, expected)) << endl << C << endl << expected << endl;
107+
};
108+
109+
// First, a sanity check with no transformations.
110+
testFn([](IndexStmt stmt) { return stmt; });
111+
// Next, fuse the outer two loops. This tests the original bug in #355.
112+
testFn([&](IndexStmt stmt) {
113+
return stmt.fuse(i, j, f);
114+
});
115+
// Lastly, fuse all of the loops into a single loop. This ensures that
116+
// locators with a chain of ancestors have all of their dependencies
117+
// generated in a valid ordering.
118+
testFn([&](IndexStmt stmt) {
119+
return stmt.fuse(i, j, f).fuse(f, k, g);
120+
});
121+
}
122+
75123
TEST(scheduling, lowerDenseMatrixMul) {
76124
Tensor<double> A("A", {4, 4}, {Dense, Dense});
77125
Tensor<double> B("B", {4, 4}, {Dense, Dense});

test/tests-transpose.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,26 @@ TEST(DISABLED_lower, transpose3) {
7676
&reason));
7777
ASSERT_EQ(error::expr_transposition, reason);
7878
}
79+
80+
// denseIterationTranspose tests a dense iteration that contain a transposition
81+
// of one of the tensors.
82+
TEST(lower, denseIterationTranspose) {
83+
auto dim = 4;
84+
Tensor<int> A("A", {dim, dim, dim}, {Dense, Dense, Dense});
85+
Tensor<int> B("B", {dim, dim, dim}, {Dense, Dense, Dense});
86+
Tensor<int> C("C", {dim, dim, dim}, {Dense, Dense, Dense});
87+
Tensor<int> expected("expected", {dim, dim, dim}, {Dense, Dense, Dense});
88+
for (int i = 0; i < dim; i++) {
89+
for (int j = 0; j < dim; j++) {
90+
for (int k = 0; k < dim; k++) {
91+
A.insert({i, j, k}, i + j + k);
92+
B.insert({i, j, k}, i + j + k);
93+
expected.insert({i, j, k}, 2 * (i + j + k));
94+
}
95+
}
96+
}
97+
A.pack(); B.pack(); expected.pack();
98+
C(i, j, k) = A(i, j, k) + B(k, j, i);
99+
C.evaluate();
100+
ASSERT_TRUE(equals(C, expected));
101+
}

0 commit comments

Comments
 (0)