args = new ArrayList<>();
+ for (AST.Expr expr: callExpr.args) {
+ boolean indexed = compileExpr(expr);
+ if (indexed)
+ codeIndexedLoad();
+ var arg = top();
+ if (!(arg instanceof Operand.TempRegisterOperand) ) {
+ var origArg = pop();
+ arg = createTemp(origArg.type);
+ code(new Instruction.Move(origArg, arg));
+ }
+ args.add((Operand.RegisterOperand) arg);
+ }
+ // Simulate the actions on the stack
+ for (int i = 0; i < args.size(); i++)
+ pop();
+ Operand.TempRegisterOperand ret = null;
+ if (callExpr.callee.type instanceof Type.TypeFunction tf &&
+ !(tf.returnType instanceof Type.TypeVoid)) {
+ ret = createTemp(tf.returnType);
+ //assert ret.regnum-maxLocalReg == returnStackPos;
+ }
+ code(new Instruction.Call(returnStackPos, ret, calleeType, args.toArray(new Operand.RegisterOperand[args.size()])));
+ return false;
+ }
+
+ private Type.TypeStruct getStructType(Type t) {
+ if (t instanceof Type.TypeStruct typeStruct) {
+ return typeStruct;
+ }
+ else if (t instanceof Type.TypeNullable ptr &&
+ ptr.baseType instanceof Type.TypeStruct typeStruct) {
+ return typeStruct;
+ }
+ else
+ throw new CompilerException("Unexpected type: " + t);
+ }
+
+ private boolean compileFieldExpr(AST.FieldExpr fieldExpr) {
+ Type.TypeStruct typeStruct = getStructType(fieldExpr.object.type);
+ int fieldIndex = typeStruct.getFieldIndex(fieldExpr.fieldName);
+ if (fieldIndex < 0)
+ throw new CompilerException("Field " + fieldExpr.fieldName + " not found");
+ boolean indexed = compileExpr(fieldExpr.object);
+ if (indexed)
+ codeIndexedLoad();
+ pushOperand(new Operand.LoadFieldOperand(pop(), fieldExpr.fieldName, fieldIndex));
+ return true;
+ }
+
+ private boolean compileArrayIndexExpr(AST.ArrayIndexExpr arrayIndexExpr) {
+ compileExpr(arrayIndexExpr.array);
+ boolean indexed = compileExpr(arrayIndexExpr.expr);
+ if (indexed)
+ codeIndexedLoad();
+ Operand index = pop();
+ Operand array = pop();
+ pushOperand(new Operand.LoadIndexedOperand(array, index));
+ return true;
+ }
+
+ private boolean compileSetFieldExpr(AST.SetFieldExpr setFieldExpr) {
+ Type.TypeStruct structType = (Type.TypeStruct) setFieldExpr.objectType;
+ int fieldIndex = structType.getFieldIndex(setFieldExpr.fieldName);
+ if (fieldIndex == -1)
+ throw new CompilerException("Field " + setFieldExpr.fieldName + " not found in struct " + structType.name);
+ pushOperand(new Operand.LoadFieldOperand(top(), setFieldExpr.fieldName, fieldIndex));
+ boolean indexed = compileExpr(setFieldExpr.value);
+ if (indexed)
+ codeIndexedLoad();
+ codeIndexedStore();
+ return false;
+ }
+
+ private void codeNew(Type type) {
+ var temp = createTemp(type);
+ if (type instanceof Type.TypeArray typeArray) {
+ code(new Instruction.NewArray(typeArray, temp));
+ }
+ else if (type instanceof Type.TypeStruct typeStruct) {
+ code(new Instruction.NewStruct(typeStruct, temp));
+ }
+ else
+ throw new CompilerException("Unexpected type: " + type);
+ }
+
+ private void codeStoreAppend() {
+ var operand = pop();
+ code(new Instruction.AStoreAppend((Operand.RegisterOperand) top(), operand));
+ }
+
+ private boolean compileNewExpr(AST.NewExpr newExpr) {
+ codeNew(newExpr.type);
+ if (newExpr.initExprList != null && !newExpr.initExprList.isEmpty()) {
+ if (newExpr.type instanceof Type.TypeArray) {
+ for (AST.Expr expr : newExpr.initExprList) {
+ // Maybe have specific AST similar to how we have SetFieldExpr?
+ boolean indexed = compileExpr(expr);
+ if (indexed)
+ codeIndexedLoad();
+ codeStoreAppend();
+ }
+ }
+ else if (newExpr.type instanceof Type.TypeStruct) {
+ for (AST.Expr expr : newExpr.initExprList) {
+ compileExpr(expr);
+ }
+ }
+ }
+ return false;
+ }
+
+ private boolean compileSymbolExpr(AST.NameExpr symbolExpr) {
+ if (symbolExpr.type instanceof Type.TypeFunction functionType)
+ pushOperand(new Operand.LocalFunctionOperand(functionType));
+ else {
+ Symbol.VarSymbol varSymbol = (Symbol.VarSymbol) symbolExpr.symbol;
+ pushLocal(varSymbol.reg);
+ }
+ return false;
+ }
+
+ private boolean compileBinaryExpr(AST.BinaryExpr binaryExpr) {
+ String opCode = null;
+ boolean indexed = compileExpr(binaryExpr.expr1);
+ if (indexed)
+ codeIndexedLoad();
+ indexed = compileExpr(binaryExpr.expr2);
+ if (indexed)
+ codeIndexedLoad();
+ opCode = binaryExpr.op.str;
+ Operand right = pop();
+ Operand left = pop();
+ if (left instanceof Operand.ConstantOperand leftconstant &&
+ right instanceof Operand.ConstantOperand rightconstant) {
+ long value = 0;
+ switch (opCode) {
+ case "+": value = leftconstant.value + rightconstant.value; break;
+ case "-": value = leftconstant.value - rightconstant.value; break;
+ case "*": value = leftconstant.value * rightconstant.value; break;
+ case "/": value = leftconstant.value / rightconstant.value; break;
+ case "%": value = leftconstant.value % rightconstant.value; break;
+ case "==": value = leftconstant.value == rightconstant.value ? 1 : 0; break;
+ case "!=": value = leftconstant.value != rightconstant.value ? 1 : 0; break;
+ case "<": value = leftconstant.value < rightconstant.value ? 1: 0; break;
+ case ">": value = leftconstant.value > rightconstant.value ? 1 : 0; break;
+ case "<=": value = leftconstant.value <= rightconstant.value ? 1 : 0; break;
+ case ">=": value = leftconstant.value <= rightconstant.value ? 1 : 0; break;
+ default: throw new CompilerException("Invalid binary op");
+ }
+ pushConstant(value, leftconstant.type);
+ }
+ else {
+ var temp = createTemp(binaryExpr.type);
+ code(new Instruction.Binary(opCode, temp, left, right));
+ }
+ return false;
+ }
+
+ private boolean compileUnaryExpr(AST.UnaryExpr unaryExpr) {
+ String opCode;
+ boolean indexed = compileExpr(unaryExpr.expr);
+ if (indexed)
+ codeIndexedLoad();
+ opCode = unaryExpr.op.str;
+ Operand top = pop();
+ if (top instanceof Operand.ConstantOperand constant) {
+ switch (opCode) {
+ case "-": pushConstant(-constant.value, constant.type); break;
+ // Maybe below we should explicitly set Int
+ case "!": pushConstant(constant.value == 0?1:0, constant.type); break;
+ default: throw new CompilerException("Invalid unary op");
+ }
+ }
+ else {
+ var temp = createTemp(unaryExpr.type);
+ code(new Instruction.Unary(opCode, temp, top));
+ }
+ return false;
+ }
+
+ private boolean compileConstantExpr(AST.LiteralExpr constantExpr) {
+ pushConstant(constantExpr.value.num.intValue(), constantExpr.type);
+ return false;
+ }
+
+ private void pushConstant(long value, Type type) {
+ pushOperand(new Operand.ConstantOperand(value, type));
+ }
+
+ private Operand.TempRegisterOperand createTemp(Type type) {
+ var offset = virtualStack.size()+maxLocalReg;
+ var id = nextReg++;
+ var name = "%t" + id;
+ var tempRegister = new Operand.TempRegisterOperand(offset, id, name, type);
+ pushOperand(tempRegister);
+ if (maxStackSize < virtualStack.size())
+ maxStackSize = virtualStack.size();
+ return tempRegister;
+ }
+
+ private void pushLocal(Register reg) {
+ pushOperand(new Operand.LocalRegisterOperand(reg));
+ }
+
+ private void pushOperand(Operand operand) {
+ virtualStack.add(operand);
+ }
+
+ private Operand pop() {
+ return virtualStack.removeLast();
+ }
+
+ private Operand top() {
+ return virtualStack.getLast();
+ }
+
+ private void codeIndexedLoad() {
+ Operand indexed = pop();
+ var temp = createTemp(indexed.type);
+ if (indexed instanceof Operand.LoadIndexedOperand loadIndexedOperand) {
+ code(new Instruction.ArrayLoad(loadIndexedOperand, temp));
+ }
+ else if (indexed instanceof Operand.LoadFieldOperand loadFieldOperand) {
+ code(new Instruction.GetField(loadFieldOperand, temp));
+ }
+ else
+ code(new Instruction.Move(indexed, temp));
+ }
+
+ private void codeIndexedStore() {
+ Operand value = pop();
+ Operand indexed = pop();
+ if (indexed instanceof Operand.LoadIndexedOperand loadIndexedOperand) {
+ code(new Instruction.ArrayStore(value, loadIndexedOperand));
+ }
+ else if (indexed instanceof Operand.LoadFieldOperand loadFieldOperand) {
+ code(new Instruction.SetField(value, loadFieldOperand));
+ }
+ else
+ code(new Instruction.Move(value, indexed));
+ }
+
+ private boolean vstackEmpty() {
+ return virtualStack.isEmpty();
+ }
+}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/DominatorTree.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/DominatorTree.java
new file mode 100644
index 0000000..ec13103
--- /dev/null
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/DominatorTree.java
@@ -0,0 +1,238 @@
+package com.compilerprogramming.ezlang.bytecode;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.function.Consumer;
+
+/**
+ * The dominator tree construction algorithm is based on figure 9.24,
+ * chapter 9, p 532, of Engineering a Compiler.
+ *
+ * The algorithm is also described in the paper 'A Simple, Fast
+ * Dominance Algorithm' by Keith D. Cooper, Timothy J. Harvey and
+ * Ken Kennedy.
+ */
+public class DominatorTree {
+ BasicBlock entry;
+ // List of basic blocks reachable from _entry block, including the _entry
+ List blocks;
+
+ int preorder;
+ int rpostorder;
+
+ /**
+ * Builds a Dominator Tree.
+ *
+ * @param entry The entry block
+ */
+ public DominatorTree(BasicBlock entry) {
+ this.entry = entry;
+ blocks = findAllBlocks(entry);
+ calculateDominatorTree();
+ populateTree();
+ setDepth();
+ calculateDominanceFrontiers();
+ }
+
+ /**
+ * Utility to locate all the basic blocks, order does not matter.
+ */
+ private static List findAllBlocks(BasicBlock root) {
+ List nodes = new ArrayList<>();
+ postOrderWalk(root, (n) -> nodes.add(n), new HashSet<>());
+ return nodes;
+ }
+
+ static void postOrderWalk(BasicBlock n, Consumer consumer, HashSet visited) {
+ visited.add(n);
+ /* For each successor node */
+ for (BasicBlock s : n.successors) {
+ if (!visited.contains(s))
+ postOrderWalk(s, consumer, visited);
+ }
+ consumer.accept(n);
+ }
+
+ private void calculateDominatorTree() {
+ resetDomInfo();
+ annotateBlocksWithRPO();
+ sortBlocksByRPO();
+
+ // Set IDom entry for root to itself
+ entry.idom = entry;
+ boolean changed = true;
+ while (changed) {
+ changed = false;
+ // for all nodes, b, in reverse postorder (except root)
+ for (BasicBlock bb : blocks) {
+ if (bb == entry) // skip root
+ continue;
+ // NewIDom = first (processed) predecessor of b, pick one
+ BasicBlock firstPred = findFirstPredecessorWithIdom(bb);
+ assert firstPred != null;
+ BasicBlock newIDom = firstPred;
+ // for all other predecessors, p, of b
+ for (BasicBlock predecessor : bb.predecessors) {
+ if (predecessor == firstPred) continue; // all other predecessors
+ if (predecessor.idom != null) {
+ // i.e. IDoms[p] calculated
+ newIDom = intersect(predecessor, newIDom);
+ }
+ }
+ if (bb.idom != newIDom) {
+ bb.idom = newIDom;
+ changed = true;
+ }
+ }
+ }
+ }
+
+ private void resetDomInfo() {
+ for (BasicBlock bb : blocks)
+ bb.resetDomInfo();
+ }
+
+ /**
+ * Assign rpo number to all the basic blocks.
+ * The rpo number defines the Reverse Post Order traversal of blocks.
+ * The Dominance calculator requires the rpo number.
+ */
+ private void annotateBlocksWithRPO() {
+ preorder = 1;
+ rpostorder = blocks.size();
+ for (BasicBlock n : blocks) n.resetRPO();
+ postOrderWalkSetRPO(entry);
+ }
+
+ // compute rpo using a depth first search
+ private void postOrderWalkSetRPO(BasicBlock n) {
+ n.pre = preorder++;
+ for (BasicBlock s : n.successors) {
+ if (s.pre == 0)
+ postOrderWalkSetRPO(s);
+ }
+ n.rpo = rpostorder--;
+ }
+
+ /**
+ * Reverse post order gives a topological sort order
+ */
+ private void sortBlocksByRPO() {
+ blocks.sort(Comparator.comparingInt(n -> n.rpo));
+ }
+
+ /**
+ * Finds nearest common ancestor
+ *
+ * The algorithm starts at the two nodes whose sets are being intersected, and walks
+ * upward from each toward the root. By comparing the nodes with their RPO numbers
+ * the algorithm finds the common ancestor - the immediate dominator of i and j.
+ */
+ private BasicBlock intersect(BasicBlock i, BasicBlock j) {
+ BasicBlock finger1 = i;
+ BasicBlock finger2 = j;
+ while (finger1 != finger2) {
+ while (finger1.rpo > finger2.rpo) {
+ finger1 = finger1.idom;
+ assert finger1 != null;
+ }
+ while (finger2.rpo > finger1.rpo) {
+ finger2 = finger2.idom;
+ assert finger2 != null;
+ }
+ }
+ return finger1;
+ }
+
+ /**
+ * Look for the first predecessor whose immediate dominator has been calculated.
+ * Because of the order in which this search occurs, we will always find at least 1
+ * such predecessor.
+ */
+ private BasicBlock findFirstPredecessorWithIdom(BasicBlock n) {
+ for (BasicBlock p : n.predecessors) {
+ if (p.idom != null) return p;
+ }
+ return null;
+ }
+
+ /**
+ * Setup the dominator tree.
+ * Each block gets the list of blocks it strictly dominates.
+ */
+ private void populateTree() {
+ for (BasicBlock block : blocks) {
+ BasicBlock idom = block.idom;
+ if (idom == block) // root
+ continue;
+ // add edge from idom to n
+ idom.dominatedChildren.add(block);
+ }
+ }
+
+ /**
+ * Sets the dominator depth on each block
+ */
+ private void setDepth() {
+ entry.domDepth = 1;
+ setDepth_(entry);
+ }
+
+ /**
+ * Sets the dominator depth on each block
+ */
+ private void setDepth_(BasicBlock block) {
+ BasicBlock idom = block.idom;
+ if (idom != block) {
+ assert idom.domDepth > 0;
+ block.domDepth = idom.domDepth + 1;
+ } else {
+ assert idom.domDepth == 1;
+ assert idom.domDepth == block.domDepth;
+ }
+ for (BasicBlock child : block.dominatedChildren)
+ setDepth_(child);
+ }
+
+ /**
+ * Calculates dominance-frontiers for nodes
+ */
+ private void calculateDominanceFrontiers() {
+ // Dominance-Frontier Algorithm - fig 5 in 'A Simple, Fast Dominance Algorithm'
+ //for all nodes, b
+ // if the number of predecessors of b ≥ 2
+ // for all predecessors, p, of b
+ // runner ← p
+ // while runner != doms[b]
+ // add b to runner’s dominance frontier set
+ // runner = doms[runner]
+ for (BasicBlock b : blocks) {
+ if (b.predecessors.size() >= 2) {
+ for (BasicBlock p : b.predecessors) {
+ BasicBlock runner = p;
+ while (runner != b.idom) {
+ runner.dominationFrontier.add(b);
+ runner = runner.idom;
+ }
+ }
+ }
+ }
+ }
+
+ public String generateDotOutput() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("digraph DomTree {\n");
+ for (BasicBlock n : blocks) {
+ sb.append(n.uniqueName()).append(" [label=\"").append(n.label()).append("\"];\n");
+ }
+ for (BasicBlock n : blocks) {
+ BasicBlock idom = n.idom;
+ if (idom == n) continue;
+ sb.append(idom.uniqueName()).append("->").append(n.uniqueName()).append(";\n");
+ }
+ sb.append("}\n");
+ return sb.toString();
+ }
+}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Instruction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Instruction.java
new file mode 100644
index 0000000..26a4326
--- /dev/null
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Instruction.java
@@ -0,0 +1,685 @@
+package com.compilerprogramming.ezlang.bytecode;
+
+import com.compilerprogramming.ezlang.types.Register;
+import com.compilerprogramming.ezlang.types.Type;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+public abstract class Instruction {
+
+ public boolean isTerminal() {
+ return false;
+ }
+ @Override
+ public String toString() {
+ return toStr(new StringBuilder()).toString();
+ }
+
+ public boolean definesVar() { return false; }
+ public Register def() { return null; }
+
+ public boolean usesVars() { return false; }
+ public List uses() { return Collections.emptyList(); }
+ public void replaceDef(Register newReg) {}
+ public void replaceUses(Register[] newUses) {}
+
+ public static class Move extends Instruction {
+ public final Operand from;
+ public final Operand to;
+ public Move(Operand from, Operand to) {
+ this.from = from;
+ this.to = to;
+ }
+
+ @Override
+ public boolean definesVar() {
+ return true;
+ }
+
+ @Override
+ public Register def() {
+ if (to instanceof Operand.RegisterOperand registerOperand)
+ return registerOperand.reg;
+ throw new IllegalStateException();
+ }
+
+ @Override
+ public void replaceDef(Register newReg) {
+ this.to.replaceRegister(newReg);
+ }
+
+ @Override
+ public boolean usesVars() {
+ return (from instanceof Operand.RegisterOperand);
+ }
+
+ @Override
+ public List uses() {
+ if (from instanceof Operand.RegisterOperand registerOperand)
+ return List.of(registerOperand.reg);
+ return super.uses();
+ }
+
+ @Override
+ public void replaceUses(Register[] newUses) {
+ from.replaceRegister(newUses[0]);
+ }
+
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb.append(to).append(" = ").append(from);
+ }
+ }
+
+ public static class NewArray extends Instruction {
+ public final Type.TypeArray type;
+ public final Operand.RegisterOperand destOperand;
+ public NewArray(Type.TypeArray type, Operand.RegisterOperand destOperand) {
+ this.type = type;
+ this.destOperand = destOperand;
+ }
+
+ @Override
+ public boolean definesVar() {
+ return true;
+ }
+
+ @Override
+ public Register def() {
+ return destOperand.reg;
+ }
+
+ @Override
+ public void replaceDef(Register newReg) {
+ destOperand.replaceRegister(newReg);
+ }
+
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb.append(destOperand)
+ .append(" = ")
+ .append("New(")
+ .append(type)
+ .append(")");
+ }
+ }
+
+ public static class NewStruct extends Instruction {
+ public final Type.TypeStruct type;
+ public final Operand.RegisterOperand destOperand;
+ public NewStruct(Type.TypeStruct type, Operand.RegisterOperand destOperand) {
+ this.type = type;
+ this.destOperand = destOperand;
+ }
+ @Override
+ public boolean definesVar() {
+ return true;
+ }
+
+ @Override
+ public Register def() {
+ return destOperand.reg;
+ }
+
+ @Override
+ public void replaceDef(Register newReg) {
+ destOperand.replaceRegister(newReg);
+ }
+
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb.append(destOperand)
+ .append(" = ")
+ .append("New(")
+ .append(type)
+ .append(")");
+ }
+ }
+
+ public static class ArrayLoad extends Instruction {
+ public final Operand arrayOperand;
+ public final Operand indexOperand;
+ public final Operand.RegisterOperand destOperand;
+ public ArrayLoad(Operand.LoadIndexedOperand from, Operand.RegisterOperand to) {
+ arrayOperand = from.arrayOperand;
+ indexOperand = from.indexOperand;
+ destOperand = to;
+ }
+ @Override
+ public boolean definesVar() {
+ return true;
+ }
+ @Override
+ public Register def() {
+ return destOperand.reg;
+ }
+ @Override
+ public void replaceDef(Register newReg) {
+ destOperand.replaceRegister(newReg);
+ }
+
+ @Override
+ public boolean usesVars() {
+ return true;
+ }
+
+ @Override
+ public List uses() {
+ List usesList = new ArrayList<>();
+ if (arrayOperand instanceof Operand.RegisterOperand registerOperand)
+ usesList.add(registerOperand.reg);
+ if (indexOperand instanceof Operand.RegisterOperand registerOperand)
+ usesList.add(registerOperand.reg);
+ return usesList;
+ }
+
+ @Override
+ public void replaceUses(Register[] newUses) {
+ int i = 0;
+ if (arrayOperand instanceof Operand.RegisterOperand) {
+ arrayOperand.replaceRegister(newUses[i++]);
+ }
+ if (indexOperand instanceof Operand.RegisterOperand) {
+ indexOperand.replaceRegister(newUses[i]);
+ }
+ }
+
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb.append(destOperand)
+ .append(" = ")
+ .append(arrayOperand)
+ .append("[")
+ .append(indexOperand)
+ .append("]");
+ }
+ }
+
+ public static class ArrayStore extends Instruction {
+ public final Operand arrayOperand;
+ public final Operand indexOperand;
+ public final Operand sourceOperand;
+ public ArrayStore(Operand from, Operand.LoadIndexedOperand to) {
+ arrayOperand = to.arrayOperand;
+ indexOperand = to.indexOperand;
+ sourceOperand = from;
+ }
+ @Override
+ public boolean usesVars() {
+ return true;
+ }
+
+ @Override
+ public List uses() {
+ List usesList = new ArrayList<>();
+ if (arrayOperand instanceof Operand.RegisterOperand registerOperand)
+ usesList.add(registerOperand.reg);
+ if (indexOperand instanceof Operand.RegisterOperand registerOperand)
+ usesList.add(registerOperand.reg);
+ if (sourceOperand instanceof Operand.RegisterOperand registerOperand)
+ usesList.add(registerOperand.reg);
+ return usesList;
+ }
+
+ @Override
+ public void replaceUses(Register[] newUses) {
+ int i = 0;
+ if (arrayOperand instanceof Operand.RegisterOperand) {
+ arrayOperand.replaceRegister(newUses[i++]);
+ }
+ if (indexOperand instanceof Operand.RegisterOperand) {
+ indexOperand.replaceRegister(newUses[i++]);
+ }
+ if (sourceOperand instanceof Operand.RegisterOperand) {
+ sourceOperand.replaceRegister(newUses[i]);
+ }
+ }
+
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb
+ .append(arrayOperand)
+ .append("[")
+ .append(indexOperand)
+ .append("] = ")
+ .append(sourceOperand);
+ }
+ }
+
+ public static class GetField extends Instruction {
+ public final Operand structOperand;
+ public final String fieldName;
+ public final int fieldIndex;
+ public final Operand.RegisterOperand destOperand;
+ public GetField(Operand.LoadFieldOperand from, Operand.RegisterOperand to)
+ {
+ this.structOperand = from.structOperand;
+ this.fieldName = from.fieldName;
+ this.fieldIndex = from.fieldIndex;
+ this.destOperand = to;
+ }
+ @Override
+ public boolean definesVar() {
+ return true;
+ }
+ @Override
+ public Register def() {
+ return destOperand.reg;
+ }
+
+ @Override
+ public void replaceDef(Register newReg) {
+ destOperand.replaceRegister(newReg);
+ }
+
+ @Override
+ public boolean usesVars() {
+ return true;
+ }
+ @Override
+ public List uses() {
+ List usesList = new ArrayList<>();
+ if (structOperand instanceof Operand.RegisterOperand registerOperand)
+ usesList.add(registerOperand.reg);
+ return usesList;
+ }
+
+ @Override
+ public void replaceUses(Register[] newUses) {
+ if (newUses.length > 0)
+ structOperand.replaceRegister(newUses[0]);
+ }
+
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb.append(destOperand)
+ .append(" = ")
+ .append(structOperand)
+ .append(".")
+ .append(fieldName);
+ }
+ }
+
+ public static class SetField extends Instruction {
+ public final Operand structOperand;
+ public final String fieldName;
+ public final int fieldIndex;
+ public final Operand sourceOperand;
+ public SetField(Operand from,Operand.LoadFieldOperand to)
+ {
+ this.structOperand = to.structOperand;
+ this.fieldName = to.fieldName;
+ this.fieldIndex = to.fieldIndex;
+ this.sourceOperand = from;
+ }
+ @Override
+ public boolean usesVars() {
+ return true;
+ }
+ @Override
+ public List uses() {
+ List usesList = new ArrayList<>();
+ if (structOperand instanceof Operand.RegisterOperand registerOperand)
+ usesList.add(registerOperand.reg);
+ if (sourceOperand instanceof Operand.RegisterOperand registerOperand)
+ usesList.add(registerOperand.reg);
+ return usesList;
+ }
+ @Override
+ public void replaceUses(Register[] newUses) {
+ int i = 0;
+ if (structOperand instanceof Operand.RegisterOperand) {
+ structOperand.replaceRegister(newUses[i++]);
+ }
+ if (sourceOperand instanceof Operand.RegisterOperand) {
+ sourceOperand.replaceRegister(newUses[i]);
+ }
+ }
+
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb
+ .append(structOperand)
+ .append(".")
+ .append(fieldName)
+ .append(" = ")
+ .append(sourceOperand);
+ }
+ }
+ public static class Return extends Move {
+ public Return(Operand from, int slot, int regNum, Type type) {
+ super(from, new Operand.ReturnRegisterOperand(slot, regNum, "%ret", type));
+ }
+ }
+ public static class Unary extends Instruction {
+ public final String unop;
+ public final Operand.RegisterOperand result;
+ public final Operand operand;
+ public Unary(String unop, Operand.RegisterOperand result, Operand operand) {
+ this.unop = unop;
+ this.result = result;
+ this.operand = operand;
+ }
+
+ @Override
+ public boolean definesVar() {
+ return true;
+ }
+
+ @Override
+ public Register def() {
+ return result.reg;
+ }
+
+ @Override
+ public void replaceDef(Register newReg) {
+ result.replaceRegister(newReg);
+ }
+
+ @Override
+ public boolean usesVars() {
+ return operand instanceof Operand.RegisterOperand registerOperand;
+ }
+
+ @Override
+ public List uses() {
+ List usesList = new ArrayList<>();
+ if (operand instanceof Operand.RegisterOperand registerOperand) {
+ usesList.add(registerOperand.reg);
+ }
+ return usesList;
+ }
+
+ @Override
+ public void replaceUses(Register[] newUses) {
+ if (newUses.length > 0)
+ operand.replaceRegister(newUses[0]);
+ }
+
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb.append(result).append(" = ").append(unop).append(operand);
+ }
+ }
+
+ public static class Binary extends Instruction {
+ public final String binOp;
+ public final Operand.RegisterOperand result;
+ public final Operand left;
+ public final Operand right;
+ public Binary(String binop, Operand.RegisterOperand result, Operand left, Operand right) {
+ this.binOp = binop;
+ this.result = result;
+ this.left = left;
+ this.right = right;
+ }
+ @Override
+ public boolean definesVar() {
+ return true;
+ }
+
+ @Override
+ public Register def() {
+ return result.reg;
+ }
+
+ @Override
+ public void replaceDef(Register newReg) {
+ result.replaceRegister(newReg);
+ }
+
+ @Override
+ public boolean usesVars() {
+ return left instanceof Operand.RegisterOperand ||
+ right instanceof Operand.RegisterOperand;
+ }
+ @Override
+ public List uses() {
+ List usesList = new ArrayList<>();
+ if (left instanceof Operand.RegisterOperand registerOperand) {
+ usesList.add(registerOperand.reg);
+ }
+ if (right instanceof Operand.RegisterOperand registerOperand) {
+ usesList.add(registerOperand.reg);
+ }
+ return usesList;
+ }
+
+ @Override
+ public void replaceUses(Register[] newUses) {
+ int i = 0;
+ if (left instanceof Operand.RegisterOperand) {
+ left.replaceRegister(newUses[i++]);
+ }
+ if (right instanceof Operand.RegisterOperand) {
+ right.replaceRegister(newUses[i]);
+ }
+ }
+
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb.append(result).append(" = ").append(left).append(binOp).append(right);
+ }
+ }
+
+ public static class AStoreAppend extends Instruction {
+ public final Operand.RegisterOperand array;
+ public final Operand value;
+ public AStoreAppend(Operand.RegisterOperand array, Operand value) {
+ this.array = array;
+ this.value = value;
+ }
+ @Override
+ public boolean usesVars() {
+ return true;
+ }
+ @Override
+ public List uses() {
+ List usesList = new ArrayList<>();
+ usesList.add(array.reg);
+ if (value instanceof Operand.RegisterOperand registerOperand) {
+ usesList.add(registerOperand.reg);
+ }
+ return usesList;
+ }
+ @Override
+ public void replaceUses(Register[] newUses) {
+ array.replaceRegister(newUses[0]);
+ if (value instanceof Operand.RegisterOperand)
+ value.replaceRegister(newUses[1]);
+ }
+
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb.append(array).append(".append(").append(value).append(")");
+ }
+ }
+
+ public static class ConditionalBranch extends Instruction {
+ public final Operand condition;
+ public final BasicBlock trueBlock;
+ public final BasicBlock falseBlock;
+ public ConditionalBranch(BasicBlock currentBlock, Operand condition, BasicBlock trueBlock, BasicBlock falseBlock) {
+ this.condition = condition;
+ this.trueBlock = trueBlock;
+ this.falseBlock = falseBlock;
+ currentBlock.addSuccessor(trueBlock);
+ currentBlock.addSuccessor(falseBlock);
+ }
+ @Override
+ public boolean usesVars() {
+ return condition instanceof Operand.RegisterOperand;
+ }
+ @Override
+ public List uses() {
+ List usesList = new ArrayList<>();
+ if (condition instanceof Operand.RegisterOperand registerOperand) {
+ usesList.add(registerOperand.reg);
+ }
+ return usesList;
+ }
+
+ @Override
+ public void replaceUses(Register[] newUses) {
+ if (condition instanceof Operand.RegisterOperand) {
+ condition.replaceRegister(newUses[0]);
+ }
+ }
+
+ @Override
+ public boolean isTerminal() {
+ return true;
+ }
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb.append("if ").append(condition).append(" goto L").append(trueBlock.bid).append(" else goto L").append(falseBlock.bid);
+ }
+ }
+
+ public static class Call extends Instruction {
+ public final Type.TypeFunction callee;
+ public final Operand.RegisterOperand[] args;
+ public final Operand.RegisterOperand returnOperand;
+ public final int newbase;
+ public Call(int newbase, Operand.RegisterOperand returnOperand, Type.TypeFunction callee, Operand.RegisterOperand... args) {
+ this.returnOperand = returnOperand;
+ this.callee = callee;
+ this.args = args;
+ this.newbase = newbase;
+ }
+
+ @Override
+ public boolean definesVar() {
+ return returnOperand != null;
+ }
+
+ @Override
+ public Register def() {
+ return returnOperand != null ? returnOperand.reg : null;
+ }
+
+ @Override
+ public void replaceDef(Register newReg) {
+ if (returnOperand != null)
+ returnOperand.replaceRegister(newReg);
+ }
+
+ @Override
+ public boolean usesVars() {
+ return args != null && args.length > 0;
+ }
+ @Override
+ public List uses() {
+ List usesList = new ArrayList<>();
+ for (Operand.RegisterOperand argOperand : args) {
+ usesList.add(argOperand.reg);
+ }
+ return usesList;
+ }
+
+ @Override
+ public void replaceUses(Register[] newUses) {
+ if (args == null)
+ return;
+ for (int i = 0; i < args.length; i++) {
+ args[i].replaceRegister(newUses[i]);
+ }
+ }
+
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ if (returnOperand != null) {
+ sb.append(returnOperand).append(" = ");
+ }
+ sb.append("call ").append(callee);
+ if (args.length > 0)
+ sb.append(" params ");
+ for (int i = 0; i < args.length; i++) {
+ if (i > 0) sb.append(", ");
+ sb.append(args[i]);
+ }
+ return sb;
+ }
+ }
+
+ public static class Jump extends Instruction {
+ public final BasicBlock jumpTo;
+ public Jump(BasicBlock jumpTo) {
+ this.jumpTo = jumpTo;
+ }
+ @Override
+ public boolean isTerminal() {
+ return true;
+ }
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb.append("goto ").append(" L").append(jumpTo.bid);
+ }
+ }
+
+ public static class Phi extends Instruction {
+ public Operand.RegisterOperand dest;
+ public final List inputs = new ArrayList<>();
+ public Phi(Register dest, List inputs) {
+ this.dest = new Operand.RegisterOperand(dest);
+ for (Register input : inputs) {
+ this.inputs.add(new Operand.RegisterOperand(input));
+ }
+ }
+ @Override
+ public boolean definesVar() {
+ return true;
+ }
+ @Override
+ public Register def() {
+ return dest.reg;
+ }
+ @Override
+ public void replaceDef(Register newReg) {
+ dest = new Operand.RegisterOperand(newReg);
+ }
+ public void replaceInput(int i, Register newReg) {
+ inputs.set(i, new Operand.RegisterOperand(newReg));
+ }
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ sb.append(dest).append(" = phi(");
+ for (int i = 0; i < inputs.size(); i++) {
+ if (i > 0) sb.append(", ");
+ sb.append(inputs.get(i));
+ }
+ sb.append(")");
+ return sb;
+ }
+ }
+
+ public static class ArgInstruction extends Instruction {
+ Operand.RegisterOperand arg;
+
+ @Override
+ public boolean definesVar() {
+ return true;
+ }
+ @Override
+ public Register def() {
+ return arg.reg;
+ }
+
+ @Override
+ public void replaceDef(Register newReg) {
+ arg.replaceRegister(newReg);
+ }
+
+ public ArgInstruction(Operand.RegisterOperand arg) {
+ this.arg = arg;
+ }
+ @Override
+ public StringBuilder toStr(StringBuilder sb) {
+ return sb.append("arg ").append(arg);
+ }
+ }
+
+ public abstract StringBuilder toStr(StringBuilder sb);
+}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopFinder.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopFinder.java
new file mode 100644
index 0000000..313bde4
--- /dev/null
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopFinder.java
@@ -0,0 +1,87 @@
+package com.compilerprogramming.ezlang.bytecode;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+public class LoopFinder {
+ // Based on Compilers: Principles, Techniques and Tools
+ // p 604
+ // 1986 ed
+ public static LoopNest getNaturalLoop(BasicBlock head, BasicBlock backedge) {
+ Stack stack = new Stack<>();
+ LoopNest loop = new LoopNest(head);
+ loop.insert(backedge, stack);
+ // trace back up from backedge to head
+ while (!stack.isEmpty()) {
+ BasicBlock m = stack.pop();
+ for (BasicBlock pred : m.predecessors) {
+ loop.insert(pred, stack);
+ }
+ }
+ return loop;
+ }
+
+ // Based on Compilers: Principles, Techniques and Tools
+ // p 604
+ // 1986 ed
+ public static List findLoops(List nodes) {
+ List list = new ArrayList<>();
+ for (BasicBlock n : nodes) {
+ for (BasicBlock input : n.predecessors) {
+ if (n.dominates(input)) {
+ list.add(getNaturalLoop(n, input));
+ }
+ }
+ }
+ return list;
+ }
+
+ public static List mergeLoopsWithSameHead(List loopNests) {
+ HashMap map = new HashMap<>();
+ for (LoopNest loopNest : loopNests) {
+ LoopNest sameHead = map.get(loopNest._loopHead.bid);
+ if (sameHead == null) map.put(loopNest._loopHead.bid, loopNest);
+ else sameHead._blocks.addAll(loopNest._blocks);
+ }
+ return map.values().stream().collect(Collectors.toList());
+ }
+
+ public static LoopNest buildLoopTree(List loopNests) {
+ for (LoopNest nest1 : loopNests) {
+ for (LoopNest nest2 : loopNests) {
+ boolean isNested = nest1.contains(nest2);
+ if (isNested) {
+ if (nest2._parent == null) nest2._parent = nest1;
+ else if (nest1._loopHead.domDepth > nest2._parent._loopHead.domDepth) nest2._parent = nest1;
+ }
+ }
+ }
+ LoopNest top = null;
+ for (LoopNest nest : loopNests) {
+ if (nest._parent != null) nest._parent._kids.add(nest);
+ else top = nest;
+ }
+ return top;
+ }
+
+ public static void annotateBasicBlocks(LoopNest loop, Set visited) {
+ if (visited.contains(loop))
+ return;
+ visited.add(loop);
+ for (LoopNest kid: loop._kids) {
+ kid._depth = loop._depth+1;
+ annotateBasicBlocks(kid, visited);
+ }
+ for (BasicBlock block: loop._blocks) {
+ if (block.loop == null)
+ block.loop = loop;
+ }
+ }
+
+ public static void annotateBasicBlocks(LoopNest top) {
+ if (top == null) // No loop
+ return;
+ top._depth = 1;
+ annotateBasicBlocks(top, new HashSet<>());
+ }
+}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopNest.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopNest.java
new file mode 100644
index 0000000..1aeca0a
--- /dev/null
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopNest.java
@@ -0,0 +1,62 @@
+package com.compilerprogramming.ezlang.bytecode;
+
+import java.util.*;
+
+public class LoopNest {
+ /**
+ * Parent loop
+ */
+ public LoopNest _parent;
+ /**
+ * The block that is the head of the loop
+ */
+ public final BasicBlock _loopHead;
+ /**
+ * Blocks that are part of this loop
+ */
+ public final Set _blocks;
+ /**
+ * Children as per Loop Tree
+ */
+ public final List _kids;
+ /**
+ * Loop Tree depth - top has depth 1
+ */
+ public int _depth;
+
+ public LoopNest(BasicBlock loopHead) {
+ _loopHead = loopHead;
+ _blocks = new HashSet<>();
+ _blocks.add(loopHead);
+ _kids = new ArrayList<>();
+ }
+
+ public void insert(BasicBlock m, Stack stack) {
+ if (!_blocks.contains(m)) {
+ _blocks.add(m);
+ stack.push(m);
+ }
+ }
+
+ public boolean contains(LoopNest other) {
+ return this != other && _blocks.containsAll(other._blocks);
+ }
+
+ public String uniqueName() { return "Loop_" + _loopHead.uniqueName(); }
+ public String label() { return "Loop(" + _loopHead.label() + ":" + _depth + ")"; }
+
+ public static String generateDotOutput(List loopNests) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("digraph LoopTree {\n");
+ for (LoopNest n : loopNests) {
+ sb.append(n.uniqueName()).append(" [label=\"").append(n.label()).append("\"];\n");
+ }
+ for (LoopNest n : loopNests) {
+ for (LoopNest c: n._kids) {
+ sb.append(n.uniqueName()).append("->").append(c.uniqueName()).append(";\n");
+ }
+ }
+ sb.append("}\n");
+ return sb.toString();
+ }
+}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Operand.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Operand.java
new file mode 100644
index 0000000..90e481d
--- /dev/null
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Operand.java
@@ -0,0 +1,130 @@
+package com.compilerprogramming.ezlang.bytecode;
+
+import com.compilerprogramming.ezlang.types.Register;
+import com.compilerprogramming.ezlang.types.Type;
+
+public class Operand {
+
+ Type type;
+
+ public void replaceRegister(Register register) {}
+
+ public static class ConstantOperand extends Operand {
+ public final long value;
+ public ConstantOperand(long value, Type type) {
+ this.value = value;
+ this.type = type;
+ }
+ @Override
+ public String toString() {
+ return String.valueOf(value);
+ }
+ }
+
+ public static class RegisterOperand extends Operand {
+ public Register reg;
+ public RegisterOperand(int slot, int regnum, String name, Type type, boolean isTemp) {
+ this.reg = new Register(slot, regnum, name, type);
+ }
+ public RegisterOperand(Register reg) {
+ this.reg = reg;
+ }
+ public int slot() { return reg.slot; }
+
+ @Override
+ public void replaceRegister(Register register) {
+ this.reg = register;
+ }
+
+ @Override
+ public String toString() {
+ return reg.name;
+ }
+ }
+
+ public static class LocalRegisterOperand extends RegisterOperand {
+ public LocalRegisterOperand(Register reg) {
+ super(reg);
+ }
+ }
+
+ public static class LocalFunctionOperand extends Operand {
+ public final Type.TypeFunction functionType;
+ public LocalFunctionOperand(Type.TypeFunction functionType) {
+ this.functionType = functionType;
+ }
+ @Override
+ public String toString() {
+ return functionType.toString();
+ }
+ }
+
+ /**
+ * Represents the return register, which is the location where
+ * the caller will expect to see any return value. The VM must map
+ * this to appropriate location.
+ */
+ public static class ReturnRegisterOperand extends RegisterOperand {
+ public ReturnRegisterOperand(int slot, int regnum, String name, Type type) { super(slot, regnum, name, type, true); }
+ @Override
+ public String toString() { return "%ret"; }
+ }
+
+ /**
+ * Represents a temp register, maps to a location on the
+ * virtual stack. Temps start at offset 0, but this is a relative
+ * register number from start of temp area.
+ */
+ public static class TempRegisterOperand extends RegisterOperand {
+ public TempRegisterOperand(int offset, int regnum, String name, Type type) {
+ super(offset, regnum, name, type, true);
+ }
+ }
+
+ public static class IndexedOperand extends Operand {}
+
+ public static class LoadIndexedOperand extends IndexedOperand {
+ public final Operand arrayOperand;
+ public final Operand indexOperand;
+ public LoadIndexedOperand(Operand arrayOperand, Operand indexOperand) {
+ this.arrayOperand = arrayOperand;
+ this.indexOperand = indexOperand;
+ assert !(indexOperand instanceof IndexedOperand) &&
+ !(arrayOperand instanceof IndexedOperand);
+ }
+ @Override
+ public String toString() {
+ return arrayOperand + "[" + indexOperand + "]";
+ }
+ }
+
+ public static class LoadFieldOperand extends IndexedOperand {
+ public final Operand structOperand;
+ public final int fieldIndex;
+ public final String fieldName;
+ public LoadFieldOperand(Operand structOperand, String fieldName, int field) {
+ this.structOperand = structOperand;
+ this.fieldName = fieldName;
+ this.fieldIndex = field;
+ assert !(structOperand instanceof IndexedOperand);
+ }
+
+ @Override
+ public String toString() {
+ return structOperand + "." + fieldName;
+ }
+ }
+
+ public static class NewTypeOperand extends Operand {
+ public final Type type;
+ public NewTypeOperand(Type type) {
+ this.type = type;
+ }
+
+ @Override
+ public String toString() {
+ return "New(" + type + ")";
+ }
+ }
+
+}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/SSATransform.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/SSATransform.java
new file mode 100644
index 0000000..3ff81cf
--- /dev/null
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/SSATransform.java
@@ -0,0 +1,228 @@
+package com.compilerprogramming.ezlang.bytecode;
+
+import com.compilerprogramming.ezlang.types.Register;
+
+import java.util.*;
+
+/**
+ * Transform a bytecode function to SSA form
+ */
+public class SSATransform {
+
+ BytecodeFunction function;
+ DominatorTree domTree;
+ Register[] globals;
+ BBSet[] blockSets;
+ List blocks;
+ int[] counters;
+ VersionStack[] stacks;
+
+ public SSATransform(BytecodeFunction bytecodeFunction) {
+ this.function = bytecodeFunction;
+ setupGlobals();
+ computeDomTreeAndDominanceFrontiers();
+ this.blocks = domTree.blocks;
+ findGlobalVars();
+ insertPhis();
+ renameVars();
+ }
+
+ private void computeDomTreeAndDominanceFrontiers() {
+ domTree = new DominatorTree(function.entry);
+ }
+
+ private void setupGlobals() {
+ // fixme this should really just look at locals I think
+ globals = new Register[function.nextReg];
+ blockSets = new BBSet[function.nextReg];
+ }
+
+ /**
+ * Compute set of registers that are live across multiple blocks
+ * i.e. are not exclusively used in a single block.
+ */
+ private void findGlobalVars() {
+ for (BasicBlock block : blocks) {
+ var varKill = new HashSet();
+ for (Instruction instruction: block.instructions) {
+ if (instruction.usesVars()) {
+ for (Register reg : instruction.uses()) {
+ if (!varKill.contains(reg.id)) {
+ globals[reg.id] = reg;
+ }
+ }
+ }
+ if (instruction.definesVar()) {
+ Register reg = instruction.def();
+ varKill.add(reg.id);
+ if (blockSets[reg.id] == null) {
+ blockSets[reg.id] = new BBSet();
+ }
+ blockSets[reg.id].add(block);
+ }
+ }
+ }
+ }
+
+ void insertPhis() {
+ for (int i = 0; i < globals.length; i++) {
+ Register x = globals[i];
+ if (x != null) {
+ var visited = new BitSet();
+ var worklist = new WorkList(blockSets[x.id].blocks);
+ var b = worklist.pop();
+ while (b != null) {
+ visited.set(b.bid);
+ for (BasicBlock d: b.dominationFrontier) {
+ // insert phi for x in d
+ d.insertPhiFor(x);
+ if (!visited.get(d.bid))
+ worklist.push(d);
+ }
+ b = worklist.pop();
+ }
+ }
+ }
+ }
+
+ void renameVars() {
+ initVersionCounters();
+ search(function.entry);
+ }
+
+ /**
+ * Creates and pushes new name
+ */
+ Register makeVersion(Register reg) {
+ int version = counters[reg.id];
+ if (version != reg.ssaVersion)
+ reg = reg.cloneWithVersion(version);
+ stacks[reg.id].push(reg);
+ counters[reg.id] = counters[reg.id] + 1;
+ return reg;
+ }
+
+ /**
+ * Recursively walk the Dominator Tree, renaming variables.
+ * Implementation is based on the algorithm in the Preston Briggs
+ * paper Practical Improvements to the Construction and Destruction of
+ * Static Single Assignment Form
+ */
+ void search(BasicBlock block) {
+ // Replace v = phi(...) with v_i = phi(...)
+ for (Instruction.Phi phi: block.phis()) {
+ Register ssaReg = makeVersion(phi.def());
+ phi.replaceDef(ssaReg);
+ }
+ // for each instruction v = x op y
+ // first replace x,y
+ // then replace v
+ for (Instruction instruction: block.instructions) {
+ if (instruction instanceof Instruction.Phi)
+ continue;
+ // first replace x,y
+ if (instruction.usesVars()) {
+ var uses = instruction.uses();
+ Register[] newUses = new Register[uses.size()];
+ for (int i = 0; i < newUses.length; i++) {
+ Register oldReg = uses.get(i);
+ newUses[i] = stacks[oldReg.id].top();
+ instruction.replaceUses(newUses);
+ }
+ }
+ // then replace v
+ if (instruction.definesVar()) {
+ Register ssaReg = makeVersion(instruction.def());
+ instruction.replaceDef(ssaReg);
+ }
+ }
+ // Update phis in successor blocks
+ for (BasicBlock s: block.successors) {
+ int j = whichPred(s,block);
+ for (Instruction.Phi phi: s.phis()) {
+ Register oldReg = phi.inputs.get(j).reg;
+ phi.replaceInput(j, stacks[oldReg.id].top());
+ }
+ }
+ // Recurse down the dominator tree
+ for (BasicBlock c: block.dominatedChildren) {
+ search(c);
+ }
+ // Pop stacks for defs
+ for (Instruction i: block.instructions) {
+ if (i.definesVar()) {
+ var reg = i.def();
+ stacks[reg.id].pop();
+ }
+ }
+ }
+
+ private int whichPred(BasicBlock s, BasicBlock block) {
+ int i = 0;
+ for (BasicBlock p: s.predecessors) {
+ if (p == block)
+ return i;
+ i++;
+ }
+ throw new IllegalStateException();
+ }
+
+ private void initVersionCounters() {
+ counters = new int[globals.length];
+ stacks = new VersionStack[globals.length];
+ for (int i = 0; i < globals.length; i++) {
+ counters[i] = 0;
+ stacks[i] = new VersionStack();
+ }
+ }
+
+ static class BBSet {
+ Set blocks = new HashSet<>();
+ void add(BasicBlock block) { blocks.add(block); }
+ }
+
+ static class VersionStack {
+ List stack = new ArrayList<>();
+ void push(Register r) { stack.add(r); }
+ Register top() { return stack.getLast(); }
+ void pop() { stack.removeLast(); }
+ }
+
+ /**
+ * Simple worklist
+ */
+ public static class WorkList {
+
+ private ArrayList blocks;
+ private final BitSet members;
+
+ WorkList() {
+ blocks = new ArrayList<>();
+ members = new BitSet();
+ }
+ WorkList(Collection blocks) {
+ this();
+ addAll(blocks);
+ }
+ public BasicBlock push( BasicBlock x ) {
+ if( x==null ) return null;
+ int idx = x.bid;
+ if( !members.get(idx) ) {
+ members.set(idx);
+ blocks.add(x);
+ }
+ return x;
+ }
+ public void addAll( Collection ary ) {
+ for( BasicBlock n : ary )
+ push(n);
+ }
+ BasicBlock pop() {
+ if ( blocks.isEmpty() )
+ return null;
+ var x = blocks.removeFirst();
+ members.clear(x.bid);
+ return x;
+ }
+ }
+}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/ExecutionStack.java b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/ExecutionStack.java
new file mode 100644
index 0000000..dbfcac9
--- /dev/null
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/ExecutionStack.java
@@ -0,0 +1,12 @@
+package com.compilerprogramming.ezlang.interpreter;
+
+public class ExecutionStack {
+
+ public Value[] stack;
+ public int sp;
+
+ public ExecutionStack(int maxStackSize) {
+ this.stack = new Value[maxStackSize];
+ this.sp = -1;
+ }
+}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java
new file mode 100644
index 0000000..3e48b71
--- /dev/null
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java
@@ -0,0 +1,256 @@
+package com.compilerprogramming.ezlang.interpreter;
+
+import com.compilerprogramming.ezlang.bytecode.BasicBlock;
+import com.compilerprogramming.ezlang.bytecode.BytecodeFunction;
+import com.compilerprogramming.ezlang.bytecode.Instruction;
+import com.compilerprogramming.ezlang.bytecode.Operand;
+import com.compilerprogramming.ezlang.exceptions.CompilerException;
+import com.compilerprogramming.ezlang.exceptions.InterpreterException;
+import com.compilerprogramming.ezlang.types.Symbol;
+import com.compilerprogramming.ezlang.types.Type;
+import com.compilerprogramming.ezlang.types.TypeDictionary;
+
+public class Interpreter {
+
+ TypeDictionary typeDictionary;
+
+ public Interpreter(TypeDictionary typeDictionary) {
+ this.typeDictionary = typeDictionary;
+ }
+
+ public Value run(String functionName) {
+ Symbol symbol = typeDictionary.lookup(functionName);
+ if (symbol instanceof Symbol.FunctionTypeSymbol functionSymbol) {
+ Frame frame = new Frame(functionSymbol);
+ ExecutionStack execStack = new ExecutionStack(1024);
+ return interpret(execStack, frame);
+ }
+ else {
+ throw new InterpreterException("Unknown function: " + functionName);
+ }
+ }
+
+ public Value interpret(ExecutionStack execStack, Frame frame) {
+ BytecodeFunction currentFunction = frame.bytecodeFunction;
+ BasicBlock currentBlock = currentFunction.entry;
+ int ip = -1;
+ int base = frame.base;
+ boolean done = false;
+ Value returnValue = null;
+
+ while (!done) {
+ Instruction instruction;
+
+ ip++;
+ instruction = currentBlock.instructions.get(ip);
+ switch (instruction) {
+ case Instruction.Return returnInst -> {
+ if (returnInst.from instanceof Operand.ConstantOperand constantOperand) {
+ execStack.stack[base] = new Value.IntegerValue(constantOperand.value);
+ }
+ else if (returnInst.from instanceof Operand.RegisterOperand registerOperand) {
+ execStack.stack[base] = execStack.stack[base+registerOperand.slot()];
+ }
+ else throw new IllegalStateException();
+ returnValue = execStack.stack[base];
+ }
+ case Instruction.Move moveInst -> {
+ if (moveInst.to instanceof Operand.RegisterOperand toReg) {
+ if (moveInst.from instanceof Operand.RegisterOperand fromReg) {
+ execStack.stack[base + toReg.slot()] = execStack.stack[base + fromReg.slot()];
+ }
+ else if (moveInst.from instanceof Operand.ConstantOperand constantOperand) {
+ execStack.stack[base + toReg.slot()] = new Value.IntegerValue(constantOperand.value);
+ }
+ else throw new IllegalStateException();
+ }
+ else throw new IllegalStateException();
+ }
+ case Instruction.Jump jumpInst -> {
+ currentBlock = jumpInst.jumpTo;
+ ip = -1;
+ if (currentBlock == currentFunction.exit)
+ done = true;
+ }
+ case Instruction.ConditionalBranch cbrInst -> {
+ boolean condition;
+ if (cbrInst.condition instanceof Operand.RegisterOperand registerOperand) {
+ Value value = execStack.stack[base + registerOperand.slot()];
+ if (value instanceof Value.IntegerValue integerValue) {
+ condition = integerValue.value != 0;
+ }
+ else {
+ condition = value != null;
+ }
+ }
+ else if (cbrInst.condition instanceof Operand.ConstantOperand constantOperand) {
+ condition = constantOperand.value != 0;
+ }
+ else throw new IllegalStateException();
+ if (condition)
+ currentBlock = cbrInst.trueBlock;
+ else
+ currentBlock = cbrInst.falseBlock;
+ ip = -1;
+ if (currentBlock == currentFunction.exit)
+ done = true;
+ }
+ case Instruction.Call callInst -> {
+ // Copy args to new frame
+ int baseReg = base+currentFunction.frameSize();
+ int reg = baseReg;
+ for (Operand.RegisterOperand arg: callInst.args) {
+ execStack.stack[base + reg] = execStack.stack[base + arg.slot()];
+ reg += 1;
+ }
+ // Call function
+ Frame newFrame = new Frame(frame, baseReg, callInst.callee);
+ interpret(execStack, newFrame);
+ // Copy return value in expected location
+ if (!(callInst.callee.returnType instanceof Type.TypeVoid)) {
+ execStack.stack[base + callInst.returnOperand.slot()] = execStack.stack[baseReg];
+ }
+ }
+ case Instruction.Unary unaryInst -> {
+ // We don't expect constant here because we fold constants in unary expressions
+ Operand.RegisterOperand unaryOperand = (Operand.RegisterOperand) unaryInst.operand;
+ Value unaryValue = execStack.stack[base + unaryOperand.slot()];
+ if (unaryValue instanceof Value.IntegerValue integerValue) {
+ switch (unaryInst.unop) {
+ case "-": execStack.stack[base + unaryInst.result.slot()] = new Value.IntegerValue(-integerValue.value); break;
+ // Maybe below we should explicitly set Int
+ case "!": execStack.stack[base + unaryInst.result.slot()] = new Value.IntegerValue(integerValue.value==0?1:0); break;
+ default: throw new CompilerException("Invalid unary op");
+ }
+ }
+ else
+ throw new IllegalStateException("Unexpected unary operand: " + unaryOperand);
+ }
+ case Instruction.Binary binaryInst -> {
+ long x, y;
+ long value = 0;
+ if (binaryInst.left instanceof Operand.ConstantOperand constant)
+ x = constant.value;
+ else if (binaryInst.left instanceof Operand.RegisterOperand registerOperand)
+ x = ((Value.IntegerValue) execStack.stack[base + registerOperand.slot()]).value;
+ else throw new IllegalStateException();
+ if (binaryInst.right instanceof Operand.ConstantOperand constant)
+ y = constant.value;
+ else if (binaryInst.right instanceof Operand.RegisterOperand registerOperand)
+ y = ((Value.IntegerValue) execStack.stack[base + registerOperand.slot()]).value;
+ else throw new IllegalStateException();
+ switch (binaryInst.binOp) {
+ case "+": value = x + y; break;
+ case "-": value = x - y; break;
+ case "*": value = x * y; break;
+ case "/": value = x / y; break;
+ case "%": value = x % y; break;
+ case "==": value = x == y ? 1 : 0; break;
+ case "!=": value = x != y ? 1 : 0; break;
+ case "<": value = x < y ? 1: 0; break;
+ case ">": value = x > y ? 1 : 0; break;
+ case "<=": value = x <= y ? 1 : 0; break;
+ case ">=": value = x <= y ? 1 : 0; break;
+ default: throw new IllegalStateException();
+ }
+ execStack.stack[base + binaryInst.result.slot()] = new Value.IntegerValue(value);
+ }
+ case Instruction.NewArray newArrayInst -> {
+ execStack.stack[base + newArrayInst.destOperand.slot()] = new Value.ArrayValue(newArrayInst.type);
+ }
+ case Instruction.NewStruct newStructInst -> {
+ execStack.stack[base + newStructInst.destOperand.slot()] = new Value.StructValue(newStructInst.type);
+ }
+ case Instruction.AStoreAppend arrayAppendInst -> {
+ Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayAppendInst.array.slot()];
+ if (arrayAppendInst.value instanceof Operand.ConstantOperand constant) {
+ arrayValue.values.add(new Value.IntegerValue(constant.value));
+ }
+ else if (arrayAppendInst.value instanceof Operand.RegisterOperand registerOperand) {
+ arrayValue.values.add(execStack.stack[base + registerOperand.slot()]);
+ }
+ else throw new IllegalStateException();
+ }
+ case Instruction.ArrayStore arrayStoreInst -> {
+ if (arrayStoreInst.arrayOperand instanceof Operand.RegisterOperand arrayOperand) {
+ Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayOperand.slot()];
+ int index = 0;
+ if (arrayStoreInst.indexOperand instanceof Operand.ConstantOperand constant) {
+ index = (int) constant.value;
+ }
+ else if (arrayStoreInst.indexOperand instanceof Operand.RegisterOperand registerOperand) {
+ Value.IntegerValue indexValue = (Value.IntegerValue) execStack.stack[base + registerOperand.slot()];
+ index = (int) indexValue.value;
+ }
+ else throw new IllegalStateException();
+ Value value;
+ if (arrayStoreInst.sourceOperand instanceof Operand.ConstantOperand constantOperand) {
+ value = new Value.IntegerValue(constantOperand.value);
+ }
+ else if (arrayStoreInst.sourceOperand instanceof Operand.RegisterOperand registerOperand) {
+ value = execStack.stack[base + registerOperand.slot()];
+ }
+ else throw new IllegalStateException();
+ arrayValue.values.set(index, value);
+ } else throw new IllegalStateException();
+ }
+ case Instruction.ArrayLoad arrayLoadInst -> {
+ if (arrayLoadInst.arrayOperand instanceof Operand.RegisterOperand arrayOperand) {
+ Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayOperand.slot()];
+ if (arrayLoadInst.indexOperand instanceof Operand.ConstantOperand constant) {
+ execStack.stack[base + arrayLoadInst.destOperand.slot()] = arrayValue.values.get((int) constant.value);
+ }
+ else if (arrayLoadInst.indexOperand instanceof Operand.RegisterOperand registerOperand) {
+ Value.IntegerValue index = (Value.IntegerValue) execStack.stack[base + registerOperand.slot()];
+ execStack.stack[base + arrayLoadInst.destOperand.slot()] = arrayValue.values.get((int) index.value);
+ }
+ else throw new IllegalStateException();
+ } else throw new IllegalStateException();
+ }
+ case Instruction.SetField setFieldInst -> {
+ if (setFieldInst.structOperand instanceof Operand.RegisterOperand structOperand) {
+ Value.StructValue structValue = (Value.StructValue) execStack.stack[base + structOperand.slot()];
+ int index = setFieldInst.fieldIndex;
+ Value value;
+ if (setFieldInst.sourceOperand instanceof Operand.ConstantOperand constant) {
+ value = new Value.IntegerValue(constant.value);
+ }
+ else if (setFieldInst.sourceOperand instanceof Operand.RegisterOperand registerOperand) {
+ value = execStack.stack[base + registerOperand.slot()];
+ }
+ else throw new IllegalStateException();
+ structValue.fields[index] = value;
+ } else throw new IllegalStateException();
+ }
+ case Instruction.GetField getFieldInst -> {
+ if (getFieldInst.structOperand instanceof Operand.RegisterOperand structOperand) {
+ Value.StructValue structValue = (Value.StructValue) execStack.stack[base + structOperand.slot()];
+ int index = getFieldInst.fieldIndex;
+ execStack.stack[base + getFieldInst.destOperand.slot()] = structValue.fields[index];
+ } else throw new IllegalStateException();
+ }
+ case Instruction.ArgInstruction argInst -> {}
+ default -> throw new IllegalStateException("Unexpected value: " + instruction);
+ }
+ }
+ return returnValue;
+ }
+
+ static class Frame {
+ Frame caller;
+ int base;
+ BytecodeFunction bytecodeFunction;
+
+ public Frame(Symbol.FunctionTypeSymbol functionSymbol) {
+ this.caller = null;
+ this.base = 0;
+ this.bytecodeFunction = (BytecodeFunction) functionSymbol.code();
+ }
+
+ Frame(Frame caller, int base, Type.TypeFunction functionType) {
+ this.caller = caller;
+ this.base = base;
+ this.bytecodeFunction = (BytecodeFunction) functionType.code;
+ }
+ }
+}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java
new file mode 100644
index 0000000..e03fb80
--- /dev/null
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java
@@ -0,0 +1,30 @@
+package com.compilerprogramming.ezlang.interpreter;
+
+import com.compilerprogramming.ezlang.types.Type;
+
+import java.util.ArrayList;
+
+public class Value {
+ static public class IntegerValue extends Value {
+ public IntegerValue(long value) {
+ this.value = value;
+ }
+ public final long value;
+ }
+ static public class ArrayValue extends Value {
+ public final Type.TypeArray arrayType;
+ public final ArrayList values;
+ public ArrayValue(Type.TypeArray arrayType) {
+ this.arrayType = arrayType;
+ values = new ArrayList<>();
+ }
+ }
+ static public class StructValue extends Value {
+ public final Type.TypeStruct structType;
+ public final Value[] fields;
+ public StructValue(Type.TypeStruct structType) {
+ this.structType = structType;
+ this.fields = new Value[structType.numFields()];
+ }
+ }
+}
diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java
new file mode 100644
index 0000000..4401d4b
--- /dev/null
+++ b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java
@@ -0,0 +1,619 @@
+package com.compilerprogramming.ezlang.bytecode;
+
+import com.compilerprogramming.ezlang.lexer.Lexer;
+import com.compilerprogramming.ezlang.parser.Parser;
+import com.compilerprogramming.ezlang.semantic.SemaAssignTypes;
+import com.compilerprogramming.ezlang.semantic.SemaDefineTypes;
+import com.compilerprogramming.ezlang.types.Symbol;
+import com.compilerprogramming.ezlang.types.TypeDictionary;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.BitSet;
+
+public class TestCompiler {
+
+ String compileSrc(String src) {
+ Parser parser = new Parser();
+ var program = parser.parse(new Lexer(src));
+ var typeDict = new TypeDictionary();
+ var sema = new SemaDefineTypes(typeDict);
+ sema.analyze(program);
+ var sema2 = new SemaAssignTypes(typeDict);
+ sema2.analyze(program);
+ var byteCodeCompiler = new BytecodeCompiler();
+ byteCodeCompiler.compile(typeDict);
+ StringBuilder sb = new StringBuilder();
+ for (Symbol s : typeDict.bindings.values()) {
+ if (s instanceof Symbol.FunctionTypeSymbol f) {
+ var functionBuilder = (BytecodeFunction) f.code();
+ BasicBlock.toStr(sb, functionBuilder.entry, new BitSet());
+ }
+ }
+ return sb.toString();
+ }
+
+ @Test
+ public void testFunction1() {
+ String src = """
+ func foo(n: Int)->Int {
+ return 1;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %ret = 1
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction2() {
+ String src = """
+ func foo(n: Int)->Int {
+ return -1;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %ret = -1
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction3() {
+ String src = """
+ func foo(n: Int)->Int {
+ return n;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %ret = n
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction4() {
+ String src = """
+ func foo(n: Int)->Int {
+ return -n;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %t2 = -n
+ %ret = %t2
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction5() {
+ String src = """
+ func foo(n: Int)->Int {
+ return n+1;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %t2 = n+1
+ %ret = %t2
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction6() {
+ String src = """
+ func foo(n: Int)->Int {
+ return 1+1;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %ret = 2
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction7() {
+ String src = """
+ func foo(n: Int)->Int {
+ return 1+1-1;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %ret = 1
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction8() {
+ String src = """
+ func foo(n: Int)->Int {
+ return 2==2;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %ret = 1
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction9() {
+ String src = """
+ func foo(n: Int)->Int {
+ return 1!=1;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %ret = 0
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction10() {
+ String src = """
+ func foo(n: [Int])->Int {
+ return n[0];
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %t2 = n[0]
+ %ret = %t2
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction11() {
+ String src = """
+ func foo(n: [Int])->Int {
+ return n[0]+n[1];
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %t2 = n[0]
+ %t3 = n[1]
+ %t4 = %t2+%t3
+ %ret = %t4
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction12() {
+ String src = """
+ func foo()->[Int] {
+ return new [Int] { 1, 2, 3 };
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ %t1 = New([Int,Int])
+ %t1.append(1)
+ %t1.append(2)
+ %t1.append(3)
+ %ret = %t1
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction13() {
+ String src = """
+ func foo(n: Int) -> [Int] {
+ return new [Int] { n };
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ %t2 = New([Int,Int])
+ %t2.append(n)
+ %ret = %t2
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction14() {
+ String src = """
+ func add(x: Int, y: Int) -> Int {
+ return x+y;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg x
+ arg y
+ %t3 = x+y
+ %ret = %t3
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction15() {
+ String src = """
+ struct Person
+ {
+ var age: Int
+ var children: Int
+ }
+ func foo(p: Person) -> Person {
+ p.age = 10;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg p
+ p.age = 10
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction16() {
+ String src = """
+ struct Person
+ {
+ var age: Int
+ var children: Int
+ }
+ func foo() -> Person {
+ return new Person { age=10, children=0 };
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ %t1 = New(Person)
+ %t1.age = 10
+ %t1.children = 0
+ %ret = %t1
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction17() {
+ String src = """
+ func foo(array: [Int]) {
+ array[0] = 1
+ array[1] = 2
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg array
+ array[0] = 1
+ array[1] = 2
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction18() {
+ String src = """
+ func min(x: Int, y: Int) -> Int {
+ if (x < y)
+ return x;
+ return y;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg x
+ arg y
+ %t3 = x 0) {
+ n = n - 1;
+ }
+ return;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg n
+ goto L2
+L2:
+ %t2 = n>0
+ if %t2 goto L3 else goto L4
+L3:
+ %t3 = n-1
+ n = %t3
+ goto L2
+L4:
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction22() {
+ String src = """
+ func foo() {}
+ func bar() { foo(); }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ goto L1
+L1:
+L0:
+ call foo
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction23() {
+ String src = """
+ func foo(x: Int, y: Int) {}
+ func bar() { foo(1,2); }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg x
+ arg y
+ goto L1
+L1:
+L0:
+ %t1 = 1
+ %t2 = 2
+ call foo params %t1, %t2
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction24() {
+ String src = """
+ func foo(x: Int, y: Int)->Int { return x+y; }
+ func bar()->Int { var t = foo(1,2); return t+1; }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg x
+ arg y
+ %t3 = x+y
+ %ret = %t3
+ goto L1
+L1:
+L0:
+ %t2 = 1
+ %t3 = 2
+ %t4 = call foo params %t2, %t3
+ t = %t4
+ %t5 = t+1
+ %ret = %t5
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction25() {
+ String src = """
+ struct Person
+ {
+ var age: Int
+ var children: Int
+ }
+ func foo(p: Person) -> Int {
+ return p.age;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg p
+ %t2 = p.age
+ %ret = %t2
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction26() {
+ String src = """
+ struct Person
+ {
+ var age: Int
+ var parent: Person
+ }
+ func foo(p: Person) -> Int {
+ return p.parent.age;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg p
+ %t2 = p.parent
+ %t3 = %t2.age
+ %ret = %t3
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction27() {
+ String src = """
+ struct Person
+ {
+ var age: Int
+ var parent: Person
+ }
+ func foo(p: [Person], i: Int) -> Int {
+ return p[i].parent.age;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg p
+ arg i
+ %t3 = p[i]
+ %t4 = %t3.parent
+ %t5 = %t4.age
+ %ret = %t5
+ goto L1
+L1:
+""", result);
+ }
+
+ @Test
+ public void testFunction28() {
+ String src = """
+ func foo(x: Int, y: Int)->Int { return x+y; }
+ func bar(a: Int)->Int { var t = foo(a,2); return t+1; }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+L0:
+ arg x
+ arg y
+ %t3 = x+y
+ %ret = %t3
+ goto L1
+L1:
+L0:
+ arg a
+ %t3 = a
+ %t4 = 2
+ %t5 = call foo params %t3, %t4
+ t = %t5
+ %t6 = t+1
+ %ret = %t6
+ goto L1
+L1:
+""", result);
+ }
+}
diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestDominators.java b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestDominators.java
new file mode 100644
index 0000000..293bd72
--- /dev/null
+++ b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestDominators.java
@@ -0,0 +1,94 @@
+package com.compilerprogramming.ezlang.bytecode;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class TestDominators {
+ BasicBlock add(List nodes, BasicBlock node) {
+ nodes.add(node);
+ return node;
+ }
+
+ BasicBlock makeGraph(List nodes) {
+ BasicBlock r0 = add(nodes, new BasicBlock(1));
+ BasicBlock r1 = add(nodes, new BasicBlock(2, r0));
+ BasicBlock r2 = add(nodes, new BasicBlock(3, r1));
+ BasicBlock r3 = add(nodes, new BasicBlock(4, r2));
+ BasicBlock r4 = add(nodes, new BasicBlock(5, r3));
+ BasicBlock r5 = add(nodes, new BasicBlock(6, r1));
+ BasicBlock r6 = add(nodes, new BasicBlock(7, r5));
+ BasicBlock r7 = add(nodes, new BasicBlock(8, r6));
+ BasicBlock r8 = add(nodes, new BasicBlock(9, r5));
+ r8.addSuccessor(r7);
+ r7.addSuccessor(r3);
+ r3.addSuccessor(r1);
+ return r0;
+ }
+
+ @Test
+ public void testDominatorTree() {
+ List nodes = new ArrayList<>();
+ BasicBlock root = makeGraph(nodes);
+ DominatorTree tree = new DominatorTree(root);
+ System.out.println(tree.generateDotOutput());
+ long[] expectedIdoms = {0,1,1,2,2,4,2,6,6,6};
+ for (BasicBlock n: nodes) {
+ Assert.assertEquals(expectedIdoms[(int)n.bid], n.idom.bid);
+ }
+ }
+
+ BasicBlock makeGraph2(List nodes) {
+
+ BasicBlock r1 = add(nodes, new BasicBlock(1));
+ BasicBlock r2 = add(nodes, new BasicBlock(2, r1));
+ BasicBlock r3 = add(nodes, new BasicBlock(3, r2));
+ BasicBlock r4 = add(nodes, new BasicBlock(4, r2));
+ BasicBlock r5 = add(nodes, new BasicBlock(5, r4));
+ BasicBlock r6 = add(nodes, new BasicBlock(6, r4));
+ BasicBlock r7 = add(nodes, new BasicBlock(7, r5, r6));
+ BasicBlock r8 = add(nodes, new BasicBlock(8, r5));
+ BasicBlock r9 = add(nodes, new BasicBlock(9, r8));
+ BasicBlock r10 = add(nodes, new BasicBlock(10, r9));
+ BasicBlock r11 = add(nodes, new BasicBlock(11, r7));
+ BasicBlock r12 = add(nodes, new BasicBlock(12, r10, r11));
+
+ r3.addSuccessor(r2);
+ r4.addSuccessor(r2);
+ r10.addSuccessor(r5);
+ r9.addSuccessor(r8);
+ return r1;
+ }
+
+ public String generateDotOutput(List nodes) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("digraph g {\n");
+ for (BasicBlock n: nodes)
+ sb.append(n.uniqueName()).append(";\n");
+ for (BasicBlock n: nodes) {
+ for (BasicBlock use: n.successors) {
+ sb.append(n.uniqueName()).append("->").append(use.uniqueName()).append(";\n");
+ }
+ }
+ sb.append("}\n");
+ return sb.toString();
+ }
+
+ @Test
+ public void testLoopNests() {
+ List nodes = new ArrayList<>();
+ BasicBlock root = makeGraph2(nodes);
+ System.out.println(generateDotOutput(nodes));
+ DominatorTree tree = new DominatorTree(root);
+ List loopNests = LoopFinder.findLoops(nodes);
+ Assert.assertEquals(2, loopNests.get(0)._loopHead.bid);
+ Assert.assertEquals(2, loopNests.get(1)._loopHead.bid);
+ Assert.assertEquals(5, loopNests.get(2)._loopHead.bid);
+ Assert.assertEquals(8, loopNests.get(3)._loopHead.bid);
+ List loops = LoopFinder.mergeLoopsWithSameHead(loopNests);
+ return;
+ }
+
+}
diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestInterpreter.java b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestInterpreter.java
new file mode 100644
index 0000000..a2fdd69
--- /dev/null
+++ b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestInterpreter.java
@@ -0,0 +1,127 @@
+package com.compilerprogramming.ezlang.bytecode;
+
+import com.compilerprogramming.ezlang.interpreter.Interpreter;
+import com.compilerprogramming.ezlang.interpreter.Value;
+import com.compilerprogramming.ezlang.lexer.Lexer;
+import com.compilerprogramming.ezlang.parser.Parser;
+import com.compilerprogramming.ezlang.semantic.SemaAssignTypes;
+import com.compilerprogramming.ezlang.semantic.SemaDefineTypes;
+import com.compilerprogramming.ezlang.types.Symbol;
+import com.compilerprogramming.ezlang.types.TypeDictionary;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.BitSet;
+
+public class TestInterpreter {
+
+ Value compileAndRun(String src, String mainFunction) {
+ Parser parser = new Parser();
+ var program = parser.parse(new Lexer(src));
+ var typeDict = new TypeDictionary();
+ var sema = new SemaDefineTypes(typeDict);
+ sema.analyze(program);
+ var sema2 = new SemaAssignTypes(typeDict);
+ sema2.analyze(program);
+ var byteCodeCompiler = new BytecodeCompiler();
+ byteCodeCompiler.compile(typeDict);
+ StringBuilder sb = new StringBuilder();
+ for (Symbol s : typeDict.bindings.values()) {
+ if (s instanceof Symbol.FunctionTypeSymbol f) {
+ var functionBuilder = (BytecodeFunction) f.code();
+ BasicBlock.toStr(sb, functionBuilder.entry, new BitSet());
+ }
+ }
+ System.out.println(sb.toString());
+ var interpreter = new Interpreter(typeDict);
+ return interpreter.run(mainFunction);
+ }
+
+ @Test
+ public void testFunction1() {
+ String src = """
+ func foo()->Int {
+ return 42;
+ }
+ """;
+ var value = compileAndRun(src, "foo");
+ Assert.assertNotNull(value);
+ Assert.assertTrue(value instanceof Value.IntegerValue integerValue
+ && integerValue.value == 42);
+ }
+
+ @Test
+ public void testFunction2() {
+ String src = """
+ func bar()->Int {
+ return 42;
+ }
+ func foo()->Int {
+ return bar();
+ }
+ """;
+ var value = compileAndRun(src, "foo");
+ Assert.assertNotNull(value);
+ Assert.assertTrue(value instanceof Value.IntegerValue integerValue
+ && integerValue.value == 42);
+ }
+
+ @Test
+ public void testFunction3() {
+ String src = """
+ func negate(n: Int)->Int {
+ return -n;
+ }
+ func foo()->Int {
+ return negate(42);
+ }
+ """;
+ var value = compileAndRun(src, "foo");
+ Assert.assertNotNull(value);
+ Assert.assertTrue(value instanceof Value.IntegerValue integerValue
+ && integerValue.value == -42);
+ }
+
+ @Test
+ public void testFunction4() {
+ String src = """
+ func foo(x: Int, y: Int)->Int { return x+y; }
+ func bar()->Int { var t = foo(1,2); return t+1; }
+ """;
+ var value = compileAndRun(src, "bar");
+ Assert.assertNotNull(value);
+ Assert.assertTrue(value instanceof Value.IntegerValue integerValue
+ && integerValue.value == 4);
+ }
+
+ @Test
+ public void testFunction5() {
+ String src = """
+ func bar()->Int { var t = new [Int] {1,21,3}; return t[1]; }
+ """;
+ var value = compileAndRun(src, "bar");
+ Assert.assertNotNull(value);
+ Assert.assertTrue(value instanceof Value.IntegerValue integerValue
+ && integerValue.value == 21);
+ }
+
+ @Test
+ public void testFunction6() {
+ String src = """
+ struct Test
+ {
+ var field: Int
+ }
+ func foo()->Test
+ {
+ var test = new Test{ field = 42 }
+ return test
+ }
+ func bar()->Int { return foo().field }
+ """;
+ var value = compileAndRun(src, "bar");
+ Assert.assertNotNull(value);
+ Assert.assertTrue(value instanceof Value.IntegerValue integerValue
+ && integerValue.value == 42);
+ }
+}
diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestSSATransform.java b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestSSATransform.java
new file mode 100644
index 0000000..70b82fe
--- /dev/null
+++ b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestSSATransform.java
@@ -0,0 +1,419 @@
+package com.compilerprogramming.ezlang.bytecode;
+
+import com.compilerprogramming.ezlang.lexer.Lexer;
+import com.compilerprogramming.ezlang.parser.Parser;
+import com.compilerprogramming.ezlang.semantic.SemaAssignTypes;
+import com.compilerprogramming.ezlang.semantic.SemaDefineTypes;
+import com.compilerprogramming.ezlang.types.Symbol;
+import com.compilerprogramming.ezlang.types.TypeDictionary;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.BitSet;
+
+public class TestSSATransform {
+
+ String compileSrc(String src) {
+ Parser parser = new Parser();
+ var program = parser.parse(new Lexer(src));
+ var typeDict = new TypeDictionary();
+ var sema = new SemaDefineTypes(typeDict);
+ sema.analyze(program);
+ var sema2 = new SemaAssignTypes(typeDict);
+ sema2.analyze(program);
+ var byteCodeCompiler = new BytecodeCompiler();
+ byteCodeCompiler.compile(typeDict);
+ StringBuilder sb = new StringBuilder();
+ for (Symbol s : typeDict.bindings.values()) {
+ if (s instanceof Symbol.FunctionTypeSymbol f) {
+ var functionBuilder = (BytecodeFunction) f.code();
+ sb.append("func ").append(f.name).append("\n");
+ sb.append("Before SSA\n");
+ sb.append("==========\n");
+ BasicBlock.toStr(sb, functionBuilder.entry, new BitSet());
+ new SSATransform(functionBuilder);
+ sb.append("After SSA\n");
+ sb.append("=========\n");
+ BasicBlock.toStr(sb, functionBuilder.entry, new BitSet());
+ }
+ }
+ return sb.toString();
+ }
+
+ @Test
+ public void test1() {
+ String src = """
+ func foo(d: Int) {
+ var a = 42;
+ var b = a;
+ var c = a + b;
+ a = c + 23;
+ c = a + d;
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+func foo
+Before SSA
+==========
+L0:
+ arg d
+ a = 42
+ b = a
+ %t5 = a+b
+ c = %t5
+ %t6 = c+23
+ a = %t6
+ %t7 = a+d
+ c = %t7
+ goto L1
+L1:
+After SSA
+=========
+L0:
+ arg d
+ a = 42
+ b = a
+ %t5 = a+b
+ c = %t5
+ %t6 = c+23
+ a_1 = %t6
+ %t7 = a_1+d
+ c_1 = %t7
+ goto L1
+L1:
+""", result);
+
+ }
+ @Test
+ public void test2() {
+ String src = """
+ func foo(d: Int)->Int {
+ var a = 42
+ if (d)
+ {
+ a = a + 1
+ }
+ else
+ {
+ a = a - 1
+ }
+ return a
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+func foo
+Before SSA
+==========
+L0:
+ arg d
+ a = 42
+ if d goto L2 else goto L3
+L2:
+ %t3 = a+1
+ a = %t3
+ goto L4
+L4:
+ %ret = a
+ goto L1
+L1:
+L3:
+ %t4 = a-1
+ a = %t4
+ goto L4
+After SSA
+=========
+L0:
+ arg d
+ a = 42
+ if d goto L2 else goto L3
+L2:
+ %t3 = a+1
+ a_2 = %t3
+ goto L4
+L4:
+ a_3 = phi(a_2, a_1)
+ %ret = a_3
+ goto L1
+L1:
+L3:
+ %t4 = a-1
+ a_1 = %t4
+ goto L4
+""", result);
+
+ }
+ @Test
+ public void test3() {
+ String src = """
+ func factorial(num: Int)->Int {
+ var result = 1
+ while (num > 1)
+ {
+ result = result * num
+ num = num - 1
+ }
+ return result
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+func factorial
+Before SSA
+==========
+L0:
+ arg num
+ result = 1
+ goto L2
+L2:
+ %t3 = num>1
+ if %t3 goto L3 else goto L4
+L3:
+ %t4 = result*num
+ result = %t4
+ %t5 = num-1
+ num = %t5
+ goto L2
+L4:
+ %ret = result
+ goto L1
+L1:
+After SSA
+=========
+L0:
+ arg num
+ result = 1
+ goto L2
+L2:
+ result_1 = phi(result, result_2)
+ num_1 = phi(num, num_2)
+ %t3 = num_1>1
+ if %t3 goto L3 else goto L4
+L3:
+ %t4 = result_1*num_1
+ result_2 = %t4
+ %t5 = num_1-1
+ num_2 = %t5
+ goto L2
+L4:
+ %ret = result_1
+ goto L1
+L1:
+""", result);
+
+ }
+
+ @Test
+ public void test4() {
+ String src = """
+ func print(a: Int, b: Int, c:Int, d:Int) {}
+ func example14_66(p: Int, q: Int, r: Int, s: Int, t: Int) {
+ var i = 1
+ var j = 1
+ var k = 1
+ var l = 1
+ while (1) {
+ if (p) {
+ j = i
+ if (q) {
+ l = 2
+ }
+ else {
+ l = 3
+ }
+ k = k + 1
+ }
+ else {
+ k = k + 2
+ }
+ print(i,j,k,l)
+ while (1) {
+ if (r) {
+ l = l + 4
+ }
+ if (!s)
+ break
+ }
+ i = i + 6
+ if (!t)
+ break
+ }
+ }
+ """;
+ String result = compileSrc(src);
+ Assert.assertEquals("""
+func print
+Before SSA
+==========
+L0:
+ arg a
+ arg b
+ arg c
+ arg d
+ goto L1
+L1:
+After SSA
+=========
+L0:
+ arg a
+ arg b
+ arg c
+ arg d
+ goto L1
+L1:
+func example14_66
+Before SSA
+==========
+L0:
+ arg p
+ arg q
+ arg r
+ arg s
+ arg t
+ i = 1
+ j = 1
+ k = 1
+ l = 1
+ goto L2
+L2:
+ if 1 goto L3 else goto L4
+L3:
+ if p goto L5 else goto L6
+L5:
+ j = i
+ if q goto L8 else goto L9
+L8:
+ l = 2
+ goto L10
+L10:
+ %t10 = k+1
+ k = %t10
+ goto L7
+L7:
+ %t12 = i
+ %t13 = j
+ %t14 = k
+ %t15 = l
+ call print params %t12, %t13, %t14, %t15
+ goto L11
+L11:
+ if 1 goto L12 else goto L13
+L12:
+ if r goto L14 else goto L15
+L14:
+ %t16 = l+4
+ l = %t16
+ goto L15
+L15:
+ %t17 = !s
+ if %t17 goto L16 else goto L17
+L16:
+ goto L13
+L13:
+ %t18 = i+6
+ i = %t18
+ %t19 = !t
+ if %t19 goto L18 else goto L19
+L18:
+ goto L4
+L4:
+ goto L1
+L1:
+L19:
+ goto L2
+L17:
+ goto L11
+L9:
+ l = 3
+ goto L10
+L6:
+ %t11 = k+2
+ k = %t11
+ goto L7
+After SSA
+=========
+L0:
+ arg p
+ arg q
+ arg r
+ arg s
+ arg t
+ i = 1
+ j = 1
+ k = 1
+ l = 1
+ goto L2
+L2:
+ l_1 = phi(l, l_9)
+ k_1 = phi(k, k_4)
+ j_1 = phi(j, j_3)
+ i_1 = phi(i, i_2)
+ if 1 goto L3 else goto L4
+L3:
+ if p goto L5 else goto L6
+L5:
+ j_2 = i_1
+ if q goto L8 else goto L9
+L8:
+ l_3 = 2
+ goto L10
+L10:
+ l_4 = phi(l_3, l_2)
+ %t10 = k_1+1
+ k_3 = %t10
+ goto L7
+L7:
+ l_5 = phi(l_4, l_1)
+ k_4 = phi(k_3, k_2)
+ j_3 = phi(j_2, j_1)
+ %t12 = i_1
+ %t13 = j_3
+ %t14 = k_4
+ %t15 = l_5
+ call print params %t12, %t13, %t14, %t15
+ goto L11
+L11:
+ l_6 = phi(l_5, l_8)
+ if 1 goto L12 else goto L13
+L12:
+ if r goto L14 else goto L15
+L14:
+ %t16 = l_6+4
+ l_7 = %t16
+ goto L15
+L15:
+ l_8 = phi(l_6, l_7)
+ %t17 = !s
+ if %t17 goto L16 else goto L17
+L16:
+ goto L13
+L13:
+ l_9 = phi(l_6, l_8)
+ %t18 = i_1+6
+ i_2 = %t18
+ %t19 = !t
+ if %t19 goto L18 else goto L19
+L18:
+ goto L4
+L4:
+ l_10 = phi(l_1, l_9)
+ k_5 = phi(k_1, k_4)
+ j_4 = phi(j_1, j_3)
+ i_3 = phi(i_1, i_2)
+ goto L1
+L1:
+L19:
+ goto L2
+L17:
+ goto L11
+L9:
+ l_2 = 3
+ goto L10
+L6:
+ %t11 = k_1+2
+ k_2 = %t11
+ goto L7
+""", result);
+ }
+}
diff --git a/pom.xml b/pom.xml
index b972fb6..4ea1d08 100644
--- a/pom.xml
+++ b/pom.xml
@@ -27,6 +27,7 @@
semantic
stackvm
registervm
+ optvm
diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java
index 0513127..a83a5f8 100644
--- a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java
+++ b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java
@@ -2,6 +2,7 @@
import com.compilerprogramming.ezlang.exceptions.CompilerException;
import com.compilerprogramming.ezlang.parser.AST;
+import com.compilerprogramming.ezlang.types.Register;
import com.compilerprogramming.ezlang.types.Scope;
import com.compilerprogramming.ezlang.types.Symbol;
import com.compilerprogramming.ezlang.types.Type;
@@ -59,7 +60,7 @@ private void setVirtualRegisters(Scope scope) {
reg = scope.parent.maxReg;
for (Symbol symbol: scope.getLocalSymbols()) {
if (symbol instanceof Symbol.VarSymbol varSymbol) {
- varSymbol.reg = reg++;
+ varSymbol.reg = new Register(reg++, 0, varSymbol.name, varSymbol.type);
}
}
scope.maxReg = reg;
@@ -145,7 +146,7 @@ private void compileAssign(AST.AssignStmt assignStmt) {
codeIndexedStore();
else if (assignStmt.lhs instanceof AST.NameExpr symbolExpr) {
Symbol.VarSymbol varSymbol = (Symbol.VarSymbol) symbolExpr.symbol;
- code(new Instruction.Move(pop(), new Operand.LocalRegisterOperand(varSymbol.reg, varSymbol.name)));
+ code(new Instruction.Move(pop(), new Operand.LocalRegisterOperand(varSymbol.reg.slot, varSymbol.name)));
}
else
throw new CompilerException("Invalid assignment expression: " + assignStmt.lhs);
@@ -240,7 +241,7 @@ private void compileLet(AST.VarStmt letStmt) {
boolean indexed = compileExpr(letStmt.expr);
if (indexed)
codeIndexedLoad();
- code(new Instruction.Move(pop(), new Operand.LocalRegisterOperand(letStmt.symbol.reg, letStmt.symbol.name)));
+ code(new Instruction.Move(pop(), new Operand.LocalRegisterOperand(letStmt.symbol.reg.slot, letStmt.symbol.name)));
}
}
@@ -402,7 +403,7 @@ private boolean compileSymbolExpr(AST.NameExpr symbolExpr) {
pushOperand(new Operand.LocalFunctionOperand(functionType));
else {
Symbol.VarSymbol varSymbol = (Symbol.VarSymbol) symbolExpr.symbol;
- pushLocal(varSymbol.reg, varSymbol.name);
+ pushLocal(varSymbol.reg.slot, varSymbol.name);
}
return false;
}
diff --git a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java
index c51dcdf..1c126b1 100644
--- a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java
+++ b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java
@@ -41,7 +41,7 @@ public ASTVisitor visit(AST.FuncDecl funcDecl, boolean enter) {
throw new CompilerException("Symbol " + funcDecl.name + " is already declared");
}
// Create function scope, that houses function parameters
- currentScope = new Scope(currentScope);
+ currentScope = new Scope(currentScope, true);
funcDecl.scope = currentScope;
// Install a symbol for the function,
// type is not fully formed at this stage
@@ -116,7 +116,7 @@ else if (varDecl.varType == AST.VarType.FUNCTION_PARAMETER
if (currentScope.localLookup(varDecl.name) != null)
throw new CompilerException("Parameter " + varDecl.name + " is already declared");
Type.TypeFunction type = (Type.TypeFunction) currentFuncDecl.symbol.type;
- varDecl.symbol = currentScope.install(varDecl.name, new Symbol.VarSymbol(varDecl.name, varDecl.typeExpr.type));
+ varDecl.symbol = currentScope.install(varDecl.name, new Symbol.ParameterSymbol(varDecl.name, varDecl.typeExpr.type));
type.addArg(varDecl.symbol);
}
else if (varDecl.varType == AST.VarType.VARIABLE) {
diff --git a/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java b/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java
index ea26f5f..f7d6668 100644
--- a/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java
+++ b/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java
@@ -2,6 +2,7 @@
import com.compilerprogramming.ezlang.exceptions.CompilerException;
import com.compilerprogramming.ezlang.parser.AST;
+import com.compilerprogramming.ezlang.types.Register;
import com.compilerprogramming.ezlang.types.Scope;
import com.compilerprogramming.ezlang.types.Symbol;
import com.compilerprogramming.ezlang.types.Type;
@@ -32,7 +33,7 @@ private void setVirtualRegisters(Scope scope) {
reg = scope.parent.maxReg;
for (Symbol symbol: scope.getLocalSymbols()) {
if (symbol instanceof Symbol.VarSymbol varSymbol) {
- varSymbol.reg = reg++;
+ varSymbol.reg = new Register(reg, 0, varSymbol.name, varSymbol.type);
}
}
scope.maxReg = reg;
@@ -112,7 +113,7 @@ private void compileAssign(AST.AssignStmt assignStmt) {
code(new Instruction.StoreIndexed());
else if (assignStmt.lhs instanceof AST.NameExpr symbolExpr) {
Symbol.VarSymbol varSymbol = (Symbol.VarSymbol) symbolExpr.symbol;
- code(new Instruction.Store(varSymbol.reg));
+ code(new Instruction.Store(varSymbol.reg.slot));
}
else
throw new CompilerException("Invalid assignment expression: " + assignStmt.lhs);
@@ -203,7 +204,7 @@ private void compileLet(AST.VarStmt letStmt) {
boolean indexed = compileExpr(letStmt.expr);
if (indexed)
code(new Instruction.LoadIndexed());
- code(new Instruction.Store(letStmt.symbol.reg));
+ code(new Instruction.Store(letStmt.symbol.reg.slot));
}
}
@@ -322,7 +323,7 @@ private boolean compileSymbolExpr(AST.NameExpr symbolExpr) {
code(new Instruction.LoadFunction(functionType));
else {
Symbol.VarSymbol varSymbol = (Symbol.VarSymbol) symbolExpr.symbol;
- code(new Instruction.LoadVar(varSymbol.reg));
+ code(new Instruction.LoadVar(varSymbol.reg.slot));
}
return false;
}
diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/Register.java b/types/src/main/java/com/compilerprogramming/ezlang/types/Register.java
new file mode 100644
index 0000000..e65e03a
--- /dev/null
+++ b/types/src/main/java/com/compilerprogramming/ezlang/types/Register.java
@@ -0,0 +1,44 @@
+package com.compilerprogramming.ezlang.types;
+
+import java.util.Objects;
+
+public class Register {
+ /**
+ * Slot in the function's frame
+ */
+ public final int slot;
+ /**
+ * Unique virtual ID
+ */
+ public final int id;
+ public final int ssaVersion;
+ public final String name;
+ public final Type type;
+
+ public Register(int slot, int id, String name, Type type) {
+ this(slot, id, name, type, 0);
+ }
+ public Register(int slot, int id, String name, Type type, int ssaVersion) {
+ this.slot = slot;
+ this.id = id;
+ this.name = name;
+ this.type = type;
+ this.ssaVersion = ssaVersion;
+ }
+ public Register cloneWithVersion(int ssaVersion) {
+ return new Register(this.slot, this.id, this.name+"_"+ssaVersion, this.type, ssaVersion);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Register register = (Register) o;
+ return id == register.id && ssaVersion == register.ssaVersion;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(id, ssaVersion);
+ }
+}
diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/Scope.java b/types/src/main/java/com/compilerprogramming/ezlang/types/Scope.java
index e6d5b5a..70b6b2d 100644
--- a/types/src/main/java/com/compilerprogramming/ezlang/types/Scope.java
+++ b/types/src/main/java/com/compilerprogramming/ezlang/types/Scope.java
@@ -13,13 +13,19 @@ public class Scope {
// values assigned by compiler
public int maxReg;
+ public final boolean isFunctionParameterScope;
- public Scope(Scope parent) {
+ public Scope(Scope parent, boolean isFunctionParameterScope) {
this.parent = parent;
+ this.isFunctionParameterScope = isFunctionParameterScope;
if (parent != null)
parent.children.add(this);
}
+ public Scope(Scope parent) {
+ this(parent, false);
+ }
+
public Symbol lookup(String name) {
Symbol symbol = bindings.get(name);
if (symbol == null && parent != null)
diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java b/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java
index 0252f7b..dc1de20 100644
--- a/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java
+++ b/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java
@@ -35,9 +35,15 @@ public Object code() {
public static class VarSymbol extends Symbol {
// Values assigned by bytecode compiler
- public int reg;
+ public Register reg;
public VarSymbol(String name, Type type) {
super(name, type);
}
}
+
+ public static class ParameterSymbol extends VarSymbol {
+ public ParameterSymbol(String name, Type type) {
+ super(name, type);
+ }
+ }
}