Skip to content

Fix the ClassCastException when CacheRequestBodyFilter is used with CircuitBreakerFilter. #3547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public RouteLocator routes(RouteLocatorBuilder builder) {
.route("cache_request_body_route", r -> r.path("/downstream/**")
.filters(f -> f.prefixPath("/httpbin")
.cacheRequestBody(String.class).uri(uri))
.build();
.build());
}
----

Expand All @@ -36,7 +36,7 @@ spring:
bodyClass: java.lang.String
----
`CacheRequestBody` extracts the request body and converts it to a body class (such as `java.lang.String`, defined in the preceding example).
`CacheRequestBody` then places it in the attributes available from `ServerWebExchange.getAttributes()`, with a key defined in `ServerWebExchangeUtils.CACHED_REQUEST_BODY_ATTR`.
`CacheRequestBody` then places it in the attributes available from `ServerWebExchange.getAttributes()`, with a key defined in `ServerWebExchangeUtils.CACHE_REQUEST_BODY_OBJECT_ATTR`.

NOTE: This filter works only with HTTP (including HTTPS) requests.

Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,9 @@
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;

import static org.springframework.cloud.gateway.support.GatewayToStringStyler.filterToStringCreator;
Expand Down Expand Up @@ -70,36 +66,11 @@ public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
return chain.filter(exchange);
}

Object cachedBody = exchange.getAttribute(ServerWebExchangeUtils.CACHED_REQUEST_BODY_ATTR);
if (cachedBody != null) {
return chain.filter(exchange);
}

