Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused functions #9

Merged
merged 2 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 21 additions & 309 deletions src/main/scala/uk/ac/ed/dal/structtensor/codegen/CodegenUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,6 @@ import compiler._
import Utils._

object CodegenUtils {
def MLIR_init_code(): String = s"""
"builtin.module"() ({
"func.func"() ({}) {function_type = (!llvm.ptr) -> i32, sym_name = "atoi", sym_visibility = "private"} : () -> ()
"func.func"() ({}) {function_type = (i32) -> (), sym_name = "srand", sym_visibility = "private"} : () -> ()
"func.func"() ({}) {function_type = () -> (i32), sym_name = "rand", sym_visibility = "private"} : () -> ()
"func.func"() ({}) {function_type = (i32) -> (), sym_name = "print_i32", sym_visibility = "private"} : () -> ()
"func.func"() ({}) {function_type = (i64) -> (), sym_name = "print_i64", sym_visibility = "private"} : () -> ()
"func.func"() ({}) {function_type = (f64) -> (), sym_name = "print_f64_cerr", sym_visibility = "private"} : () -> ()
"func.func"() ({}) {function_type = () -> (i64), sym_name = "timer", sym_visibility = "private"} : () -> ()
"func.func"() ({}) {function_type = (i64) -> (i64), sym_name = "timer_elapsed", sym_visibility = "private"} : () -> ()

"func.func"() ({
^bb0(%argc : i32, %argv : !llvm.ptr):
%const_val_0 = "arith.constant"() {"value" = 0 : index} : () -> index
%zi32 = "arith.constant"() {"value" = 0 : i32} : () -> i32
%zerof = "arith.constant"() {"value" = 0.0 : f64} : () -> f64
%onef = "arith.constant"() {"value" = 1.0 : f64} : () -> f64
%const_val_1 = "arith.constant"() {"value" = 1 : index} : () -> index
%oi1 = "arith.constant"() {"value" = 1 : i1} : () -> i1
%oi0 = "arith.constant"() {"value" = 0 : i1} : () -> i1
"func.call"(%zi32) {callee = @srand} : (i32) -> ()
%1000000 = "arith.constant"() {"value" = 1000000 : i32} : () -> i32
%f1000000 = "arith.constant"() {"value" = 1000000.0 : f64} : () -> f64
"""

def MLIR_read_argv(argv_names: Seq[String]): String = {
argv_names.zipWithIndex
.map {
case (name, i) => {
val id = i + 1
s"""
%argvv$id = llvm.getelementptr %argv[$id] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
%argv$id = "llvm.load"(%argvv$id) : (!llvm.ptr) -> !llvm.ptr
%${name}i32 = "func.call"(%argv$id) {callee = @atoi} : (!llvm.ptr) -> i32
%$name = arith.index_cast %${name}i32 : i32 to index
"""
}
}
.mkString("\n")
}

def CPP_read_argv(argv_names: Seq[String]): String = argv_names.zipWithIndex
.map { case (name, i) => s"const int $name = atoi(argv[${i + 1}]);" }
Expand Down Expand Up @@ -289,222 +249,6 @@ int main(int argc, char **argv){
}
}

def MLIR_generate_arith(
op: String,
i1: String,
i2: String,
output_var: String,
op_type: Int = 0
): String = {
val tp = if (op_type == 0) "i32" else "index"
op match {
case "%" =>
s"""$output_var = "arith.remui"($i1, $i2) : ($tp, $tp) -> $tp"""
case "+" =>
s"""$output_var = "arith.addi"($i1, $i2) : ($tp, $tp) -> $tp"""
case "-" =>
s"""$output_var = "arith.subi"($i1, $i2) : ($tp, $tp) -> $tp"""
case "*" =>
s"""$output_var = "arith.muli"($i1, $i2) : ($tp, $tp) -> $tp"""
case "/" =>
s"""$output_var = "arith.divui"($i1, $i2) : ($tp, $tp) -> $tp"""
case _ => throw new Exception("Invalid arithmetic operator")
}
}

