Skip to content

Commit b7cb017

Browse files
authored
fix multi inline (#1171)
1 parent 846b91b commit b7cb017

File tree

5 files changed

+290
-72
lines changed

5 files changed

+290
-72
lines changed

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/intermediatelang/optimizer/ConstantAndCopyPropagation.java

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -309,31 +309,19 @@ private void analyzeComponent(List<Node> scc, Map<Node, Knowledge> knowledge) {
309309
} else {
310310
Value newValue = null;
311311

312-
// Try constant folding first
313-
ImExpr foldedExpr = tryConstantFold(right, newOut);
314-
if (foldedExpr != null && foldedExpr != right) {
315-
// We successfully folded to a constant
316-
newValue = Value.tryValue(foldedExpr);
317-
if (newValue != null) {
318-
// Replace the RHS with the folded constant in the AST
319-
right.replaceBy(foldedExpr);
320-
}
321-
}
322-
323-
// If no folding happened, try regular value propagation
324-
if (newValue == null) {
325-
if (right instanceof ImConst) {
326-
newValue = Value.tryValue(right);
327-
} else if (right instanceof ImVarAccess) {
328-
ImVar varRight = ((ImVarAccess) right).getVar();
329-
if(newOut.containsKey(varRight)) {
330-
newValue = newOut.get(varRight).getOrNull();
331-
} else {
332-
newValue = Value.tryValue(right);
333-
}
334-
} else if(right instanceof ImTupleExpr) {
312+
// Constant folding is intentionally centralized in SimpleRewrites.
313+
// This pass performs propagation only to keep fold semantics in one place.
314+
if (right instanceof ImConst) {
315+
newValue = Value.tryValue(right);
316+
} else if (right instanceof ImVarAccess) {
317+
ImVar varRight = ((ImVarAccess) right).getVar();
318+
if(newOut.containsKey(varRight)) {
319+
newValue = newOut.get(varRight).getOrNull();
320+
} else {
335321
newValue = Value.tryValue(right);
336322
}
323+
} else if(right instanceof ImTupleExpr) {
324+
newValue = Value.tryValue(right);
337325
}
338326

339327
if (newValue == null) {

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/ImInliner.java

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,6 @@ private String skipReason(ImFunction caller, ImFunctionCall call, ImFunction f)
116116
if (call.getCallType() == CallType.EXECUTE) {
117117
return "execute_call";
118118
}
119-
if (translator.isLuaTarget() && !maxOneReturn(f)) {
120-
return "lua_multi_return_inline_disabled";
121-
}
122119
if (translator.isLuaTarget() && containsFuncRef(f)) {
123120
return "lua_callback_funcref_barrier";
124121
}
@@ -241,14 +238,14 @@ public void visit(ImFunctionCall called) {
241238
private ImStmts rewriteForEarlyReturns(ImStmts body, ImVar doneVar, ImVar retVar) {
242239
ImStmts rewritten = JassIm.ImStmts();
243240
for (ImStmt s : body) {
244-
ImStmt transformed = rewriteStmtForEarlyReturn(s, doneVar, retVar);
241+
ImStmts transformed = rewriteStmtForEarlyReturn(s, doneVar, retVar);
245242
ImExpr notDone = JassIm.ImOperatorCall(de.peeeq.wurstscript.WurstOperator.NOT, JassIm.ImExprs(JassIm.ImVarAccess(doneVar)));
246-
rewritten.add(JassIm.ImIf(s.attrTrace(), notDone, JassIm.ImStmts(transformed), JassIm.ImStmts()));
243+
rewritten.add(JassIm.ImIf(s.attrTrace(), notDone, transformed, JassIm.ImStmts()));
247244
}
248245
return rewritten;
249246
}
250247

251-
private ImStmt rewriteStmtForEarlyReturn(ImStmt s, ImVar doneVar, ImVar retVar) {
248+
private ImStmts rewriteStmtForEarlyReturn(ImStmt s, ImVar doneVar, ImVar retVar) {
252249
if (s instanceof ImReturn) {
253250
ImReturn r = (ImReturn) s;
254251
ImStmts b = JassIm.ImStmts();
@@ -258,27 +255,27 @@ private ImStmt rewriteStmtForEarlyReturn(ImStmt s, ImVar doneVar, ImVar retVar)
258255
b.add(JassIm.ImSet(r.getTrace(), JassIm.ImVarAccess(retVar), rv));
259256
}
260257
b.add(JassIm.ImSet(r.getTrace(), JassIm.ImVarAccess(doneVar), JassIm.ImBoolVal(true)));
261-
return ImHelper.statementExprVoid(b);
258+
return b;
262259
} else if (s instanceof ImIf) {
263260
ImIf imIf = (ImIf) s;
264261
ImStmts thenBlock = rewriteForEarlyReturns(imIf.getThenBlock().copy(), doneVar, retVar);
265262
ImStmts elseBlock = rewriteForEarlyReturns(imIf.getElseBlock().copy(), doneVar, retVar);
266-
return JassIm.ImIf(imIf.getTrace(), imIf.getCondition().copy(), thenBlock, elseBlock);
263+
return JassIm.ImStmts(JassIm.ImIf(imIf.getTrace(), imIf.getCondition().copy(), thenBlock, elseBlock));
267264
} else if (s instanceof ImLoop) {
268265
ImLoop l = (ImLoop) s;
269266
ImStmts loopBody = JassIm.ImStmts();
270267
loopBody.add(JassIm.ImExitwhen(l.getTrace(), JassIm.ImVarAccess(doneVar)));
271268
loopBody.addAll(rewriteForEarlyReturns(l.getBody().copy(), doneVar, retVar).removeAll());
272-
return JassIm.ImLoop(l.getTrace(), loopBody);
269+
return JassIm.ImStmts(JassIm.ImLoop(l.getTrace(), loopBody));
273270
} else if (s instanceof ImVarargLoop) {
274271
ImVarargLoop l = (ImVarargLoop) s;
275272
ImStmts loopBody = JassIm.ImStmts();
276273
loopBody.add(JassIm.ImExitwhen(l.getTrace(), JassIm.ImVarAccess(doneVar)));
277274
loopBody.addAll(rewriteForEarlyReturns(l.getBody().copy(), doneVar, retVar).removeAll());
278-
return JassIm.ImVarargLoop(l.getTrace(), loopBody, l.getLoopVar());
275+
return JassIm.ImStmts(JassIm.ImVarargLoop(l.getTrace(), loopBody, l.getLoopVar()));
279276
}
280277
// Keep tree ownership valid when rewrapping statements into new blocks.
281-
return s.copy();
278+
return JassIm.ImStmts(s.copy());
282279
}
283280

284281
private void rateInlinableFunctions() {
@@ -338,11 +335,6 @@ private boolean shouldInline(ImFunction caller, ImFunctionCall call, ImFunction
338335
if (f.isNative() || call.getCallType() == CallType.EXECUTE) {
339336
return false;
340337
}
341-
if (translator.isLuaTarget() && !maxOneReturn(f)) {
342-
// Conservative safety: Lua inliner multi-return rewriting is not yet fully robust
343-
// across all lowered patterns. Keep call semantics intact for now.
344-
return false;
345-
}
346338
if (translator.isLuaTarget() && containsFuncRef(f)) {
347339
// Functions that build callback refs are lowered with Lua-specific wrappers/xpcall.
348340
// Keeping them as standalone calls avoids callback context/vararg scope breakage.

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ public class LuaTranslator {
101101

102102
final ImProg prog;
103103
final LuaCompilationUnit luaModel;
104+
private final LuaStatements deferredMainInit = LuaAst.LuaStatements();
104105
private final Set<String> usedNames = new HashSet<>(Arrays.asList(
105106
// reserved function names
106107
"print", "tostring", "error",
@@ -180,7 +181,7 @@ public LuaMethod initFor(ImMethod a) {
180181
GetAForB<ImClass, LuaVariable> luaClassVar = new GetAForB<ImClass, LuaVariable>() {
181182
@Override
182183
public LuaVariable initFor(ImClass a) {
183-
return LuaAst.LuaVariable(uniqueName(a.getName()), LuaAst.LuaNoExpr());
184+
return LuaAst.LuaVariable(uniqueName(a.getName()), LuaAst.LuaTableConstructor(LuaAst.LuaTableFields()));
184185
}
185186
};
186187

@@ -321,9 +322,10 @@ public LuaCompilationUnit translate() {
321322
initClassTables(c);
322323
}
323324

325+
emitExperimentalHashtableLeakGuards();
326+
prependDeferredMainInitToMain();
324327
cleanStatements();
325328
enforceLuaLocalLimits();
326-
emitExperimentalHashtableLeakGuards();
327329

328330
return luaModel;
329331
}
@@ -353,12 +355,39 @@ private void ensureWurstContextCallbackHelpers() {
353355
}
354356

355357
private void emitExperimentalHashtableLeakGuards() {
356-
luaModel.add(LuaAst.LuaLiteral("-- Wurst experimental Lua assertion guards: raw WC3 hashtable natives must not be called."));
358+
deferMainInit(LuaAst.LuaLiteral("-- Wurst experimental Lua assertion guards: raw WC3 hashtable natives must not be called."));
359+
deferMainInit(LuaAst.LuaLiteral("do"));
360+
deferMainInit(LuaAst.LuaLiteral(" local __wurst_guard_ok = pcall(function()"));
357361
for (String nativeName : allHashtableNativeNames()) {
358-
luaModel.add(LuaAst.LuaLiteral("if " + nativeName + " ~= nil then " + nativeName
362+
deferMainInit(LuaAst.LuaLiteral(" if " + nativeName + " ~= nil then " + nativeName
359363
+ " = function(...) error(\"Wurst Lua assertion failed: unexpected call to native " + nativeName
360364
+ ". Expected __wurst_" + nativeName + ".\") end end"));
361365
}
366+
deferMainInit(LuaAst.LuaLiteral(" end)"));
367+
deferMainInit(LuaAst.LuaLiteral(" if not __wurst_guard_ok then"));
368+
deferMainInit(LuaAst.LuaLiteral(" -- Some Lua runtimes lock native globals. Compile-time leak checks stay authoritative."));
369+
deferMainInit(LuaAst.LuaLiteral(" end"));
370+
deferMainInit(LuaAst.LuaLiteral("end"));
371+
}
372+
373+
private void deferMainInit(LuaStatement statement) {
374+
deferredMainInit.add(statement);
375+
}
376+
377+
private void prependDeferredMainInitToMain() {
378+
if (deferredMainInit.isEmpty()) {
379+
return;
380+
}
381+
ImFunction mainIm = imTr.getMainFunc();
382+
if (mainIm == null) {
383+
return;
384+
}
385+
LuaFunction mainLua = luaFunc.getFor(mainIm);
386+
LuaStatements mainBody = mainLua.getBody();
387+
for (int i = deferredMainInit.size() - 1; i >= 0; i--) {
388+
LuaStatement stmt = deferredMainInit.remove(i);
389+
mainBody.add(0, stmt);
390+
}
362391
}
363392

364393
public static void assertNoLeakedHashtableNativeCalls(String luaCode) {
@@ -534,15 +563,17 @@ private void createInstanceOfFunction() {
534563

535564
private void createObjectIndexFunctions() {
536565
String vName = "__wurst_objectIndexMap";
537-
LuaVariable v = LuaAst.LuaVariable(vName, LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
538-
LuaAst.LuaTableNamedField("counter", LuaAst.LuaExprIntVal("0"))
539-
)));
566+
LuaVariable v = LuaAst.LuaVariable(vName, LuaAst.LuaExprNull());
540567
luaModel.add(v);
541-
542-
LuaVariable im = LuaAst.LuaVariable("__wurst_number_wrapper_map", LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
568+
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprVarAccess(v), LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
543569
LuaAst.LuaTableNamedField("counter", LuaAst.LuaExprIntVal("0"))
544-
)));
570+
))));
571+
572+
LuaVariable im = LuaAst.LuaVariable("__wurst_number_wrapper_map", LuaAst.LuaExprNull());
545573
luaModel.add(im);
574+
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprVarAccess(im), LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
575+
LuaAst.LuaTableNamedField("counter", LuaAst.LuaExprIntVal("0"))
576+
))));
546577

547578
{
548579
String[] code = {
@@ -597,12 +628,13 @@ private void createObjectIndexFunctions() {
597628
}
598629

599630
private void createStringIndexFunctions() {
600-
LuaVariable map = LuaAst.LuaVariable("__wurst_string_index_map", LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
631+
LuaVariable map = LuaAst.LuaVariable("__wurst_string_index_map", LuaAst.LuaExprNull());
632+
luaModel.add(map);
633+
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprVarAccess(map), LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
601634
LuaAst.LuaTableNamedField("counter", LuaAst.LuaExprIntVal("0")),
602635
LuaAst.LuaTableNamedField("byString", LuaAst.LuaTableConstructor(LuaAst.LuaTableFields())),
603636
LuaAst.LuaTableNamedField("byIndex", LuaAst.LuaTableConstructor(LuaAst.LuaTableFields()))
604-
)));
605-
luaModel.add(map);
637+
))));
606638

607639
{
608640
String[] code = {
@@ -995,8 +1027,6 @@ private void translateClass(ImClass c) {
9951027

9961028
luaModel.add(initMethod);
9971029

998-
classVar.setInitialValue(emptyTable());
999-
10001030
// translate functions
10011031
for (ImFunction f : c.getFunctions()) {
10021032
translateFunc(f);
@@ -1038,14 +1068,14 @@ private void initClassTables(ImClass c) {
10381068
// set supertype metadata:
10391069
LuaTableFields superClasses = LuaAst.LuaTableFields();
10401070
collectSuperClasses(superClasses, c, new HashSet<>());
1041-
luaModel.add(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
1071+
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
10421072
LuaAst.LuaExprVarAccess(classVar),
10431073
WURST_SUPERTYPES),
10441074
LuaAst.LuaTableConstructor(superClasses)
10451075
));
10461076

10471077
// set typeid metadata:
1048-
luaModel.add(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
1078+
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
10491079
LuaAst.LuaExprVarAccess(classVar),
10501080
ExprTranslation.TYPE_ID),
10511081
LuaAst.LuaExprIntVal("" + prog.attrTypeId().get(c))
@@ -1100,7 +1130,7 @@ private void createMethods(ImClass c, LuaVariable classVar) {
11001130
if (impl == null || impl.getImplementation() == null) {
11011131
continue;
11021132
}
1103-
luaModel.add(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
1133+
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
11041134
LuaAst.LuaExprVarAccess(classVar),
11051135
e.getKey()),
11061136
LuaAst.LuaExprFuncRef(luaFunc.getFor(impl.getImplementation()))
@@ -1343,8 +1373,9 @@ private void translateGlobal(ImVar v) {
13431373
return;
13441374
}
13451375
LuaVariable lv = luaVar.getFor(v);
1346-
lv.setInitialValue(defaultValue(v.getType()));
1376+
lv.setInitialValue(LuaAst.LuaExprNull());
13471377
luaModel.add(lv);
1378+
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprVarAccess(lv), defaultValue(v.getType())));
13481379
}
13491380

13501381
private LuaExpr defaultValue(ImType type) {

0 commit comments

Comments
 (0)