Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.zalando.logbook.servlet;

import jakarta.annotation.Nullable;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
Expand All @@ -14,11 +15,8 @@
import org.zalando.logbook.Logbook.ResponseWritingStage;
import org.zalando.logbook.Strategy;

import jakarta.annotation.Nullable;
import java.io.IOException;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;

import static jakarta.servlet.DispatcherType.ASYNC;
import static lombok.AccessLevel.PRIVATE;
Expand All @@ -29,13 +27,13 @@
public final class LogbookFilter implements HttpFilter {

/**
* Unique per instance so we don't accidentally share stages between filter
* Unique per instance, so we don't accidentally share stages between filter
* instances in the same chain.
*/
private final String responseProcessingStageName = ResponseProcessingStage.class.getName() + "-" + UUID.randomUUID();
private final String responseWritingStageSynchronizationName = ResponseWritingStage.class.getName() + "-Synchronization-"+ UUID.randomUUID();

private final Logbook logbook;
@Nullable
private final Strategy strategy;

@With
Expand All @@ -56,7 +54,6 @@ public LogbookFilter(final Logbook logbook, @Nullable final Strategy strategy) {
@Override
public void doFilter(final HttpServletRequest httpRequest, final HttpServletResponse httpResponse,
final FilterChain chain) throws ServletException, IOException {

final RemoteRequest request = new RemoteRequest(httpRequest, formRequestMode);
final LocalResponse response = new LocalResponse(httpResponse, request.getProtocolVersion());

Expand All @@ -70,28 +67,28 @@ public void doFilter(final HttpServletRequest httpRequest, final HttpServletResp
}

final ResponseWritingStage writing = processing.process(response);
request.setAsyncListener(Optional.of(new LogbookAsyncListener(event -> write(request, response, writing))));
request.setAttribute(responseWritingStageSynchronizationName, new AtomicBoolean(false));

chain.doFilter(request, response);

if (request.isAsyncStarted()) {
request.getAsyncContext().addListener(new LogbookAsyncListener(event -> write(response, writing)));

return;
}

write(request, response, writing);
// The async writing is handled by the attached on-complete listener
if (request.getDispatcherType() != ASYNC) {
write(response, writing);
}
}

private void write(RemoteRequest request, LocalResponse response, ResponseWritingStage writing) throws IOException {
final AtomicBoolean attribute = (AtomicBoolean) request.getAttribute(responseWritingStageSynchronizationName);
if (attribute != null && !attribute.getAndSet(true)) {
try {
response.flushBuffer();
} catch (IOException e) {
// ignore and try to log the response anyway
}
writing.write();
private void write(LocalResponse response, ResponseWritingStage writing) throws IOException {
try {
response.flushBuffer();
} catch (IOException e) {
// ignore and try to log the response anyway
}
writing.write();
}

private RequestWritingStage process(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package org.zalando.logbook.servlet;

import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.SneakyThrows;
import org.zalando.logbook.ContentType;
import org.zalando.logbook.HttpHeaders;
import org.zalando.logbook.HttpRequest;
import org.zalando.logbook.Origin;
import org.zalando.logbook.common.MediaTypeQuery;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
Expand All @@ -20,14 +26,6 @@
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.SneakyThrows;
import org.zalando.logbook.ContentType;
import org.zalando.logbook.HttpHeaders;
import org.zalando.logbook.HttpRequest;
import org.zalando.logbook.Origin;
import org.zalando.logbook.common.MediaTypeQuery;

import static java.nio.charset.StandardCharsets.ISO_8859_1;
import static java.util.Collections.list;
Expand All @@ -40,7 +38,6 @@
final class RemoteRequest extends HttpServletRequestWrapper implements HttpRequest {

private final AtomicReference<State> state;
private Optional<AsyncListener> asyncListener = Optional.empty();

/**
* Manages the lifecycle of HTTP request body buffering for servlet requests.
Expand Down Expand Up @@ -429,24 +426,6 @@ public byte[] getBody() {
return buffer().getBody();
}

@Override
public AsyncContext startAsync() throws IllegalStateException {
final AsyncContext asyncContext = super.startAsync();
asyncListener.ifPresent(asyncContext::addListener);
return asyncContext;
}

@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException {
final AsyncContext asyncContext = super.startAsync(servletRequest, servletResponse);
asyncListener.ifPresent(asyncContext::addListener);
return asyncContext;
}

public void setAsyncListener(Optional<AsyncListener> asyncListener) {
this.asyncListener = asyncListener;
}

private State buffer() {
return state.updateAndGet(throwingUnaryOperator(state ->
state.buffer(getRequest())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import org.mockito.ArgumentCaptor;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.bean.override.mockito.MockitoBean;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.test.context.bean.override.mockito.MockitoBean;
import org.springframework.web.client.RestTemplate;
import org.zalando.logbook.Correlation;
import org.zalando.logbook.HttpLogWriter;
Expand Down Expand Up @@ -46,15 +46,15 @@ final class AsyncDispatchTest {
@MockitoBean
private HttpLogWriter writer;

static class AsyncHttpServlet extends HttpServlet{
static class AsyncHttpServlet extends HttpServlet {

@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
asyncResponse(req.startAsync(req, resp));
}

@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) {
protected void doPost(HttpServletRequest req, HttpServletResponse resp) {
asyncResponse(req.startAsync(req, resp));
}

Expand Down Expand Up @@ -136,6 +136,32 @@ void shouldFormatAsyncResponse() throws Exception {
.contains("200 OK", "text/plain", "Hello, world!");
}

@Test
void shouldFormatStreamingResponse() throws Exception {
final RestTemplate template = new RestTemplate();
template.getForObject("http://localhost:8080/api/streaming", String.class);

waitFor(Duration.ofSeconds(1));

final String response = interceptResponse();

assertThat(response)
.contains("200 OK", "text/plain", "chunked", "Hello, world!");
}

@Test
void shouldFormatResponseForDispatchWithMultipleStartAsync() throws Exception {
final RestTemplate template = new RestTemplate();
template.getForObject("http://localhost:8080/api/multi-async/step-1", String.class);

waitFor(Duration.ofSeconds(1));

final String response = interceptResponse();

assertThat(response)
.contains("200 OK", "text/plain", "Hello, world!");
}

@Test
void shouldFormatAsyncServletResponseGet() throws Exception {
final RestTemplate template = new RestTemplate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;

import java.io.IOException;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
import java.util.concurrent.Callable;

import static org.springframework.http.MediaType.TEXT_PLAIN;
import static org.springframework.http.MediaType.TEXT_PLAIN_VALUE;

@RestController
Expand All @@ -43,6 +47,28 @@ public Callable<String> returnMessage() {
return () -> "Hello, world!";
}

@RequestMapping(path = "/streaming", produces = TEXT_PLAIN_VALUE)
public ResponseEntity<StreamingResponseBody> streaming() {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.setContentType(new MediaType(TEXT_PLAIN, StandardCharsets.UTF_8));

return new ResponseEntity<>(
outputStream -> outputStream.write("Hello, world!".getBytes(StandardCharsets.UTF_8)),
httpHeaders,
HttpStatus.OK
);
}

@RequestMapping(path = "/multi-async/step-1", produces = TEXT_PLAIN_VALUE)
public Callable<ModelAndView> multiAsyncStep1() {
return () -> new ModelAndView("forward:/api/multi-async/step-2");
}

@RequestMapping(path = "/multi-async/step-2", produces = TEXT_PLAIN_VALUE)
public Callable<String> multiAsyncStep2() {
return () -> "Hello, world!";
}

@RequestMapping("/empty")
public void empty() {
// intentionally left blank
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.zalando.logbook.Logbook;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;

import static java.util.Collections.emptyEnumeration;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -47,7 +44,7 @@ void shouldCallDestroy() {
void shouldHandleIOExceptionOnFlushBufferAndWriteResponse() throws Exception {
Logbook logbook = mock(Logbook.class);
Logbook.RequestWritingStage requestWritingStage = mock(Logbook.RequestWritingStage.class);
Logbook.ResponseWritingStage responseWritingStage = mock(Logbook.ResponseWritingStage .class);
Logbook.ResponseWritingStage responseWritingStage = mock(Logbook.ResponseWritingStage.class);
LogbookFilter filter = new LogbookFilter(logbook);
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
Expand All @@ -66,32 +63,4 @@ void shouldHandleIOExceptionOnFlushBufferAndWriteResponse() throws Exception {

verify(responseWritingStage).write();
}

@Test
void shouldNotThrowNPEIfRequestDoesntContainWritingStageSynchronizationBoolean() throws Exception {
Logbook logbook = mock(Logbook.class);
Logbook.RequestWritingStage requestWritingStage = mock(Logbook.RequestWritingStage.class);
Logbook.ResponseWritingStage responseWritingStage = mock(Logbook.ResponseWritingStage .class);
LogbookFilter filter = new LogbookFilter(logbook);
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
FilterChain chain = mock(FilterChain.class);
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);


when(logbook.process(any())).thenReturn(requestWritingStage);
when(requestWritingStage.write()).thenReturn(requestWritingStage);
when(requestWritingStage.process(any())).thenReturn(responseWritingStage);
when(request.getHeaderNames()).thenReturn(emptyEnumeration());
when(request.getDispatcherType()).thenReturn(DispatcherType.REQUEST);
when(request.getAttribute(captor.capture())).thenReturn(null);

doThrow(new IOException("Simulated IOException")).when(response).flushBuffer();

filter.doFilter(request, response, chain);

verify(responseWritingStage, never()).write();
assertThat(captor.getValue()).contains("-Synchronization-");
}

}
Original file line number Diff line number Diff line change
@@ -1,52 +1,28 @@
package org.zalando.logbook.servlet;

import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.UnsupportedEncodingException;
import java.util.Optional;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.io.UnsupportedEncodingException;

import static java.util.Collections.emptyEnumeration;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class RemoteRequestTest {

private final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
private final HttpServletResponse httpServletResponse = mock(HttpServletResponse.class);
private final AsyncContext asyncContext = mock(AsyncContext.class);
private final AsyncListener asyncListener = mock(AsyncListener.class);

private RemoteRequest remoteRequest;

@BeforeEach
void setUp() {
when(httpServletRequest.getHeaderNames()).thenReturn(emptyEnumeration());
remoteRequest = new RemoteRequest(httpServletRequest, FormRequestMode.OFF);
remoteRequest.setAsyncListener(Optional.of(asyncListener));
}

@Test
void startAsync_noargs() {
when(httpServletRequest.startAsync()).thenReturn(asyncContext);

assertEquals(asyncContext, remoteRequest.startAsync());
verify(asyncContext).addListener(asyncListener);
}

@Test
void startAsync_twoargs() {
when(httpServletRequest.startAsync(httpServletRequest, httpServletResponse)).thenReturn(asyncContext);

assertEquals(asyncContext, remoteRequest.startAsync(httpServletRequest, httpServletResponse));
verify(asyncContext).addListener(asyncListener);
}

@Test
Expand Down Expand Up @@ -180,7 +156,7 @@ void offered_without_thenGetInputStream() throws Exception {
remoteRequest.withBody();
assertEquals(4, remoteRequest.getBody().length);

jakarta.servlet.ServletInputStream stream = remoteRequest.getInputStream();
remoteRequest.getInputStream();
assertEquals(4, remoteRequest.getBody().length);
}

Expand Down