return ServerWebExchangeUtils.cacheRequestBodyAndRequest(exchange, (serverHttpRequest) -> {
final ServerRequest serverRequest = ServerRequest
.create(exchange.mutate().request(serverHttpRequest).build(), messageReaders);
return serverRequest.bodyToMono((config.getBodyClass())).doOnNext(objectValue -> {
Object previousCachedBody = exchange.getAttributes()
.put(ServerWebExchangeUtils.CACHED_REQUEST_BODY_ATTR, objectValue);
if (previousCachedBody != null) {
// store previous cached body
exchange.getAttributes().put(CACHED_ORIGINAL_REQUEST_BODY_BACKUP_ATTR, previousCachedBody);
}
}).then(Mono.defer(() -> {
ServerHttpRequest cachedRequest = exchange
.getAttribute(CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR);
Assert.notNull(cachedRequest, "cache request shouldn't be null");
return ServerWebExchangeUtils.cacheRequestBodyObject(exchange, config.getBodyClass(), messageReaders,
(serverHttpRequest, cachedBody) -> {
exchange.getAttributes().remove(CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR);
return chain.filter(exchange.mutate().request(cachedRequest).build()).doFinally(s -> {
//
Object backupCachedBody = exchange.getAttributes()
.get(CACHED_ORIGINAL_REQUEST_BODY_BACKUP_ATTR);
if (backupCachedBody instanceof DataBuffer dataBuffer) {
DataBufferUtils.release(dataBuffer);
}
});
}));
});
return chain.filter(exchange.mutate().request(serverHttpRequest).build());
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;

/**
Expand All @@ -43,8 +42,6 @@ public class ReadBodyRoutePredicateFactory extends AbstractRoutePredicateFactory

private static final String TEST_ATTRIBUTE = "read_body_predicate_test_attribute";

private static final String CACHE_REQUEST_BODY_OBJECT_KEY = "cachedRequestBodyObject";

private final List<HttpMessageReader<?>> messageReaders;

public ReadBodyRoutePredicateFactory() {
Expand All @@ -63,10 +60,7 @@ public AsyncPredicate<ServerWebExchange> applyAsync(Config config) {
return new AsyncPredicate<ServerWebExchange>() {
@Override
public Publisher<Boolean> apply(ServerWebExchange exchange) {
Class inClass = config.getInClass();

Object cachedBody = exchange.getAttribute(CACHE_REQUEST_BODY_OBJECT_KEY);
Mono<?> modifiedBody;
Object cachedBody = exchange.getAttribute(ServerWebExchangeUtils.CACHE_REQUEST_BODY_OBJECT_ATTR);
// We can only read the body from the request once, once that happens if
// we try to read the body again an exception will be thrown. The below
// if/else caches the body object as a request attribute in the
Expand All @@ -87,15 +81,14 @@ public Publisher<Boolean> apply(ServerWebExchange exchange) {
}
return Mono.just(false);
}
else {
return ServerWebExchangeUtils.cacheRequestBodyAndRequest(exchange,
(serverHttpRequest) -> ServerRequest
.create(exchange.mutate().request(serverHttpRequest).build(), messageReaders)
.bodyToMono(inClass)
.doOnNext(objectValue -> exchange.getAttributes()
.put(CACHE_REQUEST_BODY_OBJECT_KEY, objectValue))
.map(objectValue -> config.getPredicate().test(objectValue)));
}

return ServerWebExchangeUtils.cacheRequestBodyObject(exchange, config.getInClass(), messageReaders,
(serverHttpRequest, bodyObject) -> {
if (bodyObject == null) {
return Mono.just(false);
}
return Mono.just(config.predicate.test(bodyObject));
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;

Expand All @@ -43,12 +45,14 @@
import org.springframework.core.io.buffer.NettyDataBuffer;
import org.springframework.core.io.buffer.PooledDataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.AbstractServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.DispatcherHandler;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;

Expand Down Expand Up @@ -174,6 +178,12 @@ public final class ServerWebExchangeUtils {
*/
public static final String CACHED_REQUEST_BODY_ATTR = "cachedRequestBody";

/**
* Cached request decoded body object key. Used when
* {@link #cacheRequestBodyObject(ServerWebExchange, Class, List, BiFunction)}
*/
public static final String CACHE_REQUEST_BODY_OBJECT_ATTR = "cachedRequestBodyObject";

/**
* Gateway LoadBalancer {@link Response} attribute name.
*/
Expand Down Expand Up @@ -316,6 +326,29 @@ public static Map<String, String> getUriTemplateVariables(ServerWebExchange exch
return exchange.getAttributeOrDefault(URI_TEMPLATE_VARIABLES_ATTRIBUTE, new HashMap<>());
}

/**
* Caches the request body, the decoded body object and the created {@link ServerHttpRequestDecorator} in
* ServerWebExchange attributes. Those attributes are
* {@link #CACHED_REQUEST_BODY_ATTR} and
* {@link #CACHE_REQUEST_BODY_OBJECT_ATTR} and
* {@link #CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR} respectively.
* @param exchange the available ServerWebExchange.
* @param bodyClass the class of the body to be decoded
* @param messageReaders the list of message readers for decoding the body.
* @param function a function to apply on the decoded body and request.
* @param <C> the class type of the decoded body.
* @param <T> generic type for the return {@link Mono}.
* @return Mono of type T created by the function parameter.
*/
public static <C, T> Mono<T> cacheRequestBodyObject(ServerWebExchange exchange, Class<C> bodyClass,
List<HttpMessageReader<?>> messageReaders, BiFunction<ServerHttpRequest, C, Mono<T>> function) {
return cacheRequestBodyAndRequest(exchange, (serverHttpRequest) -> ServerRequest
.create(exchange.mutate().request(serverHttpRequest).build(), messageReaders).bodyToMono(bodyClass)
.doOnNext(objectValue -> exchange.getAttributes().put(CACHE_REQUEST_BODY_OBJECT_ATTR, objectValue))
.flatMap(cachedBody -> function.apply(serverHttpRequest, cachedBody))
.switchIfEmpty(function.apply(serverHttpRequest, null)));
}

/**
* Caches the request body and the created {@link ServerHttpRequestDecorator} in
* ServerWebExchange attributes. Those attributes are
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.cloud.gateway.filter.factory;

import java.util.Collections;
import java.util.Map;

import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -118,6 +119,19 @@ public void cacheRequestBodyExists() {
.isOk();
}

@Test
public void cacheRequestBodyWithCircuitBreaker() {
testClient.post().uri("/post").header("Host", "www.cacherequestbodywithcircuitbreaker.org")
.bodyValue(BODY_VALUE).exchange().expectStatus().isOk().expectBody(Map.class)
.consumeWith(result -> {
Map<?, ?> response = result.getResponseBody();
assertThat(response).isNotNull();

String responseBody = (String) response.get("data");
assertThat(responseBody).isEqualTo(BODY_VALUE);
});
}

@Test
public void toStringFormat() {
CacheRequestBodyGatewayFilterFactory.Config config = new CacheRequestBodyGatewayFilterFactory.Config();
Expand Down Expand Up @@ -163,6 +177,18 @@ public RouteLocator testRouteLocator(RouteLocatorBuilder builder) {
.cacheRequestBody(String.class)
.filter(new AssertCachedRequestBodyGatewayFilter(BODY_CACHED_EXISTS)))
.uri(uri))
.route("cache_request_body_with_circuitbreaker_test",
r -> r.path("/post")
.and()
.host("**.cacherequestbodywithcircuitbreaker.org")
.filters(f -> f.setHostHeader("www.cacherequestbody.org")
.prefixPath("/httpbin")
.cacheRequestBody(String.class)
.filter(new AssertCachedRequestBodyGatewayFilter(BODY_VALUE))
.filter(new CheckCachedRequestBodyReleasedGatewayFilter())
.circuitBreaker(config -> config.setStatusCodes(Collections.singleton("200"))
.setFallbackUri("/post")))
.uri(uri))
.build();
}

Expand All @@ -181,7 +207,7 @@ private static class AssertCachedRequestBodyGatewayFilter implements GatewayFilt

@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
String body = exchange.getAttribute(ServerWebExchangeUtils.CACHED_REQUEST_BODY_ATTR);
String body = exchange.getAttribute(ServerWebExchangeUtils.CACHE_REQUEST_BODY_OBJECT_ATTR);
if (exceptNullBody) {
assertThat(body).isNull();
}
Expand All @@ -203,7 +229,7 @@ private static class SetExchangeCachedRequestBodyGatewayFilter implements Gatewa

@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
exchange.getAttributes().put(ServerWebExchangeUtils.CACHED_REQUEST_BODY_ATTR, bodyToSetCache);
exchange.getAttributes().put(ServerWebExchangeUtils.CACHE_REQUEST_BODY_OBJECT_ATTR, bodyToSetCache);
return chain.filter(exchange);
}

Expand Down
Loading