diff --git a/gradle.properties b/gradle.properties index 1db1b5cf..fd01f739 100644 --- a/gradle.properties +++ b/gradle.properties @@ -7,7 +7,7 @@ PROJECT_LICENSE=MIT PROJECT_LICENSE_URL=https://github.com/graphql-java-kickstart/spring-java-servlet/blob/master/LICENSE.md PROJECT_DEV_ID=oliemansm PROJECT_DEV_NAME=Michiel Oliemans -LIB_GRAPHQL_JAVA_VER=22.3 +LIB_GRAPHQL_JAVA_VER=25.0 LIB_JACKSON_VER=2.17.2 LIB_SLF4J_VER=2.0.16 LIB_LOMBOK_VER=1.18.34 diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/DecoratedExecutionResult.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/DecoratedExecutionResult.java index 1880c5a8..53e94491 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/DecoratedExecutionResult.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/DecoratedExecutionResult.java @@ -2,13 +2,14 @@ import graphql.ExecutionResult; import graphql.GraphQLError; +import graphql.incremental.IncrementalExecutionResult; import java.util.List; import java.util.Map; import lombok.RequiredArgsConstructor; import org.reactivestreams.Publisher; @RequiredArgsConstructor -class DecoratedExecutionResult implements ExecutionResult { +public class DecoratedExecutionResult implements ExecutionResult { private final ExecutionResult result; @@ -16,6 +17,14 @@ boolean isAsynchronous() { return result.getData() instanceof Publisher; } + boolean isIncremental() { + return result instanceof IncrementalExecutionResult; + } + + public IncrementalExecutionResult asIncrementalExecutionResult() { + return (IncrementalExecutionResult) result; + } + @Override public List getErrors() { return result.getErrors(); diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLBatchedQueryResult.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLBatchedQueryResult.java index 415530a5..30b8d321 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLBatchedQueryResult.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLBatchedQueryResult.java @@ -19,4 +19,9 @@ public boolean isBatched() { public boolean isAsynchronous() { return false; } + + @Override + public boolean isIncremental() { + return false; + } } diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLErrorQueryResult.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLErrorQueryResult.java index bf1fc9d8..71d25a73 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLErrorQueryResult.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLErrorQueryResult.java @@ -20,6 +20,11 @@ public boolean isAsynchronous() { return false; } + @Override + public boolean isIncremental() { + return false; + } + @Override public boolean isError() { return true; diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLObjectMapper.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLObjectMapper.java index 9bac536c..c465e417 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLObjectMapper.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLObjectMapper.java @@ -9,6 +9,8 @@ import graphql.ExecutionResult; import graphql.ExecutionResultImpl; import graphql.GraphQLError; +import graphql.incremental.DelayedIncrementalPartialResult; +import graphql.incremental.IncrementalPayload; import graphql.kickstart.execution.config.ConfiguringObjectMapperProvider; import graphql.kickstart.execution.config.GraphQLServletObjectMapperConfigurer; import graphql.kickstart.execution.config.ObjectMapperProvider; @@ -18,6 +20,7 @@ import java.io.InputStream; import java.io.Writer; import java.util.ArrayList; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -118,22 +121,35 @@ public byte[] serializeResultAsBytes(ExecutionResult executionResult) { return getJacksonMapper().writeValueAsBytes(createResultFromExecutionResult(executionResult)); } + @SneakyThrows + public byte[] serializeDelayedIncrementalResultsAsBytes(DelayedIncrementalPartialResult delayedIncrementalPartialResult) { + return getJacksonMapper().writeValueAsBytes(createResultFromDelayedIncrementalPayloadResult(delayedIncrementalPartialResult)); + } + public boolean areErrorsPresent(ExecutionResult executionResult) { return graphQLErrorHandlerSupplier.get().errorsPresent(executionResult.getErrors()); } - public ExecutionResult sanitizeErrors(ExecutionResult executionResult) { - Object data = executionResult.getData(); + public boolean areExtensionsPresent(ExecutionResult executionResult) { Map extensions = executionResult.getExtensions(); - List errors = executionResult.getErrors(); + return extensions != null && !extensions.isEmpty(); + } + public ExecutionResult sanitizeErrors(ExecutionResult executionResult) { GraphQLErrorHandler errorHandler = graphQLErrorHandlerSupplier.get(); - if (errorHandler.errorsPresent(errors)) { - errors = errorHandler.processErrors(errors); - } else { - errors = null; - } - return new ExecutionResultImpl(data, errors, extensions); + return executionResult.transform(er -> { + List errors = executionResult.getErrors(); + if (errorHandler.errorsPresent(errors)) { + errors = errorHandler.processErrors(errors); + } else { + errors = List.of(); + } + er.errors(errors); + }); + } + + public DelayedIncrementalPartialResult sanitizeErrors(DelayedIncrementalPartialResult delayedIncrementalPartialResult) { + return delayedIncrementalPartialResult; } public Map createResultFromExecutionResult(ExecutionResult executionResult) { @@ -141,28 +157,35 @@ public Map createResultFromExecutionResult(ExecutionResult execu return convertSanitizedExecutionResult(sanitizedExecutionResult); } + public Map createResultFromDelayedIncrementalPayloadResult(DelayedIncrementalPartialResult delayedIncrementalPartialResult) { + DelayedIncrementalPartialResult sanitizedDelayedIncrementalPartialResult = sanitizeErrors(delayedIncrementalPartialResult); + return convertSanitizedDelayedIncrementalPartialResult(sanitizedDelayedIncrementalPartialResult); + } + public Map convertSanitizedExecutionResult(ExecutionResult executionResult) { return convertSanitizedExecutionResult(executionResult, true); } + public Map convertSanitizedDelayedIncrementalPartialResult( + DelayedIncrementalPartialResult delayedIncrementalPartialResult) { + return delayedIncrementalPartialResult.toSpecification(); + } + public Map convertSanitizedExecutionResult( ExecutionResult executionResult, boolean includeData) { - final Map result = new LinkedHashMap<>(); - - if (areErrorsPresent(executionResult)) { - result.put( - "errors", - executionResult.getErrors().stream() - .map(GraphQLError::toSpecification) - .collect(toList())); + final Map result = new HashMap<>(executionResult.toSpecification()); + + if (!areErrorsPresent(executionResult)) { + result.remove("errors"); } - if (executionResult.getExtensions() != null && !executionResult.getExtensions().isEmpty()) { - result.put("extensions", executionResult.getExtensions()); + if (!includeData) { + result.remove("data"); } + result.putIfAbsent("data", null); - if (includeData) { - result.put("data", executionResult.getData()); + if (!areExtensionsPresent(executionResult)) { + result.remove("extensions"); } return result; diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLQueryResult.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLQueryResult.java index d2502199..52cf17d7 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLQueryResult.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLQueryResult.java @@ -23,6 +23,8 @@ static GraphQLErrorQueryResult createError(int statusCode, String message) { boolean isAsynchronous(); + boolean isIncremental(); + default DecoratedExecutionResult getResult() { return null; } diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLSingleQueryResult.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLSingleQueryResult.java index 0593116b..19cbe913 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLSingleQueryResult.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/GraphQLSingleQueryResult.java @@ -17,4 +17,9 @@ public boolean isBatched() { public boolean isAsynchronous() { return result.isAsynchronous(); } + + @Override + public boolean isIncremental() { + return result.isIncremental(); + } } diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/context/DefaultGraphQLContext.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/context/DefaultGraphQLContext.java index dda335d0..52f2b005 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/context/DefaultGraphQLContext.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/context/DefaultGraphQLContext.java @@ -36,6 +36,10 @@ public void put(Object key, Object value) { map.put(key, value); } + public void putAll(Map values) { + map.putAll(values); + } + @Override public DataLoaderRegistry getDataLoaderRegistry() { return dataLoaderRegistry; diff --git a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/DelayedIncrementalPartialResultSubscriber.java b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/DelayedIncrementalPartialResultSubscriber.java new file mode 100644 index 00000000..d3b5e1ae --- /dev/null +++ b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/DelayedIncrementalPartialResultSubscriber.java @@ -0,0 +1,73 @@ +package graphql.kickstart.servlet; + +import graphql.incremental.DelayedIncrementalPartialResult; +import graphql.kickstart.execution.GraphQLObjectMapper; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.ServletResponse; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +class DelayedIncrementalPartialResultSubscriber implements Subscriber { + + private final AtomicReference subscriptionRef; + private final AsyncContext asyncContext; + private final GraphQLObjectMapper graphQLObjectMapper; + private final CountDownLatch completedLatch = new CountDownLatch(1); + + DelayedIncrementalPartialResultSubscriber( + AtomicReference subscriptionRef, + AsyncContext asyncContext, + GraphQLObjectMapper graphQLObjectMapper) { + this.subscriptionRef = subscriptionRef; + this.asyncContext = asyncContext; + this.graphQLObjectMapper = graphQLObjectMapper; + } + + @Override + public void onSubscribe(Subscription subscription) { + subscriptionRef.set(subscription); + subscriptionRef.get().request(1); + } + + @Override + public void onNext(DelayedIncrementalPartialResult delayedIncrementalPartialResult) { + try { + ServletResponse response = asyncContext.getResponse(); + ServletOutputStream outputStream = response.getOutputStream(); + outputStream.write(HttpRequestHandler.MULTIPART_BOUNDARY.getBytes(StandardCharsets.UTF_8)); + outputStream.write(HttpRequestHandler.MULTIPART_CONTENT_TYPE.getBytes( + StandardCharsets.UTF_8)); + byte[] contentBytes = graphQLObjectMapper.serializeDelayedIncrementalResultsAsBytes(delayedIncrementalPartialResult); + outputStream.write(contentBytes); + outputStream.write("\r\n".getBytes(StandardCharsets.UTF_8)); + if (!delayedIncrementalPartialResult.hasNext()) { + outputStream.write(HttpRequestHandler.MULTIPART_BOUNDARY.getBytes(StandardCharsets.UTF_8)); + } + outputStream.flush(); + subscriptionRef.get().request(1); + } catch (IOException ignored) { + // ignore + } + } + + @Override + public void onError(Throwable t) { + asyncContext.complete(); + completedLatch.countDown(); + } + + @Override + public void onComplete() { + asyncContext.complete(); + completedLatch.countDown(); + } + + void await() throws InterruptedException { + completedLatch.await(); + } +} diff --git a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/HttpRequestHandler.java b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/HttpRequestHandler.java index 4be54ee2..07e1c189 100644 --- a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/HttpRequestHandler.java +++ b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/HttpRequestHandler.java @@ -8,6 +8,9 @@ public interface HttpRequestHandler { String APPLICATION_JSON_UTF8 = "application/json;charset=UTF-8"; String APPLICATION_EVENT_STREAM_UTF8 = "text/event-stream;charset=UTF-8"; + String MULTIPART_MIXED = "multipart/mixed; boundary=\"-\""; + String MULTIPART_BOUNDARY = "---\r\n"; + String MULTIPART_CONTENT_TYPE = "Content-Type: application/json; charset=UTF-8\r\n\r\n"; int STATUS_OK = 200; int STATUS_BAD_REQUEST = 400; diff --git a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/QueryResponseWriterFactoryImpl.java b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/QueryResponseWriterFactoryImpl.java index 0dd5d82b..39e5342f 100644 --- a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/QueryResponseWriterFactoryImpl.java +++ b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/QueryResponseWriterFactoryImpl.java @@ -23,6 +23,14 @@ public QueryResponseWriter createWriter( configuration.getObjectMapper(), configuration.getSubscriptionTimeout()); } + + if (queryResult.isIncremental()) { + return new SingleIncrementalQueryResponseWriter( + queryResult.getResult().asIncrementalExecutionResult(), + configuration.getObjectMapper(), + configuration.getSubscriptionTimeout()); + } + if (queryResult.isError()) { return new ErrorQueryResponseWriter(queryResult.getStatusCode(), queryResult.getMessage()); } diff --git a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/SingleIncrementalQueryResponseWriter.java b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/SingleIncrementalQueryResponseWriter.java new file mode 100644 index 00000000..32e57a7e --- /dev/null +++ b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/SingleIncrementalQueryResponseWriter.java @@ -0,0 +1,76 @@ +package graphql.kickstart.servlet; + +import graphql.ExecutionResult; +import graphql.incremental.DelayedIncrementalPartialResult; +import graphql.incremental.IncrementalExecutionResult; +import graphql.incremental.IncrementalPayload; +import graphql.kickstart.execution.GraphQLObjectMapper; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; + +@RequiredArgsConstructor +class SingleIncrementalQueryResponseWriter implements QueryResponseWriter { + + @Getter private final IncrementalExecutionResult result; + private final GraphQLObjectMapper graphQLObjectMapper; + private final long subscriptionTimeout; + + @Override + public void write(HttpServletRequest request, HttpServletResponse response) throws IOException { + Objects.requireNonNull(request, "Http servlet request cannot be null"); + response.setContentType(HttpRequestHandler.MULTIPART_MIXED); + response.setStatus(HttpRequestHandler.STATUS_OK); + response.setCharacterEncoding(StandardCharsets.UTF_8.name()); + + // Write the initial data + ServletOutputStream outputStream = response.getOutputStream(); + outputStream.write("\r\n".getBytes(StandardCharsets.UTF_8)); + outputStream.write(HttpRequestHandler.MULTIPART_BOUNDARY.getBytes(StandardCharsets.UTF_8)); + outputStream.write(HttpRequestHandler.MULTIPART_CONTENT_TYPE.getBytes( + StandardCharsets.UTF_8)); + byte[] contentBytes = graphQLObjectMapper.serializeResultAsBytes(result); + outputStream.write(contentBytes); + outputStream.write("\r\n".getBytes(StandardCharsets.UTF_8)); + outputStream.flush(); + + // If no more data is expected, we can just complete the response here + boolean isInAsyncThread = request.isAsyncStarted(); + AsyncContext asyncContext = + isInAsyncThread ? request.getAsyncContext() : request.startAsync(request, response); + if (!result.hasNext()) { + asyncContext.complete(); + return; + } + + // Now handle any delayed incremental payloads + asyncContext.setTimeout(subscriptionTimeout); + AtomicReference subscriptionRef = new AtomicReference<>(); + asyncContext.addListener(new SubscriptionAsyncListener(subscriptionRef)); + DelayedIncrementalPartialResultSubscriber subscriber = + new DelayedIncrementalPartialResultSubscriber(subscriptionRef, asyncContext, graphQLObjectMapper); + var publisher = result.getIncrementalItemPublisher(); + publisher.subscribe(subscriber); + + if (isInAsyncThread) { + // We need to delay the completion of async context until after the subscription has + // terminated, otherwise the AsyncContext is prematurely closed. + try { + subscriber.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + } +} diff --git a/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/AbstractGraphQLHttpServletSpec.groovy b/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/AbstractGraphQLHttpServletSpec.groovy index dde7cb46..b04083ea 100644 --- a/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/AbstractGraphQLHttpServletSpec.groovy +++ b/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/AbstractGraphQLHttpServletSpec.groovy @@ -13,6 +13,7 @@ import org.springframework.mock.web.MockHttpServletResponse import spock.lang.Shared import spock.lang.Specification +import java.nio.charset.Charset import java.nio.charset.StandardCharsets import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit @@ -28,6 +29,7 @@ class AbstractGraphQLHttpServletSpec extends Specification { public static final int STATUS_ERROR = 500 public static final String CONTENT_TYPE_JSON_UTF8 = 'application/json;charset=UTF-8' public static final String CONTENT_TYPE_SERVER_SENT_EVENTS = 'text/event-stream;charset=UTF-8' + public static final String CONTENT_TYPE_MULTIPART_MIXED = "multipart/mixed; boundary=\"-\";charset=UTF-8" @Shared ObjectMapper mapper = new ObjectMapper() @@ -1226,4 +1228,36 @@ b servlet.getConfiguration().getObjectMapper().getJacksonMapper().writeValueAsString(stepInfo) != "{}" } -} + def "incremental query over HTTP POST body returns data"() { + setup: + request.setContent(mapper.writeValueAsBytes([ + query: 'query { echo(arg:"test") ... @defer(label: "deferredEcho") { deferredEcho: echo(arg:"test") } }' + ])) + request.setMethod("POST") + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_MULTIPART_MIXED + + + def actual = response.getContentAsString(Charset.defaultCharset()) + def expected = ''' +--- +Content-Type: application/json; charset=UTF-8 + +{"data":{"echo":"test"},"hasNext":true} +--- +Content-Type: application/json; charset=UTF-8 + +{"hasNext":false,"incremental":[{"path":[],"label":"deferredEcho","data":{"deferredEcho":"test"}}]} +--- +''' + + // Normalize CRLF -> LF on both sides + def normalize = { s -> s.replace("\r\n", "\n") } + normalize(actual) == normalize(expected) + } +} \ No newline at end of file diff --git a/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/DataLoaderDispatchingSpec.groovy b/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/DataLoaderDispatchingSpec.groovy index b8c7791b..8eefa339 100644 --- a/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/DataLoaderDispatchingSpec.groovy +++ b/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/DataLoaderDispatchingSpec.groovy @@ -80,30 +80,11 @@ class DataLoaderDispatchingSpec extends Specification { } } - def contextBuilder() { - return new GraphQLServletContextBuilder() { - @Override - GraphQLKickstartContext build(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { - new DefaultGraphQLContext(registry()) - } - - @Override - GraphQLKickstartContext build(Session session, HandshakeRequest handshakeRequest) { - new DefaultGraphQLContext(registry()) - } - - @Override - GraphQLKickstartContext build() { - new DefaultGraphQLContext(registry()) - } - } - } - def configureServlet(ContextSetting contextSetting) { servlet = TestUtils.createDataLoadingServlet(queryDataFetcher("A", loadCounterA), queryDataFetcher("B", loadCounterB), queryDataFetcher("C", loadCounterC) , contextSetting, - contextBuilder()) + TestUtils.contextBuilder(this::registry)) } def resetCounters() { diff --git a/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/SingleIncrementalQueryResponseWriterTest.groovy b/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/SingleIncrementalQueryResponseWriterTest.groovy new file mode 100644 index 00000000..f194f5a1 --- /dev/null +++ b/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/SingleIncrementalQueryResponseWriterTest.groovy @@ -0,0 +1,37 @@ +package graphql.kickstart.servlet + +import graphql.incremental.IncrementalExecutionResult +import graphql.kickstart.execution.GraphQLObjectMapper +import jakarta.servlet.ServletOutputStream +import org.springframework.mock.web.MockAsyncContext +import spock.lang.Specification + +import jakarta.servlet.http.HttpServletRequest +import jakarta.servlet.http.HttpServletResponse + +class SingleIncrementalQueryResponseWriterTest extends Specification { + + def "result hasNext should complete"() { + given: + def result = Mock(IncrementalExecutionResult) + result.hasNext() >> false + def objectMapper = Mock(GraphQLObjectMapper) + def writer = new SingleIncrementalQueryResponseWriter(result, objectMapper, 100) + def request = Mock(HttpServletRequest) + def responseOutputStream = Mock(ServletOutputStream) + def response = Mock(HttpServletResponse) + response.getOutputStream() >> responseOutputStream + def asyncContext = new MockAsyncContext(request, response) + request.getAsyncContext() >> asyncContext + request.isAsyncStarted() >> true + + objectMapper.serializeResultAsJson(result) >> "{ }" + + when: + writer.write(request, response) + + then: + noExceptionThrown() + } + +} diff --git a/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/TestUtils.groovy b/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/TestUtils.groovy index 684b2cc2..519e0774 100644 --- a/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/TestUtils.groovy +++ b/graphql-java-servlet/src/test/groovy/graphql/kickstart/servlet/TestUtils.groovy @@ -1,9 +1,12 @@ package graphql.kickstart.servlet import com.google.common.io.ByteStreams +import graphql.ExperimentalApi import graphql.Scalars import graphql.execution.reactive.SingleSubscriberPublisher import graphql.kickstart.execution.context.ContextSetting +import graphql.kickstart.execution.context.DefaultGraphQLContext +import graphql.kickstart.execution.context.GraphQLKickstartContext import graphql.kickstart.servlet.apollo.ApolloScalars import graphql.kickstart.servlet.context.GraphQLServletContextBuilder import graphql.kickstart.servlet.core.GraphQLServletListener @@ -14,10 +17,16 @@ import graphql.schema.idl.SchemaGenerator import graphql.schema.idl.SchemaParser import graphql.schema.idl.TypeRuntimeWiring import graphql.schema.idl.errors.SchemaProblem +import jakarta.servlet.http.HttpServletRequest +import jakarta.servlet.http.HttpServletResponse +import jakarta.websocket.Session +import jakarta.websocket.server.HandshakeRequest import lombok.NonNull +import org.dataloader.DataLoaderRegistry import java.util.concurrent.Executor import java.util.concurrent.atomic.AtomicReference +import java.util.function.Supplier class TestUtils { @@ -99,9 +108,36 @@ class TestUtils { configBuilder.with(Arrays.asList(listeners)) } configBuilder.with(executor()); + configBuilder.with(contextBuilder()) configBuilder.build() } + static def contextBuilder(Supplier dataLoaderRegistrySupplier = DataLoaderRegistry::new, Map contextMap = [:]) { + return new GraphQLServletContextBuilder() { + @Override + GraphQLKickstartContext build(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { + context() + } + + @Override + GraphQLKickstartContext build(Session session, HandshakeRequest handshakeRequest) { + context() + } + + @Override + GraphQLKickstartContext build() { + context() + } + + private context() { + var context = new DefaultGraphQLContext(dataLoaderRegistrySupplier.get()) + context.put(ExperimentalApi.ENABLE_INCREMENTAL_SUPPORT, true) + context.putAll(contextMap) + return context + } + } + } + private static Executor executor() { new Executor() { @Override