Skip to content

Commit 4263652

Browse files
authored
Improve lambda implementation (#273)
1 parent 9c377f0 commit 4263652

File tree

7 files changed

+505
-214
lines changed

7 files changed

+505
-214
lines changed

sourcegen-bytecode-writer/src/main/java/io/micronaut/sourcegen/bytecode/expression/LambdaExpressionWriter.java

+18-59
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
import javax.lang.model.element.Modifier;
3232
import java.util.ArrayList;
33-
import java.util.HashSet;
33+
import java.util.LinkedHashSet;
3434
import java.util.List;
3535
import java.util.Set;
3636

@@ -44,8 +44,8 @@ final class LambdaExpressionWriter extends AbstractStatementAwareExpressionWrite
4444
private static final String METAFACTORY_METHOD = "metafactory";
4545
private static final String METAFACTORY_DESCRIPTOR =
4646
"(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;" +
47-
"Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;" +
48-
"Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;";
47+
"Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;" +
48+
"Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;";
4949

5050
private final Lambda lambda;
5151

@@ -55,12 +55,12 @@ public LambdaExpressionWriter(Lambda lambda) {
5555

5656
@Override
5757
public void write(GeneratorAdapter generatorAdapter, MethodContext context) {
58-
List<VariableDef> capturedVariables = captureVariables(lambda.method());
58+
List<VariableDef> capturedVariables = captureVariables(lambda.implementation());
5959
MethodDef implementationMethodDef = createLambdaMethodDef(context, lambda, capturedVariables);
6060
context.lambdaMethods().add(implementationMethodDef);
6161

6262
// The captured variables are the parameters to the called bootstrap method
63-
for (VariableDef variable: capturedVariables) {
63+
for (VariableDef variable : capturedVariables) {
6464
new VariableExpressionWriter(variable).write(generatorAdapter, context);
6565
}
6666

@@ -87,19 +87,19 @@ public void write(GeneratorAdapter generatorAdapter, MethodContext context) {
8787
);
8888

8989
generatorAdapter.visitInvokeDynamicInsn(
90-
lambda.method().getName(),
90+
lambda.implementation().getName(),
9191
createDynamicInvocationDescriptor(capturedVariables, context),
9292
bootstrapMethodHandle,
93-
Type.getType(TypeUtils.getMethodDescriptor(context.objectDef(), lambda.overriddenMethod())),
93+
Type.getType(TypeUtils.getMethodDescriptor(context.objectDef(), lambda.target())),
9494
lambdaMethodHandle,
95-
Type.getType(TypeUtils.getMethodDescriptor(context.objectDef(), lambda.method()))
95+
Type.getType(TypeUtils.getMethodDescriptor(context.objectDef(), lambda.implementation()))
9696
);
9797
popValueIfNeeded(generatorAdapter, lambda.type());
9898
}
9999

100100
private String createDynamicInvocationDescriptor(List<VariableDef> capturedVariables, MethodContext context) {
101101
StringBuilder dynamicDescriptor = new StringBuilder("(");
102-
for (VariableDef variable: capturedVariables) {
102+
for (VariableDef variable : capturedVariables) {
103103
dynamicDescriptor.append(TypeUtils.getType(variable.type(), context.objectDef()));
104104
}
105105
dynamicDescriptor.append(")");
@@ -108,11 +108,11 @@ private String createDynamicInvocationDescriptor(List<VariableDef> capturedVaria
108108
}
109109

110110
private MethodDef createLambdaMethodDef(MethodContext context, Lambda lambda, List<VariableDef> capturedVariables) {
111-
MethodDef original = lambda.method();
111+
MethodDef original = lambda.implementation();
112112
List<ParameterDef> parameters = new ArrayList<>();
113113

114114
// The captured variables are parameters
115-
for (VariableDef variable: capturedVariables) {
115+
for (VariableDef variable : capturedVariables) {
116116
if (variable instanceof VariableDef.Local local) {
117117
parameters.add(ParameterDef.builder(local.name(), local.type()).build());
118118
} else if (variable instanceof VariableDef.MethodParameter parameter) {
@@ -130,7 +130,7 @@ private MethodDef createLambdaMethodDef(MethodContext context, Lambda lambda, Li
130130

131131
parameters.addAll(original.getParameters());
132132
return MethodDef.builder("lambda$" + context.methodDef().getName() + "$" +
133-
context.lambdaMethods().size())
133+
context.lambdaMethods().size())
134134
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
135135
.addParameters(parameters)
136136
.returns(original.getReturnType())
@@ -139,59 +139,19 @@ private MethodDef createLambdaMethodDef(MethodContext context, Lambda lambda, Li
139139
}
140140

141141
private List<VariableDef> captureVariables(MethodDef method) {
142-
Set<String> variables = new HashSet<>(
142+
Set<String> variables = new LinkedHashSet<>(
143143
method.getParameters().stream().map(v -> v.getName()).toList()
144144
);
145145
List<VariableDef> capturedVariables = new ArrayList<>();
146-
for (StatementDef statement: method.getStatements()) {
146+
for (StatementDef statement : method.getStatements()) {
147147
captureVariables(statement, variables, capturedVariables);
148148
}
149149
return capturedVariables;
150150
}
151151

152152
private void captureVariables(StatementDef statement, Set<String> variables, List<VariableDef> capturedVariables) {
153-
if (statement instanceof StatementDef.Multi multi) {
154-
for (StatementDef s: multi.statements()) {
155-
captureVariables(s, variables, capturedVariables);
156-
}
157-
} else if (statement instanceof StatementDef.Return returnStatement) {
158-
captureVariables(returnStatement.expression(), variables, capturedVariables);
159-
} else if (statement instanceof StatementDef.Synchronized sync) {
160-
captureVariables(sync.monitor(), variables, capturedVariables);
161-
captureVariables(sync.statement(), variables, capturedVariables);
162-
} else if (statement instanceof StatementDef.Throw throwStatement) {
163-
captureVariables(throwStatement.expression(), variables, capturedVariables);
164-
} else if (statement instanceof StatementDef.Assign assign) {
165-
captureVariables(assign.variable(), variables, capturedVariables);
166-
captureVariables(assign.expression(), variables, capturedVariables);
167-
} else if (statement instanceof StatementDef.DefineAndAssign assign) {
168-
variables.add(assign.variable().name());
169-
captureVariables(assign.expression(), variables, capturedVariables);
170-
} else if (statement instanceof StatementDef.While w) {
171-
captureVariables(w.expression(), variables, capturedVariables);
172-
captureVariables(w.statement(), variables, capturedVariables);
173-
} else if (statement instanceof StatementDef.If ifStatement) {
174-
captureVariables(ifStatement.condition(), variables, capturedVariables);
175-
captureVariables(ifStatement.statement(), variables, capturedVariables);
176-
} else if (statement instanceof StatementDef.Try tryStatement) {
177-
captureVariables(tryStatement.statement(), variables, capturedVariables);
178-
captureVariables(tryStatement.finallyStatement(), variables, capturedVariables);
179-
for (StatementDef.Try.Catch cat: tryStatement.catches()) {
180-
captureVariables(cat.statement(), variables, capturedVariables);
181-
}
182-
} else if (statement instanceof StatementDef.IfElse ifElse) {
183-
captureVariables(ifElse.condition(), variables, capturedVariables);
184-
captureVariables(ifElse.statement(), variables, capturedVariables);
185-
captureVariables(ifElse.elseStatement(), variables, capturedVariables);
186-
} else if (statement instanceof StatementDef.PutField putField) {
187-
capturedVariables.add(putField.field());
188-
variables.add(putField.field().name());
189-
captureVariables(putField.expression(), variables, capturedVariables);
190-
} else if (statement instanceof StatementDef.PutStaticField putStaticField) {
191-
captureVariables(putStaticField.expression(), variables, capturedVariables);
192-
} else {
193-
throw new IllegalStateException("Unsupported statement type in lambda: " + statement.getClass().getName());
194-
}
153+
statement.nestedExpressionsStream()
154+
.forEach(expressionDef -> captureVariables(expressionDef, variables, capturedVariables));
195155
}
196156

197157
private void captureVariables(ExpressionDef expression, Set<String> variables, List<VariableDef> capturedVariables) {
@@ -225,9 +185,8 @@ private void captureVariables(ExpressionDef expression, Set<String> variables, L
225185
}
226186
}
227187
} else {
228-
for (ExpressionDef operand: expression.operands()) {
229-
captureVariables(operand, variables, capturedVariables);
230-
}
188+
expression.nestedExpressionsStream()
189+
.forEach(expressionDef -> captureVariables(expressionDef, variables, capturedVariables));
231190
}
232191
}
233192

sourcegen-bytecode-writer/src/test/java/io/micronaut/sourcegen/bytecode/ByteCodeWriterTest.java

+198
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import io.micronaut.core.annotation.AnnotationClassValue;
55
import io.micronaut.core.reflect.ReflectionUtils;
66
import io.micronaut.inject.visitor.VisitorContext;
7+
import io.micronaut.sourcegen.custom.visitor.GenerateLambdaVisitor;
78
import io.micronaut.sourcegen.custom.visitor.innerTypes.GenerateInnerTypeInEnumVisitor;
89
import io.micronaut.sourcegen.model.ClassDef;
910
import io.micronaut.sourcegen.model.ClassTypeDef;
@@ -54,6 +55,203 @@
5455

5556
class ByteCodeWriterTest {
5657

58+
@Test
59+
void lambda() {
60+
ClassDef classDef = GenerateLambdaVisitor.getSpec("example").theClass();
61+
62+
StringWriter bytecodeWriter = new StringWriter();
63+
byte[] bytes = generateFile(classDef, bytecodeWriter);
64+
65+
String bytecode = bytecodeWriter.toString();
66+
Assertions.assertEquals("""
67+
// class version 61.0 (61)
68+
// access flags 0x1
69+
// signature Ljava/lang/Object;
70+
// declaration: example/MyClassWithLambda
71+
public class example/MyClassWithLambda {
72+
73+
74+
// access flags 0x0
75+
Ljava/lang/String; name
76+
77+
// access flags 0x1
78+
public <init>()V
79+
ALOAD 0
80+
INVOKESPECIAL java/lang/Object.<init> ()V
81+
RETURN
82+
83+
// access flags 0x1
84+
public toString()Ljava/lang/String;
85+
LDC "MyClass"
86+
ARETURN
87+
88+
// access flags 0x1
89+
public callLambda(Ljava/lang/String;)Ljava/lang/String;
90+
L0
91+
L1
92+
INVOKEDYNAMIC apply()Lexample/StringFunction; [
93+
// handle kind 0x6 : INVOKESTATIC
94+
java/lang/invoke/LambdaMetafactory.metafactory(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;
95+
// arguments:
96+
(Ljava/lang/String;)Ljava/lang/String;,\s
97+
// handle kind 0x6 : INVOKESTATIC
98+
example/MyClassWithLambda.lambda$callLambda$0(Ljava/lang/String;)Ljava/lang/String;,\s
99+
(Ljava/lang/String;)Ljava/lang/String;
100+
]
101+
ASTORE 2
102+
ALOAD 2
103+
ALOAD 1
104+
INVOKEINTERFACE example/StringFunction.apply (Ljava/lang/String;)Ljava/lang/String; (itf)
105+
ARETURN
106+
L2
107+
LOCALVARIABLE input Ljava/lang/String; L0 L2 1
108+
LOCALVARIABLE function Lexample/StringFunction; L1 L2 2
109+
110+
// access flags 0xA
111+
private static lambda$callLambda$0(Ljava/lang/String;)Ljava/lang/String;
112+
L0
113+
ALOAD 0
114+
ICONST_1
115+
INVOKEVIRTUAL java/lang/String.substring (I)Ljava/lang/String;
116+
ARETURN
117+
L1
118+
LOCALVARIABLE arg1 Ljava/lang/String; L0 L1 1
119+
120+
// access flags 0x1
121+
public callStatefulLambda(Ljava/lang/String;)Ljava/lang/String;
122+
L0
123+
L1
124+
LDC "prefix_"
125+
ASTORE 2
126+
L2
127+
ALOAD 2
128+
ALOAD 0
129+
INVOKEDYNAMIC apply(Ljava/lang/String;Lexample/MyClassWithLambda;)Lexample/StringFunction; [
130+
// handle kind 0x6 : INVOKESTATIC
131+
java/lang/invoke/LambdaMetafactory.metafactory(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;
132+
// arguments:
133+
(Ljava/lang/String;)Ljava/lang/String;,\s
134+
// handle kind 0x6 : INVOKESTATIC
135+
example/MyClassWithLambda.lambda$callStatefulLambda$0(Ljava/lang/String;Lexample/MyClassWithLambda;Ljava/lang/String;)Ljava/lang/String;,\s
136+
(Ljava/lang/String;)Ljava/lang/String;
137+
]
138+
ASTORE 3
139+
ALOAD 3
140+
ALOAD 1
141+
INVOKEINTERFACE example/StringFunction.apply (Ljava/lang/String;)Ljava/lang/String; (itf)
142+
ARETURN
143+
L3
144+
LOCALVARIABLE input Ljava/lang/String; L0 L3 1
145+
LOCALVARIABLE constant Ljava/lang/String; L1 L3 2
146+
LOCALVARIABLE function Lexample/StringFunction; L2 L3 3
147+
148+
// access flags 0xA
149+
private static lambda$callStatefulLambda$0(Ljava/lang/String;Lexample/MyClassWithLambda;Ljava/lang/String;)Ljava/lang/String;
150+
L0
151+
ALOAD 0
152+
ALOAD 2
153+
ICONST_1
154+
INVOKEVIRTUAL java/lang/String.substring (I)Ljava/lang/String;
155+
ALOAD 1
156+
INVOKEVIRTUAL example/MyClassWithLambda.toString ()Ljava/lang/String;
157+
INVOKEVIRTUAL java/lang/String.concat (Ljava/lang/String;)Ljava/lang/String;
158+
ALOAD 1
159+
GETFIELD example/MyClassWithLambda.name : Ljava/lang/String;
160+
INVOKEVIRTUAL java/lang/String.concat (Ljava/lang/String;)Ljava/lang/String;
161+
INVOKEVIRTUAL java/lang/String.concat (Ljava/lang/String;)Ljava/lang/String;
162+
ARETURN
163+
L1
164+
LOCALVARIABLE constant Ljava/lang/String; L0 L1 1
165+
LOCALVARIABLE this Lexample/MyClassWithLambda; L0 L1 2
166+
LOCALVARIABLE arg1 Ljava/lang/String; L0 L1 3
167+
168+
// access flags 0x1
169+
public callGenericLambda(Ljava/lang/String;)Ljava/lang/String;
170+
L0
171+
L1
172+
LDC "prefix_"
173+
ASTORE 2
174+
L2
175+
ALOAD 2
176+
INVOKEDYNAMIC apply(Ljava/lang/String;)Ljava/util/function/Function; [
177+
// handle kind 0x6 : INVOKESTATIC
178+
java/lang/invoke/LambdaMetafactory.metafactory(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;
179+
// arguments:
180+
(Ljava/lang/Object;)Ljava/lang/Object;,\s
181+
// handle kind 0x6 : INVOKESTATIC
182+
example/MyClassWithLambda.lambda$callGenericLambda$0(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;,\s
183+
(Ljava/lang/String;)Ljava/lang/String;
184+
]
185+
ASTORE 3
186+
ALOAD 3
187+
ALOAD 1
188+
INVOKEINTERFACE java/util/function/Function.apply (Ljava/lang/Object;)Ljava/lang/Object; (itf)
189+
CHECKCAST java/lang/String
190+
ARETURN
191+
L3
192+
LOCALVARIABLE input Ljava/lang/String; L0 L3 1
193+
LOCALVARIABLE constant Ljava/lang/String; L1 L3 2
194+
LOCALVARIABLE function Ljava/util/function/Function; L2 L3 3
195+
196+
// access flags 0xA
197+
private static lambda$callGenericLambda$0(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;
198+
L0
199+
ALOAD 0
200+
ALOAD 1
201+
ICONST_1
202+
INVOKEVIRTUAL java/lang/String.substring (I)Ljava/lang/String;
203+
INVOKEVIRTUAL java/lang/String.concat (Ljava/lang/String;)Ljava/lang/String;
204+
ARETURN
205+
L1
206+
LOCALVARIABLE constant Ljava/lang/String; L0 L1 1
207+
LOCALVARIABLE arg1 Ljava/lang/String; L0 L1 2
208+
}
209+
""", bytecode);
210+
211+
Assertions.assertEquals("""
212+
package example;
213+
214+
import java.util.function.Function;
215+
216+
public class MyClassWithLambda {
217+
String name;
218+
219+
public String toString() {
220+
return "MyClass";
221+
}
222+
223+
public String callLambda(String input) {
224+
StringFunction function = MyClassWithLambda::lambda$callLambda$0;
225+
return function.apply(input);
226+
}
227+
228+
private static String lambda$callLambda$0(String var0) {
229+
return var0.substring(1);
230+
}
231+
232+
public String callStatefulLambda(String input) {
233+
String constant = "prefix_";
234+
StringFunction function = MyClassWithLambda::lambda$callStatefulLambda$0;
235+
return function.apply(input);
236+
}
237+
238+
private static String lambda$callStatefulLambda$0(String var0, MyClassWithLambda constant, String var2) {
239+
return var0.concat(((String)var2).substring(1).concat(((MyClassWithLambda)constant).toString()).concat(constant.name));
240+
}
241+
242+
public String callGenericLambda(String input) {
243+
String constant = "prefix_";
244+
Function function = MyClassWithLambda::lambda$callGenericLambda$0;
245+
return (String)function.apply(input);
246+
}
247+
248+
private static String lambda$callGenericLambda$0(String var0, String constant) {
249+
return var0.concat(constant.substring(1));
250+
}
251+
}
252+
""", decompileToJava(bytes));
253+
}
254+
57255
@Test
58256
void voidClassLoading() {
59257
final Constructor<?> CONSTRUCTOR_CLASS_VALUE = ReflectionUtils.getRequiredInternalConstructor(

0 commit comments

Comments
 (0)