def MLIR_convert_index(index: Index, op_type: Int = 0): (String, String) = {
index match {
case Variable(name) => (s"%$name", "")
case Arithmetic(op, i1, i2) => {
val (arith_var1, arith_code1) = MLIR_convert_index(i1, op_type)
val (arith_var2, arith_code2) = MLIR_convert_index(i2, op_type)
val arith_var = getVar("%arith_var")
val arith_code =
MLIR_generate_arith(op, arith_var1, arith_var2, arith_var, op_type)
(arith_var, arith_code1 + "\n" + arith_code2 + "\n" + arith_code + "\n")
}
case ConstantInt(i) => {
val tp = if (op_type == 0) "i32" else "index"
val var_name = getVar("%consti")
(
var_name,
s"""$var_name = "arith.constant"() {value = $i : $tp} : () -> $tp"""
)
}
case ConstantDouble(d) => {
val tp = "f64"
val var_name = getVar("%constf")
(
var_name,
s"""$var_name = "arith.constant"() {value = $d : $tp} : () -> $tp"""
)
}
case _ => throw new Exception("Invalid index")
}
}

def MLIR_comp_to_predicate(op: String): String = op match {
case "<" => "2" // "slt"
case ">" => "4" // "sgt"
case "<=" => "3" // "sle"
case ">=" => "5" // "sge"
case "==" => "0" // "eq"
case "=" => "0" // "eq"
case "!=" => "1" // "ne"
case _ => throw new Exception("Invalid comparison operator")
}

def MLIR_convert_condition(condition: SoP): (String, String) = {
val orFlag = getVar("%orFlag")
val andFlag = getVar("%andFlag")
val orFlagCode =
s"""$orFlag = "arith.constant"() {value = 0 : i1} : () -> i1"""
val andFlagCode =
s"""$andFlag = "arith.constant"() {value = 1 : i1} : () -> i1"""
val (flag, res) = condition.prods.foldLeft((orFlag, ""))((acc, p) => {
val (flag2, res2) = p.exps.foldLeft((andFlag, ""))((acc2, e) => {
e match {
case Comparison(op, index, variable) => {
val f = getVar("%andFlag")
val cmp_var = getVar("%cmpFlag")
val pred = MLIR_comp_to_predicate(op)
val (arith_var, arith_code) = MLIR_convert_index(index, 1)
val cmp_code =
s"""$cmp_var = "arith.cmpi"($arith_var, %${variable.name}) {predicate = $pred : i64} : (index, index) -> i1"""
val and_code =
s"""$f = "arith.andi"(${acc2._1}, $cmp_var) : (i1, i1) -> i1"""
(
f,
acc2._2 + "\n" + arith_code + "\n" + cmp_code + "\n" + and_code + "\n"
)
}
case _ => acc2
}
})
val fflag = getVar("%orFlag")
val or_code =
s"""$fflag = "arith.ori"(${acc._1}, $flag2) : (i1, i1) -> i1"""
(fflag, acc._2 + "\n" + res2 + "\n" + or_code + "\n")
})
(flag, orFlagCode + "\n" + andFlagCode + "\n" + res)
}

