Skip to content

Commit f964d00

Browse files
committed
Various fixes including better temporary inserts
1 parent f4408d9 commit f964d00

File tree

5 files changed

+39
-24
lines changed

5 files changed

+39
-24
lines changed

src/index_notation/kernel.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ void unpackResults(size_t numResults, const vector<void*> arguments,
6565
num *= ((int*)tensorData->indices[i][0])[0];
6666
} else if (modeType.getName() == Sparse.getName()) {
6767
auto size = ((int*)tensorData->indices[i][0])[num];
68-
Array pos = Array(type<int>(), tensorData->indices[i][0], num+1, Array::UserOwns);
69-
Array idx = Array(type<int>(), tensorData->indices[i][1], size, Array::UserOwns);
68+
Array pos = Array(type<int>(), tensorData->indices[i][0],
69+
num+1, Array::UserOwns);
70+
Array idx = Array(type<int>(), tensorData->indices[i][1],
71+
size, Array::UserOwns);
7072
modeIndices.push_back(ModeIndex({pos, idx}));
7173
num = size;
7274
} else {

src/index_notation/transformations.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -644,19 +644,33 @@ static IndexStmt optimizeSpMM(IndexStmt stmt) {
644644
}
645645

646646
TensorVar A = Aaccess.getTensorVar();
647+
if (A.getFormat().getModeFormats()[0].getName() != "dense" ||
648+
A.getFormat().getModeFormats()[1].getName() != "compressed" ||
649+
A.getFormat().getModeOrdering()[0] != 0 ||
650+
A.getFormat().getModeOrdering()[1] != 1) {
651+
return stmt;
652+
}
653+
647654
TensorVar B = Baccess.getTensorVar();
648-
TensorVar C = Caccess.getTensorVar();
655+
if (B.getFormat().getModeFormats()[0].getName() != "dense" ||
656+
B.getFormat().getModeFormats()[1].getName() != "compressed" ||
657+
B.getFormat().getModeOrdering()[0] != 0 ||
658+
B.getFormat().getModeOrdering()[1] != 1) {
659+
return stmt;
660+
}
649661

650-
if (A.getFormat() != CSR ||
651-
B.getFormat() != CSR ||
652-
C.getFormat() != CSR) {
662+
TensorVar C = Caccess.getTensorVar();
663+
if (C.getFormat().getModeFormats()[0].getName() != "dense" ||
664+
C.getFormat().getModeFormats()[1].getName() != "compressed" ||
665+
C.getFormat().getModeOrdering()[0] != 0 ||
666+
C.getFormat().getModeOrdering()[1] != 1) {
653667
return stmt;
654668
}
655669

656670
// It's an SpMM statement so return an optimized SpMM statement
657671
TensorVar w("w",
658672
Type(Float64, {A.getType().getShape().getDimension(1)}),
659-
dense);
673+
taco::dense);
660674
return forall(i,
661675
where(forall(j,
662676
A(i,j) = w(j)),

src/lower/lowerer_impl.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ static bool hasStores(Stmt stmt) {
102102
return stmt.defined() && FindStores().hasStores(stmt);
103103
}
104104

105-
Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
106-
bool compute) {
105+
Stmt
106+
LowererImpl::lower(IndexStmt stmt, string name, bool assemble, bool compute)
107+
{
107108
this->assemble = assemble;
108109
this->compute = compute;
109110

@@ -125,7 +126,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
125126
// Create iterators
126127
iterators = Iterators::make(stmt, tensorVars, &indexVars);
127128

128-
vector<Access> inputAccesses, resultAccesses;
129+
vector<Access> inputAccesses, resultAccesses;
129130
set<Access> reducedAccesses;
130131
inputAccesses = getArgumentAccesses(stmt);
131132
std::tie(resultAccesses, reducedAccesses) = getResultAccesses(stmt);

test/tests-lower.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ using taco::error::expr_transposition;
4545
#include "taco/lower/mode_format_dense.h"
4646
taco::ModeFormat dense(std::make_shared<taco::DenseModeFormat>());
4747

48-
static const Dimension n, m, o;
48+
static const Dimension n;
4949
static const Type vectype(Float64, {n});
50-
static const Type mattype(Float64, {n,m});
51-
static const Type tentype(Float64, {n,m,o});
50+
static const Type mattype(Float64, {n,n});
51+
static const Type tentype(Float64, {n,n,n});
5252

5353
static TensorVar alpha("alpha", Float64);
5454
static TensorVar beta("beta", Float64);
@@ -282,17 +282,14 @@ TEST_P(lower, compile) {
282282
}
283283

284284
{
285-
SCOPED_TRACE("Separate Assembly and Compute\n" +
286-
toString(taco::lower(stmt,"assemble",true,false)) + "\n" +
287-
toString(taco::lower(stmt,"compute",false,true)));
285+
SCOPED_TRACE("Separate Assembly and Compute\n");
288286
ASSERT_TRUE(kernel.assemble(arguments));
289287
ASSERT_TRUE(kernel.compute(arguments));
290288
verifyResults(results, arguments, varsFormatted, expected);
291289
}
292290

293291
{
294-
SCOPED_TRACE("Fused Assembly and Compute\n" +
295-
toString(taco::lower(stmt,"evaluate",true,true)));
292+
SCOPED_TRACE("Fused Assembly and Compute\n");
296293
ASSERT_TRUE(kernel(arguments));
297294
verifyResults(results, arguments, varsFormatted, expected);
298295
}
@@ -734,8 +731,8 @@ TEST_STMT(DISABLED_where_spmm,
734731
forall(j,
735732
w(j) += B(i,k) * C(k,j))))),
736733
Values(
737-
Formats({{A,Format({dense,dense})},
738-
{B,Format({dense,dense})}, {C,Format({dense,dense})}}),
734+
// Formats({{A,Format({dense,dense})},
735+
// {B,Format({dense,dense})}, {C,Format({dense,dense})}}),
739736
Formats({{A,Format({dense,sparse})},
740737
{B,Format({dense,sparse})}, {C,Format({dense,sparse})}})
741738
),
@@ -1584,4 +1581,3 @@ TEST_STMT(vector_not,
15841581
{{a, {{{1}, 1.0}, {{2}, 1.0}, {{3}, 1.0}}}})
15851582
}
15861583
)
1587-

tools/taco.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -811,15 +811,17 @@ int main(int argc, char* argv[]) {
811811
else {
812812
if (newLower) {
813813
IndexStmt stmt = makeConcrete(tensor.getAssignment());
814-
if (printConcrete) {
815-
cout << stmt << endl;
816-
}
817814

818815
string reason;
819816
stmt = reorderLoopsTopologically(stmt);
820817
stmt = insertTemporaries(stmt);
821818
taco_uassert(stmt != IndexStmt()) << reason;
822819
stmt = parallelizeOuterLoop(stmt);
820+
821+
if (printConcrete) {
822+
cout << stmt << endl;
823+
}
824+
823825
compute = lower(stmt, "compute", false, true);
824826
assemble = lower(stmt, "assemble", true, false);
825827
evaluate = lower(stmt, "evaluate", true, true);

0 commit comments

Comments
 (0)