Skip to content

Commit bde3a82

Browse files
committed
lua cleanup
1 parent 014f2e2 commit bde3a82

File tree

8 files changed

+751
-396
lines changed

8 files changed

+751
-396
lines changed

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstio/WurstCompilerJassImpl.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,13 @@ public LuaCompilationUnit transformProgToLua() {
889889
}
890890
ImTranslator imTranslator2 = getImTranslator();
891891
ImOptimizer optimizer = new ImOptimizer(timeTaker, imTranslator2);
892+
893+
// Lower Lua-specific native calls into IM-level wrappers before optimization,
894+
// so the optimizer can inline and eliminate the nil-safety checks and remapped stubs.
895+
beginPhase(4, "lua native lowering");
896+
LuaNativeLowering.transform(imProg);
897+
timeTaker.endPhase();
898+
892899
// inliner
893900
stage = 5;
894901
if (runArgs.isInline()) {

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstio/languageserver/requests/MapRequest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ protected File compileMap(File projectFolder, WurstGui gui, Optional<File> mapCo
169169

170170
String compiledMapScript = sb.toString();
171171
LuaTranslator.assertNoLeakedHashtableNativeCalls(compiledMapScript);
172+
LuaTranslator.assertNoLeakedGetHandleIdCalls(compiledMapScript);
172173
File buildDir = getBuildDir();
173174
File outFile = new File(buildDir, BUILD_COMPILED_LUA_NAME);
174175
Files.write(compiledMapScript.getBytes(Charsets.UTF_8), outFile);
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
package de.peeeq.wurstscript.translation.imtranslation;
2+
3+
import de.peeeq.wurstscript.WurstOperator;
4+
import de.peeeq.wurstscript.jassIm.*;
5+
6+
import java.util.*;
7+
8+
/**
9+
* IM-level lowering pass for the Lua backend, run before optimization so the
10+
* optimizer can inline and eliminate the generated wrappers.
11+
*
12+
* <p>Three classes of WC3 BJ calls are transformed:
13+
* <ol>
14+
* <li><b>GetHandleId</b> – replaced 1:1 by {@code __wurst_GetHandleId}, whose Lua
15+
* implementation uses a stable table counter instead of the WC3 handle ID
16+
* (which can desync in Lua mode).</li>
17+
* <li><b>Hashtable natives</b> ({@code SaveInteger}, {@code LoadBoolean}, …) and
18+
* <b>context-callback natives</b> ({@code ForForce}, {@code ForGroup}, …) –
19+
* replaced 1:1 by their {@code __wurst_} prefixed equivalents, whose Lua
20+
* implementations are provided by {@link de.peeeq.wurstscript.translation.lua.translation.LuaNatives}.</li>
21+
* <li><b>All other BJ calls with at least one handle-typed parameter</b> – wrapped
22+
* by a generated IM function that first checks each handle param for {@code null}
23+
* and returns the type-appropriate default (0 / 0.0 / false / "" / nil), then
24+
* delegates to the original BJ function. This matches Jass behavior, which
25+
* silently returns defaults on null-handle calls instead of crashing.</li>
26+
* </ol>
27+
*
28+
* <p>IS_NATIVE stubs added for category 1 and 2 are recognised by
29+
* {@link de.peeeq.wurstscript.translation.lua.translation.LuaTranslator#translateFunc} as
30+
* Wurst-owned natives and filled in by
31+
* {@link de.peeeq.wurstscript.translation.lua.translation.LuaNatives}.
32+
*/
33+
public final class LuaNativeLowering {
34+
35+
/** Hashtable native names that need to be remapped to {@code __wurst_} equivalents. */
36+
private static final Set<String> HASHTABLE_NATIVE_NAMES = new HashSet<>(Arrays.asList(
37+
"InitHashtable",
38+
"SaveInteger", "SaveBoolean", "SaveReal", "SaveStr",
39+
"LoadInteger", "LoadBoolean", "LoadReal", "LoadStr",
40+
"HaveSavedInteger", "HaveSavedBoolean", "HaveSavedReal", "HaveSavedString", "HaveSavedHandle",
41+
"FlushChildHashtable", "FlushParentHashtable",
42+
"RemoveSavedInteger", "RemoveSavedBoolean", "RemoveSavedReal", "RemoveSavedString", "RemoveSavedHandle",
43+
// Handle-typed save/load variants
44+
"SavePlayerHandle", "SaveWidgetHandle", "SaveDestructableHandle", "SaveItemHandle", "SaveUnitHandle",
45+
"SaveAbilityHandle", "SaveTimerHandle", "SaveTriggerHandle", "SaveTriggerConditionHandle",
46+
"SaveTriggerActionHandle", "SaveTriggerEventHandle", "SaveForceHandle", "SaveGroupHandle",
47+
"SaveLocationHandle", "SaveRectHandle", "SaveBooleanExprHandle", "SaveSoundHandle", "SaveEffectHandle",
48+
"SaveUnitPoolHandle", "SaveItemPoolHandle", "SaveQuestHandle", "SaveQuestItemHandle",
49+
"SaveDefeatConditionHandle", "SaveTimerDialogHandle", "SaveLeaderboardHandle", "SaveMultiboardHandle",
50+
"SaveMultiboardItemHandle", "SaveTrackableHandle", "SaveDialogHandle", "SaveButtonHandle",
51+
"SaveTextTagHandle", "SaveLightningHandle", "SaveImageHandle", "SaveUbersplatHandle", "SaveRegionHandle",
52+
"SaveFogStateHandle", "SaveFogModifierHandle", "SaveAgentHandle", "SaveHashtableHandle", "SaveFrameHandle",
53+
"LoadPlayerHandle", "LoadWidgetHandle", "LoadDestructableHandle", "LoadItemHandle", "LoadUnitHandle",
54+
"LoadAbilityHandle", "LoadTimerHandle", "LoadTriggerHandle", "LoadTriggerConditionHandle",
55+
"LoadTriggerActionHandle", "LoadTriggerEventHandle", "LoadForceHandle", "LoadGroupHandle",
56+
"LoadLocationHandle", "LoadRectHandle", "LoadBooleanExprHandle", "LoadSoundHandle", "LoadEffectHandle",
57+
"LoadUnitPoolHandle", "LoadItemPoolHandle", "LoadQuestHandle", "LoadQuestItemHandle",
58+
"LoadDefeatConditionHandle", "LoadTimerDialogHandle", "LoadLeaderboardHandle", "LoadMultiboardHandle",
59+
"LoadMultiboardItemHandle", "LoadTrackableHandle", "LoadDialogHandle", "LoadButtonHandle",
60+
"LoadTextTagHandle", "LoadLightningHandle", "LoadImageHandle", "LoadUbersplatHandle", "LoadRegionHandle",
61+
"LoadFogStateHandle", "LoadFogModifierHandle", "LoadHashtableHandle", "LoadFrameHandle"
62+
));
63+
64+
/** Context-callback natives that need to be remapped to {@code __wurst_} equivalents. */
65+
private static final Set<String> CONTEXT_CALLBACK_NATIVE_NAMES = new HashSet<>(Arrays.asList(
66+
"ForForce", "GetEnumPlayer",
67+
"ForGroup", "GetEnumUnit",
68+
"EnumItemsInRect", "GetEnumItem",
69+
"EnumDestructablesInRect", "GetEnumDestructable"
70+
));
71+
72+
private LuaNativeLowering() {}
73+
74+
/**
75+
* Transforms the IM program in place.
76+
*
77+
* <p>Must be called <em>before</em> the optimizer so that the optimizer
78+
* can inline and eliminate the generated wrappers.
79+
*/
80+
public static void transform(ImProg prog) {
81+
// Maps original BJ function → replacement (either a IS_NATIVE stub or a nil-safety wrapper)
82+
Map<ImFunction, ImFunction> replacements = new LinkedHashMap<>();
83+
// Nil-safety wrappers are collected separately and added to prog AFTER the traversal,
84+
// so the traversal does not visit their bodies and replace their internal BJ delegate calls.
85+
List<ImFunction> deferredWrappers = new ArrayList<>();
86+
87+
// Snapshot to avoid ConcurrentModificationException when createNativeStub adds to prog.getFunctions()
88+
List<ImFunction> snapshot = new ArrayList<>(prog.getFunctions());
89+
for (ImFunction f : snapshot) {
90+
if (!f.isBj()) {
91+
continue;
92+
}
93+
String name = f.getName();
94+
95+
if ("GetHandleId".equals(name)) {
96+
replacements.put(f, createNativeStub("__wurst_GetHandleId", f, prog));
97+
} else if (HASHTABLE_NATIVE_NAMES.contains(name)) {
98+
replacements.put(f, createNativeStub("__wurst_" + name, f, prog));
99+
} else if (CONTEXT_CALLBACK_NATIVE_NAMES.contains(name)) {
100+
replacements.put(f, createNativeStub("__wurst_" + name, f, prog));
101+
} else if (hasHandleParam(f)) {
102+
ImFunction wrapper = createNilSafeWrapper(f);
103+
replacements.put(f, wrapper);
104+
deferredWrappers.add(wrapper);
105+
}
106+
}
107+
108+
if (replacements.isEmpty()) {
109+
return;
110+
}
111+
112+
// Replace all call sites in the existing IM (before adding wrappers).
113+
// Wrappers are deferred so their internal BJ delegate calls are not replaced.
114+
prog.accept(new Element.DefaultVisitor() {
115+
@Override
116+
public void visit(ImFunctionCall call) {
117+
super.visit(call);
118+
ImFunction replacement = replacements.get(call.getFunc());
119+
if (replacement != null) {
120+
call.replaceBy(JassIm.ImFunctionCall(
121+
call.attrTrace(), replacement,
122+
JassIm.ImTypeArguments(),
123+
call.getArguments().copy(),
124+
false, CallType.NORMAL));
125+
}
126+
}
127+
});
128+
129+
// Add nil-safety wrapper functions AFTER traversal so their own bodies are not traversed.
130+
prog.getFunctions().addAll(deferredWrappers);
131+
}
132+
133+
/**
134+
* Creates a new IS_NATIVE (non-BJ) IM function stub with the same signature as
135+
* {@code original}. The Lua translator will fill in the body via
136+
* {@code LuaNatives.get()} when it encounters the stub.
137+
*/
138+
private static ImFunction createNativeStub(String name, ImFunction original, ImProg prog) {
139+
ImVars params = JassIm.ImVars();
140+
for (ImVar p : original.getParameters()) {
141+
params.add(JassIm.ImVar(p.attrTrace(), p.getType().copy(), p.getName(), false));
142+
}
143+
ImFunction stub = JassIm.ImFunction(
144+
original.attrTrace(), name,
145+
JassIm.ImTypeVars(), params,
146+
original.getReturnType().copy(),
147+
JassIm.ImVars(), JassIm.ImStmts(),
148+
Collections.singletonList(FunctionFlagEnum.IS_NATIVE));
149+
prog.getFunctions().add(stub);
150+
return stub;
151+
}
152+
153+
/**
154+
* Creates a nil-safety wrapper for {@code bjNative}.
155+
*
156+
* <p>The generated function checks each handle-typed parameter against
157+
* {@code null} and returns the type-appropriate default value if any is
158+
* null. Otherwise it delegates to the original BJ function.
159+
*/
160+
private static ImFunction createNilSafeWrapper(ImFunction bjNative) {
161+
ImVars params = JassIm.ImVars();
162+
List<ImVar> paramVars = new ArrayList<>();
163+
for (ImVar p : bjNative.getParameters()) {
164+
ImVar copy = JassIm.ImVar(p.attrTrace(), p.getType().copy(), p.getName(), false);
165+
params.add(copy);
166+
paramVars.add(copy);
167+
}
168+
169+
ImStmts body = JassIm.ImStmts();
170+
171+
// Null-check each handle param: if param == null then return <default> end
172+
ImExpr returnDefault = defaultValueExpr(bjNative.getReturnType());
173+
for (ImVar param : paramVars) {
174+
if (isHandleType(param.getType())) {
175+
ImExpr condition = JassIm.ImOperatorCall(WurstOperator.EQ, JassIm.ImExprs(
176+
JassIm.ImVarAccess(param),
177+
JassIm.ImNull(param.getType().copy())
178+
));
179+
ImStmts thenBlock = JassIm.ImStmts(
180+
JassIm.ImReturn(bjNative.attrTrace(), returnDefault.copy())
181+
);
182+
body.add(JassIm.ImIf(bjNative.attrTrace(), condition, thenBlock, JassIm.ImStmts()));
183+
}
184+
}
185+
186+
// Delegate to the original BJ native
187+
ImExprs callArgs = JassIm.ImExprs();
188+
for (ImVar pv : paramVars) {
189+
callArgs.add(JassIm.ImVarAccess(pv));
190+
}
191+
ImFunctionCall delegate = JassIm.ImFunctionCall(
192+
bjNative.attrTrace(), bjNative,
193+
JassIm.ImTypeArguments(), callArgs, false, CallType.NORMAL);
194+
195+
if (bjNative.getReturnType() instanceof ImVoid) {
196+
body.add(delegate);
197+
} else {
198+
body.add(JassIm.ImReturn(bjNative.attrTrace(), delegate));
199+
}
200+
201+
return JassIm.ImFunction(
202+
bjNative.attrTrace(),
203+
"__wurst_safe_" + bjNative.getName(),
204+
JassIm.ImTypeVars(), params,
205+
bjNative.getReturnType().copy(),
206+
JassIm.ImVars(), body,
207+
Collections.emptyList());
208+
}
209+
210+
private static boolean hasHandleParam(ImFunction f) {
211+
for (ImVar p : f.getParameters()) {
212+
if (isHandleType(p.getType())) {
213+
return true;
214+
}
215+
}
216+
return false;
217+
}
218+
219+
/** Returns true for WC3 handle types (ImSimpleType that is not int/real/boolean/string). */
220+
static boolean isHandleType(ImType type) {
221+
if (!(type instanceof ImSimpleType)) {
222+
return false;
223+
}
224+
String n = ((ImSimpleType) type).getTypename();
225+
return !n.equals("integer") && !n.equals("real") && !n.equals("boolean") && !n.equals("string");
226+
}
227+
228+
/** Returns an IM expression representing the safe default for the given return type. */
229+
private static ImExpr defaultValueExpr(ImType returnType) {
230+
if (returnType instanceof ImSimpleType) {
231+
String n = ((ImSimpleType) returnType).getTypename();
232+
switch (n) {
233+
case "integer": return JassIm.ImIntVal(0);
234+
case "real": return JassIm.ImRealVal("0.0");
235+
case "boolean": return JassIm.ImBoolVal(false);
236+
case "string": return JassIm.ImStringVal("");
237+
}
238+
}
239+
// void or handle type → null
240+
if (returnType instanceof ImVoid) {
241+
return JassIm.ImNull(JassIm.ImVoid());
242+
}
243+
return JassIm.ImNull(returnType.copy());
244+
}
245+
}

0 commit comments

Comments
 (0)