diff --git a/sourcegen-generator-java/src/main/java/io/micronaut/sourcegen/JavaPoetSourceGenerator.java b/sourcegen-generator-java/src/main/java/io/micronaut/sourcegen/JavaPoetSourceGenerator.java index a262975f7..2c0332a9c 100644 --- a/sourcegen-generator-java/src/main/java/io/micronaut/sourcegen/JavaPoetSourceGenerator.java +++ b/sourcegen-generator-java/src/main/java/io/micronaut/sourcegen/JavaPoetSourceGenerator.java @@ -57,6 +57,7 @@ import java.lang.reflect.Array; import java.util.Collection; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -72,6 +73,7 @@ */ @Internal public sealed class JavaPoetSourceGenerator implements SourceGenerator permits GroovyPoetSourceGenerator { + private static final String EXCEPTION_NAME = "$exception"; @Override public VisitorContext.Language getLanguage() { @@ -170,7 +172,7 @@ private TypeSpec.Builder getEnumBuilder(EnumDef enumDef) { if (exps != null) { CodeBlock.Builder expBuilder = CodeBlock.builder(); for (int i = 0; i < exps.size(); i++) { - expBuilder.add(renderExpression(null, null, exps.get(i))); + expBuilder.add(renderExpression(null, null, Map.of(), exps.get(i))); if (i < exps.size() - 1) { expBuilder.add(", "); } @@ -296,13 +298,14 @@ private void buildFields(ObjectDef objectDef, TypeSpec.Builder builder) { ((EnumDef) objectDef).getFields(); for (FieldDef field : fields) { FieldSpec.Builder fieldBuilder = FieldSpec.builder( - asType(field.getType(), objectDef), - field.getName() - ).addModifiers(field.getModifiersArray()); + asType(field.getType(), objectDef), + field.getName() + ).addModifiers(field.getModifiersArray()); field.getInitializer().ifPresent(init -> fieldBuilder.initializer(renderExpression( + objectDef, null, - null, + Map.of(), init )) ); @@ -376,7 +379,7 @@ private MethodSpec asMethodSpec(ObjectDef objectDef, MethodDef method) { ); } method.getStatements().stream() - .map(st -> renderStatementCodeBlock(objectDef, method, st)) + .map(st -> renderStatementCodeBlock(objectDef, method, Map.of(), st)) .forEach(methodBuilder::addCode); return methodBuilder.build(); @@ -403,7 +406,7 @@ private void addAnnotationValue(AnnotationSpec.Builder builder, String memberNam } else if (value instanceof AnnotationDef annotationValue) { builder.addMember(memberName, asAnnotationSpec(annotationValue)); } else if (value instanceof VariableDef variableDef) { - builder.addMember(memberName, renderVariable(null, null, variableDef)); + builder.addMember(memberName, renderVariable(null, null, Map.of(), variableDef)); } else if (value instanceof Class) { builder.addMember(memberName, "$T.class", value); } else if (value instanceof Enum) { @@ -505,11 +508,14 @@ private static ClassName asClassType(ClassTypeDef classTypeDef) { return ClassName.bestGuess(classTypeDef.getCanonicalName()); } - private CodeBlock renderStatement(@Nullable ObjectDef objectDef, MethodDef methodDef, StatementDef statementDef) { + private CodeBlock renderStatement(@Nullable ObjectDef objectDef, + MethodDef methodDef, + Map remappedLocals, + StatementDef statementDef) { if (statementDef instanceof StatementDef.Throw aThrow) { return CodeBlock.concat( CodeBlock.of("throw "), - renderExpression(objectDef, methodDef, aThrow.expression()) + renderExpression(objectDef, methodDef, remappedLocals, aThrow.expression()) ); } if (statementDef instanceof StatementDef.Return aReturn) { @@ -518,22 +524,22 @@ private CodeBlock renderStatement(@Nullable ObjectDef objectDef, MethodDef metho } return CodeBlock.concat( CodeBlock.of("return "), - renderExpression(objectDef, methodDef, aReturn.expression()) + renderExpression(objectDef, methodDef, remappedLocals, aReturn.expression()) ); } if (statementDef instanceof StatementDef.Assign assign) { return CodeBlock.concat( - renderExpression(objectDef, methodDef, assign.variable()), + renderExpression(objectDef, methodDef, remappedLocals, assign.variable()), CodeBlock.of(" = "), - renderExpression(objectDef, methodDef, assign.expression()) + renderExpression(objectDef, methodDef, remappedLocals, assign.expression()) ); } if (statementDef instanceof StatementDef.PutField putField) { VariableDef.Field field = putField.field(); return CodeBlock.concat( - renderExpression(objectDef, methodDef, field.instance()), + renderExpression(objectDef, methodDef, remappedLocals, field.instance()), CodeBlock.of(".$L = ", field.name()), - renderExpression(objectDef, methodDef, putField.expression()) + renderExpression(objectDef, methodDef, remappedLocals, putField.expression()) ); } if (statementDef instanceof StatementDef.PutStaticField putStaticField) { @@ -541,37 +547,68 @@ private CodeBlock renderStatement(@Nullable ObjectDef objectDef, MethodDef metho return CodeBlock.concat( CodeBlock.of("$T.$L", asType(field.type(), objectDef), field.name()), CodeBlock.of(" = "), - renderExpression(objectDef, methodDef, putStaticField.expression()) + renderExpression(objectDef, methodDef, remappedLocals, putStaticField.expression()) ); } if (statementDef instanceof StatementDef.DefineAndAssign assign) { return CodeBlock.concat( CodeBlock.of("$T $L", asType(assign.variable().type(), objectDef), assign.variable().name()), CodeBlock.of(" = "), - renderExpression(objectDef, methodDef, assign.expression()) + renderExpression(objectDef, methodDef, remappedLocals, assign.expression()) ); } if (statementDef instanceof ExpressionDef expressionDef) { - return renderExpression(objectDef, methodDef, expressionDef); + return renderExpression(objectDef, methodDef, remappedLocals, expressionDef); } throw new IllegalStateException("Unrecognized statement: " + statementDef); } - private CodeBlock renderStatementCodeBlock(@Nullable ObjectDef objectDef, MethodDef methodDef, StatementDef statementDef) { + private CodeBlock renderStatementCodeBlock(@Nullable ObjectDef objectDef, + MethodDef methodDef, + Map remappedLocals, + StatementDef statementDef) { if (statementDef instanceof StatementDef.Multi statements) { CodeBlock.Builder builder = CodeBlock.builder(); for (StatementDef statement : statements.statements()) { - builder.add(renderStatementCodeBlock(objectDef, methodDef, statement)); + builder.add(renderStatementCodeBlock(objectDef, methodDef, remappedLocals, statement)); + } + return builder.build(); + } + if (statementDef instanceof StatementDef.Try tryStatement) { + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add("try {\n"); + builder.add(renderStatementCodeBlock(objectDef, methodDef, remappedLocals, tryStatement.statement())); + int i = 0; + for (StatementDef.Try.Catch aCatch : tryStatement.catches()) { + String exceptionLocal = "e" + i++; + builder.add(CodeBlock.of("\n} catch ($T $L) {", asType(aCatch.exception(), objectDef), exceptionLocal)); + Map newRemappedLocals = new LinkedHashMap<>(remappedLocals); + newRemappedLocals.put(EXCEPTION_NAME, exceptionLocal); + builder.add(renderStatementCodeBlock(objectDef, methodDef, newRemappedLocals, aCatch.statement())); + } + if (tryStatement.finallyStatement() != null) { + builder.add("\n} finally {"); + builder.add(renderStatementCodeBlock(objectDef, methodDef, remappedLocals, tryStatement.finallyStatement())); } + builder.add("\n}"); + return builder.build(); + } + if (statementDef instanceof StatementDef.Synchronized s) { + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add("synchronized ("); + builder.add(renderExpression(objectDef, methodDef, remappedLocals, s.monitor())); + builder.add(") {\n"); + builder.add(renderStatementCodeBlock(objectDef, methodDef, remappedLocals, s.statement())); + builder.add("\n}"); return builder.build(); } if (statementDef instanceof StatementDef.If ifStatement) { CodeBlock.Builder builder = CodeBlock.builder(); builder.add("if ("); - builder.add(renderExpression(objectDef, methodDef, ifStatement.condition())); + builder.add(renderExpression(objectDef, methodDef, remappedLocals, ifStatement.condition())); builder.add(") {\n"); builder.indent(); - builder.add(renderStatementCodeBlock(objectDef, methodDef, ifStatement.statement())); + builder.add(renderStatementCodeBlock(objectDef, methodDef, remappedLocals, ifStatement.statement())); builder.unindent(); builder.add("}\n"); return builder.build(); @@ -579,14 +616,14 @@ private CodeBlock renderStatementCodeBlock(@Nullable ObjectDef objectDef, Method if (statementDef instanceof StatementDef.IfElse ifStatement) { CodeBlock.Builder builder = CodeBlock.builder(); builder.add("if ("); - builder.add(renderExpression(objectDef, methodDef, ifStatement.condition())); + builder.add(renderExpression(objectDef, methodDef, remappedLocals, ifStatement.condition())); builder.add(") {\n"); builder.indent(); - builder.add(renderStatementCodeBlock(objectDef, methodDef, ifStatement.statement())); + builder.add(renderStatementCodeBlock(objectDef, methodDef, remappedLocals, ifStatement.statement())); builder.unindent(); builder.add("} else {\n"); builder.indent(); - builder.add(renderStatementCodeBlock(objectDef, methodDef, ifStatement.elseStatement())); + builder.add(renderStatementCodeBlock(objectDef, methodDef, remappedLocals, ifStatement.elseStatement())); builder.unindent(); builder.add("}\n"); return builder.build(); @@ -594,15 +631,15 @@ private CodeBlock renderStatementCodeBlock(@Nullable ObjectDef objectDef, Method if (statementDef instanceof StatementDef.Switch aSwitch) { CodeBlock.Builder builder = CodeBlock.builder(); builder.add("switch ("); - builder.add(renderExpression(objectDef, methodDef, aSwitch.expression())); + builder.add(renderExpression(objectDef, methodDef, remappedLocals, aSwitch.expression())); builder.add(") {\n"); builder.indent(); for (Map.Entry e : aSwitch.cases().entrySet()) { builder.add("case "); - builder.add(renderConstantExpression(e.getKey())); + builder.add(renderConstantExpression(remappedLocals, e.getKey())); builder.add(": {\n"); builder.indent(); - builder.add(renderStatementCodeBlock(objectDef, methodDef, e.getValue())); + builder.add(renderStatementCodeBlock(objectDef, methodDef, remappedLocals, e.getValue())); builder.unindent(); builder.add("}\n"); } @@ -610,7 +647,7 @@ private CodeBlock renderStatementCodeBlock(@Nullable ObjectDef objectDef, Method builder.add("default"); builder.add(": {\n"); builder.indent(); - builder.add(renderStatementCodeBlock(objectDef, methodDef, aSwitch.defaultCase())); + builder.add(renderStatementCodeBlock(objectDef, methodDef, remappedLocals, aSwitch.defaultCase())); builder.unindent(); builder.add("}\n"); } @@ -621,30 +658,33 @@ private CodeBlock renderStatementCodeBlock(@Nullable ObjectDef objectDef, Method if (statementDef instanceof StatementDef.While aWhile) { CodeBlock.Builder builder = CodeBlock.builder(); builder.add("while ("); - builder.add(renderExpression(objectDef, methodDef, aWhile.expression())); + builder.add(renderExpression(objectDef, methodDef, remappedLocals, aWhile.expression())); builder.add(") {\n"); builder.indent(); - builder.add(renderStatementCodeBlock(objectDef, methodDef, aWhile.statement())); + builder.add(renderStatementCodeBlock(objectDef, methodDef, remappedLocals, aWhile.statement())); builder.unindent(); builder.add("}\n"); return builder.build(); } return CodeBlock.builder() .addStatement( - renderStatement(objectDef, methodDef, statementDef) + renderStatement(objectDef, methodDef, remappedLocals, statementDef) ).build(); } - private CodeBlock renderExpression(@Nullable ObjectDef objectDef, MethodDef methodDef, ExpressionDef expressionDef) { + private CodeBlock renderExpression(@Nullable ObjectDef objectDef, + MethodDef methodDef, + Map remappedLocals, + ExpressionDef expressionDef) { if (expressionDef instanceof ExpressionDef.ConditionExpressionDef conditionExpressionDef) { - return renderCondition(objectDef, methodDef, conditionExpressionDef); + return renderCondition(objectDef, methodDef, remappedLocals, conditionExpressionDef); } if (expressionDef instanceof ExpressionDef.NewInstance newInstance) { return CodeBlock.concat( CodeBlock.of("new $L(", asType(newInstance.type(), objectDef)), newInstance.values() .stream() - .map(exp -> renderExpression(objectDef, methodDef, exp)) + .map(exp -> renderExpression(objectDef, methodDef, remappedLocals, exp)) .collect(CodeBlock.joining(", ")), CodeBlock.of(")") ); @@ -657,7 +697,7 @@ private CodeBlock renderExpression(@Nullable ObjectDef objectDef, MethodDef meth builder.add("new $T[]{", asType(newArray.type().componentType(), objectDef)); for (Iterator iterator = newArray.expressions().iterator(); iterator.hasNext(); ) { ExpressionDef expression = iterator.next(); - builder.add(renderExpression(objectDef, methodDef, expression)); + builder.add(renderExpression(objectDef, methodDef, remappedLocals, expression)); if (iterator.hasNext()) { builder.add(","); } @@ -667,32 +707,32 @@ private CodeBlock renderExpression(@Nullable ObjectDef objectDef, MethodDef meth } if (expressionDef instanceof ExpressionDef.Cast castExpressionDef) { if (castExpressionDef.type().equals(castExpressionDef.expressionDef().type())) { - return renderExpression(objectDef, methodDef, castExpressionDef.expressionDef()); + return renderExpression(objectDef, methodDef, remappedLocals, castExpressionDef.expressionDef()); } if (castExpressionDef.expressionDef() instanceof VariableDef variableDef) { return CodeBlock.concat( CodeBlock.of("($T) ", asType(castExpressionDef.type(), objectDef)), - renderExpression(objectDef, methodDef, variableDef) + renderExpression(objectDef, methodDef, remappedLocals, variableDef) ); } return CodeBlock.concat( CodeBlock.of("($T) (", asType(castExpressionDef.type(), objectDef)), - renderExpression(objectDef, methodDef, castExpressionDef.expressionDef()), + renderExpression(objectDef, methodDef, remappedLocals, castExpressionDef.expressionDef()), CodeBlock.of(")") ); } if (expressionDef instanceof ExpressionDef.Constant constant) { - return renderConstantExpression(constant); + return renderConstantExpression(remappedLocals, constant); } if (expressionDef instanceof ExpressionDef.InvokeInstanceMethod invokeInstanceMethod) { MethodDef callMethod = invokeInstanceMethod.method(); return CodeBlock.concat( - CodeBlock.of(renderExpression(objectDef, methodDef, invokeInstanceMethod.instance()) + CodeBlock.of(renderExpression(objectDef, methodDef, remappedLocals, invokeInstanceMethod.instance()) + (callMethod.isConstructor() ? "" : "." + callMethod.getName()) + "("), invokeInstanceMethod.values() .stream() - .map(exp -> renderExpression(objectDef, methodDef, exp)) + .map(exp -> renderExpression(objectDef, methodDef, remappedLocals, exp)) .collect(CodeBlock.joining(", ")), CodeBlock.of(")") ); @@ -702,48 +742,48 @@ private CodeBlock renderExpression(@Nullable ObjectDef objectDef, MethodDef meth CodeBlock.of("$T." + staticMethod.method().getName() + "(", asType(staticMethod.classDef(), objectDef)), staticMethod.values() .stream() - .map(exp -> renderExpression(objectDef, methodDef, exp)) + .map(exp -> renderExpression(objectDef, methodDef, remappedLocals, exp)) .collect(CodeBlock.joining(", ")), CodeBlock.of(")") ); } if (expressionDef instanceof ExpressionDef.GetPropertyValue getPropertyValue) { - return renderExpression(objectDef, methodDef, JavaIdioms.getPropertyValue(getPropertyValue)); + return renderExpression(objectDef, methodDef, remappedLocals, JavaIdioms.getPropertyValue(getPropertyValue)); } if (expressionDef instanceof ExpressionDef.MathBinaryOperation mathOperation) { return CodeBlock.concat( - renderExpressionWithParentheses(objectDef, methodDef, mathOperation.left()), + renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, mathOperation.left()), CodeBlock.of(getMathOp(mathOperation)), - renderExpressionWithParentheses(objectDef, methodDef, mathOperation.right()) + renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, mathOperation.right()) ); } if (expressionDef instanceof ExpressionDef.MathUnaryOperation mathOperation) { return CodeBlock.concat( CodeBlock.of(getMathOp(mathOperation)), - renderExpressionWithParentheses(objectDef, methodDef, mathOperation.expression()) + renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, mathOperation.expression()) ); } if (expressionDef instanceof ExpressionDef.IfElse condition) { return CodeBlock.concat( - renderExpression(objectDef, methodDef, condition.condition()), + renderExpression(objectDef, methodDef, remappedLocals, condition.condition()), CodeBlock.of(" ? "), - renderExpression(objectDef, methodDef, condition.ifExpression()), + renderExpression(objectDef, methodDef, remappedLocals, condition.ifExpression()), CodeBlock.of(" : "), - renderExpression(objectDef, methodDef, condition.elseExpression()) + renderExpression(objectDef, methodDef, remappedLocals, condition.elseExpression()) ); } if (expressionDef instanceof ExpressionDef.Switch aSwitch) { CodeBlock.Builder builder = CodeBlock.builder(); builder.add("switch ("); - builder.add(renderExpression(objectDef, methodDef, aSwitch.expression())); + builder.add(renderExpression(objectDef, methodDef, remappedLocals, aSwitch.expression())); builder.add(") {\n"); builder.indent(); for (Map.Entry e : aSwitch.cases().entrySet()) { builder.add("case "); - builder.add(renderConstantExpression(e.getKey())); + builder.add(renderConstantExpression(remappedLocals, e.getKey())); builder.add(" -> "); ExpressionDef value = e.getValue(); - builder.add(renderExpression(objectDef, methodDef, value)); + builder.add(renderExpression(objectDef, methodDef, remappedLocals, value)); if (value instanceof ExpressionDef.SwitchYieldCase) { builder.add("\n"); } else { @@ -753,7 +793,7 @@ private CodeBlock renderExpression(@Nullable ObjectDef objectDef, MethodDef meth if (aSwitch.defaultCase() != null) { builder.add("default"); builder.add(" -> "); - builder.add(renderExpression(objectDef, methodDef, aSwitch.defaultCase())); + builder.add(renderExpression(objectDef, methodDef, remappedLocals, aSwitch.defaultCase())); if (aSwitch.defaultCase() instanceof ExpressionDef.SwitchYieldCase) { builder.add("\n"); } else { @@ -776,9 +816,9 @@ private CodeBlock renderExpression(@Nullable ObjectDef objectDef, MethodDef meth StatementDef last = flatten.get(flatten.size() - 1); List rest = flatten.subList(0, flatten.size() - 1); for (StatementDef statementDef : rest) { - builder.add(renderStatementCodeBlock(objectDef, methodDef, statementDef)); + builder.add(renderStatementCodeBlock(objectDef, methodDef, remappedLocals, statementDef)); } - renderYield(builder, methodDef, last, objectDef); + renderYield(builder, methodDef, remappedLocals, last, objectDef); builder.unindent(); builder.add("}"); String str = builder.build().toString(); @@ -786,13 +826,13 @@ private CodeBlock renderExpression(@Nullable ObjectDef objectDef, MethodDef meth return CodeBlock.ofWithoutFormat(str); } if (expressionDef instanceof VariableDef variableDef) { - return renderVariable(objectDef, methodDef, variableDef); + return renderVariable(objectDef, methodDef, remappedLocals, variableDef); } if (expressionDef instanceof ExpressionDef.InvokeGetClassMethod invokeGetClassMethod) { - return renderExpression(objectDef, methodDef, JavaIdioms.getClass(invokeGetClassMethod)); + return renderExpression(objectDef, methodDef, remappedLocals, JavaIdioms.getClass(invokeGetClassMethod)); } if (expressionDef instanceof ExpressionDef.InvokeHashCodeMethod invokeHashCodeMethod) { - return renderExpression(objectDef, methodDef, JavaIdioms.hashCode(invokeHashCodeMethod)); + return renderExpression(objectDef, methodDef, remappedLocals, JavaIdioms.hashCode(invokeHashCodeMethod)); } throw new IllegalStateException("Unrecognized expression: " + expressionDef); } @@ -819,8 +859,8 @@ private static String getMathOp(ExpressionDef.MathUnaryOperation mathOperation) }; } - private CodeBlock renderExpressionWithParentheses(@Nullable ObjectDef objectDef, MethodDef methodDef, ExpressionDef expressionDef) { - var rendered = renderExpression(objectDef, methodDef, expressionDef); + private CodeBlock renderExpressionWithParentheses(@Nullable ObjectDef objectDef, MethodDef methodDef, Map remappedLocals, ExpressionDef expressionDef) { + var rendered = renderExpression(objectDef, methodDef, remappedLocals, expressionDef); while (expressionDef instanceof ExpressionDef.Cast cast) { expressionDef = cast.expressionDef(); } @@ -838,42 +878,45 @@ private CodeBlock addParentheses(CodeBlock rendered) { ); } - private CodeBlock renderCondition(@Nullable ObjectDef objectDef, MethodDef methodDef, ExpressionDef.ConditionExpressionDef expressionDef) { + private CodeBlock renderCondition(@Nullable ObjectDef objectDef, + MethodDef methodDef, + Map remappedLocals, + ExpressionDef.ConditionExpressionDef expressionDef) { if (expressionDef instanceof ExpressionDef.IsNull isNull) { - return renderCondition(objectDef, methodDef, new ExpressionDef.ComparisonOperation(ExpressionDef.ComparisonOperation.OpType.EQUAL_TO, isNull.expression(), ExpressionDef.nullValue())); + return renderCondition(objectDef, methodDef, remappedLocals, new ExpressionDef.ComparisonOperation(ExpressionDef.ComparisonOperation.OpType.EQUAL_TO, isNull.expression(), ExpressionDef.nullValue())); } if (expressionDef instanceof ExpressionDef.IsNotNull isNotNull) { - return renderCondition(objectDef, methodDef, new ExpressionDef.ComparisonOperation(ExpressionDef.ComparisonOperation.OpType.NOT_EQUAL_TO, isNotNull.expression(), ExpressionDef.nullValue())); + return renderCondition(objectDef, methodDef, remappedLocals, new ExpressionDef.ComparisonOperation(ExpressionDef.ComparisonOperation.OpType.NOT_EQUAL_TO, isNotNull.expression(), ExpressionDef.nullValue())); } if (expressionDef instanceof ExpressionDef.IsTrue isTrue) { - return renderExpressionWithParentheses(objectDef, methodDef, isTrue.expression()); + return renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, isTrue.expression()); } if (expressionDef instanceof ExpressionDef.IsFalse isFalse) { return CodeBlock.concat( CodeBlock.of("!"), - renderExpressionWithParentheses(objectDef, methodDef, isFalse.expression()) + renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, isFalse.expression()) ); } if (expressionDef instanceof ExpressionDef.ComparisonOperation comparisonOperation) { return CodeBlock.concat( - renderExpressionWithParentheses(objectDef, methodDef, comparisonOperation.left()), + renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, comparisonOperation.left()), CodeBlock.of(getOpType(comparisonOperation)), - renderExpressionWithParentheses(objectDef, methodDef, comparisonOperation.right()) + renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, comparisonOperation.right()) ); } if (expressionDef instanceof ExpressionDef.And andExpressionDef) { return CodeBlock.concat( - renderCondition(objectDef, methodDef, andExpressionDef.left()), + renderCondition(objectDef, methodDef, remappedLocals, andExpressionDef.left()), CodeBlock.of(" && "), - renderCondition(objectDef, methodDef, andExpressionDef.right()) + renderCondition(objectDef, methodDef, remappedLocals, andExpressionDef.right()) ); } if (expressionDef instanceof ExpressionDef.Or orExpressionDef) { return addParentheses( CodeBlock.concat( - renderCondition(objectDef, methodDef, orExpressionDef.left()), + renderCondition(objectDef, methodDef, remappedLocals, orExpressionDef.left()), CodeBlock.of(" || "), - renderCondition(objectDef, methodDef, orExpressionDef.right()) + renderCondition(objectDef, methodDef, remappedLocals, orExpressionDef.right()) ) ); } @@ -883,9 +926,9 @@ private CodeBlock renderCondition(@Nullable ObjectDef objectDef, MethodDef metho ExpressionDef right = equalsStructurally.other(); TypeDef rightType = right.type(); if (leftType.isPrimitive() || rightType.isPrimitive()) { - return renderEqualsReferentially(objectDef, methodDef, left, right); + return renderEqualsReferentially(objectDef, methodDef, remappedLocals, left, right); } - return renderExpressionWithParentheses(objectDef, methodDef, JavaIdioms.equalsStructurally(equalsStructurally)); + return renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, JavaIdioms.equalsStructurally(equalsStructurally)); } if (expressionDef instanceof ExpressionDef.NotEqualsStructurally notEqualsStructurally) { ExpressionDef left = notEqualsStructurally.instance(); @@ -893,19 +936,19 @@ private CodeBlock renderCondition(@Nullable ObjectDef objectDef, MethodDef metho ExpressionDef right = notEqualsStructurally.other(); TypeDef rightType = right.type(); if (leftType.isPrimitive() || rightType.isPrimitive()) { - return renderEqualsReferentially(objectDef, methodDef, left, right); + return renderEqualsReferentially(objectDef, methodDef, remappedLocals, left, right); } - return renderExpressionWithParentheses(objectDef, methodDef, JavaIdioms.equalsStructurally(notEqualsStructurally.instance(), notEqualsStructurally.other()).isFalse()); + return renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, JavaIdioms.equalsStructurally(notEqualsStructurally.instance(), notEqualsStructurally.other()).isFalse()); } if (expressionDef instanceof ExpressionDef.EqualsReferentially equalsReferentially) { ExpressionDef left = equalsReferentially.instance(); ExpressionDef right = equalsReferentially.other(); - return renderEqualsReferentially(objectDef, methodDef, left, right); + return renderEqualsReferentially(objectDef, methodDef, remappedLocals, left, right); } if (expressionDef instanceof ExpressionDef.NotEqualsReferentially notEqualsReferentially) { ExpressionDef left = notEqualsReferentially.instance(); ExpressionDef right = notEqualsReferentially.other(); - return renderNotEqualsReferentially(objectDef, methodDef, left, right); + return renderNotEqualsReferentially(objectDef, methodDef, remappedLocals, left, right); } throw new IllegalStateException("Unrecognized condition: " + expressionDef); } @@ -921,28 +964,28 @@ private static String getOpType(ExpressionDef.ComparisonOperation comparisonOper }; } - private CodeBlock renderEqualsReferentially(ObjectDef objectDef, MethodDef methodDef, ExpressionDef left, ExpressionDef right) { + private CodeBlock renderEqualsReferentially(ObjectDef objectDef, MethodDef methodDef, Map remappedLocals, ExpressionDef left, ExpressionDef right) { return CodeBlock.builder() - .add(renderExpressionWithParentheses(objectDef, methodDef, left)) + .add(renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, left)) .add(" == ") - .add(renderExpressionWithParentheses(objectDef, methodDef, right)) + .add(renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, right)) .build(); } - private CodeBlock renderNotEqualsReferentially(ObjectDef objectDef, MethodDef methodDef, ExpressionDef left, ExpressionDef right) { + private CodeBlock renderNotEqualsReferentially(ObjectDef objectDef, MethodDef methodDef, Map remappedLocals, ExpressionDef left, ExpressionDef right) { return CodeBlock.builder() - .add(renderExpressionWithParentheses(objectDef, methodDef, left)) + .add(renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, left)) .add(" != ") - .add(renderExpressionWithParentheses(objectDef, methodDef, right)) + .add(renderExpressionWithParentheses(objectDef, methodDef, remappedLocals, right)) .build(); } - private void renderYield(CodeBlock.Builder builder, MethodDef methodDef, StatementDef statementDef, ObjectDef objectDef) { + private void renderYield(CodeBlock.Builder builder, MethodDef methodDef, Map remappedLocals, StatementDef statementDef, ObjectDef objectDef) { if (statementDef instanceof StatementDef.Return aReturn) { builder.addStatement( CodeBlock.concat( CodeBlock.of("yield "), - renderExpression(objectDef, methodDef, aReturn.expression()) + renderExpression(objectDef, methodDef, remappedLocals, aReturn.expression()) ) ); } else { @@ -950,7 +993,7 @@ private void renderYield(CodeBlock.Builder builder, MethodDef methodDef, Stateme } } - private CodeBlock renderConstantExpression(ExpressionDef.Constant constant) { + private CodeBlock renderConstantExpression(Map remappedLocals, ExpressionDef.Constant constant) { TypeDef type = constant.type(); Object value = constant.value(); if (value == null) { @@ -960,6 +1003,7 @@ private CodeBlock renderConstantExpression(ExpressionDef.Constant constant) { return renderExpression( null, null, + remappedLocals, classTypeDef.getStaticField(value instanceof Enum anEnum ? anEnum.name() : value.toString(), type) ); } @@ -974,7 +1018,7 @@ private CodeBlock renderConstantExpression(ExpressionDef.Constant constant) { if (value.getClass().isArray()) { final var array = value; final var values = IntStream.range(0, Array.getLength(array)) - .mapToObj(i -> renderConstantExpression(new ExpressionDef.Constant(arrayDef.componentType(), Array.get(array, i)))) + .mapToObj(i -> renderConstantExpression(remappedLocals, new ExpressionDef.Constant(arrayDef.componentType(), Array.get(array, i)))) .collect(CodeBlock.joining(", ")); final String typeName; if (arrayDef.componentType() instanceof ClassTypeDef arrayClassTypeDef) { @@ -1006,7 +1050,10 @@ private CodeBlock renderConstantExpression(ExpressionDef.Constant constant) { throw new IllegalStateException("Unrecognized expression: " + constant); } - private CodeBlock renderVariable(@Nullable ObjectDef objectDef, @Nullable MethodDef methodDef, VariableDef variableDef) { + private CodeBlock renderVariable(@Nullable ObjectDef objectDef, @Nullable MethodDef methodDef, Map remappedLocals, VariableDef variableDef) { + if (variableDef instanceof VariableDef.ExceptionVar) { + return CodeBlock.of(Objects.requireNonNull(remappedLocals.get(EXCEPTION_NAME))); + } if (variableDef instanceof VariableDef.Local localVariableDef) { return CodeBlock.of(localVariableDef.name()); } @@ -1035,7 +1082,7 @@ private CodeBlock renderVariable(@Nullable ObjectDef objectDef, @Nullable Method } else { throw new IllegalStateException("Field access not supported on the object definition: " + objectDef); } - return CodeBlock.of(renderExpression(objectDef, methodDef, field.instance()) + "." + field.name()); + return CodeBlock.of(renderExpression(objectDef, methodDef, remappedLocals, field.instance()) + "." + field.name()); } if (variableDef instanceof VariableDef.This) { if (objectDef == null) { @@ -1043,10 +1090,13 @@ private CodeBlock renderVariable(@Nullable ObjectDef objectDef, @Nullable Method } return CodeBlock.of("this"); } - if (variableDef instanceof VariableDef.Super) { + if (variableDef instanceof VariableDef.Super aSuper) { if (objectDef == null) { throw new IllegalStateException("Accessing 'super' is not available"); } + if (aSuper.type() != TypeDef.SUPER) { + return CodeBlock.of("$T.super", asType(aSuper.type(), objectDef)); + } return CodeBlock.of("super"); } throw new IllegalStateException("Unrecognized variable: " + variableDef); diff --git a/test-suite-bytecode/src/test/java/io/micronaut/sourcegen/example/MethodInvokerTest.java b/test-suite-bytecode/src/test/java/io/micronaut/sourcegen/example/MethodInvokerTest.java index 0b48d949b..fbb1fc181 100644 --- a/test-suite-bytecode/src/test/java/io/micronaut/sourcegen/example/MethodInvokerTest.java +++ b/test-suite-bytecode/src/test/java/io/micronaut/sourcegen/example/MethodInvokerTest.java @@ -89,6 +89,35 @@ public void testStaticMethod() { ); } + @Test + public void testInterfaceSuperMethod() { + Assertions.assertEquals( + "ABCSTT102", + new MethodRepositoryInvoker() { + + @Override + public String interfaceMethod(String string, Integer integer, int i) { + throw new UnsupportedOperationException(); + } + + @Override + public double interfaceMethodReturnsDouble() { + throw new UnsupportedOperationException(); + } + + @Override + public long interfaceMethodReturnsLong() { + throw new UnsupportedOperationException(); + } + + @Override + public int interfaceMethodReturnsInt() { + throw new UnsupportedOperationException(); + } + }.defaultMethod("STT", 100, 2) + ); + } + @Test public void testDefaultMethodIgnoreResult() { Assertions.assertEquals( diff --git a/test-suite-custom-generators/src/main/java/io/micronaut/sourcegen/custom/visitor/GenerateMethodInvocationVisitor.java b/test-suite-custom-generators/src/main/java/io/micronaut/sourcegen/custom/visitor/GenerateMethodInvocationVisitor.java index 16cb03b34..ed18585ef 100644 --- a/test-suite-custom-generators/src/main/java/io/micronaut/sourcegen/custom/visitor/GenerateMethodInvocationVisitor.java +++ b/test-suite-custom-generators/src/main/java/io/micronaut/sourcegen/custom/visitor/GenerateMethodInvocationVisitor.java @@ -19,6 +19,7 @@ import io.micronaut.core.annotation.NonNull; import io.micronaut.core.reflect.ReflectionUtils; import io.micronaut.inject.ast.ClassElement; +import io.micronaut.inject.ast.MethodElement; import io.micronaut.inject.visitor.TypeElementVisitor; import io.micronaut.inject.visitor.VisitorContext; import io.micronaut.sourcegen.custom.example.GenerateMethodInvocation; @@ -28,6 +29,7 @@ import io.micronaut.sourcegen.model.ClassTypeDef; import io.micronaut.sourcegen.model.ExpressionDef; import io.micronaut.sourcegen.model.FieldDef; +import io.micronaut.sourcegen.model.JavaIdioms; import io.micronaut.sourcegen.model.MethodDef; import io.micronaut.sourcegen.model.StatementDef; import io.micronaut.sourcegen.model.TypeDef; @@ -172,6 +174,22 @@ public void visitClass(ClassElement element, VisitorContext context) { sourceGenerator.write(classDef, context, element); + ClassTypeDef myRepoType = ClassTypeDef.of(myRepository); + MethodElement defaultMethod = myRepository.findMethod("defaultMethod").get(); + ClassDef interfaceSuperInvokerDef = ClassDef.builder("io.micronaut.sourcegen.example.MethodRepositoryInvoker") + .addModifiers(Modifier.ABSTRACT) + .addSuperinterface(myRepoType) + .addMethod(MethodDef.override(defaultMethod) + .build((aThis, methodParameters) -> + JavaIdioms.concatStrings( + ExpressionDef.constant("ABC"), + aThis.superRef(myRepoType) + .invoke(defaultMethod, methodParameters) + ).returning()) + ).build(); + + sourceGenerator.write(interfaceSuperInvokerDef, context, element); + FieldDef targetField = FieldDef.builder("target", TypeDef.OBJECT).addModifiers(Modifier.PRIVATE).build(); FieldDef lockField = FieldDef.builder("lock", ClassTypeDef.of(ReentrantReadWriteLock.class)) .addModifiers(Modifier.PRIVATE, Modifier.FINAL) @@ -224,7 +242,9 @@ public void visitClass(ClassElement element, VisitorContext context) { methodParameters.get(1).invoke("getAndIncrement", TypeDef.Primitive.INT), StatementDef.doTry( StatementDef.multi( - ClassTypeDef.of(IllegalStateException.class).instantiate().doThrow(), + ExpressionDef.trueValue().isTrue().doIf( + ClassTypeDef.of(IllegalStateException.class).instantiate().doThrow() + ), methodParameters.get(1).invoke("getAndIncrement", TypeDef.Primitive.INT), aThis.field(targetField).returning() ) @@ -238,7 +258,9 @@ public void visitClass(ClassElement element, VisitorContext context) { methodParameters.get(1).invoke("getAndIncrement", TypeDef.Primitive.INT), StatementDef.doTry( StatementDef.multi( - ClassTypeDef.of(IllegalStateException.class).instantiate(ExpressionDef.constant("Bam")).doThrow(), + ExpressionDef.trueValue().isTrue().doIf( + ClassTypeDef.of(IllegalStateException.class).instantiate(ExpressionDef.constant("Bam")).doThrow() + ), methodParameters.get(1).invoke("getAndIncrement", TypeDef.Primitive.INT), aThis.field(targetField).returning() ) diff --git a/test-suite-custom-generators/src/main/java/io/micronaut/sourcegen/custom/visitor/GenerateMyRepository2Visitor.java b/test-suite-custom-generators/src/main/java/io/micronaut/sourcegen/custom/visitor/GenerateMyRepository2Visitor.java index 1a873a02c..135acfdd3 100644 --- a/test-suite-custom-generators/src/main/java/io/micronaut/sourcegen/custom/visitor/GenerateMyRepository2Visitor.java +++ b/test-suite-custom-generators/src/main/java/io/micronaut/sourcegen/custom/visitor/GenerateMyRepository2Visitor.java @@ -20,20 +20,14 @@ import io.micronaut.inject.ast.ClassElement; import io.micronaut.inject.visitor.TypeElementVisitor; import io.micronaut.inject.visitor.VisitorContext; -import io.micronaut.sourcegen.custom.example.GenerateMyRepository1; import io.micronaut.sourcegen.custom.example.GenerateMyRepository2; import io.micronaut.sourcegen.generator.SourceGenerator; import io.micronaut.sourcegen.generator.SourceGenerators; -import io.micronaut.sourcegen.model.ClassDef; -import io.micronaut.sourcegen.model.ClassTypeDef; import io.micronaut.sourcegen.model.InterfaceDef; -import io.micronaut.sourcegen.model.MethodDef; -import io.micronaut.sourcegen.model.PropertyDef; import io.micronaut.sourcegen.model.TypeDef; import javax.lang.model.element.Modifier; import java.util.List; -import java.util.Optional; @Internal public final class GenerateMyRepository2Visitor implements TypeElementVisitor { diff --git a/test-suite-java/src/main/java/io/micronaut/sourcegen/example/MyRepository.java b/test-suite-java/src/main/java/io/micronaut/sourcegen/example/MyRepository.java new file mode 100644 index 000000000..80d2f9987 --- /dev/null +++ b/test-suite-java/src/main/java/io/micronaut/sourcegen/example/MyRepository.java @@ -0,0 +1,36 @@ +/* + * Copyright 2017-2024 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.sourcegen.example; + +public interface MyRepository { + + static String staticMethod(String string, Integer integer, int i) { + return string + (integer + i); + } + + default String defaultMethod(String string, Integer integer, int i) { + return string + (integer + i); + } + + String interfaceMethod(String string, Integer integer, int i); + + double interfaceMethodReturnsDouble(); + + long interfaceMethodReturnsLong(); + + int interfaceMethodReturnsInt(); + +} diff --git a/test-suite-java/src/main/java/io/micronaut/sourcegen/example/Trigger.java b/test-suite-java/src/main/java/io/micronaut/sourcegen/example/Trigger.java index 5be15f378..ec0356fcc 100644 --- a/test-suite-java/src/main/java/io/micronaut/sourcegen/example/Trigger.java +++ b/test-suite-java/src/main/java/io/micronaut/sourcegen/example/Trigger.java @@ -34,6 +34,7 @@ @GenerateAnnotatedType @GenerateInnerTypes @GenerateMyEnum2 +@GenerateMethodInvocation public class Trigger { public List copyAddresses; } diff --git a/test-suite-java/src/test/java/io/micronaut/sourcegen/example/MethodInvokerTest.java b/test-suite-java/src/test/java/io/micronaut/sourcegen/example/MethodInvokerTest.java new file mode 100644 index 000000000..fbb1fc181 --- /dev/null +++ b/test-suite-java/src/test/java/io/micronaut/sourcegen/example/MethodInvokerTest.java @@ -0,0 +1,321 @@ +package io.micronaut.sourcegen.example; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +public class MethodInvokerTest { + + MyRepository repository = new MyRepository() { + @Override + public String interfaceMethod(String string, Integer integer, int i) { + return string + (integer + i); + } + + @Override + public double interfaceMethodReturnsDouble() { + return 1; + } + + @Override + public long interfaceMethodReturnsLong() { + return 2; + } + + @Override + public int interfaceMethodReturnsInt() { + return 3; + } + }; + + MyRepository repositoryImplementedDefault = new MyRepository() { + @Override + public String interfaceMethod(String string, Integer integer, int i) { + return string + (integer + i); + } + + @Override + public String defaultMethod(String string, Integer integer, int i) { + return "X" + MyRepository.super.defaultMethod(string, integer, i); + } + + @Override + public double interfaceMethodReturnsDouble() { + return 1; + } + + @Override + public long interfaceMethodReturnsLong() { + return 2; + } + + @Override + public int interfaceMethodReturnsInt() { + return 3; + } + }; + + @Test + public void testDefaultMethod() { + Assertions.assertEquals( + "DEFAULT102", + MethodInvoker.invokeDefaultMethod(repository, "DEFAULT", 100, 2) + ); + } + + @Test + public void testDefaultMethod2() { + Assertions.assertEquals( + "XDEFAULT102", + MethodInvoker.invokeDefaultMethod(repositoryImplementedDefault, "DEFAULT", 100, 2) + ); + } + + @Test + public void testInterfaceMethod() { + Assertions.assertEquals( + "IFC102", + MethodInvoker.invokeInterfaceMethod(repository, "IFC", 100, 2) + ); + } + + @Test + public void testStaticMethod() { + Assertions.assertEquals( + "STT102", + MethodInvoker.invokeStaticMethod("STT", 100, 2) + ); + } + + @Test + public void testInterfaceSuperMethod() { + Assertions.assertEquals( + "ABCSTT102", + new MethodRepositoryInvoker() { + + @Override + public String interfaceMethod(String string, Integer integer, int i) { + throw new UnsupportedOperationException(); + } + + @Override + public double interfaceMethodReturnsDouble() { + throw new UnsupportedOperationException(); + } + + @Override + public long interfaceMethodReturnsLong() { + throw new UnsupportedOperationException(); + } + + @Override + public int interfaceMethodReturnsInt() { + throw new UnsupportedOperationException(); + } + }.defaultMethod("STT", 100, 2) + ); + } + + @Test + public void testDefaultMethodIgnoreResult() { + Assertions.assertEquals( + "Ignored", + MethodInvoker.invokeDefaultMethodIgnoreResult(repository, "DEFAULT", 100, 2) + ); + } + + @Test + public void testDefaultMethod2IgnoreResult() { + Assertions.assertEquals( + "Ignored", + MethodInvoker.invokeDefaultMethodIgnoreResult(repositoryImplementedDefault, "DEFAULT", 100, 2) + ); + } + + @Test + public void testInterfaceMethodIgnoreResult() { + Assertions.assertEquals( + "Ignored", + MethodInvoker.invokeInterfaceMethodIgnoreResult(repository, "IFC", 100, 2) + ); + } + + @Test + public void testStaticMethodIgnoreResult() { + Assertions.assertEquals( + "Ignored", + MethodInvoker.invokeStaticMethodIgnoreResult("STT", 100, 2) + ); + } + + @Test + public void testInvokeTryFinallyMethod() { + AtomicInteger lock = new AtomicInteger(); + Assertions.assertEquals( + 0, + lock.get() + ); + MethodInvoker.invokeTryFinally(lock); + Assertions.assertEquals( + 1, + lock.get() + ); + } + + @Test + public void testInvokeTryFinallyLockMethod() { + ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + Assertions.assertEquals( + 123, + MethodInvoker.invokeTryFinallyReadLock(lock) + ); + Assertions.assertEquals( + 123, + MethodInvoker.invokeTryFinallyWriteLock(lock) + ); + Assertions.assertEquals( + 123, + MethodInvoker.invokeTryFinallyReadLock(lock) + ); + Assertions.assertEquals( + 123, + MethodInvoker.invokeTryFinallyWriteLock(lock) + ); + } + + @Test + public void testSwapper1() { + Swapper swapper = new Swapper(); + Assertions.assertEquals( + null, + swapper.getTarget() + ); + Object o1 = new Object(); + swapper.swap(o1); + Assertions.assertEquals( + o1, + swapper.getTarget() + ); + Object o2 = new Object(); + swapper.swap(o2); + Assertions.assertEquals( + o2, + swapper.getTarget() + ); + } + + @Test + public void testSwapper2() { + Swapper swapper = new Swapper(); + Assertions.assertEquals( + null, + swapper.getTarget() + ); + AtomicInteger counter = new AtomicInteger(); + Object o1 = new Object(); + swapper.swap2(o1, counter); + Assertions.assertEquals( + o1, + swapper.getTarget() + ); + Assertions.assertEquals( + 0, + counter.get() + ); + Object o2 = new Object(); + swapper.swap2(o2, counter); + Assertions.assertEquals( + o2, + swapper.getTarget() + ); + Assertions.assertEquals( + 0, + counter.get() + ); + } + + @Test + public void testSwapper3() { + Swapper swapper = new Swapper(); + Assertions.assertEquals( + null, + swapper.getTarget() + ); + AtomicInteger counter = new AtomicInteger(); + try { + swapper.swap3(new Object(), counter); + Assertions.fail(); + } catch (IllegalStateException e) { + // Ignore + } + Assertions.assertEquals( + 0, + counter.get() + ); + try { + swapper.swap3(new Object(), counter); + Assertions.fail(); + } catch (IllegalStateException e) { + // Ignore + } + Assertions.assertEquals( + 0, + counter.get() + ); + } + + @Test + public void testSwapper4() { + Swapper swapper = new Swapper(); + AtomicInteger counter = new AtomicInteger(); + Object result = swapper.swap4(new Object(), counter); + Assertions.assertEquals( + 0, + counter.get() + ); + Assertions.assertEquals( + "Bam", + result + ); + } + + @Test + public void testSwapper5() { + Swapper swapper = new Swapper(); + Assertions.assertEquals( + null, + swapper.getTarget() + ); + Object o1 = new Object(); + swapper.swap5(o1); + Assertions.assertEquals( + o1, + swapper.getTarget() + ); + Object o2 = new Object(); + swapper.swap5(o2); + Assertions.assertEquals( + o2, + swapper.getTarget() + ); + } + + @Test + public void testSwapper6() { + Swapper swapper = new Swapper(); + try { + swapper.swap6(new Object()); + Assertions.fail(); + } catch (IllegalStateException e) { + // Ignore + } + try { + swapper.swap6(new Object()); + } catch (IllegalStateException e) { + // Ignore + } + } + +}