Skip to content

Commit 43e26d3

Browse files
authored
Merge pull request #17 from edin-dal/fix/using-symbols-in-operations
Symbols can be used as a part of the computation
2 parents eb12f45 + ac2d767 commit 43e26d3

File tree

5 files changed

+131
-0
lines changed

5 files changed

+131
-0
lines changed

examples/symbol-computation.stur

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
symbols: N
2+
A(i) := B(i) * N^-1()
3+
A:D(i) := (0 <= i < N)
4+
B:D(i) := (0 <= i < N)

src/main/scala/uk/ac/ed/dal/structtensor/codegen/Bodygen.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ object Bodygen {
7676
.distinctBy(_.name)
7777
.filter(_.kind == Tensor)
7878
.filterNot(only_lhs_heads.contains)
79+
.filterNot(t => symbols.contains(t.name.deinversifiedName.toVar))
7980
.filterNot(t => decimal_pattern.matches(t.name))
8081
.map(t =>
8182
alloc_and_gen_random_number(
@@ -135,6 +136,7 @@ extern "C"
135136
val decimal_pattern = """-?\d+(\.\d+)?""".r
136137
val tensor_to_str = all_tensors
137138
.filterNot(only_lhs_heads_not_in_output.contains)
139+
.filterNot(t => symbols.contains(t.name.deinversifiedName.toVar))
138140
.filterNot(t => decimal_pattern.matches(t.name.deinversifiedName))
139141
.map(t =>
140142
"double " + (if (t.vars.isEmpty) "&"
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
#include <iostream>
3+
#include <random>
4+
#include <algorithm>
5+
#include <chrono>
6+
7+
using namespace std;
8+
using namespace std::chrono;
9+
10+
int main(int argc, char **argv){
11+
srand(0);
12+
13+
14+
const int N = atoi(argv[1]);
15+
double *B = new double[N];
16+
for (size_t i = 0; i < N; ++i) {
17+
int flag1 = 0 <= i && N > i;
18+
if (flag1) {
19+
B[i] = (double) (rand() % 1000000) / 1e6;
20+
} else {
21+
B[i] = 0.0;
22+
}
23+
}
24+
double *A = new double[N];
25+
for (size_t i = 0; i < N; ++i) {
26+
A[i] = 0.0;
27+
}
28+
29+
long time_computation = 0, start_computation, end_computation;
30+
start_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
31+
for (int i = 0; i < N; ++i) {
32+
33+
A[i] += (1. / N * B[i]);
34+
}
35+
end_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
36+
time_computation = end_computation - start_computation;
37+
cout << time_computation << endl;
38+
long time_reconstruction = 0, start_reconstruction, end_reconstruction;
39+
start_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
40+
41+
end_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
42+
time_reconstruction = end_reconstruction - start_reconstruction;
43+
cout << time_reconstruction << endl;
44+
cerr << A[0] << endl;
45+
cerr << B[0] << endl;
46+
cerr << N << endl;
47+
delete[] A;
48+
delete[] B;
49+
return 0;
50+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
#include <iostream>
3+
#include <random>
4+
#include <algorithm>
5+
#include <chrono>
6+
7+
using namespace std;
8+
using namespace std::chrono;
9+
10+
extern "C"
11+
void fn(double * A, double * B, int N) {
12+
13+
14+
long time_computation = 0, start_computation, end_computation;
15+
start_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
16+
for (int i = 0; i < N; ++i) {
17+
18+
A[i] += (1. / N * B[i]);
19+
}
20+
end_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
21+
time_computation = end_computation - start_computation;
22+
cout << time_computation << endl;
23+
long time_reconstruction = 0, start_reconstruction, end_reconstruction;
24+
start_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
25+
26+
end_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
27+
time_reconstruction = end_reconstruction - start_reconstruction;
28+
cout << time_reconstruction << endl;
29+
30+
}

src/test/scala/uk/ac/ed/dal/structtensor/codegen/CodegenTest.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1886,4 +1886,49 @@ class CodegenTest extends AnyFlatSpec with Matchers {
18861886
val lines2 = file2.getLines().toList
18871887
lines2 should be(lines1)
18881888
}
1889+
1890+
it should "generate code for when there is a symbol used in the computation without the body" in {
1891+
Utils.cnt = 0
1892+
Main.main(
1893+
Array(
1894+
"-i",
1895+
"examples/symbol-computation.stur",
1896+
"-o",
1897+
"src/test/resources/test_outputs/symbol-computation_wo_body_test.cpp"
1898+
)
1899+
)
1900+
1901+
val file1 = scala.io.Source.fromFile(
1902+
"src/test/resources/correct_test_outputs/symbol-computation_wo_body.cpp"
1903+
)
1904+
val file2 = scala.io.Source.fromFile(
1905+
"src/test/resources/test_outputs/symbol-computation_wo_body_test.cpp"
1906+
)
1907+
val lines1 = file1.getLines().toList
1908+
val lines2 = file2.getLines().toList
1909+
lines2 should be(lines1)
1910+
}
1911+
1912+
it should "generate code for when there is a symbol used in the computation with the body" in {
1913+
Utils.cnt = 0
1914+
Main.main(
1915+
Array(
1916+
"-i",
1917+
"examples/symbol-computation.stur",
1918+
"-o",
1919+
"src/test/resources/test_outputs/symbol-computation_w_body_test.cpp",
1920+
"--init-tensors"
1921+
)
1922+
)
1923+
1924+
val file1 = scala.io.Source.fromFile(
1925+
"src/test/resources/correct_test_outputs/symbol-computation_w_body.cpp"
1926+
)
1927+
val file2 = scala.io.Source.fromFile(
1928+
"src/test/resources/test_outputs/symbol-computation_w_body_test.cpp"
1929+
)
1930+
val lines1 = file1.getLines().toList
1931+
val lines2 = file2.getLines().toList
1932+
lines2 should be(lines1)
1933+
}
18891934
}

0 commit comments

Comments
 (0)