diff --git a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/OverrideUtils.java b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/OverrideUtils.java index 7d4b7e04a..c356a9184 100644 --- a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/OverrideUtils.java +++ b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/OverrideUtils.java @@ -119,8 +119,7 @@ public static void addOverride( ImFunction implementation = JassIm.ImFunction(e, subMethod.getName() + "_wrapper", JassIm.ImTypeVars(), parameters, rType, locals, body, flags); tr.getImProg().getFunctions().add(implementation); - List subMethods = Collections.emptyList(); - ImMethod wrapperMethod = JassIm.ImMethod(e, subMethod.getMethodClass(), subMethod.getName() + "_wrapper", implementation, subMethods, false); + ImMethod wrapperMethod = JassIm.ImMethod(e, subMethod.getMethodClass(), subMethod.getName() + "_wrapper", implementation, JassIm.ImMethods(), false); subClass.getMethods().add(wrapperMethod); superMethodIm.getSubMethods().add(wrapperMethod); } diff --git a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/StackTraceInjector2.java b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/StackTraceInjector2.java index 2162def15..fa516d5ab 100644 --- a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/StackTraceInjector2.java +++ b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/StackTraceInjector2.java @@ -45,6 +45,8 @@ public void transform(TimeTaker timeTaker) { final Multimap calls = LinkedListMultimap.create(); // called function -> calling function final Multimap callRelation = LinkedListMultimap.create(); + // calling function -> called function + final Multimap directCalls = LinkedListMultimap.create(); final List funcRefs = Lists.newArrayList(); prog.accept(new ImProg.DefaultVisitor() { @@ -72,9 +74,19 @@ public void visit(ImFunctionCall c) { calls.put(c.getFunc(), c); ImFunction caller = c.getNearestFunc(); callRelation.put(c.getFunc(), caller); + directCalls.put(caller, c.getFunc()); } } + @Override + public void visit(ImMethodCall c) { + super.visit(c); + ImFunction caller = c.getNearestFunc(); + ImFunction callee = c.getMethod().getImplementation(); + directCalls.put(caller, callee); + callRelation.put(callee, caller); + } + @Override public void visit(ImFuncRef imFuncRef) { super.visit(imFuncRef); @@ -97,6 +109,7 @@ public void visit(ImFuncRef imFuncRef) { Set affectedFuncs = Sets.newLinkedHashSet(stackTraceGets.keySet()); if (tr.isLuaTarget()) { + Set configOnlyFuncs = getConfigOnlyFunctions(directCalls, callRelation); // in Lua all functions are potentially affected, because we don't know where Lua might crash Stream.concat( prog.getFunctions().stream(), @@ -106,6 +119,7 @@ public void visit(ImFuncRef imFuncRef) { && !f.hasFlag(FunctionFlagEnum.IS_BJ) && !f.hasFlag(FunctionFlagEnum.IS_EXTERN)) .collect(Collectors.toCollection(() -> affectedFuncs)); + affectedFuncs.removeAll(configOnlyFuncs); } else { for (ImFunction stackTraceUse : stackTraceGets.keys()) { @@ -123,6 +137,60 @@ public void visit(ImFuncRef imFuncRef) { } + private Set getFunctionsReachableFrom(String functionName, Multimap directCalls) { + Set reachable = Sets.newLinkedHashSet(); + Deque worklist = new ArrayDeque<>(); + Stream.concat( + prog.getFunctions().stream(), + prog.getClasses().stream().flatMap(c -> c.getFunctions().stream())) + .filter(f -> f.getName().equals(functionName)) + .forEach(worklist::addLast); + while (!worklist.isEmpty()) { + ImFunction current = worklist.removeFirst(); + if (!reachable.add(current)) { + continue; + } + for (ImFunction callee : directCalls.get(current)) { + worklist.addLast(callee); + } + } + return reachable; + } + + private Set getConfigOnlyFunctions(Multimap directCalls, + Multimap reverseCalls) { + Set configReachable = getFunctionsReachableFrom("config", directCalls); + if (configReachable.isEmpty()) { + return Collections.emptySet(); + } + + Set runtimeSharedFuncs = Sets.newLinkedHashSet(); + Deque worklist = new ArrayDeque<>(); + for (ImFunction f : configReachable) { + boolean hasNonConfigCaller = reverseCalls.get(f).stream() + .anyMatch(caller -> !configReachable.contains(caller)); + if (hasNonConfigCaller) { + worklist.addLast(f); + } + } + + while (!worklist.isEmpty()) { + ImFunction current = worklist.removeFirst(); + if (!runtimeSharedFuncs.add(current)) { + continue; + } + for (ImFunction callee : directCalls.get(current)) { + if (configReachable.contains(callee)) { + worklist.addLast(callee); + } + } + } + + Set configOnlyFuncs = Sets.newLinkedHashSet(configReachable); + configOnlyFuncs.removeAll(runtimeSharedFuncs); + return configOnlyFuncs; + } + private void rewriteMethodCalls(Set affectedFuncs) { if (!tr.isLuaTarget()) { return; diff --git a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/ExprTranslation.java b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/ExprTranslation.java index 76843e035..8ec95ee4f 100644 --- a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/ExprTranslation.java +++ b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/ExprTranslation.java @@ -175,6 +175,17 @@ public static LuaExpr translate(ImMemberAccess e, LuaTranslator tr) { } public static LuaExpr translate(ImMethodCall e, LuaTranslator tr) { + ImMethod method = e.getMethod(); + if (!method.getIsAbstract() + && method.getImplementation() != null + && method.getSubMethods().isEmpty()) { + LuaExprlist args = LuaAst.LuaExprlist(); + args.add(e.getReceiver().translateToLua(tr)); + for (ImExpr arg : e.getArguments()) { + args.add(arg.translateToLua(tr)); + } + return LuaAst.LuaExprFunctionCall(tr.luaFunc.getFor(method.getImplementation()), args); + } return LuaAst.LuaExprMethodCall(e.getReceiver().translateToLua(tr), tr.luaMethod.getFor(e.getMethod()), tr.translateExprList(e.getArguments())); } diff --git a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java index d8f6b5edd..eef1e8778 100644 --- a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java +++ b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java @@ -2,12 +2,7 @@ import de.peeeq.datastructures.UnionFind; import de.peeeq.wurstscript.WLogger; -import de.peeeq.wurstscript.ast.ClassDef; -import de.peeeq.wurstscript.ast.Element; -import de.peeeq.wurstscript.ast.FuncDef; -import de.peeeq.wurstscript.ast.NameDef; -import de.peeeq.wurstscript.ast.WParameter; -import de.peeeq.wurstscript.ast.WPackage; +import de.peeeq.wurstscript.ast.*; import de.peeeq.wurstscript.jassIm.*; import de.peeeq.wurstscript.luaAst.*; import de.peeeq.wurstscript.translation.imtranslation.FunctionFlagEnum; @@ -220,7 +215,7 @@ public LuaMethod initFor(ImClass a) { private final Lazy errorFunc = Lazy.create(() -> this.getProg().getFunctions().stream() .flatMap(f -> { - Element trace = f.attrTrace(); + de.peeeq.wurstscript.ast.Element trace = f.attrTrace(); if (trace instanceof FuncDef) { FuncDef fd = (FuncDef) trace; if (fd.getName().equals("error") @@ -438,7 +433,7 @@ private void collectPredefinedNames() { } private void setNameFromTrace(JassImElementWithName named) { - Element trace = named.attrTrace(); + de.peeeq.wurstscript.ast.Element trace = named.attrTrace(); if (trace instanceof NameDef) { named.setName(((NameDef) trace).getName()); } @@ -459,9 +454,7 @@ private void normalizeMethodNames() { // give all related methods the same name in deterministic order List> groups = new ArrayList<>(); for (Set group : methodUnions.groups().values()) { - List sortedGroup = new ArrayList<>(group); - sortedGroup.sort(Comparator.comparing(this::methodSortKey)); - groups.add(sortedGroup); + groups.addAll(partitionMethodsByDispatchSignature(group)); } groups.sort(Comparator.comparing(g -> g.isEmpty() ? "" : methodSortKey(g.get(0)))); for (List group : groups) { @@ -1026,7 +1019,7 @@ private void createClassInitFunction(ImClass c, LuaVariable classVar, LuaMethod // local new_inst = { ... } LuaTableFields initialFieldValues = LuaAst.LuaTableFields(); LuaVariable newInst = LuaAst.LuaVariable("new_inst", LuaAst.LuaTableConstructor(initialFieldValues)); - for (ImVar field : c.getFields()) { + for (ImVar field : collectFieldsForAllocation(c)) { initialFieldValues.add( LuaAst.LuaTableNamedField(field.getName(), defaultValue(field.getType())) ); @@ -1044,6 +1037,25 @@ private void createClassInitFunction(ImClass c, LuaVariable classVar, LuaMethod body.add(LuaAst.LuaReturn(LuaAst.LuaExprVarAccess(newInst))); } + private List collectFieldsForAllocation(ImClass c) { + List result = new ArrayList<>(); + Set visited = new HashSet<>(); + collectFieldsForAllocation(c, result, visited); + return result; + } + + private void collectFieldsForAllocation(ImClass c, List out, Set visited) { + if (!visited.add(c)) { + return; + } + List superClasses = new ArrayList<>(c.getSuperClasses()); + superClasses.sort(Comparator.comparing(sc -> classSortKey(sc.getClassDef()))); + for (ImClassType sc : superClasses) { + collectFieldsForAllocation(sc.getClassDef(), out, visited); + } + out.addAll(c.getFields()); + } + private void initClassTables(ImClass c) { LuaVariable classVar = luaClassVar.getFor(c); // create methods: @@ -1087,7 +1099,10 @@ private void createMethods(ImClass c, LuaVariable classVar) { groupedMethods.computeIfAbsent(root, k -> new ArrayList<>()).add(method); } - List> groups = new ArrayList<>(groupedMethods.values()); + List> groups = new ArrayList<>(); + for (List methods : groupedMethods.values()) { + groups.addAll(partitionMethodsByDispatchSignature(methods)); + } groups.sort(Comparator.comparing(group -> group.isEmpty() ? "" : methodSortKey(group.get(0)))); Map slotToImpl = new TreeMap<>(); for (List groupMethods : groups) { @@ -1122,6 +1137,58 @@ private void createMethods(ImClass c, LuaVariable classVar) { } } + private List> partitionMethodsByDispatchSignature(Collection methods) { + Map> partitions = new TreeMap<>(); + for (ImMethod method : methods) { + partitions.computeIfAbsent(dispatchSignatureKey(method), k -> new ArrayList<>()).add(method); + } + List> result = new ArrayList<>(partitions.values()); + for (List group : result) { + group.sort(Comparator.comparing(this::methodSortKey)); + } + result.sort(Comparator.comparing(group -> group.isEmpty() ? "" : methodSortKey(group.get(0)))); + return result; + } + + private String dispatchSignatureKey(ImMethod method) { + ImFunction implementation = resolveDispatchSignatureImplementation(method, new HashSet<>()); + if (implementation == null) { + return ""; + } + StringBuilder sb = new StringBuilder(); + sb.append(typeKey(implementation.getReturnType())).append("|"); + ImVars params = implementation.getParameters(); + for (int i = 1; i < params.size(); i++) { + if (i > 1) { + sb.append(","); + } + sb.append(typeKey(params.get(i).getType())); + } + return sb.toString(); + } + + private ImFunction resolveDispatchSignatureImplementation(ImMethod method, Set visited) { + if (method == null || !visited.add(method)) { + return null; + } + if (method.getImplementation() != null) { + return method.getImplementation(); + } + List subMethods = new ArrayList<>(method.getSubMethods()); + subMethods.sort(Comparator.comparing(this::methodSortKey)); + for (ImMethod subMethod : subMethods) { + ImFunction resolved = resolveDispatchSignatureImplementation(subMethod, visited); + if (resolved != null) { + return resolved; + } + } + return null; + } + + private String typeKey(ImType type) { + return type == null ? "" : type.toString(); + } + private Set collectDispatchSlotNames(ImClass receiverClass, List groupMethods) { Set slotNames = new TreeSet<>(); Set semanticNames = new TreeSet<>(); @@ -1141,6 +1208,12 @@ private Set collectDispatchSlotNames(ImClass receiverClass, List classNames = new TreeSet<>(); @@ -1327,6 +1400,32 @@ private String semanticNameFromMethodName(String methodName) { return methodName; } + private boolean isClosureGeneratedClass(ImClass c) { + return c != null && c.attrTrace() instanceof ExprClosure; + } + + private String sourceSemanticName(ImMethod method) { + if (method == null) { + return ""; + } + de.peeeq.wurstscript.ast.Element trace = method.attrTrace(); + if (trace instanceof FuncDef funcDef) { + return funcDef.getName(); + } + if (trace instanceof AstElementWithFuncName withFuncName) { + return withFuncName.getFuncNameId().getName(); + } + if (method.getImplementation() != null) { + String implementationName = method.getImplementation().getName(); + int firstUnderscore = implementationName.indexOf('_'); + if (firstUnderscore > 0) { + return implementationName.substring(0, firstUnderscore); + } + return implementationName; + } + return ""; + } + private String classSortKey(ImClass c) { if (c == null) { return ""; @@ -1471,7 +1570,7 @@ public LuaFunction getErrorFunc() { } public String getTypeCastingFunctionName(ImFunction f) { - Element trace = f.attrTrace(); + de.peeeq.wurstscript.ast.Element trace = f.attrTrace(); if (trace instanceof FuncDef fd && fd.attrNearestPackage() instanceof WPackage p) { if ("TypeCasting".equals(p.getName())) { return fd.getName(); diff --git a/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java index 9e0832362..9366d263c 100644 --- a/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java +++ b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java @@ -67,6 +67,15 @@ private void assertFunctionBodyContains(String output, String functionName, Stri assertTrue("Function " + functionName + " was not found.", found); } + private String getFunctionBody(String output, String functionName) { + Pattern pattern = Pattern.compile("function\\s*" + functionName + "\\s*\\(.*\\).*\\n" + "((?:\\n|.)*?)end"); + Matcher matcher = pattern.matcher(output); + if (!matcher.find()) { + fail("Function " + functionName + " was not found."); + } + return matcher.group(1); + } + private void assertDoesNotContainRegex(String output, String regex) { Pattern pattern = Pattern.compile(regex); Matcher matcher = pattern.matcher(output); @@ -87,6 +96,38 @@ private void assertOccursBefore(String output, String first, String second) { assertTrue("Expected '" + first + "' before '" + second + "'.", firstPos < secondPos); } + private List uniqueMatches(String output, String regex, int group) { + Matcher matcher = Pattern.compile(regex).matcher(output); + List result = new ArrayList<>(); + while (matcher.find()) { + String value = matcher.group(group); + if (!result.contains(value)) { + result.add(value); + } + } + return result; + } + + private String singleMatch(String output, String regex, int group) { + Matcher matcher = Pattern.compile(regex).matcher(output); + assertTrue("Expected pattern to occur: " + regex, matcher.find()); + String result = matcher.group(group); + assertFalse("Expected exactly one match for pattern: " + regex, matcher.find()); + return result; + } + + private List nonBaseSubclassBindings(String output, String baseName, String slotName) { + Matcher matcher = Pattern.compile("([A-Za-z0-9_]+)\\." + Pattern.quote(slotName) + "\\s*=\\s*[A-Za-z0-9_]+").matcher(output); + List result = new ArrayList<>(); + while (matcher.find()) { + String owner = matcher.group(1); + if (!owner.equals(baseName) && owner.startsWith(baseName + "_") && !result.contains(owner)) { + result.add(owner); + } + } + return result; + } + private String compileLuaWithRunArgs(String testName, boolean withStdLib, String... lines) { RunArgs runArgs = new RunArgs().with("-lua", "-inline", "-localOptimizations", "-stacktraces"); WurstGui gui = new WurstGuiCliImpl(); @@ -382,7 +423,7 @@ public void lazyGenericClosureDispatchWorksInLua() throws IOException { ); String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_lazyGenericClosureDispatchWorksInLua.lua"), Charsets.UTF_8); assertTrue(compiled.contains("Lazy_lazy_Test.Lazy_retrieve =")); - assertTrue(compiled.contains("l:Lazy_get()")); + assertTrue(compiled.contains("Lazy_Lazy_get(l)") || compiled.contains("l:Lazy_get()")); } @Test @@ -435,6 +476,433 @@ public void overloadedMethodsDoNotAliasInLuaDispatchTables() throws IOException assertFalse(compiled.contains("Writer.Writer_write = Writer_Writer_write1")); } + @Test + public void overloadedOverrideDispatchDoesNotCollapseLuaSlots() throws IOException { + test().testLua(true).lines( + "package DispatchOverloadBug", + "class Base", + " function doThing(int a)", + " this.doThing(a, 0)", + " function doThing(int a, int b)", + " this.doThing(a, b, false)", + " function doThing(int a, int b, boolean flag)", + " skip", + "class Child extends Base", + " override function doThing(int a, int b)", + " super.doThing(a, b)", + "init", + " let c = new Child()", + " c.doThing(1)" + ); + + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_overloadedOverrideDispatchDoesNotCollapseLuaSlots.lua"), Charsets.UTF_8); + Matcher slotMatcher = Pattern.compile("Child\\.(Base_doThing\\d*)\\s*=\\s*Child_Child_doThing").matcher(compiled); + List overriddenSlots = new ArrayList<>(); + while (slotMatcher.find()) { + overriddenSlots.add(slotMatcher.group(1)); + } + assertEquals("Expected Child override to bind exactly one dispatch slot.", 1, overriddenSlots.size()); + + Matcher baseSlotsMatcher = Pattern.compile("Base\\.(Base_doThing\\d*)\\s*=").matcher(compiled); + List baseSlots = new ArrayList<>(); + while (baseSlotsMatcher.find()) { + String slot = baseSlotsMatcher.group(1); + if (!baseSlots.contains(slot)) { + baseSlots.add(slot); + } + } + assertEquals("Expected three distinct Base overload dispatch slots.", 3, baseSlots.size()); + assertTrue(compiled.contains("this:Base_doThing1(a, 0)")); + assertTrue(compiled.contains("Base_Base_doThing2(this1, a1, b, false)")); + assertTrue(compiled.contains("Child.Base_doThing1 = Child_Child_doThing")); + assertTrue(compiled.contains("Child.Base_doThing2 = Base_Base_doThing2")); + } + + @Test + public void moduleProvidedOverloadedOverrideDoesNotCollapseLuaSlots() { + String compiled = compileLuaWithRunArgs( + "LuaTranslationTests_moduleProvidedOverloadedOverrideDoesNotCollapseLuaSlots", + false, + "package Test", + "module M", + " function setup(int a)", + " this.setup(a, 0)", + " function setup(int a, int b)", + " this.setup(a, b, false)", + " function setup(int a, int b, boolean c)", + " skip", + "class Base", + " use M", + "class Child extends Base", + " override function setup(int a, int b)", + " super.setup(a, b)", + "init", + " let c = new Child()", + " c.setup(1)" + ); + + List overriddenSlots = uniqueMatches(compiled, "Child\\.Base(?:_M)?_setup(\\d*)\\s*=\\s*Child_Child_setup", 1); + List baseSlots = uniqueMatches(compiled, "Base\\.Base(?:_M)?_setup(\\d*)\\s*=", 1); + + assertEquals("Expected exactly one overridden setup overload family from module-provided methods.", 1, overriddenSlots.size()); + assertEquals("Expected three distinct setup slots on Base.", 3, baseSlots.size()); + assertTrue(compiled.contains(":Base_M_setup1(") || compiled.contains(":Base_setup1(")); + assertContainsRegex(compiled, "Child\\.Base(?:_M)?_setup" + Pattern.quote(overriddenSlots.get(0)) + "\\s*=\\s*Child_Child_setup"); + } + + @Test + public void multiLevelOverloadedOverridesKeepDistinctLuaSlots() { + String compiled = compileLuaWithRunArgs( + "LuaTranslationTests_multiLevelOverloadedOverridesKeepDistinctLuaSlots", + false, + "package Test", + "class Base", + " function setup(int a)", + " this.setup(a, 0)", + " function setup(int a, int b)", + " this.setup(a, b, false)", + " function setup(int a, int b, boolean c)", + " skip", + "class Mid extends Base", + " override function setup(int a, int b)", + " super.setup(a, b)", + "class Child extends Mid", + " override function setup(int a, int b, boolean c)", + " super.setup(a, b, c)", + "init", + " let c = new Child()", + " c.setup(1)" + ); + + List midSlots = uniqueMatches(compiled, "Mid\\.(Base_setup\\d*)\\s*=\\s*Mid_Mid_setup", 1); + List childSlots = uniqueMatches(compiled, "Child\\.(Base_setup\\d*)\\s*=\\s*Child_Child_setup", 1); + List baseSlots = uniqueMatches(compiled, "Base\\.(Base_setup\\d*)\\s*=", 1); + + assertEquals("Expected Mid to override exactly one base setup slot.", 1, midSlots.size()); + assertEquals("Expected Child to override exactly one base setup slot.", 1, childSlots.size()); + assertEquals("Expected three distinct setup slots across hierarchy.", 3, baseSlots.size()); + assertFalse("Expected different overload slots for Mid and Child overrides.", midSlots.get(0).equals(childSlots.get(0))); + assertTrue(compiled.contains("Child." + midSlots.get(0) + " = Mid_Mid_setup")); + assertTrue(compiled.contains("Child." + childSlots.get(0) + " = Child_Child_setup")); + } + + @Test + public void generatedOverloadDispatchMatrixKeepsDistinctLuaSlots() { + for (int overloadCount = 3; overloadCount <= 5; overloadCount++) { + for (int overrideIndex = 2; overrideIndex < overloadCount; overrideIndex++) { + String testName = "LuaTranslationTests_generatedOverloadDispatchMatrixKeepsDistinctLuaSlots_" + + overloadCount + "_" + overrideIndex; + String compiled = compileLuaWithRunArgs( + testName, + false, + generatedOverloadDispatchMatrixLines(overloadCount, overrideIndex) + ); + + List baseSlots = uniqueMatches(compiled, "Base\\.(Base_route\\d*)\\s*=", 1); + List overriddenSlots = uniqueMatches(compiled, "Child\\.(Base_route\\d*)\\s*=\\s*Child_Child_route", 1); + + assertEquals("Expected one overridden route slot for overloadCount=" + overloadCount + ", overrideIndex=" + overrideIndex, + 1, overriddenSlots.size()); + assertEquals("Expected distinct route slots to match overload count for overloadCount=" + overloadCount, + overloadCount, baseSlots.size()); + assertTrue("Missing child binding for overridden slot " + overriddenSlots.get(0), + compiled.contains("Child." + overriddenSlots.get(0) + " = Child_Child_route")); + + for (int stubIndex = Math.max(overrideIndex - 1, 1); stubIndex < overloadCount - 1; stubIndex++) { + String functionName = stubIndex == 0 ? "Base_Base_route" : "Base_Base_route" + stubIndex; + String nextFunctionName = "Base_Base_route" + (stubIndex + 1); + String body = getFunctionBody(compiled, functionName); + assertTrue("Expected " + functionName + " to directly call " + nextFunctionName, + body.contains(nextFunctionName + "(")); + assertFalse("Expected " + functionName + " to avoid virtual route dispatch in stub body.", + body.contains(":Base_route(")); + assertFalse("Expected " + functionName + " to avoid virtual route dispatch to any numbered route slot.", + body.contains(":Base_route1(") + || body.contains(":Base_route2(") + || body.contains(":Base_route3(") + || body.contains(":Base_route4(")); + } + } + } + } + + @Test + public void middleOverloadStubDirectCallsFinalImplInLua() { + test().testLua(true).lines( + "package Test", + "class Base", + " function setup(int a, int b, int c, int d, int e, int f)", + " skip", + " function setup(int a, int b, int c, int d, int e, int f, int g)", + " this.setup(a, b, c, d, e, f, g, 0)", + " function setup(int a, int b, int c, int d, int e, int f, int g, int h)", + " this.setup(a, b, c, d, e, f, g, h, 0)", + " function setup(int a, int b, int c, int d, int e, int f, int g, int h, int i)", + " skip", + "class Child extends Base", + " override function setup(int a, int b, int c, int d, int e, int f, int g, int h)", + " super.setup(a, b, c, d, e, f, g, h)", + "init", + " let c = new Child()", + " c.setup(1, 2, 3, 4, 5, 6, 7)" + ); + + try { + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_middleOverloadStubDirectCallsFinalImplInLua.lua"), Charsets.UTF_8); + String middleBody = getFunctionBody(compiled, "Base_Base_setup1"); + assertTrue(middleBody.contains("Base_Base_setup2(")); + assertFalse(middleBody.contains(":Base_setup(")); + assertFalse(middleBody.contains(":Base_setup1(")); + assertFalse(middleBody.contains(":Base_setup2(")); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Test + public void kuiStyleMiddleOverrideChainUsesDirectFinalImplCallInLua() { + String compiled = compileLuaWithRunArgs( + "LuaTranslationTests_kuiStyleMiddleOverrideChainUsesDirectFinalImplCallInLua", + false, + "package Test", + "class KUIFrame", + " function setup(int a, int b, int c, int d, int e, int f, int g)", + " skip", + "class KUIWindow extends KUIFrame", + " function setup(int a, int b, int c, int d, int e, int f, int g, int h)", + " this.setup(a, b, c, d, e, f, g, h, 0)", + " function setup(int a, int b, int c, int d, int e, int f, int g, int h, int i)", + " this.setup(a, b, c, d, e, f, g, h, i, 0)", + " function setup(int a, int b, int c, int d, int e, int f, int g, int h, int i, int j)", + " skip", + "class KUITargetUnit extends KUIWindow", + " override function setup(int a, int b, int c, int d, int e, int f, int g, int h, int i)", + " super.setup(a, b, c, d, e, f, g, h, i)", + "init", + " let x = new KUITargetUnit()", + " x.setup(1, 2, 3, 4, 5, 6, 7, 8)" + ); + + String setupBody = getFunctionBody(compiled, "KUIWindow_KUIWindow_setup1"); + assertTrue(setupBody.contains("KUIWindow_KUIWindow_setup2(")); + assertFalse(setupBody.contains(":KUIWindow_setup(")); + assertFalse(setupBody.contains(":KUITargetUnit_setup(")); + } + + @Test + public void closureInterfaceCallsitesUseSubclassBindingsForSameLuaSlot() { + String compiled = compileLuaWithRunArgs( + "LuaTranslationTests_closureInterfaceCallsitesUseSubclassBindingsForSameLuaSlot", + false, + "package Test", + "interface LLItrClosure", + " function run(int t)", + "class LinkedList", + " function forEach(LLItrClosure f)", + " f.run(1)", + "function feed(LinkedList xs)", + " xs.forEach((int t) -> skip)", + " xs.forEach((int t) -> skip)", + "init", + " let xs = new LinkedList()", + " feed(xs)" + ); + + String forEachBody = getFunctionBody(compiled, "LinkedList_LinkedList_forEach"); + String slotName = singleMatch(forEachBody, ":([A-Za-z0-9_]+)\\(", 1); + assertEquals("run", slotName); + List subclassBindings = nonBaseSubclassBindings(compiled, "LLItrClosure", slotName); + + assertTrue("Expected generated closure subclasses for LLItrClosure to bind slot " + slotName, + subclassBindings.size() >= 2); + assertContainsRegex(compiled, "LLItrClosure_[A-Za-z0-9_]+\\." + Pattern.quote(slotName) + "\\s*="); + } + + @Test + public void abstractCallbackFamiliesKeepCallsiteAndSubclassSlotNamesAlignedInLua() { + String compiled = compileLuaWithRunArgs( + "LuaTranslationTests_abstractCallbackFamiliesKeepCallsiteAndSubclassSlotNamesAlignedInLua", + false, + "package Test", + "public abstract class OtherCallback", + " abstract function callback(int elem)", + "class OtherRegistry", + " function applyTo(OtherCallback cb)", + " cb.callback(0)", + "public abstract class ForElementCallback", + " abstract function callback(int elem)", + "class Registry", + " function forEachIn(ForElementCallback cb)", + " cb.callback(1)", + "function runAll(Registry r, OtherRegistry o)", + " r.forEachIn() elem ->", + " skip", + " r.forEachIn() elem ->", + " skip", + " o.applyTo() elem ->", + " skip", + "init", + " runAll(new Registry(), new OtherRegistry())" + ); + + String forEachBody = getFunctionBody(compiled, "Registry_Registry_forEachIn"); + String callbackSlot = singleMatch(forEachBody, ":([A-Za-z0-9_]+)\\(", 1); + assertEquals("callback", callbackSlot); + List subclassBindings = nonBaseSubclassBindings(compiled, "ForElementCallback", callbackSlot); + + assertTrue("Expected generated ForElementCallback closure subclasses to bind slot " + callbackSlot, + subclassBindings.size() >= 2); + assertContainsRegex(compiled, "ForElementCallback_[A-Za-z0-9_]+\\." + Pattern.quote(callbackSlot) + "\\s*="); + + String otherBody = getFunctionBody(compiled, "OtherRegistry_OtherRegistry_applyTo"); + String otherSlot = singleMatch(otherBody, ":([A-Za-z0-9_]+)\\(", 1); + assertTrue(otherSlot.startsWith("callback")); + List otherSubclassBindings = nonBaseSubclassBindings(compiled, "OtherCallback", otherSlot); + + assertTrue("Expected generated OtherCallback closure subclass to bind slot " + otherSlot, + otherSubclassBindings.size() >= 1); + assertContainsRegex(compiled, "OtherCallback_[A-Za-z0-9_]+\\." + Pattern.quote(otherSlot) + "\\s*="); + } + + @Test + public void genericClosureInterfacesKeepPrefixedBaseSlotNamesInLua() { + String compiled = compileLuaWithRunArgs( + "LuaTranslationTests_genericClosureInterfacesKeepPrefixedBaseSlotNamesInLua", + false, + "package Test", + "class Box", + " T elem", + " construct(T elem)", + " this.elem = elem", + "class LinkedList", + " Box stored = null", + " function add(X x)", + " stored = new Box(x)", + " function forEach(LLItrClosure f)", + " if stored != null", + " f.run(stored.elem)", + "public interface LLItrClosure", + " function run(T t)", + "function test()", + " let xs = new LinkedList()", + " xs.add(1)", + " xs.forEach() itr ->", + " skip", + "init", + " test()" + ); + + String forEachBody = getFunctionBody(compiled, "LinkedList_LinkedList_forEach"); + String slotName = singleMatch(forEachBody, ":([A-Za-z0-9_]+)\\(", 1); + assertEquals("LLItrClosure_run", slotName); + assertContainsRegex(compiled, "LLItrClosure_[A-Za-z0-9_]+\\." + Pattern.quote(slotName) + "\\s*="); + assertDoesNotContainRegex(compiled, "LLItrClosure_[A-Za-z0-9_]+\\.LLItrClosure_run\\d+\\s*="); + } + + @Test + public void genericAbstractCallbacksKeepPrefixedBaseSlotNamesInLua() { + String compiled = compileLuaWithRunArgs( + "LuaTranslationTests_genericAbstractCallbacksKeepPrefixedBaseSlotNamesInLua", + false, + "package Test", + "class Registry", + " T stored = null", + " construct(T stored)", + " this.stored = stored", + " function forEachIn(ForElementCallback cb)", + " if stored != null", + " cb.callback(stored)", + "public abstract class ForElementCallback", + " abstract function callback(T elem)", + "function test()", + " let r = new Registry(1)", + " r.forEachIn() elem ->", + " skip", + "init", + " test()" + ); + + String forEachBody = getFunctionBody(compiled, "Registry_Registry_forEachIn"); + String slotName = singleMatch(forEachBody, ":([A-Za-z0-9_]+)\\(", 1); + assertEquals("ForElementCallback_callback", slotName); + assertContainsRegex(compiled, "ForElementCallback_[A-Za-z0-9_]+\\." + Pattern.quote(slotName) + "\\s*="); + assertDoesNotContainRegex(compiled, "ForElementCallback_[A-Za-z0-9_]+\\.ForElementCallback_callback\\d+\\s*="); + } + + @Test + public void genericRegistryModuleCallbacksKeepModuleQualifiedLuaSlotNames() throws IOException { + test().testLua(true).lines( + "package Test", + "public interface ForElementCallback", + " function callback(T elem)", + "public module RegistryModule", + " static T stored = null", + " protected static function addToRegistry(T obj) returns int", + " stored = obj", + " return 0", + " static function forEachIn(ForElementCallback cb) returns ForElementCallback", + " if stored != null", + " cb.callback(stored)", + " return cb", + "class Dungeon", + " use RegistryModule", + " static function makeOne()", + " addToRegistry(new Dungeon())", + "function setup()", + " Dungeon.makeOne()", + " Dungeon.forEachIn((dungeon) -> begin", + " skip", + " end)", + "init", + " setup()" + ); + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_genericRegistryModuleCallbacksKeepModuleQualifiedLuaSlotNames.lua"), Charsets.UTF_8); + + assertContainsRegex(compiled, "function\\s+[A-Za-z0-9_]*forEachIn\\("); + assertContainsRegex(compiled, "ForElementCallback_[A-Za-z0-9_]+\\.(RegistryModule_)?ForElementCallback_callback\\s*="); + assertDoesNotContainRegex(compiled, "ForElementCallback_[A-Za-z0-9_]+\\.ForElementCallback_callback\\d+\\s*="); + } + + @Test + public void nestedGenericLinkedListClosuresKeepPrefixedLuaSlotNames() throws IOException { + test().testLua(true).lines( + "package Test", + "public interface CallbackSingle", + " function run()", + "public function _Logging_executeCustom(string msg, CallbackSingle cb)", + " cb.run()", + "public interface LLItrClosure", + " function run(T t)", + "class Node", + " T elem", + " Node next = null", + "class LinkedList", + " Node first = null", + " function add(T elem)", + " let n = new Node()", + " n.elem = elem", + " n.next = first", + " first = n", + " function forEach(LLItrClosure itr)", + " var cur = first", + " while cur != null", + " itr.run(cur.elem)", + " cur = cur.next", + "function setup()", + " let ids = new LinkedList()", + " ids.add(1)", + " ids.forEach() id ->", + " _Logging_executeCustom(\"x\") ->", + " skip", + "init", + " setup()" + ); + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_nestedGenericLinkedListClosuresKeepPrefixedLuaSlotNames.lua"), Charsets.UTF_8); + + assertContainsRegex(compiled, "function\\s+LinkedList_[A-Za-z0-9_]*forEach\\("); + assertContainsRegex(compiled, "LLItrClosure_[A-Za-z0-9_]+\\.LLItrClosure_run\\s*="); + } + @Test public void mainAndConfigNamesFixed() throws IOException { test().testLua(true).lines( @@ -605,6 +1073,59 @@ public void configEntrypointNotRenamedWhenUserHasConfigFunction() throws IOExcep assertTrue(compiled.contains("config1()")); } + @Test + public void stacktracesAreNotInjectedIntoConfigSeededLuaFunctions() { + String compiled = compileLuaWithRunArgs( + "LuaTranslationTests_stacktracesAreNotInjectedIntoConfigSeededLuaFunctions", + false, + "package Test", + "@noinline function configHelper()", + " skip", + "@noinline function runtimeHelper()", + " skip", + "function config()", + " configHelper()", + "init", + " runtimeHelper()" + ); + + assertDoesNotContainRegex(compiled, "function\\s+config\\([^\\)]*__wurst_stackPos"); + assertDoesNotContainRegex(compiled, "function\\s+config1\\([^\\)]*__wurst_stackPos"); + assertDoesNotContainRegex(compiled, "function\\s+configHelper\\([^\\)]*__wurst_stackPos"); + assertDoesNotContainRegex(compiled, "configHelper\\(\"when calling configHelper"); + + assertContainsRegex(compiled, "function\\s+runtimeHelper\\([^\\)]*__wurst_stackPos"); + assertContainsRegex(compiled, "runtimeHelper\\(\"when calling runtimeHelper"); + } + + @Test + public void stacktracesStayInjectedForLuaHelpersSharedByConfigAndRuntime() { + String compiled = compileLuaWithRunArgs( + "LuaTranslationTests_stacktracesStayInjectedForLuaHelpersSharedByConfigAndRuntime", + false, + "package Test", + "@noinline function sharedLeaf()", + " skip", + "@noinline function sharedHelper()", + " sharedLeaf()", + "@noinline function configOnlyHelper()", + " skip", + "function config()", + " configOnlyHelper()", + " sharedHelper()", + "init", + " sharedHelper()" + ); + + assertDoesNotContainRegex(compiled, "function\\s+configOnlyHelper\\([^\\)]*__wurst_stackPos"); + assertDoesNotContainRegex(compiled, "configOnlyHelper\\(\"when calling configOnlyHelper"); + + assertContainsRegex(compiled, "function\\s+sharedHelper\\([^\\)]*__wurst_stackPos"); + assertContainsRegex(compiled, "function\\s+sharedLeaf\\([^\\)]*__wurst_stackPos"); + assertContainsRegex(compiled, "sharedHelper\\(\"when calling sharedHelper"); + assertContainsRegex(compiled, "sharedLeaf\\(\"when calling sharedLeaf"); + } + @Test public void objectIndexFunctionsDoNotCollideWithUserFunctions() throws IOException { test().testLua(true).lines( @@ -952,6 +1473,56 @@ private CU[] genericOverrideGlobalStateReproUnits() { }; } + private String[] generatedOverloadDispatchMatrixLines(int overloadCount, int overrideIndex) { + List lines = new ArrayList<>(); + lines.add("package Test"); + lines.add("class Base"); + for (int i = 1; i <= overloadCount; i++) { + lines.add(" function route(" + routeParams(i) + ")"); + if (i < overloadCount) { + lines.add(" this.route(" + routeArgsToNext(i) + ")"); + } else { + lines.add(" skip"); + } + } + lines.add("class Child extends Base"); + lines.add(" override function route(" + routeParams(overrideIndex) + ")"); + lines.add(" super.route(" + routeArgsCurrent(overrideIndex) + ")"); + lines.add("init"); + lines.add(" let c = new Child()"); + lines.add(" c.route(" + routeStartArgs() + ")"); + return lines.toArray(new String[0]); + } + + private String routeParams(int arity) { + List params = new ArrayList<>(); + for (int i = 1; i <= arity; i++) { + params.add("int p" + i); + } + return String.join(", ", params); + } + + private String routeArgsToNext(int arity) { + List args = new ArrayList<>(); + for (int i = 1; i <= arity; i++) { + args.add("p" + i); + } + args.add("0"); + return String.join(", ", args); + } + + private String routeArgsCurrent(int arity) { + List args = new ArrayList<>(); + for (int i = 1; i <= arity; i++) { + args.add("p" + i); + } + return String.join(", ", args); + } + + private String routeStartArgs() { + return "1"; + } + @Test public void largeFunctionSpillsLocalsIntoTableInLua() throws IOException { List lines = new ArrayList<>(); @@ -1504,4 +2075,29 @@ public void removesUnusedClassesFromLuaOutput() throws IOException { assertFalse(compiled.contains("Drop")); } + @Test + public void subclassAllocationIncludesInheritedFieldsInLua() throws IOException { + test().testLua(true).lines( + "package Test", + "class Window", + " int anchorTop", + " int anchorBottom", + " function anchored() returns boolean", + " return anchorTop == 0 and anchorBottom == 0", + "class Child extends Window", + " int ownField", + "init", + " let c = new Child()", + " c.ownField = 7", + " if c.anchored()", + " skip" + ); + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_subclassAllocationIncludesInheritedFieldsInLua.lua"), Charsets.UTF_8); + + assertContainsRegex(compiled, + "function\\s+[A-Za-z0-9_]+:create\\d+\\s*\\(\\)\\s*\\n\\s*local new_inst = \\(\\{[^\\n]*Window_anchorTop="); + assertContainsRegex(compiled, + "function\\s+[A-Za-z0-9_]+:create\\d+\\s*\\(\\)\\s*\\n\\s*local new_inst = \\(\\{[^\\n]*Window_anchorBottom="); + } + }