Skip to content

Commit 91cdc24

Browse files
authored
fix lua generic dispatch (#1178)
1 parent fa03c67 commit 91cdc24

File tree

5 files changed

+790
-17
lines changed

5 files changed

+790
-17
lines changed

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/OverrideUtils.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ public static void addOverride(
119119
ImFunction implementation = JassIm.ImFunction(e, subMethod.getName() + "_wrapper", JassIm.ImTypeVars(), parameters, rType, locals, body, flags);
120120
tr.getImProg().getFunctions().add(implementation);
121121

122-
List<ImMethod> subMethods = Collections.emptyList();
123-
ImMethod wrapperMethod = JassIm.ImMethod(e, subMethod.getMethodClass(), subMethod.getName() + "_wrapper", implementation, subMethods, false);
122+
ImMethod wrapperMethod = JassIm.ImMethod(e, subMethod.getMethodClass(), subMethod.getName() + "_wrapper", implementation, JassIm.ImMethods(), false);
124123
subClass.getMethods().add(wrapperMethod);
125124
superMethodIm.getSubMethods().add(wrapperMethod);
126125
}

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/StackTraceInjector2.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ public void transform(TimeTaker timeTaker) {
4545
final Multimap<ImFunction, ImFunctionCall> calls = LinkedListMultimap.create();
4646
// called function -> calling function
4747
final Multimap<ImFunction, ImFunction> callRelation = LinkedListMultimap.create();
48+
// calling function -> called function
49+
final Multimap<ImFunction, ImFunction> directCalls = LinkedListMultimap.create();
4850
final List<ImFuncRefOrCall> funcRefs = Lists.newArrayList();
4951
prog.accept(new ImProg.DefaultVisitor() {
5052

@@ -72,9 +74,19 @@ public void visit(ImFunctionCall c) {
7274
calls.put(c.getFunc(), c);
7375
ImFunction caller = c.getNearestFunc();
7476
callRelation.put(c.getFunc(), caller);
77+
directCalls.put(caller, c.getFunc());
7578
}
7679
}
7780

81+
@Override
82+
public void visit(ImMethodCall c) {
83+
super.visit(c);
84+
ImFunction caller = c.getNearestFunc();
85+
ImFunction callee = c.getMethod().getImplementation();
86+
directCalls.put(caller, callee);
87+
callRelation.put(callee, caller);
88+
}
89+
7890
@Override
7991
public void visit(ImFuncRef imFuncRef) {
8092
super.visit(imFuncRef);
@@ -97,6 +109,7 @@ public void visit(ImFuncRef imFuncRef) {
97109
Set<ImFunction> affectedFuncs = Sets.newLinkedHashSet(stackTraceGets.keySet());
98110

99111
if (tr.isLuaTarget()) {
112+
Set<ImFunction> configOnlyFuncs = getConfigOnlyFunctions(directCalls, callRelation);
100113
// in Lua all functions are potentially affected, because we don't know where Lua might crash
101114
Stream.concat(
102115
prog.getFunctions().stream(),
@@ -106,6 +119,7 @@ public void visit(ImFuncRef imFuncRef) {
106119
&& !f.hasFlag(FunctionFlagEnum.IS_BJ)
107120
&& !f.hasFlag(FunctionFlagEnum.IS_EXTERN))
108121
.collect(Collectors.toCollection(() -> affectedFuncs));
122+
affectedFuncs.removeAll(configOnlyFuncs);
109123

110124
} else {
111125
for (ImFunction stackTraceUse : stackTraceGets.keys()) {
@@ -123,6 +137,60 @@ public void visit(ImFuncRef imFuncRef) {
123137

124138
}
125139

140+
private Set<ImFunction> getFunctionsReachableFrom(String functionName, Multimap<ImFunction, ImFunction> directCalls) {
141+
Set<ImFunction> reachable = Sets.newLinkedHashSet();
142+
Deque<ImFunction> worklist = new ArrayDeque<>();
143+
Stream.concat(
144+
prog.getFunctions().stream(),
145+
prog.getClasses().stream().flatMap(c -> c.getFunctions().stream()))
146+
.filter(f -> f.getName().equals(functionName))
147+
.forEach(worklist::addLast);
148+
while (!worklist.isEmpty()) {
149+
ImFunction current = worklist.removeFirst();
150+
if (!reachable.add(current)) {
151+
continue;
152+
}
153+
for (ImFunction callee : directCalls.get(current)) {
154+
worklist.addLast(callee);
155+
}
156+
}
157+
return reachable;
158+
}
159+
160+
private Set<ImFunction> getConfigOnlyFunctions(Multimap<ImFunction, ImFunction> directCalls,
161+
Multimap<ImFunction, ImFunction> reverseCalls) {
162+
Set<ImFunction> configReachable = getFunctionsReachableFrom("config", directCalls);
163+
if (configReachable.isEmpty()) {
164+
return Collections.emptySet();
165+
}
166+
167+
Set<ImFunction> runtimeSharedFuncs = Sets.newLinkedHashSet();
168+
Deque<ImFunction> worklist = new ArrayDeque<>();
169+
for (ImFunction f : configReachable) {
170+
boolean hasNonConfigCaller = reverseCalls.get(f).stream()
171+
.anyMatch(caller -> !configReachable.contains(caller));
172+
if (hasNonConfigCaller) {
173+
worklist.addLast(f);
174+
}
175+
}
176+
177+
while (!worklist.isEmpty()) {
178+
ImFunction current = worklist.removeFirst();
179+
if (!runtimeSharedFuncs.add(current)) {
180+
continue;
181+
}
182+
for (ImFunction callee : directCalls.get(current)) {
183+
if (configReachable.contains(callee)) {
184+
worklist.addLast(callee);
185+
}
186+
}
187+
}
188+
189+
Set<ImFunction> configOnlyFuncs = Sets.newLinkedHashSet(configReachable);
190+
configOnlyFuncs.removeAll(runtimeSharedFuncs);
191+
return configOnlyFuncs;
192+
}
193+
126194
private void rewriteMethodCalls(Set<ImFunction> affectedFuncs) {
127195
if (!tr.isLuaTarget()) {
128196
return;

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/ExprTranslation.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,17 @@ public static LuaExpr translate(ImMemberAccess e, LuaTranslator tr) {
175175
}
176176

177177
public static LuaExpr translate(ImMethodCall e, LuaTranslator tr) {
178+
ImMethod method = e.getMethod();
179+
if (!method.getIsAbstract()
180+
&& method.getImplementation() != null
181+
&& method.getSubMethods().isEmpty()) {
182+
LuaExprlist args = LuaAst.LuaExprlist();
183+
args.add(e.getReceiver().translateToLua(tr));
184+
for (ImExpr arg : e.getArguments()) {
185+
args.add(arg.translateToLua(tr));
186+
}
187+
return LuaAst.LuaExprFunctionCall(tr.luaFunc.getFor(method.getImplementation()), args);
188+
}
178189
return LuaAst.LuaExprMethodCall(e.getReceiver().translateToLua(tr), tr.luaMethod.getFor(e.getMethod()), tr.translateExprList(e.getArguments()));
179190
}
180191

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java

Lines changed: 113 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,7 @@
22

33
import de.peeeq.datastructures.UnionFind;
44
import de.peeeq.wurstscript.WLogger;
5-
import de.peeeq.wurstscript.ast.ClassDef;
6-
import de.peeeq.wurstscript.ast.Element;
7-
import de.peeeq.wurstscript.ast.FuncDef;
8-
import de.peeeq.wurstscript.ast.NameDef;
9-
import de.peeeq.wurstscript.ast.WParameter;
10-
import de.peeeq.wurstscript.ast.WPackage;
5+
import de.peeeq.wurstscript.ast.*;
116
import de.peeeq.wurstscript.jassIm.*;
127
import de.peeeq.wurstscript.luaAst.*;
138
import de.peeeq.wurstscript.translation.imtranslation.FunctionFlagEnum;
@@ -220,7 +215,7 @@ public LuaMethod initFor(ImClass a) {
220215
private final Lazy<LuaFunction> errorFunc = Lazy.create(() ->
221216
this.getProg().getFunctions().stream()
222217
.flatMap(f -> {
223-
Element trace = f.attrTrace();
218+
de.peeeq.wurstscript.ast.Element trace = f.attrTrace();
224219
if (trace instanceof FuncDef) {
225220
FuncDef fd = (FuncDef) trace;
226221
if (fd.getName().equals("error")
@@ -438,7 +433,7 @@ private void collectPredefinedNames() {
438433
}
439434

440435
private void setNameFromTrace(JassImElementWithName named) {
441-
Element trace = named.attrTrace();
436+
de.peeeq.wurstscript.ast.Element trace = named.attrTrace();
442437
if (trace instanceof NameDef) {
443438
named.setName(((NameDef) trace).getName());
444439
}
@@ -459,9 +454,7 @@ private void normalizeMethodNames() {
459454
// give all related methods the same name in deterministic order
460455
List<List<ImMethod>> groups = new ArrayList<>();
461456
for (Set<ImMethod> group : methodUnions.groups().values()) {
462-
List<ImMethod> sortedGroup = new ArrayList<>(group);
463-
sortedGroup.sort(Comparator.comparing(this::methodSortKey));
464-
groups.add(sortedGroup);
457+
groups.addAll(partitionMethodsByDispatchSignature(group));
465458
}
466459
groups.sort(Comparator.comparing(g -> g.isEmpty() ? "" : methodSortKey(g.get(0))));
467460
for (List<ImMethod> group : groups) {
@@ -1026,7 +1019,7 @@ private void createClassInitFunction(ImClass c, LuaVariable classVar, LuaMethod
10261019
// local new_inst = { ... }
10271020
LuaTableFields initialFieldValues = LuaAst.LuaTableFields();
10281021
LuaVariable newInst = LuaAst.LuaVariable("new_inst", LuaAst.LuaTableConstructor(initialFieldValues));
1029-
for (ImVar field : c.getFields()) {
1022+
for (ImVar field : collectFieldsForAllocation(c)) {
10301023
initialFieldValues.add(
10311024
LuaAst.LuaTableNamedField(field.getName(), defaultValue(field.getType()))
10321025
);
@@ -1044,6 +1037,25 @@ private void createClassInitFunction(ImClass c, LuaVariable classVar, LuaMethod
10441037
body.add(LuaAst.LuaReturn(LuaAst.LuaExprVarAccess(newInst)));
10451038
}
10461039

1040+
private List<ImVar> collectFieldsForAllocation(ImClass c) {
1041+
List<ImVar> result = new ArrayList<>();
1042+
Set<ImClass> visited = new HashSet<>();
1043+
collectFieldsForAllocation(c, result, visited);
1044+
return result;
1045+
}
1046+
1047+
private void collectFieldsForAllocation(ImClass c, List<ImVar> out, Set<ImClass> visited) {
1048+
if (!visited.add(c)) {
1049+
return;
1050+
}
1051+
List<ImClassType> superClasses = new ArrayList<>(c.getSuperClasses());
1052+
superClasses.sort(Comparator.comparing(sc -> classSortKey(sc.getClassDef())));
1053+
for (ImClassType sc : superClasses) {
1054+
collectFieldsForAllocation(sc.getClassDef(), out, visited);
1055+
}
1056+
out.addAll(c.getFields());
1057+
}
1058+
10471059
private void initClassTables(ImClass c) {
10481060
LuaVariable classVar = luaClassVar.getFor(c);
10491061
// create methods:
@@ -1087,7 +1099,10 @@ private void createMethods(ImClass c, LuaVariable classVar) {
10871099
groupedMethods.computeIfAbsent(root, k -> new ArrayList<>()).add(method);
10881100
}
10891101

1090-
List<List<ImMethod>> groups = new ArrayList<>(groupedMethods.values());
1102+
List<List<ImMethod>> groups = new ArrayList<>();
1103+
for (List<ImMethod> methods : groupedMethods.values()) {
1104+
groups.addAll(partitionMethodsByDispatchSignature(methods));
1105+
}
10911106
groups.sort(Comparator.comparing(group -> group.isEmpty() ? "" : methodSortKey(group.get(0))));
10921107
Map<String, ImMethod> slotToImpl = new TreeMap<>();
10931108
for (List<ImMethod> groupMethods : groups) {
@@ -1122,6 +1137,58 @@ private void createMethods(ImClass c, LuaVariable classVar) {
11221137
}
11231138
}
11241139

1140+
private List<List<ImMethod>> partitionMethodsByDispatchSignature(Collection<ImMethod> methods) {
1141+
Map<String, List<ImMethod>> partitions = new TreeMap<>();
1142+
for (ImMethod method : methods) {
1143+
partitions.computeIfAbsent(dispatchSignatureKey(method), k -> new ArrayList<>()).add(method);
1144+
}
1145+
List<List<ImMethod>> result = new ArrayList<>(partitions.values());
1146+
for (List<ImMethod> group : result) {
1147+
group.sort(Comparator.comparing(this::methodSortKey));
1148+
}
1149+
result.sort(Comparator.comparing(group -> group.isEmpty() ? "" : methodSortKey(group.get(0))));
1150+
return result;
1151+
}
1152+
1153+
private String dispatchSignatureKey(ImMethod method) {
1154+
ImFunction implementation = resolveDispatchSignatureImplementation(method, new HashSet<>());
1155+
if (implementation == null) {
1156+
return "<abstract>";
1157+
}
1158+
StringBuilder sb = new StringBuilder();
1159+
sb.append(typeKey(implementation.getReturnType())).append("|");
1160+
ImVars params = implementation.getParameters();
1161+
for (int i = 1; i < params.size(); i++) {
1162+
if (i > 1) {
1163+
sb.append(",");
1164+
}
1165+
sb.append(typeKey(params.get(i).getType()));
1166+
}
1167+
return sb.toString();
1168+
}
1169+
1170+
private ImFunction resolveDispatchSignatureImplementation(ImMethod method, Set<ImMethod> visited) {
1171+
if (method == null || !visited.add(method)) {
1172+
return null;
1173+
}
1174+
if (method.getImplementation() != null) {
1175+
return method.getImplementation();
1176+
}
1177+
List<ImMethod> subMethods = new ArrayList<>(method.getSubMethods());
1178+
subMethods.sort(Comparator.comparing(this::methodSortKey));
1179+
for (ImMethod subMethod : subMethods) {
1180+
ImFunction resolved = resolveDispatchSignatureImplementation(subMethod, visited);
1181+
if (resolved != null) {
1182+
return resolved;
1183+
}
1184+
}
1185+
return null;
1186+
}
1187+
1188+
private String typeKey(ImType type) {
1189+
return type == null ? "<null>" : type.toString();
1190+
}
1191+
11251192
private Set<String> collectDispatchSlotNames(ImClass receiverClass, List<ImMethod> groupMethods) {
11261193
Set<String> slotNames = new TreeSet<>();
11271194
Set<String> semanticNames = new TreeSet<>();
@@ -1141,6 +1208,12 @@ private Set<String> collectDispatchSlotNames(ImClass receiverClass, List<ImMetho
11411208
if (owner != null && !semanticName.isEmpty()) {
11421209
slotNames.add(owner.getName() + "_" + semanticName);
11431210
}
1211+
String sourceSemanticName = sourceSemanticName(m);
1212+
if (owner != null && isClosureGeneratedClass(owner) && !sourceSemanticName.isEmpty()) {
1213+
semanticNames.add(sourceSemanticName);
1214+
slotNames.add(sourceSemanticName);
1215+
slotNames.add(owner.getName() + "_" + sourceSemanticName);
1216+
}
11441217
}
11451218
if (receiverClass != null && !semanticNames.isEmpty()) {
11461219
Set<String> classNames = new TreeSet<>();
@@ -1327,6 +1400,32 @@ private String semanticNameFromMethodName(String methodName) {
13271400
return methodName;
13281401
}
13291402

1403+
private boolean isClosureGeneratedClass(ImClass c) {
1404+
return c != null && c.attrTrace() instanceof ExprClosure;
1405+
}
1406+
1407+
private String sourceSemanticName(ImMethod method) {
1408+
if (method == null) {
1409+
return "";
1410+
}
1411+
de.peeeq.wurstscript.ast.Element trace = method.attrTrace();
1412+
if (trace instanceof FuncDef funcDef) {
1413+
return funcDef.getName();
1414+
}
1415+
if (trace instanceof AstElementWithFuncName withFuncName) {
1416+
return withFuncName.getFuncNameId().getName();
1417+
}
1418+
if (method.getImplementation() != null) {
1419+
String implementationName = method.getImplementation().getName();
1420+
int firstUnderscore = implementationName.indexOf('_');
1421+
if (firstUnderscore > 0) {
1422+
return implementationName.substring(0, firstUnderscore);
1423+
}
1424+
return implementationName;
1425+
}
1426+
return "";
1427+
}
1428+
13301429
private String classSortKey(ImClass c) {
13311430
if (c == null) {
13321431
return "";
@@ -1471,7 +1570,7 @@ public LuaFunction getErrorFunc() {
14711570
}
14721571

14731572
public String getTypeCastingFunctionName(ImFunction f) {
1474-
Element trace = f.attrTrace();
1573+
de.peeeq.wurstscript.ast.Element trace = f.attrTrace();
14751574
if (trace instanceof FuncDef fd && fd.attrNearestPackage() instanceof WPackage p) {
14761575
if ("TypeCasting".equals(p.getName())) {
14771576
return fd.getName();

0 commit comments

Comments
 (0)