Skip to content

Commit

Permalink
Merge pull request #17 from edin-dal/fix/using-symbols-in-operations
Browse files Browse the repository at this point in the history
Symbols can be used as a part of the computation
  • Loading branch information
mtghorbani authored Jul 11, 2024
2 parents eb12f45 + ac2d767 commit 43e26d3
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/symbol-computation.stur
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
symbols: N
A(i) := B(i) * N^-1()
A:D(i) := (0 <= i < N)
B:D(i) := (0 <= i < N)
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ object Bodygen {
.distinctBy(_.name)
.filter(_.kind == Tensor)
.filterNot(only_lhs_heads.contains)
.filterNot(t => symbols.contains(t.name.deinversifiedName.toVar))
.filterNot(t => decimal_pattern.matches(t.name))
.map(t =>
alloc_and_gen_random_number(
Expand Down Expand Up @@ -135,6 +136,7 @@ extern "C"
val decimal_pattern = """-?\d+(\.\d+)?""".r
val tensor_to_str = all_tensors
.filterNot(only_lhs_heads_not_in_output.contains)
.filterNot(t => symbols.contains(t.name.deinversifiedName.toVar))
.filterNot(t => decimal_pattern.matches(t.name.deinversifiedName))
.map(t =>
"double " + (if (t.vars.isEmpty) "&"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

#include <iostream>
#include <random>
#include <algorithm>
#include <chrono>

using namespace std;
using namespace std::chrono;

int main(int argc, char **argv){
srand(0);


const int N = atoi(argv[1]);
double *B = new double[N];
for (size_t i = 0; i < N; ++i) {
int flag1 = 0 <= i && N > i;
if (flag1) {
B[i] = (double) (rand() % 1000000) / 1e6;
} else {
B[i] = 0.0;
}
}
double *A = new double[N];
for (size_t i = 0; i < N; ++i) {
A[i] = 0.0;
}

long time_computation = 0, start_computation, end_computation;
start_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
for (int i = 0; i < N; ++i) {

A[i] += (1. / N * B[i]);
}
end_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
time_computation = end_computation - start_computation;
cout << time_computation << endl;
long time_reconstruction = 0, start_reconstruction, end_reconstruction;
start_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();

end_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
time_reconstruction = end_reconstruction - start_reconstruction;
cout << time_reconstruction << endl;
cerr << A[0] << endl;
cerr << B[0] << endl;
cerr << N << endl;
delete[] A;
delete[] B;
return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@

#include <iostream>
#include <random>
#include <algorithm>
#include <chrono>

using namespace std;
using namespace std::chrono;

extern "C"
void fn(double * A, double * B, int N) {


long time_computation = 0, start_computation, end_computation;
start_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
for (int i = 0; i < N; ++i) {

A[i] += (1. / N * B[i]);
}
end_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
time_computation = end_computation - start_computation;
cout << time_computation << endl;
long time_reconstruction = 0, start_reconstruction, end_reconstruction;
start_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();

end_reconstruction = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
time_reconstruction = end_reconstruction - start_reconstruction;
cout << time_reconstruction << endl;

}
45 changes: 45 additions & 0 deletions src/test/scala/uk/ac/ed/dal/structtensor/codegen/CodegenTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1886,4 +1886,49 @@ class CodegenTest extends AnyFlatSpec with Matchers {
val lines2 = file2.getLines().toList
lines2 should be(lines1)
}

it should "generate code for when there is a symbol used in the computation without the body" in {
Utils.cnt = 0
Main.main(
Array(
"-i",
"examples/symbol-computation.stur",
"-o",
"src/test/resources/test_outputs/symbol-computation_wo_body_test.cpp"
)
)

val file1 = scala.io.Source.fromFile(
"src/test/resources/correct_test_outputs/symbol-computation_wo_body.cpp"
)
val file2 = scala.io.Source.fromFile(
"src/test/resources/test_outputs/symbol-computation_wo_body_test.cpp"
)
val lines1 = file1.getLines().toList
val lines2 = file2.getLines().toList
lines2 should be(lines1)
}

it should "generate code for when there is a symbol used in the computation with the body" in {
Utils.cnt = 0
Main.main(
Array(
"-i",
"examples/symbol-computation.stur",
"-o",
"src/test/resources/test_outputs/symbol-computation_w_body_test.cpp",
"--init-tensors"
)
)

val file1 = scala.io.Source.fromFile(
"src/test/resources/correct_test_outputs/symbol-computation_w_body.cpp"
)
val file2 = scala.io.Source.fromFile(
"src/test/resources/test_outputs/symbol-computation_w_body_test.cpp"
)
val lines1 = file1.getLines().toList
val lines2 = file2.getLines().toList
lines2 should be(lines1)
}
}

0 comments on commit 43e26d3

Please sign in to comment.