Skip to content

Commit 2ebfbc7

Browse files
committed
fix producer consumer internchange
when the producer is at the end of the assignment, the argument packing needs to be changed according to the changed index statement
1 parent 270daf8 commit 2ebfbc7

File tree

3 files changed

+143
-1
lines changed

3 files changed

+143
-1
lines changed

include/taco/tensor.h

+1
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ class TensorBase {
429429

430430
/// Compute the given expression and put the values in the tensor storage.
431431
void compute();
432+
void compute(IndexStmt stmt);
432433

433434
/// Compile, assemble and compute as needed.
434435
void evaluate();

src/tensor.cpp

+87
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,41 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr) {
775775
return getOperands.arguments;
776776
}
777777

778+
static inline map<TensorVar, TensorBase> getTensors(const IndexStmt& stmt, vector<TensorVar>& operands) {
779+
struct GetOperands : public IndexNotationVisitor {
780+
using IndexNotationVisitor::visit;
781+
vector<TensorVar>& operands;
782+
map<TensorVar, TensorBase> arguments;
783+
784+
GetOperands(vector<TensorVar>& operands) : operands(operands) {}
785+
786+
void visit(const AccessNode* node) {
787+
if (!isa<AccessTensorNode>(node)) {
788+
return; // temporary ignore
789+
}
790+
Access ac = Access(node);
791+
taco_iassert(isa<AccessTensorNode>(node)) << "Unknown subexpression";
792+
793+
if (!util::contains(arguments, node->tensorVar)) {
794+
arguments.insert({node->tensorVar, to<AccessTensorNode>(node)->tensor});
795+
operands.push_back(node->tensorVar);
796+
}
797+
798+
// Also add any tensors backing index sets of tensor accesses.
799+
for (auto& p : node->indexSetModes) {
800+
auto tv = p.second.tensor.getTensorVar();
801+
if (!util::contains(arguments, tv)) {
802+
arguments.insert({tv, p.second.tensor});
803+
operands.push_back(tv);
804+
}
805+
}
806+
}
807+
};
808+
GetOperands getOperands(operands);
809+
stmt.accept(&getOperands);
810+
return getOperands.arguments;
811+
}
812+
778813
static inline
779814
vector<void*> packArguments(const TensorBase& tensor) {
780815
vector<void*> arguments;
@@ -805,6 +840,35 @@ vector<void*> packArguments(const TensorBase& tensor) {
805840
return arguments;
806841
}
807842

843+
static inline
844+
vector<void*> packArguments(const TensorBase& tensor, const IndexStmt stmt) {
845+
vector<void*> arguments;
846+
847+
// Pack the result tensor
848+
arguments.push_back(tensor.getStorage());
849+
850+
// Pack any index sets on the result tensor at the front of the arguments list.
851+
auto lhs = getNode(tensor.getAssignment().getLhs());
852+
// We check isa<AccessNode> rather than isa<AccessTensorNode> to catch cases
853+
// where the underlying access is represented with the base AccessNode class.
854+
if (isa<AccessNode>(lhs)) {
855+
auto indexSetModes = to<AccessNode>(lhs)->indexSetModes;
856+
for (auto& it : indexSetModes) {
857+
arguments.push_back(it.second.tensor.getStorage());
858+
}
859+
}
860+
861+
// Pack operand tensors
862+
std::vector<TensorVar> operands;
863+
auto tensors = getTensors(stmt, operands);
864+
for (auto& operand : operands) {
865+
taco_iassert(util::contains(tensors, operand));
866+
arguments.push_back(tensors.at(operand).getStorage());
867+
}
868+
869+
return arguments;
870+
}
871+
808872
void TensorBase::assemble() {
809873
taco_uassert(!needsCompile()) << error::assemble_without_compile;
810874
if (!needsAssemble()) {
@@ -849,6 +913,29 @@ void TensorBase::compute() {
849913
}
850914
}
851915

916+
void TensorBase::compute(IndexStmt stmt) {
917+
taco_uassert(!needsCompile()) << error::compute_without_compile;
918+
if (!needsCompute()) {
919+
return;
920+
}
921+
setNeedsCompute(false);
922+
// Sync operand tensors if needed.
923+
auto operands = getTensors(getAssignment().getRhs());
924+
for (auto& operand : operands) {
925+
operand.second.syncValues();
926+
operand.second.removeDependentTensor(*this);
927+
}
928+
929+
auto arguments = packArguments(*this, stmt);
930+
this->content->module->callFuncPacked("compute", arguments.data());
931+
932+
if (content->assembleWhileCompute) {
933+
setNeedsAssemble(false);
934+
taco_tensor_t* tensorData = ((taco_tensor_t*)arguments[0]);
935+
content->valuesSize = unpackTensorData(*tensorData, *this);
936+
}
937+
}
938+
852939
void TensorBase::evaluate() {
853940
this->compile();
854941
if (!getAssignment().getOperator().defined()) {

test/tests-workspaces.cpp

+55-1
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ TEST(workspaces, tile_dotProduct_3) {
652652

653653
TEST(workspaces, loopfuse) {
654654
int N = 16;
655+
float SPARSITY = 0.3;
655656
Tensor<double> A("A", {N, N}, Format{Dense, Dense});
656657
Tensor<double> B("B", {N, N}, Format{Dense, Sparse});
657658
Tensor<double> C("C", {N, N}, Format{Dense, Dense});
@@ -660,7 +661,9 @@ TEST(workspaces, loopfuse) {
660661

661662
for (int i = 0; i < N; i++) {
662663
for (int j = 0; j < N; j++) {
663-
B.insert({i, j}, (double) i);
664+
float rand_float = (float) rand() / (float) RAND_MAX;
665+
if (rand_float < SPARSITY)
666+
B.insert({i, j}, (double) i);
664667
C.insert({i, j}, (double) j);
665668
E.insert({i, j}, (double) i*j);
666669
D.insert({i, j}, (double) i*j);
@@ -703,6 +706,57 @@ TEST(workspaces, loopfuse) {
703706
}
704707

705708

709+
TEST(workspaces, loopreversefuse) {
710+
int N = 16;
711+
float SPARSITY = 0.3;
712+
Tensor<double> A("A", {N, N}, Format{Dense, Dense});
713+
Tensor<double> B("B", {N, N}, Format{Dense, Sparse});
714+
Tensor<double> C("C", {N, N}, Format{Dense, Dense});
715+
Tensor<double> D("D", {N, N}, Format{Dense, Dense});
716+
Tensor<double> E("E", {N, N}, Format{Dense, Dense});
717+
718+
for (int i = 0; i < N; i++) {
719+
for (int j = 0; j < N; j++) {
720+
float rand_float = (float) rand() / (float) RAND_MAX;
721+
if (rand_float < SPARSITY)
722+
B.insert({i, j}, (double) rand_float);
723+
C.insert({i, j}, (double) j);
724+
E.insert({i, j}, (double) i*j);
725+
D.insert({i, j}, (double) i*j);
726+
}
727+
}
728+
729+
IndexVar i("i"), j("j"), k("k"), l("l"), m("m");
730+
A(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m);
731+
732+
IndexStmt stmt = A.getAssignment().concretize();
733+
734+
std::cout << stmt << endl;
735+
vector<int> path1;
736+
stmt = stmt
737+
.reorder({m,k,l,i,j})
738+
.loopfuse(2, false, path1)
739+
;
740+
stmt = stmt
741+
.parallelize(m, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces)
742+
;
743+
744+
stmt = stmt.concretize();
745+
cout << "final stmt: " << stmt << endl;
746+
printCodeToFile("loopreversefuse", stmt);
747+
748+
A.compile(stmt);
749+
B.pack();
750+
A.assemble();
751+
A.compute(stmt);
752+
753+
Tensor<double> expected("expected", {N, N}, Format{Dense, Dense});
754+
expected(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m);
755+
expected.compile();
756+
expected.assemble();
757+
expected.compute();
758+
ASSERT_TENSOR_EQ(expected, A);
759+
}
706760

707761
TEST(workspaces, loopcontractfuse) {
708762
int N = 16;

0 commit comments

Comments
 (0)