Skip to content

Commit 02e7f2a

Browse files
authored
Merge pull request #27 from edin-dal/fix/self-inner-product
Self inner product bug fixed
2 parents e4ae1d2 + ada4766 commit 02e7f2a

File tree

5 files changed

+80
-8
lines changed

5 files changed

+80
-8
lines changed

examples/self-inner-product.stur

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
symbols: N
2+
A() := B(i) * B(i)
3+
B:D(i) := (0 <= i) * (i < N)

project/build.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
sbt.version=1.7.0
1+
sbt.version=1.10.0

src/main/scala/uk/ac/ed/dal/structtensor/compiler/Compiler.scala

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ object Compiler {
4545
val all_intersects =
4646
vars.tail.foldLeft(vars.head)((acc, cur) => acc.intersect(cur))
4747
val vars_wo_intersects = vars.map(_.diff(all_intersects))
48+
val cond0 = vars_wo_intersects.forall(_.length > 0)
4849
val cond1 = isPairwiseIntersectEmpty(vars_wo_intersects)
4950
val all_intersect_indecies = vars.map(v =>
5051
v.zipWithIndex.collect {
@@ -54,7 +55,7 @@ object Compiler {
5455
val cond2 = all_intersect_indecies.tail.forall(
5556
_.toSet == all_intersect_indecies.head.toSet
5657
)
57-
cond1 && cond2
58+
cond0 && cond1 && cond2
5859
}
5960

6061
def groupBySameName(
@@ -290,8 +291,10 @@ object Compiler {
290291
)
291292
) {
292293
val us = Rule(lhs.uniqueHead(), SoP(Seq(Prod(Seq(rhs.uniqueHead())))))
293-
val rm = Rule(lhs.redundancyHead(), SoP(Seq(Prod(Seq(rhs.redundancyHead())))))
294-
val c = Rule(lhs.compressedHead(), SoP(Seq(Prod(Seq(rhs.compressedHead())))))
294+
val rm =
295+
Rule(lhs.redundancyHead(), SoP(Seq(Prod(Seq(rhs.redundancyHead())))))
296+
val c =
297+
Rule(lhs.compressedHead(), SoP(Seq(Prod(Seq(rhs.compressedHead())))))
295298
(us, rm, c)
296299
} else {
297300
val us = Rule(
@@ -305,7 +308,10 @@ object Compiler {
305308
Seq(
306309
Prod(Seq(rhs.compressedHead())),
307310
Prod(
308-
Seq(rhs.redundancyHead(), rhs.vars2RedundancyVars().compressedHead())
311+
Seq(
312+
rhs.redundancyHead(),
313+
rhs.vars2RedundancyVars().compressedHead()
314+
)
309315
)
310316
)
311317
)
@@ -1071,7 +1077,10 @@ object Compiler {
10711077
val rm = Rule(lhs.redundancyHead(), rmBody)
10721078

10731079
val cBody = SoP(
1074-
Seq(Prod(Seq(acc1.compressedHead())), Prod(Seq(acc2.compressedHead())))
1080+
Seq(
1081+
Prod(Seq(acc1.compressedHead())),
1082+
Prod(Seq(acc2.compressedHead()))
1083+
)
10751084
)
10761085
val c = Rule(lhs.compressedHead(), cBody)
10771086

@@ -1088,12 +1097,18 @@ object Compiler {
10881097
val us = Rule(lhs.uniqueHead(), usBody)
10891098

10901099
val rmBody = SoP(
1091-
Seq(Prod(Seq(acc1.redundancyHead())), Prod(Seq(acc2.redundancyHead())))
1100+
Seq(
1101+
Prod(Seq(acc1.redundancyHead())),
1102+
Prod(Seq(acc2.redundancyHead()))
1103+
)
10921104
)
10931105
val rm = Rule(lhs.redundancyHead(), rmBody)
10941106

10951107
val cBody = SoP(
1096-
Seq(Prod(Seq(acc1.compressedHead())), Prod(Seq(acc2.compressedHead())))
1108+
Seq(
1109+
Prod(Seq(acc1.compressedHead())),
1110+
Prod(Seq(acc2.compressedHead()))
1111+
)
10971112
)
10981113
val c = Rule(lhs.compressedHead(), cBody)
10991114

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
2+
#include <iostream>
3+
#include <random>
4+
#include <algorithm>
5+
#include <chrono>
6+
7+
using namespace std;
8+
using namespace std::chrono;
9+
10+
extern "C"
11+
void fn(double & A, double * B, int N) {
12+
13+
14+
long time_computation = 0, start_computation, end_computation;
15+
start_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
16+
{
17+
for (int i = 0; i < N; ++i) {
18+
19+
A += (B[i] * B[i]);
20+
}
21+
}
22+
end_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
23+
time_computation = end_computation - start_computation;
24+
cout << time_computation << endl;
25+
long time_reconstruction = 0, start_reconstruction, end_reconstruction;
26+
start_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
27+
28+
end_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
29+
time_reconstruction = end_reconstruction - start_reconstruction;
30+
cout << time_reconstruction << endl;
31+
32+
}

src/test/scala/uk/ac/ed/dal/structtensor/codegen/CodegenTest.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,4 +2196,26 @@ class CodegenTest extends AnyFlatSpec with Matchers {
21962196
val lines2 = file2.getLines().toList
21972197
lines2 should be(lines1)
21982198
}
2199+
2200+
it should "generate correct code for self inner product without the body" in {
2201+
Utils.cnt = 0
2202+
Main.main(
2203+
Array(
2204+
"-i",
2205+
"examples/self-inner-product.stur",
2206+
"-o",
2207+
"src/test/resources/test_outputs/self-inner-product_wo_body_test.cpp"
2208+
)
2209+
)
2210+
2211+
val file1 = scala.io.Source.fromFile(
2212+
"src/test/resources/correct_test_outputs/self-inner-product_wo_body.cpp"
2213+
)
2214+
val file2 = scala.io.Source.fromFile(
2215+
"src/test/resources/test_outputs/self-inner-product_wo_body_test.cpp"
2216+
)
2217+
val lines1 = file1.getLines().toList
2218+
val lines2 = file2.getLines().toList
2219+
lines2 should be(lines1)
2220+
}
21992221
}

0 commit comments

Comments
 (0)