提交 0e9ea2c9 编写于 作者: A Arjen Poutsma

Fix review remarks on Servlet.fn

This commit incoporates the remarks made during the Servlet.fn review.

See gh-21490
上级 2f682e10
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
package org.springframework.web.servlet.function; package org.springframework.web.servlet.function;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Type;
import java.net.URI; import java.net.URI;
import java.time.Instant; import java.time.Instant;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
...@@ -39,12 +40,14 @@ import org.reactivestreams.Publisher; ...@@ -39,12 +40,14 @@ import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber; import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription; import org.reactivestreams.Subscription;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.CacheControl; import org.springframework.http.CacheControl;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.InvalidMediaTypeException;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
...@@ -64,7 +67,7 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> { ...@@ -64,7 +67,7 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
private final T entity; private final T entity;
private final BuilderFunction<T> builderFunction; private final Type entityType;
private int status = HttpStatus.OK.value(); private int status = HttpStatus.OK.value();
...@@ -73,9 +76,9 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> { ...@@ -73,9 +76,9 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
private final MultiValueMap<String, Cookie> cookies = new LinkedMultiValueMap<>(); private final MultiValueMap<String, Cookie> cookies = new LinkedMultiValueMap<>();
private DefaultEntityResponseBuilder(T entity, BuilderFunction<T> builderFunction) { private DefaultEntityResponseBuilder(T entity, @Nullable Type entityType) {
this.entity = entity; this.entity = entity;
this.builderFunction = builderFunction; this.entityType = (entityType != null) ? entityType : entity.getClass();
} }
@Override @Override
...@@ -185,9 +188,23 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> { ...@@ -185,9 +188,23 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
return this; return this;
} }
@SuppressWarnings({"rawtypes", "unchecked"})
@Override @Override
public EntityResponse<T> build() { public EntityResponse<T> build() {
return this.builderFunction.build(this.status, this.headers, this.cookies, this.entity); if (this.entity instanceof CompletionStage) {
CompletionStage completionStage = (CompletionStage) this.entity;
return new CompletionStageEntityResponse(this.status, this.headers, this.cookies,
completionStage, this.entityType);
}
else if (this.entity instanceof Publisher) {
Publisher publisher = (Publisher) this.entity;
return new PublisherEntityResponse(this.status, this.headers, this.cookies, publisher,
this.entityType);
}
else {
return new DefaultEntityResponse<>(this.status, this.headers, this.cookies, this.entity,
this.entityType);
}
} }
...@@ -195,42 +212,16 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> { ...@@ -195,42 +212,16 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
* Return a new {@link EntityResponse.Builder} from the given object. * Return a new {@link EntityResponse.Builder} from the given object.
*/ */
public static <T> EntityResponse.Builder<T> fromObject(T t) { public static <T> EntityResponse.Builder<T> fromObject(T t) {
return new DefaultEntityResponseBuilder<>(t, DefaultEntityResponse::new); return new DefaultEntityResponseBuilder<>(t, null);
}
/**
* Return a new {@link EntityResponse.Builder} from the given completion stage.
*/
public static <T> EntityResponse.Builder<CompletionStage<T>> fromCompletionStage(
CompletionStage<T> completionStage) {
return new DefaultEntityResponseBuilder<>(completionStage,
CompletionStageEntityResponse::new);
} }
/** /**
* Return a new {@link EntityResponse.Builder} from the given Reactive Streams publisher. * Return a new {@link EntityResponse.Builder} from the given object and type reference.
*/ */
public static <T> EntityResponse.Builder<Publisher<T>> fromPublisher(Publisher<T> publisher) { public static <T> EntityResponse.Builder<T> fromObject(T t, ParameterizedTypeReference<?> bodyType) {
return new DefaultEntityResponseBuilder<>(publisher, PublisherEntityResponse::new); return new DefaultEntityResponseBuilder<>(t, bodyType.getType());
} }
@SuppressWarnings("unchecked")
private static <T> HttpMessageConverter<T> cast(HttpMessageConverter<?> messageConverter) {
return (HttpMessageConverter<T>) messageConverter;
}
/**
* Defines contract for building {@link EntityResponse} instances.
*/
private interface BuilderFunction<T> {
EntityResponse<T> build(int statusCode, HttpHeaders headers,
MultiValueMap<String, Cookie> cookies, T entity);
}
/** /**
* Default {@link EntityResponse} implementation for synchronous bodies. * Default {@link EntityResponse} implementation for synchronous bodies.
*/ */
...@@ -240,12 +231,15 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> { ...@@ -240,12 +231,15 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
private final T entity; private final T entity;
private final Type entityType;
public DefaultEntityResponse(int statusCode, HttpHeaders headers, public DefaultEntityResponse(int statusCode, HttpHeaders headers,
MultiValueMap<String, Cookie> cookies, T entity) { MultiValueMap<String, Cookie> cookies, T entity, Type entityType) {
super(statusCode, headers, cookies); super(statusCode, headers, cookies);
this.entity = entity; this.entity = entity;
this.entityType = entityType;
} }
@Override @Override
...@@ -258,11 +252,12 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> { ...@@ -258,11 +252,12 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
HttpServletResponse servletResponse, Context context) HttpServletResponse servletResponse, Context context)
throws ServletException, IOException { throws ServletException, IOException {
writeEntityWithMessageConverters(this.entity, servletRequest, servletResponse, context); writeEntityWithMessageConverters(this.entity, servletRequest,servletResponse, context);
return null; return null;
} }
@SuppressWarnings("unchecked")
protected void writeEntityWithMessageConverters(Object entity, protected void writeEntityWithMessageConverters(Object entity,
HttpServletRequest request, HttpServletResponse response, HttpServletRequest request, HttpServletResponse response,
ServerResponse.Context context) ServerResponse.Context context)
...@@ -271,30 +266,39 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> { ...@@ -271,30 +266,39 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
ServletServerHttpResponse serverResponse = new ServletServerHttpResponse(response); ServletServerHttpResponse serverResponse = new ServletServerHttpResponse(response);
MediaType contentType = getContentType(response); MediaType contentType = getContentType(response);
Class<?> entityType = entity.getClass(); Class<?> entityClass = entity.getClass();
HttpMessageConverter<Object> messageConverter = context.messageConverters().stream() for (HttpMessageConverter<?> messageConverter : context.messageConverters()) {
.filter(converter -> converter.canWrite(entityType, contentType)) if (messageConverter instanceof GenericHttpMessageConverter<?>) {
.findFirst() GenericHttpMessageConverter<Object> genericMessageConverter =
.map(DefaultEntityResponseBuilder::cast) (GenericHttpMessageConverter<Object>) messageConverter;
.orElseThrow(() -> new HttpMediaTypeNotAcceptableException( if (genericMessageConverter.canWrite(this.entityType, entityClass, contentType)) {
producibleMediaTypes(context.messageConverters(), entityType))); genericMessageConverter.write(entity, this.entityType, contentType, serverResponse);
return;
}
}
if (messageConverter.canWrite(entityClass, contentType)) {
((HttpMessageConverter<Object>)messageConverter).write(entity, contentType, serverResponse);
return;
}
}
messageConverter.write(entity, contentType, serverResponse); List<MediaType> producibleMediaTypes = producibleMediaTypes(context.messageConverters(), entityClass);
throw new HttpMediaTypeNotAcceptableException(producibleMediaTypes);
} }
@Nullable @Nullable
private MediaType getContentType(HttpServletResponse response) { private static MediaType getContentType(HttpServletResponse response) {
try { try {
return MediaType.parseMediaType(response.getContentType()); return MediaType.parseMediaType(response.getContentType()).removeQualityValue();
} }
catch (InvalidMediaTypeException ex) { catch (InvalidMediaTypeException ex) {
return null; return null;
} }
} }
protected final void tryWriteEntityWithMessageConverters(Object entity, protected void tryWriteEntityWithMessageConverters(Object entity,
HttpServletRequest request, HttpServletResponse response, HttpServletRequest request, HttpServletResponse response,
ServerResponse.Context context) { ServerResponse.Context context) {
try { try {
writeEntityWithMessageConverters(entity, request, response, context); writeEntityWithMessageConverters(entity, request, response, context);
...@@ -323,10 +327,10 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> { ...@@ -323,10 +327,10 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
private static class CompletionStageEntityResponse<T> private static class CompletionStageEntityResponse<T>
extends DefaultEntityResponse<CompletionStage<T>> { extends DefaultEntityResponse<CompletionStage<T>> {
public CompletionStageEntityResponse(int statusCode, public CompletionStageEntityResponse(int statusCode, HttpHeaders headers,
HttpHeaders headers, MultiValueMap<String, Cookie> cookies, CompletionStage<T> entity,
MultiValueMap<String, Cookie> cookies, CompletionStage<T> entity) { Type entityType) {
super(statusCode, headers, cookies, entity); super(statusCode, headers, cookies, entity, entityType);
} }
@Override @Override
...@@ -338,6 +342,7 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> { ...@@ -338,6 +342,7 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
entity().whenComplete((entity, throwable) -> { entity().whenComplete((entity, throwable) -> {
try { try {
if (entity != null) { if (entity != null) {
tryWriteEntityWithMessageConverters(entity, tryWriteEntityWithMessageConverters(entity,
(HttpServletRequest) asyncContext.getRequest(), (HttpServletRequest) asyncContext.getRequest(),
(HttpServletResponse) asyncContext.getResponse(), (HttpServletResponse) asyncContext.getResponse(),
...@@ -358,8 +363,9 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> { ...@@ -358,8 +363,9 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
private static class PublisherEntityResponse<T> extends DefaultEntityResponse<Publisher<T>> { private static class PublisherEntityResponse<T> extends DefaultEntityResponse<Publisher<T>> {
public PublisherEntityResponse(int statusCode, HttpHeaders headers, public PublisherEntityResponse(int statusCode, HttpHeaders headers,
MultiValueMap<String, Cookie> cookies, Publisher<T> entity) { MultiValueMap<String, Cookie> cookies, Publisher<T> entity,
super(statusCode, headers, cookies, entity); Type entityType) {
super(statusCode, headers, cookies, entity, entityType);
} }
@Override @Override
...@@ -425,6 +431,8 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> { ...@@ -425,6 +431,8 @@ class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
(HttpServletRequest) this.asyncContext.getRequest(), (HttpServletRequest) this.asyncContext.getRequest(),
(HttpServletResponse) this.asyncContext.getResponse(), (HttpServletResponse) this.asyncContext.getResponse(),
this.context); this.context);
this.asyncContext.complete();
} }
@Override @Override
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
package org.springframework.web.servlet.function; package org.springframework.web.servlet.function;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type; import java.lang.reflect.Type;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.URI; import java.net.URI;
...@@ -48,7 +49,6 @@ import org.springframework.http.MediaType; ...@@ -48,7 +49,6 @@ import org.springframework.http.MediaType;
import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
...@@ -152,35 +152,42 @@ class DefaultServerRequest implements ServerRequest { ...@@ -152,35 +152,42 @@ class DefaultServerRequest implements ServerRequest {
@Override @Override
public <T> T body(ParameterizedTypeReference<T> bodyType) throws IOException, ServletException { public <T> T body(ParameterizedTypeReference<T> bodyType) throws IOException, ServletException {
Type type = bodyType.getType(); Type type = bodyType.getType();
Class<?> contextClass = null; return bodyInternal(type, bodyClass(type));
}
static Class<?> bodyClass(Type type) {
if (type instanceof Class) { if (type instanceof Class) {
contextClass = (Class<?>) type; return (Class<?>) type;
} }
return bodyInternal(type, contextClass); if (type instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) type;
if (parameterizedType.getRawType() instanceof Class) {
return (Class<?>) parameterizedType.getRawType();
}
}
return Object.class;
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private <T> T bodyInternal(Type type, @Nullable Class<?> contextClass) private <T> T bodyInternal(Type bodyType, Class<?> bodyClass)
throws ServletException, IOException { throws ServletException, IOException {
MediaType contentType = MediaType contentType =
this.headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); this.headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
for (HttpMessageConverter<?> messageConverter : this.messageConverters) { for (HttpMessageConverter<?> messageConverter : this.messageConverters) {
if (messageConverter instanceof GenericHttpMessageConverter<?>) { if (messageConverter instanceof GenericHttpMessageConverter) {
GenericHttpMessageConverter<T> genericMessageConverter = GenericHttpMessageConverter<T> genericMessageConverter =
(GenericHttpMessageConverter<T>) messageConverter; (GenericHttpMessageConverter<T>) messageConverter;
if (genericMessageConverter.canRead(type, contextClass, contentType)) { if (genericMessageConverter.canRead(bodyType, bodyClass, contentType)) {
return genericMessageConverter.read(type, contextClass, this.serverHttpRequest); return genericMessageConverter.read(bodyType, bodyClass, this.serverHttpRequest);
} }
} }
else { if (messageConverter.canRead(bodyClass, contentType)) {
if (messageConverter.canRead(contextClass, contentType)) { HttpMessageConverter<T> theConverter =
HttpMessageConverter<T> theConverter = (HttpMessageConverter<T>) messageConverter;
(HttpMessageConverter<T>) messageConverter; Class<? extends T> clazz = (Class<? extends T>) bodyClass;
Class<? extends T> clazz = (Class<? extends T>) contextClass; return theConverter.read(clazz, this.serverHttpRequest);
return theConverter.read(clazz, this.serverHttpRequest);
}
} }
} }
throw new HttpMediaTypeNotSupportedException(contentType, this.allSupportedMediaTypes); throw new HttpMediaTypeNotSupportedException(contentType, this.allSupportedMediaTypes);
...@@ -196,6 +203,11 @@ class DefaultServerRequest implements ServerRequest { ...@@ -196,6 +203,11 @@ class DefaultServerRequest implements ServerRequest {
return this.attributes; return this.attributes;
} }
@Override
public Optional<String> param(String name) {
return Optional.ofNullable(servletRequest().getParameter(name));
}
@Override @Override
public MultiValueMap<String, String> params() { public MultiValueMap<String, String> params() {
return this.params; return this.params;
......
...@@ -46,7 +46,6 @@ import org.springframework.http.HttpMethod; ...@@ -46,7 +46,6 @@ import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
...@@ -239,35 +238,29 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder { ...@@ -239,35 +238,29 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder {
@Override @Override
public <T> T body(ParameterizedTypeReference<T> bodyType) throws IOException, ServletException { public <T> T body(ParameterizedTypeReference<T> bodyType) throws IOException, ServletException {
Type type = bodyType.getType(); Type type = bodyType.getType();
Class<?> contextClass = null; return bodyInternal(type, DefaultServerRequest.bodyClass(type));
if (type instanceof Class) {
contextClass = (Class<?>) type;
}
return bodyInternal(type, contextClass);
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private <T> T bodyInternal(Type type, @Nullable Class<?> contextClass) private <T> T bodyInternal(Type bodyType, Class<?> bodyClass)
throws ServletException, IOException { throws ServletException, IOException {
HttpInputMessage inputMessage = new BuiltInputMessage(); HttpInputMessage inputMessage = new BuiltInputMessage();
MediaType contentType = headers().contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); MediaType contentType = headers().contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
for (HttpMessageConverter<?> messageConverter : this.messageConverters) { for (HttpMessageConverter<?> messageConverter : this.messageConverters) {
if (messageConverter instanceof GenericHttpMessageConverter<?>) { if (messageConverter instanceof GenericHttpMessageConverter) {
GenericHttpMessageConverter<T> genericMessageConverter = GenericHttpMessageConverter<T> genericMessageConverter =
(GenericHttpMessageConverter<T>) messageConverter; (GenericHttpMessageConverter<T>) messageConverter;
if (genericMessageConverter.canRead(type, contextClass, contentType)) { if (genericMessageConverter.canRead(bodyType, bodyClass, contentType)) {
return genericMessageConverter.read(type, contextClass, inputMessage); return genericMessageConverter.read(bodyType, bodyClass, inputMessage);
} }
} }
else { if (messageConverter.canRead(bodyClass, contentType)) {
if (messageConverter.canRead(contextClass, contentType)) { HttpMessageConverter<T> theConverter =
HttpMessageConverter<T> theConverter = (HttpMessageConverter<T>) messageConverter;
(HttpMessageConverter<T>) messageConverter; Class<? extends T> clazz = (Class<? extends T>) bodyClass;
Class<? extends T> clazz = (Class<? extends T>) contextClass; return theConverter.read(clazz, inputMessage);
return theConverter.read(clazz, inputMessage);
}
} }
} }
throw new HttpMediaTypeNotSupportedException(contentType, Collections.emptyList()); throw new HttpMediaTypeNotSupportedException(contentType, Collections.emptyList());
......
...@@ -29,7 +29,6 @@ import java.util.LinkedHashSet; ...@@ -29,7 +29,6 @@ import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletionStage;
import java.util.function.BiFunction; import java.util.function.BiFunction;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Predicate; import java.util.function.Predicate;
...@@ -38,8 +37,7 @@ import javax.servlet.http.Cookie; ...@@ -38,8 +37,7 @@ import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.reactivestreams.Publisher; import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.CacheControl; import org.springframework.http.CacheControl;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
...@@ -196,16 +194,8 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder { ...@@ -196,16 +194,8 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
} }
@Override @Override
public ServerResponse asyncBody(CompletionStage<?> asyncBody) { public <T> ServerResponse body(T body, ParameterizedTypeReference<T> bodyType) {
return DefaultEntityResponseBuilder.fromCompletionStage(asyncBody) return DefaultEntityResponseBuilder.fromObject(body, bodyType)
.headers(this.headers)
.status(this.statusCode)
.build();
}
@Override
public ServerResponse asyncBody(Publisher<?> futureBody) {
return DefaultEntityResponseBuilder.fromPublisher(futureBody)
.headers(this.headers) .headers(this.headers)
.status(this.statusCode) .status(this.statusCode)
.build(); .build();
......
...@@ -20,12 +20,10 @@ import java.net.URI; ...@@ -20,12 +20,10 @@ import java.net.URI;
import java.time.Instant; import java.time.Instant;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletionStage;
import java.util.function.Consumer; import java.util.function.Consumer;
import javax.servlet.http.Cookie; import javax.servlet.http.Cookie;
import org.reactivestreams.Publisher; import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.CacheControl; import org.springframework.http.CacheControl;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
...@@ -61,24 +59,16 @@ public interface EntityResponse<T> extends ServerResponse { ...@@ -61,24 +59,16 @@ public interface EntityResponse<T> extends ServerResponse {
} }
/** /**
* Create a builder for an asynchronous body supplied by the given {@link CompletionStage}. * Create a builder with the given object and type reference.
* @param completionStage the supplier of the response body * @param t the object that represents the body of the response
* @param <T> the type of the elements contained in the publisher * @param entityType the type of the entity, used to capture the generic type
* @param <T> the type of element contained in the publisher
* @return the created builder * @return the created builder
*/ */
static <T> Builder<CompletionStage<T>> fromCompletionStage(CompletionStage<T> completionStage) { static <T> Builder<T> fromObject(T t, ParameterizedTypeReference<T> entityType) {
return DefaultEntityResponseBuilder.fromCompletionStage(completionStage); return DefaultEntityResponseBuilder.fromObject(t, entityType);
} }
/**
* Create a builder for an asynchronous body supplied by the given {@link Publisher}.
* @param publisher the supplier of the response body
* @param <T> the type of the elements contained in the publisher
* @return the created builder
*/
static <T> Builder<Publisher<T>> fromPublisher(Publisher<T> publisher) {
return DefaultEntityResponseBuilder.fromPublisher(publisher);
}
/** /**
* Defines a builder for {@code EntityResponse}. * Defines a builder for {@code EntityResponse}.
......
...@@ -34,6 +34,7 @@ import javax.servlet.http.HttpServletResponse; ...@@ -34,6 +34,7 @@ import javax.servlet.http.HttpServletResponse;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.CacheControl; import org.springframework.http.CacheControl;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
...@@ -362,26 +363,23 @@ public interface ServerResponse { ...@@ -362,26 +363,23 @@ public interface ServerResponse {
/** /**
* Set the body of the response to the given {@code Object} and return it. * Set the body of the response to the given {@code Object} and return it.
*
* <p>Asynchronous response bodies are supported by providing a {@link CompletionStage} or
* {@link Publisher} as body.
* @param body the body of the response * @param body the body of the response
* @return the built response * @return the built response
*/ */
ServerResponse body(Object body); ServerResponse body(Object body);
/** /**
* Set the asynchronous body of the response to the given {@link CompletionStage} and * Set the body of the response to the given {@code Object} and return it. The parameter
* return it. * {@code bodyType} is used to capture the generic type.
* @param asyncBody the body of the response *
* @return the built response * @param body the body of the response
*/ * @param bodyType the type of the body, used to capture the generic type
ServerResponse asyncBody(CompletionStage<?> asyncBody);
/**
* Set the asynchronous body of the response to the given {@link Publisher} and
* return it.
* @param asyncBody the body of the response
* @return the built response * @return the built response
*/ */
ServerResponse asyncBody(Publisher<?> asyncBody); <T> ServerResponse body(T body, ParameterizedTypeReference<T> bodyType);
/** /**
* Render the template with the given {@code name} using the given {@code modelAttributes}. * Render the template with the given {@code name} using the given {@code modelAttributes}.
......
...@@ -29,6 +29,7 @@ import javax.servlet.http.Cookie; ...@@ -29,6 +29,7 @@ import javax.servlet.http.Cookie;
import org.junit.Test; import org.junit.Test;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.CacheControl; import org.springframework.http.CacheControl;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
...@@ -63,6 +64,16 @@ public class DefaultEntityResponseBuilderTests { ...@@ -63,6 +64,16 @@ public class DefaultEntityResponseBuilderTests {
assertSame(body, response.entity()); assertSame(body, response.entity());
} }
@Test
public void fromObjectTypeReference() {
String body = "foo";
EntityResponse<String> response = EntityResponse.fromObject(body,
new ParameterizedTypeReference<String>() {})
.build();
assertSame(body, response.entity());
}
@Test @Test
public void status() { public void status() {
String body = "foo"; String body = "foo";
......
...@@ -36,6 +36,7 @@ import org.springframework.http.HttpRange; ...@@ -36,6 +36,7 @@ import org.springframework.http.HttpRange;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpSession; import org.springframework.mock.web.test.MockHttpSession;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
...@@ -240,14 +241,16 @@ public class DefaultServerRequestTests { ...@@ -240,14 +241,16 @@ public class DefaultServerRequestTests {
@Test @Test
public void bodyParameterizedTypeReference() throws Exception { public void bodyParameterizedTypeReference() throws Exception {
MockHttpServletRequest servletRequest = new MockHttpServletRequest("GET", "/"); MockHttpServletRequest servletRequest = new MockHttpServletRequest("GET", "/");
servletRequest.setContentType(MediaType.TEXT_PLAIN_VALUE); servletRequest.setContentType(MediaType.APPLICATION_JSON_VALUE);
servletRequest.setContent("foo".getBytes(UTF_8)); servletRequest.setContent("[\"foo\",\"bar\"]".getBytes(UTF_8));
DefaultServerRequest request = new DefaultServerRequest(servletRequest, DefaultServerRequest request = new DefaultServerRequest(servletRequest,
this.messageConverters); Collections.singletonList(new MappingJackson2HttpMessageConverter()));
String result = request.body(new ParameterizedTypeReference<String>() {}); List<String> result = request.body(new ParameterizedTypeReference<List<String>>() {});
assertEquals("foo", result); assertEquals(2, result.size());
assertEquals("foo", result.get(0));
assertEquals("bar", result.get(1));
} }
@Test(expected = HttpMediaTypeNotSupportedException.class) @Test(expected = HttpMediaTypeNotSupportedException.class)
......
...@@ -20,6 +20,7 @@ import java.net.URI; ...@@ -20,6 +20,7 @@ import java.net.URI;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.List; import java.util.List;
...@@ -31,6 +32,7 @@ import org.junit.Test; ...@@ -31,6 +32,7 @@ import org.junit.Test;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.CacheControl; import org.springframework.http.CacheControl;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
...@@ -38,6 +40,7 @@ import org.springframework.http.HttpStatus; ...@@ -38,6 +40,7 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.mock.web.test.MockHttpServletResponse;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
...@@ -288,10 +291,27 @@ public class DefaultServerResponseBuilderTests { ...@@ -288,10 +291,27 @@ public class DefaultServerResponseBuilderTests {
} }
@Test @Test
public void asyncBodyCompletionStage() throws Exception { public void bodyWithParameterizedTypeReference() throws Exception {
List<String> body = new ArrayList<>();
body.add("foo");
body.add("bar");
ServerResponse response = ServerResponse.ok().body(body, new ParameterizedTypeReference<List<String>>() {});
MockHttpServletRequest mockRequest = new MockHttpServletRequest("GET", "http://example.com");
MockHttpServletResponse mockResponse = new MockHttpServletResponse();
ServerResponse.Context context = () -> Collections.singletonList(new MappingJackson2HttpMessageConverter());
ModelAndView mav = response.writeTo(mockRequest, mockResponse, context);
assertNull(mav);
assertEquals("[\"foo\",\"bar\"]", mockResponse.getContentAsString());
}
@Test
public void bodyCompletionStage() throws Exception {
String body = "foo"; String body = "foo";
CompletionStage<String> completionStage = CompletableFuture.completedFuture(body); CompletionStage<String> completionStage = CompletableFuture.completedFuture(body);
ServerResponse response = ServerResponse.ok().asyncBody(completionStage); ServerResponse response = ServerResponse.ok().body(completionStage);
MockHttpServletRequest mockRequest = new MockHttpServletRequest("GET", "http://example.com"); MockHttpServletRequest mockRequest = new MockHttpServletRequest("GET", "http://example.com");
MockHttpServletResponse mockResponse = new MockHttpServletResponse(); MockHttpServletResponse mockResponse = new MockHttpServletResponse();
...@@ -307,10 +327,10 @@ public class DefaultServerResponseBuilderTests { ...@@ -307,10 +327,10 @@ public class DefaultServerResponseBuilderTests {
} }
@Test @Test
public void asyncBodyPublisher() throws Exception { public void bodyPublisher() throws Exception {
String body = "foo"; String body = "foo";
Publisher<String> publisher = Mono.just(body); Publisher<String> publisher = Mono.just(body);
ServerResponse response = ServerResponse.ok().asyncBody(publisher); ServerResponse response = ServerResponse.ok().body(publisher);
MockHttpServletRequest mockRequest = new MockHttpServletRequest("GET", "http://example.com"); MockHttpServletRequest mockRequest = new MockHttpServletRequest("GET", "http://example.com");
MockHttpServletResponse mockResponse = new MockHttpServletResponse(); MockHttpServletResponse mockResponse = new MockHttpServletResponse();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册