Skip to content

Commit f4408d9

Browse files
committed
Adds call to use inserTemporaries
1 parent 4f4ac27 commit f4408d9

File tree

3 files changed

+5
-1
lines changed

3 files changed

+5
-1
lines changed

src/index_notation/transformations.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,9 @@ static IndexStmt optimizeSpMM(IndexStmt stmt) {
654654
}
655655

656656
// It's an SpMM statement so return an optimized SpMM statement
657-
TensorVar w("w", Type(Float64, {Dimension()}), dense);
657+
TensorVar w("w",
658+
Type(Float64, {A.getType().getShape().getDimension(1)}),
659+
dense);
658660
return forall(i,
659661
where(forall(j,
660662
A(i,j) = w(j)),

src/tensor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ void TensorBase::compile(bool assembleWhileCompute) {
461461
IndexStmt stmt = makeConcrete(assignment);
462462
string reason;
463463
stmt = reorderLoopsTopologically(stmt);
464+
stmt = insertTemporaries(stmt);
464465
taco_uassert(stmt != IndexStmt()) << reason;
465466
stmt = parallelizeOuterLoop(stmt);
466467
content->assembleFunc = lower(stmt, "assemble", true, false);

tools/taco.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,7 @@ int main(int argc, char* argv[]) {
817817

818818
string reason;
819819
stmt = reorderLoopsTopologically(stmt);
820+
stmt = insertTemporaries(stmt);
820821
taco_uassert(stmt != IndexStmt()) << reason;
821822
stmt = parallelizeOuterLoop(stmt);
822823
compute = lower(stmt, "compute", false, true);

0 commit comments

Comments
 (0)