def MLIR_alloc_and_gen_random_number(
head: Access,
dims: Seq[Dim],
sopCond: SoP
): String = {
dims match {
case Nil => {
val rval1 = getVar("rval")
val rval2 = getVar("rval")
val rval3 = getVar("rval")
s"""
%$rval1 = "func.call"() {callee = @rand} : () -> i32
%$rval2 = "arith.remui"(%$rval1, %1000000) : (i32, i32) -> i32
%$rval3 = "arith.sitofp"(%$rval2) : (i32) -> f64
%${head.name} = "arith.divf"(%$rval3, %f1000000) : (f64, f64) -> f64
"""
}
case _ => {
val var_name = head.name
val vars = head.vars
val iter_seq = vars.map(_.name)
val dimsAndCode: Seq[(String, String)] =
dims.map(d => MLIR_convert_index(d, 1))
val dimensions = dimsAndCode.map(_._1)
val dimensions_code = dimsAndCode.map(_._2).mkString("\n")
val c0 =
s"%$var_name = memref.alloc(${dimensions.mkString(", ")}) : memref<${"?x" * dimensions.length}f64>"

val (flag, c11) = MLIR_convert_condition(sopCond)

val c1 = dimensions
.zip(iter_seq)
.map {
case (dim, i) => {
s"""
"scf.for"(%const_val_0, $dim, %const_val_1) ({
^bb0(%$i: index):
"""
}
}
.mkString("\n")
val ivars = iter_seq.map(e => s"%$e").mkString(", ")
val qvars = dimensions.map(e => s"?").mkString("x") + "x"
val index_vars = dimensions.map(e => s"index").mkString(", ")

val rval1 = getVar("rval")
val rval2 = getVar("rval")
val rval3 = getVar("rval")
val rval4 = getVar("rval")
val c2 = s"""
"scf.if"(${flag}) ({
%$rval1 = "func.call"() {callee = @rand} : () -> i32
%$rval2 = "arith.remui"(%$rval1, %1000000) : (i32, i32) -> i32
%$rval3 = "arith.sitofp"(%$rval2) : (i32) -> f64
%$rval4 = "arith.divf"(%$rval3, %f1000000) : (f64, f64) -> f64
"memref.store"(%$rval4, %$var_name, $ivars): (f64, memref<${qvars}f64>, $index_vars) -> ()
"scf.yield"() : () -> ()
}, {
"memref.store"(%zerof, %$var_name, $ivars): (f64, memref<${qvars}f64>, $index_vars) -> ()
"scf.yield"() : () -> ()
}) : (i1) -> ()
"""
val c3 = dimensions
.map(dim => s"""
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
""").mkString("\n")
s"$dimensions_code\n$c0\n$c1\n$c11\n$c2\n$c3"
}
}
}

def MLIR_alloc_and_gen_zero(head: Access, dims: Seq[Dim]): String = {
dims match {
case Nil =>
s"""%${head.name} = "arith.constant"() {value = 0.0 : f64} : () -> f64"""
case _ => {
val var_name = head.name
val vars = head.vars
val iter_seq = vars.map(_.name)
val dimsAndCode: Seq[(String, String)] =
dims.map(d => MLIR_convert_index(d, 1))
val dimensions = dimsAndCode.map(_._1)
val dimensions_code = dimsAndCode.map(_._2).mkString("\n")
val c0 =
s"%$var_name = memref.alloc(${dimensions.mkString(", ")}) : memref<${"?x" * dimensions.length}f64>"

val c1 = dimensions
.zip(iter_seq)
.map {
case (dim, i) => {
s"""
"scf.for"(%const_val_0, $dim, %const_val_1) ({
^bb0(%$i: index):
"""
}
}
.mkString("\n")
val ivars = iter_seq.map(e => s"%$e").mkString(", ")
val qvars = dimensions.map(e => s"?").mkString("x") + "x"
val index_vars = dimensions.map(e => s"index").mkString(", ")

val c2 = s"""
"memref.store"(%zerof, %$var_name, $ivars): (f64, memref<${qvars}f64>, $index_vars) -> ()
"""

val c3 = dimensions
.map(dim => s"""
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
""").mkString("\n")
s"$dimensions_code\n$c0\n$c1\n$c2\n$c3"
}
}
}

