diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java index 418db05..06c42b5 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java @@ -460,13 +460,10 @@ private boolean codeBoolean(AST.BinaryExpr binaryExpr) { jumpTo(l3); startBlock(l2); // Below we must write to the same temp - //code(new Instruction.Move(new Operand.ConstantOperand(isAnd ? 0 : 1, typeDictionary.INT), new Operand.TempRegisterOperand(temp.reg))); code(new Instruction.Move(new Operand.ConstantOperand(isAnd ? 0 : 1, typeDictionary.INT), temp)); jumpTo(l3); startBlock(l3); // leave temp on virtual stack -// var temp2 = (Operand.TempRegisterOperand) pop(); -// pushOperand(new Operand.TempRegisterOperand(temp2.reg)); return false; } @@ -483,10 +480,19 @@ private boolean compileBinaryExpr(AST.BinaryExpr binaryExpr) { indexed = compileExpr(binaryExpr.expr2); if (indexed) codeIndexedLoad(); - opCode = binaryExpr.op.str; Operand right = pop(); Operand left = pop(); - if (left instanceof Operand.ConstantOperand leftconstant && + if (left instanceof Operand.NullConstantOperand && + right instanceof Operand.NullConstantOperand) { + long value = 0; + switch (opCode) { + case "==": value = 1; break; + case "!=": value = 0; break; + default: throw new CompilerException("Invalid binary op"); + } + pushConstant(value, typeDictionary.INT); + } + else if (left instanceof Operand.ConstantOperand leftconstant && right instanceof Operand.ConstantOperand rightconstant) { long value = 0; switch (opCode) { @@ -535,7 +541,11 @@ private boolean compileUnaryExpr(AST.UnaryExpr unaryExpr) { } private boolean compileConstantExpr(AST.LiteralExpr constantExpr) { - pushConstant(constantExpr.value.num.intValue(), constantExpr.type); + if (constantExpr.type instanceof Type.TypeInteger) + pushConstant(constantExpr.value.num.intValue(), constantExpr.type); + else if (constantExpr.type instanceof Type.TypeNull) + pushNullConstant(constantExpr.type); + else throw new CompilerException("Invalid constant type"); return false; } @@ -543,6 +553,10 @@ private void pushConstant(long value, Type type) { pushOperand(new Operand.ConstantOperand(value, type)); } + private void pushNullConstant(Type type) { + pushOperand(new Operand.NullConstantOperand(type)); + } + private Operand.TempRegisterOperand createTemp(Type type) { var tempRegister = new Operand.TempRegisterOperand(registerPool.newTempReg(type)); pushOperand(tempRegister); @@ -552,8 +566,8 @@ private Operand.TempRegisterOperand createTemp(Type type) { Type typeOfOperand(Operand operand) { if (operand instanceof Operand.ConstantOperand constant) return constant.type; -// else if (operand instanceof Operand.NullConstantOperand nullConstantOperand) -// return nullConstantOperand.type; + else if (operand instanceof Operand.NullConstantOperand nullConstantOperand) + return nullConstantOperand.type; else if (operand instanceof Operand.RegisterOperand registerOperand) return registerOperand.type; else throw new CompilerException("Invalid operand"); @@ -569,7 +583,7 @@ private Operand.TempRegisterOperand createTempAndMove(Operand src) { private Operand.RegisterOperand ensureTemp() { Operand top = top(); if (top instanceof Operand.ConstantOperand - //|| top instanceof Operand.NullConstantOperand + || top instanceof Operand.NullConstantOperand || top instanceof Operand.LocalRegisterOperand) { return createTempAndMove(pop()); } else if (top instanceof Operand.IndexedOperand) { diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java index 1714f6b..1749d49 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java @@ -65,7 +65,7 @@ public void rename(Integer source, Integer target) { var toSet = edges.get(target); if (toSet == null) { //throw new RuntimeException("Cannot find edge " + target + " from " + source); - return; // FIXME this is workaround to handle sceanrio where target is arg register but we need a better way + return; // FIXME this is workaround to handle scenario where target is arg register but we need a better way } toSet.addAll(fromSet); // If any node interfered with from diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java index b4bde07..57a691e 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java @@ -18,6 +18,16 @@ public String toString() { } } + public static class NullConstantOperand extends Operand { + public NullConstantOperand(Type type) { + this.type = type; + } + @Override + public String toString() { + return "null"; + } + } + public static class RegisterOperand extends Operand { final Register reg; protected RegisterOperand(Register reg) { diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SparseConditionalConstantPropagation.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SparseConditionalConstantPropagation.java index 73a04fc..c83ee31 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SparseConditionalConstantPropagation.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SparseConditionalConstantPropagation.java @@ -377,35 +377,40 @@ private boolean evalInstruction(Instruction instruction) { } case Instruction.Binary binaryInst -> { var cell = valueLattice.get(binaryInst.result().reg); - LatticeElement left, right; + LatticeElement left = null; + LatticeElement right = null; + // TODO we cannot yet evaluate null in comparisons if (binaryInst.left() instanceof Operand.ConstantOperand constant) left = new LatticeElement(V_CONSTANT, constant.value); else if (binaryInst.left() instanceof Operand.RegisterOperand registerOperand) left = valueLattice.get(registerOperand.reg); - else throw new IllegalStateException(); if (binaryInst.right() instanceof Operand.ConstantOperand constant) right = new LatticeElement(V_CONSTANT, constant.value); else if (binaryInst.right() instanceof Operand.RegisterOperand registerOperand) right = valueLattice.get(registerOperand.reg); - else throw new IllegalStateException(); - switch (binaryInst.binOp) { - case "+": - case "-": - case "*": - case "/": - case "%": - changed = evalArith(cell, left, right, binaryInst.binOp); - break; - case "==": - case "!=": - case "<": - case ">": - case "<=": - case ">=": - changed = evalLogical(cell, left, right, binaryInst.binOp); - break; - default: - throw new IllegalStateException(); + if (left != null && right != null) { + switch (binaryInst.binOp) { + case "+": + case "-": + case "*": + case "/": + case "%": + changed = evalArith(cell, left, right, binaryInst.binOp); + break; + case "==": + case "!=": + case "<": + case ">": + case "<=": + case ">=": + changed = evalLogical(cell, left, right, binaryInst.binOp); + break; + default: + throw new IllegalStateException(); + } + } + else { + cell.setKind(V_VARYING); } } case Instruction.NewArray newArrayInst -> { diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java index 6aec1b8..d59e1f0 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java @@ -48,6 +48,9 @@ public Value interpret(ExecutionStack execStack, Frame frame) { if (retInst.value() instanceof Operand.ConstantOperand constantOperand) { execStack.stack[base] = new Value.IntegerValue(constantOperand.value); } + else if (retInst.value() instanceof Operand.NullConstantOperand) { + execStack.stack[base] = new Value.NullValue(); + } else if (retInst.value() instanceof Operand.RegisterOperand registerOperand) { execStack.stack[base] = execStack.stack[base+registerOperand.frameSlot()]; } @@ -63,6 +66,9 @@ else if (retInst.value() instanceof Operand.RegisterOperand registerOperand) { else if (moveInst.from() instanceof Operand.ConstantOperand constantOperand) { execStack.stack[base + toReg.frameSlot()] = new Value.IntegerValue(constantOperand.value); } + else if (moveInst.from() instanceof Operand.NullConstantOperand) { + execStack.stack[base + toReg.frameSlot()] = new Value.NullValue(); + } else throw new IllegalStateException(); } else throw new IllegalStateException(); @@ -107,6 +113,9 @@ else if (cbrInst.condition() instanceof Operand.ConstantOperand constantOperand) else if (arg instanceof Operand.ConstantOperand constantOperand) { execStack.stack[base + reg] = new Value.IntegerValue(constantOperand.value); } + else if (arg instanceof Operand.NullConstantOperand) { + execStack.stack[base + reg] = new Value.NullValue(); + } reg += 1; } // Call function @@ -135,31 +144,60 @@ else if (arg instanceof Operand.ConstantOperand constantOperand) { case Instruction.Binary binaryInst -> { long x, y; long value = 0; - if (binaryInst.left() instanceof Operand.ConstantOperand constant) - x = constant.value; - else if (binaryInst.left() instanceof Operand.RegisterOperand registerOperand) - x = ((Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]).value; - else throw new IllegalStateException(); - if (binaryInst.right() instanceof Operand.ConstantOperand constant) - y = constant.value; - else if (binaryInst.right() instanceof Operand.RegisterOperand registerOperand) - y = ((Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]).value; - else throw new IllegalStateException(); - switch (binaryInst.binOp) { - case "+": value = x + y; break; - case "-": value = x - y; break; - case "*": value = x * y; break; - case "/": value = x / y; break; - case "%": value = x % y; break; - case "==": value = x == y ? 1 : 0; break; - case "!=": value = x != y ? 1 : 0; break; - case "<": value = x < y ? 1: 0; break; - case ">": value = x > y ? 1 : 0; break; - case "<=": value = x <= y ? 1 : 0; break; - case ">=": value = x <= y ? 1 : 0; break; - default: throw new IllegalStateException(); + boolean intOp = true; + if (binaryInst.binOp.equals("==") || binaryInst.binOp.equals("!=")) { + Operand.RegisterOperand nonNullLitOperand = null; + if (binaryInst.left() instanceof Operand.NullConstantOperand) { + nonNullLitOperand = (Operand.RegisterOperand)binaryInst.right(); + } + else if (binaryInst.right() instanceof Operand.NullConstantOperand) { + nonNullLitOperand = (Operand.RegisterOperand)binaryInst.left(); + } + if (nonNullLitOperand != null) { + intOp = false; + Value otherValue = execStack.stack[base + nonNullLitOperand.frameSlot()]; + switch (binaryInst.binOp) { + case "==": { + value = otherValue instanceof Value.NullValue ? 1 : 0; + break; + } + case "!=": { + value = otherValue instanceof Value.NullValue ? 0 : 1; + break; + } + default: + throw new IllegalStateException(); + } + execStack.stack[base + binaryInst.result().frameSlot()] = new Value.IntegerValue(value); + } + } + if (intOp) { + if (binaryInst.left() instanceof Operand.ConstantOperand constant) + x = constant.value; + else if (binaryInst.left() instanceof Operand.RegisterOperand registerOperand) + x = ((Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]).value; + else throw new IllegalStateException(); + if (binaryInst.right() instanceof Operand.ConstantOperand constant) + y = constant.value; + else if (binaryInst.right() instanceof Operand.RegisterOperand registerOperand) + y = ((Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]).value; + else throw new IllegalStateException(); + switch (binaryInst.binOp) { + case "+": value = x + y; break; + case "-": value = x - y; break; + case "*": value = x * y; break; + case "/": value = x / y; break; + case "%": value = x % y; break; + case "==": value = x == y ? 1 : 0; break; + case "!=": value = x != y ? 1 : 0; break; + case "<": value = x < y ? 1: 0; break; + case ">": value = x > y ? 1 : 0; break; + case "<=": value = x <= y ? 1 : 0; break; + case ">=": value = x <= y ? 1 : 0; break; + default: throw new IllegalStateException(); + } + execStack.stack[base + binaryInst.result().frameSlot()] = new Value.IntegerValue(value); } - execStack.stack[base + binaryInst.result().frameSlot()] = new Value.IntegerValue(value); } case Instruction.NewArray newArrayInst -> { execStack.stack[base + newArrayInst.destOperand().frameSlot()] = new Value.ArrayValue(newArrayInst.type); @@ -172,6 +210,9 @@ else if (binaryInst.right() instanceof Operand.RegisterOperand registerOperand) if (arrayAppendInst.value() instanceof Operand.ConstantOperand constant) { arrayValue.values.add(new Value.IntegerValue(constant.value)); } + else if (arrayAppendInst.value() instanceof Operand.NullConstantOperand) { + arrayValue.values.add(new Value.NullValue()); + } else if (arrayAppendInst.value() instanceof Operand.RegisterOperand registerOperand) { arrayValue.values.add(execStack.stack[base + registerOperand.frameSlot()]); } @@ -193,6 +234,9 @@ else if (arrayStoreInst.indexOperand() instanceof Operand.RegisterOperand regist if (arrayStoreInst.sourceOperand() instanceof Operand.ConstantOperand constantOperand) { value = new Value.IntegerValue(constantOperand.value); } + else if (arrayStoreInst.sourceOperand() instanceof Operand.NullConstantOperand) { + value = new Value.NullValue(); + } else if (arrayStoreInst.sourceOperand() instanceof Operand.RegisterOperand registerOperand) { value = execStack.stack[base + registerOperand.frameSlot()]; } @@ -221,6 +265,9 @@ else if (arrayLoadInst.indexOperand() instanceof Operand.RegisterOperand registe if (setFieldInst.sourceOperand() instanceof Operand.ConstantOperand constant) { value = new Value.IntegerValue(constant.value); } + else if (setFieldInst.sourceOperand() instanceof Operand.NullConstantOperand) { + value = new Value.NullValue(); + } else if (setFieldInst.sourceOperand() instanceof Operand.RegisterOperand registerOperand) { value = execStack.stack[base + registerOperand.frameSlot()]; } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java index e03fb80..7f914a6 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java @@ -11,6 +11,9 @@ public IntegerValue(long value) { } public final long value; } + static public class NullValue extends Value { + public NullValue() {} + } static public class ArrayValue extends Value { public final Type.TypeArray arrayType; public final ArrayList values; diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java index 3f8a052..1c4fc5b 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java @@ -2525,6 +2525,100 @@ func foo()->Int %t0_0 = 0 %t0_2 = %t0_0 goto L4 +""", result); + } + + @Test + public void testSSA15() { + String src = """ + struct Foo + { + var i: Int + } + func foo()->Int + { + var f = new [Foo?] { new Foo{i = 1}, null } + return null == f[1] && 1 == f[0].i + } +"""; + String result = compileSrc(src); + Assert.assertEquals(""" +func foo +Before SSA +========== +L0: + %t1 = New([Foo?]) + %t2 = New(Foo) + %t2.i = 1 + %t1.append(%t2) + %t1.append(null) + f = %t1 + %t3 = f[1] + %t4 = null==%t3 + if %t4 goto L2 else goto L3 +L2: + %t5 = f[0] + %t6 = %t5.i + %t7 = 1==%t6 + goto L4 +L4: + ret %t7 + goto L1 +L1: +L3: + %t7 = 0 + goto L4 +After SSA +========= +L0: + %t1_0 = New([Foo?]) + %t2_0 = New(Foo) + %t2_0.i = 1 + %t1_0.append(%t2_0) + %t1_0.append(null) + f_0 = %t1_0 + %t3_0 = f_0[1] + %t4_0 = null==%t3_0 + if %t4_0 goto L2 else goto L3 +L2: + %t5_0 = f_0[0] + %t6_0 = %t5_0.i + %t7_1 = 1==%t6_0 + goto L4 +L4: + %t7_2 = phi(%t7_1, %t7_0) + ret %t7_2 + goto L1 +L1: +L3: + %t7_0 = 0 + goto L4 +After exiting SSA +================= +L0: + %t1_0 = New([Foo?]) + %t2_0 = New(Foo) + %t2_0.i = 1 + %t1_0.append(%t2_0) + %t1_0.append(null) + f_0 = %t1_0 + %t3_0 = f_0[1] + %t4_0 = null==%t3_0 + if %t4_0 goto L2 else goto L3 +L2: + %t5_0 = f_0[0] + %t6_0 = %t5_0.i + %t7_1 = 1==%t6_0 + %t7_2 = %t7_1 + goto L4 +L4: + ret %t7_2 + goto L1 +L1: +L3: + %t7_0 = 0 + %t7_2 = %t7_0 + goto L4 """, result); } } \ No newline at end of file diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java b/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java index e463a29..4f8c0f0 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java @@ -446,78 +446,78 @@ func foo()->Int { } -// @Test -// public void testFunction100() { -// String src = """ -// struct Test -// { -// var field: Int -// } -// func foo()->Test? -// { -// return null; -// } -// -// """; -// var value = compileAndRun(src, "foo", Options.OPT); -// Assert.assertNotNull(value); -// Assert.assertTrue(value instanceof Value.NullValue); -// } -// -// @Test -// public void testFunction101() { -// String src = """ -// func foo()->Int -// { -// return null == null; -// } -// -// """; -// var value = compileAndRun(src, "foo", Options.OPT); -// Assert.assertNotNull(value); -// Assert.assertTrue(value instanceof Value.IntegerValue integerValue && -// integerValue.value == 1); -// } -// -// @Test -// public void testFunction102() { -// String src = """ -// struct Foo -// { -// var next: Foo? -// } -// func foo()->Int -// { -// var f = new Foo{ next = null } -// return null == f.next -// } -// -// """; -// var value = compileAndRun(src, "foo", Options.OPT); -// Assert.assertNotNull(value); -// Assert.assertTrue(value instanceof Value.IntegerValue integerValue && -// integerValue.value == 1); -// } -// -// @Test -// public void testFunction103() { -// String src = """ -// struct Foo -// { -// var i: Int -// } -// func foo()->Int -// { -// var f = new [Foo?] { new Foo{i = 1}, null } -// return null == f[1] && 1 == f[0].i -// } -// -// """; -// var value = compileAndRun(src, "foo", Options.OPT); -// Assert.assertNotNull(value); -// Assert.assertTrue(value instanceof Value.IntegerValue integerValue && -// integerValue.value == 1); -// } + @Test + public void testFunction100() { + String src = """ + struct Test + { + var field: Int + } + func foo()->Test? + { + return null; + } + + """; + var value = compileAndRun(src, "foo", Options.OPT); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.NullValue); + } + + @Test + public void testFunction101() { + String src = """ + func foo()->Int + { + return null == null; + } + + """; + var value = compileAndRun(src, "foo", Options.OPT); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 1); + } + + @Test + public void testFunction102() { + String src = """ + struct Foo + { + var next: Foo? + } + func foo()->Int + { + var f = new Foo{ next = null } + return null == f.next + } + + """; + var value = compileAndRun(src, "foo", Options.OPT); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 1); + } + + @Test + public void testFunction103() { + String src = """ + struct Foo + { + var i: Int + } + func foo()->Int + { + var f = new [Foo?] { new Foo{i = 1}, null } + return null == f[1] && 1 == f[0].i + } + + """; + var value = compileAndRun(src, "foo", Options.OPT); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 1); + } @Test public void testFunction104() { diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java index 56d5f55..9cd5112 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java @@ -5,6 +5,7 @@ import com.compilerprogramming.ezlang.types.Scope; import com.compilerprogramming.ezlang.types.Symbol; import com.compilerprogramming.ezlang.types.Type; +import com.compilerprogramming.ezlang.types.TypeDictionary; import java.util.ArrayList; import java.util.List; @@ -19,6 +20,7 @@ public class CompiledFunction { private BasicBlock currentContinueTarget; public int maxLocalReg; public int maxStackSize; + private final TypeDictionary typeDictionary; /** * We essentially do a form of abstract interpretation as we generate @@ -30,7 +32,7 @@ public class CompiledFunction { */ private List virtualStack = new ArrayList<>(); - public CompiledFunction(Symbol.FunctionTypeSymbol functionSymbol) { + public CompiledFunction(Symbol.FunctionTypeSymbol functionSymbol, TypeDictionary typeDictionary) { AST.FuncDecl funcDecl = (AST.FuncDecl) functionSymbol.functionDecl; setVirtualRegisters(funcDecl.scope); this.BID = 0; @@ -38,6 +40,7 @@ public CompiledFunction(Symbol.FunctionTypeSymbol functionSymbol) { this.exit = createBlock(); this.currentBreakTarget = null; this.currentContinueTarget = null; + this.typeDictionary = typeDictionary; compileStatement(funcDecl.block); exitBlockIfNeeded(); } @@ -408,18 +411,58 @@ private boolean compileSymbolExpr(AST.NameExpr symbolExpr) { return false; } + private boolean codeBoolean(AST.BinaryExpr binaryExpr) { + boolean isAnd = binaryExpr.op.str.equals("&&"); + BasicBlock l1 = createBlock(); + BasicBlock l2 = createBlock(); + BasicBlock l3 = createBlock(); + boolean indexed = compileExpr(binaryExpr.expr1); + if (indexed) + codeIndexedLoad(); + if (isAnd) { + code(new Instruction.ConditionalBranch(currentBlock, pop(), l1, l2)); + } else { + code(new Instruction.ConditionalBranch(currentBlock, pop(), l2, l1)); + } + startBlock(l1); + compileExpr(binaryExpr.expr2); + var temp = ensureTemp(); + jumpTo(l3); + startBlock(l2); + // Below we must write to the same temp + code(new Instruction.Move(new Operand.ConstantOperand(isAnd ? 0 : 1, typeDictionary.INT), temp)); + jumpTo(l3); + startBlock(l3); + // leave temp on virtual stack + return false; + } + + private boolean compileBinaryExpr(AST.BinaryExpr binaryExpr) { - String opCode = null; + String opCode = binaryExpr.op.str; + if (opCode.equals("&&") || + opCode.equals("||")) { + return codeBoolean(binaryExpr); + } boolean indexed = compileExpr(binaryExpr.expr1); if (indexed) codeIndexedLoad(); indexed = compileExpr(binaryExpr.expr2); if (indexed) codeIndexedLoad(); - opCode = binaryExpr.op.str; Operand right = pop(); Operand left = pop(); - if (left instanceof Operand.ConstantOperand leftconstant && + if (left instanceof Operand.NullConstantOperand && + right instanceof Operand.NullConstantOperand) { + long value = 0; + switch (opCode) { + case "==": value = 1; break; + case "!=": value = 0; break; + default: throw new CompilerException("Invalid binary op"); + } + pushConstant(value, typeDictionary.INT); + } + else if (left instanceof Operand.ConstantOperand leftconstant && right instanceof Operand.ConstantOperand rightconstant) { long value = 0; switch (opCode) { @@ -468,7 +511,11 @@ private boolean compileUnaryExpr(AST.UnaryExpr unaryExpr) { } private boolean compileConstantExpr(AST.LiteralExpr constantExpr) { - pushConstant(constantExpr.value.num.intValue(), constantExpr.type); + if (constantExpr.type instanceof Type.TypeInteger) + pushConstant(constantExpr.value.num.intValue(), constantExpr.type); + else if (constantExpr.type instanceof Type.TypeNull) + pushNullConstant(constantExpr.type); + else throw new CompilerException("Invalid constant type"); return false; } @@ -476,6 +523,10 @@ private void pushConstant(long value, Type type) { pushOperand(new Operand.ConstantOperand(value, type)); } + private void pushNullConstant(Type type) { + pushOperand(new Operand.NullConstantOperand(type)); + } + private Operand.TempRegisterOperand createTemp(Type type) { var tempRegister = new Operand.TempRegisterOperand(virtualStack.size()+maxLocalReg, type); pushOperand(tempRegister); @@ -484,6 +535,36 @@ private Operand.TempRegisterOperand createTemp(Type type) { return tempRegister; } + Type typeOfOperand(Operand operand) { + if (operand instanceof Operand.ConstantOperand constant) + return constant.type; + else if (operand instanceof Operand.NullConstantOperand nullConstantOperand) + return nullConstantOperand.type; + else if (operand instanceof Operand.RegisterOperand registerOperand) + return registerOperand.type; + else throw new CompilerException("Invalid operand"); + } + + private Operand.TempRegisterOperand createTempAndMove(Operand src) { + Type type = typeOfOperand(src); + var temp = createTemp(type); + code(new Instruction.Move(src, temp)); + return temp; + } + + private Operand.RegisterOperand ensureTemp() { + Operand top = top(); + if (top instanceof Operand.ConstantOperand + || top instanceof Operand.NullConstantOperand + || top instanceof Operand.LocalRegisterOperand) { + return createTempAndMove(pop()); + } else if (top instanceof Operand.IndexedOperand) { + return codeIndexedLoad(); + } else if (top instanceof Operand.TempRegisterOperand tempRegisterOperand) { + return tempRegisterOperand; + } else throw new CompilerException("Cannot convert to temporary register"); + } + private void pushLocal(int regnum, String varName) { pushOperand(new Operand.LocalRegisterOperand(regnum, varName)); } @@ -500,7 +581,7 @@ private Operand top() { return virtualStack.getLast(); } - private void codeIndexedLoad() { + private Operand.TempRegisterOperand codeIndexedLoad() { Operand indexed = pop(); var temp = createTemp(indexed.type); if (indexed instanceof Operand.LoadIndexedOperand loadIndexedOperand) { @@ -511,6 +592,7 @@ else if (indexed instanceof Operand.LoadFieldOperand loadFieldOperand) { } else code(new Instruction.Move(indexed, temp)); + return temp; } private void codeIndexedStore() { diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java index 5c0b365..37eaea0 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java @@ -16,7 +16,7 @@ private void compile(TypeDictionary typeDictionary) { for (Symbol symbol: typeDictionary.getLocalSymbols()) { if (symbol instanceof Symbol.FunctionTypeSymbol functionSymbol) { Type.TypeFunction functionType = (Type.TypeFunction) functionSymbol.type; - functionType.code = new CompiledFunction(functionSymbol); + functionType.code = new CompiledFunction(functionSymbol, typeDictionary); } } } diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java index 9cb9a70..54188da 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java @@ -18,6 +18,16 @@ public String toString() { } } + public static class NullConstantOperand extends Operand { + public NullConstantOperand(Type type) { + this.type = type; + } + @Override + public String toString() { + return "null"; + } + } + public static abstract class RegisterOperand extends Operand { public final int regnum; public RegisterOperand(int regnum) { diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java index 7be8ef9..5ceaeaa 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java @@ -48,6 +48,9 @@ public Value interpret(ExecutionStack execStack, Frame frame) { if (retInst.value() instanceof Operand.ConstantOperand constantOperand) { execStack.stack[base] = new Value.IntegerValue(constantOperand.value); } + else if (retInst.value() instanceof Operand.NullConstantOperand) { + execStack.stack[base] = new Value.NullValue(); + } else if (retInst.value() instanceof Operand.RegisterOperand registerOperand) { execStack.stack[base] = execStack.stack[base+registerOperand.frameSlot()]; } @@ -63,6 +66,9 @@ else if (retInst.value() instanceof Operand.RegisterOperand registerOperand) { else if (moveInst.from() instanceof Operand.ConstantOperand constantOperand) { execStack.stack[base + toReg.frameSlot()] = new Value.IntegerValue(constantOperand.value); } + else if (moveInst.from() instanceof Operand.NullConstantOperand) { + execStack.stack[base + toReg.frameSlot()] = new Value.NullValue(); + } else throw new IllegalStateException(); } else throw new IllegalStateException(); @@ -107,6 +113,9 @@ else if (cbrInst.condition() instanceof Operand.ConstantOperand constantOperand) else if (arg instanceof Operand.ConstantOperand constantOperand) { execStack.stack[base + reg] = new Value.IntegerValue(constantOperand.value); } + else if (arg instanceof Operand.NullConstantOperand) { + execStack.stack[base + reg] = new Value.NullValue(); + } reg += 1; } // Call function @@ -135,31 +144,60 @@ else if (arg instanceof Operand.ConstantOperand constantOperand) { case Instruction.Binary binaryInst -> { long x, y; long value = 0; - if (binaryInst.left() instanceof Operand.ConstantOperand constant) - x = constant.value; - else if (binaryInst.left() instanceof Operand.RegisterOperand registerOperand) - x = ((Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]).value; - else throw new IllegalStateException(); - if (binaryInst.right() instanceof Operand.ConstantOperand constant) - y = constant.value; - else if (binaryInst.right() instanceof Operand.RegisterOperand registerOperand) - y = ((Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]).value; - else throw new IllegalStateException(); - switch (binaryInst.binOp) { - case "+": value = x + y; break; - case "-": value = x - y; break; - case "*": value = x * y; break; - case "/": value = x / y; break; - case "%": value = x % y; break; - case "==": value = x == y ? 1 : 0; break; - case "!=": value = x != y ? 1 : 0; break; - case "<": value = x < y ? 1: 0; break; - case ">": value = x > y ? 1 : 0; break; - case "<=": value = x <= y ? 1 : 0; break; - case ">=": value = x <= y ? 1 : 0; break; - default: throw new IllegalStateException(); + boolean intOp = true; + if (binaryInst.binOp.equals("==") || binaryInst.binOp.equals("!=")) { + Operand.RegisterOperand nonNullLitOperand = null; + if (binaryInst.left() instanceof Operand.NullConstantOperand) { + nonNullLitOperand = (Operand.RegisterOperand)binaryInst.right(); + } + else if (binaryInst.right() instanceof Operand.NullConstantOperand) { + nonNullLitOperand = (Operand.RegisterOperand)binaryInst.left(); + } + if (nonNullLitOperand != null) { + intOp = false; + Value otherValue = execStack.stack[base + nonNullLitOperand.frameSlot()]; + switch (binaryInst.binOp) { + case "==": { + value = otherValue instanceof Value.NullValue ? 1 : 0; + break; + } + case "!=": { + value = otherValue instanceof Value.NullValue ? 0 : 1; + break; + } + default: + throw new IllegalStateException(); + } + execStack.stack[base + binaryInst.result().frameSlot()] = new Value.IntegerValue(value); + } + } + if (intOp) { + if (binaryInst.left() instanceof Operand.ConstantOperand constant) + x = constant.value; + else if (binaryInst.left() instanceof Operand.RegisterOperand registerOperand) + x = ((Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]).value; + else throw new IllegalStateException(); + if (binaryInst.right() instanceof Operand.ConstantOperand constant) + y = constant.value; + else if (binaryInst.right() instanceof Operand.RegisterOperand registerOperand) + y = ((Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]).value; + else throw new IllegalStateException(); + switch (binaryInst.binOp) { + case "+": value = x + y; break; + case "-": value = x - y; break; + case "*": value = x * y; break; + case "/": value = x / y; break; + case "%": value = x % y; break; + case "==": value = x == y ? 1 : 0; break; + case "!=": value = x != y ? 1 : 0; break; + case "<": value = x < y ? 1: 0; break; + case ">": value = x > y ? 1 : 0; break; + case "<=": value = x <= y ? 1 : 0; break; + case ">=": value = x <= y ? 1 : 0; break; + default: throw new IllegalStateException(); + } + execStack.stack[base + binaryInst.result().frameSlot()] = new Value.IntegerValue(value); } - execStack.stack[base + binaryInst.result().frameSlot()] = new Value.IntegerValue(value); } case Instruction.NewArray newArrayInst -> { execStack.stack[base + newArrayInst.destOperand().frameSlot()] = new Value.ArrayValue(newArrayInst.type); @@ -172,6 +210,9 @@ else if (binaryInst.right() instanceof Operand.RegisterOperand registerOperand) if (arrayAppendInst.value() instanceof Operand.ConstantOperand constant) { arrayValue.values.add(new Value.IntegerValue(constant.value)); } + else if (arrayAppendInst.value() instanceof Operand.NullConstantOperand) { + arrayValue.values.add(new Value.NullValue()); + } else if (arrayAppendInst.value() instanceof Operand.RegisterOperand registerOperand) { arrayValue.values.add(execStack.stack[base + registerOperand.frameSlot()]); } @@ -193,6 +234,9 @@ else if (arrayStoreInst.indexOperand() instanceof Operand.RegisterOperand regist if (arrayStoreInst.sourceOperand() instanceof Operand.ConstantOperand constantOperand) { value = new Value.IntegerValue(constantOperand.value); } + else if (arrayStoreInst.sourceOperand() instanceof Operand.NullConstantOperand) { + value = new Value.NullValue(); + } else if (arrayStoreInst.sourceOperand() instanceof Operand.RegisterOperand registerOperand) { value = execStack.stack[base + registerOperand.frameSlot()]; } @@ -221,6 +265,9 @@ else if (arrayLoadInst.indexOperand() instanceof Operand.RegisterOperand registe if (setFieldInst.sourceOperand() instanceof Operand.ConstantOperand constant) { value = new Value.IntegerValue(constant.value); } + else if (setFieldInst.sourceOperand() instanceof Operand.NullConstantOperand) { + value = new Value.NullValue(); + } else if (setFieldInst.sourceOperand() instanceof Operand.RegisterOperand registerOperand) { value = execStack.stack[base + registerOperand.frameSlot()]; } diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java index e03fb80..7f914a6 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java @@ -11,6 +11,9 @@ public IntegerValue(long value) { } public final long value; } + static public class NullValue extends Value { + public NullValue() {} + } static public class ArrayValue extends Value { public final Type.TypeArray arrayType; public final ArrayList values; diff --git a/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java b/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java index 0b0d91f..2deb384 100644 --- a/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java +++ b/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java @@ -564,4 +564,20 @@ public void testFunction28() { L1: """, result); } + + @Test + public void testFunction29() { + String src = """ + struct Foo { var bar: Int } + func foo()->Foo? { return null; } + """; + String result = compileSrc(src); + Assert.assertEquals(""" +L0: + ret null + goto L1 +L1: +""", result); + } + } diff --git a/registervm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java b/registervm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java index 4913352..fcb93e2 100644 --- a/registervm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java +++ b/registervm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java @@ -102,4 +102,146 @@ func bar()->Int { return foo().field } Assert.assertTrue(value instanceof Value.IntegerValue integerValue && integerValue.value == 42); } + + @Test + public void testFunction100() { + String src = """ + struct Test + { + var field: Int + } + func foo()->Test? + { + return null; + } + + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.NullValue); + } + + @Test + public void testFunction101() { + String src = """ + func foo()->Int + { + return null == null; + } + + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 1); + } + + @Test + public void testFunction102() { + String src = """ + struct Foo + { + var next: Foo? + } + func foo()->Int + { + var f = new Foo{ next = null } + return null == f.next + } + + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 1); + } + + @Test + public void testFunction103() { + String src = """ + struct Foo + { + var i: Int + } + func foo()->Int + { + var f = new [Foo?] { new Foo{i = 1}, null } + return null == f[1] && 1 == f[0].i + } + + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 1); + } + + @Test + public void testFunction104() { + String src = """ + func foo()->Int + { + return 1 == 1 && 2 == 2 + } + + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 1); + } + + @Test + public void testFunction105() { + String src = """ + func bar(a: Int, b: Int)->Int + { + return a+1 == b-1 && b / a == 2 + } + func foo()->Int + { + return bar(3,5) + } + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 0); + } + + @Test + public void testFunction106() { + String src = """ + func bar(a: Int, b: Int)->Int + { + return a+1 == b-1 || b / a == 2 + } + func foo()->Int + { + return bar(3,5) + } + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 1); + } + + @Test + public void testFunction107() { + String src = """ + func bar(a: [Int])->Int + { + return a[0]+a[2] == a[1]-a[2] || a[1] / a[0] == 2 + } + func foo()->Int + { + return bar(new [Int] {3,5,1}) + } + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue && + integerValue.value == 1); + } } diff --git a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java index d996df5..f22c855 100644 --- a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java +++ b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java @@ -77,6 +77,8 @@ public ASTVisitor visit(AST.BinaryExpr binaryExpr, boolean enter) { else if (((binaryExpr.expr1.type instanceof Type.TypeNull && binaryExpr.expr2.type instanceof Type.TypeNullable) || (binaryExpr.expr1.type instanceof Type.TypeNullable && + binaryExpr.expr2.type instanceof Type.TypeNull) || + (binaryExpr.expr1.type instanceof Type.TypeNull && binaryExpr.expr2.type instanceof Type.TypeNull)) && (binaryExpr.op.str.equals("==") || binaryExpr.op.str.equals("!="))) { binaryExpr.type = typeDictionary.INT; @@ -269,7 +271,7 @@ public ASTVisitor visit(AST.ReturnStmt returnStmt, boolean enter) { return this; Type.TypeFunction functionType = (Type.TypeFunction) currentFuncDecl.symbol.type; if (returnStmt.expr != null) { - validType(returnStmt.expr.type, false); + validType(returnStmt.expr.type, true); checkAssignmentCompatible(functionType.returnType, returnStmt.expr.type); } else if (!(functionType.returnType instanceof Type.TypeVoid)) {