Skip to content

Commit 27e898c

Browse files
committed
Fixed precompute transformation and attempt at fixing tensor-compiler#355. Also generate more optimized attribute query code for parallel sparse tensor addition
1 parent 45ca20e commit 27e898c

File tree

6 files changed

+585
-110
lines changed

6 files changed

+585
-110
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class WindowedIndexVar;
3434
class IndexSetVar;
3535
class TensorVar;
3636

37+
class IndexStmt;
3738
class IndexExpr;
3839
class Assignment;
3940
class Access;
@@ -63,6 +64,14 @@ struct SuchThatNode;
6364
class IndexExprVisitorStrict;
6465
class IndexStmtVisitorStrict;
6566

67+
/// Return true if the index statement is of the given subtype. The subtypes
68+
/// are Assignment, Forall, Where, Sequence, and Multi.
69+
template <typename SubType> bool isa(IndexExpr);
70+
71+
/// Casts the index statement to the given subtype. Assumes S is a subtype and
72+
/// the subtypes are Assignment, Forall, Where, Sequence, and Multi.
73+
template <typename SubType> SubType to(IndexExpr);
74+
6675
/// A tensor index expression describes a tensor computation as a scalar
6776
/// expression where tensors are indexed by index variables (`IndexVar`). The
6877
/// index variables range over the tensor dimensions they index, and the scalar
@@ -161,6 +170,12 @@ class IndexExpr : public util::IntrusivePtr<const IndexExprNode> {
161170
/// Returns the schedule of the index expression.
162171
const Schedule& getSchedule() const;
163172

173+
/// Casts index expression to specified subtype.
174+
template <typename SubType>
175+
SubType as() {
176+
return to<SubType>(*this);
177+
}
178+
164179
/// Visit the index expression's sub-expressions.
165180
void accept(IndexExprVisitorStrict *) const;
166181

@@ -204,14 +219,6 @@ IndexExpr operator*(const IndexExpr&, const IndexExpr&);
204219
/// ```
205220
IndexExpr operator/(const IndexExpr&, const IndexExpr&);
206221

207-
/// Return true if the index statement is of the given subtype. The subtypes
208-
/// are Assignment, Forall, Where, Sequence, and Multi.
209-
template <typename SubType> bool isa(IndexExpr);
210-
211-
/// Casts the index statement to the given subtype. Assumes S is a subtype and
212-
/// the subtypes are Assignment, Forall, Where, Sequence, and Multi.
213-
template <typename SubType> SubType to(IndexExpr);
214-
215222

216223
/// An index expression that represents a tensor access, such as `A(i,j))`.
217224
/// Access expressions are returned when calling the overloaded operator() on
@@ -514,6 +521,14 @@ class Reduction : public IndexExpr {
514521
/// Create a summation index expression.
515522
Reduction sum(IndexVar i, IndexExpr expr);
516523

524+
/// Return true if the index statement is of the given subtype. The subtypes
525+
/// are Assignment, Forall, Where, Multi, and Sequence.
526+
template <typename SubType> bool isa(IndexStmt);
527+
528+
/// Casts the index statement to the given subtype. Assumes S is a subtype and
529+
/// the subtypes are Assignment, Forall, Where, Multi, and Sequence.
530+
template <typename SubType> SubType to(IndexStmt);
531+
517532
/// A an index statement computes a tensor. The index statements are
518533
/// assignment, forall, where, multi, and sequence.
519534
class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
@@ -633,9 +648,9 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
633648
///
634649
/// Preconditions:
635650
/// The index variable supplied to the coord transformation must be in
636-
/// position space. The index variable supplied to the pos transformation
637-
/// must be in coordinate space. The pos transformation also takes an
638-
/// input to indicate which position space to use. This input must appear in the computation
651+
/// position space. The index variable supplied to the pos transformation must
652+
/// be in coordinate space. The pos transformation also takes an input to
653+
/// indicate which position space to use. This input must appear in the computation
639654
/// expression and also be indexed by this index variable. In the case that this
640655
/// index variable is derived from multiple index variables, these variables must appear
641656
/// directly nested in the mode ordering of this datastructure. This allows for
@@ -661,28 +676,38 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
661676
/// to the pos transformation.
662677
IndexStmt fuse(IndexVar i, IndexVar j, IndexVar f) const;
663678

664-
/// The precompute transformation is described in kjolstad2019
665-
/// allows us to leverage scratchpad memories and
666-
/// reorder computations to increase locality
679+
/// The precompute transformation is described in kjolstad2019
680+
/// allows us to leverage scratchpad memories and
681+
/// reorder computations to increase locality
667682
IndexStmt precompute(IndexExpr expr, IndexVar i, IndexVar iw, TensorVar workspace) const;
668683

669684
/// bound specifies a compile-time constraint on an index variable's
670685
/// iteration space that allows knowledge of the
671686
/// size or structured sparsity pattern of the inputs to be
672-
/// incorporated during bounds propagatio
687+
/// incorporated during bounds propagation
673688
///
674689
/// Preconditions:
675-
/// The precondition for bound is that the computation bounds supplied are correct
676-
/// given the inputs that this code will be run on.
690+
/// The precondition for bound is that the computation bounds supplied are
691+
/// correct given the inputs that this code will be run on.
677692
IndexStmt bound(IndexVar i, IndexVar i1, size_t bound, BoundType bound_type) const;
678693

679-
/// The unroll
680-
/// primitive unrolls the corresponding loop by a statically-known
694+
/// The unroll primitive unrolls the corresponding loop by a statically-known
681695
/// integer number of iterations
682696
/// Preconditions: unrollFactor is a positive nonzero integer
683697
IndexStmt unroll(IndexVar i, size_t unrollFactor) const;
684698

699+
/// The assemble primitive specifies whether a result tensor should be
700+
/// assembled by appending or inserting nonzeros into the result tensor.
701+
/// In the latter case, the transformation inserts additional loops to
702+
/// precompute statistics about the result tensor that are required for
703+
/// preallocating memory and coordinating insertions of nonzeros.
685704
IndexStmt assemble(TensorVar result, AssembleStrategy strategy) const;
705+
706+
/// Casts index statement to specified subtype.
707+
template <typename SubType>
708+
SubType as() {
709+
return to<SubType>(*this);
710+
}
686711
};
687712

688713
/// Check if two index statements are isomorphic.
@@ -694,13 +719,6 @@ bool equals(IndexStmt, IndexStmt);
694719
/// Print the index statement.
695720
std::ostream& operator<<(std::ostream&, const IndexStmt&);
696721

697-
/// Return true if the index statement is of the given subtype. The subtypes
698-
/// are Assignment, Forall, Where, Multi, and Sequence.
699-
template <typename SubType> bool isa(IndexStmt);
700-
701-
/// Casts the index statement to the given subtype. Assumes S is a subtype and
702-
/// the subtypes are Assignment, Forall, Where, Multi, and Sequence.
703-
template <typename SubType> SubType to(IndexStmt);
704722

705723
/// An assignment statement assigns an index expression to the locations in a
706724
/// tensor given by an lhs access expression.

src/index_notation/transformations.cpp

Lines changed: 116 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#include <iostream>
1414
#include <algorithm>
1515
#include <limits>
16+
#include <set>
17+
#include <map>
18+
#include <vector>
1619

1720
using namespace std;
1821

@@ -171,6 +174,12 @@ static bool containsExpr(Assignment assignment, IndexExpr expr) {
171174
IndexExpr expr;
172175
bool contains = false;
173176

177+
void visit(const AccessNode* node) {
178+
if (equals(IndexExpr(node), expr)) {
179+
contains = true;
180+
}
181+
}
182+
174183
void visit(const UnaryExprNode* node) {
175184
if (equals(IndexExpr(node), expr)) {
176185
contains = true;
@@ -213,6 +222,60 @@ static Assignment getAssignmentContainingExpr(IndexStmt stmt, IndexExpr expr) {
213222
return assignment;
214223
}
215224

225+
static IndexStmt eliminateRedundantReductions(IndexStmt stmt,
226+
const std::set<TensorVar>* const candidates = nullptr) {
227+
228+
struct ReduceToAssign : public IndexNotationRewriter {
229+
using IndexNotationRewriter::visit;
230+
231+
const std::set<TensorVar>* const candidates;
232+
std::map<TensorVar,std::set<IndexVar>> availableVars;
233+
234+
ReduceToAssign(const std::set<TensorVar>* const candidates) :
235+
candidates(candidates) {}
236+
237+
IndexStmt rewrite(IndexStmt stmt) {
238+
for (const auto& result : getResults(stmt)) {
239+
availableVars[result] = {};
240+
}
241+
return IndexNotationRewriter::rewrite(stmt);
242+
}
243+
244+
void visit(const ForallNode* op) {
245+
for (auto& it : availableVars) {
246+
it.second.insert(op->indexVar);
247+
}
248+
IndexNotationRewriter::visit(op);
249+
for (auto& it : availableVars) {
250+
it.second.erase(op->indexVar);
251+
}
252+
}
253+
254+
void visit(const WhereNode* op) {
255+
const auto workspaces = getResults(op->producer);
256+
for (const auto& workspace : workspaces) {
257+
availableVars[workspace] = {};
258+
}
259+
IndexNotationRewriter::visit(op);
260+
for (const auto& workspace : workspaces) {
261+
availableVars.erase(workspace);
262+
}
263+
}
264+
265+
void visit(const AssignmentNode* op) {
266+
const auto result = op->lhs.getTensorVar();
267+
if (op->op.defined() &&
268+
util::toSet(op->lhs.getIndexVars()) == availableVars[result] &&
269+
(!candidates || util::contains(*candidates, result))) {
270+
stmt = Assignment(op->lhs, op->rhs);
271+
return;
272+
}
273+
stmt = op;
274+
}
275+
};
276+
return ReduceToAssign(candidates).rewrite(stmt);
277+
}
278+
216279
IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
217280
INIT_REASON(reason);
218281

@@ -229,30 +292,68 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
229292

230293
Precompute precompute;
231294

232-
void visit(const ForallNode* node) {
233-
Forall foralli(node);
295+
void visit(const ForallNode* op) {
296+
Forall foralli(op);
234297
IndexVar i = precompute.geti();
298+
IndexVar j = foralli.getIndexVar();
235299

236-
if (foralli.getIndexVar() == i) {
300+
Assignment assign = getAssignmentContainingExpr(foralli,
301+
precompute.getExpr());
302+
if (j == i && assign.defined()) {
237303
IndexStmt s = foralli.getStmt();
238304
TensorVar ws = precompute.getWorkspace();
239305
IndexExpr e = precompute.getExpr();
240306
IndexVar iw = precompute.getiw();
241307

242308
IndexStmt consumer = forall(i, replace(s, {{e, ws(i)}}));
243-
IndexStmt producer = forall(iw, ws(iw) = replace(e, {{i,iw}}));
309+
IndexStmt producer = forall(iw, Assignment(ws(iw), replace(e, {{i,iw}}),
310+
assign.getOperator()));
244311
Where where(consumer, producer);
245312

246313
stmt = where;
247314
return;
248315
}
249-
IndexNotationRewriter::visit(node);
250-
}
251316

317+
IndexStmt s = rewrite(op->stmt);
318+
if (s == op->stmt) {
319+
stmt = op;
320+
return;
321+
} else if (isa<Where>(s)) {
322+
Where body = to<Where>(s);
323+
const auto consumerHasJ =
324+
util::contains(body.getConsumer().getIndexVars(), j);
325+
const auto producerHasJ =
326+
util::contains(body.getProducer().getIndexVars(), j);
327+
if (consumerHasJ && !producerHasJ) {
328+
const auto producer = body.getProducer();
329+
const auto consumer = Forall(op->indexVar, body.getConsumer(),
330+
op->parallel_unit,
331+
op->output_race_strategy,
332+
op->unrollFactor);
333+
stmt = Where(consumer, producer);
334+
return;
335+
} else if (producerHasJ && !consumerHasJ) {
336+
const auto producer = Forall(op->indexVar, body.getProducer(),
337+
op->parallel_unit,
338+
op->output_race_strategy,
339+
op->unrollFactor);
340+
const auto consumer = body.getConsumer();
341+
stmt = Where(consumer, producer);
342+
return;
343+
}
344+
}
345+
stmt = Forall(op->indexVar, s, op->parallel_unit,
346+
op->output_race_strategy, op->unrollFactor);
347+
}
252348
};
253349
PrecomputeRewriter rewriter;
254350
rewriter.precompute = *this;
255-
return rewriter.rewrite(stmt);
351+
stmt = rewriter.rewrite(stmt);
352+
353+
// Convert redundant reductions to assignments
354+
stmt = eliminateRedundantReductions(stmt);
355+
356+
return stmt;
256357
}
257358

258359
void Precompute::print(std::ostream& os) const {
@@ -506,23 +607,24 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const {
506607
Iterators iterators(foralli, tensorVars);
507608
definedIndexVars.insert(foralli.getIndexVar());
508609
MergeLattice lattice = MergeLattice::make(foralli, iterators, provGraph, definedIndexVars);
509-
// Precondition 3: No parallelization of variables under a reduction
610+
// Precondition 1: No parallelization of variables under a reduction
510611
// variable (ie MergePoint has at least 1 result iterators)
511-
if (parallelize.getOutputRaceStrategy() == OutputRaceStrategy::NoRaces && lattice.results().empty()
512-
&& lattice != MergeLattice({MergePoint({iterators.modeIterator(foralli.getIndexVar())}, {}, {})})) {
612+
if (parallelize.getOutputRaceStrategy() == OutputRaceStrategy::NoRaces &&
613+
(lattice.results().empty() || lattice.results()[0].getIndexVar() != foralli.getIndexVar()) &&
614+
lattice != MergeLattice({MergePoint({iterators.modeIterator(foralli.getIndexVar())}, {}, {})})) {
513615
reason = "Precondition failed: Free variables cannot be dominated by reduction variables in the iteration graph, "
514616
"as this causes scatter behavior and we do not yet emit parallel synchronization constructs";
515617
return;
516618
}
517619

518620
if (foralli.getIndexVar() == i) {
519-
// Precondition 1: No coiteration of node (ie Merge Lattice has only 1 iterator)
621+
// Precondition 2: No coiteration of mode (ie Merge Lattice has only 1 iterator)
520622
if (lattice.iterators().size() != 1) {
521623
reason = "Precondition failed: The loop must not merge tensor dimensions, that is, it must be a for loop;";
522624
return;
523625
}
524626

525-
// Precondition 2: Every result iterator must have insert capability
627+
// Precondition 3: Every result iterator must have insert capability
526628
for (Iterator iterator : lattice.results()) {
527629
if (util::contains(assembledByUngroupedInsert, iterator.getTensor())) {
528630
for (Iterator it = iterator; !it.isRoot(); it = it.getParent()) {
@@ -923,37 +1025,8 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
9231025
}
9241026

9251027
// Convert redundant reductions to assignments
926-
struct ReduceToAssign : public IndexNotationRewriter {
927-
using IndexNotationRewriter::visit;
928-
929-
const std::set<TensorVar>& insertedResults;
930-
std::set<IndexVar> availableVars;
931-
932-
ReduceToAssign(const std::set<TensorVar>& insertedResults) :
933-
insertedResults(insertedResults) {}
934-
935-
void visit(const ForallNode* op) {
936-
availableVars.insert(op->indexVar);
937-
IndexNotationRewriter::visit(op);
938-
availableVars.erase(op->indexVar);
939-
}
940-
941-
void visit(const AssignmentNode* op) {
942-
std::set<IndexVar> accessVars;
943-
for (const auto& index : op->lhs.getIndexVars()) {
944-
accessVars.insert(index);
945-
}
946-
947-
if (op->op.defined() && accessVars == availableVars &&
948-
util::contains(insertedResults, op->lhs.getTensorVar())) {
949-
stmt = new AssignmentNode(op->lhs, op->rhs, IndexExpr());
950-
return;
951-
}
952-
953-
stmt = op;
954-
}
955-
};
956-
loweredQueries = ReduceToAssign(insertedResults).rewrite(loweredQueries);
1028+
loweredQueries = eliminateRedundantReductions(loweredQueries,
1029+
&insertedResults);
9571030

9581031
// Inline definitions of temporaries into their corresponding uses, as long
9591032
// as the temporaries are not the results of reductions

0 commit comments

Comments
 (0)