def CPP_printerr(access: Access): String = access.vars.isEmpty match {
case true => s"cerr << ${access.name} << endl;"
case false =>
Expand All @@ -519,14 +263,6 @@ int main(int argc, char **argv){
.mkString("][")}]);"""
}

def MLIR_printerr(access: Access): String = {
val last = getVar("%last")
s"""
$last = "memref.load"(%${access.name}${", %const_val_0" * access.vars.length}) : (memref<${"?x" * access.vars.length}f64>${", index" * access.vars.length}) -> f64
"func.call"($last) {callee = @print_f64_cerr} : (f64) -> ()
"""
}

def CPP_free(var_name: String, dims: Seq[Dim]) = {
val dimensions = dims.map(C_convert_index(_))
val c0 = dimensions.init.zipWithIndex
Expand All @@ -544,31 +280,17 @@ $last = "memref.load"(%${access.name}${", %const_val_0" * access.vars.length}) :
s"$c0$c1$c2"
}

def MLIR_free(var_name: String, dims: Seq[Dim]) = {
val dim_str = "?x" * dims.length
s""""memref.dealloc"(%$var_name) : (memref<${dim_str}f64>) -> ()"""
}

def MLIR_return(): String = """"func.return"() : () -> ()
}) {function_type = (i32, !llvm.ptr) -> (), sym_name = "main", sym_visibility = "private"} : () -> ()
}) : () -> ()
"""

def init_code(lang: String): String = lang.toUpperCase() match {
case "C" => C_init_code()
case "CPP" => CPP_init_code()
case "MLIR" => MLIR_init_code()
case "SNIPPETS" => MLIR_init_code()
case _ => throw new Exception("Unknown code language")
case "C" => C_init_code()
case "CPP" => CPP_init_code()
case _ => throw new Exception("Unknown code language")
}

def read_argv(lang: String, argv_names: Seq[String]): String =
lang.toUpperCase() match {
case "C" => C_read_argv(argv_names)
case "CPP" => CPP_read_argv(argv_names)
case "MLIR" => MLIR_read_argv(argv_names)
case "SNIPPETS" => MLIR_read_argv(argv_names)
case _ => throw new Exception("Unknown code language")
case "C" => C_read_argv(argv_names)
case "CPP" => CPP_read_argv(argv_names)
case _ => throw new Exception("Unknown code language")
}

def alloc_and_gen_random_number(
Expand All @@ -577,45 +299,35 @@ $last = "memref.load"(%${access.name}${", %const_val_0" * access.vars.length}) :
dims: Seq[Dim],
sopCond: SoP
): String = lang.toUpperCase() match {
case "C" => C_alloc_and_gen_random_number(head, dims, sopCond)
case "CPP" => CPP_alloc_and_gen_random_number(head, dims, sopCond)
case "MLIR" => MLIR_alloc_and_gen_random_number(head, dims, sopCond)
case "SNIPPETS" => MLIR_alloc_and_gen_random_number(head, dims, sopCond)
case _ => throw new Exception("Unknown code language")
case "C" => C_alloc_and_gen_random_number(head, dims, sopCond)
case "CPP" => CPP_alloc_and_gen_random_number(head, dims, sopCond)
case _ => throw new Exception("Unknown code language")
}

def alloc_and_gen_zero(lang: String, head: Access, dims: Seq[Dim]): String =
lang.toUpperCase() match {
case "C" => C_alloc_and_gen_zero(head, dims)
case "CPP" => CPP_alloc_and_gen_zero(head, dims)
case "MLIR" => MLIR_alloc_and_gen_zero(head, dims)
case "SNIPPETS" => MLIR_alloc_and_gen_zero(head, dims)
case _ => throw new Exception("Unknown code language")
case "C" => C_alloc_and_gen_zero(head, dims)
case "CPP" => CPP_alloc_and_gen_zero(head, dims)
case _ => throw new Exception("Unknown code language")
}

def printerr(lang: String, head: Access): String = lang.toUpperCase() match {
case "C" => C_printerr(head)
case "CPP" => CPP_printerr(head)
case "MLIR" => MLIR_printerr(head)
case "SNIPPETS" => MLIR_printerr(head)
case _ => throw new Exception("Unknown code language")
case "C" => C_printerr(head)
case "CPP" => CPP_printerr(head)
case _ => throw new Exception("Unknown code language")
}

def free(lang: String, var_name: String, dims: Seq[Dim]) =
lang.toUpperCase() match {
case "C" => C_free(var_name)
case "CPP" => CPP_free(var_name, dims)
case "MLIR" => MLIR_free(var_name, dims)
case "SNIPPETS" => MLIR_free(var_name, dims)
case _ => throw new Exception("Unknown code language")
case "C" => C_free(var_name)
case "CPP" => CPP_free(var_name, dims)
case _ => throw new Exception("Unknown code language")
}

def return_code(lang: String): String = lang.toUpperCase() match {
case "C" => C_return()
case "CPP" => CPP_return()
case "MLIR" => MLIR_return()
case "SNIPPETS" => MLIR_return()
case _ => throw new Exception("Unknown code language")
case "C" => C_return()
case "CPP" => CPP_return()
case _ => throw new Exception("Unknown code language")
}

}
Loading