13
13
#include < iostream>
14
14
#include < algorithm>
15
15
#include < limits>
16
+ #include < set>
17
+ #include < map>
18
+ #include < vector>
16
19
17
20
using namespace std ;
18
21
@@ -171,6 +174,12 @@ static bool containsExpr(Assignment assignment, IndexExpr expr) {
171
174
IndexExpr expr;
172
175
bool contains = false ;
173
176
177
+ void visit (const AccessNode* node) {
178
+ if (equals (IndexExpr (node), expr)) {
179
+ contains = true ;
180
+ }
181
+ }
182
+
174
183
void visit (const UnaryExprNode* node) {
175
184
if (equals (IndexExpr (node), expr)) {
176
185
contains = true ;
@@ -213,6 +222,60 @@ static Assignment getAssignmentContainingExpr(IndexStmt stmt, IndexExpr expr) {
213
222
return assignment;
214
223
}
215
224
225
+ static IndexStmt eliminateRedundantReductions (IndexStmt stmt,
226
+ const std::set<TensorVar>* const candidates = nullptr ) {
227
+
228
+ struct ReduceToAssign : public IndexNotationRewriter {
229
+ using IndexNotationRewriter::visit;
230
+
231
+ const std::set<TensorVar>* const candidates;
232
+ std::map<TensorVar,std::set<IndexVar>> availableVars;
233
+
234
+ ReduceToAssign (const std::set<TensorVar>* const candidates) :
235
+ candidates (candidates) {}
236
+
237
+ IndexStmt rewrite (IndexStmt stmt) {
238
+ for (const auto & result : getResults (stmt)) {
239
+ availableVars[result] = {};
240
+ }
241
+ return IndexNotationRewriter::rewrite (stmt);
242
+ }
243
+
244
+ void visit (const ForallNode* op) {
245
+ for (auto & it : availableVars) {
246
+ it.second .insert (op->indexVar );
247
+ }
248
+ IndexNotationRewriter::visit (op);
249
+ for (auto & it : availableVars) {
250
+ it.second .erase (op->indexVar );
251
+ }
252
+ }
253
+
254
+ void visit (const WhereNode* op) {
255
+ const auto workspaces = getResults (op->producer );
256
+ for (const auto & workspace : workspaces) {
257
+ availableVars[workspace] = {};
258
+ }
259
+ IndexNotationRewriter::visit (op);
260
+ for (const auto & workspace : workspaces) {
261
+ availableVars.erase (workspace);
262
+ }
263
+ }
264
+
265
+ void visit (const AssignmentNode* op) {
266
+ const auto result = op->lhs .getTensorVar ();
267
+ if (op->op .defined () &&
268
+ util::toSet (op->lhs .getIndexVars ()) == availableVars[result] &&
269
+ (!candidates || util::contains (*candidates, result))) {
270
+ stmt = Assignment (op->lhs , op->rhs );
271
+ return ;
272
+ }
273
+ stmt = op;
274
+ }
275
+ };
276
+ return ReduceToAssign (candidates).rewrite (stmt);
277
+ }
278
+
216
279
IndexStmt Precompute::apply (IndexStmt stmt, std::string* reason) const {
217
280
INIT_REASON (reason);
218
281
@@ -229,30 +292,68 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
229
292
230
293
Precompute precompute;
231
294
232
- void visit (const ForallNode* node ) {
233
- Forall foralli (node );
295
+ void visit (const ForallNode* op ) {
296
+ Forall foralli (op );
234
297
IndexVar i = precompute.geti ();
298
+ IndexVar j = foralli.getIndexVar ();
235
299
236
- if (foralli.getIndexVar () == i) {
300
+ Assignment assign = getAssignmentContainingExpr (foralli,
301
+ precompute.getExpr ());
302
+ if (j == i && assign.defined ()) {
237
303
IndexStmt s = foralli.getStmt ();
238
304
TensorVar ws = precompute.getWorkspace ();
239
305
IndexExpr e = precompute.getExpr ();
240
306
IndexVar iw = precompute.getiw ();
241
307
242
308
IndexStmt consumer = forall (i, replace (s, {{e, ws (i)}}));
243
- IndexStmt producer = forall (iw, ws (iw) = replace (e, {{i,iw}}));
309
+ IndexStmt producer = forall (iw, Assignment (ws (iw), replace (e, {{i,iw}}),
310
+ assign.getOperator ()));
244
311
Where where (consumer, producer);
245
312
246
313
stmt = where;
247
314
return ;
248
315
}
249
- IndexNotationRewriter::visit (node);
250
- }
251
316
317
+ IndexStmt s = rewrite (op->stmt );
318
+ if (s == op->stmt ) {
319
+ stmt = op;
320
+ return ;
321
+ } else if (isa<Where>(s)) {
322
+ Where body = to<Where>(s);
323
+ const auto consumerHasJ =
324
+ util::contains (body.getConsumer ().getIndexVars (), j);
325
+ const auto producerHasJ =
326
+ util::contains (body.getProducer ().getIndexVars (), j);
327
+ if (consumerHasJ && !producerHasJ) {
328
+ const auto producer = body.getProducer ();
329
+ const auto consumer = Forall (op->indexVar , body.getConsumer (),
330
+ op->parallel_unit ,
331
+ op->output_race_strategy ,
332
+ op->unrollFactor );
333
+ stmt = Where (consumer, producer);
334
+ return ;
335
+ } else if (producerHasJ && !consumerHasJ) {
336
+ const auto producer = Forall (op->indexVar , body.getProducer (),
337
+ op->parallel_unit ,
338
+ op->output_race_strategy ,
339
+ op->unrollFactor );
340
+ const auto consumer = body.getConsumer ();
341
+ stmt = Where (consumer, producer);
342
+ return ;
343
+ }
344
+ }
345
+ stmt = Forall (op->indexVar , s, op->parallel_unit ,
346
+ op->output_race_strategy , op->unrollFactor );
347
+ }
252
348
};
253
349
PrecomputeRewriter rewriter;
254
350
rewriter.precompute = *this ;
255
- return rewriter.rewrite (stmt);
351
+ stmt = rewriter.rewrite (stmt);
352
+
353
+ // Convert redundant reductions to assignments
354
+ stmt = eliminateRedundantReductions (stmt);
355
+
356
+ return stmt;
256
357
}
257
358
258
359
void Precompute::print (std::ostream& os) const {
@@ -506,23 +607,24 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const {
506
607
Iterators iterators (foralli, tensorVars);
507
608
definedIndexVars.insert (foralli.getIndexVar ());
508
609
MergeLattice lattice = MergeLattice::make (foralli, iterators, provGraph, definedIndexVars);
509
- // Precondition 3 : No parallelization of variables under a reduction
610
+ // Precondition 1 : No parallelization of variables under a reduction
510
611
// variable (ie MergePoint has at least 1 result iterators)
511
- if (parallelize.getOutputRaceStrategy () == OutputRaceStrategy::NoRaces && lattice.results ().empty ()
512
- && lattice != MergeLattice ({MergePoint ({iterators.modeIterator (foralli.getIndexVar ())}, {}, {})})) {
612
+ if (parallelize.getOutputRaceStrategy () == OutputRaceStrategy::NoRaces &&
613
+ (lattice.results ().empty () || lattice.results ()[0 ].getIndexVar () != foralli.getIndexVar ()) &&
614
+ lattice != MergeLattice ({MergePoint ({iterators.modeIterator (foralli.getIndexVar ())}, {}, {})})) {
513
615
reason = " Precondition failed: Free variables cannot be dominated by reduction variables in the iteration graph, "
514
616
" as this causes scatter behavior and we do not yet emit parallel synchronization constructs" ;
515
617
return ;
516
618
}
517
619
518
620
if (foralli.getIndexVar () == i) {
519
- // Precondition 1 : No coiteration of node (ie Merge Lattice has only 1 iterator)
621
+ // Precondition 2 : No coiteration of mode (ie Merge Lattice has only 1 iterator)
520
622
if (lattice.iterators ().size () != 1 ) {
521
623
reason = " Precondition failed: The loop must not merge tensor dimensions, that is, it must be a for loop;" ;
522
624
return ;
523
625
}
524
626
525
- // Precondition 2 : Every result iterator must have insert capability
627
+ // Precondition 3 : Every result iterator must have insert capability
526
628
for (Iterator iterator : lattice.results ()) {
527
629
if (util::contains (assembledByUngroupedInsert, iterator.getTensor ())) {
528
630
for (Iterator it = iterator; !it.isRoot (); it = it.getParent ()) {
@@ -923,37 +1025,8 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
923
1025
}
924
1026
925
1027
// Convert redundant reductions to assignments
926
- struct ReduceToAssign : public IndexNotationRewriter {
927
- using IndexNotationRewriter::visit;
928
-
929
- const std::set<TensorVar>& insertedResults;
930
- std::set<IndexVar> availableVars;
931
-
932
- ReduceToAssign (const std::set<TensorVar>& insertedResults) :
933
- insertedResults (insertedResults) {}
934
-
935
- void visit (const ForallNode* op) {
936
- availableVars.insert (op->indexVar );
937
- IndexNotationRewriter::visit (op);
938
- availableVars.erase (op->indexVar );
939
- }
940
-
941
- void visit (const AssignmentNode* op) {
942
- std::set<IndexVar> accessVars;
943
- for (const auto & index : op->lhs .getIndexVars ()) {
944
- accessVars.insert (index );
945
- }
946
-
947
- if (op->op .defined () && accessVars == availableVars &&
948
- util::contains (insertedResults, op->lhs .getTensorVar ())) {
949
- stmt = new AssignmentNode (op->lhs , op->rhs , IndexExpr ());
950
- return ;
951
- }
952
-
953
- stmt = op;
954
- }
955
- };
956
- loweredQueries = ReduceToAssign (insertedResults).rewrite (loweredQueries);
1028
+ loweredQueries = eliminateRedundantReductions (loweredQueries,
1029
+ &insertedResults);
957
1030
958
1031
// Inline definitions of temporaries into their corresponding uses, as long
959
1032
// as the temporaries are not the results of reductions
0 commit comments