diff --git a/spring-core/src/main/java/org/springframework/core/MethodParameter.java b/spring-core/src/main/java/org/springframework/core/MethodParameter.java index 040819bea6673592691a9e9d1c66823afa3174c6..e3ae0562bac8780ea4648afb515395c732fe7add 100644 --- a/spring-core/src/main/java/org/springframework/core/MethodParameter.java +++ b/spring-core/src/main/java/org/springframework/core/MethodParameter.java @@ -26,11 +26,9 @@ import java.lang.reflect.Parameter; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.function.Predicate; -import java.util.stream.Collectors; import kotlin.reflect.KFunction; import kotlin.reflect.KParameter; @@ -398,7 +396,7 @@ public class MethodParameter { * either in the form of Java 8's {@link java.util.Optional}, any variant * of a parameter-level {@code Nullable} annotation (such as from JSR-305 * or the FindBugs set of annotations), or a language-level nullable type - * declaration in Kotlin. + * declaration or {@code Continuation} parameter in Kotlin. * @since 4.3 */ public boolean isOptional() { @@ -867,37 +865,39 @@ public class MethodParameter { private static class KotlinDelegate { /** - * Check whether the specified {@link MethodParameter} represents a nullable Kotlin type - * or an optional parameter (with a default value in the Kotlin declaration). + * Check whether the specified {@link MethodParameter} represents a nullable Kotlin type, + * an optional parameter (with a default value in the Kotlin declaration) or a {@code Continuation} parameter + * used in suspending functions. */ public static boolean isOptional(MethodParameter param) { Method method = param.getMethod(); - Constructor ctor = param.getConstructor(); int index = param.getParameterIndex(); if (method != null && index == -1) { KFunction function = ReflectJvmMapping.getKotlinFunction(method); return (function != null && function.getReturnType().isMarkedNullable()); } - else { - KFunction function = null; - Predicate predicate = null; - if (method != null) { - function = ReflectJvmMapping.getKotlinFunction(method); - predicate = p -> KParameter.Kind.VALUE.equals(p.getKind()); - } - else if (ctor != null) { - function = ReflectJvmMapping.getKotlinFunction(ctor); - predicate = p -> KParameter.Kind.VALUE.equals(p.getKind()) || - KParameter.Kind.INSTANCE.equals(p.getKind()); + KFunction function; + Predicate predicate; + if (method != null) { + if (param.parameterType.getName().equals("kotlin.coroutines.Continuation")) { + return true; } - if (function != null) { - List parameters = function.getParameters(); - KParameter parameter = parameters - .stream() - .filter(predicate) - .collect(Collectors.toList()) - .get(index); - return (parameter.getType().isMarkedNullable() || parameter.isOptional()); + function = ReflectJvmMapping.getKotlinFunction(method); + predicate = p -> KParameter.Kind.VALUE.equals(p.getKind()); + } + else { + function = ReflectJvmMapping.getKotlinFunction(param.getConstructor()); + predicate = p -> KParameter.Kind.VALUE.equals(p.getKind()) || + KParameter.Kind.INSTANCE.equals(p.getKind()); + } + if (function != null) { + int i = 0; + for (KParameter kParameter : function.getParameters()) { + if (predicate.test(kParameter)) { + if (index == i++) { + return (kParameter.getType().isMarkedNullable() || kParameter.isOptional()); + } + } } } return false; diff --git a/spring-core/src/test/kotlin/org/springframework/core/KotlinMethodParameterTests.kt b/spring-core/src/test/kotlin/org/springframework/core/KotlinMethodParameterTests.kt index b88fbaa64b8722835286d339233045b2ed802149..419698d2ff6c3af0be715c8f615c43158aa7574a 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/KotlinMethodParameterTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/KotlinMethodParameterTests.kt @@ -20,6 +20,7 @@ import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import java.lang.reflect.Method import java.lang.reflect.TypeVariable +import kotlin.coroutines.Continuation import kotlin.reflect.full.declaredFunctions import kotlin.reflect.jvm.javaMethod @@ -101,6 +102,13 @@ class KotlinMethodParameterTests { assertThat(returnGenericParameterType("suspendFun8")).isEqualTo(Object::class.java) } + @Test + fun `Continuation parameter is optional`() { + val method = this::class.java.getDeclaredMethod("suspendFun", String::class.java, Continuation::class.java) + assertThat(MethodParameter(method, 0).isOptional).isFalse() + assertThat(MethodParameter(method, 1).isOptional).isTrue() + } + private fun returnParameterType(funName: String) = returnMethodParameter(funName).parameterType private fun returnGenericParameterType(funName: String) = returnMethodParameter(funName).genericParameterType private fun returnGenericParameterTypeName(funName: String) = returnGenericParameterType(funName).typeName