Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ public static void addOverride(
ImFunction implementation = JassIm.ImFunction(e, subMethod.getName() + "_wrapper", JassIm.ImTypeVars(), parameters, rType, locals, body, flags);
tr.getImProg().getFunctions().add(implementation);

List<ImMethod> subMethods = Collections.emptyList();
ImMethod wrapperMethod = JassIm.ImMethod(e, subMethod.getMethodClass(), subMethod.getName() + "_wrapper", implementation, subMethods, false);
ImMethod wrapperMethod = JassIm.ImMethod(e, subMethod.getMethodClass(), subMethod.getName() + "_wrapper", implementation, JassIm.ImMethods(), false);
subClass.getMethods().add(wrapperMethod);
superMethodIm.getSubMethods().add(wrapperMethod);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public void transform(TimeTaker timeTaker) {
final Multimap<ImFunction, ImFunctionCall> calls = LinkedListMultimap.create();
// called function -> calling function
final Multimap<ImFunction, ImFunction> callRelation = LinkedListMultimap.create();
// calling function -> called function
final Multimap<ImFunction, ImFunction> directCalls = LinkedListMultimap.create();
final List<ImFuncRefOrCall> funcRefs = Lists.newArrayList();
prog.accept(new ImProg.DefaultVisitor() {

Expand Down Expand Up @@ -72,9 +74,19 @@ public void visit(ImFunctionCall c) {
calls.put(c.getFunc(), c);
ImFunction caller = c.getNearestFunc();
callRelation.put(c.getFunc(), caller);
directCalls.put(caller, c.getFunc());
}
}

@Override
public void visit(ImMethodCall c) {
super.visit(c);
ImFunction caller = c.getNearestFunc();
ImFunction callee = c.getMethod().getImplementation();
directCalls.put(caller, callee);
callRelation.put(callee, caller);
}

@Override
public void visit(ImFuncRef imFuncRef) {
super.visit(imFuncRef);
Expand All @@ -97,6 +109,7 @@ public void visit(ImFuncRef imFuncRef) {
Set<ImFunction> affectedFuncs = Sets.newLinkedHashSet(stackTraceGets.keySet());

if (tr.isLuaTarget()) {
Set<ImFunction> configOnlyFuncs = getConfigOnlyFunctions(directCalls, callRelation);
// in Lua all functions are potentially affected, because we don't know where Lua might crash
Stream.concat(
prog.getFunctions().stream(),
Expand All @@ -106,6 +119,7 @@ public void visit(ImFuncRef imFuncRef) {
&& !f.hasFlag(FunctionFlagEnum.IS_BJ)
&& !f.hasFlag(FunctionFlagEnum.IS_EXTERN))
.collect(Collectors.toCollection(() -> affectedFuncs));
affectedFuncs.removeAll(configOnlyFuncs);

} else {
for (ImFunction stackTraceUse : stackTraceGets.keys()) {
Expand All @@ -123,6 +137,60 @@ public void visit(ImFuncRef imFuncRef) {

}

private Set<ImFunction> getFunctionsReachableFrom(String functionName, Multimap<ImFunction, ImFunction> directCalls) {
Set<ImFunction> reachable = Sets.newLinkedHashSet();
Deque<ImFunction> worklist = new ArrayDeque<>();
Stream.concat(
prog.getFunctions().stream(),
prog.getClasses().stream().flatMap(c -> c.getFunctions().stream()))
.filter(f -> f.getName().equals(functionName))
.forEach(worklist::addLast);
while (!worklist.isEmpty()) {
ImFunction current = worklist.removeFirst();
if (!reachable.add(current)) {
continue;
}
for (ImFunction callee : directCalls.get(current)) {
worklist.addLast(callee);
}
}
return reachable;
}

private Set<ImFunction> getConfigOnlyFunctions(Multimap<ImFunction, ImFunction> directCalls,
Multimap<ImFunction, ImFunction> reverseCalls) {
Set<ImFunction> configReachable = getFunctionsReachableFrom("config", directCalls);
if (configReachable.isEmpty()) {
return Collections.emptySet();
}

Set<ImFunction> runtimeSharedFuncs = Sets.newLinkedHashSet();
Deque<ImFunction> worklist = new ArrayDeque<>();
for (ImFunction f : configReachable) {
boolean hasNonConfigCaller = reverseCalls.get(f).stream()
.anyMatch(caller -> !configReachable.contains(caller));
if (hasNonConfigCaller) {
worklist.addLast(f);
}
}

while (!worklist.isEmpty()) {
ImFunction current = worklist.removeFirst();
if (!runtimeSharedFuncs.add(current)) {
continue;
}
for (ImFunction callee : directCalls.get(current)) {
if (configReachable.contains(callee)) {
worklist.addLast(callee);
}
}
}

Set<ImFunction> configOnlyFuncs = Sets.newLinkedHashSet(configReachable);
configOnlyFuncs.removeAll(runtimeSharedFuncs);
return configOnlyFuncs;
}

private void rewriteMethodCalls(Set<ImFunction> affectedFuncs) {
if (!tr.isLuaTarget()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,17 @@ public static LuaExpr translate(ImMemberAccess e, LuaTranslator tr) {
}

public static LuaExpr translate(ImMethodCall e, LuaTranslator tr) {
ImMethod method = e.getMethod();
if (!method.getIsAbstract()
&& method.getImplementation() != null
&& method.getSubMethods().isEmpty()) {
LuaExprlist args = LuaAst.LuaExprlist();
args.add(e.getReceiver().translateToLua(tr));
for (ImExpr arg : e.getArguments()) {
args.add(arg.translateToLua(tr));
}
return LuaAst.LuaExprFunctionCall(tr.luaFunc.getFor(method.getImplementation()), args);
}
return LuaAst.LuaExprMethodCall(e.getReceiver().translateToLua(tr), tr.luaMethod.getFor(e.getMethod()), tr.translateExprList(e.getArguments()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@

import de.peeeq.datastructures.UnionFind;
import de.peeeq.wurstscript.WLogger;
import de.peeeq.wurstscript.ast.ClassDef;
import de.peeeq.wurstscript.ast.Element;
import de.peeeq.wurstscript.ast.FuncDef;
import de.peeeq.wurstscript.ast.NameDef;
import de.peeeq.wurstscript.ast.WParameter;
import de.peeeq.wurstscript.ast.WPackage;
import de.peeeq.wurstscript.ast.*;
import de.peeeq.wurstscript.jassIm.*;
import de.peeeq.wurstscript.luaAst.*;
import de.peeeq.wurstscript.translation.imtranslation.FunctionFlagEnum;
Expand Down Expand Up @@ -220,7 +215,7 @@ public LuaMethod initFor(ImClass a) {
private final Lazy<LuaFunction> errorFunc = Lazy.create(() ->
this.getProg().getFunctions().stream()
.flatMap(f -> {
Element trace = f.attrTrace();
de.peeeq.wurstscript.ast.Element trace = f.attrTrace();
if (trace instanceof FuncDef) {
FuncDef fd = (FuncDef) trace;
if (fd.getName().equals("error")
Expand Down Expand Up @@ -438,7 +433,7 @@ private void collectPredefinedNames() {
}

private void setNameFromTrace(JassImElementWithName named) {
Element trace = named.attrTrace();
de.peeeq.wurstscript.ast.Element trace = named.attrTrace();
if (trace instanceof NameDef) {
named.setName(((NameDef) trace).getName());
}
Expand All @@ -459,9 +454,7 @@ private void normalizeMethodNames() {
// give all related methods the same name in deterministic order
List<List<ImMethod>> groups = new ArrayList<>();
for (Set<ImMethod> group : methodUnions.groups().values()) {
List<ImMethod> sortedGroup = new ArrayList<>(group);
sortedGroup.sort(Comparator.comparing(this::methodSortKey));
groups.add(sortedGroup);
groups.addAll(partitionMethodsByDispatchSignature(group));
}
groups.sort(Comparator.comparing(g -> g.isEmpty() ? "" : methodSortKey(g.get(0))));
for (List<ImMethod> group : groups) {
Expand Down Expand Up @@ -1026,7 +1019,7 @@ private void createClassInitFunction(ImClass c, LuaVariable classVar, LuaMethod
// local new_inst = { ... }
LuaTableFields initialFieldValues = LuaAst.LuaTableFields();
LuaVariable newInst = LuaAst.LuaVariable("new_inst", LuaAst.LuaTableConstructor(initialFieldValues));
for (ImVar field : c.getFields()) {
for (ImVar field : collectFieldsForAllocation(c)) {
initialFieldValues.add(
LuaAst.LuaTableNamedField(field.getName(), defaultValue(field.getType()))
);
Expand All @@ -1044,6 +1037,25 @@ private void createClassInitFunction(ImClass c, LuaVariable classVar, LuaMethod
body.add(LuaAst.LuaReturn(LuaAst.LuaExprVarAccess(newInst)));
}

private List<ImVar> collectFieldsForAllocation(ImClass c) {
List<ImVar> result = new ArrayList<>();
Set<ImClass> visited = new HashSet<>();
collectFieldsForAllocation(c, result, visited);
return result;
}

private void collectFieldsForAllocation(ImClass c, List<ImVar> out, Set<ImClass> visited) {
if (!visited.add(c)) {
return;
}
List<ImClassType> superClasses = new ArrayList<>(c.getSuperClasses());
superClasses.sort(Comparator.comparing(sc -> classSortKey(sc.getClassDef())));
for (ImClassType sc : superClasses) {
collectFieldsForAllocation(sc.getClassDef(), out, visited);
}
out.addAll(c.getFields());
}

private void initClassTables(ImClass c) {
LuaVariable classVar = luaClassVar.getFor(c);
// create methods:
Expand Down Expand Up @@ -1087,7 +1099,10 @@ private void createMethods(ImClass c, LuaVariable classVar) {
groupedMethods.computeIfAbsent(root, k -> new ArrayList<>()).add(method);
}

List<List<ImMethod>> groups = new ArrayList<>(groupedMethods.values());
List<List<ImMethod>> groups = new ArrayList<>();
for (List<ImMethod> methods : groupedMethods.values()) {
groups.addAll(partitionMethodsByDispatchSignature(methods));
}
groups.sort(Comparator.comparing(group -> group.isEmpty() ? "" : methodSortKey(group.get(0))));
Map<String, ImMethod> slotToImpl = new TreeMap<>();
for (List<ImMethod> groupMethods : groups) {
Expand Down Expand Up @@ -1122,6 +1137,58 @@ private void createMethods(ImClass c, LuaVariable classVar) {
}
}

private List<List<ImMethod>> partitionMethodsByDispatchSignature(Collection<ImMethod> methods) {
Map<String, List<ImMethod>> partitions = new TreeMap<>();
for (ImMethod method : methods) {
partitions.computeIfAbsent(dispatchSignatureKey(method), k -> new ArrayList<>()).add(method);
}
List<List<ImMethod>> result = new ArrayList<>(partitions.values());
for (List<ImMethod> group : result) {
group.sort(Comparator.comparing(this::methodSortKey));
}
result.sort(Comparator.comparing(group -> group.isEmpty() ? "" : methodSortKey(group.get(0))));
return result;
}

private String dispatchSignatureKey(ImMethod method) {
ImFunction implementation = resolveDispatchSignatureImplementation(method, new HashSet<>());
if (implementation == null) {
return "<abstract>";
}
StringBuilder sb = new StringBuilder();
sb.append(typeKey(implementation.getReturnType())).append("|");
ImVars params = implementation.getParameters();
for (int i = 1; i < params.size(); i++) {
if (i > 1) {
sb.append(",");
}
sb.append(typeKey(params.get(i).getType()));
}
return sb.toString();
}

private ImFunction resolveDispatchSignatureImplementation(ImMethod method, Set<ImMethod> visited) {
if (method == null || !visited.add(method)) {
return null;
}
if (method.getImplementation() != null) {
return method.getImplementation();
}
List<ImMethod> subMethods = new ArrayList<>(method.getSubMethods());
subMethods.sort(Comparator.comparing(this::methodSortKey));
for (ImMethod subMethod : subMethods) {
ImFunction resolved = resolveDispatchSignatureImplementation(subMethod, visited);
if (resolved != null) {
return resolved;
}
}
return null;
}

private String typeKey(ImType type) {
return type == null ? "<null>" : type.toString();
}

private Set<String> collectDispatchSlotNames(ImClass receiverClass, List<ImMethod> groupMethods) {
Set<String> slotNames = new TreeSet<>();
Set<String> semanticNames = new TreeSet<>();
Expand All @@ -1141,6 +1208,12 @@ private Set<String> collectDispatchSlotNames(ImClass receiverClass, List<ImMetho
if (owner != null && !semanticName.isEmpty()) {
slotNames.add(owner.getName() + "_" + semanticName);
}
String sourceSemanticName = sourceSemanticName(m);
if (owner != null && isClosureGeneratedClass(owner) && !sourceSemanticName.isEmpty()) {
semanticNames.add(sourceSemanticName);
slotNames.add(sourceSemanticName);
slotNames.add(owner.getName() + "_" + sourceSemanticName);
}
}
if (receiverClass != null && !semanticNames.isEmpty()) {
Set<String> classNames = new TreeSet<>();
Expand Down Expand Up @@ -1327,6 +1400,32 @@ private String semanticNameFromMethodName(String methodName) {
return methodName;
}

private boolean isClosureGeneratedClass(ImClass c) {
return c != null && c.attrTrace() instanceof ExprClosure;
}

private String sourceSemanticName(ImMethod method) {
if (method == null) {
return "";
}
de.peeeq.wurstscript.ast.Element trace = method.attrTrace();
if (trace instanceof FuncDef funcDef) {
return funcDef.getName();
}
if (trace instanceof AstElementWithFuncName withFuncName) {
return withFuncName.getFuncNameId().getName();
}
if (method.getImplementation() != null) {
String implementationName = method.getImplementation().getName();
int firstUnderscore = implementationName.indexOf('_');
if (firstUnderscore > 0) {
return implementationName.substring(0, firstUnderscore);
}
return implementationName;
}
return "";
}

private String classSortKey(ImClass c) {
if (c == null) {
return "";
Expand Down Expand Up @@ -1471,7 +1570,7 @@ public LuaFunction getErrorFunc() {
}

public String getTypeCastingFunctionName(ImFunction f) {
Element trace = f.attrTrace();
de.peeeq.wurstscript.ast.Element trace = f.attrTrace();
if (trace instanceof FuncDef fd && fd.attrNearestPackage() instanceof WPackage p) {
if ("TypeCasting".equals(p.getName())) {
return fd.getName();
Expand Down
Loading
Loading