diff --git a/examples/covar.stur b/examples/covar.stur new file mode 100644 index 0000000..b51fe11 --- /dev/null +++ b/examples/covar.stur @@ -0,0 +1,4 @@ +symbols: N, M +covar(j, k) := X(i, j) * X(i, k) +covar:D(i, j) := (0 <= i < N) * (0 <= j < N) +X:D(i, j) := (0 <= i < M) * (0 <= j < N) \ No newline at end of file diff --git a/examples/pseudo-lr-training.stur b/examples/pseudo-lr-training.stur new file mode 100644 index 0000000..13645aa --- /dev/null +++ b/examples/pseudo-lr-training.stur @@ -0,0 +1,13 @@ +symbols: N, P +outputs: DW, covar +covar(i, k) := X(j, i) * X(j, k) +T1(i) := covar(i, j) * w(i) +T2(i) := X(j, i) * y(j) +DW(i) := T1(i) + T2(i) +X:D(j, i) := (0 <= j) * (j < N) * (0 <= i) * (i < P) +y:D(j) := (0 <= j) * (j < N) +w:D(i) := (0 <= i) * (i < P) +covar:D(i, k) := (0 <= i) * (i < P) * (0 <= k) * (k < P) +T1:D(i) := (0 <= i) * (i < P) +T2:D(i) := (0 <= i) * (i < P) +DW:D(i) := (0 <= i) * (i < P) \ No newline at end of file diff --git a/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Compiler.scala b/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Compiler.scala index 14dc00b..d1d98c6 100644 --- a/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Compiler.scala +++ b/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Compiler.scala @@ -41,6 +41,22 @@ object Compiler { } .forall(_ == true) + def isSubSelfOuterProduct(vars: Seq[Seq[Variable]]): Boolean = { + val all_intersects = + vars.tail.foldLeft(vars.head)((acc, cur) => acc.intersect(cur)) + val vars_wo_intersects = vars.map(_.diff(all_intersects)) + val cond1 = isPairwiseIntersectEmpty(vars_wo_intersects) + val all_intersect_indecies = vars.map(v => + v.zipWithIndex.collect { + case (variable, index) if all_intersects.contains(variable) => index + } + ) + val cond2 = all_intersect_indecies.tail.forall( + _.toSet == all_intersect_indecies.head.toSet + ) + cond1 && cond2 + } + def groupBySameName( exp: Exp, rest: Seq[Exp], @@ -64,7 +80,11 @@ object Compiler { isThereAnySameNameLeft = true, init_exp = real_init_exp ) - else if (isPairwiseIntersectEmpty(sameNameList.map(e => e.vars))) + else if ( + isPairwiseIntersectEmpty( + sameNameList.map(e => e.vars) + ) || isSubSelfOuterProduct(sameNameList.map(e => e.vars)) + ) (sameNameList, true) else groupBySameName( @@ -141,7 +161,7 @@ object Compiler { } else { val newVariables = sameNameExpsList.flatMap { case acc @ Access(_, vars, _) => vars - } // This one does not distinct since the intersection of the variables is empty + }.distinct val newHead = Access(getVar("sameNameProdHead"), newVariables, Tensor) val newBody = SoP(Seq(Prod(sameNameExpsList))) val newRule = Rule(newHead, newBody) @@ -240,9 +260,35 @@ object Compiler { (us, rm, c) } - def project(lhs: Access, rhs: Access): (Rule, Rule, Rule) = { + def checkIfItsOptimalProject( + lhs: Access, + rhs: Access, + ctx: Seq[(Rule, Rule, Rule, Rule)] + ): Boolean = { + val uncommon_vars = rhs.vars.diff(lhs.vars) + val vecMult = vectorizeComparisonMultiplication( + "=", + uncommon_vars, + uncommon_vars.redundancyVars + ) + val denormRM = locallyDenormalizeAndReturnBody(rhs.redundancyHead(), ctx) + denormRM.prods + .forall(p => p.exps.intersect(vecMult).toSet == vecMult.toSet) + } + + def project( + lhs: Access, + rhs: Access, + ctx: Seq[(Rule, Rule, Rule, Rule)] + ): (Rule, Rule, Rule) = { assert(lhs.vars.toSet.subsetOf(rhs.vars.toSet)) - if (lhs.vars.toSet == rhs.vars.toSet) { + if ( + lhs.vars.toSet == rhs.vars.toSet || checkIfItsOptimalProject( + lhs, + rhs, + ctx + ) + ) { val us = Rule(lhs.uniqueHead, SoP(Seq(Prod(Seq(rhs.uniqueHead))))) val rm = Rule(lhs.redundancyHead, SoP(Seq(Prod(Seq(rhs.redundancyHead))))) val c = Rule(lhs.compressedHead, SoP(Seq(Prod(Seq(rhs.compressedHead))))) @@ -292,72 +338,9 @@ object Compiler { .union(vars2) .toSet && isIntersectEmpty(vars1, vars2) ) { - val usBody = SoP( - Seq( - Prod( - Seq( - acc1.uniqueHead, - acc2.uniqueHead - ) ++ vectorizeComparisonMultiplication("<=", vars1, vars2) - ) - ) - ) - val us = Rule(lhs.uniqueHead, usBody) - - val rmBody = SoP( - Seq( - Prod(Seq(acc1.redundancyHead, acc2.redundancyHead)), - Prod( - Seq( - acc1.uniqueHead, - acc2.redundancyHead - ) ++ vectorizeComparisonMultiplication( - "=", - vars1, - vars1.redundancyVars - ) - ), - Prod( - Seq( - acc1.redundancyHead, - acc2.uniqueHead - ) ++ vectorizeComparisonMultiplication( - "=", - vars2, - vars2.redundancyVars - ) - ), - Prod( - Seq( - acc1.uniqueHead, - acc2.uniqueHead - ) ++ vectorizeComparisonMultiplication( - "=", - vars1, - vars2.redundancyVars - ) ++ vectorizeComparisonMultiplication( - "=", - vars2, - vars1.redundancyVars - ) ++ vectorizeComparisonMultiplication(">", vars1, vars2) - ) - ) - ) - val rm = Rule(lhs.redundancyHead, rmBody) - - val cBody = SoP( - Seq( - Prod( - Seq( - acc1.compressedHead, - acc2.compressedHead - ) ++ vectorizeComparisonMultiplication("<=", vars1, vars2) - ) - ) - ) - val c = Rule(lhs.compressedHead, cBody) - - (us, rm, c) + selfOuterProduct(lhs, Seq(acc1, acc2)) + } else if (name1 == name2 && isSubSelfOuterProduct(Seq(vars1, vars2))) { + selfOuterProduct(lhs, Seq(acc1, acc2)) } else if ( lhs.vars.toSet == vars1.union(vars2).toSet && isIntersectEmpty( vars1, @@ -862,7 +845,7 @@ object Compiler { SoP(Seq(Prod(accSeq.map(_.compressedHead())))) ) (us, rm, c) - } else { // we made sure in the normalization step that all the accesses have the same name and their variable pairwise intersection is empty + } else if (isPairwiseIntersectEmpty(accSeq.map(acc => acc.vars))) { // we made sure in the normalization step that all the accesses have the same name and their variable pairwise intersection is empty val (us2, rm2, c2) = selfOuterProduct2(lhs, accSeq) val allUniqueHeads = accSeq.map(_.uniqueHead()) @@ -907,6 +890,114 @@ object Compiler { val cBody = prodTimesSoP(Prod(accSeq.map(_.compressedHead())), c2.body) val c = Rule(lhs.compressedHead(), cBody) + (us, rm, c) + } else { // We assured that if the pairwise intersection is not empty, then the isSubSelfOuter function has returned true to be here. + val vars = accSeq.map(_.vars) + val all_intersect = + vars.tail.foldLeft(vars.head)((acc, cur) => acc.intersect(cur)) + + val (usSoPSeq, rmSoPSeq, cSoPSeq) = (0 to accSeq.length - 1) + .flatMap(i => { + (0 to accSeq.length - 1) + .combinations(i) + .map(indexList => { + val prodAndVectorizedMult = Prod( + accSeq.zipWithIndex + .collect { + case (acc, ind) if !indexList.contains(ind) => + vectorizeComparisonMultiplication( + "=", + all_intersect, + all_intersect.redundancyVars + ) :+ + acc.uniqueHead + case (acc, ind) if indexList.contains(ind) => + Seq(acc.redundancyHead) + } + .flatten + .toSeq + ) + + val (prodExps, vecMult) = prodAndVectorizedMult.exps.partition { + case _: Access => true + case _ => false + } + val prod = Prod(prodExps) + + val (uniqueAccesses, rest) = prod.exps.partition { + case _ @Access(_, _, UniqueSet) => true + case _ => false + } + + if (uniqueAccesses.length == 1) { + val us = SoP(Seq(prod)) + val rm = emptySoP() + val c11 = uniqueAccesses.collect { case u: Access => + Access( + u.name.substring(0, u.name.length() - 2) + "C", + u.vars, + CompressedTensor + ) + } + val c12 = rest.collect { case r: Access => + Access( + r.name.substring(0, r.name.length() - 2) + "C", + r.vars.drop(r.vars.length / 2), + CompressedTensor + ) + } + val c = SoP(Seq(Prod(c11 ++ c12 ++ rest))) + + (us, rm, c) + } else { + val all_unique_accesses_minus_intersect = + uniqueAccesses.collect { + case _ @Access(name, variables, kind) => + Access(name, variables.diff(all_intersect), kind) + } + + val (us2, rm2, c2) = + selfOuterProduct2(lhs, all_unique_accesses_minus_intersect) + + val us = prodTimesSoP(prod, us2.body) + val injectedMapRM = + injectRM( + rm2.body, + us2.body.prods.head, + all_unique_accesses_minus_intersect.map(_.vars).transpose + ) + val rm = + prodTimesSoP(prodAndVectorizedMult, injectedMapRM) + + val c11 = uniqueAccesses.collect { case u: Access => + Access( + u.name.substring(0, u.name.length() - 2) + "C", + u.vars, + CompressedTensor + ) + } + val c12 = rest.collect { case r: Access => + Access( + r.name.substring(0, r.name.length() - 2) + "C", + r.vars.drop(r.vars.length / 2), + CompressedTensor + ) + } + val c = prodTimesSoP(Prod(c11 ++ c12 ++ rest), c2.body) + + (us, rm, c) + } + }) + }) + .unzip3 + + val us = Rule(lhs.uniqueHead, concatSoP(usSoPSeq)) + val rm = Rule( + lhs.redundancyHead, + concatSoP(rmSoPSeq :+ SoP(Seq(Prod(accSeq.map(_.redundancyHead()))))) + ) + val c = Rule(lhs.compressedHead, concatSoP(cSoPSeq)) + (us, rm, c) } } @@ -1064,7 +1155,7 @@ object Compiler { } } else if (exps.length == 1) { exps.head match { - case access: Access => project(head, access) + case access: Access => project(head, access, ctx) case _ => throw new AssertionError("Expected an Access expression") } } else if (exps.length == 2) { diff --git a/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Optimizer.scala b/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Optimizer.scala index 11c7298..d03e36d 100644 --- a/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Optimizer.scala +++ b/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Optimizer.scala @@ -592,9 +592,14 @@ object Optimizer { } }) - val all_variable_names = - (symbols.map(_.name) ++ boundVariables(Seq(rule))).distinct + val base_variables = + symbols.map(_.name) ++ rule.head.vars.map(_.name).distinct val finalBody = SoP(newBody.prods.map(p => { + val all_variable_names = base_variables ++ p.exps + .collect { case a: Access => a.vars } + .flatten + .distinct + .map(_.name) Prod( p.exps.filter(e => getVariables(e).map(_.name).forall(all_variable_names.contains(_)) diff --git a/src/test/resources/correct_test_outputs/PGLM_w_body_DataLayout.cpp b/src/test/resources/correct_test_outputs/PGLM_w_body_DataLayout.cpp index 92bd261..02c6aec 100644 --- a/src/test/resources/correct_test_outputs/PGLM_w_body_DataLayout.cpp +++ b/src/test/resources/correct_test_outputs/PGLM_w_body_DataLayout.cpp @@ -63,15 +63,15 @@ B2[j] += B[i][j]; long time_computation = 0, start_computation, end_computation; start_computation = duration_cast(system_clock::now().time_since_epoch()).count(); int i = 0; -for (int i30 = 0; i30 < W; ++i30) { +for (int i28 = 0; i28 < W; ++i28) { -A[i] += (B1[i30] * C[i30]); +A[i] += (B1[i28] * C[i28]); } for (int i = 1; i < W; ++i) { -int i31 = (i - 1); -if (i31 >= 0 && i31 < min({(W - 1), W})) { -A[i] += (B2[i31] * C[i31]); +int i29 = (i - 1); +if (i29 >= 0 && i29 < min({(W - 1), W})) { +A[i] += (B2[i29] * C[i29]); } } end_computation = duration_cast(system_clock::now().time_since_epoch()).count(); diff --git a/src/test/resources/correct_test_outputs/PGLM_wo_body_DataLayout.cpp b/src/test/resources/correct_test_outputs/PGLM_wo_body_DataLayout.cpp index ffefc50..365412b 100644 --- a/src/test/resources/correct_test_outputs/PGLM_wo_body_DataLayout.cpp +++ b/src/test/resources/correct_test_outputs/PGLM_wo_body_DataLayout.cpp @@ -27,15 +27,15 @@ B2[j] += B[i][j]; long time_computation = 0, start_computation, end_computation; start_computation = duration_cast(system_clock::now().time_since_epoch()).count(); int i = 0; -for (int i28 = 0; i28 < W; ++i28) { +for (int i26 = 0; i26 < W; ++i26) { -A[i] += (B1[i28] * C[i28]); +A[i] += (B1[i26] * C[i26]); } for (int i = 1; i < W; ++i) { -int i29 = (i - 1); -if (i29 >= 0 && i29 < min({(W - 1), W})) { -A[i] += (B2[i29] * C[i29]); +int i27 = (i - 1); +if (i27 >= 0 && i27 < min({(W - 1), W})) { +A[i] += (B2[i27] * C[i27]); } } end_computation = duration_cast(system_clock::now().time_since_epoch()).count(); diff --git a/src/test/resources/correct_test_outputs/covar_wo_body.cpp b/src/test/resources/correct_test_outputs/covar_wo_body.cpp new file mode 100644 index 0000000..fc46636 --- /dev/null +++ b/src/test/resources/correct_test_outputs/covar_wo_body.cpp @@ -0,0 +1,44 @@ + +#include +#include +#include +#include + +using namespace std; +using namespace std::chrono; + +extern "C" +void fn(double ** covar, double ** X, int N, int M) { + + +long time_computation = 0, start_computation, end_computation; +start_computation = duration_cast(system_clock::now().time_since_epoch()).count(); +for (int j = 0; j < N; ++j) { + +for (int k = 0; k < min({(j) + 1, N}); ++k) { + +for (int i = 0; i < M; ++i) { + +covar[j][k] += (X[i][k] * X[i][j]); +} +} +} +end_computation = duration_cast(system_clock::now().time_since_epoch()).count(); +time_computation = end_computation - start_computation; +cout << time_computation << endl; +long time_reconstruction = 0, start_reconstruction, end_reconstruction; +start_reconstruction = duration_cast(system_clock::now().time_since_epoch()).count(); +for (int j = 0; j < N; ++j) { + +int kp = j; +for (int k = max({(j) + 1, 0}); k < N; ++k) { + +int jp = k; +covar[j][k] = covar[jp][kp]; +} +} +end_reconstruction = duration_cast(system_clock::now().time_since_epoch()).count(); +time_reconstruction = end_reconstruction - start_reconstruction; +cout << time_reconstruction << endl; + +} \ No newline at end of file diff --git a/src/test/resources/correct_test_outputs/pseudo-lr-training_wo_body.cpp b/src/test/resources/correct_test_outputs/pseudo-lr-training_wo_body.cpp new file mode 100644 index 0000000..494c2c2 --- /dev/null +++ b/src/test/resources/correct_test_outputs/pseudo-lr-training_wo_body.cpp @@ -0,0 +1,66 @@ + +#include +#include +#include +#include + +using namespace std; +using namespace std::chrono; + +extern "C" +void fn(double ** covar, double ** X, double * w, double * y, double * DW, int N, int P) { + + +long time_computation = 0, start_computation, end_computation; +start_computation = duration_cast(system_clock::now().time_since_epoch()).count(); +for (int i = 0; i < P; ++i) { + +for (int k = 0; k < min({(i) + 1, P}); ++k) { + +for (int j = 0; j < N; ++j) { + +covar[i][k] += (X[j][k] * X[j][i]); +} +} +} +for (int i = 0; i < P; ++i) { + +for (int i19 = 0; i19 < min({(i) + 1, P}); ++i19) { + +DW[i] += (covar[i][i19] * w[i]); +} +} +for (int i = 0; i < P; ++i) { + +for (int i19 = max({(i) + 1, 0, i}); i19 < P; ++i19) { + +DW[i] += (w[i] * covar[i19][i]); +} +} +for (int i = 0; i < P; ++i) { + +for (int i20 = 0; i20 < N; ++i20) { + +DW[i] += (X[i20][i] * y[i20]); +} +} +end_computation = duration_cast(system_clock::now().time_since_epoch()).count(); +time_computation = end_computation - start_computation; +cout << time_computation << endl; +long time_reconstruction = 0, start_reconstruction, end_reconstruction; +start_reconstruction = duration_cast(system_clock::now().time_since_epoch()).count(); +for (int i = 0; i < P; ++i) { + +int kp = i; +for (int k = max({(i) + 1, 0}); k < P; ++k) { + +int ip = k; +covar[i][k] = covar[ip][kp]; +} +} + +end_reconstruction = duration_cast(system_clock::now().time_since_epoch()).count(); +time_reconstruction = end_reconstruction - start_reconstruction; +cout << time_reconstruction << endl; + +} \ No newline at end of file diff --git a/src/test/scala/uk/ac/ed/dal/structtensor/codegen/CodegenTest.scala b/src/test/scala/uk/ac/ed/dal/structtensor/codegen/CodegenTest.scala index ba6e1bf..636f0f0 100644 --- a/src/test/scala/uk/ac/ed/dal/structtensor/codegen/CodegenTest.scala +++ b/src/test/scala/uk/ac/ed/dal/structtensor/codegen/CodegenTest.scala @@ -2086,4 +2086,48 @@ class CodegenTest extends AnyFlatSpec with Matchers { val lines2 = file2.getLines().toList lines2 should be(lines1) } + + it should "generate code for when there is an optimal projection, like matrix covariance matrix without the body" in { + Utils.cnt = 0 + Main.main( + Array( + "-i", + "examples/covar.stur", + "-o", + "src/test/resources/test_outputs/covar_wo_body_test.cpp" + ) + ) + + val file1 = scala.io.Source.fromFile( + "src/test/resources/correct_test_outputs/covar_wo_body.cpp" + ) + val file2 = scala.io.Source.fromFile( + "src/test/resources/test_outputs/covar_wo_body_test.cpp" + ) + val lines1 = file1.getLines().toList + val lines2 = file2.getLines().toList + lines2 should be(lines1) + } + + it should "generate code for when there is an optimal projection, like pseudo lr training without the body" in { + Utils.cnt = 0 + Main.main( + Array( + "-i", + "examples/pseudo-lr-training.stur", + "-o", + "src/test/resources/test_outputs/pseudo-lr-training_wo_body_test.cpp" + ) + ) + + val file1 = scala.io.Source.fromFile( + "src/test/resources/correct_test_outputs/pseudo-lr-training_wo_body.cpp" + ) + val file2 = scala.io.Source.fromFile( + "src/test/resources/test_outputs/pseudo-lr-training_wo_body_test.cpp" + ) + val lines1 = file1.getLines().toList + val lines2 = file2.getLines().toList + lines2 should be(lines1) + } } diff --git a/src/test/scala/uk/ac/ed/dal/structtensor/compiler/CompilerTest.scala b/src/test/scala/uk/ac/ed/dal/structtensor/compiler/CompilerTest.scala index 4ecb5e8..1a8714b 100644 --- a/src/test/scala/uk/ac/ed/dal/structtensor/compiler/CompilerTest.scala +++ b/src/test/scala/uk/ac/ed/dal/structtensor/compiler/CompilerTest.scala @@ -503,7 +503,8 @@ class CompilerTest ) val res = Compiler.project( rule.head, - rule.body.prods.head.exps.head.asInstanceOf[Access] + rule.body.prods.head.exps.head.asInstanceOf[Access], + Seq() ) val (us, rm, cc) = res @@ -583,7 +584,32 @@ class CompilerTest ) val res = Compiler.project( rule.head, - rule.body.prods.head.exps.head.asInstanceOf[Access] + rule.body.prods.head.exps.head.asInstanceOf[Access], + Seq( + ( + emptyRule(), + Rule( + Access( + "b", + Seq(Variable("i"), Variable("j"), Variable("k")), + Tensor + ).redundancyHead(), + SoP( + Seq( + Prod( + Seq( + Comparison("=", Variable("i"), Variable("kp")), + Comparison("=", Variable("k"), Variable("ip")), + Comparison("=", Variable("j"), Variable("jp")) + ) + ) + ) + ) + ), + emptyRule(), + emptyRule() + ) + ) ) val (us, rm, cc) = res @@ -684,11 +710,392 @@ class CompilerTest assertThrows[AssertionError] { Compiler.project( rule.head, - rule.body.prods.head.exps.head.asInstanceOf[Access] + rule.body.prods.head.exps.head.asInstanceOf[Access], + Seq() ) } } + it should "project optimally when there is an optimal projection" in { + // This test is based on the following: + // symbols: N, M + // covar(j, k) := X(i, j) * X(i, k) + // covar:D(i, j) := (0 <= i < N) * (0 <= j < N) + // X:D(i, j) := (0 <= i < M) * (0 <= j < N) + + val lhs = Access("covar", List(Variable("j"), Variable("k")), Tensor) + val rhs = Access( + "sameNameProdHead2", + List(Variable("i"), Variable("k"), Variable("j")), + Tensor + ) + val ctx = List( + ( + Rule( + Access("X_US", List(Variable("i"), Variable("j")), UniqueSet), + SoP( + List( + Prod( + List( + Comparison("<=", ConstantInt(0), Variable("i")), + Comparison(">", Variable("M"), Variable("i")), + Comparison("<=", ConstantInt(0), Variable("j")), + Comparison(">", Variable("N"), Variable("j")) + ) + ) + ) + ) + ), + Rule( + Access( + "X_RM", + List(Variable("i"), Variable("j"), Variable("ip"), Variable("jp")), + RedundancyMap + ), + SoP(List()) + ), + Rule( + Access("X_C", List(Variable("i"), Variable("j")), CompressedTensor), + SoP( + List( + Prod( + List( + Access("X", List(Variable("i"), Variable("j")), Tensor), + Comparison("<=", ConstantInt(0), Variable("i")), + Comparison(">", Variable("M"), Variable("i")), + Comparison("<=", ConstantInt(0), Variable("j")), + Comparison(">", Variable("N"), Variable("j")) + ) + ) + ) + ) + ), + Rule( + Access("X", List(Variable("i"), Variable("j")), Tensor), + SoP( + List( + Prod( + List(Access("X", List(Variable("i"), Variable("j")), Tensor)) + ) + ) + ) + ) + ), + ( + Rule( + Access("X_US", List(Variable("i"), Variable("k")), UniqueSet), + SoP( + List( + Prod( + List( + Comparison("<=", ConstantInt(0), Variable("i")), + Comparison(">", Variable("M"), Variable("i")), + Comparison("<=", ConstantInt(0), Variable("k")), + Comparison(">", Variable("N"), Variable("k")) + ) + ) + ) + ) + ), + Rule( + Access( + "X_RM", + List(Variable("i"), Variable("k"), Variable("ip"), Variable("kp")), + RedundancyMap + ), + SoP(List()) + ), + Rule( + Access("X_C", List(Variable("i"), Variable("k")), CompressedTensor), + SoP( + List( + Prod( + List( + Access("X", List(Variable("i"), Variable("k")), Tensor), + Comparison("<=", ConstantInt(0), Variable("i")), + Comparison(">", Variable("M"), Variable("i")), + Comparison("<=", ConstantInt(0), Variable("k")), + Comparison(">", Variable("N"), Variable("k")) + ) + ) + ) + ) + ), + Rule( + Access("X", List(Variable("i"), Variable("k")), Tensor), + SoP( + List( + Prod( + List(Access("X", List(Variable("i"), Variable("k")), Tensor)) + ) + ) + ) + ) + ), + ( + Rule( + Access( + "sameNameProdHead2_US", + List(Variable("i"), Variable("k"), Variable("j")), + UniqueSet + ), + SoP( + Vector( + Prod( + List( + Access("X_US", List(Variable("i"), Variable("k")), UniqueSet), + Access("X_US", List(Variable("i"), Variable("j")), UniqueSet), + Comparison("<=", Variable("k"), Variable("j")) + ) + ), + Prod( + List( + Access( + "X_RM", + List( + Variable("i"), + Variable("k"), + Variable("ip"), + Variable("kp") + ), + RedundancyMap + ), + Access("X_US", List(Variable("i"), Variable("j")), UniqueSet) + ) + ), + Prod( + List( + Access("X_US", List(Variable("i"), Variable("k")), UniqueSet), + Access( + "X_RM", + List( + Variable("i"), + Variable("j"), + Variable("ip"), + Variable("jp") + ), + RedundancyMap + ) + ) + ) + ) + ) + ), + Rule( + Access( + "sameNameProdHead2_RM", + List( + Variable("i"), + Variable("k"), + Variable("j"), + Variable("ip"), + Variable("kp"), + Variable("jp") + ), + RedundancyMap + ), + SoP( + Vector( + Prod( + List( + Comparison("=", Variable("i"), Variable("ip")), + Access("X_US", List(Variable("i"), Variable("k")), UniqueSet), + Comparison("=", Variable("i"), Variable("ip")), + Access("X_US", List(Variable("i"), Variable("j")), UniqueSet), + Comparison("=", Variable("j"), Variable("kp")), + Comparison("=", Variable("k"), Variable("jp")), + Comparison("<", Variable("j"), Variable("k")) + ) + ), + Prod( + List( + Access( + "X_RM", + List( + Variable("i"), + Variable("k"), + Variable("ip"), + Variable("kp") + ), + RedundancyMap + ), + Access( + "X_RM", + List( + Variable("i"), + Variable("j"), + Variable("ip"), + Variable("jp") + ), + RedundancyMap + ) + ) + ) + ) + ) + ), + Rule( + Access( + "sameNameProdHead2_C", + List(Variable("i"), Variable("k"), Variable("j")), + CompressedTensor + ), + SoP( + Vector( + Prod( + List( + Access( + "X_C", + List(Variable("i"), Variable("k")), + CompressedTensor + ), + Access( + "X_C", + List(Variable("i"), Variable("j")), + CompressedTensor + ), + Comparison("<=", Variable("k"), Variable("j")) + ) + ), + Prod( + List( + Access( + "X_C", + List(Variable("i"), Variable("j")), + CompressedTensor + ), + Access( + "X_C", + List(Variable("ip"), Variable("kp")), + CompressedTensor + ), + Access( + "X_RM", + List( + Variable("i"), + Variable("k"), + Variable("ip"), + Variable("kp") + ), + RedundancyMap + ) + ) + ), + Prod( + List( + Access( + "X_C", + List(Variable("i"), Variable("k")), + CompressedTensor + ), + Access( + "X_C", + List(Variable("ip"), Variable("jp")), + CompressedTensor + ), + Access( + "X_RM", + List( + Variable("i"), + Variable("j"), + Variable("ip"), + Variable("jp") + ), + RedundancyMap + ) + ) + ) + ) + ) + ), + Rule( + Access( + "sameNameProdHead2", + List(Variable("i"), Variable("k"), Variable("j")), + Tensor + ), + SoP( + List( + Prod( + List( + Access("X", List(Variable("i"), Variable("k")), Tensor), + Access("X", List(Variable("i"), Variable("j")), Tensor) + ) + ) + ) + ) + ) + ) + ) + + val res = Compiler.project(lhs, rhs, ctx) + val (us, rm, cc) = res + + val expectedUS = Rule( + Access("covar_US", List(Variable("j"), Variable("k")), UniqueSet), + SoP( + List( + Prod( + List( + Access( + "sameNameProdHead2_US", + List(Variable("i"), Variable("k"), Variable("j")), + UniqueSet + ) + ) + ) + ) + ) + ) + val expectedRM = Rule( + Access( + "covar_RM", + List(Variable("j"), Variable("k"), Variable("jp"), Variable("kp")), + RedundancyMap + ), + SoP( + List( + Prod( + List( + Access( + "sameNameProdHead2_RM", + List( + Variable("i"), + Variable("k"), + Variable("j"), + Variable("ip"), + Variable("kp"), + Variable("jp") + ), + RedundancyMap + ) + ) + ) + ) + ) + ) + val expectedCC = Rule( + Access("covar_C", List(Variable("j"), Variable("k")), CompressedTensor), + SoP( + List( + Prod( + List( + Access( + "sameNameProdHead2_C", + List(Variable("i"), Variable("k"), Variable("j")), + CompressedTensor + ) + ) + ) + ) + ) + ) + + us shouldBe expectedUS + rm shouldBe expectedRM + cc shouldBe expectedCC + } + it should "be able to generate a vectorize comparison multiplication" in { Compiler.vectorizeComparisonMultiplication( "<=", @@ -747,31 +1154,39 @@ class CompilerTest Seq( Prod( Seq( - Access("b_RM", Seq(Variable("i"), Variable("ip")), RedundancyMap), + Access( + "b_RM", + Seq(Variable("i"), Variable("ip")), + RedundancyMap + ), Access("b_RM", Seq(Variable("j"), Variable("jp")), RedundancyMap) ) ), Prod( Seq( + Comparison("=", Variable("i"), Variable("ip")), Access("b_US", Seq(Variable("i")), UniqueSet), - Access("b_RM", Seq(Variable("j"), Variable("jp")), RedundancyMap), - Comparison("=", Variable("i"), Variable("ip")) + Access("b_RM", Seq(Variable("j"), Variable("jp")), RedundancyMap) ) ), Prod( Seq( - Access("b_RM", Seq(Variable("i"), Variable("ip")), RedundancyMap), - Access("b_US", Seq(Variable("j")), UniqueSet), - Comparison("=", Variable("j"), Variable("jp")) + Access( + "b_RM", + Seq(Variable("i"), Variable("ip")), + RedundancyMap + ), + Comparison("=", Variable("j"), Variable("jp")), + Access("b_US", Seq(Variable("j")), UniqueSet) ) ), Prod( Seq( Access("b_US", Seq(Variable("i")), UniqueSet), Access("b_US", Seq(Variable("j")), UniqueSet), - Comparison("=", Variable("i"), Variable("jp")), Comparison("=", Variable("j"), Variable("ip")), - Comparison(">", Variable("i"), Variable("j")) + Comparison("=", Variable("i"), Variable("jp")), + Comparison("<", Variable("j"), Variable("i")) ) ) ) @@ -2474,7 +2889,7 @@ class CompilerTest ) } - it should "compile a a computation, given all inputs and symbols" in { + it should "compile a computation, given all inputs and symbols" in { // first expression in PR2C.stur val computation = Rule( Access("A", Seq(Variable("i"), Variable("j")), Tensor), @@ -2596,9 +3011,9 @@ class CompilerTest Seq( Prod( Seq( - Comparison("=", Variable("i"), Variable("jp")), Comparison("=", Variable("j"), Variable("ip")), - Comparison(">", Variable("i"), Variable("j")), + Comparison("=", Variable("i"), Variable("jp")), + Comparison("<", Variable("j"), Variable("i")), Comparison("<=", ConstantInt(0), Variable("i")), Comparison(">", Variable("N"), Variable("i")), Comparison("<=", ConstantInt(0), Variable("j")), diff --git a/src/test/scala/uk/ac/ed/dal/structtensor/compiler/OptimizerTest.scala b/src/test/scala/uk/ac/ed/dal/structtensor/compiler/OptimizerTest.scala index 3e600c5..677fd4f 100644 --- a/src/test/scala/uk/ac/ed/dal/structtensor/compiler/OptimizerTest.scala +++ b/src/test/scala/uk/ac/ed/dal/structtensor/compiler/OptimizerTest.scala @@ -1151,25 +1151,20 @@ class OptimizerTest val actual = Optimizer.replaceEqualVariables(rule, Seq()) actual shouldBe Rule( - Access("A", Seq(Variable("i")), Tensor), + Access("A", List(Variable("i")), Tensor), SoP( - Seq( + List( Prod( - Seq( - Access("B", Seq(Variable("i"), Variable("i")), Tensor), - Access("C", Seq(Variable("i"), Variable("i")), Tensor), - Access("D", Seq(Variable("p"), Variable("q")), Tensor), - Comparison("=", Variable("i"), Variable("j")), - Comparison("=", Variable("i"), Variable("l")), - Comparison("=", Variable("j"), Variable("i")) + List( + Access("B", List(Variable("i"), Variable("i")), Tensor), + Access("C", List(Variable("i"), Variable("i")), Tensor), + Access("D", List(Variable("p"), Variable("q")), Tensor) ) ), Prod( - Seq( - Access("B", Seq(Variable("i"), Variable("i")), Tensor), - Access("C", Seq(Variable("p"), Variable("p")), Tensor), - Comparison("=", Variable("i"), Variable("j")), - Comparison("=", Variable("p"), Variable("l")) + List( + Access("B", List(Variable("i"), Variable("i")), Tensor), + Access("C", List(Variable("p"), Variable("p")), Tensor) ) ) )