Skip to content

fix: Preserve NULL Parameters in case of AOP Proxy #35014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ public static Object invokeSuspendingFunction(Method method, @Nullable Object ta
Continuation<?> continuation = (Continuation<?>) args[args.length -1];
Assert.state(continuation != null, "No Continuation available");
CoroutineContext context = continuation.getContext().minusKey(Job.Key);
return CoroutinesUtils.invokeSuspendingFunction(context, method, target, args);
return CoroutinesUtils.invokeSuspendingFunctionPreserveNulls(context, method, target, args);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ class AopUtilsKotlinTests {
}
}

@Test
fun `Invoking suspending function with null argument should not return default value`() {
val method = ReflectionUtils.findMethod(WithoutInterface::class.java, "handleWithDefaultParam",
String::class. java, Continuation::class.java)!!
val continuation = Continuation<Any>(CoroutineName("test")) { }
val result = AopUtils.invokeJoinpointUsingReflection(WithoutInterface(), method, arrayOf(null, continuation))
assertThat(result).isInstanceOfSatisfying(Mono::class.java) {
assertThat(it.block()).isEqualTo(null)
}
}

@Test
fun `Invoking suspending function on bridged method should return Mono`() {
val value = "foo"
Expand All @@ -65,6 +76,11 @@ class AopUtilsKotlinTests {
delay(1)
return value
}

suspend fun handleWithDefaultParam(value: String? = "defaultVal") : String? {
delay(1)
return value
}
}

interface ProxyInterface<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,35 @@ public static Publisher<?> invokeSuspendingFunction(Method method, Object target
@SuppressWarnings({"DataFlowIssue", "NullAway"})
public static Publisher<?> invokeSuspendingFunction(
CoroutineContext context, Method method, @Nullable Object target, @Nullable Object... args) {
return invokeSuspendingFunctionCore(context, method, target, args, false);
}

/**
* Invoke a suspending function and convert it to {@link Mono} or
* {@link Flux}.
* @param context the coroutine context to use
* @param method the suspending function to invoke
* @param target the target to invoke {@code method} on
* @param args the function arguments. If the {@code Continuation} argument is specified as the last argument
* (typically {@code null}), it is ignored.
* @return the method invocation result as reactive stream
* @throws IllegalArgumentException if {@code method} is not a suspending function
* @since 6.0
* This function preservers the null parameter passed in argument
*/
@SuppressWarnings({"DataFlowIssue", "NullAway"})
public static Publisher<?> invokeSuspendingFunctionPreserveNulls(
CoroutineContext context, Method method, @Nullable Object target, @Nullable Object... args) {
return invokeSuspendingFunctionCore(context, method, target, args, true);
}

private static Publisher<?> invokeSuspendingFunctionCore(
CoroutineContext context,
Method method,
@Nullable Object target,
@Nullable Object[] args,
boolean preserveNulls)
{

Assert.isTrue(KotlinDetector.isSuspendingFunction(method), "Method must be a suspending function");
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
Expand All @@ -120,26 +149,7 @@ public static Publisher<?> invokeSuspendingFunction(
KCallablesJvm.setAccessible(function, true);
}
Mono<Object> mono = MonoKt.mono(context, (scope, continuation) -> {
Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1);
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE, EXTENSION_RECEIVER -> {
Object arg = args[index];
if (!(parameter.isOptional() && arg == null)) {
KType type = parameter.getType();
if (!(type.isMarkedNullable() && arg == null) &&
type.getClassifier() instanceof KClass<?> kClass &&
KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass))) {
arg = box(kClass, arg);
}
argMap.put(parameter, arg);
}
index++;
}
}
}
Map<KParameter, Object> argMap = buildArgMap(function, target, args, preserveNulls);
return KCallables.callSuspendBy(function, argMap, continuation);
})
.filter(result -> result != Unit.INSTANCE)
Expand All @@ -158,6 +168,40 @@ public static Publisher<?> invokeSuspendingFunction(
return mono;
}

private static Map<KParameter, Object> buildArgMap(
KFunction<?> function,
@Nullable Object target,
@Nullable Object[] args,
boolean preserveNulls) {

Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1);
int index = 0;

for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE, EXTENSION_RECEIVER -> {
Object arg = args[index];

if (!(parameter.isOptional() && arg == null)) {
KType type = parameter.getType();
if (!(type.isMarkedNullable() && arg == null) &&
type.getClassifier() instanceof KClass<?> kClass &&
KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass))) {
arg = box(kClass, arg);
}
argMap.put(parameter, arg);
} else if(preserveNulls) {
argMap.put(parameter, arg);
}
index++;
}
}
}
return argMap;
}


private static Object box(KClass<?> kClass, @Nullable Object arg) {
KFunction<?> constructor = Objects.requireNonNull(KClasses.getPrimaryConstructor(kClass));
KType type = constructor.getParameters().get(0).getType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ class CoroutinesUtilsTests {
}
}

@Test
fun invokeSuspendingFunctionWithNullableParameterPreservesNull() {
val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithOptionalParameterAndDefaultValue", String::class.java, Continuation::class.java)
val mono = CoroutinesUtils.invokeSuspendingFunctionPreserveNulls(Dispatchers.Unconfined, method, this, null, null) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingleOrNull()).isNull()
}
}


@Test
fun invokePrivateSuspendingFunction() {
val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("privateSuspendingFunction", String::class.java, Continuation::class.java)
Expand Down Expand Up @@ -300,6 +310,12 @@ class CoroutinesUtilsTests {
return value
}

suspend fun suspendingFunctionWithOptionalParameterAndDefaultValue(value: String? = "foo"): String? {
delay(1)
return value
}


suspend fun suspendingFunctionWithMono(): Mono<String> {
delay(1)
return Mono.just("foo")
Expand Down