Skip to content

lower: fix a bug causing undefined variables when applying fuse #362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/lower/lowerer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,24 @@ Stmt LowererImpl::lowerForall(Forall forall)
Expr recoveredValue = provGraph.recoverVariable(varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
taco_iassert(indexVarToExprMap.count(varToRecover));
recoverySteps.push_back(VarDecl::make(indexVarToExprMap[varToRecover], recoveredValue));

// After we've recovered this index variable, some iterators are now
// accessible for use when declaring locator access variables. So, generate
// the accessors for those locator variables as part of the recovery process.
// This is necessary after a fuse transformation, for example: If we fuse
// two index variables (i, j) into f, then after we've generated the loop for
// f, all locate accessors for i and j are now available for use.
std::vector<Iterator> itersForVar;
for (auto& iters : iterators.levelIterators()) {
// Collect all level iterators that have locate and iterate over
// the recovered index variable.
if (iters.second.getIndexVar() == varToRecover && iters.second.hasLocate()) {
itersForVar.push_back(iters.second);
}
}
// Finally, declare all of the collected iterators' position access variables.
recoverySteps.push_back(this->declLocatePosVars(itersForVar));

// place underived guard
std::vector<ir::Expr> iterBounds = provGraph.deriveIterBounds(varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
if (forallNeedsUnderivedGuards && underivedBounds.count(varToRecover) &&
Expand Down Expand Up @@ -2275,7 +2293,6 @@ Stmt LowererImpl::declLocatePosVars(vector<Iterator> locators) {
if (locateIterator.isLeaf()) {
break;
}

locateIterator = locateIterator.getChild();
} while (accessibleIterators.contains(locateIterator));
}
Expand Down
48 changes: 48 additions & 0 deletions test/tests-scheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,54 @@ TEST(scheduling, splitIndexStmt) {
ASSERT_TRUE(equals(a(i) = b(i), i2Forall.getStmt()));
}

TEST(scheduling, fuseDenseLoops) {
auto dim = 4;
Tensor<int> A("A", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> B("B", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> expected("expected", {dim, dim, dim}, {Dense, Dense, Dense});
IndexVar f("f"), g("g");
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
for (int k = 0; k < dim; k++) {
A.insert({i, j, k}, i + j + k);
B.insert({i, j, k}, i + j + k);
expected.insert({i, j, k}, 2 * (i + j + k));
}
}
}
A.pack();
B.pack();
expected.pack();

// Helper function to evaluate the target statement and verify the results.
// It takes in a function that applies some scheduling transforms to the
// input IndexStmt, and applies to the point-wise tensor addition below.
// The test is structured this way as TACO does its best to avoid re-compilation
// whenever possible. I.e. changing the stmt that a tensor is compiled with
// doesn't cause compilation to occur again.
auto testFn = [&](std::function<IndexStmt(IndexStmt)> modifier) {
Tensor<int> C("C", {dim, dim, dim}, {Dense, Dense, Dense});
C(i, j, k) = A(i, j, k) + B(i, j, k);
auto stmt = C.getAssignment().concretize();
C.compile(modifier(stmt));
C.evaluate();
ASSERT_TRUE(equals(C, expected)) << endl << C << endl << expected << endl;
};

// First, a sanity check with no transformations.
testFn([](IndexStmt stmt) { return stmt; });
// Next, fuse the outer two loops. This tests the original bug in #355.
testFn([&](IndexStmt stmt) {
return stmt.fuse(i, j, f);
});
// Lastly, fuse all of the loops into a single loop. This ensures that
// locators with a chain of ancestors have all of their dependencies
// generated in a valid ordering.
testFn([&](IndexStmt stmt) {
return stmt.fuse(i, j, f).fuse(f, k, g);
});
}

TEST(scheduling, lowerDenseMatrixMul) {
Tensor<double> A("A", {4, 4}, {Dense, Dense});
Tensor<double> B("B", {4, 4}, {Dense, Dense});
Expand Down
23 changes: 23 additions & 0 deletions test/tests-transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,26 @@ TEST(DISABLED_lower, transpose3) {
&reason));
ASSERT_EQ(error::expr_transposition, reason);
}

// denseIterationTranspose tests a dense iteration that contain a transposition
// of one of the tensors.
TEST(lower, denseIterationTranspose) {
auto dim = 4;
Tensor<int> A("A", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> B("B", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> C("C", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> expected("expected", {dim, dim, dim}, {Dense, Dense, Dense});
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
for (int k = 0; k < dim; k++) {
A.insert({i, j, k}, i + j + k);
B.insert({i, j, k}, i + j + k);
expected.insert({i, j, k}, 2 * (i + j + k));
}
}
}
A.pack(); B.pack(); expected.pack();
C(i, j, k) = A(i, j, k) + B(k, j, i);
C.evaluate();
ASSERT_TRUE(equals(C, expected));
}