Skip to content

Commit 46a40a0

Browse files
committed
lower: properly fix #355
This commit properly fixes #355 by ensuring that duplicate locators are not generated by different codepaths. This bug is masked by the ir::simplify call which removes the extra locators in most situations.
1 parent 0ede002 commit 46a40a0

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

src/lower/lowerer_impl_imperative.cpp

+27-13
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,9 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
605605
inParallelLoopDepth++;
606606
}
607607

608+
// Record that we might have some fresh locators that need to be recovered.
609+
std::vector<Iterator> freshLocateIterators;
610+
608611
// Recover any available parents that were not recoverable previously
609612
vector<Stmt> recoverySteps;
610613
for (const IndexVar& varToRecover : provGraph.newlyRecoverableParents(forall.getIndexVar(), definedIndexVars)) {
@@ -634,17 +637,16 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
634637
// the accessors for those locator variables as part of the recovery process.
635638
// This is necessary after a fuse transformation, for example: If we fuse
636639
// two index variables (i, j) into f, then after we've generated the loop for
637-
// f, all locate accessors for i and j are now available for use.
640+
// f, all locate accessors for i and j are now available for use. So, remember
641+
// // that we have some new locate iterators that should be recovered.
638642
std::vector<Iterator> itersForVar;
639643
for (auto& iters : iterators.levelIterators()) {
640644
// Collect all level iterators that have locate and iterate over
641645
// the recovered index variable.
642646
if (iters.second.getIndexVar() == varToRecover && iters.second.hasLocate()) {
643-
itersForVar.push_back(iters.second);
647+
freshLocateIterators.push_back(iters.second);
644648
}
645649
}
646-
// Finally, declare all of the collected iterators' position access variables.
647-
recoverySteps.push_back(this->declLocatePosVars(itersForVar));
648650

649651
// place underived guard
650652
std::vector<ir::Expr> iterBounds = provGraph.deriveIterBounds(varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
@@ -799,7 +801,15 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
799801
}
800802
// Emit dimension coordinate iteration loop
801803
else if (iterator.isDimensionIterator()) {
802-
loops = lowerForallDimension(forall, point.locators(),
804+
// A proper fix to #355. Adding information that those locate iterators are now ready is the
805+
// correct way to recover them, rather than blindly duplicating the emitted locators.
806+
auto locatorsCopy = std::vector<Iterator>(point.locators());
807+
for (auto it : freshLocateIterators) {
808+
if (!util::contains(locatorsCopy, it)) {
809+
locatorsCopy.push_back(it);
810+
}
811+
}
812+
loops = lowerForallDimension(forall, locatorsCopy,
803813
inserters, appenders, reducedAccesses, recoveryStmt);
804814
}
805815
// Emit position iteration loop
@@ -1772,14 +1782,19 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt,
17721782
const set<Access>& reducedAccesses) {
17731783
Stmt initVals = resizeAndInitValues(appenders, reducedAccesses);
17741784

1775-
// Inserter positions
1776-
Stmt declInserterPosVars = declLocatePosVars(inserters);
1777-
1778-
// Locate positions
1779-
Stmt declLocatorPosVars = declLocatePosVars(locators);
1785+
// There can be overlaps between the inserters and locators, which results in
1786+
// duplicate emitting of variable declarations. We'll fix that here.
1787+
std::vector<Iterator> itersWithLocators;
1788+
for (auto it : inserters) {
1789+
if (!util::contains(itersWithLocators, it)) { itersWithLocators.push_back(it); }
1790+
}
1791+
for (auto it : locators) {
1792+
if (!util::contains(itersWithLocators, it)) { itersWithLocators.push_back(it); }
1793+
}
1794+
auto declPosVars = declLocatePosVars(itersWithLocators);
17801795

17811796
if (captureNextLocatePos) {
1782-
capturedLocatePos = Block::make(declInserterPosVars, declLocatorPosVars);
1797+
capturedLocatePos = declPosVars;
17831798
captureNextLocatePos = false;
17841799
}
17851800

@@ -1792,8 +1807,7 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt,
17921807
// TODO: Emit code to insert coordinates
17931808

17941809
return Block::make(initVals,
1795-
declInserterPosVars,
1796-
declLocatorPosVars,
1810+
declPosVars,
17971811
body,
17981812
appendCoords);
17991813
}

0 commit comments

Comments
 (0)