diff --git a/logbook-servlet/src/main/java/org/zalando/logbook/servlet/LogbookFilter.java b/logbook-servlet/src/main/java/org/zalando/logbook/servlet/LogbookFilter.java index 63ecd508f..bb02800bd 100644 --- a/logbook-servlet/src/main/java/org/zalando/logbook/servlet/LogbookFilter.java +++ b/logbook-servlet/src/main/java/org/zalando/logbook/servlet/LogbookFilter.java @@ -1,5 +1,6 @@ package org.zalando.logbook.servlet; +import jakarta.annotation.Nullable; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; @@ -14,11 +15,8 @@ import org.zalando.logbook.Logbook.ResponseWritingStage; import org.zalando.logbook.Strategy; -import jakarta.annotation.Nullable; import java.io.IOException; -import java.util.Optional; import java.util.UUID; -import java.util.concurrent.atomic.AtomicBoolean; import static jakarta.servlet.DispatcherType.ASYNC; import static lombok.AccessLevel.PRIVATE; @@ -29,13 +27,13 @@ public final class LogbookFilter implements HttpFilter { /** - * Unique per instance so we don't accidentally share stages between filter + * Unique per instance, so we don't accidentally share stages between filter * instances in the same chain. */ private final String responseProcessingStageName = ResponseProcessingStage.class.getName() + "-" + UUID.randomUUID(); - private final String responseWritingStageSynchronizationName = ResponseWritingStage.class.getName() + "-Synchronization-"+ UUID.randomUUID(); private final Logbook logbook; + @Nullable private final Strategy strategy; @With @@ -56,7 +54,6 @@ public LogbookFilter(final Logbook logbook, @Nullable final Strategy strategy) { @Override public void doFilter(final HttpServletRequest httpRequest, final HttpServletResponse httpResponse, final FilterChain chain) throws ServletException, IOException { - final RemoteRequest request = new RemoteRequest(httpRequest, formRequestMode); final LocalResponse response = new LocalResponse(httpResponse, request.getProtocolVersion()); @@ -70,28 +67,28 @@ public void doFilter(final HttpServletRequest httpRequest, final HttpServletResp } final ResponseWritingStage writing = processing.process(response); - request.setAsyncListener(Optional.of(new LogbookAsyncListener(event -> write(request, response, writing)))); - request.setAttribute(responseWritingStageSynchronizationName, new AtomicBoolean(false)); chain.doFilter(request, response); if (request.isAsyncStarted()) { + request.getAsyncContext().addListener(new LogbookAsyncListener(event -> write(response, writing))); + return; } - write(request, response, writing); + // The async writing is handled by the attached on-complete listener + if (request.getDispatcherType() != ASYNC) { + write(response, writing); + } } - private void write(RemoteRequest request, LocalResponse response, ResponseWritingStage writing) throws IOException { - final AtomicBoolean attribute = (AtomicBoolean) request.getAttribute(responseWritingStageSynchronizationName); - if (attribute != null && !attribute.getAndSet(true)) { - try { - response.flushBuffer(); - } catch (IOException e) { - // ignore and try to log the response anyway - } - writing.write(); + private void write(LocalResponse response, ResponseWritingStage writing) throws IOException { + try { + response.flushBuffer(); + } catch (IOException e) { + // ignore and try to log the response anyway } + writing.write(); } private RequestWritingStage process( diff --git a/logbook-servlet/src/main/java/org/zalando/logbook/servlet/RemoteRequest.java b/logbook-servlet/src/main/java/org/zalando/logbook/servlet/RemoteRequest.java index a7d2fe7db..98d0f9195 100644 --- a/logbook-servlet/src/main/java/org/zalando/logbook/servlet/RemoteRequest.java +++ b/logbook-servlet/src/main/java/org/zalando/logbook/servlet/RemoteRequest.java @@ -1,12 +1,18 @@ package org.zalando.logbook.servlet; -import jakarta.servlet.AsyncContext; -import jakarta.servlet.AsyncListener; import jakarta.servlet.ServletInputStream; import jakarta.servlet.ServletRequest; -import jakarta.servlet.ServletResponse; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequestWrapper; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.SneakyThrows; +import org.zalando.logbook.ContentType; +import org.zalando.logbook.HttpHeaders; +import org.zalando.logbook.HttpRequest; +import org.zalando.logbook.Origin; +import org.zalando.logbook.common.MediaTypeQuery; + import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -20,14 +26,6 @@ import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Predicate; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.SneakyThrows; -import org.zalando.logbook.ContentType; -import org.zalando.logbook.HttpHeaders; -import org.zalando.logbook.HttpRequest; -import org.zalando.logbook.Origin; -import org.zalando.logbook.common.MediaTypeQuery; import static java.nio.charset.StandardCharsets.ISO_8859_1; import static java.util.Collections.list; @@ -40,7 +38,6 @@ final class RemoteRequest extends HttpServletRequestWrapper implements HttpRequest { private final AtomicReference state; - private Optional asyncListener = Optional.empty(); /** * Manages the lifecycle of HTTP request body buffering for servlet requests. @@ -429,24 +426,6 @@ public byte[] getBody() { return buffer().getBody(); } - @Override - public AsyncContext startAsync() throws IllegalStateException { - final AsyncContext asyncContext = super.startAsync(); - asyncListener.ifPresent(asyncContext::addListener); - return asyncContext; - } - - @Override - public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException { - final AsyncContext asyncContext = super.startAsync(servletRequest, servletResponse); - asyncListener.ifPresent(asyncContext::addListener); - return asyncContext; - } - - public void setAsyncListener(Optional asyncListener) { - this.asyncListener = asyncListener; - } - private State buffer() { return state.updateAndGet(throwingUnaryOperator(state -> state.buffer(getRequest()))); diff --git a/logbook-servlet/src/test/java/org/zalando/logbook/servlet/AsyncDispatchTest.java b/logbook-servlet/src/test/java/org/zalando/logbook/servlet/AsyncDispatchTest.java index 11536e255..c96644d25 100644 --- a/logbook-servlet/src/test/java/org/zalando/logbook/servlet/AsyncDispatchTest.java +++ b/logbook-servlet/src/test/java/org/zalando/logbook/servlet/AsyncDispatchTest.java @@ -11,12 +11,12 @@ import org.mockito.ArgumentCaptor; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.test.context.bean.override.mockito.MockitoBean; import org.springframework.boot.web.servlet.FilterRegistrationBean; import org.springframework.boot.web.servlet.ServletRegistrationBean; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; +import org.springframework.test.context.bean.override.mockito.MockitoBean; import org.springframework.web.client.RestTemplate; import org.zalando.logbook.Correlation; import org.zalando.logbook.HttpLogWriter; @@ -46,7 +46,7 @@ final class AsyncDispatchTest { @MockitoBean private HttpLogWriter writer; - static class AsyncHttpServlet extends HttpServlet{ + static class AsyncHttpServlet extends HttpServlet { @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) { @@ -54,7 +54,7 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) { } @Override - protected void doPost(HttpServletRequest req, HttpServletResponse resp) { + protected void doPost(HttpServletRequest req, HttpServletResponse resp) { asyncResponse(req.startAsync(req, resp)); } @@ -136,6 +136,32 @@ void shouldFormatAsyncResponse() throws Exception { .contains("200 OK", "text/plain", "Hello, world!"); } + @Test + void shouldFormatStreamingResponse() throws Exception { + final RestTemplate template = new RestTemplate(); + template.getForObject("http://localhost:8080/api/streaming", String.class); + + waitFor(Duration.ofSeconds(1)); + + final String response = interceptResponse(); + + assertThat(response) + .contains("200 OK", "text/plain", "chunked", "Hello, world!"); + } + + @Test + void shouldFormatResponseForDispatchWithMultipleStartAsync() throws Exception { + final RestTemplate template = new RestTemplate(); + template.getForObject("http://localhost:8080/api/multi-async/step-1", String.class); + + waitFor(Duration.ofSeconds(1)); + + final String response = interceptResponse(); + + assertThat(response) + .contains("200 OK", "text/plain", "Hello, world!"); + } + @Test void shouldFormatAsyncServletResponseGet() throws Exception { final RestTemplate template = new RestTemplate(); diff --git a/logbook-servlet/src/test/java/org/zalando/logbook/servlet/ExampleController.java b/logbook-servlet/src/test/java/org/zalando/logbook/servlet/ExampleController.java index 0276dd00e..aeb88dced 100644 --- a/logbook-servlet/src/test/java/org/zalando/logbook/servlet/ExampleController.java +++ b/logbook-servlet/src/test/java/org/zalando/logbook/servlet/ExampleController.java @@ -11,12 +11,16 @@ import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.ModelAndView; +import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; import java.io.IOException; import java.nio.CharBuffer; +import java.nio.charset.StandardCharsets; import java.util.Objects; import java.util.concurrent.Callable; +import static org.springframework.http.MediaType.TEXT_PLAIN; import static org.springframework.http.MediaType.TEXT_PLAIN_VALUE; @RestController @@ -43,6 +47,28 @@ public Callable returnMessage() { return () -> "Hello, world!"; } + @RequestMapping(path = "/streaming", produces = TEXT_PLAIN_VALUE) + public ResponseEntity streaming() { + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.setContentType(new MediaType(TEXT_PLAIN, StandardCharsets.UTF_8)); + + return new ResponseEntity<>( + outputStream -> outputStream.write("Hello, world!".getBytes(StandardCharsets.UTF_8)), + httpHeaders, + HttpStatus.OK + ); + } + + @RequestMapping(path = "/multi-async/step-1", produces = TEXT_PLAIN_VALUE) + public Callable multiAsyncStep1() { + return () -> new ModelAndView("forward:/api/multi-async/step-2"); + } + + @RequestMapping(path = "/multi-async/step-2", produces = TEXT_PLAIN_VALUE) + public Callable multiAsyncStep2() { + return () -> "Hello, world!"; + } + @RequestMapping("/empty") public void empty() { // intentionally left blank diff --git a/logbook-servlet/src/test/java/org/zalando/logbook/servlet/LogbookFilterTest.java b/logbook-servlet/src/test/java/org/zalando/logbook/servlet/LogbookFilterTest.java index 8acc6cfbd..3bf45f761 100644 --- a/logbook-servlet/src/test/java/org/zalando/logbook/servlet/LogbookFilterTest.java +++ b/logbook-servlet/src/test/java/org/zalando/logbook/servlet/LogbookFilterTest.java @@ -6,18 +6,15 @@ import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; import org.zalando.logbook.Logbook; import java.io.IOException; import java.util.concurrent.atomic.AtomicBoolean; import static java.util.Collections.emptyEnumeration; -import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -47,7 +44,7 @@ void shouldCallDestroy() { void shouldHandleIOExceptionOnFlushBufferAndWriteResponse() throws Exception { Logbook logbook = mock(Logbook.class); Logbook.RequestWritingStage requestWritingStage = mock(Logbook.RequestWritingStage.class); - Logbook.ResponseWritingStage responseWritingStage = mock(Logbook.ResponseWritingStage .class); + Logbook.ResponseWritingStage responseWritingStage = mock(Logbook.ResponseWritingStage.class); LogbookFilter filter = new LogbookFilter(logbook); HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); @@ -66,32 +63,4 @@ void shouldHandleIOExceptionOnFlushBufferAndWriteResponse() throws Exception { verify(responseWritingStage).write(); } - - @Test - void shouldNotThrowNPEIfRequestDoesntContainWritingStageSynchronizationBoolean() throws Exception { - Logbook logbook = mock(Logbook.class); - Logbook.RequestWritingStage requestWritingStage = mock(Logbook.RequestWritingStage.class); - Logbook.ResponseWritingStage responseWritingStage = mock(Logbook.ResponseWritingStage .class); - LogbookFilter filter = new LogbookFilter(logbook); - HttpServletRequest request = mock(HttpServletRequest.class); - HttpServletResponse response = mock(HttpServletResponse.class); - FilterChain chain = mock(FilterChain.class); - ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); - - - when(logbook.process(any())).thenReturn(requestWritingStage); - when(requestWritingStage.write()).thenReturn(requestWritingStage); - when(requestWritingStage.process(any())).thenReturn(responseWritingStage); - when(request.getHeaderNames()).thenReturn(emptyEnumeration()); - when(request.getDispatcherType()).thenReturn(DispatcherType.REQUEST); - when(request.getAttribute(captor.capture())).thenReturn(null); - - doThrow(new IOException("Simulated IOException")).when(response).flushBuffer(); - - filter.doFilter(request, response, chain); - - verify(responseWritingStage, never()).write(); - assertThat(captor.getValue()).contains("-Synchronization-"); - } - } diff --git a/logbook-servlet/src/test/java/org/zalando/logbook/servlet/RemoteRequestTest.java b/logbook-servlet/src/test/java/org/zalando/logbook/servlet/RemoteRequestTest.java index ab2974b96..bb0f11ca9 100644 --- a/logbook-servlet/src/test/java/org/zalando/logbook/servlet/RemoteRequestTest.java +++ b/logbook-servlet/src/test/java/org/zalando/logbook/servlet/RemoteRequestTest.java @@ -1,28 +1,21 @@ package org.zalando.logbook.servlet; -import jakarta.servlet.AsyncContext; -import jakarta.servlet.AsyncListener; import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; -import java.io.UnsupportedEncodingException; -import java.util.Optional; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.io.UnsupportedEncodingException; + import static java.util.Collections.emptyEnumeration; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; class RemoteRequestTest { private final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); - private final HttpServletResponse httpServletResponse = mock(HttpServletResponse.class); - private final AsyncContext asyncContext = mock(AsyncContext.class); - private final AsyncListener asyncListener = mock(AsyncListener.class); private RemoteRequest remoteRequest; @@ -30,23 +23,6 @@ class RemoteRequestTest { void setUp() { when(httpServletRequest.getHeaderNames()).thenReturn(emptyEnumeration()); remoteRequest = new RemoteRequest(httpServletRequest, FormRequestMode.OFF); - remoteRequest.setAsyncListener(Optional.of(asyncListener)); - } - - @Test - void startAsync_noargs() { - when(httpServletRequest.startAsync()).thenReturn(asyncContext); - - assertEquals(asyncContext, remoteRequest.startAsync()); - verify(asyncContext).addListener(asyncListener); - } - - @Test - void startAsync_twoargs() { - when(httpServletRequest.startAsync(httpServletRequest, httpServletResponse)).thenReturn(asyncContext); - - assertEquals(asyncContext, remoteRequest.startAsync(httpServletRequest, httpServletResponse)); - verify(asyncContext).addListener(asyncListener); } @Test @@ -180,7 +156,7 @@ void offered_without_thenGetInputStream() throws Exception { remoteRequest.withBody(); assertEquals(4, remoteRequest.getBody().length); - jakarta.servlet.ServletInputStream stream = remoteRequest.getInputStream(); + remoteRequest.getInputStream(); assertEquals(4, remoteRequest.getBody().length); }