@@ -605,6 +605,9 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
605
605
inParallelLoopDepth++;
606
606
}
607
607
608
+ // Record that we might have some fresh locators that need to be recovered.
609
+ std::vector<Iterator> freshLocateIterators;
610
+
608
611
// Recover any available parents that were not recoverable previously
609
612
vector<Stmt> recoverySteps;
610
613
for (const IndexVar& varToRecover : provGraph.newlyRecoverableParents (forall.getIndexVar (), definedIndexVars)) {
@@ -634,17 +637,16 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
634
637
// the accessors for those locator variables as part of the recovery process.
635
638
// This is necessary after a fuse transformation, for example: If we fuse
636
639
// two index variables (i, j) into f, then after we've generated the loop for
637
- // f, all locate accessors for i and j are now available for use.
640
+ // f, all locate accessors for i and j are now available for use. So, remember
641
+ // // that we have some new locate iterators that should be recovered.
638
642
std::vector<Iterator> itersForVar;
639
643
for (auto & iters : iterators.levelIterators ()) {
640
644
// Collect all level iterators that have locate and iterate over
641
645
// the recovered index variable.
642
646
if (iters.second .getIndexVar () == varToRecover && iters.second .hasLocate ()) {
643
- itersForVar .push_back (iters.second );
647
+ freshLocateIterators .push_back (iters.second );
644
648
}
645
649
}
646
- // Finally, declare all of the collected iterators' position access variables.
647
- recoverySteps.push_back (this ->declLocatePosVars (itersForVar));
648
650
649
651
// place underived guard
650
652
std::vector<ir::Expr> iterBounds = provGraph.deriveIterBounds (varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
@@ -799,7 +801,15 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
799
801
}
800
802
// Emit dimension coordinate iteration loop
801
803
else if (iterator.isDimensionIterator ()) {
802
- loops = lowerForallDimension (forall, point.locators (),
804
+ // A proper fix to #355. Adding information that those locate iterators are now ready is the
805
+ // correct way to recover them, rather than blindly duplicating the emitted locators.
806
+ auto locatorsCopy = std::vector<Iterator>(point.locators ());
807
+ for (auto it : freshLocateIterators) {
808
+ if (!util::contains (locatorsCopy, it)) {
809
+ locatorsCopy.push_back (it);
810
+ }
811
+ }
812
+ loops = lowerForallDimension (forall, locatorsCopy,
803
813
inserters, appenders, reducedAccesses, recoveryStmt);
804
814
}
805
815
// Emit position iteration loop
@@ -1772,14 +1782,19 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt,
1772
1782
const set<Access>& reducedAccesses) {
1773
1783
Stmt initVals = resizeAndInitValues (appenders, reducedAccesses);
1774
1784
1775
- // Inserter positions
1776
- Stmt declInserterPosVars = declLocatePosVars (inserters);
1777
-
1778
- // Locate positions
1779
- Stmt declLocatorPosVars = declLocatePosVars (locators);
1785
+ // There can be overlaps between the inserters and locators, which results in
1786
+ // duplicate emitting of variable declarations. We'll fix that here.
1787
+ std::vector<Iterator> itersWithLocators;
1788
+ for (auto it : inserters) {
1789
+ if (!util::contains (itersWithLocators, it)) { itersWithLocators.push_back (it); }
1790
+ }
1791
+ for (auto it : locators) {
1792
+ if (!util::contains (itersWithLocators, it)) { itersWithLocators.push_back (it); }
1793
+ }
1794
+ auto declPosVars = declLocatePosVars (itersWithLocators);
1780
1795
1781
1796
if (captureNextLocatePos) {
1782
- capturedLocatePos = Block::make (declInserterPosVars, declLocatorPosVars) ;
1797
+ capturedLocatePos = declPosVars ;
1783
1798
captureNextLocatePos = false ;
1784
1799
}
1785
1800
@@ -1792,8 +1807,7 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt,
1792
1807
// TODO: Emit code to insert coordinates
1793
1808
1794
1809
return Block::make (initVals,
1795
- declInserterPosVars,
1796
- declLocatorPosVars,
1810
+ declPosVars,
1797
1811
body,
1798
1812
appendCoords);
1799
1813
}
0 commit comments