diff --git a/optvm/pom.xml b/optvm/pom.xml new file mode 100644 index 0000000..1e8c725 --- /dev/null +++ b/optvm/pom.xml @@ -0,0 +1,31 @@ + + + 4.0.0 + + com.compilerprogramming.ezlang + compilercraft + 1.0 + + optvm + jar + Optimizing VM + + + com.compilerprogramming.ezlang + parser + 1.0 + + + com.compilerprogramming.ezlang + types + 1.0 + + + com.compilerprogramming.ezlang + semantic + 1.0 + + + \ No newline at end of file diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BasicBlock.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BasicBlock.java new file mode 100644 index 0000000..41293df --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BasicBlock.java @@ -0,0 +1,156 @@ +package com.compilerprogramming.ezlang.bytecode; + +import com.compilerprogramming.ezlang.types.Register; + +import java.util.*; + +public class BasicBlock { + public final int bid; + public final boolean loopHead; + public final List successors = new ArrayList<>(); // successors + public final List predecessors = new ArrayList<>(); + public final List instructions = new ArrayList<>(); + /** + * The preorder traversal number, also acts as a flag indicating whether the + * BB is yet to be visited (_pre==0 means not yet visited). + */ + int pre; + /** + * The depth of the BB in the dominator tree + */ + int domDepth; + /** + * Reverse post order traversal number; + * Sort node list in ascending order by this to traverse graph in reverse post order. + * In RPO order if an edge exists from A to B we visit A followed by B, but cycles have to + * be dealt with in another way. + */ + int rpo; + /** + * Immediate dominator is the closest strict dominator. + * @see DominatorTree + */ + public BasicBlock idom; + /** + * Nodes for whom this node is the immediate dominator, + * thus the dominator tree. + */ + public List dominatedChildren = new ArrayList<>(); + /** + * Dominance frontier + */ + public Set dominationFrontier = new HashSet<>(); + + /** + * Nearest Loop to which this BB belongs + */ + public LoopNest loop; + + public BasicBlock(int bid, boolean loopHead) { + this.bid = bid; + this.loopHead = loopHead; + } + public BasicBlock(int bid) { + this(bid, false); + } + // For testing only + public BasicBlock(int bid, BasicBlock... preds) { + this.bid = bid; + this.loopHead = false; + for (BasicBlock bb : preds) + bb.addSuccessor(this); + } + public void add(Instruction instruction) { + instructions.add(instruction); + } + public void addSuccessor(BasicBlock successor) { + successors.add(successor); + successor.predecessors.add(this); + } + + /** + * Initially the phi has the form + * v = phi(v,v,...) + */ + public void insertPhiFor(Register var) { + for (Instruction i: instructions) { + if (i instanceof Instruction.Phi phi) { + if (phi.def().id == var.id) + // already added + return; + } + else break; + } + List inputs = new ArrayList<>(); + for (int i = 0; i < predecessors.size(); i++) + inputs.add(var); + Instruction.Phi phi = new Instruction.Phi(var, inputs); + instructions.add(0, phi); + } + public List phis() { + List list = new ArrayList<>(); + for (Instruction i: instructions) { + if (i instanceof Instruction.Phi phi) + list.add(phi); + else break; + } + return list; + } + public static StringBuilder toStr(StringBuilder sb, BasicBlock bb, BitSet visited) + { + if (visited.get(bb.bid)) + return sb; + visited.set(bb.bid); + sb.append("L").append(bb.bid).append(":\n"); + for (Instruction n: bb.instructions) { + sb.append(" "); + n.toStr(sb).append("\n"); + } + for (BasicBlock succ: bb.successors) { + toStr(sb, succ, visited); + } + return sb; + } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BasicBlock that = (BasicBlock) o; + return bid == that.bid; + } + + @Override + public int hashCode() { + return Objects.hash(bid); + } + + public String label() { + return "BB(" + bid + ")"; + } + + public String uniqueName() { + return "BB_" + bid; + } + + //////////////// dominator calculations ///////////////////// + + public void resetDomInfo() { + domDepth = 0; + idom = null; + dominatedChildren.clear(); + dominationFrontier.clear(); + } + + public void resetRPO() { + pre = 0; + rpo = 0; + } + + public boolean dominates(BasicBlock other) { + if (this == other) return true; + while (other.domDepth > domDepth) other = other.idom; + return this == other; + } + + /////////////////// End of dominator calculations ////////////////////////////////// +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeCompiler.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeCompiler.java new file mode 100644 index 0000000..3bef310 --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeCompiler.java @@ -0,0 +1,17 @@ +package com.compilerprogramming.ezlang.bytecode; + +import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.Type; +import com.compilerprogramming.ezlang.types.TypeDictionary; + +public class BytecodeCompiler { + + public void compile(TypeDictionary typeDictionary) { + for (Symbol symbol: typeDictionary.getLocalSymbols()) { + if (symbol instanceof Symbol.FunctionTypeSymbol functionSymbol) { + Type.TypeFunction functionType = (Type.TypeFunction) functionSymbol.type; + functionType.code = new BytecodeFunction(functionSymbol); + } + } + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java new file mode 100644 index 0000000..dfdfdff --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java @@ -0,0 +1,558 @@ +package com.compilerprogramming.ezlang.bytecode; + +import com.compilerprogramming.ezlang.exceptions.CompilerException; +import com.compilerprogramming.ezlang.parser.AST; +import com.compilerprogramming.ezlang.types.Register; +import com.compilerprogramming.ezlang.types.Scope; +import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.Type; + +import java.util.ArrayList; +import java.util.List; + +public class BytecodeFunction { + + public BasicBlock entry; + public BasicBlock exit; + public int maxLocalReg; + public int maxStackSize; + private int bid = 0; + private BasicBlock currentBlock; + private BasicBlock currentBreakTarget; + private BasicBlock currentContinueTarget; + private Type.TypeFunction functionType; + /** + * Each register is assigned a unique ID which is not + * the same as the slot number inside the frame, as that + * can be shared by registers because of disjoint life times. + * Start the ID at 1, so that we reserve 0 for the return register + */ + public int nextReg = 1; + static final int RETURN_REG_ID = 0; + public final int numLocalRegs; + + /** + * We essentially do a form of abstract interpretation as we generate + * the bytecode instructions. For this purpose we use a virtual operand stack. + * + * This is similar to the technique described in + * Dynamic Optimization through the use of Automatic Runtime Specialization + * by John Whaley + */ + private List virtualStack = new ArrayList<>(); + + public BytecodeFunction(Symbol.FunctionTypeSymbol functionSymbol) { + AST.FuncDecl funcDecl = (AST.FuncDecl) functionSymbol.functionDecl; + setVirtualRegisters(funcDecl.scope); + this.numLocalRegs = this.nextReg; // Before assigning temps + this.bid = 0; + this.entry = this.currentBlock = createBlock(); + this.exit = createBlock(); + this.currentBreakTarget = null; + this.currentContinueTarget = null; + this.functionType = (Type.TypeFunction) functionSymbol.type; + generateArgInstructions(funcDecl.scope); + compileStatement(funcDecl.block); + exitBlockIfNeeded(); + } + + private void generateArgInstructions(Scope scope) { + if (scope.isFunctionParameterScope) { + for (Symbol symbol: scope.getLocalSymbols()) { + if (symbol instanceof Symbol.ParameterSymbol parameterSymbol) { + code(new Instruction.ArgInstruction(new Operand.LocalRegisterOperand(parameterSymbol.reg))); + } + } + } + } + + public int frameSize() { + return maxLocalReg+maxStackSize; + } + + private void exitBlockIfNeeded() { + if (currentBlock != null && + currentBlock != exit) { + startBlock(exit); + } + } + + private void setVirtualRegisters(Scope scope) { + int reg = 0; + if (scope.parent != null) + reg = scope.parent.maxReg; + for (Symbol symbol: scope.getLocalSymbols()) { + if (symbol instanceof Symbol.VarSymbol varSymbol) { + varSymbol.reg = new Register(reg++, nextReg++, varSymbol.name, varSymbol.type); + } + } + scope.maxReg = reg; + if (maxLocalReg < scope.maxReg) + maxLocalReg = scope.maxReg; + for (Scope childScope: scope.children) { + setVirtualRegisters(childScope); + } + } + + private BasicBlock createBlock() { + return new BasicBlock(bid++); + } + + private BasicBlock createLoopHead() { + return new BasicBlock(bid++, true); + } + + private void compileBlock(AST.BlockStmt block) { + for (AST.Stmt stmt: block.stmtList) { + compileStatement(stmt); + } + } + + private void compileReturn(AST.ReturnStmt returnStmt) { + if (returnStmt.expr != null) { + boolean isIndexed = compileExpr(returnStmt.expr); + if (isIndexed) + codeIndexedLoad(); + if (virtualStack.size() == 1) + code(new Instruction.Return(pop(), 0, RETURN_REG_ID, functionType.returnType)); + else if (virtualStack.size() > 1) + throw new CompilerException("Virtual stack has more than one item at return"); + } + jumpTo(exit); + } + + private void code(Instruction instruction) { + currentBlock.add(instruction); + } + + private void compileStatement(AST.Stmt statement) { + switch (statement) { + case AST.BlockStmt blockStmt -> { + compileBlock(blockStmt); + } + case AST.VarStmt letStmt -> { + compileLet(letStmt); + } + case AST.IfElseStmt ifElseStmt -> { + compileIf(ifElseStmt); + } + case AST.WhileStmt whileStmt -> { + compileWhile(whileStmt); + } + case AST.ContinueStmt continueStmt -> { + compileContinue(continueStmt); + } + case AST.BreakStmt breakStmt -> { + compileBreak(breakStmt); + } + case AST.ReturnStmt returnStmt -> { + compileReturn(returnStmt); + } + case AST.AssignStmt assignStmt -> { + compileAssign(assignStmt); + } + case AST.ExprStmt exprStmt -> { + compileExprStmt(exprStmt); + } + default -> throw new IllegalStateException("Unexpected value: " + statement); + } + } + + private void compileAssign(AST.AssignStmt assignStmt) { + boolean indexedLhs = false; + if (!(assignStmt.lhs instanceof AST.NameExpr)) + indexedLhs = compileExpr(assignStmt.lhs); + boolean indexedRhs = compileExpr(assignStmt.rhs); + if (indexedRhs) + codeIndexedLoad(); + if (indexedLhs) + codeIndexedStore(); + else if (assignStmt.lhs instanceof AST.NameExpr symbolExpr) { + Symbol.VarSymbol varSymbol = (Symbol.VarSymbol) symbolExpr.symbol; + code(new Instruction.Move(pop(), new Operand.LocalRegisterOperand(varSymbol.reg))); + } + else + throw new CompilerException("Invalid assignment expression: " + assignStmt.lhs); + } + + private void compileExprStmt(AST.ExprStmt exprStmt) { + boolean indexed = compileExpr(exprStmt.expr); + if (indexed) + codeIndexedLoad(); + if (!vstackEmpty()) + pop(); + } + + private void compileContinue(AST.ContinueStmt continueStmt) { + if (currentContinueTarget == null) + throw new CompilerException("No continue target found"); + jumpTo(currentContinueTarget); + } + + private void compileBreak(AST.BreakStmt breakStmt) { + if (currentBreakTarget == null) + throw new CompilerException("No break target found"); + jumpTo(currentBreakTarget); + } + + private void compileWhile(AST.WhileStmt whileStmt) { + BasicBlock loopBlock = createLoopHead(); + BasicBlock bodyBlock = createBlock(); + BasicBlock exitBlock = createBlock(); + BasicBlock savedBreakTarget = currentBreakTarget; + BasicBlock savedContinueTarget = currentContinueTarget; + currentBreakTarget = exitBlock; + currentContinueTarget = loopBlock; + startBlock(loopBlock); + boolean indexed = compileExpr(whileStmt.condition); + if (indexed) + codeIndexedLoad(); + code(new Instruction.ConditionalBranch(currentBlock, pop(), bodyBlock, exitBlock)); + assert vstackEmpty(); + startBlock(bodyBlock); + compileStatement(whileStmt.stmt); + if (!isBlockTerminated(currentBlock)) + jumpTo(loopBlock); + startBlock(exitBlock); + currentContinueTarget = savedContinueTarget; + currentBreakTarget = savedBreakTarget; + } + + private boolean isBlockTerminated(BasicBlock block) { + return (block.instructions.size() > 0 && + block.instructions.getLast().isTerminal()); + } + + private void jumpTo(BasicBlock block) { + assert !isBlockTerminated(currentBlock); + currentBlock.add(new Instruction.Jump(block)); + currentBlock.addSuccessor(block); + } + + private void startBlock(BasicBlock block) { + if (!isBlockTerminated(currentBlock)) { + jumpTo(block); + } + currentBlock = block; + } + + private void compileIf(AST.IfElseStmt ifElseStmt) { + BasicBlock ifBlock = createBlock(); + boolean needElse = ifElseStmt.elseStmt != null; + BasicBlock elseBlock = needElse ? createBlock() : null; + BasicBlock exitBlock = createBlock(); + boolean indexed = compileExpr(ifElseStmt.condition); + if (indexed) + codeIndexedLoad(); + code(new Instruction.ConditionalBranch(currentBlock, pop(), ifBlock, needElse ? elseBlock : exitBlock)); + assert vstackEmpty(); + startBlock(ifBlock); + compileStatement(ifElseStmt.ifStmt); + if (!isBlockTerminated(currentBlock)) + jumpTo(exitBlock); + if (elseBlock != null) { + startBlock(elseBlock); + compileStatement(ifElseStmt.elseStmt); + if (!isBlockTerminated(currentBlock)) + jumpTo(exitBlock); + } + startBlock(exitBlock); + } + + private void compileLet(AST.VarStmt letStmt) { + if (letStmt.expr != null) { + boolean indexed = compileExpr(letStmt.expr); + if (indexed) + codeIndexedLoad(); + code(new Instruction.Move(pop(), new Operand.LocalRegisterOperand(letStmt.symbol.reg))); + } + } + + private boolean compileExpr(AST.Expr expr) { + switch (expr) { + case AST.LiteralExpr constantExpr -> { + return compileConstantExpr(constantExpr); + } + case AST.BinaryExpr binaryExpr -> { + return compileBinaryExpr(binaryExpr); + } + case AST.UnaryExpr unaryExpr -> { + return compileUnaryExpr(unaryExpr); + } + case AST.NameExpr symbolExpr -> { + return compileSymbolExpr(symbolExpr); + } + case AST.NewExpr newExpr -> { + return compileNewExpr(newExpr); + } + case AST.ArrayIndexExpr arrayIndexExpr -> { + return compileArrayIndexExpr(arrayIndexExpr); + } + case AST.FieldExpr fieldExpr -> { + return compileFieldExpr(fieldExpr); + } + case AST.SetFieldExpr setFieldExpr -> { + return compileSetFieldExpr(setFieldExpr); + } + case AST.CallExpr callExpr -> { + return compileCallExpr(callExpr); + } + default -> throw new IllegalStateException("Unexpected value: " + expr); + } + } + + private boolean compileCallExpr(AST.CallExpr callExpr) { + compileExpr(callExpr.callee); + var callee = pop(); + Type.TypeFunction calleeType = null; + if (callee instanceof Operand.LocalFunctionOperand functionOperand) + calleeType = functionOperand.functionType; + else throw new CompilerException("Cannot call a non function type"); + var returnStackPos = virtualStack.size(); + List args = new ArrayList<>(); + for (AST.Expr expr: callExpr.args) { + boolean indexed = compileExpr(expr); + if (indexed) + codeIndexedLoad(); + var arg = top(); + if (!(arg instanceof Operand.TempRegisterOperand) ) { + var origArg = pop(); + arg = createTemp(origArg.type); + code(new Instruction.Move(origArg, arg)); + } + args.add((Operand.RegisterOperand) arg); + } + // Simulate the actions on the stack + for (int i = 0; i < args.size(); i++) + pop(); + Operand.TempRegisterOperand ret = null; + if (callExpr.callee.type instanceof Type.TypeFunction tf && + !(tf.returnType instanceof Type.TypeVoid)) { + ret = createTemp(tf.returnType); + //assert ret.regnum-maxLocalReg == returnStackPos; + } + code(new Instruction.Call(returnStackPos, ret, calleeType, args.toArray(new Operand.RegisterOperand[args.size()]))); + return false; + } + + private Type.TypeStruct getStructType(Type t) { + if (t instanceof Type.TypeStruct typeStruct) { + return typeStruct; + } + else if (t instanceof Type.TypeNullable ptr && + ptr.baseType instanceof Type.TypeStruct typeStruct) { + return typeStruct; + } + else + throw new CompilerException("Unexpected type: " + t); + } + + private boolean compileFieldExpr(AST.FieldExpr fieldExpr) { + Type.TypeStruct typeStruct = getStructType(fieldExpr.object.type); + int fieldIndex = typeStruct.getFieldIndex(fieldExpr.fieldName); + if (fieldIndex < 0) + throw new CompilerException("Field " + fieldExpr.fieldName + " not found"); + boolean indexed = compileExpr(fieldExpr.object); + if (indexed) + codeIndexedLoad(); + pushOperand(new Operand.LoadFieldOperand(pop(), fieldExpr.fieldName, fieldIndex)); + return true; + } + + private boolean compileArrayIndexExpr(AST.ArrayIndexExpr arrayIndexExpr) { + compileExpr(arrayIndexExpr.array); + boolean indexed = compileExpr(arrayIndexExpr.expr); + if (indexed) + codeIndexedLoad(); + Operand index = pop(); + Operand array = pop(); + pushOperand(new Operand.LoadIndexedOperand(array, index)); + return true; + } + + private boolean compileSetFieldExpr(AST.SetFieldExpr setFieldExpr) { + Type.TypeStruct structType = (Type.TypeStruct) setFieldExpr.objectType; + int fieldIndex = structType.getFieldIndex(setFieldExpr.fieldName); + if (fieldIndex == -1) + throw new CompilerException("Field " + setFieldExpr.fieldName + " not found in struct " + structType.name); + pushOperand(new Operand.LoadFieldOperand(top(), setFieldExpr.fieldName, fieldIndex)); + boolean indexed = compileExpr(setFieldExpr.value); + if (indexed) + codeIndexedLoad(); + codeIndexedStore(); + return false; + } + + private void codeNew(Type type) { + var temp = createTemp(type); + if (type instanceof Type.TypeArray typeArray) { + code(new Instruction.NewArray(typeArray, temp)); + } + else if (type instanceof Type.TypeStruct typeStruct) { + code(new Instruction.NewStruct(typeStruct, temp)); + } + else + throw new CompilerException("Unexpected type: " + type); + } + + private void codeStoreAppend() { + var operand = pop(); + code(new Instruction.AStoreAppend((Operand.RegisterOperand) top(), operand)); + } + + private boolean compileNewExpr(AST.NewExpr newExpr) { + codeNew(newExpr.type); + if (newExpr.initExprList != null && !newExpr.initExprList.isEmpty()) { + if (newExpr.type instanceof Type.TypeArray) { + for (AST.Expr expr : newExpr.initExprList) { + // Maybe have specific AST similar to how we have SetFieldExpr? + boolean indexed = compileExpr(expr); + if (indexed) + codeIndexedLoad(); + codeStoreAppend(); + } + } + else if (newExpr.type instanceof Type.TypeStruct) { + for (AST.Expr expr : newExpr.initExprList) { + compileExpr(expr); + } + } + } + return false; + } + + private boolean compileSymbolExpr(AST.NameExpr symbolExpr) { + if (symbolExpr.type instanceof Type.TypeFunction functionType) + pushOperand(new Operand.LocalFunctionOperand(functionType)); + else { + Symbol.VarSymbol varSymbol = (Symbol.VarSymbol) symbolExpr.symbol; + pushLocal(varSymbol.reg); + } + return false; + } + + private boolean compileBinaryExpr(AST.BinaryExpr binaryExpr) { + String opCode = null; + boolean indexed = compileExpr(binaryExpr.expr1); + if (indexed) + codeIndexedLoad(); + indexed = compileExpr(binaryExpr.expr2); + if (indexed) + codeIndexedLoad(); + opCode = binaryExpr.op.str; + Operand right = pop(); + Operand left = pop(); + if (left instanceof Operand.ConstantOperand leftconstant && + right instanceof Operand.ConstantOperand rightconstant) { + long value = 0; + switch (opCode) { + case "+": value = leftconstant.value + rightconstant.value; break; + case "-": value = leftconstant.value - rightconstant.value; break; + case "*": value = leftconstant.value * rightconstant.value; break; + case "/": value = leftconstant.value / rightconstant.value; break; + case "%": value = leftconstant.value % rightconstant.value; break; + case "==": value = leftconstant.value == rightconstant.value ? 1 : 0; break; + case "!=": value = leftconstant.value != rightconstant.value ? 1 : 0; break; + case "<": value = leftconstant.value < rightconstant.value ? 1: 0; break; + case ">": value = leftconstant.value > rightconstant.value ? 1 : 0; break; + case "<=": value = leftconstant.value <= rightconstant.value ? 1 : 0; break; + case ">=": value = leftconstant.value <= rightconstant.value ? 1 : 0; break; + default: throw new CompilerException("Invalid binary op"); + } + pushConstant(value, leftconstant.type); + } + else { + var temp = createTemp(binaryExpr.type); + code(new Instruction.Binary(opCode, temp, left, right)); + } + return false; + } + + private boolean compileUnaryExpr(AST.UnaryExpr unaryExpr) { + String opCode; + boolean indexed = compileExpr(unaryExpr.expr); + if (indexed) + codeIndexedLoad(); + opCode = unaryExpr.op.str; + Operand top = pop(); + if (top instanceof Operand.ConstantOperand constant) { + switch (opCode) { + case "-": pushConstant(-constant.value, constant.type); break; + // Maybe below we should explicitly set Int + case "!": pushConstant(constant.value == 0?1:0, constant.type); break; + default: throw new CompilerException("Invalid unary op"); + } + } + else { + var temp = createTemp(unaryExpr.type); + code(new Instruction.Unary(opCode, temp, top)); + } + return false; + } + + private boolean compileConstantExpr(AST.LiteralExpr constantExpr) { + pushConstant(constantExpr.value.num.intValue(), constantExpr.type); + return false; + } + + private void pushConstant(long value, Type type) { + pushOperand(new Operand.ConstantOperand(value, type)); + } + + private Operand.TempRegisterOperand createTemp(Type type) { + var offset = virtualStack.size()+maxLocalReg; + var id = nextReg++; + var name = "%t" + id; + var tempRegister = new Operand.TempRegisterOperand(offset, id, name, type); + pushOperand(tempRegister); + if (maxStackSize < virtualStack.size()) + maxStackSize = virtualStack.size(); + return tempRegister; + } + + private void pushLocal(Register reg) { + pushOperand(new Operand.LocalRegisterOperand(reg)); + } + + private void pushOperand(Operand operand) { + virtualStack.add(operand); + } + + private Operand pop() { + return virtualStack.removeLast(); + } + + private Operand top() { + return virtualStack.getLast(); + } + + private void codeIndexedLoad() { + Operand indexed = pop(); + var temp = createTemp(indexed.type); + if (indexed instanceof Operand.LoadIndexedOperand loadIndexedOperand) { + code(new Instruction.ArrayLoad(loadIndexedOperand, temp)); + } + else if (indexed instanceof Operand.LoadFieldOperand loadFieldOperand) { + code(new Instruction.GetField(loadFieldOperand, temp)); + } + else + code(new Instruction.Move(indexed, temp)); + } + + private void codeIndexedStore() { + Operand value = pop(); + Operand indexed = pop(); + if (indexed instanceof Operand.LoadIndexedOperand loadIndexedOperand) { + code(new Instruction.ArrayStore(value, loadIndexedOperand)); + } + else if (indexed instanceof Operand.LoadFieldOperand loadFieldOperand) { + code(new Instruction.SetField(value, loadFieldOperand)); + } + else + code(new Instruction.Move(value, indexed)); + } + + private boolean vstackEmpty() { + return virtualStack.isEmpty(); + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/DominatorTree.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/DominatorTree.java new file mode 100644 index 0000000..ec13103 --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/DominatorTree.java @@ -0,0 +1,238 @@ +package com.compilerprogramming.ezlang.bytecode; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.function.Consumer; + +/** + * The dominator tree construction algorithm is based on figure 9.24, + * chapter 9, p 532, of Engineering a Compiler. + *

+ * The algorithm is also described in the paper 'A Simple, Fast + * Dominance Algorithm' by Keith D. Cooper, Timothy J. Harvey and + * Ken Kennedy. + */ +public class DominatorTree { + BasicBlock entry; + // List of basic blocks reachable from _entry block, including the _entry + List blocks; + + int preorder; + int rpostorder; + + /** + * Builds a Dominator Tree. + * + * @param entry The entry block + */ + public DominatorTree(BasicBlock entry) { + this.entry = entry; + blocks = findAllBlocks(entry); + calculateDominatorTree(); + populateTree(); + setDepth(); + calculateDominanceFrontiers(); + } + + /** + * Utility to locate all the basic blocks, order does not matter. + */ + private static List findAllBlocks(BasicBlock root) { + List nodes = new ArrayList<>(); + postOrderWalk(root, (n) -> nodes.add(n), new HashSet<>()); + return nodes; + } + + static void postOrderWalk(BasicBlock n, Consumer consumer, HashSet visited) { + visited.add(n); + /* For each successor node */ + for (BasicBlock s : n.successors) { + if (!visited.contains(s)) + postOrderWalk(s, consumer, visited); + } + consumer.accept(n); + } + + private void calculateDominatorTree() { + resetDomInfo(); + annotateBlocksWithRPO(); + sortBlocksByRPO(); + + // Set IDom entry for root to itself + entry.idom = entry; + boolean changed = true; + while (changed) { + changed = false; + // for all nodes, b, in reverse postorder (except root) + for (BasicBlock bb : blocks) { + if (bb == entry) // skip root + continue; + // NewIDom = first (processed) predecessor of b, pick one + BasicBlock firstPred = findFirstPredecessorWithIdom(bb); + assert firstPred != null; + BasicBlock newIDom = firstPred; + // for all other predecessors, p, of b + for (BasicBlock predecessor : bb.predecessors) { + if (predecessor == firstPred) continue; // all other predecessors + if (predecessor.idom != null) { + // i.e. IDoms[p] calculated + newIDom = intersect(predecessor, newIDom); + } + } + if (bb.idom != newIDom) { + bb.idom = newIDom; + changed = true; + } + } + } + } + + private void resetDomInfo() { + for (BasicBlock bb : blocks) + bb.resetDomInfo(); + } + + /** + * Assign rpo number to all the basic blocks. + * The rpo number defines the Reverse Post Order traversal of blocks. + * The Dominance calculator requires the rpo number. + */ + private void annotateBlocksWithRPO() { + preorder = 1; + rpostorder = blocks.size(); + for (BasicBlock n : blocks) n.resetRPO(); + postOrderWalkSetRPO(entry); + } + + // compute rpo using a depth first search + private void postOrderWalkSetRPO(BasicBlock n) { + n.pre = preorder++; + for (BasicBlock s : n.successors) { + if (s.pre == 0) + postOrderWalkSetRPO(s); + } + n.rpo = rpostorder--; + } + + /** + * Reverse post order gives a topological sort order + */ + private void sortBlocksByRPO() { + blocks.sort(Comparator.comparingInt(n -> n.rpo)); + } + + /** + * Finds nearest common ancestor + *

+ * The algorithm starts at the two nodes whose sets are being intersected, and walks + * upward from each toward the root. By comparing the nodes with their RPO numbers + * the algorithm finds the common ancestor - the immediate dominator of i and j. + */ + private BasicBlock intersect(BasicBlock i, BasicBlock j) { + BasicBlock finger1 = i; + BasicBlock finger2 = j; + while (finger1 != finger2) { + while (finger1.rpo > finger2.rpo) { + finger1 = finger1.idom; + assert finger1 != null; + } + while (finger2.rpo > finger1.rpo) { + finger2 = finger2.idom; + assert finger2 != null; + } + } + return finger1; + } + + /** + * Look for the first predecessor whose immediate dominator has been calculated. + * Because of the order in which this search occurs, we will always find at least 1 + * such predecessor. + */ + private BasicBlock findFirstPredecessorWithIdom(BasicBlock n) { + for (BasicBlock p : n.predecessors) { + if (p.idom != null) return p; + } + return null; + } + + /** + * Setup the dominator tree. + * Each block gets the list of blocks it strictly dominates. + */ + private void populateTree() { + for (BasicBlock block : blocks) { + BasicBlock idom = block.idom; + if (idom == block) // root + continue; + // add edge from idom to n + idom.dominatedChildren.add(block); + } + } + + /** + * Sets the dominator depth on each block + */ + private void setDepth() { + entry.domDepth = 1; + setDepth_(entry); + } + + /** + * Sets the dominator depth on each block + */ + private void setDepth_(BasicBlock block) { + BasicBlock idom = block.idom; + if (idom != block) { + assert idom.domDepth > 0; + block.domDepth = idom.domDepth + 1; + } else { + assert idom.domDepth == 1; + assert idom.domDepth == block.domDepth; + } + for (BasicBlock child : block.dominatedChildren) + setDepth_(child); + } + + /** + * Calculates dominance-frontiers for nodes + */ + private void calculateDominanceFrontiers() { + // Dominance-Frontier Algorithm - fig 5 in 'A Simple, Fast Dominance Algorithm' + //for all nodes, b + // if the number of predecessors of b ≥ 2 + // for all predecessors, p, of b + // runner ← p + // while runner != doms[b] + // add b to runner’s dominance frontier set + // runner = doms[runner] + for (BasicBlock b : blocks) { + if (b.predecessors.size() >= 2) { + for (BasicBlock p : b.predecessors) { + BasicBlock runner = p; + while (runner != b.idom) { + runner.dominationFrontier.add(b); + runner = runner.idom; + } + } + } + } + } + + public String generateDotOutput() { + StringBuilder sb = new StringBuilder(); + sb.append("digraph DomTree {\n"); + for (BasicBlock n : blocks) { + sb.append(n.uniqueName()).append(" [label=\"").append(n.label()).append("\"];\n"); + } + for (BasicBlock n : blocks) { + BasicBlock idom = n.idom; + if (idom == n) continue; + sb.append(idom.uniqueName()).append("->").append(n.uniqueName()).append(";\n"); + } + sb.append("}\n"); + return sb.toString(); + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Instruction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Instruction.java new file mode 100644 index 0000000..26a4326 --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Instruction.java @@ -0,0 +1,685 @@ +package com.compilerprogramming.ezlang.bytecode; + +import com.compilerprogramming.ezlang.types.Register; +import com.compilerprogramming.ezlang.types.Type; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public abstract class Instruction { + + public boolean isTerminal() { + return false; + } + @Override + public String toString() { + return toStr(new StringBuilder()).toString(); + } + + public boolean definesVar() { return false; } + public Register def() { return null; } + + public boolean usesVars() { return false; } + public List uses() { return Collections.emptyList(); } + public void replaceDef(Register newReg) {} + public void replaceUses(Register[] newUses) {} + + public static class Move extends Instruction { + public final Operand from; + public final Operand to; + public Move(Operand from, Operand to) { + this.from = from; + this.to = to; + } + + @Override + public boolean definesVar() { + return true; + } + + @Override + public Register def() { + if (to instanceof Operand.RegisterOperand registerOperand) + return registerOperand.reg; + throw new IllegalStateException(); + } + + @Override + public void replaceDef(Register newReg) { + this.to.replaceRegister(newReg); + } + + @Override + public boolean usesVars() { + return (from instanceof Operand.RegisterOperand); + } + + @Override + public List uses() { + if (from instanceof Operand.RegisterOperand registerOperand) + return List.of(registerOperand.reg); + return super.uses(); + } + + @Override + public void replaceUses(Register[] newUses) { + from.replaceRegister(newUses[0]); + } + + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(to).append(" = ").append(from); + } + } + + public static class NewArray extends Instruction { + public final Type.TypeArray type; + public final Operand.RegisterOperand destOperand; + public NewArray(Type.TypeArray type, Operand.RegisterOperand destOperand) { + this.type = type; + this.destOperand = destOperand; + } + + @Override + public boolean definesVar() { + return true; + } + + @Override + public Register def() { + return destOperand.reg; + } + + @Override + public void replaceDef(Register newReg) { + destOperand.replaceRegister(newReg); + } + + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(destOperand) + .append(" = ") + .append("New(") + .append(type) + .append(")"); + } + } + + public static class NewStruct extends Instruction { + public final Type.TypeStruct type; + public final Operand.RegisterOperand destOperand; + public NewStruct(Type.TypeStruct type, Operand.RegisterOperand destOperand) { + this.type = type; + this.destOperand = destOperand; + } + @Override + public boolean definesVar() { + return true; + } + + @Override + public Register def() { + return destOperand.reg; + } + + @Override + public void replaceDef(Register newReg) { + destOperand.replaceRegister(newReg); + } + + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(destOperand) + .append(" = ") + .append("New(") + .append(type) + .append(")"); + } + } + + public static class ArrayLoad extends Instruction { + public final Operand arrayOperand; + public final Operand indexOperand; + public final Operand.RegisterOperand destOperand; + public ArrayLoad(Operand.LoadIndexedOperand from, Operand.RegisterOperand to) { + arrayOperand = from.arrayOperand; + indexOperand = from.indexOperand; + destOperand = to; + } + @Override + public boolean definesVar() { + return true; + } + @Override + public Register def() { + return destOperand.reg; + } + @Override + public void replaceDef(Register newReg) { + destOperand.replaceRegister(newReg); + } + + @Override + public boolean usesVars() { + return true; + } + + @Override + public List uses() { + List usesList = new ArrayList<>(); + if (arrayOperand instanceof Operand.RegisterOperand registerOperand) + usesList.add(registerOperand.reg); + if (indexOperand instanceof Operand.RegisterOperand registerOperand) + usesList.add(registerOperand.reg); + return usesList; + } + + @Override + public void replaceUses(Register[] newUses) { + int i = 0; + if (arrayOperand instanceof Operand.RegisterOperand) { + arrayOperand.replaceRegister(newUses[i++]); + } + if (indexOperand instanceof Operand.RegisterOperand) { + indexOperand.replaceRegister(newUses[i]); + } + } + + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(destOperand) + .append(" = ") + .append(arrayOperand) + .append("[") + .append(indexOperand) + .append("]"); + } + } + + public static class ArrayStore extends Instruction { + public final Operand arrayOperand; + public final Operand indexOperand; + public final Operand sourceOperand; + public ArrayStore(Operand from, Operand.LoadIndexedOperand to) { + arrayOperand = to.arrayOperand; + indexOperand = to.indexOperand; + sourceOperand = from; + } + @Override + public boolean usesVars() { + return true; + } + + @Override + public List uses() { + List usesList = new ArrayList<>(); + if (arrayOperand instanceof Operand.RegisterOperand registerOperand) + usesList.add(registerOperand.reg); + if (indexOperand instanceof Operand.RegisterOperand registerOperand) + usesList.add(registerOperand.reg); + if (sourceOperand instanceof Operand.RegisterOperand registerOperand) + usesList.add(registerOperand.reg); + return usesList; + } + + @Override + public void replaceUses(Register[] newUses) { + int i = 0; + if (arrayOperand instanceof Operand.RegisterOperand) { + arrayOperand.replaceRegister(newUses[i++]); + } + if (indexOperand instanceof Operand.RegisterOperand) { + indexOperand.replaceRegister(newUses[i++]); + } + if (sourceOperand instanceof Operand.RegisterOperand) { + sourceOperand.replaceRegister(newUses[i]); + } + } + + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb + .append(arrayOperand) + .append("[") + .append(indexOperand) + .append("] = ") + .append(sourceOperand); + } + } + + public static class GetField extends Instruction { + public final Operand structOperand; + public final String fieldName; + public final int fieldIndex; + public final Operand.RegisterOperand destOperand; + public GetField(Operand.LoadFieldOperand from, Operand.RegisterOperand to) + { + this.structOperand = from.structOperand; + this.fieldName = from.fieldName; + this.fieldIndex = from.fieldIndex; + this.destOperand = to; + } + @Override + public boolean definesVar() { + return true; + } + @Override + public Register def() { + return destOperand.reg; + } + + @Override + public void replaceDef(Register newReg) { + destOperand.replaceRegister(newReg); + } + + @Override + public boolean usesVars() { + return true; + } + @Override + public List uses() { + List usesList = new ArrayList<>(); + if (structOperand instanceof Operand.RegisterOperand registerOperand) + usesList.add(registerOperand.reg); + return usesList; + } + + @Override + public void replaceUses(Register[] newUses) { + if (newUses.length > 0) + structOperand.replaceRegister(newUses[0]); + } + + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(destOperand) + .append(" = ") + .append(structOperand) + .append(".") + .append(fieldName); + } + } + + public static class SetField extends Instruction { + public final Operand structOperand; + public final String fieldName; + public final int fieldIndex; + public final Operand sourceOperand; + public SetField(Operand from,Operand.LoadFieldOperand to) + { + this.structOperand = to.structOperand; + this.fieldName = to.fieldName; + this.fieldIndex = to.fieldIndex; + this.sourceOperand = from; + } + @Override + public boolean usesVars() { + return true; + } + @Override + public List uses() { + List usesList = new ArrayList<>(); + if (structOperand instanceof Operand.RegisterOperand registerOperand) + usesList.add(registerOperand.reg); + if (sourceOperand instanceof Operand.RegisterOperand registerOperand) + usesList.add(registerOperand.reg); + return usesList; + } + @Override + public void replaceUses(Register[] newUses) { + int i = 0; + if (structOperand instanceof Operand.RegisterOperand) { + structOperand.replaceRegister(newUses[i++]); + } + if (sourceOperand instanceof Operand.RegisterOperand) { + sourceOperand.replaceRegister(newUses[i]); + } + } + + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb + .append(structOperand) + .append(".") + .append(fieldName) + .append(" = ") + .append(sourceOperand); + } + } + public static class Return extends Move { + public Return(Operand from, int slot, int regNum, Type type) { + super(from, new Operand.ReturnRegisterOperand(slot, regNum, "%ret", type)); + } + } + public static class Unary extends Instruction { + public final String unop; + public final Operand.RegisterOperand result; + public final Operand operand; + public Unary(String unop, Operand.RegisterOperand result, Operand operand) { + this.unop = unop; + this.result = result; + this.operand = operand; + } + + @Override + public boolean definesVar() { + return true; + } + + @Override + public Register def() { + return result.reg; + } + + @Override + public void replaceDef(Register newReg) { + result.replaceRegister(newReg); + } + + @Override + public boolean usesVars() { + return operand instanceof Operand.RegisterOperand registerOperand; + } + + @Override + public List uses() { + List usesList = new ArrayList<>(); + if (operand instanceof Operand.RegisterOperand registerOperand) { + usesList.add(registerOperand.reg); + } + return usesList; + } + + @Override + public void replaceUses(Register[] newUses) { + if (newUses.length > 0) + operand.replaceRegister(newUses[0]); + } + + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(result).append(" = ").append(unop).append(operand); + } + } + + public static class Binary extends Instruction { + public final String binOp; + public final Operand.RegisterOperand result; + public final Operand left; + public final Operand right; + public Binary(String binop, Operand.RegisterOperand result, Operand left, Operand right) { + this.binOp = binop; + this.result = result; + this.left = left; + this.right = right; + } + @Override + public boolean definesVar() { + return true; + } + + @Override + public Register def() { + return result.reg; + } + + @Override + public void replaceDef(Register newReg) { + result.replaceRegister(newReg); + } + + @Override + public boolean usesVars() { + return left instanceof Operand.RegisterOperand || + right instanceof Operand.RegisterOperand; + } + @Override + public List uses() { + List usesList = new ArrayList<>(); + if (left instanceof Operand.RegisterOperand registerOperand) { + usesList.add(registerOperand.reg); + } + if (right instanceof Operand.RegisterOperand registerOperand) { + usesList.add(registerOperand.reg); + } + return usesList; + } + + @Override + public void replaceUses(Register[] newUses) { + int i = 0; + if (left instanceof Operand.RegisterOperand) { + left.replaceRegister(newUses[i++]); + } + if (right instanceof Operand.RegisterOperand) { + right.replaceRegister(newUses[i]); + } + } + + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(result).append(" = ").append(left).append(binOp).append(right); + } + } + + public static class AStoreAppend extends Instruction { + public final Operand.RegisterOperand array; + public final Operand value; + public AStoreAppend(Operand.RegisterOperand array, Operand value) { + this.array = array; + this.value = value; + } + @Override + public boolean usesVars() { + return true; + } + @Override + public List uses() { + List usesList = new ArrayList<>(); + usesList.add(array.reg); + if (value instanceof Operand.RegisterOperand registerOperand) { + usesList.add(registerOperand.reg); + } + return usesList; + } + @Override + public void replaceUses(Register[] newUses) { + array.replaceRegister(newUses[0]); + if (value instanceof Operand.RegisterOperand) + value.replaceRegister(newUses[1]); + } + + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(array).append(".append(").append(value).append(")"); + } + } + + public static class ConditionalBranch extends Instruction { + public final Operand condition; + public final BasicBlock trueBlock; + public final BasicBlock falseBlock; + public ConditionalBranch(BasicBlock currentBlock, Operand condition, BasicBlock trueBlock, BasicBlock falseBlock) { + this.condition = condition; + this.trueBlock = trueBlock; + this.falseBlock = falseBlock; + currentBlock.addSuccessor(trueBlock); + currentBlock.addSuccessor(falseBlock); + } + @Override + public boolean usesVars() { + return condition instanceof Operand.RegisterOperand; + } + @Override + public List uses() { + List usesList = new ArrayList<>(); + if (condition instanceof Operand.RegisterOperand registerOperand) { + usesList.add(registerOperand.reg); + } + return usesList; + } + + @Override + public void replaceUses(Register[] newUses) { + if (condition instanceof Operand.RegisterOperand) { + condition.replaceRegister(newUses[0]); + } + } + + @Override + public boolean isTerminal() { + return true; + } + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append("if ").append(condition).append(" goto L").append(trueBlock.bid).append(" else goto L").append(falseBlock.bid); + } + } + + public static class Call extends Instruction { + public final Type.TypeFunction callee; + public final Operand.RegisterOperand[] args; + public final Operand.RegisterOperand returnOperand; + public final int newbase; + public Call(int newbase, Operand.RegisterOperand returnOperand, Type.TypeFunction callee, Operand.RegisterOperand... args) { + this.returnOperand = returnOperand; + this.callee = callee; + this.args = args; + this.newbase = newbase; + } + + @Override + public boolean definesVar() { + return returnOperand != null; + } + + @Override + public Register def() { + return returnOperand != null ? returnOperand.reg : null; + } + + @Override + public void replaceDef(Register newReg) { + if (returnOperand != null) + returnOperand.replaceRegister(newReg); + } + + @Override + public boolean usesVars() { + return args != null && args.length > 0; + } + @Override + public List uses() { + List usesList = new ArrayList<>(); + for (Operand.RegisterOperand argOperand : args) { + usesList.add(argOperand.reg); + } + return usesList; + } + + @Override + public void replaceUses(Register[] newUses) { + if (args == null) + return; + for (int i = 0; i < args.length; i++) { + args[i].replaceRegister(newUses[i]); + } + } + + @Override + public StringBuilder toStr(StringBuilder sb) { + if (returnOperand != null) { + sb.append(returnOperand).append(" = "); + } + sb.append("call ").append(callee); + if (args.length > 0) + sb.append(" params "); + for (int i = 0; i < args.length; i++) { + if (i > 0) sb.append(", "); + sb.append(args[i]); + } + return sb; + } + } + + public static class Jump extends Instruction { + public final BasicBlock jumpTo; + public Jump(BasicBlock jumpTo) { + this.jumpTo = jumpTo; + } + @Override + public boolean isTerminal() { + return true; + } + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append("goto ").append(" L").append(jumpTo.bid); + } + } + + public static class Phi extends Instruction { + public Operand.RegisterOperand dest; + public final List inputs = new ArrayList<>(); + public Phi(Register dest, List inputs) { + this.dest = new Operand.RegisterOperand(dest); + for (Register input : inputs) { + this.inputs.add(new Operand.RegisterOperand(input)); + } + } + @Override + public boolean definesVar() { + return true; + } + @Override + public Register def() { + return dest.reg; + } + @Override + public void replaceDef(Register newReg) { + dest = new Operand.RegisterOperand(newReg); + } + public void replaceInput(int i, Register newReg) { + inputs.set(i, new Operand.RegisterOperand(newReg)); + } + @Override + public StringBuilder toStr(StringBuilder sb) { + sb.append(dest).append(" = phi("); + for (int i = 0; i < inputs.size(); i++) { + if (i > 0) sb.append(", "); + sb.append(inputs.get(i)); + } + sb.append(")"); + return sb; + } + } + + public static class ArgInstruction extends Instruction { + Operand.RegisterOperand arg; + + @Override + public boolean definesVar() { + return true; + } + @Override + public Register def() { + return arg.reg; + } + + @Override + public void replaceDef(Register newReg) { + arg.replaceRegister(newReg); + } + + public ArgInstruction(Operand.RegisterOperand arg) { + this.arg = arg; + } + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append("arg ").append(arg); + } + } + + public abstract StringBuilder toStr(StringBuilder sb); +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopFinder.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopFinder.java new file mode 100644 index 0000000..313bde4 --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopFinder.java @@ -0,0 +1,87 @@ +package com.compilerprogramming.ezlang.bytecode; + +import java.util.*; +import java.util.stream.Collectors; + +public class LoopFinder { + // Based on Compilers: Principles, Techniques and Tools + // p 604 + // 1986 ed + public static LoopNest getNaturalLoop(BasicBlock head, BasicBlock backedge) { + Stack stack = new Stack<>(); + LoopNest loop = new LoopNest(head); + loop.insert(backedge, stack); + // trace back up from backedge to head + while (!stack.isEmpty()) { + BasicBlock m = stack.pop(); + for (BasicBlock pred : m.predecessors) { + loop.insert(pred, stack); + } + } + return loop; + } + + // Based on Compilers: Principles, Techniques and Tools + // p 604 + // 1986 ed + public static List findLoops(List nodes) { + List list = new ArrayList<>(); + for (BasicBlock n : nodes) { + for (BasicBlock input : n.predecessors) { + if (n.dominates(input)) { + list.add(getNaturalLoop(n, input)); + } + } + } + return list; + } + + public static List mergeLoopsWithSameHead(List loopNests) { + HashMap map = new HashMap<>(); + for (LoopNest loopNest : loopNests) { + LoopNest sameHead = map.get(loopNest._loopHead.bid); + if (sameHead == null) map.put(loopNest._loopHead.bid, loopNest); + else sameHead._blocks.addAll(loopNest._blocks); + } + return map.values().stream().collect(Collectors.toList()); + } + + public static LoopNest buildLoopTree(List loopNests) { + for (LoopNest nest1 : loopNests) { + for (LoopNest nest2 : loopNests) { + boolean isNested = nest1.contains(nest2); + if (isNested) { + if (nest2._parent == null) nest2._parent = nest1; + else if (nest1._loopHead.domDepth > nest2._parent._loopHead.domDepth) nest2._parent = nest1; + } + } + } + LoopNest top = null; + for (LoopNest nest : loopNests) { + if (nest._parent != null) nest._parent._kids.add(nest); + else top = nest; + } + return top; + } + + public static void annotateBasicBlocks(LoopNest loop, Set visited) { + if (visited.contains(loop)) + return; + visited.add(loop); + for (LoopNest kid: loop._kids) { + kid._depth = loop._depth+1; + annotateBasicBlocks(kid, visited); + } + for (BasicBlock block: loop._blocks) { + if (block.loop == null) + block.loop = loop; + } + } + + public static void annotateBasicBlocks(LoopNest top) { + if (top == null) // No loop + return; + top._depth = 1; + annotateBasicBlocks(top, new HashSet<>()); + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopNest.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopNest.java new file mode 100644 index 0000000..1aeca0a --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/LoopNest.java @@ -0,0 +1,62 @@ +package com.compilerprogramming.ezlang.bytecode; + +import java.util.*; + +public class LoopNest { + /** + * Parent loop + */ + public LoopNest _parent; + /** + * The block that is the head of the loop + */ + public final BasicBlock _loopHead; + /** + * Blocks that are part of this loop + */ + public final Set _blocks; + /** + * Children as per Loop Tree + */ + public final List _kids; + /** + * Loop Tree depth - top has depth 1 + */ + public int _depth; + + public LoopNest(BasicBlock loopHead) { + _loopHead = loopHead; + _blocks = new HashSet<>(); + _blocks.add(loopHead); + _kids = new ArrayList<>(); + } + + public void insert(BasicBlock m, Stack stack) { + if (!_blocks.contains(m)) { + _blocks.add(m); + stack.push(m); + } + } + + public boolean contains(LoopNest other) { + return this != other && _blocks.containsAll(other._blocks); + } + + public String uniqueName() { return "Loop_" + _loopHead.uniqueName(); } + public String label() { return "Loop(" + _loopHead.label() + ":" + _depth + ")"; } + + public static String generateDotOutput(List loopNests) { + StringBuilder sb = new StringBuilder(); + sb.append("digraph LoopTree {\n"); + for (LoopNest n : loopNests) { + sb.append(n.uniqueName()).append(" [label=\"").append(n.label()).append("\"];\n"); + } + for (LoopNest n : loopNests) { + for (LoopNest c: n._kids) { + sb.append(n.uniqueName()).append("->").append(c.uniqueName()).append(";\n"); + } + } + sb.append("}\n"); + return sb.toString(); + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Operand.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Operand.java new file mode 100644 index 0000000..90e481d --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/Operand.java @@ -0,0 +1,130 @@ +package com.compilerprogramming.ezlang.bytecode; + +import com.compilerprogramming.ezlang.types.Register; +import com.compilerprogramming.ezlang.types.Type; + +public class Operand { + + Type type; + + public void replaceRegister(Register register) {} + + public static class ConstantOperand extends Operand { + public final long value; + public ConstantOperand(long value, Type type) { + this.value = value; + this.type = type; + } + @Override + public String toString() { + return String.valueOf(value); + } + } + + public static class RegisterOperand extends Operand { + public Register reg; + public RegisterOperand(int slot, int regnum, String name, Type type, boolean isTemp) { + this.reg = new Register(slot, regnum, name, type); + } + public RegisterOperand(Register reg) { + this.reg = reg; + } + public int slot() { return reg.slot; } + + @Override + public void replaceRegister(Register register) { + this.reg = register; + } + + @Override + public String toString() { + return reg.name; + } + } + + public static class LocalRegisterOperand extends RegisterOperand { + public LocalRegisterOperand(Register reg) { + super(reg); + } + } + + public static class LocalFunctionOperand extends Operand { + public final Type.TypeFunction functionType; + public LocalFunctionOperand(Type.TypeFunction functionType) { + this.functionType = functionType; + } + @Override + public String toString() { + return functionType.toString(); + } + } + + /** + * Represents the return register, which is the location where + * the caller will expect to see any return value. The VM must map + * this to appropriate location. + */ + public static class ReturnRegisterOperand extends RegisterOperand { + public ReturnRegisterOperand(int slot, int regnum, String name, Type type) { super(slot, regnum, name, type, true); } + @Override + public String toString() { return "%ret"; } + } + + /** + * Represents a temp register, maps to a location on the + * virtual stack. Temps start at offset 0, but this is a relative + * register number from start of temp area. + */ + public static class TempRegisterOperand extends RegisterOperand { + public TempRegisterOperand(int offset, int regnum, String name, Type type) { + super(offset, regnum, name, type, true); + } + } + + public static class IndexedOperand extends Operand {} + + public static class LoadIndexedOperand extends IndexedOperand { + public final Operand arrayOperand; + public final Operand indexOperand; + public LoadIndexedOperand(Operand arrayOperand, Operand indexOperand) { + this.arrayOperand = arrayOperand; + this.indexOperand = indexOperand; + assert !(indexOperand instanceof IndexedOperand) && + !(arrayOperand instanceof IndexedOperand); + } + @Override + public String toString() { + return arrayOperand + "[" + indexOperand + "]"; + } + } + + public static class LoadFieldOperand extends IndexedOperand { + public final Operand structOperand; + public final int fieldIndex; + public final String fieldName; + public LoadFieldOperand(Operand structOperand, String fieldName, int field) { + this.structOperand = structOperand; + this.fieldName = fieldName; + this.fieldIndex = field; + assert !(structOperand instanceof IndexedOperand); + } + + @Override + public String toString() { + return structOperand + "." + fieldName; + } + } + + public static class NewTypeOperand extends Operand { + public final Type type; + public NewTypeOperand(Type type) { + this.type = type; + } + + @Override + public String toString() { + return "New(" + type + ")"; + } + } + +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/SSATransform.java b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/SSATransform.java new file mode 100644 index 0000000..3ff81cf --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/bytecode/SSATransform.java @@ -0,0 +1,228 @@ +package com.compilerprogramming.ezlang.bytecode; + +import com.compilerprogramming.ezlang.types.Register; + +import java.util.*; + +/** + * Transform a bytecode function to SSA form + */ +public class SSATransform { + + BytecodeFunction function; + DominatorTree domTree; + Register[] globals; + BBSet[] blockSets; + List blocks; + int[] counters; + VersionStack[] stacks; + + public SSATransform(BytecodeFunction bytecodeFunction) { + this.function = bytecodeFunction; + setupGlobals(); + computeDomTreeAndDominanceFrontiers(); + this.blocks = domTree.blocks; + findGlobalVars(); + insertPhis(); + renameVars(); + } + + private void computeDomTreeAndDominanceFrontiers() { + domTree = new DominatorTree(function.entry); + } + + private void setupGlobals() { + // fixme this should really just look at locals I think + globals = new Register[function.nextReg]; + blockSets = new BBSet[function.nextReg]; + } + + /** + * Compute set of registers that are live across multiple blocks + * i.e. are not exclusively used in a single block. + */ + private void findGlobalVars() { + for (BasicBlock block : blocks) { + var varKill = new HashSet(); + for (Instruction instruction: block.instructions) { + if (instruction.usesVars()) { + for (Register reg : instruction.uses()) { + if (!varKill.contains(reg.id)) { + globals[reg.id] = reg; + } + } + } + if (instruction.definesVar()) { + Register reg = instruction.def(); + varKill.add(reg.id); + if (blockSets[reg.id] == null) { + blockSets[reg.id] = new BBSet(); + } + blockSets[reg.id].add(block); + } + } + } + } + + void insertPhis() { + for (int i = 0; i < globals.length; i++) { + Register x = globals[i]; + if (x != null) { + var visited = new BitSet(); + var worklist = new WorkList(blockSets[x.id].blocks); + var b = worklist.pop(); + while (b != null) { + visited.set(b.bid); + for (BasicBlock d: b.dominationFrontier) { + // insert phi for x in d + d.insertPhiFor(x); + if (!visited.get(d.bid)) + worklist.push(d); + } + b = worklist.pop(); + } + } + } + } + + void renameVars() { + initVersionCounters(); + search(function.entry); + } + + /** + * Creates and pushes new name + */ + Register makeVersion(Register reg) { + int version = counters[reg.id]; + if (version != reg.ssaVersion) + reg = reg.cloneWithVersion(version); + stacks[reg.id].push(reg); + counters[reg.id] = counters[reg.id] + 1; + return reg; + } + + /** + * Recursively walk the Dominator Tree, renaming variables. + * Implementation is based on the algorithm in the Preston Briggs + * paper Practical Improvements to the Construction and Destruction of + * Static Single Assignment Form + */ + void search(BasicBlock block) { + // Replace v = phi(...) with v_i = phi(...) + for (Instruction.Phi phi: block.phis()) { + Register ssaReg = makeVersion(phi.def()); + phi.replaceDef(ssaReg); + } + // for each instruction v = x op y + // first replace x,y + // then replace v + for (Instruction instruction: block.instructions) { + if (instruction instanceof Instruction.Phi) + continue; + // first replace x,y + if (instruction.usesVars()) { + var uses = instruction.uses(); + Register[] newUses = new Register[uses.size()]; + for (int i = 0; i < newUses.length; i++) { + Register oldReg = uses.get(i); + newUses[i] = stacks[oldReg.id].top(); + instruction.replaceUses(newUses); + } + } + // then replace v + if (instruction.definesVar()) { + Register ssaReg = makeVersion(instruction.def()); + instruction.replaceDef(ssaReg); + } + } + // Update phis in successor blocks + for (BasicBlock s: block.successors) { + int j = whichPred(s,block); + for (Instruction.Phi phi: s.phis()) { + Register oldReg = phi.inputs.get(j).reg; + phi.replaceInput(j, stacks[oldReg.id].top()); + } + } + // Recurse down the dominator tree + for (BasicBlock c: block.dominatedChildren) { + search(c); + } + // Pop stacks for defs + for (Instruction i: block.instructions) { + if (i.definesVar()) { + var reg = i.def(); + stacks[reg.id].pop(); + } + } + } + + private int whichPred(BasicBlock s, BasicBlock block) { + int i = 0; + for (BasicBlock p: s.predecessors) { + if (p == block) + return i; + i++; + } + throw new IllegalStateException(); + } + + private void initVersionCounters() { + counters = new int[globals.length]; + stacks = new VersionStack[globals.length]; + for (int i = 0; i < globals.length; i++) { + counters[i] = 0; + stacks[i] = new VersionStack(); + } + } + + static class BBSet { + Set blocks = new HashSet<>(); + void add(BasicBlock block) { blocks.add(block); } + } + + static class VersionStack { + List stack = new ArrayList<>(); + void push(Register r) { stack.add(r); } + Register top() { return stack.getLast(); } + void pop() { stack.removeLast(); } + } + + /** + * Simple worklist + */ + public static class WorkList { + + private ArrayList blocks; + private final BitSet members; + + WorkList() { + blocks = new ArrayList<>(); + members = new BitSet(); + } + WorkList(Collection blocks) { + this(); + addAll(blocks); + } + public BasicBlock push( BasicBlock x ) { + if( x==null ) return null; + int idx = x.bid; + if( !members.get(idx) ) { + members.set(idx); + blocks.add(x); + } + return x; + } + public void addAll( Collection ary ) { + for( BasicBlock n : ary ) + push(n); + } + BasicBlock pop() { + if ( blocks.isEmpty() ) + return null; + var x = blocks.removeFirst(); + members.clear(x.bid); + return x; + } + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/ExecutionStack.java b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/ExecutionStack.java new file mode 100644 index 0000000..dbfcac9 --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/ExecutionStack.java @@ -0,0 +1,12 @@ +package com.compilerprogramming.ezlang.interpreter; + +public class ExecutionStack { + + public Value[] stack; + public int sp; + + public ExecutionStack(int maxStackSize) { + this.stack = new Value[maxStackSize]; + this.sp = -1; + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java new file mode 100644 index 0000000..3e48b71 --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java @@ -0,0 +1,256 @@ +package com.compilerprogramming.ezlang.interpreter; + +import com.compilerprogramming.ezlang.bytecode.BasicBlock; +import com.compilerprogramming.ezlang.bytecode.BytecodeFunction; +import com.compilerprogramming.ezlang.bytecode.Instruction; +import com.compilerprogramming.ezlang.bytecode.Operand; +import com.compilerprogramming.ezlang.exceptions.CompilerException; +import com.compilerprogramming.ezlang.exceptions.InterpreterException; +import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.Type; +import com.compilerprogramming.ezlang.types.TypeDictionary; + +public class Interpreter { + + TypeDictionary typeDictionary; + + public Interpreter(TypeDictionary typeDictionary) { + this.typeDictionary = typeDictionary; + } + + public Value run(String functionName) { + Symbol symbol = typeDictionary.lookup(functionName); + if (symbol instanceof Symbol.FunctionTypeSymbol functionSymbol) { + Frame frame = new Frame(functionSymbol); + ExecutionStack execStack = new ExecutionStack(1024); + return interpret(execStack, frame); + } + else { + throw new InterpreterException("Unknown function: " + functionName); + } + } + + public Value interpret(ExecutionStack execStack, Frame frame) { + BytecodeFunction currentFunction = frame.bytecodeFunction; + BasicBlock currentBlock = currentFunction.entry; + int ip = -1; + int base = frame.base; + boolean done = false; + Value returnValue = null; + + while (!done) { + Instruction instruction; + + ip++; + instruction = currentBlock.instructions.get(ip); + switch (instruction) { + case Instruction.Return returnInst -> { + if (returnInst.from instanceof Operand.ConstantOperand constantOperand) { + execStack.stack[base] = new Value.IntegerValue(constantOperand.value); + } + else if (returnInst.from instanceof Operand.RegisterOperand registerOperand) { + execStack.stack[base] = execStack.stack[base+registerOperand.slot()]; + } + else throw new IllegalStateException(); + returnValue = execStack.stack[base]; + } + case Instruction.Move moveInst -> { + if (moveInst.to instanceof Operand.RegisterOperand toReg) { + if (moveInst.from instanceof Operand.RegisterOperand fromReg) { + execStack.stack[base + toReg.slot()] = execStack.stack[base + fromReg.slot()]; + } + else if (moveInst.from instanceof Operand.ConstantOperand constantOperand) { + execStack.stack[base + toReg.slot()] = new Value.IntegerValue(constantOperand.value); + } + else throw new IllegalStateException(); + } + else throw new IllegalStateException(); + } + case Instruction.Jump jumpInst -> { + currentBlock = jumpInst.jumpTo; + ip = -1; + if (currentBlock == currentFunction.exit) + done = true; + } + case Instruction.ConditionalBranch cbrInst -> { + boolean condition; + if (cbrInst.condition instanceof Operand.RegisterOperand registerOperand) { + Value value = execStack.stack[base + registerOperand.slot()]; + if (value instanceof Value.IntegerValue integerValue) { + condition = integerValue.value != 0; + } + else { + condition = value != null; + } + } + else if (cbrInst.condition instanceof Operand.ConstantOperand constantOperand) { + condition = constantOperand.value != 0; + } + else throw new IllegalStateException(); + if (condition) + currentBlock = cbrInst.trueBlock; + else + currentBlock = cbrInst.falseBlock; + ip = -1; + if (currentBlock == currentFunction.exit) + done = true; + } + case Instruction.Call callInst -> { + // Copy args to new frame + int baseReg = base+currentFunction.frameSize(); + int reg = baseReg; + for (Operand.RegisterOperand arg: callInst.args) { + execStack.stack[base + reg] = execStack.stack[base + arg.slot()]; + reg += 1; + } + // Call function + Frame newFrame = new Frame(frame, baseReg, callInst.callee); + interpret(execStack, newFrame); + // Copy return value in expected location + if (!(callInst.callee.returnType instanceof Type.TypeVoid)) { + execStack.stack[base + callInst.returnOperand.slot()] = execStack.stack[baseReg]; + } + } + case Instruction.Unary unaryInst -> { + // We don't expect constant here because we fold constants in unary expressions + Operand.RegisterOperand unaryOperand = (Operand.RegisterOperand) unaryInst.operand; + Value unaryValue = execStack.stack[base + unaryOperand.slot()]; + if (unaryValue instanceof Value.IntegerValue integerValue) { + switch (unaryInst.unop) { + case "-": execStack.stack[base + unaryInst.result.slot()] = new Value.IntegerValue(-integerValue.value); break; + // Maybe below we should explicitly set Int + case "!": execStack.stack[base + unaryInst.result.slot()] = new Value.IntegerValue(integerValue.value==0?1:0); break; + default: throw new CompilerException("Invalid unary op"); + } + } + else + throw new IllegalStateException("Unexpected unary operand: " + unaryOperand); + } + case Instruction.Binary binaryInst -> { + long x, y; + long value = 0; + if (binaryInst.left instanceof Operand.ConstantOperand constant) + x = constant.value; + else if (binaryInst.left instanceof Operand.RegisterOperand registerOperand) + x = ((Value.IntegerValue) execStack.stack[base + registerOperand.slot()]).value; + else throw new IllegalStateException(); + if (binaryInst.right instanceof Operand.ConstantOperand constant) + y = constant.value; + else if (binaryInst.right instanceof Operand.RegisterOperand registerOperand) + y = ((Value.IntegerValue) execStack.stack[base + registerOperand.slot()]).value; + else throw new IllegalStateException(); + switch (binaryInst.binOp) { + case "+": value = x + y; break; + case "-": value = x - y; break; + case "*": value = x * y; break; + case "/": value = x / y; break; + case "%": value = x % y; break; + case "==": value = x == y ? 1 : 0; break; + case "!=": value = x != y ? 1 : 0; break; + case "<": value = x < y ? 1: 0; break; + case ">": value = x > y ? 1 : 0; break; + case "<=": value = x <= y ? 1 : 0; break; + case ">=": value = x <= y ? 1 : 0; break; + default: throw new IllegalStateException(); + } + execStack.stack[base + binaryInst.result.slot()] = new Value.IntegerValue(value); + } + case Instruction.NewArray newArrayInst -> { + execStack.stack[base + newArrayInst.destOperand.slot()] = new Value.ArrayValue(newArrayInst.type); + } + case Instruction.NewStruct newStructInst -> { + execStack.stack[base + newStructInst.destOperand.slot()] = new Value.StructValue(newStructInst.type); + } + case Instruction.AStoreAppend arrayAppendInst -> { + Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayAppendInst.array.slot()]; + if (arrayAppendInst.value instanceof Operand.ConstantOperand constant) { + arrayValue.values.add(new Value.IntegerValue(constant.value)); + } + else if (arrayAppendInst.value instanceof Operand.RegisterOperand registerOperand) { + arrayValue.values.add(execStack.stack[base + registerOperand.slot()]); + } + else throw new IllegalStateException(); + } + case Instruction.ArrayStore arrayStoreInst -> { + if (arrayStoreInst.arrayOperand instanceof Operand.RegisterOperand arrayOperand) { + Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayOperand.slot()]; + int index = 0; + if (arrayStoreInst.indexOperand instanceof Operand.ConstantOperand constant) { + index = (int) constant.value; + } + else if (arrayStoreInst.indexOperand instanceof Operand.RegisterOperand registerOperand) { + Value.IntegerValue indexValue = (Value.IntegerValue) execStack.stack[base + registerOperand.slot()]; + index = (int) indexValue.value; + } + else throw new IllegalStateException(); + Value value; + if (arrayStoreInst.sourceOperand instanceof Operand.ConstantOperand constantOperand) { + value = new Value.IntegerValue(constantOperand.value); + } + else if (arrayStoreInst.sourceOperand instanceof Operand.RegisterOperand registerOperand) { + value = execStack.stack[base + registerOperand.slot()]; + } + else throw new IllegalStateException(); + arrayValue.values.set(index, value); + } else throw new IllegalStateException(); + } + case Instruction.ArrayLoad arrayLoadInst -> { + if (arrayLoadInst.arrayOperand instanceof Operand.RegisterOperand arrayOperand) { + Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayOperand.slot()]; + if (arrayLoadInst.indexOperand instanceof Operand.ConstantOperand constant) { + execStack.stack[base + arrayLoadInst.destOperand.slot()] = arrayValue.values.get((int) constant.value); + } + else if (arrayLoadInst.indexOperand instanceof Operand.RegisterOperand registerOperand) { + Value.IntegerValue index = (Value.IntegerValue) execStack.stack[base + registerOperand.slot()]; + execStack.stack[base + arrayLoadInst.destOperand.slot()] = arrayValue.values.get((int) index.value); + } + else throw new IllegalStateException(); + } else throw new IllegalStateException(); + } + case Instruction.SetField setFieldInst -> { + if (setFieldInst.structOperand instanceof Operand.RegisterOperand structOperand) { + Value.StructValue structValue = (Value.StructValue) execStack.stack[base + structOperand.slot()]; + int index = setFieldInst.fieldIndex; + Value value; + if (setFieldInst.sourceOperand instanceof Operand.ConstantOperand constant) { + value = new Value.IntegerValue(constant.value); + } + else if (setFieldInst.sourceOperand instanceof Operand.RegisterOperand registerOperand) { + value = execStack.stack[base + registerOperand.slot()]; + } + else throw new IllegalStateException(); + structValue.fields[index] = value; + } else throw new IllegalStateException(); + } + case Instruction.GetField getFieldInst -> { + if (getFieldInst.structOperand instanceof Operand.RegisterOperand structOperand) { + Value.StructValue structValue = (Value.StructValue) execStack.stack[base + structOperand.slot()]; + int index = getFieldInst.fieldIndex; + execStack.stack[base + getFieldInst.destOperand.slot()] = structValue.fields[index]; + } else throw new IllegalStateException(); + } + case Instruction.ArgInstruction argInst -> {} + default -> throw new IllegalStateException("Unexpected value: " + instruction); + } + } + return returnValue; + } + + static class Frame { + Frame caller; + int base; + BytecodeFunction bytecodeFunction; + + public Frame(Symbol.FunctionTypeSymbol functionSymbol) { + this.caller = null; + this.base = 0; + this.bytecodeFunction = (BytecodeFunction) functionSymbol.code(); + } + + Frame(Frame caller, int base, Type.TypeFunction functionType) { + this.caller = caller; + this.base = base; + this.bytecodeFunction = (BytecodeFunction) functionType.code; + } + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java new file mode 100644 index 0000000..e03fb80 --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java @@ -0,0 +1,30 @@ +package com.compilerprogramming.ezlang.interpreter; + +import com.compilerprogramming.ezlang.types.Type; + +import java.util.ArrayList; + +public class Value { + static public class IntegerValue extends Value { + public IntegerValue(long value) { + this.value = value; + } + public final long value; + } + static public class ArrayValue extends Value { + public final Type.TypeArray arrayType; + public final ArrayList values; + public ArrayValue(Type.TypeArray arrayType) { + this.arrayType = arrayType; + values = new ArrayList<>(); + } + } + static public class StructValue extends Value { + public final Type.TypeStruct structType; + public final Value[] fields; + public StructValue(Type.TypeStruct structType) { + this.structType = structType; + this.fields = new Value[structType.numFields()]; + } + } +} diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java new file mode 100644 index 0000000..4401d4b --- /dev/null +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java @@ -0,0 +1,619 @@ +package com.compilerprogramming.ezlang.bytecode; + +import com.compilerprogramming.ezlang.lexer.Lexer; +import com.compilerprogramming.ezlang.parser.Parser; +import com.compilerprogramming.ezlang.semantic.SemaAssignTypes; +import com.compilerprogramming.ezlang.semantic.SemaDefineTypes; +import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.TypeDictionary; +import org.junit.Assert; +import org.junit.Test; + +import java.util.BitSet; + +public class TestCompiler { + + String compileSrc(String src) { + Parser parser = new Parser(); + var program = parser.parse(new Lexer(src)); + var typeDict = new TypeDictionary(); + var sema = new SemaDefineTypes(typeDict); + sema.analyze(program); + var sema2 = new SemaAssignTypes(typeDict); + sema2.analyze(program); + var byteCodeCompiler = new BytecodeCompiler(); + byteCodeCompiler.compile(typeDict); + StringBuilder sb = new StringBuilder(); + for (Symbol s : typeDict.bindings.values()) { + if (s instanceof Symbol.FunctionTypeSymbol f) { + var functionBuilder = (BytecodeFunction) f.code(); + BasicBlock.toStr(sb, functionBuilder.entry, new BitSet()); + } + } + return sb.toString(); + } + + @Test + public void testFunction1() { + String src = """ + func foo(n: Int)->Int { + return 1; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %ret = 1 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction2() { + String src = """ + func foo(n: Int)->Int { + return -1; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %ret = -1 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction3() { + String src = """ + func foo(n: Int)->Int { + return n; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %ret = n + goto L1 +L1: +""", result); + } + + @Test + public void testFunction4() { + String src = """ + func foo(n: Int)->Int { + return -n; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %t2 = -n + %ret = %t2 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction5() { + String src = """ + func foo(n: Int)->Int { + return n+1; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %t2 = n+1 + %ret = %t2 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction6() { + String src = """ + func foo(n: Int)->Int { + return 1+1; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %ret = 2 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction7() { + String src = """ + func foo(n: Int)->Int { + return 1+1-1; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %ret = 1 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction8() { + String src = """ + func foo(n: Int)->Int { + return 2==2; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %ret = 1 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction9() { + String src = """ + func foo(n: Int)->Int { + return 1!=1; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %ret = 0 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction10() { + String src = """ + func foo(n: [Int])->Int { + return n[0]; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %t2 = n[0] + %ret = %t2 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction11() { + String src = """ + func foo(n: [Int])->Int { + return n[0]+n[1]; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %t2 = n[0] + %t3 = n[1] + %t4 = %t2+%t3 + %ret = %t4 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction12() { + String src = """ + func foo()->[Int] { + return new [Int] { 1, 2, 3 }; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + %t1 = New([Int,Int]) + %t1.append(1) + %t1.append(2) + %t1.append(3) + %ret = %t1 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction13() { + String src = """ + func foo(n: Int) -> [Int] { + return new [Int] { n }; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + %t2 = New([Int,Int]) + %t2.append(n) + %ret = %t2 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction14() { + String src = """ + func add(x: Int, y: Int) -> Int { + return x+y; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg x + arg y + %t3 = x+y + %ret = %t3 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction15() { + String src = """ + struct Person + { + var age: Int + var children: Int + } + func foo(p: Person) -> Person { + p.age = 10; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg p + p.age = 10 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction16() { + String src = """ + struct Person + { + var age: Int + var children: Int + } + func foo() -> Person { + return new Person { age=10, children=0 }; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + %t1 = New(Person) + %t1.age = 10 + %t1.children = 0 + %ret = %t1 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction17() { + String src = """ + func foo(array: [Int]) { + array[0] = 1 + array[1] = 2 + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg array + array[0] = 1 + array[1] = 2 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction18() { + String src = """ + func min(x: Int, y: Int) -> Int { + if (x < y) + return x; + return y; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg x + arg y + %t3 = x 0) { + n = n - 1; + } + return; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg n + goto L2 +L2: + %t2 = n>0 + if %t2 goto L3 else goto L4 +L3: + %t3 = n-1 + n = %t3 + goto L2 +L4: + goto L1 +L1: +""", result); + } + + @Test + public void testFunction22() { + String src = """ + func foo() {} + func bar() { foo(); } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + goto L1 +L1: +L0: + call foo + goto L1 +L1: +""", result); + } + + @Test + public void testFunction23() { + String src = """ + func foo(x: Int, y: Int) {} + func bar() { foo(1,2); } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg x + arg y + goto L1 +L1: +L0: + %t1 = 1 + %t2 = 2 + call foo params %t1, %t2 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction24() { + String src = """ + func foo(x: Int, y: Int)->Int { return x+y; } + func bar()->Int { var t = foo(1,2); return t+1; } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg x + arg y + %t3 = x+y + %ret = %t3 + goto L1 +L1: +L0: + %t2 = 1 + %t3 = 2 + %t4 = call foo params %t2, %t3 + t = %t4 + %t5 = t+1 + %ret = %t5 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction25() { + String src = """ + struct Person + { + var age: Int + var children: Int + } + func foo(p: Person) -> Int { + return p.age; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg p + %t2 = p.age + %ret = %t2 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction26() { + String src = """ + struct Person + { + var age: Int + var parent: Person + } + func foo(p: Person) -> Int { + return p.parent.age; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg p + %t2 = p.parent + %t3 = %t2.age + %ret = %t3 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction27() { + String src = """ + struct Person + { + var age: Int + var parent: Person + } + func foo(p: [Person], i: Int) -> Int { + return p[i].parent.age; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg p + arg i + %t3 = p[i] + %t4 = %t3.parent + %t5 = %t4.age + %ret = %t5 + goto L1 +L1: +""", result); + } + + @Test + public void testFunction28() { + String src = """ + func foo(x: Int, y: Int)->Int { return x+y; } + func bar(a: Int)->Int { var t = foo(a,2); return t+1; } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + arg x + arg y + %t3 = x+y + %ret = %t3 + goto L1 +L1: +L0: + arg a + %t3 = a + %t4 = 2 + %t5 = call foo params %t3, %t4 + t = %t5 + %t6 = t+1 + %ret = %t6 + goto L1 +L1: +""", result); + } +} diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestDominators.java b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestDominators.java new file mode 100644 index 0000000..293bd72 --- /dev/null +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestDominators.java @@ -0,0 +1,94 @@ +package com.compilerprogramming.ezlang.bytecode; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +public class TestDominators { + BasicBlock add(List nodes, BasicBlock node) { + nodes.add(node); + return node; + } + + BasicBlock makeGraph(List nodes) { + BasicBlock r0 = add(nodes, new BasicBlock(1)); + BasicBlock r1 = add(nodes, new BasicBlock(2, r0)); + BasicBlock r2 = add(nodes, new BasicBlock(3, r1)); + BasicBlock r3 = add(nodes, new BasicBlock(4, r2)); + BasicBlock r4 = add(nodes, new BasicBlock(5, r3)); + BasicBlock r5 = add(nodes, new BasicBlock(6, r1)); + BasicBlock r6 = add(nodes, new BasicBlock(7, r5)); + BasicBlock r7 = add(nodes, new BasicBlock(8, r6)); + BasicBlock r8 = add(nodes, new BasicBlock(9, r5)); + r8.addSuccessor(r7); + r7.addSuccessor(r3); + r3.addSuccessor(r1); + return r0; + } + + @Test + public void testDominatorTree() { + List nodes = new ArrayList<>(); + BasicBlock root = makeGraph(nodes); + DominatorTree tree = new DominatorTree(root); + System.out.println(tree.generateDotOutput()); + long[] expectedIdoms = {0,1,1,2,2,4,2,6,6,6}; + for (BasicBlock n: nodes) { + Assert.assertEquals(expectedIdoms[(int)n.bid], n.idom.bid); + } + } + + BasicBlock makeGraph2(List nodes) { + + BasicBlock r1 = add(nodes, new BasicBlock(1)); + BasicBlock r2 = add(nodes, new BasicBlock(2, r1)); + BasicBlock r3 = add(nodes, new BasicBlock(3, r2)); + BasicBlock r4 = add(nodes, new BasicBlock(4, r2)); + BasicBlock r5 = add(nodes, new BasicBlock(5, r4)); + BasicBlock r6 = add(nodes, new BasicBlock(6, r4)); + BasicBlock r7 = add(nodes, new BasicBlock(7, r5, r6)); + BasicBlock r8 = add(nodes, new BasicBlock(8, r5)); + BasicBlock r9 = add(nodes, new BasicBlock(9, r8)); + BasicBlock r10 = add(nodes, new BasicBlock(10, r9)); + BasicBlock r11 = add(nodes, new BasicBlock(11, r7)); + BasicBlock r12 = add(nodes, new BasicBlock(12, r10, r11)); + + r3.addSuccessor(r2); + r4.addSuccessor(r2); + r10.addSuccessor(r5); + r9.addSuccessor(r8); + return r1; + } + + public String generateDotOutput(List nodes) { + StringBuilder sb = new StringBuilder(); + sb.append("digraph g {\n"); + for (BasicBlock n: nodes) + sb.append(n.uniqueName()).append(";\n"); + for (BasicBlock n: nodes) { + for (BasicBlock use: n.successors) { + sb.append(n.uniqueName()).append("->").append(use.uniqueName()).append(";\n"); + } + } + sb.append("}\n"); + return sb.toString(); + } + + @Test + public void testLoopNests() { + List nodes = new ArrayList<>(); + BasicBlock root = makeGraph2(nodes); + System.out.println(generateDotOutput(nodes)); + DominatorTree tree = new DominatorTree(root); + List loopNests = LoopFinder.findLoops(nodes); + Assert.assertEquals(2, loopNests.get(0)._loopHead.bid); + Assert.assertEquals(2, loopNests.get(1)._loopHead.bid); + Assert.assertEquals(5, loopNests.get(2)._loopHead.bid); + Assert.assertEquals(8, loopNests.get(3)._loopHead.bid); + List loops = LoopFinder.mergeLoopsWithSameHead(loopNests); + return; + } + +} diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestInterpreter.java b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestInterpreter.java new file mode 100644 index 0000000..a2fdd69 --- /dev/null +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestInterpreter.java @@ -0,0 +1,127 @@ +package com.compilerprogramming.ezlang.bytecode; + +import com.compilerprogramming.ezlang.interpreter.Interpreter; +import com.compilerprogramming.ezlang.interpreter.Value; +import com.compilerprogramming.ezlang.lexer.Lexer; +import com.compilerprogramming.ezlang.parser.Parser; +import com.compilerprogramming.ezlang.semantic.SemaAssignTypes; +import com.compilerprogramming.ezlang.semantic.SemaDefineTypes; +import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.TypeDictionary; +import org.junit.Assert; +import org.junit.Test; + +import java.util.BitSet; + +public class TestInterpreter { + + Value compileAndRun(String src, String mainFunction) { + Parser parser = new Parser(); + var program = parser.parse(new Lexer(src)); + var typeDict = new TypeDictionary(); + var sema = new SemaDefineTypes(typeDict); + sema.analyze(program); + var sema2 = new SemaAssignTypes(typeDict); + sema2.analyze(program); + var byteCodeCompiler = new BytecodeCompiler(); + byteCodeCompiler.compile(typeDict); + StringBuilder sb = new StringBuilder(); + for (Symbol s : typeDict.bindings.values()) { + if (s instanceof Symbol.FunctionTypeSymbol f) { + var functionBuilder = (BytecodeFunction) f.code(); + BasicBlock.toStr(sb, functionBuilder.entry, new BitSet()); + } + } + System.out.println(sb.toString()); + var interpreter = new Interpreter(typeDict); + return interpreter.run(mainFunction); + } + + @Test + public void testFunction1() { + String src = """ + func foo()->Int { + return 42; + } + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 42); + } + + @Test + public void testFunction2() { + String src = """ + func bar()->Int { + return 42; + } + func foo()->Int { + return bar(); + } + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 42); + } + + @Test + public void testFunction3() { + String src = """ + func negate(n: Int)->Int { + return -n; + } + func foo()->Int { + return negate(42); + } + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == -42); + } + + @Test + public void testFunction4() { + String src = """ + func foo(x: Int, y: Int)->Int { return x+y; } + func bar()->Int { var t = foo(1,2); return t+1; } + """; + var value = compileAndRun(src, "bar"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 4); + } + + @Test + public void testFunction5() { + String src = """ + func bar()->Int { var t = new [Int] {1,21,3}; return t[1]; } + """; + var value = compileAndRun(src, "bar"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 21); + } + + @Test + public void testFunction6() { + String src = """ + struct Test + { + var field: Int + } + func foo()->Test + { + var test = new Test{ field = 42 } + return test + } + func bar()->Int { return foo().field } + """; + var value = compileAndRun(src, "bar"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 42); + } +} diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestSSATransform.java b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestSSATransform.java new file mode 100644 index 0000000..70b82fe --- /dev/null +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestSSATransform.java @@ -0,0 +1,419 @@ +package com.compilerprogramming.ezlang.bytecode; + +import com.compilerprogramming.ezlang.lexer.Lexer; +import com.compilerprogramming.ezlang.parser.Parser; +import com.compilerprogramming.ezlang.semantic.SemaAssignTypes; +import com.compilerprogramming.ezlang.semantic.SemaDefineTypes; +import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.TypeDictionary; +import org.junit.Assert; +import org.junit.Test; + +import java.util.BitSet; + +public class TestSSATransform { + + String compileSrc(String src) { + Parser parser = new Parser(); + var program = parser.parse(new Lexer(src)); + var typeDict = new TypeDictionary(); + var sema = new SemaDefineTypes(typeDict); + sema.analyze(program); + var sema2 = new SemaAssignTypes(typeDict); + sema2.analyze(program); + var byteCodeCompiler = new BytecodeCompiler(); + byteCodeCompiler.compile(typeDict); + StringBuilder sb = new StringBuilder(); + for (Symbol s : typeDict.bindings.values()) { + if (s instanceof Symbol.FunctionTypeSymbol f) { + var functionBuilder = (BytecodeFunction) f.code(); + sb.append("func ").append(f.name).append("\n"); + sb.append("Before SSA\n"); + sb.append("==========\n"); + BasicBlock.toStr(sb, functionBuilder.entry, new BitSet()); + new SSATransform(functionBuilder); + sb.append("After SSA\n"); + sb.append("=========\n"); + BasicBlock.toStr(sb, functionBuilder.entry, new BitSet()); + } + } + return sb.toString(); + } + + @Test + public void test1() { + String src = """ + func foo(d: Int) { + var a = 42; + var b = a; + var c = a + b; + a = c + 23; + c = a + d; + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +func foo +Before SSA +========== +L0: + arg d + a = 42 + b = a + %t5 = a+b + c = %t5 + %t6 = c+23 + a = %t6 + %t7 = a+d + c = %t7 + goto L1 +L1: +After SSA +========= +L0: + arg d + a = 42 + b = a + %t5 = a+b + c = %t5 + %t6 = c+23 + a_1 = %t6 + %t7 = a_1+d + c_1 = %t7 + goto L1 +L1: +""", result); + + } + @Test + public void test2() { + String src = """ + func foo(d: Int)->Int { + var a = 42 + if (d) + { + a = a + 1 + } + else + { + a = a - 1 + } + return a + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +func foo +Before SSA +========== +L0: + arg d + a = 42 + if d goto L2 else goto L3 +L2: + %t3 = a+1 + a = %t3 + goto L4 +L4: + %ret = a + goto L1 +L1: +L3: + %t4 = a-1 + a = %t4 + goto L4 +After SSA +========= +L0: + arg d + a = 42 + if d goto L2 else goto L3 +L2: + %t3 = a+1 + a_2 = %t3 + goto L4 +L4: + a_3 = phi(a_2, a_1) + %ret = a_3 + goto L1 +L1: +L3: + %t4 = a-1 + a_1 = %t4 + goto L4 +""", result); + + } + @Test + public void test3() { + String src = """ + func factorial(num: Int)->Int { + var result = 1 + while (num > 1) + { + result = result * num + num = num - 1 + } + return result + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +func factorial +Before SSA +========== +L0: + arg num + result = 1 + goto L2 +L2: + %t3 = num>1 + if %t3 goto L3 else goto L4 +L3: + %t4 = result*num + result = %t4 + %t5 = num-1 + num = %t5 + goto L2 +L4: + %ret = result + goto L1 +L1: +After SSA +========= +L0: + arg num + result = 1 + goto L2 +L2: + result_1 = phi(result, result_2) + num_1 = phi(num, num_2) + %t3 = num_1>1 + if %t3 goto L3 else goto L4 +L3: + %t4 = result_1*num_1 + result_2 = %t4 + %t5 = num_1-1 + num_2 = %t5 + goto L2 +L4: + %ret = result_1 + goto L1 +L1: +""", result); + + } + + @Test + public void test4() { + String src = """ + func print(a: Int, b: Int, c:Int, d:Int) {} + func example14_66(p: Int, q: Int, r: Int, s: Int, t: Int) { + var i = 1 + var j = 1 + var k = 1 + var l = 1 + while (1) { + if (p) { + j = i + if (q) { + l = 2 + } + else { + l = 3 + } + k = k + 1 + } + else { + k = k + 2 + } + print(i,j,k,l) + while (1) { + if (r) { + l = l + 4 + } + if (!s) + break + } + i = i + 6 + if (!t) + break + } + } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +func print +Before SSA +========== +L0: + arg a + arg b + arg c + arg d + goto L1 +L1: +After SSA +========= +L0: + arg a + arg b + arg c + arg d + goto L1 +L1: +func example14_66 +Before SSA +========== +L0: + arg p + arg q + arg r + arg s + arg t + i = 1 + j = 1 + k = 1 + l = 1 + goto L2 +L2: + if 1 goto L3 else goto L4 +L3: + if p goto L5 else goto L6 +L5: + j = i + if q goto L8 else goto L9 +L8: + l = 2 + goto L10 +L10: + %t10 = k+1 + k = %t10 + goto L7 +L7: + %t12 = i + %t13 = j + %t14 = k + %t15 = l + call print params %t12, %t13, %t14, %t15 + goto L11 +L11: + if 1 goto L12 else goto L13 +L12: + if r goto L14 else goto L15 +L14: + %t16 = l+4 + l = %t16 + goto L15 +L15: + %t17 = !s + if %t17 goto L16 else goto L17 +L16: + goto L13 +L13: + %t18 = i+6 + i = %t18 + %t19 = !t + if %t19 goto L18 else goto L19 +L18: + goto L4 +L4: + goto L1 +L1: +L19: + goto L2 +L17: + goto L11 +L9: + l = 3 + goto L10 +L6: + %t11 = k+2 + k = %t11 + goto L7 +After SSA +========= +L0: + arg p + arg q + arg r + arg s + arg t + i = 1 + j = 1 + k = 1 + l = 1 + goto L2 +L2: + l_1 = phi(l, l_9) + k_1 = phi(k, k_4) + j_1 = phi(j, j_3) + i_1 = phi(i, i_2) + if 1 goto L3 else goto L4 +L3: + if p goto L5 else goto L6 +L5: + j_2 = i_1 + if q goto L8 else goto L9 +L8: + l_3 = 2 + goto L10 +L10: + l_4 = phi(l_3, l_2) + %t10 = k_1+1 + k_3 = %t10 + goto L7 +L7: + l_5 = phi(l_4, l_1) + k_4 = phi(k_3, k_2) + j_3 = phi(j_2, j_1) + %t12 = i_1 + %t13 = j_3 + %t14 = k_4 + %t15 = l_5 + call print params %t12, %t13, %t14, %t15 + goto L11 +L11: + l_6 = phi(l_5, l_8) + if 1 goto L12 else goto L13 +L12: + if r goto L14 else goto L15 +L14: + %t16 = l_6+4 + l_7 = %t16 + goto L15 +L15: + l_8 = phi(l_6, l_7) + %t17 = !s + if %t17 goto L16 else goto L17 +L16: + goto L13 +L13: + l_9 = phi(l_6, l_8) + %t18 = i_1+6 + i_2 = %t18 + %t19 = !t + if %t19 goto L18 else goto L19 +L18: + goto L4 +L4: + l_10 = phi(l_1, l_9) + k_5 = phi(k_1, k_4) + j_4 = phi(j_1, j_3) + i_3 = phi(i_1, i_2) + goto L1 +L1: +L19: + goto L2 +L17: + goto L11 +L9: + l_2 = 3 + goto L10 +L6: + %t11 = k_1+2 + k_2 = %t11 + goto L7 +""", result); + } +} diff --git a/pom.xml b/pom.xml index b972fb6..4ea1d08 100644 --- a/pom.xml +++ b/pom.xml @@ -27,6 +27,7 @@ semantic stackvm registervm + optvm diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java index 0513127..a83a5f8 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java @@ -2,6 +2,7 @@ import com.compilerprogramming.ezlang.exceptions.CompilerException; import com.compilerprogramming.ezlang.parser.AST; +import com.compilerprogramming.ezlang.types.Register; import com.compilerprogramming.ezlang.types.Scope; import com.compilerprogramming.ezlang.types.Symbol; import com.compilerprogramming.ezlang.types.Type; @@ -59,7 +60,7 @@ private void setVirtualRegisters(Scope scope) { reg = scope.parent.maxReg; for (Symbol symbol: scope.getLocalSymbols()) { if (symbol instanceof Symbol.VarSymbol varSymbol) { - varSymbol.reg = reg++; + varSymbol.reg = new Register(reg++, 0, varSymbol.name, varSymbol.type); } } scope.maxReg = reg; @@ -145,7 +146,7 @@ private void compileAssign(AST.AssignStmt assignStmt) { codeIndexedStore(); else if (assignStmt.lhs instanceof AST.NameExpr symbolExpr) { Symbol.VarSymbol varSymbol = (Symbol.VarSymbol) symbolExpr.symbol; - code(new Instruction.Move(pop(), new Operand.LocalRegisterOperand(varSymbol.reg, varSymbol.name))); + code(new Instruction.Move(pop(), new Operand.LocalRegisterOperand(varSymbol.reg.slot, varSymbol.name))); } else throw new CompilerException("Invalid assignment expression: " + assignStmt.lhs); @@ -240,7 +241,7 @@ private void compileLet(AST.VarStmt letStmt) { boolean indexed = compileExpr(letStmt.expr); if (indexed) codeIndexedLoad(); - code(new Instruction.Move(pop(), new Operand.LocalRegisterOperand(letStmt.symbol.reg, letStmt.symbol.name))); + code(new Instruction.Move(pop(), new Operand.LocalRegisterOperand(letStmt.symbol.reg.slot, letStmt.symbol.name))); } } @@ -402,7 +403,7 @@ private boolean compileSymbolExpr(AST.NameExpr symbolExpr) { pushOperand(new Operand.LocalFunctionOperand(functionType)); else { Symbol.VarSymbol varSymbol = (Symbol.VarSymbol) symbolExpr.symbol; - pushLocal(varSymbol.reg, varSymbol.name); + pushLocal(varSymbol.reg.slot, varSymbol.name); } return false; } diff --git a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java index c51dcdf..1c126b1 100644 --- a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java +++ b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java @@ -41,7 +41,7 @@ public ASTVisitor visit(AST.FuncDecl funcDecl, boolean enter) { throw new CompilerException("Symbol " + funcDecl.name + " is already declared"); } // Create function scope, that houses function parameters - currentScope = new Scope(currentScope); + currentScope = new Scope(currentScope, true); funcDecl.scope = currentScope; // Install a symbol for the function, // type is not fully formed at this stage @@ -116,7 +116,7 @@ else if (varDecl.varType == AST.VarType.FUNCTION_PARAMETER if (currentScope.localLookup(varDecl.name) != null) throw new CompilerException("Parameter " + varDecl.name + " is already declared"); Type.TypeFunction type = (Type.TypeFunction) currentFuncDecl.symbol.type; - varDecl.symbol = currentScope.install(varDecl.name, new Symbol.VarSymbol(varDecl.name, varDecl.typeExpr.type)); + varDecl.symbol = currentScope.install(varDecl.name, new Symbol.ParameterSymbol(varDecl.name, varDecl.typeExpr.type)); type.addArg(varDecl.symbol); } else if (varDecl.varType == AST.VarType.VARIABLE) { diff --git a/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java b/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java index ea26f5f..f7d6668 100644 --- a/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java +++ b/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java @@ -2,6 +2,7 @@ import com.compilerprogramming.ezlang.exceptions.CompilerException; import com.compilerprogramming.ezlang.parser.AST; +import com.compilerprogramming.ezlang.types.Register; import com.compilerprogramming.ezlang.types.Scope; import com.compilerprogramming.ezlang.types.Symbol; import com.compilerprogramming.ezlang.types.Type; @@ -32,7 +33,7 @@ private void setVirtualRegisters(Scope scope) { reg = scope.parent.maxReg; for (Symbol symbol: scope.getLocalSymbols()) { if (symbol instanceof Symbol.VarSymbol varSymbol) { - varSymbol.reg = reg++; + varSymbol.reg = new Register(reg, 0, varSymbol.name, varSymbol.type); } } scope.maxReg = reg; @@ -112,7 +113,7 @@ private void compileAssign(AST.AssignStmt assignStmt) { code(new Instruction.StoreIndexed()); else if (assignStmt.lhs instanceof AST.NameExpr symbolExpr) { Symbol.VarSymbol varSymbol = (Symbol.VarSymbol) symbolExpr.symbol; - code(new Instruction.Store(varSymbol.reg)); + code(new Instruction.Store(varSymbol.reg.slot)); } else throw new CompilerException("Invalid assignment expression: " + assignStmt.lhs); @@ -203,7 +204,7 @@ private void compileLet(AST.VarStmt letStmt) { boolean indexed = compileExpr(letStmt.expr); if (indexed) code(new Instruction.LoadIndexed()); - code(new Instruction.Store(letStmt.symbol.reg)); + code(new Instruction.Store(letStmt.symbol.reg.slot)); } } @@ -322,7 +323,7 @@ private boolean compileSymbolExpr(AST.NameExpr symbolExpr) { code(new Instruction.LoadFunction(functionType)); else { Symbol.VarSymbol varSymbol = (Symbol.VarSymbol) symbolExpr.symbol; - code(new Instruction.LoadVar(varSymbol.reg)); + code(new Instruction.LoadVar(varSymbol.reg.slot)); } return false; } diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/Register.java b/types/src/main/java/com/compilerprogramming/ezlang/types/Register.java new file mode 100644 index 0000000..e65e03a --- /dev/null +++ b/types/src/main/java/com/compilerprogramming/ezlang/types/Register.java @@ -0,0 +1,44 @@ +package com.compilerprogramming.ezlang.types; + +import java.util.Objects; + +public class Register { + /** + * Slot in the function's frame + */ + public final int slot; + /** + * Unique virtual ID + */ + public final int id; + public final int ssaVersion; + public final String name; + public final Type type; + + public Register(int slot, int id, String name, Type type) { + this(slot, id, name, type, 0); + } + public Register(int slot, int id, String name, Type type, int ssaVersion) { + this.slot = slot; + this.id = id; + this.name = name; + this.type = type; + this.ssaVersion = ssaVersion; + } + public Register cloneWithVersion(int ssaVersion) { + return new Register(this.slot, this.id, this.name+"_"+ssaVersion, this.type, ssaVersion); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Register register = (Register) o; + return id == register.id && ssaVersion == register.ssaVersion; + } + + @Override + public int hashCode() { + return Objects.hash(id, ssaVersion); + } +} diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/Scope.java b/types/src/main/java/com/compilerprogramming/ezlang/types/Scope.java index e6d5b5a..70b6b2d 100644 --- a/types/src/main/java/com/compilerprogramming/ezlang/types/Scope.java +++ b/types/src/main/java/com/compilerprogramming/ezlang/types/Scope.java @@ -13,13 +13,19 @@ public class Scope { // values assigned by compiler public int maxReg; + public final boolean isFunctionParameterScope; - public Scope(Scope parent) { + public Scope(Scope parent, boolean isFunctionParameterScope) { this.parent = parent; + this.isFunctionParameterScope = isFunctionParameterScope; if (parent != null) parent.children.add(this); } + public Scope(Scope parent) { + this(parent, false); + } + public Symbol lookup(String name) { Symbol symbol = bindings.get(name); if (symbol == null && parent != null) diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java b/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java index 0252f7b..dc1de20 100644 --- a/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java +++ b/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java @@ -35,9 +35,15 @@ public Object code() { public static class VarSymbol extends Symbol { // Values assigned by bytecode compiler - public int reg; + public Register reg; public VarSymbol(String name, Type type) { super(name, type); } } + + public static class ParameterSymbol extends VarSymbol { + public ParameterSymbol(String name, Type type) { + super(name, type); + } + } }