Skip to content

Commit

Permalink
Merge pull request #27 from edin-dal/fix/self-inner-product
Browse files Browse the repository at this point in the history
Self inner product bug fixed
  • Loading branch information
mtghorbani authored Oct 28, 2024
2 parents e4ae1d2 + ada4766 commit 02e7f2a
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 8 deletions.
3 changes: 3 additions & 0 deletions examples/self-inner-product.stur
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
symbols: N
A() := B(i) * B(i)
B:D(i) := (0 <= i) * (i < N)
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=1.7.0
sbt.version=1.10.0
29 changes: 22 additions & 7 deletions src/main/scala/uk/ac/ed/dal/structtensor/compiler/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ object Compiler {
val all_intersects =
vars.tail.foldLeft(vars.head)((acc, cur) => acc.intersect(cur))
val vars_wo_intersects = vars.map(_.diff(all_intersects))
val cond0 = vars_wo_intersects.forall(_.length > 0)
val cond1 = isPairwiseIntersectEmpty(vars_wo_intersects)
val all_intersect_indecies = vars.map(v =>
v.zipWithIndex.collect {
Expand All @@ -54,7 +55,7 @@ object Compiler {
val cond2 = all_intersect_indecies.tail.forall(
_.toSet == all_intersect_indecies.head.toSet
)
cond1 && cond2
cond0 && cond1 && cond2
}

def groupBySameName(
Expand Down Expand Up @@ -290,8 +291,10 @@ object Compiler {
)
) {
val us = Rule(lhs.uniqueHead(), SoP(Seq(Prod(Seq(rhs.uniqueHead())))))
val rm = Rule(lhs.redundancyHead(), SoP(Seq(Prod(Seq(rhs.redundancyHead())))))
val c = Rule(lhs.compressedHead(), SoP(Seq(Prod(Seq(rhs.compressedHead())))))
val rm =
Rule(lhs.redundancyHead(), SoP(Seq(Prod(Seq(rhs.redundancyHead())))))
val c =
Rule(lhs.compressedHead(), SoP(Seq(Prod(Seq(rhs.compressedHead())))))
(us, rm, c)
} else {
val us = Rule(
Expand All @@ -305,7 +308,10 @@ object Compiler {
Seq(
Prod(Seq(rhs.compressedHead())),
Prod(
Seq(rhs.redundancyHead(), rhs.vars2RedundancyVars().compressedHead())
Seq(
rhs.redundancyHead(),
rhs.vars2RedundancyVars().compressedHead()
)
)
)
)
Expand Down Expand Up @@ -1071,7 +1077,10 @@ object Compiler {
val rm = Rule(lhs.redundancyHead(), rmBody)

val cBody = SoP(
Seq(Prod(Seq(acc1.compressedHead())), Prod(Seq(acc2.compressedHead())))
Seq(
Prod(Seq(acc1.compressedHead())),
Prod(Seq(acc2.compressedHead()))
)
)
val c = Rule(lhs.compressedHead(), cBody)

Expand All @@ -1088,12 +1097,18 @@ object Compiler {
val us = Rule(lhs.uniqueHead(), usBody)

val rmBody = SoP(
Seq(Prod(Seq(acc1.redundancyHead())), Prod(Seq(acc2.redundancyHead())))
Seq(
Prod(Seq(acc1.redundancyHead())),
Prod(Seq(acc2.redundancyHead()))
)
)
val rm = Rule(lhs.redundancyHead(), rmBody)

val cBody = SoP(
Seq(Prod(Seq(acc1.compressedHead())), Prod(Seq(acc2.compressedHead())))
Seq(
Prod(Seq(acc1.compressedHead())),
Prod(Seq(acc2.compressedHead()))
)
)
val c = Rule(lhs.compressedHead(), cBody)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

#include <iostream>
#include <random>
#include <algorithm>
#include <chrono>

using namespace std;
using namespace std::chrono;

extern "C"
void fn(double & A, double * B, int N) {


long time_computation = 0, start_computation, end_computation;
start_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
{
for (int i = 0; i < N; ++i) {

A += (B[i] * B[i]);
}
}
end_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
time_computation = end_computation - start_computation;
cout << time_computation << endl;
long time_reconstruction = 0, start_reconstruction, end_reconstruction;
start_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();

end_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
time_reconstruction = end_reconstruction - start_reconstruction;
cout << time_reconstruction << endl;

}
22 changes: 22 additions & 0 deletions src/test/scala/uk/ac/ed/dal/structtensor/codegen/CodegenTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2196,4 +2196,26 @@ class CodegenTest extends AnyFlatSpec with Matchers {
val lines2 = file2.getLines().toList
lines2 should be(lines1)
}

it should "generate correct code for self inner product without the body" in {
Utils.cnt = 0
Main.main(
Array(
"-i",
"examples/self-inner-product.stur",
"-o",
"src/test/resources/test_outputs/self-inner-product_wo_body_test.cpp"
)
)

val file1 = scala.io.Source.fromFile(
"src/test/resources/correct_test_outputs/self-inner-product_wo_body.cpp"
)
val file2 = scala.io.Source.fromFile(
"src/test/resources/test_outputs/self-inner-product_wo_body_test.cpp"
)
val lines1 = file1.getLines().toList
val lines2 = file2.getLines().toList
lines2 should be(lines1)
}
}

0 comments on commit 02e7f2a

Please sign in to comment.