Skip to content

Commit

Permalink
Merge pull request #23 from edin-dal/feature/evolved-self-outer-produ…
Browse files Browse the repository at this point in the history
…ct-inference

Feature/evolved self outer product inference
  • Loading branch information
mtghorbani authored Jul 18, 2024
2 parents 1a07598 + 0ae3c34 commit cb78ba1
Show file tree
Hide file tree
Showing 11 changed files with 789 additions and 112 deletions.
4 changes: 4 additions & 0 deletions examples/covar.stur
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
symbols: N, M
covar(j, k) := X(i, j) * X(i, k)
covar:D(i, j) := (0 <= i < N) * (0 <= j < N)
X:D(i, j) := (0 <= i < M) * (0 <= j < N)
13 changes: 13 additions & 0 deletions examples/pseudo-lr-training.stur
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
symbols: N, P
outputs: DW, covar
covar(i, k) := X(j, i) * X(j, k)
T1(i) := covar(i, j) * w(i)
T2(i) := X(j, i) * y(j)
DW(i) := T1(i) + T2(i)
X:D(j, i) := (0 <= j) * (j < N) * (0 <= i) * (i < P)
y:D(j) := (0 <= j) * (j < N)
w:D(i) := (0 <= i) * (i < P)
covar:D(i, k) := (0 <= i) * (i < P) * (0 <= k) * (k < P)
T1:D(i) := (0 <= i) * (i < P)
T2:D(i) := (0 <= i) * (i < P)
DW:D(i) := (0 <= i) * (i < P)
235 changes: 163 additions & 72 deletions src/main/scala/uk/ac/ed/dal/structtensor/compiler/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ object Compiler {
}
.forall(_ == true)

def isSubSelfOuterProduct(vars: Seq[Seq[Variable]]): Boolean = {
val all_intersects =
vars.tail.foldLeft(vars.head)((acc, cur) => acc.intersect(cur))
val vars_wo_intersects = vars.map(_.diff(all_intersects))
val cond1 = isPairwiseIntersectEmpty(vars_wo_intersects)
val all_intersect_indecies = vars.map(v =>
v.zipWithIndex.collect {
case (variable, index) if all_intersects.contains(variable) => index
}
)
val cond2 = all_intersect_indecies.tail.forall(
_.toSet == all_intersect_indecies.head.toSet
)
cond1 && cond2
}

def groupBySameName(
exp: Exp,
rest: Seq[Exp],
Expand All @@ -64,7 +80,11 @@ object Compiler {
isThereAnySameNameLeft = true,
init_exp = real_init_exp
)
else if (isPairwiseIntersectEmpty(sameNameList.map(e => e.vars)))
else if (
isPairwiseIntersectEmpty(
sameNameList.map(e => e.vars)
) || isSubSelfOuterProduct(sameNameList.map(e => e.vars))
)
(sameNameList, true)
else
groupBySameName(
Expand Down Expand Up @@ -141,7 +161,7 @@ object Compiler {
} else {
val newVariables = sameNameExpsList.flatMap {
case acc @ Access(_, vars, _) => vars
} // This one does not distinct since the intersection of the variables is empty
}.distinct
val newHead = Access(getVar("sameNameProdHead"), newVariables, Tensor)
val newBody = SoP(Seq(Prod(sameNameExpsList)))
val newRule = Rule(newHead, newBody)
Expand Down Expand Up @@ -240,9 +260,35 @@ object Compiler {
(us, rm, c)
}

def project(lhs: Access, rhs: Access): (Rule, Rule, Rule) = {
def checkIfItsOptimalProject(
lhs: Access,
rhs: Access,
ctx: Seq[(Rule, Rule, Rule, Rule)]
): Boolean = {
val uncommon_vars = rhs.vars.diff(lhs.vars)
val vecMult = vectorizeComparisonMultiplication(
"=",
uncommon_vars,
uncommon_vars.redundancyVars
)
val denormRM = locallyDenormalizeAndReturnBody(rhs.redundancyHead(), ctx)
denormRM.prods
.forall(p => p.exps.intersect(vecMult).toSet == vecMult.toSet)
}

def project(
lhs: Access,
rhs: Access,
ctx: Seq[(Rule, Rule, Rule, Rule)]
): (Rule, Rule, Rule) = {
assert(lhs.vars.toSet.subsetOf(rhs.vars.toSet))
if (lhs.vars.toSet == rhs.vars.toSet) {
if (
lhs.vars.toSet == rhs.vars.toSet || checkIfItsOptimalProject(
lhs,
rhs,
ctx
)
) {
val us = Rule(lhs.uniqueHead, SoP(Seq(Prod(Seq(rhs.uniqueHead)))))
val rm = Rule(lhs.redundancyHead, SoP(Seq(Prod(Seq(rhs.redundancyHead)))))
val c = Rule(lhs.compressedHead, SoP(Seq(Prod(Seq(rhs.compressedHead)))))
Expand Down Expand Up @@ -292,72 +338,9 @@ object Compiler {
.union(vars2)
.toSet && isIntersectEmpty(vars1, vars2)
) {
val usBody = SoP(
Seq(
Prod(
Seq(
acc1.uniqueHead,
acc2.uniqueHead
) ++ vectorizeComparisonMultiplication("<=", vars1, vars2)
)
)
)
val us = Rule(lhs.uniqueHead, usBody)

val rmBody = SoP(
Seq(
Prod(Seq(acc1.redundancyHead, acc2.redundancyHead)),
Prod(
Seq(
acc1.uniqueHead,
acc2.redundancyHead
) ++ vectorizeComparisonMultiplication(
"=",
vars1,
vars1.redundancyVars
)
),
Prod(
Seq(
acc1.redundancyHead,
acc2.uniqueHead
) ++ vectorizeComparisonMultiplication(
"=",
vars2,
vars2.redundancyVars
)
),
Prod(
Seq(
acc1.uniqueHead,
acc2.uniqueHead
) ++ vectorizeComparisonMultiplication(
"=",
vars1,
vars2.redundancyVars
) ++ vectorizeComparisonMultiplication(
"=",
vars2,
vars1.redundancyVars
) ++ vectorizeComparisonMultiplication(">", vars1, vars2)
)
)
)
val rm = Rule(lhs.redundancyHead, rmBody)

val cBody = SoP(
Seq(
Prod(
Seq(
acc1.compressedHead,
acc2.compressedHead
) ++ vectorizeComparisonMultiplication("<=", vars1, vars2)
)
)
)
val c = Rule(lhs.compressedHead, cBody)

(us, rm, c)
selfOuterProduct(lhs, Seq(acc1, acc2))
} else if (name1 == name2 && isSubSelfOuterProduct(Seq(vars1, vars2))) {
selfOuterProduct(lhs, Seq(acc1, acc2))
} else if (
lhs.vars.toSet == vars1.union(vars2).toSet && isIntersectEmpty(
vars1,
Expand Down Expand Up @@ -862,7 +845,7 @@ object Compiler {
SoP(Seq(Prod(accSeq.map(_.compressedHead()))))
)
(us, rm, c)
} else { // we made sure in the normalization step that all the accesses have the same name and their variable pairwise intersection is empty
} else if (isPairwiseIntersectEmpty(accSeq.map(acc => acc.vars))) { // we made sure in the normalization step that all the accesses have the same name and their variable pairwise intersection is empty
val (us2, rm2, c2) = selfOuterProduct2(lhs, accSeq)
val allUniqueHeads = accSeq.map(_.uniqueHead())

Expand Down Expand Up @@ -907,6 +890,114 @@ object Compiler {
val cBody = prodTimesSoP(Prod(accSeq.map(_.compressedHead())), c2.body)
val c = Rule(lhs.compressedHead(), cBody)

(us, rm, c)
} else { // We assured that if the pairwise intersection is not empty, then the isSubSelfOuter function has returned true to be here.
val vars = accSeq.map(_.vars)
val all_intersect =
vars.tail.foldLeft(vars.head)((acc, cur) => acc.intersect(cur))

val (usSoPSeq, rmSoPSeq, cSoPSeq) = (0 to accSeq.length - 1)
.flatMap(i => {
(0 to accSeq.length - 1)
.combinations(i)
.map(indexList => {
val prodAndVectorizedMult = Prod(
accSeq.zipWithIndex
.collect {
case (acc, ind) if !indexList.contains(ind) =>
vectorizeComparisonMultiplication(
"=",
all_intersect,
all_intersect.redundancyVars
) :+
acc.uniqueHead
case (acc, ind) if indexList.contains(ind) =>
Seq(acc.redundancyHead)
}
.flatten
.toSeq
)

val (prodExps, vecMult) = prodAndVectorizedMult.exps.partition {
case _: Access => true
case _ => false
}
val prod = Prod(prodExps)

val (uniqueAccesses, rest) = prod.exps.partition {
case _ @Access(_, _, UniqueSet) => true
case _ => false
}

if (uniqueAccesses.length == 1) {
val us = SoP(Seq(prod))
val rm = emptySoP()
val c11 = uniqueAccesses.collect { case u: Access =>
Access(
u.name.substring(0, u.name.length() - 2) + "C",
u.vars,
CompressedTensor
)
}
val c12 = rest.collect { case r: Access =>
Access(
r.name.substring(0, r.name.length() - 2) + "C",
r.vars.drop(r.vars.length / 2),
CompressedTensor
)
}
val c = SoP(Seq(Prod(c11 ++ c12 ++ rest)))

(us, rm, c)
} else {
val all_unique_accesses_minus_intersect =
uniqueAccesses.collect {
case _ @Access(name, variables, kind) =>
Access(name, variables.diff(all_intersect), kind)
}

val (us2, rm2, c2) =
selfOuterProduct2(lhs, all_unique_accesses_minus_intersect)

val us = prodTimesSoP(prod, us2.body)
val injectedMapRM =
injectRM(
rm2.body,
us2.body.prods.head,
all_unique_accesses_minus_intersect.map(_.vars).transpose
)
val rm =
prodTimesSoP(prodAndVectorizedMult, injectedMapRM)

val c11 = uniqueAccesses.collect { case u: Access =>
Access(
u.name.substring(0, u.name.length() - 2) + "C",
u.vars,
CompressedTensor
)
}
val c12 = rest.collect { case r: Access =>
Access(
r.name.substring(0, r.name.length() - 2) + "C",
r.vars.drop(r.vars.length / 2),
CompressedTensor
)
}
val c = prodTimesSoP(Prod(c11 ++ c12 ++ rest), c2.body)

(us, rm, c)
}
})
})
.unzip3

val us = Rule(lhs.uniqueHead, concatSoP(usSoPSeq))
val rm = Rule(
lhs.redundancyHead,
concatSoP(rmSoPSeq :+ SoP(Seq(Prod(accSeq.map(_.redundancyHead())))))
)
val c = Rule(lhs.compressedHead, concatSoP(cSoPSeq))

(us, rm, c)
}
}
Expand Down Expand Up @@ -1064,7 +1155,7 @@ object Compiler {
}
} else if (exps.length == 1) {
exps.head match {
case access: Access => project(head, access)
case access: Access => project(head, access, ctx)
case _ => throw new AssertionError("Expected an Access expression")
}
} else if (exps.length == 2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,14 @@ object Optimizer {
}
})

val all_variable_names =
(symbols.map(_.name) ++ boundVariables(Seq(rule))).distinct
val base_variables =
symbols.map(_.name) ++ rule.head.vars.map(_.name).distinct
val finalBody = SoP(newBody.prods.map(p => {
val all_variable_names = base_variables ++ p.exps
.collect { case a: Access => a.vars }
.flatten
.distinct
.map(_.name)
Prod(
p.exps.filter(e =>
getVariables(e).map(_.name).forall(all_variable_names.contains(_))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ B2[j] += B[i][j];
long time_computation = 0, start_computation, end_computation;
start_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
int i = 0;
for (int i30 = 0; i30 < W; ++i30) {
for (int i28 = 0; i28 < W; ++i28) {

A[i] += (B1[i30] * C[i30]);
A[i] += (B1[i28] * C[i28]);
}
for (int i = 1; i < W; ++i) {

int i31 = (i - 1);
if (i31 >= 0 && i31 < min({(W - 1), W})) {
A[i] += (B2[i31] * C[i31]);
int i29 = (i - 1);
if (i29 >= 0 && i29 < min({(W - 1), W})) {
A[i] += (B2[i29] * C[i29]);
}
}
end_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ B2[j] += B[i][j];
long time_computation = 0, start_computation, end_computation;
start_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
int i = 0;
for (int i28 = 0; i28 < W; ++i28) {
for (int i26 = 0; i26 < W; ++i26) {

A[i] += (B1[i28] * C[i28]);
A[i] += (B1[i26] * C[i26]);
}
for (int i = 1; i < W; ++i) {

int i29 = (i - 1);
if (i29 >= 0 && i29 < min({(W - 1), W})) {
A[i] += (B2[i29] * C[i29]);
int i27 = (i - 1);
if (i27 >= 0 && i27 < min({(W - 1), W})) {
A[i] += (B2[i27] * C[i27]);
}
}
end_computation = duration_cast<microseconds>(system_clock::now().time_since_epoch()).count();
Expand Down
Loading

0 comments on commit cb78ba1

Please sign in to comment.