From 20392a0c8d1584607ecb52997f92faf258078ec5 Mon Sep 17 00:00:00 2001 From: David Brassely Date: Mon, 8 Sep 2025 11:07:28 +0200 Subject: [PATCH 1/2] feat: Provide a new http-client layer to create new HTTP client implementation / extensions (vertx) --- README.md | 14 +- client/base/src/main/java/io/a2a/A2A.java | 30 +- .../java/io/a2a/client/ClientBuilderTest.java | 36 +- .../client/transport/grpc/GrpcTransport.java | 20 +- .../transport/jsonrpc/JSONRPCTransport.java | 160 ++++----- .../jsonrpc/JSONRPCTransportConfig.java | 19 +- .../JSONRPCTransportConfigBuilder.java | 18 +- .../jsonrpc/JSONRPCTransportProvider.java | 19 +- .../jsonrpc/sse/SSEEventListener.java | 25 +- .../jsonrpc/sse/SSEEventListenerTest.java | 20 +- .../transport/rest/RestErrorMapper.java | 6 +- .../client/transport/rest/RestTransport.java | 179 +++++----- .../transport/rest/RestTransportConfig.java | 19 +- .../rest/RestTransportConfigBuilder.java | 21 +- .../transport/rest/RestTransportProvider.java | 21 +- .../rest/sse/RestSSEEventListener.java | 39 ++- .../transport/rest/RestTransportTest.java | 45 +-- .../spi/AbstractClientTransport.java | 31 ++ .../interceptors/ClientCallInterceptor.java | 2 +- .../spi/interceptors/PayloadAndHeaders.java | 2 +- .../interceptors/auth/AuthInterceptor.java | 2 +- extras/README.md | 3 +- extras/http-client-vertx/README.md | 76 +++++ extras/http-client-vertx/pom.xml | 57 ++++ .../client/http/vertx/VertxHttpClient.java | 216 ++++++++++++ .../http/vertx/VertxHttpClientBuilder.java | 30 ++ .../a2a/client/http/vertx/sse/SSEHandler.java | 124 +++++++ .../client/http/vertx/ClientBuilderTest.java | 62 ++++ .../http/vertx/VertxHttpClientTest.java | 13 + .../JpaPushNotificationConfigStoreTest.java | 46 +-- http-client/pom.xml | 4 +- .../io/a2a/client/http/A2ACardResolver.java | 71 ++-- .../io/a2a/client/http/A2AHttpClient.java | 42 --- .../io/a2a/client/http/A2AHttpResponse.java | 9 - .../java/io/a2a/client/http/HttpClient.java | 45 +++ .../io/a2a/client/http/HttpClientBuilder.java | 10 + .../java/io/a2a/client/http/HttpResponse.java | 17 + .../io/a2a/client/http/JdkA2AHttpClient.java | 311 ------------------ .../io/a2a/client/http/jdk/JdkHttpClient.java | 260 +++++++++++++++ .../client/http/jdk/JdkHttpClientBuilder.java | 12 + .../a2a/client/http/jdk/sse/SSEHandler.java | 120 +++++++ .../io/a2a/client/http/sse/CommentEvent.java | 54 +++ .../io/a2a/client/http/sse/DataEvent.java | 81 +++++ .../java/io/a2a/client/http/sse/Event.java | 11 + .../a2a/client/http/A2ACardResolverTest.java | 180 +++++----- .../client/http/jdk/JdkHttpClientTest.java | 31 ++ pom.xml | 21 ++ .../io/a2a/server/http/HttpClientManager.java | 59 ++++ .../tasks/BasePushNotificationSender.java | 33 +- .../server/http/HttpClientManagerTest.java | 52 +++ .../AbstractA2ARequestHandlerTest.java | 100 +++--- ...MemoryPushNotificationConfigStoreTest.java | 60 ++-- .../tasks/PushNotificationSenderTest.java | 125 ++++--- tests/client-common/pom.xml | 60 ++++ .../http/common/AbstractHttpClientTest.java | 187 +++++++++++ .../a2a/client/http/common/JsonMessages.java | 85 +++++ .../http/common/JsonStreamingMessages.java | 15 + .../server/apps/common/TestHttpClient.java | 82 ++--- .../jsonrpc/handler/JSONRPCHandlerTest.java | 3 +- 59 files changed, 2468 insertions(+), 1027 deletions(-) create mode 100644 client/transport/spi/src/main/java/io/a2a/client/transport/spi/AbstractClientTransport.java create mode 100644 extras/http-client-vertx/README.md create mode 100644 extras/http-client-vertx/pom.xml create mode 100644 extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClient.java create mode 100644 extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClientBuilder.java create mode 100644 extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/sse/SSEHandler.java create mode 100644 extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/ClientBuilderTest.java create mode 100644 extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/VertxHttpClientTest.java delete mode 100644 http-client/src/main/java/io/a2a/client/http/A2AHttpClient.java delete mode 100644 http-client/src/main/java/io/a2a/client/http/A2AHttpResponse.java create mode 100644 http-client/src/main/java/io/a2a/client/http/HttpClient.java create mode 100644 http-client/src/main/java/io/a2a/client/http/HttpClientBuilder.java create mode 100644 http-client/src/main/java/io/a2a/client/http/HttpResponse.java delete mode 100644 http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java create mode 100644 http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClient.java create mode 100644 http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClientBuilder.java create mode 100644 http-client/src/main/java/io/a2a/client/http/jdk/sse/SSEHandler.java create mode 100644 http-client/src/main/java/io/a2a/client/http/sse/CommentEvent.java create mode 100644 http-client/src/main/java/io/a2a/client/http/sse/DataEvent.java create mode 100644 http-client/src/main/java/io/a2a/client/http/sse/Event.java create mode 100644 http-client/src/test/java/io/a2a/client/http/jdk/JdkHttpClientTest.java create mode 100644 server-common/src/main/java/io/a2a/server/http/HttpClientManager.java create mode 100644 server-common/src/test/java/io/a2a/server/http/HttpClientManagerTest.java create mode 100644 tests/client-common/pom.xml create mode 100644 tests/client-common/src/test/java/io/a2a/client/http/common/AbstractHttpClientTest.java create mode 100644 tests/client-common/src/test/java/io/a2a/client/http/common/JsonMessages.java create mode 100644 tests/client-common/src/test/java/io/a2a/client/http/common/JsonStreamingMessages.java diff --git a/README.md b/README.md index 52bdea52..2316b226 100644 --- a/README.md +++ b/README.md @@ -349,13 +349,13 @@ Different transport protocols can be configured with specific settings using spe ##### JSON-RPC Transport Configuration -For the JSON-RPC transport, to use the default `JdkA2AHttpClient`, provide a `JSONRPCTransportConfig` created with its default constructor. +For the JSON-RPC transport, to use the default `JdkHttpClient`, provide a `JSONRPCTransportConfig` created with its default constructor. To use a custom HTTP client implementation, simply create a `JSONRPCTransportConfig` as follows: ```java -// Create a custom HTTP client -A2AHttpClient customHttpClient = ... +// Create a custom HTTP client builder +HttpClientBuilder httpClientBuilder = ... // Configure the client settings ClientConfig clientConfig = new ClientConfig.Builder() @@ -365,7 +365,7 @@ ClientConfig clientConfig = new ClientConfig.Builder() Client client = Client .builder(agentCard) .clientConfig(clientConfig) - .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig(customHttpClient)) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig(httpClientBuilder)) .build(); ``` @@ -396,13 +396,13 @@ Client client = Client ##### HTTP+JSON/REST Transport Configuration -For the HTTP+JSON/REST transport, if you'd like to use the default `JdkA2AHttpClient`, provide a `RestTransportConfig` created with its default constructor. +For the HTTP+JSON/REST transport, if you'd like to use the default `JdkHttpClient`, provide a `RestTransportConfig` created with its default constructor. To use a custom HTTP client implementation, simply create a `RestTransportConfig` as follows: ```java // Create a custom HTTP client -A2AHttpClient customHttpClient = ... +HttpClientBuilder httpClientBuilder = ... // Configure the client settings ClientConfig clientConfig = new ClientConfig.Builder() @@ -412,7 +412,7 @@ ClientConfig clientConfig = new ClientConfig.Builder() Client client = Client .builder(agentCard) .clientConfig(clientConfig) - .withTransport(RestTransport.class, new RestTransportConfig(customHttpClient)) + .withTransport(RestTransport.class, new RestTransportConfig(httpClientBuilder)) .build(); ``` diff --git a/client/base/src/main/java/io/a2a/A2A.java b/client/base/src/main/java/io/a2a/A2A.java index 063527c2..158daac1 100644 --- a/client/base/src/main/java/io/a2a/A2A.java +++ b/client/base/src/main/java/io/a2a/A2A.java @@ -3,11 +3,9 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.UUID; import io.a2a.client.http.A2ACardResolver; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClient; import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientJSONError; import io.a2a.spec.AgentCard; @@ -139,20 +137,7 @@ private static Message toMessage(List> parts, Message.Role role, String * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ public static AgentCard getAgentCard(String agentUrl) throws A2AClientError, A2AClientJSONError { - return getAgentCard(new JdkA2AHttpClient(), agentUrl); - } - - /** - * Get the agent card for an A2A agent. - * - * @param httpClient the http client to use - * @param agentUrl the base URL for the agent whose agent card we want to retrieve - * @return the agent card - * @throws A2AClientError If an HTTP error occurs fetching the card - * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema - */ - public static AgentCard getAgentCard(A2AHttpClient httpClient, String agentUrl) throws A2AClientError, A2AClientJSONError { - return getAgentCard(httpClient, agentUrl, null, null); + return getAgentCard(HttpClient.createHttpClient(agentUrl), null, null); } /** @@ -160,30 +145,29 @@ public static AgentCard getAgentCard(A2AHttpClient httpClient, String agentUrl) * * @param agentUrl the base URL for the agent whose agent card we want to retrieve * @param relativeCardPath optional path to the agent card endpoint relative to the base - * agent URL, defaults to ".well-known/agent-card.json" + * agent URL, defaults to "/.well-known/agent-card.json" * @param authHeaders the HTTP authentication headers to use * @return the agent card * @throws A2AClientError If an HTTP error occurs fetching the card * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ public static AgentCard getAgentCard(String agentUrl, String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { - return getAgentCard(new JdkA2AHttpClient(), agentUrl, relativeCardPath, authHeaders); + return getAgentCard(HttpClient.createHttpClient(agentUrl), relativeCardPath, authHeaders); } /** * Get the agent card for an A2A agent. * * @param httpClient the http client to use - * @param agentUrl the base URL for the agent whose agent card we want to retrieve * @param relativeCardPath optional path to the agent card endpoint relative to the base - * agent URL, defaults to ".well-known/agent-card.json" + * agent URL, defaults to "/.well-known/agent-card.json" * @param authHeaders the HTTP authentication headers to use * @return the agent card * @throws A2AClientError If an HTTP error occurs fetching the card * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ - public static AgentCard getAgentCard(A2AHttpClient httpClient, String agentUrl, String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { - A2ACardResolver resolver = new A2ACardResolver(httpClient, agentUrl, relativeCardPath, authHeaders); + public static AgentCard getAgentCard(HttpClient httpClient, String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { + A2ACardResolver resolver = new A2ACardResolver(httpClient, relativeCardPath, authHeaders); return resolver.getAgentCard(); } } diff --git a/client/base/src/test/java/io/a2a/client/ClientBuilderTest.java b/client/base/src/test/java/io/a2a/client/ClientBuilderTest.java index 1c7ed38a..b8f849cb 100644 --- a/client/base/src/test/java/io/a2a/client/ClientBuilderTest.java +++ b/client/base/src/test/java/io/a2a/client/ClientBuilderTest.java @@ -1,7 +1,8 @@ package io.a2a.client; import io.a2a.client.config.ClientConfig; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClientBuilder; +import io.a2a.client.http.jdk.JdkHttpClientBuilder; import io.a2a.client.transport.grpc.GrpcTransport; import io.a2a.client.transport.grpc.GrpcTransportConfigBuilder; import io.a2a.client.transport.jsonrpc.JSONRPCTransport; @@ -71,13 +72,40 @@ public void shouldNotFindConfigurationTransport() throws A2AClientException { } @Test - public void shouldCreateJSONRPCClient() throws A2AClientException { + public void shouldNotCreateJSONRPCClient_nullHttpClientFactory() throws A2AClientException { + Assertions.assertThrows(IllegalArgumentException.class, + () -> { + Client + .builder(card) + .clientConfig(new ClientConfig.Builder().setUseClientPreference(true).build()) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfigBuilder() + .addInterceptor(null) + .httpClientBuilder(null)) + .build(); + }); + } + + @Test + public void shouldCreateJSONRPCClient_defaultHttpClientFactory() throws A2AClientException { + Client client = Client + .builder(card) + .clientConfig(new ClientConfig.Builder().setUseClientPreference(true).build()) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfigBuilder() + .addInterceptor(null) + .httpClientBuilder(HttpClientBuilder.DEFAULT_FACTORY)) + .build(); + + Assertions.assertNotNull(client); + } + + @Test + public void shouldCreateJSONRPCClient_withHttpClientFactory() throws A2AClientException { Client client = Client .builder(card) .clientConfig(new ClientConfig.Builder().setUseClientPreference(true).build()) .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfigBuilder() .addInterceptor(null) - .httpClient(null)) + .httpClientBuilder(new JdkHttpClientBuilder())) .build(); Assertions.assertNotNull(client); @@ -88,7 +116,7 @@ public void shouldCreateClient_differentConfigurations() throws A2AClientExcepti Client client = Client .builder(card) .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfigBuilder()) - .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig(new JdkA2AHttpClient())) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig()) .build(); Assertions.assertNotNull(client); diff --git a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java index 2023339d..d1943f27 100644 --- a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java +++ b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java @@ -11,8 +11,8 @@ import java.util.function.Consumer; import java.util.stream.Collectors; +import io.a2a.client.transport.spi.AbstractClientTransport; import io.a2a.client.transport.spi.interceptors.ClientCallContext; -import io.a2a.client.transport.spi.ClientTransport; import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor; import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders; import io.a2a.client.transport.spi.interceptors.auth.AuthInterceptor; @@ -50,7 +50,7 @@ import io.grpc.stub.MetadataUtils; import io.grpc.stub.StreamObserver; -public class GrpcTransport implements ClientTransport { +public class GrpcTransport extends AbstractClientTransport { private static final Metadata.Key AUTHORIZATION_METADATA_KEY = Metadata.Key.of( AuthInterceptor.AUTHORIZATION, @@ -60,7 +60,6 @@ public class GrpcTransport implements ClientTransport { Metadata.ASCII_STRING_MARSHALLER); private final A2AServiceBlockingV2Stub blockingStub; private final A2AServiceStub asyncStub; - private final List interceptors; private AgentCard agentCard; public GrpcTransport(Channel channel, AgentCard agentCard) { @@ -68,11 +67,11 @@ public GrpcTransport(Channel channel, AgentCard agentCard) { } public GrpcTransport(Channel channel, AgentCard agentCard, List interceptors) { + super(interceptors); checkNotNullParam("channel", channel); this.asyncStub = A2AServiceGrpc.newStub(channel); this.blockingStub = A2AServiceGrpc.newBlockingV2Stub(channel); this.agentCard = agentCard; - this.interceptors = interceptors; } @Override @@ -365,17 +364,4 @@ private String getTaskPushNotificationConfigName(String taskId, String pushNotif return name.toString(); } - private PayloadAndHeaders applyInterceptors(String methodName, Object payload, - AgentCard agentCard, ClientCallContext clientCallContext) { - PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload, - clientCallContext != null ? clientCallContext.getHeaders() : null); - if (interceptors != null && ! interceptors.isEmpty()) { - for (ClientCallInterceptor interceptor : interceptors) { - payloadAndHeaders = interceptor.intercept(methodName, payloadAndHeaders.getPayload(), - payloadAndHeaders.getHeaders(), agentCard, clientCallContext); - } - } - return payloadAndHeaders; - } - } \ No newline at end of file diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java index 8464911f..a1ff52bb 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java @@ -3,21 +3,23 @@ import static io.a2a.util.Assert.checkNotNullParam; import java.io.IOException; +import java.net.URI; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.function.BiConsumer; import java.util.function.Consumer; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import io.a2a.client.http.A2ACardResolver; +import io.a2a.client.transport.spi.AbstractClientTransport; import io.a2a.client.transport.spi.interceptors.ClientCallContext; import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor; import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; -import io.a2a.client.http.JdkA2AHttpClient; -import io.a2a.client.transport.spi.ClientTransport; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientException; import io.a2a.spec.AgentCard; @@ -59,8 +61,9 @@ import java.util.concurrent.atomic.AtomicReference; import io.a2a.util.Utils; +import org.jspecify.annotations.Nullable; -public class JSONRPCTransport implements ClientTransport { +public class JSONRPCTransport extends AbstractClientTransport { private static final TypeReference SEND_MESSAGE_RESPONSE_REFERENCE = new TypeReference<>() {}; private static final TypeReference GET_TASK_RESPONSE_REFERENCE = new TypeReference<>() {}; @@ -71,9 +74,8 @@ public class JSONRPCTransport implements ClientTransport { private static final TypeReference DELETE_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE = new TypeReference<>() {}; private static final TypeReference GET_AUTHENTICATED_EXTENDED_CARD_RESPONSE_REFERENCE = new TypeReference<>() {}; - private final A2AHttpClient httpClient; - private final String agentUrl; - private final List interceptors; + private final HttpClient httpClient; + private final String agentPath; private AgentCard agentCard; private boolean needsExtendedCard = false; @@ -81,21 +83,26 @@ public JSONRPCTransport(String agentUrl) { this(null, null, agentUrl, null); } - public JSONRPCTransport(AgentCard agentCard) { - this(null, agentCard, agentCard.url(), null); - } - - public JSONRPCTransport(A2AHttpClient httpClient, AgentCard agentCard, - String agentUrl, List interceptors) { - this.httpClient = httpClient == null ? new JdkA2AHttpClient() : httpClient; + public JSONRPCTransport(@Nullable HttpClient httpClient, @Nullable AgentCard agentCard, + String agentUrl, @Nullable List interceptors) { + super(interceptors); + this.httpClient = httpClient == null ? HttpClient.createHttpClient(agentUrl) : httpClient; this.agentCard = agentCard; - this.agentUrl = agentUrl; - this.interceptors = interceptors; + + String sAgentPath = URI.create(agentUrl).getPath(); + + // Strip the last slash if one is provided + if (sAgentPath.endsWith("/")) { + this.agentPath = sAgentPath.substring(0, sAgentPath.length() - 1); + } else { + this.agentPath = sAgentPath; + } + this.needsExtendedCard = agentCard == null || agentCard.supportsAuthenticatedExtendedCard(); } @Override - public EventKind sendMessage(MessageSendParams request, ClientCallContext context) throws A2AClientException { + public EventKind sendMessage(MessageSendParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); SendMessageRequest sendMessageRequest = new SendMessageRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -103,8 +110,7 @@ public EventKind sendMessage(MessageSendParams request, ClientCallContext contex .params(request) .build(); // id will be randomly generated - PayloadAndHeaders payloadAndHeaders = applyInterceptors(SendMessageRequest.METHOD, sendMessageRequest, - agentCard, context); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(SendMessageRequest.METHOD, sendMessageRequest, agentCard, context); try { String httpResponseBody = sendPostRequest(payloadAndHeaders); @@ -119,7 +125,7 @@ public EventKind sendMessage(MessageSendParams request, ClientCallContext contex @Override public void sendMessageStreaming(MessageSendParams request, Consumer eventConsumer, - Consumer errorConsumer, ClientCallContext context) throws A2AClientException { + Consumer errorConsumer, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); checkNotNullParam("eventConsumer", eventConsumer); SendStreamingMessageRequest sendStreamingMessageRequest = new SendStreamingMessageRequest.Builder() @@ -128,29 +134,33 @@ public void sendMessageStreaming(MessageSendParams request, Consumer> ref = new AtomicReference<>(); + AtomicReference> ref = new AtomicReference<>(); SSEEventListener sseEventListener = new SSEEventListener(eventConsumer, errorConsumer); try { - A2AHttpClient.PostBuilder builder = createPostBuilder(payloadAndHeaders); - ref.set(builder.postAsyncSSE( - msg -> sseEventListener.onMessage(msg, ref.get()), - throwable -> sseEventListener.onError(throwable, ref.get()), - () -> { - // We don't need to do anything special on completion - })); + HttpClient.PostRequestBuilder builder = createPostBuilder(payloadAndHeaders).asSSE(); + ref.set(builder.send() + .whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (httpResponse != null) { + httpResponse.bodyAsSse( + msg -> sseEventListener.onMessage(msg, ref.get()), + cause -> sseEventListener.onError(cause, ref.get())); + } else { + errorConsumer.accept(throwable); + } + } + })); } catch (IOException e) { throw new A2AClientException("Failed to send streaming message request: " + e, e); - } catch (InterruptedException e) { - throw new A2AClientException("Send streaming message request timed out: " + e, e); } } @Override - public Task getTask(TaskQueryParams request, ClientCallContext context) throws A2AClientException { + public Task getTask(TaskQueryParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); GetTaskRequest getTaskRequest = new GetTaskRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -158,8 +168,7 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A .params(request) .build(); // id will be randomly generated - PayloadAndHeaders payloadAndHeaders = applyInterceptors(GetTaskRequest.METHOD, getTaskRequest, - agentCard, context); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(GetTaskRequest.METHOD, getTaskRequest, agentCard, context); try { String httpResponseBody = sendPostRequest(payloadAndHeaders); @@ -173,7 +182,7 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A } @Override - public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A2AClientException { + public Task cancelTask(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); CancelTaskRequest cancelTaskRequest = new CancelTaskRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -181,8 +190,7 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A .params(request) .build(); // id will be randomly generated - PayloadAndHeaders payloadAndHeaders = applyInterceptors(CancelTaskRequest.METHOD, cancelTaskRequest, - agentCard, context); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(CancelTaskRequest.METHOD, cancelTaskRequest, agentCard, context); try { String httpResponseBody = sendPostRequest(payloadAndHeaders); @@ -197,7 +205,7 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A @Override public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, - ClientCallContext context) throws A2AClientException { + @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); SetTaskPushNotificationConfigRequest setTaskPushNotificationRequest = new SetTaskPushNotificationConfigRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -222,7 +230,7 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN @Override public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPushNotificationConfigParams request, - ClientCallContext context) throws A2AClientException { + @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); GetTaskPushNotificationConfigRequest getTaskPushNotificationRequest = new GetTaskPushNotificationConfigRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -248,7 +256,7 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPu @Override public List listTaskPushNotificationConfigurations( ListTaskPushNotificationConfigParams request, - ClientCallContext context) throws A2AClientException { + @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); ListTaskPushNotificationConfigRequest listTaskPushNotificationRequest = new ListTaskPushNotificationConfigRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -273,7 +281,7 @@ public List listTaskPushNotificationConfigurations( @Override public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request, - ClientCallContext context) throws A2AClientException { + @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); DeleteTaskPushNotificationConfigRequest deleteTaskPushNotificationRequest = new DeleteTaskPushNotificationConfigRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -296,7 +304,7 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC @Override public void resubscribe(TaskIdParams request, Consumer eventConsumer, - Consumer errorConsumer, ClientCallContext context) throws A2AClientException { + Consumer errorConsumer, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); checkNotNullParam("eventConsumer", eventConsumer); checkNotNullParam("errorConsumer", errorConsumer); @@ -309,30 +317,33 @@ public void resubscribe(TaskIdParams request, Consumer event PayloadAndHeaders payloadAndHeaders = applyInterceptors(TaskResubscriptionRequest.METHOD, taskResubscriptionRequest, agentCard, context); - AtomicReference> ref = new AtomicReference<>(); + AtomicReference> ref = new AtomicReference<>(); SSEEventListener sseEventListener = new SSEEventListener(eventConsumer, errorConsumer); try { - A2AHttpClient.PostBuilder builder = createPostBuilder(payloadAndHeaders); - ref.set(builder.postAsyncSSE( - msg -> sseEventListener.onMessage(msg, ref.get()), - throwable -> sseEventListener.onError(throwable, ref.get()), - () -> { - // We don't need to do anything special on completion - })); + HttpClient.PostRequestBuilder builder = createPostBuilder(payloadAndHeaders).asSSE(); + ref.set(builder.send().whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (httpResponse != null) { + httpResponse.bodyAsSse( + msg -> sseEventListener.onMessage(msg, ref.get()), + cause -> sseEventListener.onError(cause, ref.get())); + } else { + errorConsumer.accept(throwable); + } + } + })); } catch (IOException e) { throw new A2AClientException("Failed to send task resubscription request: " + e, e); - } catch (InterruptedException e) { - throw new A2AClientException("Task resubscription request timed out: " + e, e); } } @Override - public AgentCard getAgentCard(ClientCallContext context) throws A2AClientException { - A2ACardResolver resolver; + public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2AClientException { try { if (agentCard == null) { - resolver = new A2ACardResolver(httpClient, agentUrl, null, getHttpHeaders(context)); + A2ACardResolver resolver = new A2ACardResolver(httpClient, agentPath, getHttpHeaders(context)); agentCard = resolver.getAgentCard(); needsExtendedCard = agentCard.supportsAuthenticatedExtendedCard(); } @@ -368,30 +379,25 @@ public void close() { // no-op } - private PayloadAndHeaders applyInterceptors(String methodName, Object payload, - AgentCard agentCard, ClientCallContext clientCallContext) { - PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload, getHttpHeaders(clientCallContext)); - if (interceptors != null && ! interceptors.isEmpty()) { - for (ClientCallInterceptor interceptor : interceptors) { - payloadAndHeaders = interceptor.intercept(methodName, payloadAndHeaders.getPayload(), - payloadAndHeaders.getHeaders(), agentCard, clientCallContext); + private String sendPostRequest(PayloadAndHeaders payloadAndHeaders) throws IOException, InterruptedException { + HttpClient.PostRequestBuilder builder = createPostBuilder(payloadAndHeaders); + try { + HttpResponse response = builder.send().get(); + if (!response.success()) { + throw new IOException("Request failed " + response.statusCode()); } - } - return payloadAndHeaders; - } + return response.body(); - private String sendPostRequest(PayloadAndHeaders payloadAndHeaders) throws IOException, InterruptedException { - A2AHttpClient.PostBuilder builder = createPostBuilder(payloadAndHeaders); - A2AHttpResponse response = builder.post(); - if (!response.success()) { - throw new IOException("Request failed " + response.status()); + } catch (ExecutionException e) { + if (e.getCause() instanceof IOException) { + throw (IOException) e.getCause(); + } + throw new IOException("Failed to send request", e.getCause()); } - return response.body(); } - private A2AHttpClient.PostBuilder createPostBuilder(PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException { - A2AHttpClient.PostBuilder postBuilder = httpClient.createPost() - .url(agentUrl) + private HttpClient.PostRequestBuilder createPostBuilder(PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException { + HttpClient.PostRequestBuilder postBuilder = httpClient.post(agentPath) .addHeader("Content-Type", "application/json") .body(Utils.OBJECT_MAPPER.writeValueAsString(payloadAndHeaders.getPayload())); @@ -414,7 +420,7 @@ private > T unmarshalResponse(String response, Type return value; } - private Map getHttpHeaders(ClientCallContext context) { + private Map getHttpHeaders(@Nullable ClientCallContext context) { return context != null ? context.getHeaders() : null; } } \ No newline at end of file diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfig.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfig.java index efd3bbdf..2cdc4183 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfig.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfig.java @@ -1,21 +1,24 @@ package io.a2a.client.transport.jsonrpc; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportConfig; -import io.a2a.client.http.A2AHttpClient; +import io.a2a.util.Assert; +import org.jspecify.annotations.Nullable; public class JSONRPCTransportConfig extends ClientTransportConfig { - private final A2AHttpClient httpClient; + private final HttpClientBuilder httpClientBuilder; - public JSONRPCTransportConfig() { - this.httpClient = null; + public JSONRPCTransportConfig(HttpClientBuilder httpClientBuilder) { + Assert.checkNotNullParam("httpClientBuilder", httpClientBuilder); + this.httpClientBuilder = httpClientBuilder; } - public JSONRPCTransportConfig(A2AHttpClient httpClient) { - this.httpClient = httpClient; + public JSONRPCTransportConfig() { + this.httpClientBuilder = HttpClientBuilder.DEFAULT_FACTORY; } - public A2AHttpClient getHttpClient() { - return httpClient; + public HttpClientBuilder getHttpClientBuilder() { + return this.httpClientBuilder; } } \ No newline at end of file diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfigBuilder.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfigBuilder.java index 64153620..ed1956e3 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfigBuilder.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfigBuilder.java @@ -1,27 +1,23 @@ package io.a2a.client.transport.jsonrpc; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportConfigBuilder; +import io.a2a.util.Assert; public class JSONRPCTransportConfigBuilder extends ClientTransportConfigBuilder { - private A2AHttpClient httpClient; + private HttpClientBuilder httpClientBuilder = HttpClientBuilder.DEFAULT_FACTORY; - public JSONRPCTransportConfigBuilder httpClient(A2AHttpClient httpClient) { - this.httpClient = httpClient; + public JSONRPCTransportConfigBuilder httpClientBuilder(HttpClientBuilder httpClientBuilder) { + Assert.checkNotNullParam("httpClientBuilder", httpClientBuilder); + this.httpClientBuilder = httpClientBuilder; return this; } @Override public JSONRPCTransportConfig build() { - // No HTTP client provided, fallback to the default one (JDK-based implementation) - if (httpClient == null) { - httpClient = new JdkA2AHttpClient(); - } - - JSONRPCTransportConfig config = new JSONRPCTransportConfig(httpClient); + JSONRPCTransportConfig config = new JSONRPCTransportConfig(httpClientBuilder); config.setInterceptors(this.interceptors); return config; } diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportProvider.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportProvider.java index 97c22866..66de8dcf 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportProvider.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportProvider.java @@ -1,6 +1,7 @@ package io.a2a.client.transport.jsonrpc; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportProvider; import io.a2a.spec.A2AClientException; import io.a2a.spec.AgentCard; @@ -9,12 +10,20 @@ public class JSONRPCTransportProvider implements ClientTransportProvider { @Override - public JSONRPCTransport create(JSONRPCTransportConfig clientTransportConfig, AgentCard agentCard, String agentUrl) throws A2AClientException { - if (clientTransportConfig == null) { - clientTransportConfig = new JSONRPCTransportConfig(new JdkA2AHttpClient()); + public JSONRPCTransport create(JSONRPCTransportConfig transportConfig, AgentCard agentCard, String agentUrl) throws A2AClientException { + if (transportConfig == null) { + transportConfig = new JSONRPCTransportConfig(); } - return new JSONRPCTransport(clientTransportConfig.getHttpClient(), agentCard, agentUrl, clientTransportConfig.getInterceptors()); + HttpClientBuilder httpClientBuilder = transportConfig.getHttpClientBuilder(); + + try { + final HttpClient httpClient = httpClientBuilder.create(agentUrl); + + return new JSONRPCTransport(httpClient, agentCard, agentUrl, transportConfig.getInterceptors()); + } catch (Exception ex) { + throw new A2AClientException("Failed to create JSONRPC transport", ex); + } } @Override diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java index 99ca546c..af88c732 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java @@ -2,6 +2,9 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.DataEvent; +import io.a2a.client.http.sse.Event; import io.a2a.spec.JSONRPCError; import io.a2a.spec.StreamingEventKind; import io.a2a.spec.TaskStatusUpdateEvent; @@ -23,22 +26,28 @@ public SSEEventListener(Consumer eventHandler, this.errorHandler = errorHandler; } - public void onMessage(String message, Future completableFuture) { - try { - handleMessage(OBJECT_MAPPER.readTree(message),completableFuture); - } catch (JsonProcessingException e) { - log.warning("Failed to parse JSON message: " + message); + public void onMessage(Event event, Future completableFuture) { + log.fine("Streaming message received: " + event); + + if (event instanceof DataEvent) { + try { + handleMessage(OBJECT_MAPPER.readTree(((DataEvent) event).getData()), completableFuture); + } catch (JsonProcessingException e) { + log.warning("Failed to parse JSON message: " + ((DataEvent) event).getData()); + } } } - public void onError(Throwable throwable, Future future) { + public void onError(Throwable throwable, Future future) { if (errorHandler != null) { errorHandler.accept(throwable); } - future.cancel(true); // close SSE channel + if (future != null) { + future.cancel(true); // close SSE channel + } } - private void handleMessage(JsonNode jsonNode, Future future) { + private void handleMessage(JsonNode jsonNode, Future future) { try { if (jsonNode.has("error")) { JSONRPCError error = OBJECT_MAPPER.treeToValue(jsonNode.get("error"), JSONRPCError.class); diff --git a/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListenerTest.java b/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListenerTest.java index 8c4c1495..0acfcea0 100644 --- a/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListenerTest.java +++ b/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListenerTest.java @@ -13,6 +13,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.DataEvent; import io.a2a.client.transport.jsonrpc.JsonStreamingMessages; import io.a2a.spec.Artifact; import io.a2a.spec.JSONRPCError; @@ -43,7 +45,7 @@ public void testOnEventWithTaskResult() throws Exception { JsonStreamingMessages.STREAMING_TASK_EVENT.indexOf("{")); // Call the onEvent method directly - listener.onMessage(eventData, null); + listener.onMessage(new DataEvent(null, eventData, null), null); // Verify the event was processed correctly assertNotNull(receivedEvent.get()); @@ -68,7 +70,7 @@ public void testOnEventWithMessageResult() throws Exception { JsonStreamingMessages.STREAMING_MESSAGE_EVENT.indexOf("{")); // Call onEvent method - listener.onMessage(eventData, null); + listener.onMessage(new DataEvent(null, eventData, null), null); // Verify the event was processed correctly assertNotNull(receivedEvent.get()); @@ -96,7 +98,7 @@ public void testOnEventWithTaskStatusUpdateEventEvent() throws Exception { JsonStreamingMessages.STREAMING_STATUS_UPDATE_EVENT.indexOf("{")); // Call onEvent method - listener.onMessage(eventData, null); + listener.onMessage(new DataEvent(null, eventData, null), null); // Verify the event was processed correctly assertNotNull(receivedEvent.get()); @@ -122,7 +124,7 @@ public void testOnEventWithTaskArtifactUpdateEventEvent() throws Exception { JsonStreamingMessages.STREAMING_ARTIFACT_UPDATE_EVENT.indexOf("{")); // Call onEvent method - listener.onMessage(eventData, null); + listener.onMessage(new DataEvent(null, eventData, null), null); // Verify the event was processed correctly assertNotNull(receivedEvent.get()); @@ -154,7 +156,7 @@ public void testOnEventWithError() throws Exception { JsonStreamingMessages.STREAMING_ERROR_EVENT.indexOf("{")); // Call onEvent method - listener.onMessage(eventData, null); + listener.onMessage(new DataEvent(null, eventData, null), null); // Verify the error was processed correctly assertNotNull(receivedError.get()); @@ -217,7 +219,7 @@ public void testOnEventWithFinalTaskStatusUpdateEventEventCancels() throws Excep // Call onEvent method CancelCapturingFuture future = new CancelCapturingFuture(); - listener.onMessage(eventData, future); + listener.onMessage(new DataEvent(null, eventData, null), future); // Verify the event was processed correctly assertNotNull(receivedEvent.get()); @@ -232,7 +234,7 @@ public void testOnEventWithFinalTaskStatusUpdateEventEventCancels() throws Excep } - private static class CancelCapturingFuture implements Future { + private static class CancelCapturingFuture implements Future { private boolean cancelHandlerCalled; public CancelCapturingFuture() { @@ -255,12 +257,12 @@ public boolean isDone() { } @Override - public Void get() throws InterruptedException, ExecutionException { + public HttpResponse get() throws InterruptedException, ExecutionException { return null; } @Override - public Void get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + public HttpResponse get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { return null; } } diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java index 965cc296..85bf962b 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java @@ -4,7 +4,7 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; -import io.a2a.client.http.A2AHttpResponse; +import io.a2a.client.http.HttpResponse; import io.a2a.spec.A2AClientException; import io.a2a.spec.AuthenticatedExtendedCardNotConfiguredError; import io.a2a.spec.ContentTypeNotSupportedError; @@ -28,8 +28,8 @@ public class RestErrorMapper { private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().registerModule(new JavaTimeModule()); - public static A2AClientException mapRestError(A2AHttpResponse response) { - return RestErrorMapper.mapRestError(response.body(), response.status()); + public static A2AClientException mapRestError(HttpResponse response) { + return RestErrorMapper.mapRestError(response.body(), response.statusCode()); } public static A2AClientException mapRestError(String body, int code) { diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java index f659589b..912c0082 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java @@ -7,11 +7,10 @@ import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.util.JsonFormat; import io.a2a.client.http.A2ACardResolver; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; import io.a2a.client.transport.rest.sse.RestSSEEventListener; -import io.a2a.client.transport.spi.ClientTransport; +import io.a2a.client.transport.spi.AbstractClientTransport; import io.a2a.client.transport.spi.interceptors.ClientCallContext; import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor; import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders; @@ -38,8 +37,11 @@ import io.a2a.spec.SetTaskPushNotificationConfigRequest; import io.a2a.util.Utils; import java.io.IOException; +import java.net.URI; import java.util.Collections; import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.function.BiConsumer; import java.util.logging.Logger; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -47,25 +49,31 @@ import java.util.function.Consumer; import org.jspecify.annotations.Nullable; -public class RestTransport implements ClientTransport { +public class RestTransport extends AbstractClientTransport { private static final Logger log = Logger.getLogger(RestTransport.class.getName()); - private final A2AHttpClient httpClient; - private final String agentUrl; - private @Nullable final List interceptors; - private AgentCard agentCard; + private final HttpClient httpClient; + private final String agentPath; + private @Nullable AgentCard agentCard; private boolean needsExtendedCard = false; - public RestTransport(AgentCard agentCard) { - this(null, agentCard, agentCard.url(), null); + public RestTransport(String agentUrl) { + this(null, null, agentUrl, null); } - public RestTransport(@Nullable A2AHttpClient httpClient, AgentCard agentCard, + public RestTransport(@Nullable HttpClient httpClient, @Nullable AgentCard agentCard, String agentUrl, @Nullable List interceptors) { - this.httpClient = httpClient == null ? new JdkA2AHttpClient() : httpClient; + super(interceptors); + this.httpClient = httpClient == null ? HttpClient.createHttpClient(agentUrl) : httpClient; this.agentCard = agentCard; - this.agentUrl = agentUrl.endsWith("/") ? agentUrl.substring(0, agentUrl.length() - 1) : agentUrl; - this.interceptors = interceptors; + String sAgentPath = URI.create(agentUrl).getPath(); + + // Strip the last slash if one is provided + if (sAgentPath.endsWith("/")) { + this.agentPath = sAgentPath.substring(0, sAgentPath.length() - 1); + } else { + this.agentPath = sAgentPath; + } } @Override @@ -74,7 +82,7 @@ public EventKind sendMessage(MessageSendParams messageSendParams, @Nullable Clie io.a2a.grpc.SendMessageRequest.Builder builder = io.a2a.grpc.SendMessageRequest.newBuilder(ProtoUtils.ToProto.sendMessageRequest(messageSendParams)); PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.SendMessageRequest.METHOD, builder, agentCard, context); try { - String httpResponseBody = sendPostRequest(agentUrl + "/v1/message:send", payloadAndHeaders); + String httpResponseBody = sendPostRequest("/v1/message:send", payloadAndHeaders); io.a2a.grpc.SendMessageResponse.Builder responseBuilder = io.a2a.grpc.SendMessageResponse.newBuilder(); JsonFormat.parser().merge(httpResponseBody, responseBuilder); if (responseBuilder.hasMsg()) { @@ -86,7 +94,7 @@ public EventKind sendMessage(MessageSendParams messageSendParams, @Nullable Clie throw new A2AClientException("Failed to send message, wrong response:" + httpResponseBody); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to send message: " + e, e); } } @@ -99,20 +107,24 @@ public void sendMessageStreaming(MessageSendParams messageSendParams, Consumer> ref = new AtomicReference<>(); + AtomicReference> ref = new AtomicReference<>(); RestSSEEventListener sseEventListener = new RestSSEEventListener(eventConsumer, errorConsumer); try { - A2AHttpClient.PostBuilder postBuilder = createPostBuilder(agentUrl + "/v1/message:stream", payloadAndHeaders); - ref.set(postBuilder.postAsyncSSE( - msg -> sseEventListener.onMessage(msg, ref.get()), - throwable -> sseEventListener.onError(throwable, ref.get()), - () -> { - // We don't need to do anything special on completion - })); + HttpClient.PostRequestBuilder postBuilder = createPostBuilder("/v1/message:stream", payloadAndHeaders).asSSE(); + ref.set(postBuilder.send().whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (httpResponse != null) { + httpResponse.bodyAsSse( + msg -> sseEventListener.onMessage(msg, ref.get()), + cause -> sseEventListener.onError(cause, ref.get())); + } else { + errorConsumer.accept(throwable); + } + } + })); } catch (IOException e) { throw new A2AClientException("Failed to send streaming message request: " + e, e); - } catch (InterruptedException e) { - throw new A2AClientException("Send streaming message request timed out: " + e, e); } } @@ -124,19 +136,20 @@ public Task getTask(TaskQueryParams taskQueryParams, @Nullable ClientCallContext PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, builder, agentCard, context); try { - String url; + String path; if (taskQueryParams.historyLength() != null) { - url = agentUrl + String.format("/v1/tasks/%1s?historyLength=%2d", taskQueryParams.id(), taskQueryParams.historyLength()); + path = String.format("/v1/tasks/%1s?historyLength=%2d", taskQueryParams.id(), taskQueryParams.historyLength()); } else { - url = agentUrl + String.format("/v1/tasks/%1s", taskQueryParams.id()); + path = String.format("/v1/tasks/%1s", taskQueryParams.id()); } - A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); if (payloadAndHeaders.getHeaders() != null) { for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { getBuilder.addHeader(entry.getKey(), entry.getValue()); } } - A2AHttpResponse response = getBuilder.get(); + CompletableFuture responseFut = getBuilder.send(); + HttpResponse response = responseFut.get(); if (!response.success()) { throw RestErrorMapper.mapRestError(response); } @@ -146,7 +159,7 @@ public Task getTask(TaskQueryParams taskQueryParams, @Nullable ClientCallContext return ProtoUtils.FromProto.task(responseBuilder); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to get task: " + e, e); } } @@ -159,13 +172,13 @@ public Task cancelTask(TaskIdParams taskIdParams, @Nullable ClientCallContext co PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.CancelTaskRequest.METHOD, builder, agentCard, context); try { - String httpResponseBody = sendPostRequest(agentUrl + String.format("/v1/tasks/%1s:cancel", taskIdParams.id()), payloadAndHeaders); + String httpResponseBody = sendPostRequest(String.format("/v1/tasks/%1s:cancel", taskIdParams.id()), payloadAndHeaders); io.a2a.grpc.Task.Builder responseBuilder = io.a2a.grpc.Task.newBuilder(); JsonFormat.parser().merge(httpResponseBody, responseBuilder); return ProtoUtils.FromProto.task(responseBuilder); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to cancel task: " + e, e); } } @@ -181,13 +194,13 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN } PayloadAndHeaders payloadAndHeaders = applyInterceptors(SetTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); try { - String httpResponseBody = sendPostRequest(agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs", request.taskId()), payloadAndHeaders); + String httpResponseBody = sendPostRequest(String.format("/v1/tasks/%1s/pushNotificationConfigs", request.taskId()), payloadAndHeaders); io.a2a.grpc.TaskPushNotificationConfig.Builder responseBuilder = io.a2a.grpc.TaskPushNotificationConfig.newBuilder(); JsonFormat.parser().merge(httpResponseBody, responseBuilder); return ProtoUtils.FromProto.taskPushNotificationConfig(responseBuilder); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to set task push notification config: " + e, e); } } @@ -200,14 +213,17 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPu PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); try { - String url = agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); - A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + String path = String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); + HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); if (payloadAndHeaders.getHeaders() != null) { for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { getBuilder.addHeader(entry.getKey(), entry.getValue()); } } - A2AHttpResponse response = getBuilder.get(); + + CompletableFuture responseFut = getBuilder.send(); + HttpResponse response = responseFut.get(); + if (!response.success()) { throw RestErrorMapper.mapRestError(response); } @@ -217,7 +233,7 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPu return ProtoUtils.FromProto.taskPushNotificationConfig(responseBuilder); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to get push notifications: " + e, e); } } @@ -230,14 +246,16 @@ public List listTaskPushNotificationConfigurations(L PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.ListTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); try { - String url = agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs", request.id()); - A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + String path = String.format("/v1/tasks/%1s/pushNotificationConfigs", request.id()); + HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); if (payloadAndHeaders.getHeaders() != null) { for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { getBuilder.addHeader(entry.getKey(), entry.getValue()); } } - A2AHttpResponse response = getBuilder.get(); + CompletableFuture responseFut = getBuilder.send(); + HttpResponse response = responseFut.get(); + if (!response.success()) { throw RestErrorMapper.mapRestError(response); } @@ -247,7 +265,7 @@ public List listTaskPushNotificationConfigurations(L return ProtoUtils.FromProto.listTaskPushNotificationConfigParams(responseBuilder); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to list push notifications: " + e, e); } } @@ -259,20 +277,22 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.DeleteTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); try { - String url = agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); - A2AHttpClient.DeleteBuilder deleteBuilder = httpClient.createDelete().url(url); + String path = String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); + HttpClient.DeleteRequestBuilder deleteBuilder = httpClient.delete(agentPath + path); if (payloadAndHeaders.getHeaders() != null) { for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { deleteBuilder.addHeader(entry.getKey(), entry.getValue()); } } - A2AHttpResponse response = deleteBuilder.delete(); + CompletableFuture responseFut = deleteBuilder.send(); + HttpResponse response = responseFut.get(); + if (!response.success()) { throw RestErrorMapper.mapRestError(response); } } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to delete push notification config: " + e, e); } } @@ -285,21 +305,25 @@ public void resubscribe(TaskIdParams request, Consumer event builder.setName("tasks/" + request.id()); PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.TaskResubscriptionRequest.METHOD, builder, agentCard, context); - AtomicReference> ref = new AtomicReference<>(); + AtomicReference> ref = new AtomicReference<>(); RestSSEEventListener sseEventListener = new RestSSEEventListener(eventConsumer, errorConsumer); try { - String url = agentUrl + String.format("/v1/tasks/%1s:subscribe", request.id()); - A2AHttpClient.PostBuilder postBuilder = createPostBuilder(url, payloadAndHeaders); - ref.set(postBuilder.postAsyncSSE( - msg -> sseEventListener.onMessage(msg, ref.get()), - throwable -> sseEventListener.onError(throwable, ref.get()), - () -> { - // We don't need to do anything special on completion - })); + String path = String.format("/v1/tasks/%1s:subscribe", request.id()); + HttpClient.PostRequestBuilder postBuilder = createPostBuilder(path, payloadAndHeaders).asSSE(); + ref.set(postBuilder.send().whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (httpResponse != null) { + httpResponse.bodyAsSse( + msg -> sseEventListener.onMessage(msg, ref.get()), + cause -> sseEventListener.onError(cause, ref.get())); + } else { + errorConsumer.accept(throwable); + } + } + })); } catch (IOException e) { throw new A2AClientException("Failed to send streaming message request: " + e, e); - } catch (InterruptedException e) { - throw new A2AClientException("Send streaming message request timed out: " + e, e); } } @@ -308,7 +332,7 @@ public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2ACli A2ACardResolver resolver; try { if (agentCard == null) { - resolver = new A2ACardResolver(httpClient, agentUrl, null, getHttpHeaders(context)); + resolver = new A2ACardResolver(httpClient, agentPath, getHttpHeaders(context)); agentCard = resolver.getAgentCard(); needsExtendedCard = agentCard.supportsAuthenticatedExtendedCard(); } @@ -317,14 +341,16 @@ public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2ACli } PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, null, agentCard, context); - String url = agentUrl + String.format("/v1/card"); - A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + + HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + "/v1/card"); if (payloadAndHeaders.getHeaders() != null) { for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { getBuilder.addHeader(entry.getKey(), entry.getValue()); } } - A2AHttpResponse response = getBuilder.get(); + CompletableFuture responseFut = getBuilder.send(); + HttpResponse response = responseFut.get(); + if (!response.success()) { throw RestErrorMapper.mapRestError(response); } @@ -332,7 +358,7 @@ public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2ACli agentCard = Utils.OBJECT_MAPPER.readValue(httpResponseBody, AgentCard.class); needsExtendedCard = false; return agentCard; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to get authenticated extended agent card: " + e, e); } catch (A2AClientError e) { throw new A2AClientException("Failed to get agent card: " + e, e); @@ -344,21 +370,11 @@ public void close() { // no-op } - private PayloadAndHeaders applyInterceptors(String methodName, @Nullable MessageOrBuilder payload, - AgentCard agentCard, @Nullable ClientCallContext clientCallContext) { - PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload, getHttpHeaders(clientCallContext)); - if (interceptors != null && !interceptors.isEmpty()) { - for (ClientCallInterceptor interceptor : interceptors) { - payloadAndHeaders = interceptor.intercept(methodName, payloadAndHeaders.getPayload(), - payloadAndHeaders.getHeaders(), agentCard, clientCallContext); - } - } - return payloadAndHeaders; - } + private String sendPostRequest(String path, PayloadAndHeaders payloadAndHeaders) throws IOException, InterruptedException, ExecutionException { + HttpClient.PostRequestBuilder builder = createPostBuilder(path, payloadAndHeaders); + CompletableFuture responseFut = builder.send(); - private String sendPostRequest(String url, PayloadAndHeaders payloadAndHeaders) throws IOException, InterruptedException { - A2AHttpClient.PostBuilder builder = createPostBuilder(url, payloadAndHeaders); - A2AHttpResponse response = builder.post(); + HttpResponse response = responseFut.get(); if (!response.success()) { log.fine("Error on POST processing " + JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); throw RestErrorMapper.mapRestError(response); @@ -366,10 +382,9 @@ private String sendPostRequest(String url, PayloadAndHeaders payloadAndHeaders) return response.body(); } - private A2AHttpClient.PostBuilder createPostBuilder(String url, PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException, InvalidProtocolBufferException { + private HttpClient.PostRequestBuilder createPostBuilder(String path, PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException, InvalidProtocolBufferException { log.fine(JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); - A2AHttpClient.PostBuilder postBuilder = httpClient.createPost() - .url(url) + HttpClient.PostRequestBuilder postBuilder = httpClient.post(agentPath + path) .addHeader("Content-Type", "application/json") .body(JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfig.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfig.java index d097b010..21b694ce 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfig.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfig.java @@ -1,22 +1,23 @@ package io.a2a.client.transport.rest; -import io.a2a.client.http.A2AHttpClient; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportConfig; -import org.jspecify.annotations.Nullable; +import io.a2a.util.Assert; public class RestTransportConfig extends ClientTransportConfig { - private final @Nullable A2AHttpClient httpClient; + private final HttpClientBuilder httpClientBuilder; - public RestTransportConfig() { - this.httpClient = null; + public RestTransportConfig(HttpClientBuilder httpClientBuilder) { + Assert.checkNotNullParam("httpClientBuilder", httpClientBuilder); + this.httpClientBuilder = httpClientBuilder; } - public RestTransportConfig(A2AHttpClient httpClient) { - this.httpClient = httpClient; + public RestTransportConfig() { + this.httpClientBuilder = HttpClientBuilder.DEFAULT_FACTORY; } - public @Nullable A2AHttpClient getHttpClient() { - return httpClient; + public HttpClientBuilder getHttpClientBuilder() { + return httpClientBuilder; } } \ No newline at end of file diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfigBuilder.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfigBuilder.java index 68150f18..edcbcd1c 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfigBuilder.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfigBuilder.java @@ -1,27 +1,24 @@ package io.a2a.client.transport.rest; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportConfigBuilder; -import org.jspecify.annotations.Nullable; + +import io.a2a.util.Assert; public class RestTransportConfigBuilder extends ClientTransportConfigBuilder { - private @Nullable A2AHttpClient httpClient; + private HttpClientBuilder httpClientBuilder = io.a2a.client.http.HttpClientBuilder.DEFAULT_FACTORY; + + public RestTransportConfigBuilder httpClientBuilder(HttpClientBuilder httpClientBuilder) { + Assert.checkNotNullParam("httpClientBuilder", httpClientBuilder); + this.httpClientBuilder = httpClientBuilder; - public RestTransportConfigBuilder httpClient(A2AHttpClient httpClient) { - this.httpClient = httpClient; return this; } @Override public RestTransportConfig build() { - // No HTTP client provided, fallback to the default one (JDK-based implementation) - if (httpClient == null) { - httpClient = new JdkA2AHttpClient(); - } - - RestTransportConfig config = new RestTransportConfig(httpClient); + RestTransportConfig config = new RestTransportConfig(this.httpClientBuilder); config.setInterceptors(this.interceptors); return config; } diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportProvider.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportProvider.java index 99d15596..cd03086c 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportProvider.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportProvider.java @@ -1,6 +1,7 @@ package io.a2a.client.transport.rest; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportProvider; import io.a2a.spec.A2AClientException; import io.a2a.spec.AgentCard; @@ -14,12 +15,20 @@ public String getTransportProtocol() { } @Override - public RestTransport create(RestTransportConfig clientTransportConfig, AgentCard agentCard, String agentUrl) throws A2AClientException { - RestTransportConfig transportConfig = clientTransportConfig; - if (transportConfig == null) { - transportConfig = new RestTransportConfig(new JdkA2AHttpClient()); + public RestTransport create(RestTransportConfig transportConfig, AgentCard agentCard, String agentUrl) throws A2AClientException { + if (transportConfig == null) { + transportConfig = new RestTransportConfig(); + } + + HttpClientBuilder httpClientBuilder = transportConfig.getHttpClientBuilder(); + + try { + final HttpClient httpClient = httpClientBuilder.create(agentUrl); + + return new RestTransport(httpClient, agentCard, agentUrl, transportConfig.getInterceptors()); + } catch (Exception ex) { + throw new A2AClientException("Failed to create REST transport", ex); } - return new RestTransport(clientTransportConfig.getHttpClient(), agentCard, agentUrl, transportConfig.getInterceptors()); } @Override diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java index d0b130ee..2afd586e 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java @@ -1,20 +1,19 @@ package io.a2a.client.transport.rest.sse; -import static io.a2a.grpc.StreamResponse.PayloadCase.ARTIFACT_UPDATE; -import static io.a2a.grpc.StreamResponse.PayloadCase.MSG; -import static io.a2a.grpc.StreamResponse.PayloadCase.STATUS_UPDATE; -import static io.a2a.grpc.StreamResponse.PayloadCase.TASK; - import java.util.concurrent.Future; import java.util.function.Consumer; import java.util.logging.Logger; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.util.JsonFormat; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.DataEvent; +import io.a2a.client.http.sse.Event; import io.a2a.client.transport.rest.RestErrorMapper; import io.a2a.grpc.StreamResponse; import io.a2a.grpc.utils.ProtoUtils; import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.TaskStatusUpdateEvent; import org.jspecify.annotations.Nullable; public class RestSSEEventListener { @@ -29,18 +28,21 @@ public RestSSEEventListener(Consumer eventHandler, this.errorHandler = errorHandler; } - public void onMessage(String message, @Nullable Future completableFuture) { - try { - log.fine("Streaming message received: " + message); - io.a2a.grpc.StreamResponse.Builder builder = io.a2a.grpc.StreamResponse.newBuilder(); - JsonFormat.parser().merge(message, builder); - handleMessage(builder.build()); - } catch (InvalidProtocolBufferException e) { - errorHandler.accept(RestErrorMapper.mapRestError(message, 500)); + public void onMessage(Event event, @Nullable Future completableFuture) { + log.fine("Streaming message received: " + event); + + if (event instanceof DataEvent) { + try { + io.a2a.grpc.StreamResponse.Builder builder = io.a2a.grpc.StreamResponse.newBuilder(); + JsonFormat.parser().merge(((DataEvent) event).getData(), builder); + handleMessage(builder.build(), completableFuture); + } catch (InvalidProtocolBufferException e) { + errorHandler.accept(RestErrorMapper.mapRestError(((DataEvent) event).getData(), 500)); + } } } - public void onError(Throwable throwable, @Nullable Future future) { + public void onError(Throwable throwable, @Nullable Future future) { if (errorHandler != null) { errorHandler.accept(throwable); } @@ -49,15 +51,19 @@ public void onError(Throwable throwable, @Nullable Future future) { } } - private void handleMessage(StreamResponse response) { + private void handleMessage(StreamResponse response, @Nullable Future future) { StreamingEventKind event; switch (response.getPayloadCase()) { case MSG -> event = ProtoUtils.FromProto.message(response.getMsg()); case TASK -> event = ProtoUtils.FromProto.task(response.getTask()); - case STATUS_UPDATE -> + case STATUS_UPDATE -> { event = ProtoUtils.FromProto.taskStatusUpdateEvent(response.getStatusUpdate()); + if (((TaskStatusUpdateEvent) event).isFinal() && future != null) { + future.cancel(true); // close SSE channel + } + } case ARTIFACT_UPDATE -> event = ProtoUtils.FromProto.taskArtifactUpdateEvent(response.getArtifactUpdate()); default -> { @@ -68,5 +74,4 @@ private void handleMessage(StreamResponse response) { } eventHandler.accept(event); } - } diff --git a/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java b/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java index a296553c..ae938cb4 100644 --- a/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java +++ b/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java @@ -1,6 +1,5 @@ package io.a2a.client.transport.rest; - import static io.a2a.client.transport.rest.JsonRestMessages.CANCEL_TASK_TEST_REQUEST; import static io.a2a.client.transport.rest.JsonRestMessages.CANCEL_TASK_TEST_RESPONSE; import static io.a2a.client.transport.rest.JsonRestMessages.GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE; @@ -22,9 +21,6 @@ import static org.mockserver.model.HttpResponse.response; import io.a2a.client.transport.spi.interceptors.ClientCallContext; -import io.a2a.spec.AgentCapabilities; -import io.a2a.spec.AgentCard; -import io.a2a.spec.AgentSkill; import io.a2a.spec.Artifact; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; import io.a2a.spec.EventKind; @@ -67,28 +63,7 @@ public class RestTransportTest { private static final Logger log = Logger.getLogger(RestTransportTest.class.getName()); private ClientAndServer server; - private static final AgentCard CARD = new AgentCard.Builder() - .name("Hello World Agent") - .description("Just a hello world agent") - .url("http://localhost:4001") - .version("1.0.0") - .documentationUrl("http://example.com/docs") - .capabilities(new AgentCapabilities.Builder() - .streaming(true) - .pushNotifications(true) - .stateTransitionHistory(true) - .build()) - .defaultInputModes(Collections.singletonList("text")) - .defaultOutputModes(Collections.singletonList("text")) - .skills(Collections.singletonList(new AgentSkill.Builder() - .id("hello_world") - .name("Returns hello world") - .description("just returns hello world") - .tags(Collections.singletonList("hello world")) - .examples(List.of("hi", "hello world")) - .build())) - .protocolVersion("0.3.0") - .build(); + private static final String AGENT_URL = "http://localhost:4001"; @BeforeEach public void setUp() throws IOException { @@ -129,7 +104,7 @@ public void testSendMessage() throws Exception { MessageSendParams messageSendParams = new MessageSendParams(message, null, null); ClientCallContext context = null; - RestTransport instance = new RestTransport(CARD); + RestTransport instance = new RestTransport(AGENT_URL); EventKind result = instance.sendMessage(messageSendParams, context); assertEquals("task", result.getKind()); Task task = (Task) result; @@ -170,7 +145,7 @@ public void testCancelTask() throws Exception { .withBody(CANCEL_TASK_TEST_RESPONSE) ); ClientCallContext context = null; - RestTransport instance = new RestTransport(CARD); + RestTransport instance = new RestTransport(AGENT_URL); Task task = instance.cancelTask(new TaskIdParams("de38c76d-d54c-436c-8b9f-4c2703648d64", new HashMap<>()), context); assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); @@ -196,7 +171,7 @@ public void testGetTask() throws Exception { ); ClientCallContext context = null; TaskQueryParams request = new TaskQueryParams("de38c76d-d54c-436c-8b9f-4c2703648d64", 10); - RestTransport instance = new RestTransport(CARD); + RestTransport instance = new RestTransport(AGENT_URL); Task task = instance.getTask(request, context); assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); assertEquals(TaskState.COMPLETED, task.getStatus().state()); @@ -248,7 +223,7 @@ public void testSendMessageStreaming() throws Exception { .withBody(SEND_MESSAGE_STREAMING_TEST_RESPONSE) ); - RestTransport client = new RestTransport(CARD); + RestTransport client = new RestTransport(AGENT_URL); Message message = new Message.Builder() .role(Message.Role.USER) .parts(Collections.singletonList(new TextPart("tell me some jokes"))) @@ -298,7 +273,7 @@ public void testSetTaskPushNotificationConfiguration() throws Exception { .withStatusCode(200) .withBody(SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) ); - RestTransport client = new RestTransport(CARD); + RestTransport client = new RestTransport(AGENT_URL); TaskPushNotificationConfig pushedConfig = new TaskPushNotificationConfig( "de38c76d-d54c-436c-8b9f-4c2703648d64", new PushNotificationConfig.Builder() @@ -331,7 +306,7 @@ public void testGetTaskPushNotificationConfiguration() throws Exception { .withBody(GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) ); - RestTransport client = new RestTransport(CARD); + RestTransport client = new RestTransport(AGENT_URL); TaskPushNotificationConfig taskPushNotificationConfig = client.getTaskPushNotificationConfiguration( new GetTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", "10", new HashMap<>()), null); @@ -359,7 +334,7 @@ public void testListTaskPushNotificationConfigurations() throws Exception { .withBody(LIST_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) ); - RestTransport client = new RestTransport(CARD); + RestTransport client = new RestTransport(AGENT_URL); List taskPushNotificationConfigs = client.listTaskPushNotificationConfigurations( new ListTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", new HashMap<>()), null); assertEquals(2, taskPushNotificationConfigs.size()); @@ -395,7 +370,7 @@ public void testDeleteTaskPushNotificationConfigurations() throws Exception { .withStatusCode(200) ); ClientCallContext context = null; - RestTransport instance = new RestTransport(CARD); + RestTransport instance = new RestTransport(AGENT_URL); instance.deleteTaskPushNotificationConfigurations(new DeleteTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", "10"), context); } @@ -418,7 +393,7 @@ public void testResubscribe() throws Exception { .withBody(TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE) ); - RestTransport client = new RestTransport(CARD); + RestTransport client = new RestTransport(AGENT_URL); TaskIdParams taskIdParams = new TaskIdParams("task-1234"); AtomicReference receivedEvent = new AtomicReference<>(); diff --git a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/AbstractClientTransport.java b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/AbstractClientTransport.java new file mode 100644 index 00000000..fff6f284 --- /dev/null +++ b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/AbstractClientTransport.java @@ -0,0 +1,31 @@ +package io.a2a.client.transport.spi; + +import io.a2a.client.transport.spi.interceptors.ClientCallContext; +import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor; +import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders; +import io.a2a.spec.AgentCard; +import org.jspecify.annotations.Nullable; + +import java.util.List; + +public abstract class AbstractClientTransport implements ClientTransport { + + private final @Nullable List interceptors; + + public AbstractClientTransport(@Nullable List interceptors) { + this.interceptors = interceptors; + } + + protected PayloadAndHeaders applyInterceptors(String methodName, @Nullable Object payload, + @Nullable AgentCard agentCard, @Nullable ClientCallContext clientCallContext) { + PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload, + clientCallContext != null ? clientCallContext.getHeaders() : null); + if (interceptors != null && ! interceptors.isEmpty()) { + for (ClientCallInterceptor interceptor : interceptors) { + payloadAndHeaders = interceptor.intercept(methodName, payloadAndHeaders.getPayload(), + payloadAndHeaders.getHeaders(), agentCard, clientCallContext); + } + } + return payloadAndHeaders; + } +} diff --git a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/ClientCallInterceptor.java b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/ClientCallInterceptor.java index 41141298..b8a8de79 100644 --- a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/ClientCallInterceptor.java +++ b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/ClientCallInterceptor.java @@ -23,5 +23,5 @@ public abstract class ClientCallInterceptor { * @return the potentially modified payload and headers */ public abstract PayloadAndHeaders intercept(String methodName, @Nullable Object payload, Map headers, - AgentCard agentCard, @Nullable ClientCallContext clientCallContext); + @Nullable AgentCard agentCard, @Nullable ClientCallContext clientCallContext); } diff --git a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/PayloadAndHeaders.java b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/PayloadAndHeaders.java index 4783cb71..816ad3e5 100644 --- a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/PayloadAndHeaders.java +++ b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/PayloadAndHeaders.java @@ -10,7 +10,7 @@ public class PayloadAndHeaders { private final @Nullable Object payload; private final Map headers; - public PayloadAndHeaders(@Nullable Object payload, Map headers) { + public PayloadAndHeaders(@Nullable Object payload, @Nullable Map headers) { this.payload = payload; this.headers = headers == null ? Collections.emptyMap() : new HashMap<>(headers); } diff --git a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/auth/AuthInterceptor.java b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/auth/AuthInterceptor.java index d2f2a576..8fda4ca4 100644 --- a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/auth/AuthInterceptor.java +++ b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/auth/AuthInterceptor.java @@ -33,7 +33,7 @@ public AuthInterceptor(final CredentialService credentialService) { @Override public PayloadAndHeaders intercept(String methodName, @Nullable Object payload, Map headers, - AgentCard agentCard, @Nullable ClientCallContext clientCallContext) { + @Nullable AgentCard agentCard, @Nullable ClientCallContext clientCallContext) { Map updatedHeaders = new HashMap<>(headers == null ? new HashMap<>() : headers); if (agentCard == null || agentCard.security() == null || agentCard.securitySchemes() == null) { return new PayloadAndHeaders(payload, updatedHeaders); diff --git a/extras/README.md b/extras/README.md index 3f85e4f9..19807a8f 100644 --- a/extras/README.md +++ b/extras/README.md @@ -6,4 +6,5 @@ Please see the README's of each child directory for more details. [`task-store-database-jpa`](./task-store-database-jpa/README.md) - Replaces the default `InMemoryTaskStore` with a `TaskStore` backed by a RDBMS. It uses JPA to interact with the RDBMS. [`push-notification-config-store-database-jpa`](./push-notification-config-store-database-jpa/README.md) - Replaces the default `InMemoryPushNotificationConfigStore` with a `PushNotificationConfigStore` backed by a RDBMS. It uses JPA to interact with the RDBMS. -[`queue-manager-replicated`](./queue-manager-replicated/README.md) - Replaces the default `InMemoryQueueManager` with a `QueueManager` supporting replication to other A2A servers implementing the same agent. You can write your own `ReplicationStrategy`, or use the provided `MicroProfile Reactive Messaging implementation`. \ No newline at end of file +[`queue-manager-replicated`](./queue-manager-replicated/README.md) - Replaces the default `InMemoryQueueManager` with a `QueueManager` supporting replication to other A2A servers implementing the same agent. You can write your own `ReplicationStrategy`, or use the provided `MicroProfile Reactive Messaging implementation`. +[`vertx-http-client`](./vertx-http-client/README.md) - Replaces the default `HttpClient` JDK implementation with a http-client implementation backed by Vertx, better suited for Quarkus applications. \ No newline at end of file diff --git a/extras/http-client-vertx/README.md b/extras/http-client-vertx/README.md new file mode 100644 index 00000000..f18b90d3 --- /dev/null +++ b/extras/http-client-vertx/README.md @@ -0,0 +1,76 @@ +# A2A Java SDK - Vertx HTTP Client + +This module provides an HTTP client implementation of the `HttpClient` interface that relies on Vertx for the HTTP transport communication. + +By default, the A2A client is relying on the default JDK HttpClient implementation. While this one is convenient for most of use-cases, it may still +be relevant to switch to the Vertx based implementation, especially when your current code is already relying on Vertx or if your A2A server is based on Quarkus which, itself, heavily relies on Vertx. + +## Quick Start + +This section will get you up and running quickly with a `Client` using the `VertxHttpClient` implementation. + +### 1. Add Dependency + +Add this module to your project's `pom.xml`: + +```xml + + io.github.a2asdk + a2a-java-extras-http-client-vertx + ${a2a.version} + +``` + +### 2. Configure Client + +##### JSON-RPC Transport Configuration + +For the JSON-RPC transport, to use the default `JdkHttpClient`, provide a `JSONRPCTransportConfig` created with its default constructor. + +To use a custom HTTP client implementation, simply create a `JSONRPCTransportConfig` as follows: + +```java +import io.a2a.client.http.vertx.VertxHttpClientBuilder; + +// Create a Vertx HTTP client +HttpClientBuilder vertxHttpClientBuilder = new VertxHttpClientBuilder(); + +// Configure the client settings +ClientConfig clientConfig = new ClientConfig.Builder() + .setAcceptedOutputModes(List.of("text")) + .build(); + +Client client = Client + .builder(agentCard) + .clientConfig(clientConfig) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig(vertxHttpClientBuilder)) + .build(); +``` + +## Configuration Options + +This implementation allows to pass the Vertx context you want to rely on, but also the HTTPClientOptions, in case +you want / need to provide some extended configuration's properties such as a better of management of SSL Context, or an HTTP proxy. + +```java +import io.a2a.client.http.vertx.VertxHttpClientBuilder; +import io.vertx.core.Vertx; +import io.vertx.core.http.HttpClientOptions; +import io.vertx.core.net.ProxyOptions; + +// Create a Vertx HTTP client +HttpClientBuilder vertxHttpClientBuilder = new VertxHttpClientBuilder() + .vertx(Vertx.vertx()) + .options(new HttpClientOptions().setProxyOptions(new ProxyOptions().setHost("host").setPort("1234"))); + + // Configure the client settings + ClientConfig clientConfig = new ClientConfig.Builder() + .setAcceptedOutputModes(List.of("text")) + .build(); + + Client client = Client + .builder(agentCard) + .clientConfig(clientConfig) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig(vertxHttpClientBuilder)) + .build(); +``` diff --git a/extras/http-client-vertx/pom.xml b/extras/http-client-vertx/pom.xml new file mode 100644 index 00000000..656f1943 --- /dev/null +++ b/extras/http-client-vertx/pom.xml @@ -0,0 +1,57 @@ + + + 4.0.0 + + + io.github.a2asdk + a2a-java-sdk-parent + 0.3.0.Beta3-SNAPSHOT + ../../pom.xml + + a2a-java-extras-http-client-vertx + + jar + + Java A2A Extras: Vertx HTTP Client + Java SDK for the Agent2Agent Protocol (A2A) - Extras - Vertx HTTP Client + + + + ${project.groupId} + a2a-java-sdk-http-client + + + + ${project.groupId} + a2a-java-sdk-client + test + + + + ${project.groupId} + a2a-java-sdk-tests-client-common + test-jar + test + + + + io.vertx + vertx-core + + + + org.junit.jupiter + junit-jupiter-api + test + + + + org.wiremock + wiremock + 3.13.1 + test + + + \ No newline at end of file diff --git a/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClient.java b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClient.java new file mode 100644 index 00000000..62284dc6 --- /dev/null +++ b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClient.java @@ -0,0 +1,216 @@ +package io.a2a.client.http.vertx; + +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.Event; +import io.a2a.client.http.vertx.sse.SSEHandler; +import io.a2a.common.A2AErrorMessages; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.vertx.core.*; +import io.vertx.core.http.*; + +import java.io.IOException; +import java.net.*; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; +import java.util.function.Function; + +import static java.net.HttpURLConnection.HTTP_FORBIDDEN; +import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; + +public class VertxHttpClient implements HttpClient { + + private final io.vertx.core.http.HttpClient client; + + private final Vertx vertx; + + VertxHttpClient(String baseUrl, Vertx vertx, HttpClientOptions options) { + this.vertx = vertx; + this.client = initClient(baseUrl, options); + } + + private io.vertx.core.http.HttpClient initClient(String baseUrl, HttpClientOptions options) { + URL targetUrl = buildUrl(baseUrl); + + return this.vertx.createHttpClient(options + .setDefaultHost(targetUrl.getHost()) + .setDefaultPort(targetUrl.getPort() != -1 ? targetUrl.getPort() : targetUrl.getDefaultPort()) + .setSsl(isSecureProtocol(targetUrl.getProtocol()))); + } + + @Override + public GetRequestBuilder get(String path) { + return new VertxGetRequestBuilder(path); + } + + @Override + public PostRequestBuilder post(String path) { + return new VertxPostRequestBuilder(path); + } + + @Override + public DeleteRequestBuilder delete(String path) { + return new VertxDeleteRequestBuilder(path); + } + + private static final URLStreamHandler URL_HANDLER = new URLStreamHandler() { + protected URLConnection openConnection(URL u) { + return null; + } + }; + + private static URL buildUrl(String uri) { + try { + return new URL(null, uri, URL_HANDLER); + } catch (MalformedURLException var2) { + throw new IllegalArgumentException("URI [" + uri + "] is not valid"); + } + } + + private static boolean isSecureProtocol(String protocol) { + return protocol.charAt(protocol.length() - 1) == 's' && protocol.length() > 2; + } + + private abstract class VertxRequestBuilder> implements RequestBuilder { + protected final Future request; + protected final Map headers = new HashMap<>(); + + public VertxRequestBuilder(String path, HttpMethod method) { + this.request = client.request(method, path); + } + + @Override + public T addHeader(String name, String value) { + headers.put(name, value); + return self(); + } + + @Override + public T addHeaders(Map headers) { + if (headers != null && ! headers.isEmpty()) { + for (Map.Entry entry : headers.entrySet()) { + addHeader(entry.getKey(), entry.getValue()); + } + } + return self(); + } + + @SuppressWarnings("unchecked") + T self() { + return (T) this; + } + + protected Future sendRequest() { + return sendRequest(Optional.empty()); + } + + protected Future sendRequest(Optional body) { + return request + .compose(new Function>() { + @Override + public Future apply(HttpClientRequest request) { + // Prepare the request + request.headers().addAll(headers); + + if (body.isPresent()) { + return request.send(body.get()); + } else { + return request.send(); + } + } + }); + } + + @Override + public CompletableFuture send() { + return sendRequest() + .compose(RESPONSE_MAPPER) + .toCompletionStage() + .toCompletableFuture(); + } + } + + private class VertxGetRequestBuilder extends VertxRequestBuilder implements GetRequestBuilder { + + public VertxGetRequestBuilder(String path) { + super(path, HttpMethod.GET); + } + } + + private class VertxDeleteRequestBuilder extends VertxRequestBuilder implements DeleteRequestBuilder { + + public VertxDeleteRequestBuilder(String path) { + super(path, HttpMethod.DELETE); + } + } + + private class VertxPostRequestBuilder extends VertxRequestBuilder implements PostRequestBuilder { + String body = ""; + + public VertxPostRequestBuilder(String path) { + super(path, HttpMethod.POST); + } + + @Override + public PostRequestBuilder body(String body) { + this.body = body; + return this; + } + + @Override + public CompletableFuture send() { + return sendRequest(Optional.of(this.body)) + .compose(RESPONSE_MAPPER) + .toCompletionStage() + .toCompletableFuture(); + } + } + + private final Function> RESPONSE_MAPPER = response -> { + if (response.statusCode() == HTTP_UNAUTHORIZED) { + return Future.failedFuture(new IOException(A2AErrorMessages.AUTHENTICATION_FAILED)); + } else if (response.statusCode() == HTTP_FORBIDDEN) { + return Future.failedFuture(new IOException(A2AErrorMessages.AUTHORIZATION_FAILED)); + } + + return Future.succeededFuture(new VertxHttpResponse(response)); + }; + + private record VertxHttpResponse(HttpClientResponse response)implements HttpResponse { + + @Override + public int statusCode() { + return response.statusCode(); + } + + @Override + public String body() { + try { + return response.body().toCompletionStage().toCompletableFuture().get().toString(); + + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + } + + @Override + public void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer) { + String contentType = response.headers().get(HttpHeaderNames.CONTENT_TYPE.toString()); + + if (contentType != null && HttpHeaderValues.TEXT_EVENT_STREAM.contentEqualsIgnoreCase(contentType)) { + final SSEHandler handler = new SSEHandler(eventConsumer); + + response.handler(handler).exceptionHandler(errorConsumer::accept); + } else { + throw new IllegalStateException("Response is not an event-stream response."); + } + } + } +} diff --git a/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClientBuilder.java b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClientBuilder.java new file mode 100644 index 00000000..c727612b --- /dev/null +++ b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClientBuilder.java @@ -0,0 +1,30 @@ +package io.a2a.client.http.vertx; + +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpClientBuilder; +import io.vertx.core.Vertx; +import io.vertx.core.http.HttpClientOptions; + +public class VertxHttpClientBuilder implements HttpClientBuilder { + + private Vertx vertx; + + private HttpClientOptions options; + + public VertxHttpClientBuilder vertx(Vertx vertx) { + this.vertx = vertx; + return this; + } + + public VertxHttpClientBuilder options(HttpClientOptions options) { + this.options = options; + return this; + } + + @Override + public HttpClient create(String url) { + return new VertxHttpClient(url, + vertx != null ? vertx : Vertx.vertx(), + options != null ? options : new HttpClientOptions()); + } +} diff --git a/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/sse/SSEHandler.java b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/sse/SSEHandler.java new file mode 100644 index 00000000..d9ee3df8 --- /dev/null +++ b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/sse/SSEHandler.java @@ -0,0 +1,124 @@ +package io.a2a.client.http.vertx.sse; + +import io.a2a.client.http.sse.CommentEvent; +import io.a2a.client.http.sse.DataEvent; +import io.a2a.client.http.sse.Event; +import io.vertx.core.Handler; +import io.vertx.core.buffer.Buffer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.function.Consumer; + +public class SSEHandler implements Handler { + + + private static final Logger LOG = LoggerFactory.getLogger(SSEHandler.class); + + private static final String UTF8_BOM = "\uFEFF"; + + private static final String DEFAULT_EVENT_NAME = "message"; + + private String currentEventName = DEFAULT_EVENT_NAME; + private final StringBuilder dataBuffer = new StringBuilder(); + + private String lastEventId = ""; + + private final Consumer eventConsumer; + + public SSEHandler(Consumer eventConsumer) { + this.eventConsumer = eventConsumer; + } + + private void handleFieldValue(String fieldName, String value) { + switch (fieldName) { + case "event": + currentEventName = value; + break; + case "data": + dataBuffer.append(value).append("\n"); + break; + case "id": + if (!value.contains("\0")) { + lastEventId = value; + } + break; + case "retry": + // ignored + break; + } + } + + private String stripLeadingSpaceIfPresent(String field) { + if (field.charAt(0) == ' ') { + return field.substring(1); + } + return field; + } + + private String removeLeadingBom(String input) { + if (input.startsWith(UTF8_BOM)) { + return input.substring(UTF8_BOM.length()); + } + return input; + } + + private String removeTrailingNewline(String input) { + if (input.endsWith("\n")) { + return input.substring(0, input.length() - 1); + } + return input; + } + + private Buffer buffer = Buffer.buffer(); + + @Override + public void handle(Buffer chunk) { + buffer.appendBuffer(chunk); + int separatorIndex; + // The separator for events is a double newline + String separator = "\n\n"; + while ((separatorIndex = buffer.toString().indexOf(separator)) != -1) { + Buffer eventData = buffer.getBuffer(0, separatorIndex); + parse(eventData.toString()); + buffer = buffer.getBuffer(separatorIndex + separator.length(), buffer.length()); + } + } + + private void parse(String input) { + String[] parts = input.split("\n"); + + for (String part : parts) { + LOG.debug("got line `{}`", part); + String line = removeTrailingNewline(removeLeadingBom(part)); + + if (line.startsWith(":")) { + eventConsumer.accept(new CommentEvent(line.substring(1).trim())); + } else if (line.contains(":")) { + List lineParts = List.of(line.split(":", 2)); + if (lineParts.size() == 2) { + handleFieldValue(lineParts.get(0), stripLeadingSpaceIfPresent(lineParts.get(1))); + } + } else { + handleFieldValue(line, ""); + } + } + + LOG.debug( + "broadcasting new event named {} lastEventId is {}", + currentEventName, + lastEventId + ); + + if (!dataBuffer.isEmpty()) { + // Remove trailing newline + dataBuffer.setLength(dataBuffer.length() - 1); + eventConsumer.accept(new DataEvent(currentEventName, dataBuffer.toString(), lastEventId)); + } + + // reset + dataBuffer.setLength(0); + currentEventName = DEFAULT_EVENT_NAME; + } +} diff --git a/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/ClientBuilderTest.java b/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/ClientBuilderTest.java new file mode 100644 index 00000000..68a8399d --- /dev/null +++ b/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/ClientBuilderTest.java @@ -0,0 +1,62 @@ +package io.a2a.client.http.vertx; + +import io.a2a.client.Client; +import io.a2a.client.config.ClientConfig; +import io.a2a.client.transport.jsonrpc.JSONRPCTransport; +import io.a2a.client.transport.jsonrpc.JSONRPCTransportConfigBuilder; +import io.a2a.spec.A2AClientException; +import io.a2a.spec.AgentCapabilities; +import io.a2a.spec.AgentCard; +import io.a2a.spec.AgentInterface; +import io.a2a.spec.AgentSkill; +import io.a2a.spec.TransportProtocol; +import io.vertx.core.Vertx; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; + +/** + * The purpose of this one is to make sure that the Vertx http implementation can be integrated into + * the Client builder when creating a new instance of the Client. + */ +public class ClientBuilderTest { + + private final AgentCard card = new AgentCard.Builder() + .name("Hello World Agent") + .description("Just a hello world agent") + .url("http://localhost:9999") + .version("1.0.0") + .documentationUrl("http://example.com/docs") + .capabilities(new AgentCapabilities.Builder() + .streaming(true) + .pushNotifications(true) + .stateTransitionHistory(true) + .build()) + .defaultInputModes(Collections.singletonList("text")) + .defaultOutputModes(Collections.singletonList("text")) + .skills(Collections.singletonList(new AgentSkill.Builder() + .id("hello_world") + .name("Returns hello world") + .description("just returns hello world") + .tags(Collections.singletonList("hello world")) + .examples(List.of("hi", "hello world")) + .build())) + .protocolVersion("0.3.0") + .additionalInterfaces(List.of( + new AgentInterface(TransportProtocol.JSONRPC.asString(), "http://localhost:9999"))) + .build(); + + @Test + public void shouldCreateJSONRPCClient() throws A2AClientException { + Client client = Client + .builder(card) + .clientConfig(new ClientConfig.Builder().build()) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfigBuilder() + .httpClientBuilder(new VertxHttpClientBuilder().vertx(Vertx.vertx()))) + .build(); + + Assertions.assertNotNull(client); + } +} diff --git a/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/VertxHttpClientTest.java b/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/VertxHttpClientTest.java new file mode 100644 index 00000000..6a94f13e --- /dev/null +++ b/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/VertxHttpClientTest.java @@ -0,0 +1,13 @@ +package io.a2a.client.http.vertx; + +import io.a2a.client.http.HttpClientBuilder; +import io.a2a.client.http.common.AbstractHttpClientTest; +import io.vertx.core.http.HttpClientOptions; + +public class VertxHttpClientTest extends AbstractHttpClientTest { + + protected HttpClientBuilder getHttpClientBuilder() { + return new VertxHttpClientBuilder() + .options(new HttpClientOptions().setMaxChunkSize(24)); + } +} diff --git a/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaPushNotificationConfigStoreTest.java b/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaPushNotificationConfigStoreTest.java index 70f9d1e5..383c4405 100644 --- a/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaPushNotificationConfigStoreTest.java +++ b/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaPushNotificationConfigStoreTest.java @@ -5,13 +5,15 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; +import io.a2a.server.http.HttpClientManager; import org.mockito.ArgumentCaptor; import java.util.List; +import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -22,8 +24,6 @@ import jakarta.inject.Inject; import jakarta.transaction.Transactional; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; import io.a2a.server.tasks.BasePushNotificationSender; import io.a2a.server.tasks.PushNotificationConfigStore; import io.a2a.spec.PushNotificationConfig; @@ -41,18 +41,18 @@ public class JpaPushNotificationConfigStoreTest { private BasePushNotificationSender notificationSender; @Mock - private A2AHttpClient mockHttpClient; + private HttpClientManager clientManager; @Mock - private A2AHttpClient.PostBuilder mockPostBuilder; + private HttpClient.PostRequestBuilder mockPostBuilder; @Mock - private A2AHttpResponse mockHttpResponse; + private HttpResponse mockHttpResponse; @BeforeEach public void setUp() { MockitoAnnotations.openMocks(this); - notificationSender = new BasePushNotificationSender(configStore, mockHttpClient); + notificationSender = new BasePushNotificationSender(configStore, clientManager); } @Test @@ -232,21 +232,22 @@ public void testSendNotificationSuccess() throws Exception { PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", null); configStore.setInfo(taskId, config); + HttpClient mockHttpClient = mock(HttpClient.class); + when(clientManager.getOrCreate(any())).thenReturn(mockHttpClient); + // Mock successful HTTP response - when(mockHttpClient.createPost()).thenReturn(mockPostBuilder); - when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); + when(mockHttpClient.post(any())).thenReturn(mockPostBuilder); when(mockPostBuilder.body(any(String.class))).thenReturn(mockPostBuilder); - when(mockPostBuilder.post()).thenReturn(mockHttpResponse); + when(mockPostBuilder.send()).thenReturn(CompletableFuture.completedFuture(mockHttpResponse)); when(mockHttpResponse.success()).thenReturn(true); notificationSender.sendNotification(task); // Verify HTTP client was called ArgumentCaptor bodyCaptor = ArgumentCaptor.forClass(String.class); - verify(mockHttpClient).createPost(); - verify(mockPostBuilder).url(config.url()); + verify(mockHttpClient).post(any()); verify(mockPostBuilder).body(bodyCaptor.capture()); - verify(mockPostBuilder).post(); + verify(mockPostBuilder).send(); // Verify the request body contains the task data String sentBody = bodyCaptor.getValue(); @@ -263,11 +264,13 @@ public void testSendNotificationWithToken() throws Exception { PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", "unique_token"); configStore.setInfo(taskId, config); + HttpClient mockHttpClient = mock(HttpClient.class); + when(clientManager.getOrCreate(any())).thenReturn(mockHttpClient); + // Mock successful HTTP response - when(mockHttpClient.createPost()).thenReturn(mockPostBuilder); - when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); + when(mockHttpClient.post(any())).thenReturn(mockPostBuilder); when(mockPostBuilder.body(any(String.class))).thenReturn(mockPostBuilder); - when(mockPostBuilder.post()).thenReturn(mockHttpResponse); + when(mockPostBuilder.send()).thenReturn(CompletableFuture.completedFuture(mockHttpResponse)); when(mockHttpResponse.success()).thenReturn(true); notificationSender.sendNotification(task); @@ -279,10 +282,9 @@ public void testSendNotificationWithToken() throws Exception { // For now, just verify basic HTTP client interaction ArgumentCaptor bodyCaptor = ArgumentCaptor.forClass(String.class); - verify(mockHttpClient).createPost(); - verify(mockPostBuilder).url(config.url()); + verify(mockHttpClient).post(any()); verify(mockPostBuilder).body(bodyCaptor.capture()); - verify(mockPostBuilder).post(); + verify(mockPostBuilder).send(); // Verify the request body contains the task data String sentBody = bodyCaptor.getValue(); @@ -299,7 +301,7 @@ public void testSendNotificationNoConfig() throws Exception { notificationSender.sendNotification(task); // Verify HTTP client was never called - verify(mockHttpClient, never()).createPost(); + verify(clientManager, never()).getOrCreate(any()); } @Test diff --git a/http-client/pom.xml b/http-client/pom.xml index 4e138b09..e8c4541e 100644 --- a/http-client/pom.xml +++ b/http-client/pom.xml @@ -29,8 +29,8 @@ - org.mock-server - mockserver-netty + org.wiremock + wiremock test diff --git a/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java b/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java index 5d94686b..d938bb93 100644 --- a/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java +++ b/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java @@ -2,10 +2,10 @@ import static io.a2a.util.Utils.unmarshalFrom; -import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.util.Map; +import java.util.concurrent.ExecutionException; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; @@ -15,63 +15,77 @@ import org.jspecify.annotations.Nullable; public class A2ACardResolver { - private final A2AHttpClient httpClient; - private final String url; + private final HttpClient httpClient; private final @Nullable Map authHeaders; - + private final String agentCardPath; private static final String DEFAULT_AGENT_CARD_PATH = "/.well-known/agent-card.json"; private static final TypeReference AGENT_CARD_TYPE_REFERENCE = new TypeReference<>() {}; /** * Get the agent card for an A2A agent. - * The {@code JdkA2AHttpClient} will be used to fetch the agent card. + * The {@code HttpClient} will be used to fetch the agent card. * * @param baseUrl the base URL for the agent whose agent card we want to retrieve * @throws A2AClientError if the URL for the agent is invalid */ public A2ACardResolver(String baseUrl) throws A2AClientError { - this(new JdkA2AHttpClient(), baseUrl, null, null); + this.httpClient = HttpClient.createHttpClient(baseUrl); + this.authHeaders = null; + + try { + String agentCardPath = new URI(baseUrl).getPath(); + + if (agentCardPath.endsWith("/")) { + agentCardPath = agentCardPath.substring(0, agentCardPath.length() - 1); + } + + if (agentCardPath.isEmpty()) { + this.agentCardPath = DEFAULT_AGENT_CARD_PATH; + } else if (agentCardPath.endsWith(DEFAULT_AGENT_CARD_PATH)) { + this.agentCardPath = agentCardPath; + } else { + this.agentCardPath = agentCardPath + DEFAULT_AGENT_CARD_PATH; + } + } catch (URISyntaxException e) { + throw new A2AClientError("Invalid agent URL", e); + } } /** - /**Get the agent card for an A2A agent. - * * @param httpClient the http client to use - * @param baseUrl the base URL for the agent whose agent card we want to retrieve * @throws A2AClientError if the URL for the agent is invalid */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl) throws A2AClientError { - this(httpClient, baseUrl, null, null); + A2ACardResolver(HttpClient httpClient) throws A2AClientError { + this(httpClient, null, null); } /** * @param httpClient the http client to use - * @param baseUrl the base URL for the agent whose agent card we want to retrieve * @param agentCardPath optional path to the agent card endpoint relative to the base * agent URL, defaults to ".well-known/agent-card.json" * @throws A2AClientError if the URL for the agent is invalid */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCardPath) throws A2AClientError { - this(httpClient, baseUrl, agentCardPath, null); + public A2ACardResolver(HttpClient httpClient, String agentCardPath) throws A2AClientError { + this(httpClient, agentCardPath, null); } /** * @param httpClient the http client to use - * @param baseUrl the base URL for the agent whose agent card we want to retrieve * @param agentCardPath optional path to the agent card endpoint relative to the base * agent URL, defaults to ".well-known/agent-card.json" * @param authHeaders the HTTP authentication headers to use. May be {@code null} * @throws A2AClientError if the URL for the agent is invalid */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, @Nullable String agentCardPath, - @Nullable Map authHeaders) throws A2AClientError { + public A2ACardResolver(HttpClient httpClient, @Nullable String agentCardPath, + @Nullable Map authHeaders) throws A2AClientError { this.httpClient = httpClient; - String effectiveAgentCardPath = agentCardPath == null || agentCardPath.isEmpty() ? DEFAULT_AGENT_CARD_PATH : agentCardPath; - try { - this.url = new URI(baseUrl).resolve(effectiveAgentCardPath).toString(); - } catch (URISyntaxException e) { - throw new A2AClientError("Invalid agent URL", e); + if (agentCardPath == null || agentCardPath.isEmpty()) { + this.agentCardPath = DEFAULT_AGENT_CARD_PATH; + } else if (agentCardPath.endsWith(DEFAULT_AGENT_CARD_PATH)) { + this.agentCardPath = agentCardPath; + } else { + this.agentCardPath = agentCardPath + DEFAULT_AGENT_CARD_PATH; } this.authHeaders = authHeaders; } @@ -84,8 +98,7 @@ public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, @Nullable Strin * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { - A2AHttpClient.GetBuilder builder = httpClient.createGet() - .url(url) + HttpClient.GetRequestBuilder builder = httpClient.get(agentCardPath) .addHeader("Content-Type", "application/json"); if (authHeaders != null) { @@ -95,13 +108,14 @@ public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { } String body; + try { - A2AHttpResponse response = builder.get(); + HttpResponse response = builder.send().get(); if (!response.success()) { - throw new A2AClientError("Failed to obtain agent card: " + response.status()); + throw new A2AClientError("Failed to obtain agent card: " + response.statusCode()); } body = response.body(); - } catch (IOException | InterruptedException e) { + } catch (InterruptedException | ExecutionException e) { throw new A2AClientError("Failed to obtain agent card", e); } @@ -110,8 +124,5 @@ public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { } catch (JsonProcessingException e) { throw new A2AClientJSONError("Could not unmarshal agent card response", e); } - } - - } diff --git a/http-client/src/main/java/io/a2a/client/http/A2AHttpClient.java b/http-client/src/main/java/io/a2a/client/http/A2AHttpClient.java deleted file mode 100644 index 52c252a8..00000000 --- a/http-client/src/main/java/io/a2a/client/http/A2AHttpClient.java +++ /dev/null @@ -1,42 +0,0 @@ -package io.a2a.client.http; - -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; - -public interface A2AHttpClient { - - GetBuilder createGet(); - - PostBuilder createPost(); - - DeleteBuilder createDelete(); - - interface Builder> { - T url(String s); - T addHeaders(Map headers); - T addHeader(String name, String value); - } - - interface GetBuilder extends Builder { - A2AHttpResponse get() throws IOException, InterruptedException; - CompletableFuture getAsyncSSE( - Consumer messageConsumer, - Consumer errorConsumer, - Runnable completeRunnable) throws IOException, InterruptedException; - } - - interface PostBuilder extends Builder { - PostBuilder body(String body); - A2AHttpResponse post() throws IOException, InterruptedException; - CompletableFuture postAsyncSSE( - Consumer messageConsumer, - Consumer errorConsumer, - Runnable completeRunnable) throws IOException, InterruptedException; - } - - interface DeleteBuilder extends Builder { - A2AHttpResponse delete() throws IOException, InterruptedException; - } -} diff --git a/http-client/src/main/java/io/a2a/client/http/A2AHttpResponse.java b/http-client/src/main/java/io/a2a/client/http/A2AHttpResponse.java deleted file mode 100644 index 171fceeb..00000000 --- a/http-client/src/main/java/io/a2a/client/http/A2AHttpResponse.java +++ /dev/null @@ -1,9 +0,0 @@ -package io.a2a.client.http; - -public interface A2AHttpResponse { - int status(); - - boolean success(); - - String body(); -} diff --git a/http-client/src/main/java/io/a2a/client/http/HttpClient.java b/http-client/src/main/java/io/a2a/client/http/HttpClient.java new file mode 100644 index 00000000..1cb14fde --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/HttpClient.java @@ -0,0 +1,45 @@ +package io.a2a.client.http; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +public interface HttpClient { + + static HttpClient createHttpClient(String baseUrl) { + return HttpClientBuilder.DEFAULT_FACTORY.create(baseUrl); + } + + GetRequestBuilder get(String path); + + PostRequestBuilder post(String path); + + DeleteRequestBuilder delete(String path); + + interface RequestBuilder> { + CompletableFuture send(); + + T addHeader(String name, String value); + + T addHeaders(Map headers); + } + + interface GetRequestBuilder extends RequestBuilder { + + } + + interface PostRequestBuilder extends RequestBuilder { + PostRequestBuilder body(String body); + + default PostRequestBuilder asSSE() { + return addHeader("Accept", "text/event-stream"); + } + + default CompletableFuture send(String body) { + return this.body(body).send(); + } + } + + interface DeleteRequestBuilder extends RequestBuilder { + + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/HttpClientBuilder.java b/http-client/src/main/java/io/a2a/client/http/HttpClientBuilder.java new file mode 100644 index 00000000..1e894a9d --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/HttpClientBuilder.java @@ -0,0 +1,10 @@ +package io.a2a.client.http; + +import io.a2a.client.http.jdk.JdkHttpClientBuilder; + +public interface HttpClientBuilder { + + HttpClientBuilder DEFAULT_FACTORY = new JdkHttpClientBuilder(); + + HttpClient create(String url); +} diff --git a/http-client/src/main/java/io/a2a/client/http/HttpResponse.java b/http-client/src/main/java/io/a2a/client/http/HttpResponse.java new file mode 100644 index 00000000..3e2f35f6 --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/HttpResponse.java @@ -0,0 +1,17 @@ +package io.a2a.client.http; + +import io.a2a.client.http.sse.Event; + +import java.util.function.Consumer; + +public interface HttpResponse { + int statusCode(); + + default boolean success() { + return statusCode() >= 200 && statusCode() < 300; + } + + String body(); + + void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer); +} diff --git a/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java b/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java deleted file mode 100644 index 9b800374..00000000 --- a/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java +++ /dev/null @@ -1,311 +0,0 @@ -package io.a2a.client.http; - -import static java.net.HttpURLConnection.HTTP_FORBIDDEN; -import static java.net.HttpURLConnection.HTTP_MULT_CHOICE; -import static java.net.HttpURLConnection.HTTP_OK; -import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; - -import java.io.IOException; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.net.http.HttpResponse.BodyHandler; -import java.net.http.HttpResponse.BodyHandlers; -import java.net.http.HttpResponse.BodySubscribers; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Flow; -import java.util.function.Consumer; -import org.jspecify.annotations.Nullable; - -import io.a2a.common.A2AErrorMessages; - -public class JdkA2AHttpClient implements A2AHttpClient { - - private final HttpClient httpClient; - - public JdkA2AHttpClient() { - httpClient = HttpClient.newBuilder() - .version(HttpClient.Version.HTTP_2) - .followRedirects(HttpClient.Redirect.NORMAL) - .build(); - } - - @Override - public GetBuilder createGet() { - return new JdkGetBuilder(); - } - - @Override - public PostBuilder createPost() { - return new JdkPostBuilder(); - } - - @Override - public DeleteBuilder createDelete() { - return new JdkDeleteBuilder(); - } - - private abstract class JdkBuilder> implements Builder { - private String url = ""; - private Map headers = new HashMap<>(); - - @Override - public T url(String url) { - this.url = url; - return self(); - } - - @Override - public T addHeader(String name, String value) { - headers.put(name, value); - return self(); - } - - @Override - public T addHeaders(Map headers) { - if(headers != null && ! headers.isEmpty()) { - for (Map.Entry entry : headers.entrySet()) { - addHeader(entry.getKey(), entry.getValue()); - } - } - return self(); - } - - @SuppressWarnings("unchecked") - T self() { - return (T) this; - } - - protected HttpRequest.Builder createRequestBuilder() throws IOException { - HttpRequest.Builder builder = HttpRequest.newBuilder() - .uri(URI.create(url)); - for (Map.Entry headerEntry : headers.entrySet()) { - builder.header(headerEntry.getKey(), headerEntry.getValue()); - } - return builder; - } - - protected CompletableFuture asyncRequest( - HttpRequest request, - Consumer messageConsumer, - Consumer errorConsumer, - Runnable completeRunnable - ) { - Flow.Subscriber subscriber = new Flow.Subscriber() { - private Flow.@Nullable Subscription subscription; - private volatile boolean errorRaised = false; - - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - this.subscription.request(1); - } - - @Override - public void onNext(String item) { - // SSE messages sometimes start with "data:". Strip that off - if (item != null && item.startsWith("data:")) { - item = item.substring(5).trim(); - if (!item.isEmpty()) { - messageConsumer.accept(item); - } - } - if (subscription != null) { - subscription.request(1); - } - } - - @Override - public void onError(Throwable throwable) { - if (!errorRaised) { - errorRaised = true; - errorConsumer.accept(throwable); - } - if (subscription != null) { - subscription.cancel(); - } - } - - @Override - public void onComplete() { - if (!errorRaised) { - completeRunnable.run(); - } - if (subscription != null) { - subscription.cancel(); - } - } - }; - - // Create a custom body handler that checks status before processing body - BodyHandler bodyHandler = responseInfo -> { - // Check for authentication/authorization errors only - if (responseInfo.statusCode() == HTTP_UNAUTHORIZED || responseInfo.statusCode() == HTTP_FORBIDDEN) { - final String errorMessage; - if (responseInfo.statusCode() == HTTP_UNAUTHORIZED) { - errorMessage = A2AErrorMessages.AUTHENTICATION_FAILED; - } else { - errorMessage = A2AErrorMessages.AUTHORIZATION_FAILED; - } - // Return a body subscriber that immediately signals error - return BodySubscribers.fromSubscriber(new Flow.Subscriber>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscriber.onError(new IOException(errorMessage)); - } - - @Override - public void onNext(List item) { - // Should not be called - } - - @Override - public void onError(Throwable throwable) { - // Should not be called - } - - @Override - public void onComplete() { - // Should not be called - } - }); - } else { - // For all other status codes (including other errors), proceed with normal line subscriber - return BodyHandlers.fromLineSubscriber(subscriber).apply(responseInfo); - } - }; - - // Send the response async, and let the subscriber handle the lines. - return httpClient.sendAsync(request, bodyHandler) - .thenAccept(response -> { - // Handle non-authentication/non-authorization errors here - if (!isSuccessStatus(response.statusCode()) && - response.statusCode() != HTTP_UNAUTHORIZED && - response.statusCode() != HTTP_FORBIDDEN) { - subscriber.onError(new IOException("Request failed with status " + response.statusCode() + ":" + response.body())); - } - }); - } - } - - private class JdkGetBuilder extends JdkBuilder implements A2AHttpClient.GetBuilder { - - private HttpRequest.Builder createRequestBuilder(boolean SSE) throws IOException { - HttpRequest.Builder builder = super.createRequestBuilder().GET(); - if (SSE) { - builder.header("Accept", "text/event-stream"); - } - return builder; - } - - @Override - public A2AHttpResponse get() throws IOException, InterruptedException { - HttpRequest request = createRequestBuilder(false) - .build(); - HttpResponse response = - httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); - return new JdkHttpResponse(response); - } - - @Override - public CompletableFuture getAsyncSSE( - Consumer messageConsumer, - Consumer errorConsumer, - Runnable completeRunnable) throws IOException, InterruptedException { - HttpRequest request = createRequestBuilder(true) - .build(); - return super.asyncRequest(request, messageConsumer, errorConsumer, completeRunnable); - } - - } - - private class JdkDeleteBuilder extends JdkBuilder implements A2AHttpClient.DeleteBuilder { - - @Override - public A2AHttpResponse delete() throws IOException, InterruptedException { - HttpRequest request = super.createRequestBuilder().DELETE().build(); - HttpResponse response = - httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); - return new JdkHttpResponse(response); - } - - } - - private class JdkPostBuilder extends JdkBuilder implements A2AHttpClient.PostBuilder { - String body = ""; - - @Override - public PostBuilder body(String body) { - this.body = body; - return self(); - } - - private HttpRequest.Builder createRequestBuilder(boolean SSE) throws IOException { - HttpRequest.Builder builder = super.createRequestBuilder() - .POST(HttpRequest.BodyPublishers.ofString(body, StandardCharsets.UTF_8)); - if (SSE) { - builder.header("Accept", "text/event-stream"); - } - return builder; - } - - @Override - public A2AHttpResponse post() throws IOException, InterruptedException { - HttpRequest request = createRequestBuilder(false) - .POST(HttpRequest.BodyPublishers.ofString(body, StandardCharsets.UTF_8)) - .build(); - HttpResponse response = - httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); - - if (response.statusCode() == HTTP_UNAUTHORIZED) { - throw new IOException(A2AErrorMessages.AUTHENTICATION_FAILED); - } else if (response.statusCode() == HTTP_FORBIDDEN) { - throw new IOException(A2AErrorMessages.AUTHORIZATION_FAILED); - } - - return new JdkHttpResponse(response); - } - - @Override - public CompletableFuture postAsyncSSE( - Consumer messageConsumer, - Consumer errorConsumer, - Runnable completeRunnable) throws IOException, InterruptedException { - HttpRequest request = createRequestBuilder(true) - .build(); - return super.asyncRequest(request, messageConsumer, errorConsumer, completeRunnable); - } - } - - private record JdkHttpResponse(HttpResponse response) implements A2AHttpResponse { - - @Override - public int status() { - return response.statusCode(); - } - - @Override - public boolean success() {// Send the request and get the response - return success(response); - } - - static boolean success(HttpResponse response) { - return response.statusCode() >= HTTP_OK && response.statusCode() < HTTP_MULT_CHOICE; - } - - @Override - public String body() { - return response.body(); - } - } - - private static boolean isSuccessStatus(int statusCode) { - return statusCode >= HTTP_OK && statusCode < HTTP_MULT_CHOICE; - } -} diff --git a/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClient.java b/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClient.java new file mode 100644 index 00000000..83e31208 --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClient.java @@ -0,0 +1,260 @@ +package io.a2a.client.http.jdk; + +import static java.net.HttpURLConnection.HTTP_FORBIDDEN; +import static java.net.HttpURLConnection.HTTP_MULT_CHOICE; +import static java.net.HttpURLConnection.HTTP_OK; +import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; + +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.jdk.sse.SSEHandler; +import io.a2a.client.http.sse.Event; + +import java.io.IOException; +import java.net.*; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Flow; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.a2a.common.A2AErrorMessages; + +class JdkHttpClient implements HttpClient { + + private final java.net.http.HttpClient httpClient; + private final String baseUrl; + + JdkHttpClient(String baseUrl) { + this.httpClient = java.net.http.HttpClient.newBuilder() + .version(java.net.http.HttpClient.Version.HTTP_2) + .followRedirects(java.net.http.HttpClient.Redirect.NORMAL) + .build(); + + URL targetUrl = buildUrl(baseUrl); + this.baseUrl = targetUrl.getProtocol() + "://" + targetUrl.getAuthority(); + } + + String getBaseUrl() { + return baseUrl; + } + + private static final URLStreamHandler URL_HANDLER = new URLStreamHandler() { + protected URLConnection openConnection(URL u) { + return null; + } + }; + + private static URL buildUrl(String uri) { + try { + return new URL(null, uri, URL_HANDLER); + } catch (MalformedURLException var2) { + throw new IllegalArgumentException("URI [" + uri + "] is not valid"); + } + } + + @Override + public GetRequestBuilder get(String path) { + return new JdkGetRequestBuilder(path); + } + + @Override + public PostRequestBuilder post(String path) { + return new JdkPostRequestBuilder(path); + } + + @Override + public DeleteRequestBuilder delete(String path) { + return new JdkDeleteBuilder(path); + } + + private abstract class JdkRequestBuilder> implements RequestBuilder { + private final String path; + protected final Map headers = new HashMap<>(); + + public JdkRequestBuilder(String path) { + this.path = path; + } + + @Override + public T addHeader(String name, String value) { + headers.put(name, value); + return self(); + } + + @Override + public T addHeaders(Map headers) { + if (headers != null && !headers.isEmpty()) { + for (Map.Entry entry : headers.entrySet()) { + addHeader(entry.getKey(), entry.getValue()); + } + } + return self(); + } + + @SuppressWarnings("unchecked") + T self() { + return (T) this; + } + + protected HttpRequest.Builder createRequestBuilder() { + HttpRequest.Builder builder = HttpRequest.newBuilder() + .uri(URI.create(baseUrl + path)); + for (Map.Entry headerEntry : headers.entrySet()) { + builder.header(headerEntry.getKey(), headerEntry.getValue()); + } + return builder; + } + } + + private class JdkGetRequestBuilder extends JdkRequestBuilder implements GetRequestBuilder { + + public JdkGetRequestBuilder(String path) { + super(path); + } + + @Override + public CompletableFuture send() { + HttpRequest request = super.createRequestBuilder().GET().build(); + return httpClient + .sendAsync(request, BodyHandlers.ofString(StandardCharsets.UTF_8)) + .thenCompose(RESPONSE_MAPPER); + } + } + + private class JdkDeleteBuilder extends JdkRequestBuilder implements DeleteRequestBuilder { + + public JdkDeleteBuilder(String path) { + super(path); + } + + @Override + public CompletableFuture send() { + HttpRequest request = super.createRequestBuilder().DELETE().build(); + return httpClient + .sendAsync(request, BodyHandlers.ofString(StandardCharsets.UTF_8)) + .thenCompose(RESPONSE_MAPPER); + } + } + + private class JdkPostRequestBuilder extends JdkRequestBuilder implements PostRequestBuilder { + String body = ""; + + public JdkPostRequestBuilder(String path) { + super(path); + } + + @Override + public PostRequestBuilder body(String body) { + this.body = body; + return this; + } + + @Override + public CompletableFuture send() { + final HttpRequest request = super.createRequestBuilder() + .POST(HttpRequest.BodyPublishers.ofString(body, StandardCharsets.UTF_8)) + .build(); + + final BodyHandler bodyHandler; + + final String contentTypeHeader = this.headers.get("Accept"); + if ("text/event-stream".equalsIgnoreCase(contentTypeHeader)) { + bodyHandler = BodyHandlers.ofPublisher(); + } else { + bodyHandler = BodyHandlers.ofString(StandardCharsets.UTF_8); + } + + return httpClient.sendAsync(request, bodyHandler).thenCompose(RESPONSE_MAPPER); + } + } + + private final static Function, CompletionStage> RESPONSE_MAPPER = response -> { + if (response.statusCode() == HTTP_UNAUTHORIZED) { + return CompletableFuture.failedStage(new IOException(A2AErrorMessages.AUTHENTICATION_FAILED)); + } else if (response.statusCode() == HTTP_FORBIDDEN) { + return CompletableFuture.failedStage(new IOException(A2AErrorMessages.AUTHORIZATION_FAILED)); + } + + return CompletableFuture.completedFuture(new JdkHttpResponse(response)); + }; + + private record JdkHttpResponse(java.net.http.HttpResponse response) implements HttpResponse { + + @Override + public int statusCode() { + return response.statusCode(); + } + + static boolean success(java.net.http.HttpResponse response) { + return response.statusCode() >= HTTP_OK && response.statusCode() < HTTP_MULT_CHOICE; + } + + @Override + public String body() { + if (response.body() instanceof String) { + return (String) response.body(); + } + + throw new IllegalStateException(); + } + + @Override + public void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer) { + if (success()) { + Optional contentTypeOpt = response.headers().firstValue("Content-Type"); + + if (contentTypeOpt.isPresent() && contentTypeOpt.get().equalsIgnoreCase("text/event-stream")) { + Flow.Publisher> publisher = (Flow.Publisher>) response.body(); + + SSEHandler sseHandler = new SSEHandler(); + sseHandler.subscribe(new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(Event item) { + eventConsumer.accept(item); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + errorConsumer.accept(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + publisher.subscribe(java.net.http.HttpResponse.BodySubscribers.fromLineSubscriber(sseHandler)); + } else { + errorConsumer.accept(new IOException("Response is not an event-stream response: Content-Type[" + contentTypeOpt.orElse("unknown") + "]")); + } + } else { + errorConsumer.accept(new IOException("Request failed: status[" + response.statusCode() + "]")); + } + } + } + + private static boolean isSuccessStatus(int statusCode) { + return statusCode >= HTTP_OK && statusCode < HTTP_MULT_CHOICE; + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClientBuilder.java b/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClientBuilder.java new file mode 100644 index 00000000..21f50ade --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClientBuilder.java @@ -0,0 +1,12 @@ +package io.a2a.client.http.jdk; + +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpClientBuilder; + +public class JdkHttpClientBuilder implements HttpClientBuilder { + + @Override + public HttpClient create(String url) { + return new JdkHttpClient(url); + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/jdk/sse/SSEHandler.java b/http-client/src/main/java/io/a2a/client/http/jdk/sse/SSEHandler.java new file mode 100644 index 00000000..b7975bae --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/jdk/sse/SSEHandler.java @@ -0,0 +1,120 @@ +package io.a2a.client.http.jdk.sse; + +import io.a2a.client.http.sse.CommentEvent; +import io.a2a.client.http.sse.DataEvent; +import io.a2a.client.http.sse.Event; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.concurrent.Flow; +import java.util.concurrent.SubmissionPublisher; + +public class SSEHandler extends SubmissionPublisher + implements Flow.Processor { + + public static final String EVENT_STREAM_MEDIA_TYPE = "text/event-stream"; + + private static final Logger LOG = LoggerFactory.getLogger(SSEHandler.class); + + private static final String UTF8_BOM = "\uFEFF"; + + private static final String DEFAULT_EVENT_NAME = "message"; + + private Flow.Subscription subscription; + + private String currentEventName = DEFAULT_EVENT_NAME; + private final StringBuilder dataBuffer = new StringBuilder(); + + private String lastEventId = ""; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String input) { + LOG.debug("got line `{}`", input); + String line = removeTrailingNewline(removeLeadingBom(input)); + + if (line.startsWith(":")) { + submit(new CommentEvent(line.substring(1).trim())); + } else if (line.isBlank()) { + LOG.debug( + "broadcasting new event named {} lastEventId is {}", + currentEventName, + lastEventId + ); + + String dataString = dataBuffer.toString(); + if (!dataString.isEmpty()) { + submit(new DataEvent(currentEventName, dataBuffer.toString(), lastEventId)); + } + //reset things + dataBuffer.setLength(0); + currentEventName = DEFAULT_EVENT_NAME; + } else if (line.contains(":")) { + List lineParts = List.of(line.split(":", 2)); + if (lineParts.size() == 2) { + handleFieldValue(lineParts.get(0), stripLeadingSpaceIfPresent(lineParts.get(1))); + } + } else { + handleFieldValue(line, ""); + } + subscription.request(1); + } + + private void handleFieldValue(String fieldName, String value) { + switch (fieldName) { + case "event": + currentEventName = value; + break; + case "data": + dataBuffer.append(value).append("\n"); + break; + case "id": + if (!value.contains("\0")) { + lastEventId = value; + } + break; + case "retry": + // ignored + break; + } + } + + @Override + public void onError(Throwable throwable) { + LOG.debug("Error in SSE handler {}", throwable.getMessage()); + closeExceptionally(throwable); + } + + @Override + public void onComplete() { + LOG.debug("SSE handler complete"); + close(); + } + + private String stripLeadingSpaceIfPresent(String field) { + if (field.charAt(0) == ' ') { + return field.substring(1); + } + return field; + } + + private String removeLeadingBom(String input) { + if (input.startsWith(UTF8_BOM)) { + return input.substring(UTF8_BOM.length()); + } + return input; + } + + private String removeTrailingNewline(String input) { + if (input.endsWith("\n")) { + return input.substring(0, input.length() - 1); + } + return input; + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/sse/CommentEvent.java b/http-client/src/main/java/io/a2a/client/http/sse/CommentEvent.java new file mode 100644 index 00000000..0a0b3e68 --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/sse/CommentEvent.java @@ -0,0 +1,54 @@ +package io.a2a.client.http.sse; + +import java.util.Objects; +import java.util.StringJoiner; + +/** + * Represents an SSE Comment + * This is a line starting with a colon (:) + */ +public class CommentEvent extends Event { + + private final String comment; + + @Override + Type getType() { + return Type.COMMENT; + } + + public CommentEvent(String comment) { + this.comment = comment; + } + + /** + * + * @return the contents of the last line starting with `:` (omitting the colon) + */ + public String getComment() { + return comment; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CommentEvent that = (CommentEvent) o; + return Objects.equals(comment, that.comment); + } + + @Override + public int hashCode() { + return Objects.hash(comment); + } + + @Override + public String toString() { + return new StringJoiner(", ", CommentEvent.class.getSimpleName() + "[", "]") + .add("comment='" + comment + "'") + .toString(); + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/sse/DataEvent.java b/http-client/src/main/java/io/a2a/client/http/sse/DataEvent.java new file mode 100644 index 00000000..daae9c5c --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/sse/DataEvent.java @@ -0,0 +1,81 @@ +package io.a2a.client.http.sse; + +import java.util.Objects; +import java.util.StringJoiner; + +/** + * Represents an SSE DataEvent + * It contains three fields: event name, data, and lastEventId + */ +public class DataEvent extends Event { + + private final String eventName; + private final String data; + private final String lastEventId; + + public DataEvent(String eventName, String data, String lastEventId) { + this.eventName = eventName; + this.data = data; + this.lastEventId = lastEventId; + } + + @Override + Type getType() { + return Type.DATA; + } + + /** + * + * @return the content of the last line starting with `event:` + */ + public String getEventName() { + return eventName; + } + + /** + * + * @return the accumulated contents of data buffers from lines starting with `data:` + */ + public String getData() { + return data; + } + + /** + * + * @return the last event id sent in a line starting with `id:` + */ + public String getLastEventId() { + return lastEventId; + } + + @Override + public String toString() { + return new StringJoiner(", ", DataEvent.class.getSimpleName() + "[", "]") + .add("eventName='" + eventName + "'") + .add("data='" + data + "'") + .add("lastEventId='" + lastEventId + "'") + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DataEvent event = (DataEvent) o; + return ( + Objects.equals(getType(), event.getType()) && + Objects.equals(eventName, event.eventName) && + Objects.equals(data, event.data) && + Objects.equals(lastEventId, event.lastEventId) + ); + } + + @Override + public int hashCode() { + return Objects.hash(getType(), eventName, data, lastEventId); + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/sse/Event.java b/http-client/src/main/java/io/a2a/client/http/sse/Event.java new file mode 100644 index 00000000..66920e9c --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/sse/Event.java @@ -0,0 +1,11 @@ +package io.a2a.client.http.sse; + +public abstract class Event { + + enum Type { + COMMENT, + DATA, + } + + abstract Type getType(); +} diff --git a/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java b/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java index 99d26ada..6f921c8d 100644 --- a/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java +++ b/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java @@ -1,20 +1,20 @@ package io.a2a.client.http; +import static com.github.tomakehurst.wiremock.client.WireMock.*; import static io.a2a.util.Utils.OBJECT_MAPPER; import static io.a2a.util.Utils.unmarshalFrom; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; -import java.io.IOException; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; import com.fasterxml.jackson.core.type.TypeReference; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.core.WireMockConfiguration; import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientJSONError; import io.a2a.spec.AgentCard; -import java.util.Map; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class A2ACardResolverTest { @@ -22,54 +22,90 @@ public class A2ACardResolverTest { private static final String AGENT_CARD_PATH = "/.well-known/agent-card.json"; private static final TypeReference AGENT_CARD_TYPE_REFERENCE = new TypeReference<>() {}; + private WireMockServer server; + + @BeforeEach + public void setUp() { + server = new WireMockServer(WireMockConfiguration.options().dynamicPort()); + server.start(); + + configureFor("localhost", server.port()); + } + + @AfterEach + public void tearDown() { + if (server != null) { + server.stop(); + } + } + @Test public void testConstructorStripsSlashes() throws Exception { - TestHttpClient client = new TestHttpClient(); - client.body = JsonMessages.AGENT_CARD; + HttpClient client = HttpClient.createHttpClient("http://localhost:" + server.port()); + + givenThat(get(urlPathEqualTo(AGENT_CARD_PATH)) + .willReturn(okForContentType("application/json", JsonMessages.AGENT_CARD))); + + givenThat(get(urlPathEqualTo("/subpath" + AGENT_CARD_PATH)) + .willReturn(okForContentType("application/json", JsonMessages.AGENT_CARD))); - A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + A2ACardResolver resolver = new A2ACardResolver(client); AgentCard card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); - resolver = new A2ACardResolver(client, "http://example.com"); + resolver = new A2ACardResolver(client, AGENT_CARD_PATH); card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); - // baseUrl with trailing slash, agentCardParth with leading slash - resolver = new A2ACardResolver(client, "http://example.com/", AGENT_CARD_PATH); + + resolver = new A2ACardResolver("http://localhost:" + server.port()); card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); - // baseUrl without trailing slash, agentCardPath with leading slash - resolver = new A2ACardResolver(client, "http://example.com", AGENT_CARD_PATH); + resolver = new A2ACardResolver("http://localhost:" + server.port() + AGENT_CARD_PATH); card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); - // baseUrl with trailing slash, agentCardPath without leading slash - resolver = new A2ACardResolver(client, "http://example.com/", AGENT_CARD_PATH.substring(1)); + // baseUrl with trailing slash + resolver = new A2ACardResolver("http://localhost:" + server.port() + "/"); card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); - // baseUrl without trailing slash, agentCardPath without leading slash - resolver = new A2ACardResolver(client, "http://example.com", AGENT_CARD_PATH.substring(1)); + // Sub-path + // baseUrl with trailing slash + resolver = new A2ACardResolver("http://localhost:" + server.port() + "/subpath"); card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo("/subpath" + AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); } @Test public void testGetAgentCardSuccess() throws Exception { - TestHttpClient client = new TestHttpClient(); - client.body = JsonMessages.AGENT_CARD; + HttpClient client = HttpClient.createHttpClient("http://localhost:" + server.port()); + + givenThat(get(urlPathEqualTo(AGENT_CARD_PATH)) + .willReturn(okForContentType("application/json", JsonMessages.AGENT_CARD))); - A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + A2ACardResolver resolver = new A2ACardResolver(client); AgentCard card = resolver.getAgentCard(); AgentCard expectedCard = unmarshalFrom(JsonMessages.AGENT_CARD, AGENT_CARD_TYPE_REFERENCE); @@ -77,14 +113,19 @@ public void testGetAgentCardSuccess() throws Exception { String requestCardString = OBJECT_MAPPER.writeValueAsString(card); assertEquals(expected, requestCardString); + + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); } @Test public void testGetAgentCardJsonDecodeError() throws Exception { - TestHttpClient client = new TestHttpClient(); - client.body = "X" + JsonMessages.AGENT_CARD; + HttpClient client = HttpClient.createHttpClient("http://localhost:" + server.port()); - A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + givenThat(get(urlPathEqualTo(AGENT_CARD_PATH)) + .willReturn(okForContentType("application/json", "X" + JsonMessages.AGENT_CARD))); + + A2ACardResolver resolver = new A2ACardResolver(client); boolean success = false; try { @@ -93,15 +134,20 @@ public void testGetAgentCardJsonDecodeError() throws Exception { } catch (A2AClientJSONError expected) { } assertFalse(success); + + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); } @Test public void testGetAgentCardRequestError() throws Exception { - TestHttpClient client = new TestHttpClient(); - client.status = 503; + HttpClient client = HttpClient.createHttpClient("http://localhost:" + server.port()); + + givenThat(get(urlPathEqualTo(AGENT_CARD_PATH)) + .willReturn(status(503))); - A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + A2ACardResolver resolver = new A2ACardResolver(client); String msg = null; try { @@ -110,71 +156,9 @@ public void testGetAgentCardRequestError() throws Exception { msg = expected.getMessage(); } assertTrue(msg.contains("503")); - } - - private static class TestHttpClient implements A2AHttpClient { - int status = 200; - String body; - String url; - - @Override - public GetBuilder createGet() { - return new TestGetBuilder(); - } - - @Override - public PostBuilder createPost() { - return null; - } - @Override - public DeleteBuilder createDelete() { - return null; - } - - class TestGetBuilder implements A2AHttpClient.GetBuilder { - - @Override - public A2AHttpResponse get() throws IOException, InterruptedException { - return new A2AHttpResponse() { - @Override - public int status() { - return status; - } - - @Override - public boolean success() { - return status == 200; - } - - @Override - public String body() { - return body; - } - }; - } - - @Override - public CompletableFuture getAsyncSSE(Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { - return null; - } - - @Override - public GetBuilder url(String s) { - url = s; - return this; - } - - @Override - public GetBuilder addHeader(String name, String value) { - return this; - } - - @Override - public GetBuilder addHeaders(Map headers) { - return this; - } - } + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); } } diff --git a/http-client/src/test/java/io/a2a/client/http/jdk/JdkHttpClientTest.java b/http-client/src/test/java/io/a2a/client/http/jdk/JdkHttpClientTest.java new file mode 100644 index 00000000..7ca4e3d4 --- /dev/null +++ b/http-client/src/test/java/io/a2a/client/http/jdk/JdkHttpClientTest.java @@ -0,0 +1,31 @@ +package io.a2a.client.http.jdk; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class JdkHttpClientTest { + + @Test + public void testBaseUrlNormalization() { + String baseUrl = "http://localhost:8080"; + + JdkHttpClient client = new JdkHttpClient(baseUrl); + Assertions.assertEquals(baseUrl, client.getBaseUrl()); + + baseUrl = "http://localhost"; + client = new JdkHttpClient(baseUrl); + Assertions.assertEquals("http://localhost", client.getBaseUrl()); + + baseUrl = "https://localhost"; + client = new JdkHttpClient(baseUrl); + Assertions.assertEquals("https://localhost", client.getBaseUrl()); + + baseUrl = "https://localhost:443"; + client = new JdkHttpClient(baseUrl); + Assertions.assertEquals("https://localhost:443", client.getBaseUrl()); + + baseUrl = "https://localhost:80/test"; + client = new JdkHttpClient(baseUrl); + Assertions.assertEquals("https://localhost:80", client.getBaseUrl()); + } +} \ No newline at end of file diff --git a/pom.xml b/pom.xml index 913b0c8c..74dc84f3 100644 --- a/pom.xml +++ b/pom.xml @@ -55,6 +55,7 @@ 3.1.0 5.13.4 5.17.0 + 3.13.1 5.15.0 1.1.1 1.7.1 @@ -247,6 +248,12 @@ ${mockserver.version} test + + org.wiremock + wiremock + ${wiremock.version} + test + ch.qos.logback logback-classic @@ -265,6 +272,18 @@ test ${project.version} + + ${project.groupId} + a2a-java-sdk-tests-client-common + ${project.version} + + + ${project.groupId} + a2a-java-sdk-tests-client-common + test-jar + test + ${project.version} + ${project.groupId} a2a-java-sdk-server-common @@ -437,6 +456,7 @@ extras/task-store-database-jpa extras/push-notification-config-store-database-jpa extras/queue-manager-replicated + extras/http-client-vertx http-client reference/common reference/grpc @@ -447,6 +467,7 @@ spec-grpc tck tests/server-common + tests/client-common transport/jsonrpc transport/grpc transport/rest diff --git a/server-common/src/main/java/io/a2a/server/http/HttpClientManager.java b/server-common/src/main/java/io/a2a/server/http/HttpClientManager.java new file mode 100644 index 00000000..fd02caf1 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/http/HttpClientManager.java @@ -0,0 +1,59 @@ +package io.a2a.server.http; + +import io.a2a.client.http.HttpClient; +import io.a2a.util.Assert; +import jakarta.enterprise.context.ApplicationScoped; + +import java.net.URI; +import java.net.URL; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; + +@ApplicationScoped +public class HttpClientManager { + + private final Map clients = new ConcurrentHashMap<>(); + + public HttpClient getOrCreate(String url) { + Assert.checkNotNullParam("url", url); + + try { + return clients.computeIfAbsent(Endpoint.from(URI.create(url).toURL()), new Function() { + @Override + public HttpClient apply(Endpoint edpt) { + return HttpClient.createHttpClient(url); + } + }); + } catch (Exception ex) { + throw new IllegalArgumentException("URL is malformed: [" + url + "]"); + } + } + + private static class Endpoint { + private final String host; + private final int port; + + public Endpoint(String host, int port) { + this.host = host; + this.port = port; + } + + public static Endpoint from(URL url) { + return new Endpoint(url.getHost(), url.getPort() != -1 ? url.getPort() : url.getDefaultPort()); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Endpoint endpoint = (Endpoint) o; + return port == endpoint.port && Objects.equals(host, endpoint.host); + } + + @Override + public int hashCode() { + return Objects.hash(host, port); + } + } +} diff --git a/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java b/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java index 4afaf3b4..bb304b44 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java +++ b/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java @@ -1,18 +1,19 @@ package io.a2a.server.tasks; import static io.a2a.common.A2AHeaders.X_A2A_NOTIFICATION_TOKEN; + +import io.a2a.server.http.HttpClientManager; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; -import java.io.IOException; +import java.net.URI; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import com.fasterxml.jackson.core.JsonProcessingException; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClient; import io.a2a.spec.PushNotificationConfig; import io.a2a.spec.Task; import io.a2a.util.Utils; @@ -25,18 +26,13 @@ public class BasePushNotificationSender implements PushNotificationSender { private static final Logger LOGGER = LoggerFactory.getLogger(BasePushNotificationSender.class); - private final A2AHttpClient httpClient; private final PushNotificationConfigStore configStore; + private final HttpClientManager clientManager; @Inject - public BasePushNotificationSender(PushNotificationConfigStore configStore) { - this.httpClient = new JdkA2AHttpClient(); - this.configStore = configStore; - } - - public BasePushNotificationSender(PushNotificationConfigStore configStore, A2AHttpClient httpClient) { + public BasePushNotificationSender(PushNotificationConfigStore configStore, HttpClientManager clientManager) { this.configStore = configStore; - this.httpClient = httpClient; + this.clientManager = clientManager; } @Override @@ -68,10 +64,13 @@ private CompletableFuture dispatch(Task task, PushNotificationConfig pu } private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) { - String url = pushInfo.url(); - String token = pushInfo.token(); + final String url = pushInfo.url(); + final String token = pushInfo.token(); - A2AHttpClient.PostBuilder postBuilder = httpClient.createPost(); + // Delegate to the HTTP client manager to better manage client's connection pool. + final HttpClient client = clientManager.getOrCreate(url); + final URI uri = URI.create(url); + HttpClient.PostRequestBuilder postBuilder = client.post(uri.getPath()); if (token != null && !token.isBlank()) { postBuilder.addHeader(X_A2A_NOTIFICATION_TOKEN, token); } @@ -89,10 +88,10 @@ private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) try { postBuilder - .url(url) .body(body) - .post(); - } catch (IOException | InterruptedException e) { + .send() + .get(); + } catch (ExecutionException | InterruptedException e) { LOGGER.debug("Error pushing data to " + url + ": {}", e.getMessage(), e); return false; } diff --git a/server-common/src/test/java/io/a2a/server/http/HttpClientManagerTest.java b/server-common/src/test/java/io/a2a/server/http/HttpClientManagerTest.java new file mode 100644 index 00000000..f244dfdf --- /dev/null +++ b/server-common/src/test/java/io/a2a/server/http/HttpClientManagerTest.java @@ -0,0 +1,52 @@ +package io.a2a.server.http; + +import io.a2a.client.http.HttpClient; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class HttpClientManagerTest { + + private final HttpClientManager clientManager = new HttpClientManager(); + + @Test + public void testThrowsIllegalArgument() { + Assertions.assertThrows( + IllegalArgumentException.class, + () -> clientManager.getOrCreate(null) + ); + } + + @Test + public void testValidateCacheInstance() { + HttpClient client1 = clientManager.getOrCreate("http://localhost:8000"); + HttpClient client2 = clientManager.getOrCreate("http://localhost:8000"); + HttpClient client3 = clientManager.getOrCreate("http://localhost:8001"); + HttpClient client4 = clientManager.getOrCreate("http://remote_agent:8001"); + + Assertions.assertSame(client1, client2); + Assertions.assertNotSame(client1, client3); + Assertions.assertNotSame(client1, client4); + Assertions.assertNotSame(client3, client4); + } + + @Test + public void testValidateCacheNoPort() { + HttpClient client1 = clientManager.getOrCreate("https://localhost"); + HttpClient client2 = clientManager.getOrCreate("https://localhost:443"); + HttpClient client3 = clientManager.getOrCreate("http://localhost"); + HttpClient client4 = clientManager.getOrCreate("http://localhost:80"); + + Assertions.assertSame(client1, client2); + Assertions.assertNotSame(client1, client3); + Assertions.assertSame(client3, client4); + Assertions.assertNotSame(client2, client4); + } + + @Test + public void testThrowsInvalidUrl() { + Assertions.assertThrows( + IllegalArgumentException.class, + () -> clientManager.getOrCreate("this_is_invalid") + ); + } +} diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java index 9f12ee79..3dd5a97c 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java @@ -3,6 +3,10 @@ import java.io.IOException; import java.io.InputStream; import java.net.URL; + +import io.a2a.client.http.sse.Event; +import io.a2a.server.http.HttpClientManager; +import jakarta.enterprise.context.Dependent; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -13,10 +17,8 @@ import java.util.concurrent.Executors; import java.util.function.Consumer; -import jakarta.enterprise.context.Dependent; - -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; import io.a2a.server.agentexecution.AgentExecutor; import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; @@ -42,6 +44,11 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; public class AbstractA2ARequestHandlerTest { @@ -61,6 +68,9 @@ public class AbstractA2ARequestHandlerTest { private static final String PREFERRED_TRANSPORT = "preferred-transport"; private static final String A2A_REQUESTHANDLER_TEST_PROPERTIES = "/a2a-requesthandler-test.properties"; + @Mock + private HttpClientManager clientManager; + protected AgentExecutor executor; protected TaskStore taskStore; protected RequestHandler requestHandler; @@ -73,6 +83,8 @@ public class AbstractA2ARequestHandlerTest { @BeforeEach public void init() { + MockitoAnnotations.openMocks(this); + executor = new AgentExecutor() { @Override public void execute(RequestContext context, EventQueue eventQueue) throws JSONRPCError { @@ -92,8 +104,10 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPC taskStore = new InMemoryTaskStore(); queueManager = new InMemoryQueueManager(); httpClient = new TestHttpClient(); + + Mockito.when(clientManager.getOrCreate(any())).thenReturn(httpClient); PushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); - PushNotificationSender pushSender = new BasePushNotificationSender(pushConfigStore, httpClient); + PushNotificationSender pushSender = new BasePushNotificationSender(pushConfigStore, clientManager); requestHandler = new DefaultRequestHandler(executor, taskStore, queueManager, pushConfigStore, pushSender, internalExecutor); } @@ -148,75 +162,79 @@ protected interface AgentExecutorMethod { @Dependent @IfBuildProfile("test") - protected static class TestHttpClient implements A2AHttpClient { + protected static class TestHttpClient implements HttpClient { public final List tasks = Collections.synchronizedList(new ArrayList<>()); public volatile CountDownLatch latch; @Override - public GetBuilder createGet() { + public GetRequestBuilder get(String path) { return null; } @Override - public PostBuilder createPost() { - return new TestHttpClient.TestPostBuilder(); + public PostRequestBuilder post(String path) { + return new TestPostRequestBuilder(); } @Override - public DeleteBuilder createDelete() { + public DeleteRequestBuilder delete(String path) { return null; } - class TestPostBuilder implements A2AHttpClient.PostBuilder { + class TestPostRequestBuilder implements PostRequestBuilder { + private volatile String body; @Override - public PostBuilder body(String body) { + public PostRequestBuilder body(String body) { this.body = body; return this; } @Override - public A2AHttpResponse post() throws IOException, InterruptedException { - tasks.add(Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE)); + public CompletableFuture send() { + CompletableFuture future = new CompletableFuture<>(); + try { - return new A2AHttpResponse() { - @Override - public int status() { - return 200; - } - - @Override - public boolean success() { - return true; - } - - @Override - public String body() { - return ""; - } - }; + tasks.add(Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE)); + + future.complete( + new HttpResponse() { + @Override + public int statusCode() { + return 200; + } + + @Override + public boolean success() { + return true; + } + + @Override + public String body() { + return ""; + } + + @Override + public void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer) { + + } + }); + } catch (Exception ex) { + future.completeExceptionally(ex); } finally { latch.countDown(); } - } - - @Override - public CompletableFuture postAsyncSSE(Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { - return null; - } - @Override - public PostBuilder url(String s) { - return this; + return future; } @Override - public PostBuilder addHeader(String name, String value) { + public PostRequestBuilder addHeader(String name, String value) { return this; } @Override - public PostBuilder addHeaders(Map headers) { + public PostRequestBuilder addHeaders(Map headers) { return this; } diff --git a/server-common/src/test/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStoreTest.java b/server-common/src/test/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStoreTest.java index 9156f78b..81d27be1 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStoreTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStoreTest.java @@ -9,17 +9,19 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; +import io.a2a.server.http.HttpClientManager; import org.mockito.ArgumentCaptor; import java.util.List; +import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; import io.a2a.common.A2AHeaders; import io.a2a.spec.PushNotificationConfig; import io.a2a.spec.Task; @@ -32,35 +34,39 @@ class InMemoryPushNotificationConfigStoreTest { private BasePushNotificationSender notificationSender; @Mock - private A2AHttpClient mockHttpClient; + private HttpClientManager clientManager; @Mock - private A2AHttpClient.PostBuilder mockPostBuilder; + private HttpClient mockHttpClient; @Mock - private A2AHttpResponse mockHttpResponse; + private HttpClient.PostRequestBuilder mockPostBuilder; + + @Mock + private HttpResponse mockHttpResponse; @BeforeEach public void setUp() { MockitoAnnotations.openMocks(this); configStore = new InMemoryPushNotificationConfigStore(); - notificationSender = new BasePushNotificationSender(configStore, mockHttpClient); + notificationSender = new BasePushNotificationSender(configStore, clientManager); } private void setupBasicMockHttpResponse() throws Exception { - when(mockHttpClient.createPost()).thenReturn(mockPostBuilder); - when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); + when(clientManager.getOrCreate(any())).thenReturn(mockHttpClient); + when(mockHttpClient.post(any())).thenReturn(mockPostBuilder); +// when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); when(mockPostBuilder.body(any(String.class))).thenReturn(mockPostBuilder); - when(mockPostBuilder.post()).thenReturn(mockHttpResponse); + when(mockPostBuilder.send()).thenReturn(CompletableFuture.completedFuture(mockHttpResponse)); when(mockHttpResponse.success()).thenReturn(true); } private void verifyHttpCallWithoutToken(PushNotificationConfig config, Task task, String expectedToken) throws Exception { ArgumentCaptor bodyCaptor = ArgumentCaptor.forClass(String.class); - verify(mockHttpClient).createPost(); - verify(mockPostBuilder).url(config.url()); + verify(mockHttpClient).post(any()); +// verify(mockPostBuilder).url(config.url()); verify(mockPostBuilder).body(bodyCaptor.capture()); - verify(mockPostBuilder).post(); + verify(mockPostBuilder).send(); // Verify that addHeader was never called for authentication token verify(mockPostBuilder, never()).addHeader(A2AHeaders.X_A2A_NOTIFICATION_TOKEN, expectedToken); @@ -229,21 +235,23 @@ public void testSendNotificationSuccess() throws Exception { PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", null); configStore.setInfo(taskId, config); + when(clientManager.getOrCreate(any())).thenReturn(mockHttpClient); + // Mock successful HTTP response - when(mockHttpClient.createPost()).thenReturn(mockPostBuilder); - when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); + when(mockHttpClient.post(any())).thenReturn(mockPostBuilder); +// when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); when(mockPostBuilder.body(any(String.class))).thenReturn(mockPostBuilder); - when(mockPostBuilder.post()).thenReturn(mockHttpResponse); + when(mockPostBuilder.send()).thenReturn(CompletableFuture.completedFuture(mockHttpResponse)); when(mockHttpResponse.success()).thenReturn(true); notificationSender.sendNotification(task); // Verify HTTP client was called ArgumentCaptor bodyCaptor = ArgumentCaptor.forClass(String.class); - verify(mockHttpClient).createPost(); - verify(mockPostBuilder).url(config.url()); + verify(mockHttpClient).post(any()); +// verify(mockPostBuilder).url(config.url()); verify(mockPostBuilder).body(bodyCaptor.capture()); - verify(mockPostBuilder).post(); + verify(mockPostBuilder).send(); // Verify the request body contains the task data String sentBody = bodyCaptor.getValue(); @@ -258,24 +266,26 @@ public void testSendNotificationWithToken() throws Exception { PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", "unique_token"); configStore.setInfo(taskId, config); + when(clientManager.getOrCreate(any())).thenReturn(mockHttpClient); + // Mock successful HTTP response - when(mockHttpClient.createPost()).thenReturn(mockPostBuilder); - when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); + when(mockHttpClient.post(any())).thenReturn(mockPostBuilder); +// when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); when(mockPostBuilder.body(any(String.class))).thenReturn(mockPostBuilder); when(mockPostBuilder.addHeader(any(String.class), any(String.class))).thenReturn(mockPostBuilder); - when(mockPostBuilder.post()).thenReturn(mockHttpResponse); + when(mockPostBuilder.send()).thenReturn(CompletableFuture.completedFuture(mockHttpResponse)); when(mockHttpResponse.success()).thenReturn(true); notificationSender.sendNotification(task); // Verify HTTP client was called with proper authentication ArgumentCaptor bodyCaptor = ArgumentCaptor.forClass(String.class); - verify(mockHttpClient).createPost(); - verify(mockPostBuilder).url(config.url()); + verify(mockHttpClient).post(any()); +// verify(mockPostBuilder).url(config.url()); verify(mockPostBuilder).body(bodyCaptor.capture()); // Verify that the token is included in request headers as X-A2A-Notification-Token verify(mockPostBuilder).addHeader(A2AHeaders.X_A2A_NOTIFICATION_TOKEN, config.token()); - verify(mockPostBuilder).post(); + verify(mockPostBuilder).send(); // Verify the request body contains the task data String sentBody = bodyCaptor.getValue(); @@ -291,7 +301,7 @@ public void testSendNotificationNoConfig() throws Exception { notificationSender.sendNotification(task); // Verify HTTP client was never called - verify(mockHttpClient, never()).createPost(); + verify(mockHttpClient, never()).post(any()); } @Test diff --git a/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java b/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java index 2ab974ed..f01d2422 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java @@ -2,6 +2,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; import java.io.IOException; import java.util.ArrayList; @@ -13,20 +15,27 @@ import java.util.concurrent.TimeUnit; import java.util.function.Consumer; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.Event; +import io.a2a.server.http.HttpClientManager; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; import io.a2a.common.A2AHeaders; import io.a2a.util.Utils; import io.a2a.spec.PushNotificationConfig; import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; public class PushNotificationSenderTest { + @Mock + private HttpClientManager clientManager; + private TestHttpClient testHttpClient; private InMemoryPushNotificationConfigStore configStore; private BasePushNotificationSender sender; @@ -34,7 +43,7 @@ public class PushNotificationSenderTest { /** * Simple test implementation of A2AHttpClient that captures HTTP calls for verification */ - private static class TestHttpClient implements A2AHttpClient { + private static class TestHttpClient implements HttpClient { final List tasks = Collections.synchronizedList(new ArrayList<>()); final List urls = Collections.synchronizedList(new ArrayList<>()); final List> headers = Collections.synchronizedList(new ArrayList<>()); @@ -42,85 +51,85 @@ private static class TestHttpClient implements A2AHttpClient { volatile boolean shouldThrowException = false; @Override - public GetBuilder createGet() { + public GetRequestBuilder get(String path) { return null; } @Override - public PostBuilder createPost() { + public PostRequestBuilder post(String path) { return new TestPostBuilder(); } @Override - public DeleteBuilder createDelete() { + public DeleteRequestBuilder delete(String path) { return null; } - class TestPostBuilder implements A2AHttpClient.PostBuilder { + class TestPostBuilder implements HttpClient.PostRequestBuilder { private volatile String body; - private volatile String url; private final Map requestHeaders = new java.util.HashMap<>(); @Override - public PostBuilder body(String body) { + public PostRequestBuilder body(String body) { this.body = body; return this; } @Override - public A2AHttpResponse post() throws IOException, InterruptedException { + public CompletableFuture send() { + CompletableFuture future = new CompletableFuture<>(); + if (shouldThrowException) { - throw new IOException("Simulated network error"); + future.completeExceptionally(new IOException("Simulated network error")); + return future; } try { Task task = Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE); tasks.add(task); - urls.add(url); headers.add(new java.util.HashMap<>(requestHeaders)); - - return new A2AHttpResponse() { - @Override - public int status() { - return 200; - } - - @Override - public boolean success() { - return true; - } - - @Override - public String body() { - return ""; - } - }; + + future.complete( + new HttpResponse() { + @Override + public int statusCode() { + return 200; + } + + @Override + public boolean success() { + return true; + } + + @Override + public String body() { + return ""; + } + + @Override + public void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer) { + + } + }); + } catch (Exception e) { + future.completeExceptionally(e); } finally { if (latch != null) { latch.countDown(); } } - } - @Override - public CompletableFuture postAsyncSSE(Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { - return null; + return future; } @Override - public PostBuilder url(String url) { - this.url = url; - return this; - } - - @Override - public PostBuilder addHeader(String name, String value) { + public PostRequestBuilder addHeader(String name, String value) { requestHeaders.put(name, value); return this; } @Override - public PostBuilder addHeaders(Map headers) { + public PostRequestBuilder addHeaders(Map headers) { requestHeaders.putAll(headers); return this; } @@ -129,9 +138,10 @@ public PostBuilder addHeaders(Map headers) { @BeforeEach public void setUp() { + MockitoAnnotations.openMocks(this); testHttpClient = new TestHttpClient(); configStore = new InMemoryPushNotificationConfigStore(); - sender = new BasePushNotificationSender(configStore, testHttpClient); + sender = new BasePushNotificationSender(configStore, clientManager); } private void testSendNotificationWithInvalidToken(String token, String testName) throws InterruptedException { @@ -141,7 +151,9 @@ private void testSendNotificationWithInvalidToken(String token, String testName) // Set up the configuration in the store configStore.setInfo(taskId, config); - + + when(clientManager.getOrCreate(any())).thenReturn(testHttpClient); + // Set up latch to wait for async completion testHttpClient.latch = new CountDownLatch(1); @@ -185,7 +197,9 @@ public void testSendNotificationSuccess() throws InterruptedException { // Set up the configuration in the store configStore.setInfo(taskId, config); - + + when(clientManager.getOrCreate(any())).thenReturn(testHttpClient); + // Set up latch to wait for async completion testHttpClient.latch = new CountDownLatch(1); @@ -210,7 +224,9 @@ public void testSendNotificationWithTokenSuccess() throws InterruptedException { // Set up the configuration in the store configStore.setInfo(taskId, config); - + + when(clientManager.getOrCreate(any())).thenReturn(testHttpClient); + // Set up latch to wait for async completion testHttpClient.latch = new CountDownLatch(1); @@ -263,22 +279,27 @@ public void testSendNotificationMultipleConfigs() throws InterruptedException { // Set up multiple configurations in the store configStore.setInfo(taskId, config1); configStore.setInfo(taskId, config2); - + + TestHttpClient httpClient = spy(testHttpClient); + when(clientManager.getOrCreate(any())).thenReturn(httpClient); + // Set up latch to wait for async completion (2 calls expected) - testHttpClient.latch = new CountDownLatch(2); + httpClient.latch = new CountDownLatch(2); sender.sendNotification(taskData); // Wait for the async operations to complete - assertTrue(testHttpClient.latch.await(5, TimeUnit.SECONDS), "HTTP calls should complete within 5 seconds"); + assertTrue(httpClient.latch.await(5, TimeUnit.SECONDS), "HTTP calls should complete within 5 seconds"); // Verify both tasks were sent via HTTP - assertEquals(2, testHttpClient.tasks.size()); - assertEquals(2, testHttpClient.urls.size()); - assertTrue(testHttpClient.urls.containsAll(java.util.List.of("http://notify.me/cfg1", "http://notify.me/cfg2"))); + assertEquals(2, httpClient.tasks.size()); + //assertEquals(2, testHttpClient.urls.size()); + verify(httpClient).post("/cfg1"); + verify(httpClient).post("/cfg2"); + // assertTrue(testHttpClient.urls.containsAll(java.util.List.of("http://notify.me/cfg1", "http://notify.me/cfg2"))); // Both tasks should be identical (same task sent to different endpoints) - for (Task sentTask : testHttpClient.tasks) { + for (Task sentTask : httpClient.tasks) { assertEquals(taskData.getId(), sentTask.getId()); assertEquals(taskData.getContextId(), sentTask.getContextId()); assertEquals(taskData.getStatus().state(), sentTask.getStatus().state()); diff --git a/tests/client-common/pom.xml b/tests/client-common/pom.xml new file mode 100644 index 00000000..f003c7b6 --- /dev/null +++ b/tests/client-common/pom.xml @@ -0,0 +1,60 @@ + + + 4.0.0 + + + io.github.a2asdk + a2a-java-sdk-parent + 0.3.0.Beta3-SNAPSHOT + ../../pom.xml + + a2a-java-sdk-tests-client-common + + jar + + Java A2A SDK Client Tests Common + Java SDK for the Agent2Agent Protocol (A2A) - SDK - Client Tests Common + + + + ${project.groupId} + a2a-java-sdk-http-client + test + + + org.junit.jupiter + junit-jupiter-api + test + + + org.wiremock + wiremock + test + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + + + + + + \ No newline at end of file diff --git a/tests/client-common/src/test/java/io/a2a/client/http/common/AbstractHttpClientTest.java b/tests/client-common/src/test/java/io/a2a/client/http/common/AbstractHttpClientTest.java new file mode 100644 index 00000000..9b21d4f1 --- /dev/null +++ b/tests/client-common/src/test/java/io/a2a/client/http/common/AbstractHttpClientTest.java @@ -0,0 +1,187 @@ +package io.a2a.client.http.common; + +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.core.WireMockConfiguration; +import io.a2a.client.http.HttpClientBuilder; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.Event; +import org.junit.jupiter.api.*; + +import java.net.HttpURLConnection; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static com.github.tomakehurst.wiremock.client.WireMock.*; +import static org.junit.jupiter.api.Assertions.*; + +public abstract class AbstractHttpClientTest { + + private static final String AGENT_CARD_PATH = "/.well-known/agent-card.json"; + + private WireMockServer server; + + @BeforeEach + public void setUp() { + server = new WireMockServer(WireMockConfiguration.options().dynamicPort()); + server.start(); + + configureFor("localhost", server.port()); + } + + @AfterEach + public void tearDown() { + if (server != null) { + server.stop(); + } + } + + protected abstract HttpClientBuilder getHttpClientBuilder(); + + private String getServerUrl() { + return "http://localhost:" + server.port(); + } + + /** + * This test is disabled until we can make the http-client layer fully async + */ + @Test + @Disabled + public void testGetWithBodyResponse() throws Exception { + givenThat(get(urlPathEqualTo(AGENT_CARD_PATH)) + .willReturn(okForContentType("application/json", JsonMessages.AGENT_CARD))); + + CountDownLatch latch = new CountDownLatch(1); + getHttpClientBuilder() + .create(getServerUrl()) + .get(AGENT_CARD_PATH) + .send() + .thenAccept(new Consumer() { + @Override + public void accept(HttpResponse httpResponse) { + String body = httpResponse.body(); + + Assertions.assertEquals(JsonMessages.AGENT_CARD, body); + latch.countDown(); + } + }); + + boolean dataReceived = latch.await(5, TimeUnit.SECONDS); + assertTrue(dataReceived); + + } + + @Test + public void testA2AClientSendStreamingMessage() throws Exception { + String eventStream = + JsonStreamingMessages.SEND_MESSAGE_STREAMING_TEST_RESPONSE + + JsonStreamingMessages.TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE; + + givenThat(post(urlPathEqualTo("/")) + .willReturn(okForContentType("text/event-stream", eventStream))); + + CountDownLatch latch = new CountDownLatch(2); + AtomicReference errorRef = new AtomicReference<>(); + + getHttpClientBuilder() + .create(getServerUrl()) + .post("/") + .send() + .thenAccept(new Consumer() { + @Override + public void accept(HttpResponse httpResponse) { + httpResponse.bodyAsSse(new Consumer() { + @Override + public void accept(Event event) { + System.out.println(event); + latch.countDown(); + } + }, new Consumer() { + @Override + public void accept(Throwable throwable) { + errorRef.set(throwable); + latch.countDown(); + } + }); + } + }); + + boolean dataReceived = latch.await(5, TimeUnit.SECONDS); + assertTrue(dataReceived); + assertNull(errorRef.get(), "Should not receive errors during SSE stream"); + } + + @Test + public void testUnauthorizedClient_post() throws Exception { + givenThat(post(urlPathEqualTo("/")) + .willReturn(aResponse().withStatus(HttpURLConnection.HTTP_UNAUTHORIZED))); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference errorRef = new AtomicReference<>(); + AtomicReference responseRef = new AtomicReference<>(); + + getHttpClientBuilder() + // Enforce that the client will be receiving the SSE stream into multiple chunks + // .options(new HttpClientOptions().setMaxChunkSize(24)) + .create(getServerUrl()) + .post("/") + .send() + .whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (throwable != null) { + errorRef.set(throwable); + } + + if (httpResponse != null) { + responseRef.set(httpResponse); + } + + latch.countDown(); + } + }); + + boolean callCompleted = latch.await(5, TimeUnit.SECONDS); + assertTrue(callCompleted); + assertNull(responseRef.get(), "Should not receive response when unauthorized"); + assertNotNull(errorRef.get(), "Should not receive errors during SSE stream"); + } + + @Test + public void testUnauthorizedClient_get() throws Exception { + givenThat(get(urlPathEqualTo("/")) + .willReturn(aResponse().withStatus(HttpURLConnection.HTTP_UNAUTHORIZED))); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference errorRef = new AtomicReference<>(); + AtomicReference responseRef = new AtomicReference<>(); + + getHttpClientBuilder() + // Enforce that the client will be receiving the SSE stream into multiple chunks + // .options(new HttpClientOptions().setMaxChunkSize(24)) + .create(getServerUrl()) + .get("/") + .send() + .whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (throwable != null) { + errorRef.set(throwable); + } + + if (httpResponse != null) { + responseRef.set(httpResponse); + } + + latch.countDown(); + } + }); + + boolean callCompleted = latch.await(5, TimeUnit.SECONDS); + assertTrue(callCompleted); + assertNull(responseRef.get(), "Should not receive response when unauthorized"); + assertNotNull(errorRef.get(), "Should not receive errors during SSE stream"); + } +} diff --git a/tests/client-common/src/test/java/io/a2a/client/http/common/JsonMessages.java b/tests/client-common/src/test/java/io/a2a/client/http/common/JsonMessages.java new file mode 100644 index 00000000..0ab9d811 --- /dev/null +++ b/tests/client-common/src/test/java/io/a2a/client/http/common/JsonMessages.java @@ -0,0 +1,85 @@ +package io.a2a.client.http.common; + +/** + * Request and response messages used by the tests. These have been created following examples from + * the A2A sample messages. + */ +public class JsonMessages { + + static final String AGENT_CARD = """ + { + "protocolVersion": "0.2.9", + "name": "GeoSpatial Route Planner Agent", + "description": "Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", + "url": "https://georoute-agent.example.com/a2a/v1", + "preferredTransport": "JSONRPC", + "additionalInterfaces" : [ + {"url": "https://georoute-agent.example.com/a2a/v1", "transport": "JSONRPC"}, + {"url": "https://georoute-agent.example.com/a2a/grpc", "transport": "GRPC"}, + {"url": "https://georoute-agent.example.com/a2a/json", "transport": "HTTP+JSON"} + ], + "provider": { + "organization": "Example Geo Services Inc.", + "url": "https://www.examplegeoservices.com" + }, + "iconUrl": "https://georoute-agent.example.com/icon.png", + "version": "1.2.0", + "documentationUrl": "https://docs.examplegeoservices.com/georoute-agent/api", + "capabilities": { + "streaming": true, + "pushNotifications": true, + "stateTransitionHistory": false + }, + "securitySchemes": { + "google": { + "type": "openIdConnect", + "openIdConnectUrl": "https://accounts.google.com/.well-known/openid-configuration" + } + }, + "security": [{ "google": ["openid", "profile", "email"] }], + "defaultInputModes": ["application/json", "text/plain"], + "defaultOutputModes": ["application/json", "image/png"], + "skills": [ + { + "id": "route-optimizer-traffic", + "name": "Traffic-Aware Route Optimizer", + "description": "Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", + "tags": ["maps", "routing", "navigation", "directions", "traffic"], + "examples": [ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\\"origin\\": {\\"lat\\": 37.422, \\"lng\\": -122.084}, \\"destination\\": {\\"lat\\": 37.7749, \\"lng\\": -122.4194}, \\"preferences\\": [\\"avoid_ferries\\"]}" + ], + "inputModes": ["application/json", "text/plain"], + "outputModes": [ + "application/json", + "application/vnd.geo+json", + "text/html" + ] + }, + { + "id": "custom-map-generator", + "name": "Personalized Map Generator", + "description": "Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", + "tags": ["maps", "customization", "visualization", "cartography"], + "examples": [ + "Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location." + ], + "inputModes": ["application/json"], + "outputModes": [ + "image/png", + "image/jpeg", + "application/json", + "text/html" + ] + } + ], + "supportsAuthenticatedExtendedCard": true, + "signatures": [ + { + "protected": "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0", + "signature": "QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ" + } + ] + }"""; +} \ No newline at end of file diff --git a/tests/client-common/src/test/java/io/a2a/client/http/common/JsonStreamingMessages.java b/tests/client-common/src/test/java/io/a2a/client/http/common/JsonStreamingMessages.java new file mode 100644 index 00000000..15ae5c38 --- /dev/null +++ b/tests/client-common/src/test/java/io/a2a/client/http/common/JsonStreamingMessages.java @@ -0,0 +1,15 @@ +package io.a2a.client.http.common; + +/** + * Contains JSON strings for testing SSE streaming. + */ +public class JsonStreamingMessages { + + static final String SEND_MESSAGE_STREAMING_TEST_RESPONSE = + "event: message\n" + + "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"id\":\"2\",\"contextId\":\"context-1234\",\"status\":{\"state\":\"completed\"},\"artifacts\":[{\"artifactId\":\"artifact-1\",\"name\":\"joke\",\"parts\":[{\"kind\":\"text\",\"text\":\"Why did the chicken cross the road? To get to the other side!\"}]}],\"metadata\":{},\"kind\":\"task\"}}\n\n"; + + static final String TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE = + "event: message\n" + + "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"id\":\"2\",\"contextId\":\"context-5678\",\"status\":{\"state\":\"completed\"},\"artifacts\":[{\"artifactId\":\"artifact-1\",\"name\":\"joke\",\"parts\":[{\"kind\":\"text\",\"text\":\"Why did the chicken cross the road? To get to the other side!\"}]}],\"metadata\":{},\"kind\":\"task\"}}\n\n"; +} \ No newline at end of file diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java index f161307a..d79c5df5 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java @@ -1,6 +1,5 @@ package io.a2a.server.apps.common; -import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -8,86 +7,91 @@ import java.util.concurrent.CountDownLatch; import java.util.function.Consumer; +import io.a2a.client.http.sse.Event; import jakarta.enterprise.context.Dependent; import jakarta.enterprise.inject.Alternative; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; import io.a2a.spec.Task; import io.a2a.util.Utils; import java.util.Map; @Dependent @Alternative -public class TestHttpClient implements A2AHttpClient { +public class TestHttpClient implements HttpClient { final List tasks = Collections.synchronizedList(new ArrayList<>()); volatile CountDownLatch latch; @Override - public GetBuilder createGet() { + public GetRequestBuilder get(String path) { return null; } @Override - public PostBuilder createPost() { - return new TestPostBuilder(); + public PostRequestBuilder post(String path) { + return new TestPostRequestBuilder(); } @Override - public DeleteBuilder createDelete() { + public DeleteRequestBuilder delete(String path) { return null; } - class TestPostBuilder implements A2AHttpClient.PostBuilder { + class TestPostRequestBuilder implements PostRequestBuilder { + private volatile String body; @Override - public PostBuilder body(String body) { + public PostRequestBuilder body(String body) { this.body = body; return this; } @Override - public A2AHttpResponse post() throws IOException, InterruptedException { - tasks.add(Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE)); + public CompletableFuture send() { + CompletableFuture future = new CompletableFuture<>(); + try { - return new A2AHttpResponse() { - @Override - public int status() { - return 200; - } - - @Override - public boolean success() { - return true; - } - - @Override - public String body() { - return ""; - } - }; + tasks.add(Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE)); + + future.complete( + new HttpResponse() { + @Override + public int statusCode() { + return 200; + } + + @Override + public boolean success() { + return true; + } + + @Override + public String body() { + return ""; + } + + @Override + public void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer) { + + } + }); + } catch (Exception ex) { + future.completeExceptionally(ex); } finally { latch.countDown(); } - } - - @Override - public CompletableFuture postAsyncSSE(Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { - return null; - } - @Override - public PostBuilder url(String s) { - return this; + return future; } @Override - public PostBuilder addHeader(String name, String value) { + public PostRequestBuilder addHeader(String name, String value) { return this; } @Override - public PostBuilder addHeaders(Map headers) { + public PostRequestBuilder addHeaders(Map headers) { return this; } } diff --git a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java index 9d12824b..e330180c 100644 --- a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java +++ b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java @@ -22,8 +22,7 @@ import io.a2a.server.events.EventConsumer; import io.a2a.server.requesthandlers.AbstractA2ARequestHandlerTest; import io.a2a.server.requesthandlers.DefaultRequestHandler; -import io.a2a.server.tasks.ResultAggregator; -import io.a2a.server.tasks.TaskUpdater; +import io.a2a.server.tasks.*; import io.a2a.spec.AgentCard; import io.a2a.spec.Artifact; import io.a2a.spec.AuthenticatedExtendedCardNotConfiguredError; From 306cebee41f6ee826e04afa0071cdd3258b29ed3 Mon Sep 17 00:00:00 2001 From: David Brassely Date: Tue, 14 Oct 2025 22:30:55 +0200 Subject: [PATCH 2/2] --wip-- [skip ci] --- .../java/io/a2a/client/AbstractClient.java | 349 +-------- .../main/java/io/a2a/client/AsyncClient.java | 220 ++++++ .../src/main/java/io/a2a/client/Client.java | 481 ++++++++----- .../java/io/a2a/client/ClientBuilder.java | 24 +- .../main/java/io/a2a/client/SyncClient.java | 242 +++++++ .../client/transport/grpc/GrpcTransport.java | 154 ++-- .../grpc/SingleValueStreamObserver.java | 68 ++ .../transport/jsonrpc/JSONRPCTransport.java | 253 +++++-- .../jsonrpc/JSONRPCTransportTest.java | 674 +++++++++++------- .../transport/rest/RestErrorMapper.java | 6 +- .../client/transport/rest/RestTransport.java | 417 ++++++----- .../transport/rest/RestTransportTest.java | 271 ++++--- .../client/transport/spi/ClientTransport.java | 17 +- .../client/http/vertx/VertxHttpClient.java | 13 +- .../io/a2a/client/http/A2ACardResolver.java | 2 +- .../java/io/a2a/client/http/HttpClient.java | 4 +- .../java/io/a2a/client/http/HttpResponse.java | 3 +- .../io/a2a/client/http/jdk/JdkHttpClient.java | 13 +- .../http/common/AbstractHttpClientTest.java | 9 +- 19 files changed, 2002 insertions(+), 1218 deletions(-) create mode 100644 client/base/src/main/java/io/a2a/client/AsyncClient.java create mode 100644 client/base/src/main/java/io/a2a/client/SyncClient.java create mode 100644 client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/SingleValueStreamObserver.java diff --git a/client/base/src/main/java/io/a2a/client/AbstractClient.java b/client/base/src/main/java/io/a2a/client/AbstractClient.java index b6a41a35..68480904 100644 --- a/client/base/src/main/java/io/a2a/client/AbstractClient.java +++ b/client/base/src/main/java/io/a2a/client/AbstractClient.java @@ -1,26 +1,14 @@ package io.a2a.client; -import static io.a2a.util.Assert.checkNotNullParam; +import io.a2a.spec.AgentCard; +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; import java.util.List; -import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Consumer; -import io.a2a.client.transport.spi.interceptors.ClientCallContext; -import io.a2a.spec.A2AClientException; -import io.a2a.spec.AgentCard; -import io.a2a.spec.DeleteTaskPushNotificationConfigParams; -import io.a2a.spec.GetTaskPushNotificationConfigParams; -import io.a2a.spec.ListTaskPushNotificationConfigParams; -import io.a2a.spec.Message; -import io.a2a.spec.PushNotificationConfig; -import io.a2a.spec.Task; -import io.a2a.spec.TaskIdParams; -import io.a2a.spec.TaskPushNotificationConfig; -import io.a2a.spec.TaskQueryParams; -import org.jspecify.annotations.NonNull; -import org.jspecify.annotations.Nullable; +import static io.a2a.util.Assert.checkNotNullParam; /** * Abstract class representing an A2A client. Provides a standard set @@ -43,334 +31,6 @@ public AbstractClient(@NonNull List> consumer this.streamingErrorHandler = streamingErrorHandler; } - /** - * Send a message to the remote agent. This method will automatically use - * the streaming or non-streaming approach as determined by the server's - * agent card and the client configuration. The configured client consumers - * will be used to handle messages, tasks, and update events received - * from the remote agent. The configured streaming error handler will be used - * if an error occurs during streaming. The configured client push notification - * configuration will get used for streaming. - * - * @param request the message - * @throws A2AClientException if sending the message fails for any reason - */ - public void sendMessage(Message request) throws A2AClientException { - sendMessage(request, null); - } - - /** - * Send a message to the remote agent. This method will automatically use - * the streaming or non-streaming approach as determined by the server's - * agent card and the client configuration. The configured client consumers - * will be used to handle messages, tasks, and update events received - * from the remote agent. The configured streaming error handler will be used - * if an error occurs during streaming. The configured client push notification - * configuration will get used for streaming. - * - * @param request the message - * @param context optional client call context for the request (may be {@code null}) - * @throws A2AClientException if sending the message fails for any reason - */ - public abstract void sendMessage(Message request, @Nullable ClientCallContext context) throws A2AClientException; - - /** - * Send a message to the remote agent. This method will automatically use - * the streaming or non-streaming approach as determined by the server's - * agent card and the client configuration. The specified client consumers - * will be used to handle messages, tasks, and update events received - * from the remote agent. The specified streaming error handler will be used - * if an error occurs during streaming. The configured client push notification - * configuration will get used for streaming. - * - * @param request the message - * @param consumers a list of consumers to pass responses from the remote agent to - * @param streamingErrorHandler an error handler that should be used for the streaming case if an error occurs - * @throws A2AClientException if sending the message fails for any reason - */ - public void sendMessage(Message request, - List> consumers, - Consumer streamingErrorHandler) throws A2AClientException { - sendMessage(request, consumers, streamingErrorHandler, null); - } - - /** - * Send a message to the remote agent. This method will automatically use - * the streaming or non-streaming approach as determined by the server's - * agent card and the client configuration. The specified client consumers - * will be used to handle messages, tasks, and update events received - * from the remote agent. The specified streaming error handler will be used - * if an error occurs during streaming. The configured client push notification - * configuration will get used for streaming. - * - * @param request the message - * @param consumers a list of consumers to pass responses from the remote agent to - * @param streamingErrorHandler an error handler that should be used for the streaming case if an error occurs - * @param context optional client call context for the request (may be {@code null}) - * @throws A2AClientException if sending the message fails for any reason - */ - public abstract void sendMessage(Message request, - List> consumers, - Consumer streamingErrorHandler, - @Nullable ClientCallContext context) throws A2AClientException; - - /** - * Send a message to the remote agent. This method will automatically use - * the streaming or non-streaming approach as determined by the server's - * agent card and the client configuration. The configured client consumers - * will be used to handle messages, tasks, and update events received from - * the remote agent. The configured streaming error handler will be used - * if an error occurs during streaming. - * - * @param request the message - * @param pushNotificationConfiguration the push notification configuration that should be - * used if the streaming approach is used - * @param metadata the optional metadata to include when sending the message - * @throws A2AClientException if sending the message fails for any reason - */ - public void sendMessage(Message request, PushNotificationConfig pushNotificationConfiguration, - Map metadata) throws A2AClientException { - sendMessage(request, pushNotificationConfiguration, metadata, null); - } - - /** - * Send a message to the remote agent. This method will automatically use - * the streaming or non-streaming approach as determined by the server's - * agent card and the client configuration. The configured client consumers - * will be used to handle messages, tasks, and update events received from - * the remote agent. The configured streaming error handler will be used - * if an error occurs during streaming. - * - * @param request the message - * @param pushNotificationConfiguration the push notification configuration that should be - * used if the streaming approach is used - * @param metadata the optional metadata to include when sending the message - * @param context optional client call context for the request (may be {@code null}) - * @throws A2AClientException if sending the message fails for any reason - */ - public abstract void sendMessage(Message request, PushNotificationConfig pushNotificationConfiguration, - Map metadata, @Nullable ClientCallContext context) throws A2AClientException; - - /** - * Retrieve the current state and history of a specific task. - * - * @param request the task query parameters specifying which task to retrieve - * @return the task - * @throws A2AClientException if retrieving the task fails for any reason - */ - public Task getTask(TaskQueryParams request) throws A2AClientException { - return getTask(request, null); - } - - /** - * Retrieve the current state and history of a specific task. - * - * @param request the task query parameters specifying which task to retrieve - * @param context optional client call context for the request (may be {@code null}) - * @return the task - * @throws A2AClientException if retrieving the task fails for any reason - */ - public abstract Task getTask(TaskQueryParams request, @Nullable ClientCallContext context) throws A2AClientException; - - /** - * Request the agent to cancel a specific task. - * - * @param request the task ID parameters specifying which task to cancel - * @return the cancelled task - * @throws A2AClientException if cancelling the task fails for any reason - */ - public Task cancelTask(TaskIdParams request) throws A2AClientException { - return cancelTask(request, null); - } - - /** - * Request the agent to cancel a specific task. - * - * @param request the task ID parameters specifying which task to cancel - * @param context optional client call context for the request (may be {@code null}) - * @return the cancelled task - * @throws A2AClientException if cancelling the task fails for any reason - */ - public abstract Task cancelTask(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException; - - /** - * Set or update the push notification configuration for a specific task. - * - * @param request the push notification configuration to set for the task - * @return the configured TaskPushNotificationConfig - * @throws A2AClientException if setting the task push notification configuration fails for any reason - */ - public TaskPushNotificationConfig setTaskPushNotificationConfiguration( - TaskPushNotificationConfig request) throws A2AClientException { - return setTaskPushNotificationConfiguration(request, null); - } - - /** - * Set or update the push notification configuration for a specific task. - * - * @param request the push notification configuration to set for the task - * @param context optional client call context for the request (may be {@code null}) - * @return the configured TaskPushNotificationConfig - * @throws A2AClientException if setting the task push notification configuration fails for any reason - */ - public abstract TaskPushNotificationConfig setTaskPushNotificationConfiguration( - TaskPushNotificationConfig request, - @Nullable ClientCallContext context) throws A2AClientException; - - /** - * Retrieve the push notification configuration for a specific task. - * - * @param request the parameters specifying which task's notification config to retrieve - * @return the task push notification config - * @throws A2AClientException if getting the task push notification config fails for any reason - */ - public TaskPushNotificationConfig getTaskPushNotificationConfiguration( - GetTaskPushNotificationConfigParams request) throws A2AClientException { - return getTaskPushNotificationConfiguration(request, null); - } - - /** - * Retrieve the push notification configuration for a specific task. - * - * @param request the parameters specifying which task's notification config to retrieve - * @param context optional client call context for the request (may be {@code null}) - * @return the task push notification config - * @throws A2AClientException if getting the task push notification config fails for any reason - */ - public abstract TaskPushNotificationConfig getTaskPushNotificationConfiguration( - GetTaskPushNotificationConfigParams request, - @Nullable ClientCallContext context) throws A2AClientException; - - /** - * Retrieve the list of push notification configurations for a specific task. - * - * @param request the parameters specifying which task's notification configs to retrieve - * @return the list of task push notification configs - * @throws A2AClientException if getting the task push notification configs fails for any reason - */ - public List listTaskPushNotificationConfigurations( - ListTaskPushNotificationConfigParams request) throws A2AClientException { - return listTaskPushNotificationConfigurations(request, null); - } - - /** - * Retrieve the list of push notification configurations for a specific task. - * - * @param request the parameters specifying which task's notification configs to retrieve - * @param context optional client call context for the request (may be {@code null}) - * @return the list of task push notification configs - * @throws A2AClientException if getting the task push notification configs fails for any reason - */ - public abstract List listTaskPushNotificationConfigurations( - ListTaskPushNotificationConfigParams request, - @Nullable ClientCallContext context) throws A2AClientException; - - /** - * Delete the list of push notification configurations for a specific task. - * - * @param request the parameters specifying which task's notification configs to delete - * @throws A2AClientException if deleting the task push notification configs fails for any reason - */ - public void deleteTaskPushNotificationConfigurations( - DeleteTaskPushNotificationConfigParams request) throws A2AClientException { - deleteTaskPushNotificationConfigurations(request, null); - } - - /** - * Delete the list of push notification configurations for a specific task. - * - * @param request the parameters specifying which task's notification configs to delete - * @param context optional client call context for the request (may be {@code null}) - * @throws A2AClientException if deleting the task push notification configs fails for any reason - */ - public abstract void deleteTaskPushNotificationConfigurations( - DeleteTaskPushNotificationConfigParams request, - @Nullable ClientCallContext context) throws A2AClientException; - - /** - * Resubscribe to a task's event stream. - * This is only available if both the client and server support streaming. - * The configured client consumers will be used to handle messages, tasks, - * and update events received from the remote agent. The configured streaming - * error handler will be used if an error occurs during streaming. - * - * @param request the parameters specifying which task's notification configs to delete - * @throws A2AClientException if resubscribing fails for any reason - */ - public void resubscribe(TaskIdParams request) throws A2AClientException { - resubscribe(request, null); - } - - /** - * Resubscribe to a task's event stream. - * This is only available if both the client and server support streaming. - * The configured client consumers will be used to handle messages, tasks, - * and update events received from the remote agent. The configured streaming - * error handler will be used if an error occurs during streaming. - * - * @param request the parameters specifying which task's notification configs to delete - * @param context optional client call context for the request (may be {@code null}) - * @throws A2AClientException if resubscribing fails for any reason - */ - public abstract void resubscribe(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException; - - /** - * Resubscribe to a task's event stream. - * This is only available if both the client and server support streaming. - * The specified client consumers will be used to handle messages, tasks, and - * update events received from the remote agent. The specified streaming error - * handler will be used if an error occurs during streaming. - * - * @param request the parameters specifying which task's notification configs to delete - * @param consumers a list of consumers to pass responses from the remote agent to - * @param streamingErrorHandler an error handler that should be used for the streaming case if an error occurs - * @throws A2AClientException if resubscribing fails for any reason - */ - public void resubscribe(TaskIdParams request, List> consumers, - Consumer streamingErrorHandler) throws A2AClientException { - resubscribe(request, consumers, streamingErrorHandler, null); - } - - /** - * Resubscribe to a task's event stream. - * This is only available if both the client and server support streaming. - * The specified client consumers will be used to handle messages, tasks, and - * update events received from the remote agent. The specified streaming error - * handler will be used if an error occurs during streaming. - * - * @param request the parameters specifying which task's notification configs to delete - * @param consumers a list of consumers to pass responses from the remote agent to - * @param streamingErrorHandler an error handler that should be used for the streaming case if an error occurs - * @param context optional client call context for the request (may be {@code null}) - * @throws A2AClientException if resubscribing fails for any reason - */ - public abstract void resubscribe(TaskIdParams request, List> consumers, - Consumer streamingErrorHandler, @Nullable ClientCallContext context) throws A2AClientException; - - /** - * Retrieve the AgentCard. - * - * @return the AgentCard - * @throws A2AClientException if retrieving the agent card fails for any reason - */ - public AgentCard getAgentCard() throws A2AClientException { - return getAgentCard(null); - } - - /** - * Retrieve the AgentCard. - * - * @param context optional client call context for the request (may be {@code null}) - * @return the AgentCard - * @throws A2AClientException if retrieving the agent card fails for any reason - */ - public abstract AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2AClientException; - - /** - * Close the transport and release any associated resources. - */ - public abstract void close(); - /** * Process the event using all configured consumers. */ @@ -388,5 +48,4 @@ void consume(ClientEvent clientEventOrMessage, AgentCard agentCard) { public @Nullable Consumer getStreamingErrorHandler() { return streamingErrorHandler; } - } \ No newline at end of file diff --git a/client/base/src/main/java/io/a2a/client/AsyncClient.java b/client/base/src/main/java/io/a2a/client/AsyncClient.java new file mode 100644 index 00000000..049828de --- /dev/null +++ b/client/base/src/main/java/io/a2a/client/AsyncClient.java @@ -0,0 +1,220 @@ +package io.a2a.client; + +import io.a2a.client.config.ClientConfig; +import io.a2a.client.transport.spi.ClientTransport; +import io.a2a.client.transport.spi.interceptors.ClientCallContext; +import io.a2a.spec.*; +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static io.a2a.util.Assert.checkNotNullParam; + +public class AsyncClient extends AbstractClient { + + private final ClientConfig clientConfig; + private final ClientTransport clientTransport; + private AgentCard agentCard; + + AsyncClient(AgentCard agentCard, ClientConfig clientConfig, ClientTransport clientTransport, + List> consumers, @Nullable Consumer streamingErrorHandler) { + super(consumers, streamingErrorHandler); + checkNotNullParam("agentCard", agentCard); + + this.agentCard = agentCard; + this.clientConfig = clientConfig; + this.clientTransport = clientTransport; + } + + public void sendMessage(Message request, @Nullable ClientCallContext context) throws A2AClientException { + MessageSendParams messageSendParams = getMessageSendParams(request, clientConfig); + sendMessage(messageSendParams, null, null, context); + } + + public void sendMessage(Message request, List> consumers, + Consumer streamingErrorHandler, @Nullable ClientCallContext context) throws A2AClientException { + MessageSendParams messageSendParams = getMessageSendParams(request, clientConfig); + sendMessage(messageSendParams, consumers, streamingErrorHandler, context); + } + + public void sendMessage(Message request, PushNotificationConfig pushNotificationConfiguration, + Map metatadata, @Nullable ClientCallContext context) throws A2AClientException { + MessageSendConfiguration messageSendConfiguration = createMessageSendConfiguration(pushNotificationConfiguration); + + MessageSendParams messageSendParams = new MessageSendParams.Builder() + .message(request) + .configuration(messageSendConfiguration) + .metadata(metatadata) + .build(); + + sendMessage(messageSendParams, null, null, context); + } + + public CompletableFuture getTask(TaskQueryParams request, @Nullable ClientCallContext context) throws A2AClientException { + return clientTransport.getTask(request, context); + } + + public CompletableFuture cancelTask(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException { + return clientTransport.cancelTask(request, context); + } + + public CompletableFuture setTaskPushNotificationConfiguration( + TaskPushNotificationConfig request, @Nullable ClientCallContext context) throws A2AClientException { + return clientTransport.setTaskPushNotificationConfiguration(request, context); + } + + public CompletableFuture getTaskPushNotificationConfiguration( + GetTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { + return clientTransport.getTaskPushNotificationConfiguration(request, context); + } + + public CompletableFuture> listTaskPushNotificationConfigurations( + ListTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { + return clientTransport.listTaskPushNotificationConfigurations(request, context); + } + + public CompletableFuture deleteTaskPushNotificationConfigurations( + DeleteTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { + return clientTransport.deleteTaskPushNotificationConfigurations(request, context); + } + + public void resubscribe(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException { + resubscribeToTask(request, null, null, context); + } + + public void resubscribe(TaskIdParams request, @Nullable List> consumers, + @Nullable Consumer streamingErrorHandler, @Nullable ClientCallContext context) throws A2AClientException { + resubscribeToTask(request, consumers, streamingErrorHandler, context); + } + + public CompletableFuture getAgentCard(@Nullable ClientCallContext context) throws A2AClientException { + return clientTransport.getAgentCard(context) + .whenComplete(new BiConsumer() { + @Override + public void accept(AgentCard agentCard, Throwable throwable) { + if (agentCard != null) { + AsyncClient.this.agentCard = agentCard; + } + } + }); + } + + public void close() { + clientTransport.close(); + } + + private ClientEvent getClientEvent(StreamingEventKind event, ClientTaskManager taskManager) throws A2AClientError { + if (event instanceof Message message) { + return new MessageEvent(message); + } else if (event instanceof Task task) { + taskManager.saveTaskEvent(task); + return new TaskEvent(taskManager.getCurrentTask()); + } else if (event instanceof TaskStatusUpdateEvent updateEvent) { + taskManager.saveTaskEvent(updateEvent); + return new TaskUpdateEvent(taskManager.getCurrentTask(), updateEvent); + } else if (event instanceof TaskArtifactUpdateEvent updateEvent) { + taskManager.saveTaskEvent(updateEvent); + return new TaskUpdateEvent(taskManager.getCurrentTask(), updateEvent); + } else { + throw new A2AClientInvalidStateError("Invalid client event"); + } + } + + private MessageSendConfiguration createMessageSendConfiguration(@Nullable PushNotificationConfig pushNotificationConfig) { + return new MessageSendConfiguration.Builder() + .acceptedOutputModes(clientConfig.getAcceptedOutputModes()) + .blocking(!clientConfig.isPolling()) + .historyLength(clientConfig.getHistoryLength()) + .pushNotificationConfig(pushNotificationConfig) + .build(); + } + + private void sendMessage(MessageSendParams messageSendParams, @Nullable List> consumers, + @Nullable Consumer errorHandler, @Nullable ClientCallContext context) throws A2AClientException { + if (! clientConfig.isStreaming() || ! agentCard.capabilities().streaming()) { + clientTransport.sendMessage(messageSendParams, context) + .thenAccept(new Consumer() { + @Override + public void accept(EventKind eventKind) { + ClientEvent clientEvent; + if (eventKind instanceof Task task) { + clientEvent = new TaskEvent(task); + } else { + // must be a message + clientEvent = new MessageEvent((Message) eventKind); + } + consume(clientEvent, agentCard, consumers); + } + }); + } else { + ClientTaskManager tracker = new ClientTaskManager(); + Consumer overriddenErrorHandler = getOverriddenErrorHandler(errorHandler); + Consumer eventHandler = event -> { + try { + ClientEvent clientEvent = getClientEvent(event, tracker); + consume(clientEvent, agentCard, consumers); + } catch (A2AClientError e) { + overriddenErrorHandler.accept(e); + } + }; + clientTransport.sendMessageStreaming(messageSendParams, eventHandler, overriddenErrorHandler, context); + } + } + + private void resubscribeToTask(TaskIdParams request, @Nullable List> consumers, + @Nullable Consumer errorHandler, @Nullable ClientCallContext context) throws A2AClientException { + if (! clientConfig.isStreaming() || ! agentCard.capabilities().streaming()) { + throw new A2AClientException("Client and/or server does not support resubscription"); + } + ClientTaskManager tracker = new ClientTaskManager(); + Consumer overriddenErrorHandler = getOverriddenErrorHandler(errorHandler); + Consumer eventHandler = event -> { + try { + ClientEvent clientEvent = getClientEvent(event, tracker); + consume(clientEvent, agentCard, consumers); + } catch (A2AClientError e) { + overriddenErrorHandler.accept(e); + } + }; + clientTransport.resubscribe(request, eventHandler, overriddenErrorHandler, context); + } + + private @NonNull Consumer getOverriddenErrorHandler(@Nullable Consumer errorHandler) { + return e -> { + if (errorHandler != null) { + errorHandler.accept(e); + } else { + if (getStreamingErrorHandler() != null) { + getStreamingErrorHandler().accept(e); + } + } + }; + } + + private void consume(ClientEvent clientEvent, AgentCard agentCard, @Nullable List> consumers) { + if (consumers != null) { + // use specified consumers + for (BiConsumer consumer : consumers) { + consumer.accept(clientEvent, agentCard); + } + } else { + // use configured consumers + consume(clientEvent, agentCard); + } + } + + private MessageSendParams getMessageSendParams(Message request, ClientConfig clientConfig) { + MessageSendConfiguration messageSendConfiguration = createMessageSendConfiguration(clientConfig.getPushNotificationConfig()); + + return new MessageSendParams.Builder() + .message(request) + .configuration(messageSendConfiguration) + .metadata(clientConfig.getMetadata()) + .build(); + } +} diff --git a/client/base/src/main/java/io/a2a/client/Client.java b/client/base/src/main/java/io/a2a/client/Client.java index ab222266..5a26f6d8 100644 --- a/client/base/src/main/java/io/a2a/client/Client.java +++ b/client/base/src/main/java/io/a2a/client/Client.java @@ -1,46 +1,37 @@ package io.a2a.client; +import static io.a2a.util.Assert.checkNotNullParam; + import java.util.List; import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Consumer; import io.a2a.client.config.ClientConfig; -import io.a2a.client.transport.spi.interceptors.ClientCallContext; import io.a2a.client.transport.spi.ClientTransport; -import io.a2a.spec.A2AClientError; +import io.a2a.client.transport.spi.interceptors.ClientCallContext; import io.a2a.spec.A2AClientException; -import io.a2a.spec.A2AClientInvalidStateError; import io.a2a.spec.AgentCard; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; -import io.a2a.spec.EventKind; import io.a2a.spec.GetTaskPushNotificationConfigParams; import io.a2a.spec.ListTaskPushNotificationConfigParams; import io.a2a.spec.Message; -import io.a2a.spec.MessageSendConfiguration; -import io.a2a.spec.MessageSendParams; import io.a2a.spec.PushNotificationConfig; -import io.a2a.spec.StreamingEventKind; import io.a2a.spec.Task; -import io.a2a.spec.TaskArtifactUpdateEvent; import io.a2a.spec.TaskIdParams; import io.a2a.spec.TaskPushNotificationConfig; import io.a2a.spec.TaskQueryParams; -import io.a2a.spec.TaskStatusUpdateEvent; - -import static io.a2a.util.Assert.checkNotNullParam; - import org.jspecify.annotations.NonNull; import org.jspecify.annotations.Nullable; -public class Client extends AbstractClient { +public abstract class Client extends AbstractClient { - private final ClientConfig clientConfig; - private final ClientTransport clientTransport; - private AgentCard agentCard; + protected final ClientConfig clientConfig; + protected final ClientTransport clientTransport; + protected AgentCard agentCard; Client(AgentCard agentCard, ClientConfig clientConfig, ClientTransport clientTransport, - List> consumers, @Nullable Consumer streamingErrorHandler) { + List> consumers, @Nullable Consumer streamingErrorHandler) { super(consumers, streamingErrorHandler); checkNotNullParam("agentCard", agentCard); @@ -53,191 +44,331 @@ public static ClientBuilder builder(AgentCard agentCard) { return new ClientBuilder(agentCard); } - @Override - public void sendMessage(Message request, @Nullable ClientCallContext context) throws A2AClientException { - MessageSendParams messageSendParams = getMessageSendParams(request, clientConfig); - sendMessage(messageSendParams, null, null, context); + /** + * Send a message to the remote agent. This method will automatically use + * the streaming or non-streaming approach as determined by the server's + * agent card and the client configuration. The configured client consumers + * will be used to handle messages, tasks, and update events received + * from the remote agent. The configured streaming error handler will be used + * if an error occurs during streaming. The configured client push notification + * configuration will get used for streaming. + * + * @param request the message + * @throws A2AClientException if sending the message fails for any reason + */ + public void sendMessage(Message request) throws A2AClientException { + sendMessage(request, null); } - @Override - public void sendMessage(Message request, List> consumers, - Consumer streamingErrorHandler, @Nullable ClientCallContext context) throws A2AClientException { - MessageSendParams messageSendParams = getMessageSendParams(request, clientConfig); - sendMessage(messageSendParams, consumers, streamingErrorHandler, context); + /** + * Send a message to the remote agent. This method will automatically use + * the streaming or non-streaming approach as determined by the server's + * agent card and the client configuration. The configured client consumers + * will be used to handle messages, tasks, and update events received + * from the remote agent. The configured streaming error handler will be used + * if an error occurs during streaming. The configured client push notification + * configuration will get used for streaming. + * + * @param request the message + * @param context optional client call context for the request (may be {@code null}) + * @throws A2AClientException if sending the message fails for any reason + */ + public abstract void sendMessage(Message request, @Nullable ClientCallContext context) throws A2AClientException; + + /** + * Send a message to the remote agent. This method will automatically use + * the streaming or non-streaming approach as determined by the server's + * agent card and the client configuration. The specified client consumers + * will be used to handle messages, tasks, and update events received + * from the remote agent. The specified streaming error handler will be used + * if an error occurs during streaming. The configured client push notification + * configuration will get used for streaming. + * + * @param request the message + * @param consumers a list of consumers to pass responses from the remote agent to + * @param streamingErrorHandler an error handler that should be used for the streaming case if an error occurs + * @throws A2AClientException if sending the message fails for any reason + */ + public void sendMessage(Message request, + List> consumers, + Consumer streamingErrorHandler) throws A2AClientException { + sendMessage(request, consumers, streamingErrorHandler, null); } - @Override + /** + * Send a message to the remote agent. This method will automatically use + * the streaming or non-streaming approach as determined by the server's + * agent card and the client configuration. The specified client consumers + * will be used to handle messages, tasks, and update events received + * from the remote agent. The specified streaming error handler will be used + * if an error occurs during streaming. The configured client push notification + * configuration will get used for streaming. + * + * @param request the message + * @param consumers a list of consumers to pass responses from the remote agent to + * @param streamingErrorHandler an error handler that should be used for the streaming case if an error occurs + * @param context optional client call context for the request (may be {@code null}) + * @throws A2AClientException if sending the message fails for any reason + */ + public abstract void sendMessage(Message request, + List> consumers, + Consumer streamingErrorHandler, + @Nullable ClientCallContext context) throws A2AClientException; + + /** + * Send a message to the remote agent. This method will automatically use + * the streaming or non-streaming approach as determined by the server's + * agent card and the client configuration. The configured client consumers + * will be used to handle messages, tasks, and update events received from + * the remote agent. The configured streaming error handler will be used + * if an error occurs during streaming. + * + * @param request the message + * @param pushNotificationConfiguration the push notification configuration that should be + * used if the streaming approach is used + * @param metadata the optional metadata to include when sending the message + * @throws A2AClientException if sending the message fails for any reason + */ public void sendMessage(Message request, PushNotificationConfig pushNotificationConfiguration, - Map metatadata, @Nullable ClientCallContext context) throws A2AClientException { - MessageSendConfiguration messageSendConfiguration = createMessageSendConfiguration(pushNotificationConfiguration); - - MessageSendParams messageSendParams = new MessageSendParams.Builder() - .message(request) - .configuration(messageSendConfiguration) - .metadata(metatadata) - .build(); - - sendMessage(messageSendParams, null, null, context); + Map metadata) throws A2AClientException { + sendMessage(request, pushNotificationConfiguration, metadata, null); } - @Override - public Task getTask(TaskQueryParams request, @Nullable ClientCallContext context) throws A2AClientException { - return clientTransport.getTask(request, context); + /** + * Send a message to the remote agent. This method will automatically use + * the streaming or non-streaming approach as determined by the server's + * agent card and the client configuration. The configured client consumers + * will be used to handle messages, tasks, and update events received from + * the remote agent. The configured streaming error handler will be used + * if an error occurs during streaming. + * + * @param request the message + * @param pushNotificationConfiguration the push notification configuration that should be + * used if the streaming approach is used + * @param metadata the optional metadata to include when sending the message + * @param context optional client call context for the request (may be {@code null}) + * @throws A2AClientException if sending the message fails for any reason + */ + public abstract void sendMessage(Message request, PushNotificationConfig pushNotificationConfiguration, + Map metadata, @Nullable ClientCallContext context) throws A2AClientException; + + /** + * Retrieve the current state and history of a specific task. + * + * @param request the task query parameters specifying which task to retrieve + * @return the task + * @throws A2AClientException if retrieving the task fails for any reason + */ + public Task getTask(TaskQueryParams request) throws A2AClientException { + return getTask(request, null); } - @Override - public Task cancelTask(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException { - return clientTransport.cancelTask(request, context); + /** + * Retrieve the current state and history of a specific task. + * + * @param request the task query parameters specifying which task to retrieve + * @param context optional client call context for the request (may be {@code null}) + * @return the task + * @throws A2AClientException if retrieving the task fails for any reason + */ + public abstract Task getTask(TaskQueryParams request, @Nullable ClientCallContext context) throws A2AClientException; + + /** + * Request the agent to cancel a specific task. + * + * @param request the task ID parameters specifying which task to cancel + * @return the cancelled task + * @throws A2AClientException if cancelling the task fails for any reason + */ + public Task cancelTask(TaskIdParams request) throws A2AClientException { + return cancelTask(request, null); } - @Override + /** + * Request the agent to cancel a specific task. + * + * @param request the task ID parameters specifying which task to cancel + * @param context optional client call context for the request (may be {@code null}) + * @return the cancelled task + * @throws A2AClientException if cancelling the task fails for any reason + */ + public abstract Task cancelTask(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException; + + /** + * Set or update the push notification configuration for a specific task. + * + * @param request the push notification configuration to set for the task + * @return the configured TaskPushNotificationConfig + * @throws A2AClientException if setting the task push notification configuration fails for any reason + */ public TaskPushNotificationConfig setTaskPushNotificationConfiguration( - TaskPushNotificationConfig request, @Nullable ClientCallContext context) throws A2AClientException { - return clientTransport.setTaskPushNotificationConfiguration(request, context); + TaskPushNotificationConfig request) throws A2AClientException { + return setTaskPushNotificationConfiguration(request, null); } - @Override + /** + * Set or update the push notification configuration for a specific task. + * + * @param request the push notification configuration to set for the task + * @param context optional client call context for the request (may be {@code null}) + * @return the configured TaskPushNotificationConfig + * @throws A2AClientException if setting the task push notification configuration fails for any reason + */ + public abstract TaskPushNotificationConfig setTaskPushNotificationConfiguration( + TaskPushNotificationConfig request, + @Nullable ClientCallContext context) throws A2AClientException; + + /** + * Retrieve the push notification configuration for a specific task. + * + * @param request the parameters specifying which task's notification config to retrieve + * @return the task push notification config + * @throws A2AClientException if getting the task push notification config fails for any reason + */ public TaskPushNotificationConfig getTaskPushNotificationConfiguration( - GetTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { - return clientTransport.getTaskPushNotificationConfiguration(request, context); + GetTaskPushNotificationConfigParams request) throws A2AClientException { + return getTaskPushNotificationConfiguration(request, null); } - @Override + /** + * Retrieve the push notification configuration for a specific task. + * + * @param request the parameters specifying which task's notification config to retrieve + * @param context optional client call context for the request (may be {@code null}) + * @return the task push notification config + * @throws A2AClientException if getting the task push notification config fails for any reason + */ + public abstract TaskPushNotificationConfig getTaskPushNotificationConfiguration( + GetTaskPushNotificationConfigParams request, + @Nullable ClientCallContext context) throws A2AClientException; + + /** + * Retrieve the list of push notification configurations for a specific task. + * + * @param request the parameters specifying which task's notification configs to retrieve + * @return the list of task push notification configs + * @throws A2AClientException if getting the task push notification configs fails for any reason + */ public List listTaskPushNotificationConfigurations( - ListTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { - return clientTransport.listTaskPushNotificationConfigurations(request, context); + ListTaskPushNotificationConfigParams request) throws A2AClientException { + return listTaskPushNotificationConfigurations(request, null); } - @Override + /** + * Retrieve the list of push notification configurations for a specific task. + * + * @param request the parameters specifying which task's notification configs to retrieve + * @param context optional client call context for the request (may be {@code null}) + * @return the list of task push notification configs + * @throws A2AClientException if getting the task push notification configs fails for any reason + */ + public abstract List listTaskPushNotificationConfigurations( + ListTaskPushNotificationConfigParams request, + @Nullable ClientCallContext context) throws A2AClientException; + + /** + * Delete the list of push notification configurations for a specific task. + * + * @param request the parameters specifying which task's notification configs to delete + * @throws A2AClientException if deleting the task push notification configs fails for any reason + */ public void deleteTaskPushNotificationConfigurations( - DeleteTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { - clientTransport.deleteTaskPushNotificationConfigurations(request, context); - } - - @Override - public void resubscribe(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException { - resubscribeToTask(request, null, null, context); - } - - @Override - public void resubscribe(TaskIdParams request, @Nullable List> consumers, - @Nullable Consumer streamingErrorHandler, @Nullable ClientCallContext context) throws A2AClientException { - resubscribeToTask(request, consumers, streamingErrorHandler, context); + DeleteTaskPushNotificationConfigParams request) throws A2AClientException { + deleteTaskPushNotificationConfigurations(request, null); } - @Override - public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2AClientException { - agentCard = clientTransport.getAgentCard(context); - return agentCard; + /** + * Delete the list of push notification configurations for a specific task. + * + * @param request the parameters specifying which task's notification configs to delete + * @param context optional client call context for the request (may be {@code null}) + * @throws A2AClientException if deleting the task push notification configs fails for any reason + */ + public abstract void deleteTaskPushNotificationConfigurations( + DeleteTaskPushNotificationConfigParams request, + @Nullable ClientCallContext context) throws A2AClientException; + + /** + * Resubscribe to a task's event stream. + * This is only available if both the client and server support streaming. + * The configured client consumers will be used to handle messages, tasks, + * and update events received from the remote agent. The configured streaming + * error handler will be used if an error occurs during streaming. + * + * @param request the parameters specifying which task's notification configs to delete + * @throws A2AClientException if resubscribing fails for any reason + */ + public void resubscribe(TaskIdParams request) throws A2AClientException { + resubscribe(request, null); } - @Override - public void close() { - clientTransport.close(); + /** + * Resubscribe to a task's event stream. + * This is only available if both the client and server support streaming. + * The configured client consumers will be used to handle messages, tasks, + * and update events received from the remote agent. The configured streaming + * error handler will be used if an error occurs during streaming. + * + * @param request the parameters specifying which task's notification configs to delete + * @param context optional client call context for the request (may be {@code null}) + * @throws A2AClientException if resubscribing fails for any reason + */ + public abstract void resubscribe(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException; + + /** + * Resubscribe to a task's event stream. + * This is only available if both the client and server support streaming. + * The specified client consumers will be used to handle messages, tasks, and + * update events received from the remote agent. The specified streaming error + * handler will be used if an error occurs during streaming. + * + * @param request the parameters specifying which task's notification configs to delete + * @param consumers a list of consumers to pass responses from the remote agent to + * @param streamingErrorHandler an error handler that should be used for the streaming case if an error occurs + * @throws A2AClientException if resubscribing fails for any reason + */ + public void resubscribe(TaskIdParams request, List> consumers, + Consumer streamingErrorHandler) throws A2AClientException { + resubscribe(request, consumers, streamingErrorHandler, null); } - private ClientEvent getClientEvent(StreamingEventKind event, ClientTaskManager taskManager) throws A2AClientError { - if (event instanceof Message message) { - return new MessageEvent(message); - } else if (event instanceof Task task) { - taskManager.saveTaskEvent(task); - return new TaskEvent(taskManager.getCurrentTask()); - } else if (event instanceof TaskStatusUpdateEvent updateEvent) { - taskManager.saveTaskEvent(updateEvent); - return new TaskUpdateEvent(taskManager.getCurrentTask(), updateEvent); - } else if (event instanceof TaskArtifactUpdateEvent updateEvent) { - taskManager.saveTaskEvent(updateEvent); - return new TaskUpdateEvent(taskManager.getCurrentTask(), updateEvent); - } else { - throw new A2AClientInvalidStateError("Invalid client event"); - } + /** + * Resubscribe to a task's event stream. + * This is only available if both the client and server support streaming. + * The specified client consumers will be used to handle messages, tasks, and + * update events received from the remote agent. The specified streaming error + * handler will be used if an error occurs during streaming. + * + * @param request the parameters specifying which task's notification configs to delete + * @param consumers a list of consumers to pass responses from the remote agent to + * @param streamingErrorHandler an error handler that should be used for the streaming case if an error occurs + * @param context optional client call context for the request (may be {@code null}) + * @throws A2AClientException if resubscribing fails for any reason + */ + public abstract void resubscribe(TaskIdParams request, List> consumers, + Consumer streamingErrorHandler, @Nullable ClientCallContext context) throws A2AClientException; + + /** + * Retrieve the AgentCard. + * + * @return the AgentCard + * @throws A2AClientException if retrieving the agent card fails for any reason + */ + public AgentCard getAgentCard() throws A2AClientException { + return getAgentCard(null); } - private MessageSendConfiguration createMessageSendConfiguration(@Nullable PushNotificationConfig pushNotificationConfig) { - return new MessageSendConfiguration.Builder() - .acceptedOutputModes(clientConfig.getAcceptedOutputModes()) - .blocking(!clientConfig.isPolling()) - .historyLength(clientConfig.getHistoryLength()) - .pushNotificationConfig(pushNotificationConfig) - .build(); - } - - private void sendMessage(MessageSendParams messageSendParams, @Nullable List> consumers, - @Nullable Consumer errorHandler, @Nullable ClientCallContext context) throws A2AClientException { - if (! clientConfig.isStreaming() || ! agentCard.capabilities().streaming()) { - EventKind eventKind = clientTransport.sendMessage(messageSendParams, context); - ClientEvent clientEvent; - if (eventKind instanceof Task task) { - clientEvent = new TaskEvent(task); - } else { - // must be a message - clientEvent = new MessageEvent((Message) eventKind); - } - consume(clientEvent, agentCard, consumers); - } else { - ClientTaskManager tracker = new ClientTaskManager(); - Consumer overriddenErrorHandler = getOverriddenErrorHandler(errorHandler); - Consumer eventHandler = event -> { - try { - ClientEvent clientEvent = getClientEvent(event, tracker); - consume(clientEvent, agentCard, consumers); - } catch (A2AClientError e) { - overriddenErrorHandler.accept(e); - } - }; - clientTransport.sendMessageStreaming(messageSendParams, eventHandler, overriddenErrorHandler, context); - } - } - - private void resubscribeToTask(TaskIdParams request, @Nullable List> consumers, - @Nullable Consumer errorHandler, @Nullable ClientCallContext context) throws A2AClientException { - if (! clientConfig.isStreaming() || ! agentCard.capabilities().streaming()) { - throw new A2AClientException("Client and/or server does not support resubscription"); - } - ClientTaskManager tracker = new ClientTaskManager(); - Consumer overriddenErrorHandler = getOverriddenErrorHandler(errorHandler); - Consumer eventHandler = event -> { - try { - ClientEvent clientEvent = getClientEvent(event, tracker); - consume(clientEvent, agentCard, consumers); - } catch (A2AClientError e) { - overriddenErrorHandler.accept(e); - } - }; - clientTransport.resubscribe(request, eventHandler, overriddenErrorHandler, context); - } - - private @NonNull Consumer getOverriddenErrorHandler(@Nullable Consumer errorHandler) { - return e -> { - if (errorHandler != null) { - errorHandler.accept(e); - } else { - if (getStreamingErrorHandler() != null) { - getStreamingErrorHandler().accept(e); - } - } - }; - } - - private void consume(ClientEvent clientEvent, AgentCard agentCard, @Nullable List> consumers) { - if (consumers != null) { - // use specified consumers - for (BiConsumer consumer : consumers) { - consumer.accept(clientEvent, agentCard); - } - } else { - // use configured consumers - consume(clientEvent, agentCard); - } - } - - private MessageSendParams getMessageSendParams(Message request, ClientConfig clientConfig) { - MessageSendConfiguration messageSendConfiguration = createMessageSendConfiguration(clientConfig.getPushNotificationConfig()); - - return new MessageSendParams.Builder() - .message(request) - .configuration(messageSendConfiguration) - .metadata(clientConfig.getMetadata()) - .build(); - } + /** + * Retrieve the AgentCard. + * + * @param context optional client call context for the request (may be {@code null}) + * @return the AgentCard + * @throws A2AClientException if retrieving the agent card fails for any reason + */ + public abstract AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2AClientException; + + /** + * Close the transport and release any associated resources. + */ + public abstract void close(); } diff --git a/client/base/src/main/java/io/a2a/client/ClientBuilder.java b/client/base/src/main/java/io/a2a/client/ClientBuilder.java index da61dfa3..accfaaec 100644 --- a/client/base/src/main/java/io/a2a/client/ClientBuilder.java +++ b/client/base/src/main/java/io/a2a/client/ClientBuilder.java @@ -76,14 +76,34 @@ public ClientBuilder clientConfig(@NonNull ClientConfig clientConfig) { return this; } - public Client build() throws A2AClientException { + /** + * Keep this method to maintain backward compatibility. + * @return A synchronous version of the A2AClient + * @throws A2AClientException + * @deprecated Instead use {@link ClientBuilder#sync()} + */ + public SyncClient build() throws A2AClientException { + return sync(); + } + + public SyncClient sync() throws A2AClientException { + if (this.clientConfig == null) { + this.clientConfig = new ClientConfig.Builder().build(); + } + + ClientTransport clientTransport = buildClientTransport(); + + return new SyncClient(agentCard, clientConfig, clientTransport, consumers, streamErrorHandler); + } + + public AsyncClient async() throws A2AClientException { if (this.clientConfig == null) { this.clientConfig = new ClientConfig.Builder().build(); } ClientTransport clientTransport = buildClientTransport(); - return new Client(agentCard, clientConfig, clientTransport, consumers, streamErrorHandler); + return new AsyncClient(agentCard, clientConfig, clientTransport, consumers, streamErrorHandler); } @SuppressWarnings("unchecked") diff --git a/client/base/src/main/java/io/a2a/client/SyncClient.java b/client/base/src/main/java/io/a2a/client/SyncClient.java new file mode 100644 index 00000000..b3c52b27 --- /dev/null +++ b/client/base/src/main/java/io/a2a/client/SyncClient.java @@ -0,0 +1,242 @@ +package io.a2a.client; + +import io.a2a.client.config.ClientConfig; +import io.a2a.client.transport.spi.ClientTransport; +import io.a2a.client.transport.spi.interceptors.ClientCallContext; +import io.a2a.spec.*; +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +public class SyncClient extends Client { + + SyncClient(AgentCard agentCard, ClientConfig clientConfig, ClientTransport clientTransport, + List> consumers, @Nullable Consumer streamingErrorHandler) { + super(agentCard, clientConfig, clientTransport, consumers, streamingErrorHandler); + } + + public static ClientBuilder builder(AgentCard agentCard) { + return new ClientBuilder(agentCard); + } + + public void sendMessage(Message request, @Nullable ClientCallContext context) throws A2AClientException { + MessageSendParams messageSendParams = getMessageSendParams(request, clientConfig); + sendMessage(messageSendParams, null, null, context); + } + + public void sendMessage(Message request, List> consumers, + Consumer streamingErrorHandler, @Nullable ClientCallContext context) throws A2AClientException { + MessageSendParams messageSendParams = getMessageSendParams(request, clientConfig); + sendMessage(messageSendParams, consumers, streamingErrorHandler, context); + } + + public void sendMessage(Message request, PushNotificationConfig pushNotificationConfiguration, + Map metatadata, @Nullable ClientCallContext context) throws A2AClientException { + MessageSendConfiguration messageSendConfiguration = createMessageSendConfiguration(pushNotificationConfiguration); + + MessageSendParams messageSendParams = new MessageSendParams.Builder() + .message(request) + .configuration(messageSendConfiguration) + .metadata(metatadata) + .build(); + + sendMessage(messageSendParams, null, null, context); + } + + public Task getTask(TaskQueryParams request, @Nullable ClientCallContext context) throws A2AClientException { + try { + return clientTransport.getTask(request, context).get(); + } catch (ExecutionException | InterruptedException e) { + throw new A2AClientException("Failed to get task: " + e, e); + } + } + + public Task cancelTask(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException { + try { + return clientTransport.cancelTask(request, context).get(); + } catch (ExecutionException | InterruptedException e) { + throw new A2AClientException("Failed to cancel task: " + e, e); + } + } + + public TaskPushNotificationConfig setTaskPushNotificationConfiguration( + TaskPushNotificationConfig request, @Nullable ClientCallContext context) throws A2AClientException { + try { + return clientTransport.setTaskPushNotificationConfiguration(request, context).get(); + } catch (ExecutionException | InterruptedException e) { + throw new A2AClientException("Failed to set task push notification config: " + e, e); + } + } + + public TaskPushNotificationConfig getTaskPushNotificationConfiguration( + GetTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { + try { + return clientTransport.getTaskPushNotificationConfiguration(request, context).get(); + } catch (ExecutionException | InterruptedException e) { + throw new A2AClientException("Failed to get task push notification config: " + e, e); + } + } + + public List listTaskPushNotificationConfigurations( + ListTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { + try { + return clientTransport.listTaskPushNotificationConfigurations(request, context).get(); + } catch (ExecutionException | InterruptedException e) { + throw new A2AClientException("Failed to list task push notification configs: " + e, e); + } + } + + public void deleteTaskPushNotificationConfigurations( + DeleteTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { + try { + clientTransport.deleteTaskPushNotificationConfigurations(request, context).get(); + } catch (ExecutionException | InterruptedException e) { + throw new A2AClientException("Failed to delete task push notification config: " + e, e); + } + } + + public void resubscribe(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException { + resubscribeToTask(request, null, null, context); + } + + @Override + public void resubscribe(TaskIdParams request, @Nullable List> consumers, + @Nullable Consumer streamingErrorHandler, @Nullable ClientCallContext context) throws A2AClientException { + resubscribeToTask(request, consumers, streamingErrorHandler, context); + } + + @Override + public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2AClientException { + try { + agentCard = clientTransport.getAgentCard(context).get(); + return agentCard; + } catch (ExecutionException | InterruptedException e) { + throw new A2AClientException("Failed to get agend card: " + e, e); + } + } + + @Override + public void close() { + clientTransport.close(); + } + + private ClientEvent getClientEvent(StreamingEventKind event, ClientTaskManager taskManager) throws A2AClientError { + if (event instanceof Message message) { + return new MessageEvent(message); + } else if (event instanceof Task task) { + taskManager.saveTaskEvent(task); + return new TaskEvent(taskManager.getCurrentTask()); + } else if (event instanceof TaskStatusUpdateEvent updateEvent) { + taskManager.saveTaskEvent(updateEvent); + return new TaskUpdateEvent(taskManager.getCurrentTask(), updateEvent); + } else if (event instanceof TaskArtifactUpdateEvent updateEvent) { + taskManager.saveTaskEvent(updateEvent); + return new TaskUpdateEvent(taskManager.getCurrentTask(), updateEvent); + } else { + throw new A2AClientInvalidStateError("Invalid client event"); + } + } + + private MessageSendConfiguration createMessageSendConfiguration(@Nullable PushNotificationConfig pushNotificationConfig) { + return new MessageSendConfiguration.Builder() + .acceptedOutputModes(clientConfig.getAcceptedOutputModes()) + .blocking(!clientConfig.isPolling()) + .historyLength(clientConfig.getHistoryLength()) + .pushNotificationConfig(pushNotificationConfig) + .build(); + } + + private void sendMessage(MessageSendParams messageSendParams, @Nullable List> consumers, + @Nullable Consumer errorHandler, @Nullable ClientCallContext context) throws A2AClientException { + if (! clientConfig.isStreaming() || ! agentCard.capabilities().streaming()) { + try { + EventKind eventKind = clientTransport.sendMessage(messageSendParams, context).get(); + ClientEvent clientEvent; + if (eventKind instanceof Task task) { + clientEvent = new TaskEvent(task); + } else { + // must be a message + clientEvent = new MessageEvent((Message) eventKind); + } + consume(clientEvent, agentCard, consumers); + + } catch (InterruptedException | ExecutionException e) { + if (e.getCause() instanceof A2AClientException) { + throw (A2AClientException) e.getCause(); + } + + throw new A2AClientException("Unable to send message", e); + } + } else { + ClientTaskManager tracker = new ClientTaskManager(); + Consumer overriddenErrorHandler = getOverriddenErrorHandler(errorHandler); + Consumer eventHandler = event -> { + try { + ClientEvent clientEvent = getClientEvent(event, tracker); + consume(clientEvent, agentCard, consumers); + } catch (A2AClientError e) { + overriddenErrorHandler.accept(e); + } + }; + clientTransport.sendMessageStreaming(messageSendParams, eventHandler, overriddenErrorHandler, context); + } + } + + private void resubscribeToTask(TaskIdParams request, @Nullable List> consumers, + @Nullable Consumer errorHandler, @Nullable ClientCallContext context) throws A2AClientException { + if (! clientConfig.isStreaming() || ! agentCard.capabilities().streaming()) { + throw new A2AClientException("Client and/or server does not support resubscription"); + } + ClientTaskManager tracker = new ClientTaskManager(); + Consumer overriddenErrorHandler = getOverriddenErrorHandler(errorHandler); + Consumer eventHandler = event -> { + try { + ClientEvent clientEvent = getClientEvent(event, tracker); + consume(clientEvent, agentCard, consumers); + } catch (A2AClientError e) { + overriddenErrorHandler.accept(e); + } + }; + clientTransport.resubscribe(request, eventHandler, overriddenErrorHandler, context); + } + + private @NonNull Consumer getOverriddenErrorHandler(@Nullable Consumer errorHandler) { + return e -> { + if (errorHandler != null) { + errorHandler.accept(e); + } else { + if (getStreamingErrorHandler() != null) { + getStreamingErrorHandler().accept(e); + } + } + }; + } + + private void consume(ClientEvent clientEvent, AgentCard agentCard, @Nullable List> consumers) { + if (consumers != null) { + // use specified consumers + for (BiConsumer consumer : consumers) { + consumer.accept(clientEvent, agentCard); + } + } else { + // use configured consumers + consume(clientEvent, agentCard); + } + } + + private MessageSendParams getMessageSendParams(Message request, ClientConfig clientConfig) { + MessageSendConfiguration messageSendConfiguration = createMessageSendConfiguration(clientConfig.getPushNotificationConfig()); + + return new MessageSendParams.Builder() + .message(request) + .configuration(messageSendConfiguration) + .metadata(clientConfig.getMetadata()) + .build(); + } +} diff --git a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java index d1943f27..e6ff4c0a 100644 --- a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java +++ b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java @@ -8,9 +8,14 @@ import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.function.Function; import java.util.stream.Collectors; +import com.google.protobuf.Empty; import io.a2a.client.transport.spi.AbstractClientTransport; import io.a2a.client.transport.spi.interceptors.ClientCallContext; import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor; @@ -25,25 +30,10 @@ import io.a2a.grpc.GetTaskRequest; import io.a2a.grpc.ListTaskPushNotificationConfigRequest; import io.a2a.grpc.SendMessageRequest; -import io.a2a.grpc.SendMessageResponse; import io.a2a.grpc.StreamResponse; import io.a2a.grpc.TaskSubscriptionRequest; -import io.a2a.spec.A2AClientException; -import io.a2a.spec.AgentCard; -import io.a2a.spec.DeleteTaskPushNotificationConfigParams; -import io.a2a.spec.EventKind; -import io.a2a.spec.GetTaskPushNotificationConfigParams; -import io.a2a.spec.ListTaskPushNotificationConfigParams; -import io.a2a.spec.MessageSendParams; -import io.a2a.spec.SendStreamingMessageRequest; -import io.a2a.spec.SetTaskPushNotificationConfigRequest; -import io.a2a.spec.StreamingEventKind; -import io.a2a.spec.Task; -import io.a2a.spec.TaskIdParams; -import io.a2a.spec.TaskPushNotificationConfig; -import io.a2a.spec.TaskQueryParams; -import io.a2a.spec.TaskResubscriptionRequest; +import io.a2a.spec.*; import io.grpc.Channel; import io.grpc.Metadata; import io.grpc.StatusRuntimeException; @@ -75,26 +65,30 @@ public GrpcTransport(Channel channel, AgentCard agentCard, List sendMessage(MessageSendParams request, ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); SendMessageRequest sendMessageRequest = createGrpcSendMessageRequest(request, context); PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.SendMessageRequest.METHOD, sendMessageRequest, agentCard, context); - try { - A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders); - SendMessageResponse response = stubWithMetadata.sendMessage(sendMessageRequest); - if (response.hasMsg()) { - return FromProto.message(response.getMsg()); - } else if (response.hasTask()) { - return FromProto.task(response.getTask()); - } else { - throw new A2AClientException("Server response did not contain a message or task"); - } - } catch (StatusRuntimeException e) { - throw GrpcErrorMapper.mapGrpcError(e, "Failed to send message: "); - } + A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context, payloadAndHeaders); + SingleValueStreamObserver observer = new SingleValueStreamObserver<>(); + stubWithMetadata.sendMessage(sendMessageRequest, observer); + + return observer.completionStage() + .thenCompose(new Function>() { + @Override + public CompletionStage apply(io.a2a.grpc.SendMessageResponse response) { + if (response.hasMsg()) { + return CompletableFuture.completedFuture(FromProto.message(response.getMsg())); + } else if (response.hasTask()) { + return CompletableFuture.completedFuture(FromProto.task(response.getTask())); + } else { + return CompletableFuture.failedFuture(new A2AClientException("Server response did not contain a message or task")); + } + } + }).toCompletableFuture(); } @Override @@ -116,7 +110,7 @@ public void sendMessageStreaming(MessageSendParams request, Consumer getTask(TaskQueryParams request, ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); GetTaskRequest.Builder requestBuilder = GetTaskRequest.newBuilder(); @@ -128,16 +122,17 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, getTaskRequest, agentCard, context); - try { - A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders); - return FromProto.task(stubWithMetadata.getTask(getTaskRequest)); - } catch (StatusRuntimeException e) { - throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task: "); - } + A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context, payloadAndHeaders); + SingleValueStreamObserver observer = new SingleValueStreamObserver<>(); + stubWithMetadata.getTask(getTaskRequest, observer); + + return observer.completionStage() + .thenCompose(task -> CompletableFuture.completedFuture(FromProto.task(task))) + .toCompletableFuture(); } @Override - public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A2AClientException { + public CompletableFuture cancelTask(TaskIdParams request, ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); CancelTaskRequest cancelTaskRequest = CancelTaskRequest.newBuilder() @@ -146,16 +141,17 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.CancelTaskRequest.METHOD, cancelTaskRequest, agentCard, context); - try { - A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders); - return FromProto.task(stubWithMetadata.cancelTask(cancelTaskRequest)); - } catch (StatusRuntimeException e) { - throw GrpcErrorMapper.mapGrpcError(e, "Failed to cancel task: "); - } + A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context, payloadAndHeaders); + SingleValueStreamObserver observer = new SingleValueStreamObserver<>(); + stubWithMetadata.cancelTask(cancelTaskRequest, observer); + + return observer.completionStage() + .thenCompose(task -> CompletableFuture.completedFuture(FromProto.task(task))) + .toCompletableFuture(); } @Override - public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, + public CompletableFuture setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); @@ -168,16 +164,17 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN PayloadAndHeaders payloadAndHeaders = applyInterceptors(SetTaskPushNotificationConfigRequest.METHOD, grpcRequest, agentCard, context); - try { - A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders); - return FromProto.taskPushNotificationConfig(stubWithMetadata.createTaskPushNotificationConfig(grpcRequest)); - } catch (StatusRuntimeException e) { - throw GrpcErrorMapper.mapGrpcError(e, "Failed to create task push notification config: "); - } + A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context, payloadAndHeaders); + SingleValueStreamObserver observer = new SingleValueStreamObserver<>(); + stubWithMetadata.createTaskPushNotificationConfig(grpcRequest, observer); + + return observer.completionStage() + .thenCompose(taskPushNotificationConfig -> CompletableFuture.completedFuture(FromProto.taskPushNotificationConfig(taskPushNotificationConfig))) + .toCompletableFuture(); } @Override - public TaskPushNotificationConfig getTaskPushNotificationConfiguration( + public CompletableFuture getTaskPushNotificationConfiguration( GetTaskPushNotificationConfigParams request, ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); @@ -188,16 +185,17 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration( PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskPushNotificationConfigRequest.METHOD, grpcRequest, agentCard, context); - try { - A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders); - return FromProto.taskPushNotificationConfig(stubWithMetadata.getTaskPushNotificationConfig(grpcRequest)); - } catch (StatusRuntimeException e) { - throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task push notification config: "); - } + A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context, payloadAndHeaders); + SingleValueStreamObserver observer = new SingleValueStreamObserver<>(); + stubWithMetadata.getTaskPushNotificationConfig(grpcRequest, observer); + + return observer.completionStage() + .thenCompose(taskPushNotificationConfig -> CompletableFuture.completedFuture(FromProto.taskPushNotificationConfig(taskPushNotificationConfig))) + .toCompletableFuture(); } @Override - public List listTaskPushNotificationConfigurations( + public CompletableFuture> listTaskPushNotificationConfigurations( ListTaskPushNotificationConfigParams request, ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); @@ -208,18 +206,35 @@ public List listTaskPushNotificationConfigurations( PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.ListTaskPushNotificationConfigRequest.METHOD, grpcRequest, agentCard, context); + A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context, payloadAndHeaders); + SingleValueStreamObserver observer = new SingleValueStreamObserver<>(); + stubWithMetadata.listTaskPushNotificationConfig(grpcRequest, observer); + + return observer.completionStage() + .thenCompose(new Function>>() { + @Override + public CompletionStage> apply(io.a2a.grpc.ListTaskPushNotificationConfigResponse listTaskPushNotificationConfigResponse) { + return CompletableFuture.completedFuture( + listTaskPushNotificationConfigResponse.getConfigsList().stream() + .map(FromProto::taskPushNotificationConfig).collect(Collectors.toList())); + } + }) + .toCompletableFuture(); + + /* try { A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders); return stubWithMetadata.listTaskPushNotificationConfig(grpcRequest).getConfigsList().stream() .map(FromProto::taskPushNotificationConfig) - .collect(Collectors.toList()); + ; } catch (StatusRuntimeException e) { throw GrpcErrorMapper.mapGrpcError(e, "Failed to list task push notification config: "); } + */ } @Override - public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request, + public CompletableFuture deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request, ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); @@ -229,12 +244,13 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.DeleteTaskPushNotificationConfigRequest.METHOD, grpcRequest, agentCard, context); - try { - A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders); - stubWithMetadata.deleteTaskPushNotificationConfig(grpcRequest); - } catch (StatusRuntimeException e) { - throw GrpcErrorMapper.mapGrpcError(e, "Failed to delete task push notification config: "); - } + A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context, payloadAndHeaders); + SingleValueStreamObserver observer = new SingleValueStreamObserver<>(); + stubWithMetadata.deleteTaskPushNotificationConfig(grpcRequest, observer); + + return observer.completionStage() + .thenApply((Function) empty -> null) + .toCompletableFuture(); } @Override @@ -260,9 +276,9 @@ public void resubscribe(TaskIdParams request, Consumer event } @Override - public AgentCard getAgentCard(ClientCallContext context) throws A2AClientException { + public CompletableFuture getAgentCard(ClientCallContext context) throws A2AClientException { // TODO: Determine how to handle retrieving the authenticated extended agent card - return agentCard; + return CompletableFuture.completedFuture(agentCard); } @Override diff --git a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/SingleValueStreamObserver.java b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/SingleValueStreamObserver.java new file mode 100644 index 00000000..187ca289 --- /dev/null +++ b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/SingleValueStreamObserver.java @@ -0,0 +1,68 @@ +package io.a2a.client.transport.grpc; + +import io.grpc.StatusRuntimeException; +import io.grpc.stub.StreamObserver; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; + +/** + * A simple {@link StreamObserver} adapter class that completes + * a {@link CompletableFuture} when the observer is completed. + *

+ * This observer uses the value passed to its {@link #onNext(Object)} method to complete + * the {@link CompletableFuture}. + *

+ * This observer should only be used in cases where a single result is expected. If more + * that one call is made to {@link #onNext(Object)} then future will be completed with + * an exception. + * + * @param The type of objects received in this stream. + */ +public class SingleValueStreamObserver implements StreamObserver { + + private int count; + + private T result; + + private final CompletableFuture resultFuture = new CompletableFuture<>(); + + /** + * Create a SingleValueStreamObserver. + */ + public SingleValueStreamObserver() { + } + + /** + * Obtain the {@link CompletableFuture} that will be completed + * when the {@link StreamObserver} completes. + * + * @return The CompletableFuture + */ + public CompletionStage completionStage() { + return resultFuture; + } + + @Override + public void onNext(T value) { + if (count++ == 0) { + result = value; + } else { + resultFuture.completeExceptionally(new IllegalStateException("More than one result received.")); + } + } + + @Override + public void onError(Throwable t) { + if (t instanceof StatusRuntimeException) { + resultFuture.completeExceptionally(GrpcErrorMapper.mapGrpcError((StatusRuntimeException) t)); + } else { + resultFuture.completeExceptionally(t); + } + } + + @Override + public void onCompleted() { + resultFuture.complete(result); + } +} diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java index a1ff52bb..77b9dcd2 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java @@ -1,12 +1,14 @@ package io.a2a.client.transport.jsonrpc; import static io.a2a.util.Assert.checkNotNullParam; +import static java.net.HttpURLConnection.HTTP_FORBIDDEN; +import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; import java.io.IOException; import java.net.URI; import java.util.List; import java.util.Map; -import java.util.concurrent.ExecutionException; +import java.util.concurrent.CompletionStage; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -20,6 +22,7 @@ import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders; import io.a2a.client.http.HttpClient; import io.a2a.client.http.HttpResponse; +import io.a2a.common.A2AErrorMessages; import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientException; import io.a2a.spec.AgentCard; @@ -59,12 +62,16 @@ import io.a2a.client.transport.jsonrpc.sse.SSEEventListener; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.logging.Logger; import io.a2a.util.Utils; import org.jspecify.annotations.Nullable; public class JSONRPCTransport extends AbstractClientTransport { + private static final Logger log = Logger.getLogger(JSONRPCTransport.class.getName()); + private static final TypeReference SEND_MESSAGE_RESPONSE_REFERENCE = new TypeReference<>() {}; private static final TypeReference GET_TASK_RESPONSE_REFERENCE = new TypeReference<>() {}; private static final TypeReference CANCEL_TASK_RESPONSE_REFERENCE = new TypeReference<>() {}; @@ -102,7 +109,7 @@ public JSONRPCTransport(@Nullable HttpClient httpClient, @Nullable AgentCard age } @Override - public EventKind sendMessage(MessageSendParams request, @Nullable ClientCallContext context) throws A2AClientException { + public CompletableFuture sendMessage(MessageSendParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); SendMessageRequest sendMessageRequest = new SendMessageRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -113,13 +120,21 @@ public EventKind sendMessage(MessageSendParams request, @Nullable ClientCallCont PayloadAndHeaders payloadAndHeaders = applyInterceptors(SendMessageRequest.METHOD, sendMessageRequest, agentCard, context); try { - String httpResponseBody = sendPostRequest(payloadAndHeaders); - SendMessageResponse response = unmarshalResponse(httpResponseBody, SEND_MESSAGE_RESPONSE_REFERENCE); - return response.getResult(); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException e) { - throw new A2AClientException("Failed to send message: " + e, e); + return sendPostRequest(payloadAndHeaders) + .thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + try { + return CompletableFuture.completedFuture(unmarshalResponse(httpResponseBody, SEND_MESSAGE_RESPONSE_REFERENCE).getResult()); + } catch (A2AClientException e) { + return CompletableFuture.failedFuture(e); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to send message: " + e, e)); + } + } + }); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to prepare request: " + e, e)); } } @@ -160,7 +175,7 @@ public void accept(HttpResponse httpResponse, Throwable throwable) { } @Override - public Task getTask(TaskQueryParams request, @Nullable ClientCallContext context) throws A2AClientException { + public CompletableFuture getTask(TaskQueryParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); GetTaskRequest getTaskRequest = new GetTaskRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -171,18 +186,26 @@ public Task getTask(TaskQueryParams request, @Nullable ClientCallContext context PayloadAndHeaders payloadAndHeaders = applyInterceptors(GetTaskRequest.METHOD, getTaskRequest, agentCard, context); try { - String httpResponseBody = sendPostRequest(payloadAndHeaders); - GetTaskResponse response = unmarshalResponse(httpResponseBody, GET_TASK_RESPONSE_REFERENCE); - return response.getResult(); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException e) { - throw new A2AClientException("Failed to get task: " + e, e); + return sendPostRequest(payloadAndHeaders) + .thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + try { + return CompletableFuture.completedFuture(unmarshalResponse(httpResponseBody, GET_TASK_RESPONSE_REFERENCE).getResult()); + } catch (A2AClientException e) { + return CompletableFuture.failedFuture(e); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to get task: " + e, e)); + } + } + }); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to prepare request: " + e, e)); } } @Override - public Task cancelTask(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException { + public CompletableFuture cancelTask(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); CancelTaskRequest cancelTaskRequest = new CancelTaskRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -193,18 +216,26 @@ public Task cancelTask(TaskIdParams request, @Nullable ClientCallContext context PayloadAndHeaders payloadAndHeaders = applyInterceptors(CancelTaskRequest.METHOD, cancelTaskRequest, agentCard, context); try { - String httpResponseBody = sendPostRequest(payloadAndHeaders); - CancelTaskResponse response = unmarshalResponse(httpResponseBody, CANCEL_TASK_RESPONSE_REFERENCE); - return response.getResult(); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException e) { - throw new A2AClientException("Failed to cancel task: " + e, e); + return sendPostRequest(payloadAndHeaders) + .thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + try { + return CompletableFuture.completedFuture(unmarshalResponse(httpResponseBody, CANCEL_TASK_RESPONSE_REFERENCE).getResult()); + } catch (A2AClientException e) { + return CompletableFuture.failedFuture(e); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to cancel task: " + e, e)); + } + } + }); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to prepare request: " + e, e)); } } @Override - public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, + public CompletableFuture setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); SetTaskPushNotificationConfigRequest setTaskPushNotificationRequest = new SetTaskPushNotificationConfigRequest.Builder() @@ -217,19 +248,26 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN setTaskPushNotificationRequest, agentCard, context); try { - String httpResponseBody = sendPostRequest(payloadAndHeaders); - SetTaskPushNotificationConfigResponse response = unmarshalResponse(httpResponseBody, - SET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); - return response.getResult(); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException e) { - throw new A2AClientException("Failed to set task push notification config: " + e, e); + return sendPostRequest(payloadAndHeaders) + .thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + try { + return CompletableFuture.completedFuture(unmarshalResponse(httpResponseBody, SET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE).getResult()); + } catch (A2AClientException e) { + return CompletableFuture.failedFuture(e); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to set task push notification config: " + e, e)); + } + } + }); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to prepare request: " + e, e)); } } @Override - public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPushNotificationConfigParams request, + public CompletableFuture getTaskPushNotificationConfiguration(GetTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); GetTaskPushNotificationConfigRequest getTaskPushNotificationRequest = new GetTaskPushNotificationConfigRequest.Builder() @@ -242,19 +280,26 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPu getTaskPushNotificationRequest, agentCard, context); try { - String httpResponseBody = sendPostRequest(payloadAndHeaders); - GetTaskPushNotificationConfigResponse response = unmarshalResponse(httpResponseBody, - GET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); - return response.getResult(); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException e) { - throw new A2AClientException("Failed to get task push notification config: " + e, e); + return sendPostRequest(payloadAndHeaders) + .thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + try { + return CompletableFuture.completedFuture(unmarshalResponse(httpResponseBody, GET_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE).getResult()); + } catch (A2AClientException e) { + return CompletableFuture.failedFuture(e); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to get task push notification config: " + e, e)); + } + } + }); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to prepare request: " + e, e)); } } @Override - public List listTaskPushNotificationConfigurations( + public CompletableFuture> listTaskPushNotificationConfigurations( ListTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); @@ -268,19 +313,26 @@ public List listTaskPushNotificationConfigurations( listTaskPushNotificationRequest, agentCard, context); try { - String httpResponseBody = sendPostRequest(payloadAndHeaders); - ListTaskPushNotificationConfigResponse response = unmarshalResponse(httpResponseBody, - LIST_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); - return response.getResult(); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException e) { - throw new A2AClientException("Failed to list task push notification configs: " + e, e); + return sendPostRequest(payloadAndHeaders) + .thenCompose(new Function>>() { + @Override + public CompletionStage> apply(String httpResponseBody) { + try { + return CompletableFuture.completedFuture(unmarshalResponse(httpResponseBody, LIST_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE).getResult()); + } catch (A2AClientException e) { + return CompletableFuture.failedFuture(e); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to list task push notification configs: " + e, e)); + } + } + }); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to prepare request: " + e, e)); } } @Override - public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request, + public CompletableFuture deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); DeleteTaskPushNotificationConfigRequest deleteTaskPushNotificationRequest = new DeleteTaskPushNotificationConfigRequest.Builder() @@ -293,12 +345,22 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC deleteTaskPushNotificationRequest, agentCard, context); try { - String httpResponseBody = sendPostRequest(payloadAndHeaders); - unmarshalResponse(httpResponseBody, DELETE_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException e) { - throw new A2AClientException("Failed to delete task push notification configs: " + e, e); + return sendPostRequest(payloadAndHeaders) + .thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + try { + unmarshalResponse(httpResponseBody, DELETE_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE); + return CompletableFuture.completedFuture(null); + } catch (A2AClientException e) { + return CompletableFuture.failedFuture(e); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to delete task push notification configs: " + e, e)); + } + } + }); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to prepare request: " + e, e)); } } @@ -340,15 +402,18 @@ public void accept(HttpResponse httpResponse, Throwable throwable) { } @Override - public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2AClientException { - try { + public CompletableFuture getAgentCard(@Nullable ClientCallContext context) throws A2AClientException { if (agentCard == null) { - A2ACardResolver resolver = new A2ACardResolver(httpClient, agentPath, getHttpHeaders(context)); - agentCard = resolver.getAgentCard(); - needsExtendedCard = agentCard.supportsAuthenticatedExtendedCard(); + try { + A2ACardResolver resolver = new A2ACardResolver(httpClient, agentPath, getHttpHeaders(context)); + agentCard = resolver.getAgentCard(); + needsExtendedCard = agentCard.supportsAuthenticatedExtendedCard(); + } catch (A2AClientError e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to get agent card: " + e, e)); + } } if (!needsExtendedCard) { - return agentCard; + return CompletableFuture.completedFuture(agentCard); } GetAuthenticatedExtendedCardRequest getExtendedAgentCardRequest = new GetAuthenticatedExtendedCardRequest.Builder() @@ -360,18 +425,28 @@ public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2ACli getExtendedAgentCardRequest, agentCard, context); try { - String httpResponseBody = sendPostRequest(payloadAndHeaders); - GetAuthenticatedExtendedCardResponse response = unmarshalResponse(httpResponseBody, - GET_AUTHENTICATED_EXTENDED_CARD_RESPONSE_REFERENCE); - agentCard = response.getResult(); - needsExtendedCard = false; - return agentCard; - } catch (IOException | InterruptedException e) { - throw new A2AClientException("Failed to get authenticated extended agent card: " + e, e); + return sendPostRequest(payloadAndHeaders) + .thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + try { + return CompletableFuture.completedFuture(unmarshalResponse(httpResponseBody, GET_AUTHENTICATED_EXTENDED_CARD_RESPONSE_REFERENCE).getResult()); + } catch (A2AClientException e) { + return CompletableFuture.failedFuture(e); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to get authenticated extended agent card: " + e, e)); + } + } + }).whenComplete(new BiConsumer() { + @Override + public void accept(AgentCard agentCard, Throwable throwable) { + JSONRPCTransport.this.agentCard = agentCard; + needsExtendedCard = false; + } + }); + } catch (IOException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to prepare request: " + e, e)); } - } catch(A2AClientError e){ - throw new A2AClientException("Failed to get agent card: " + e, e); - } } @Override @@ -379,12 +454,31 @@ public void close() { // no-op } - private String sendPostRequest(PayloadAndHeaders payloadAndHeaders) throws IOException, InterruptedException { - HttpClient.PostRequestBuilder builder = createPostBuilder(payloadAndHeaders); + private CompletableFuture sendPostRequest(PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException { + return createPostBuilder(payloadAndHeaders) + .send() + .thenCompose(new Function>() { + @Override + public CompletionStage apply(HttpResponse response) { + if (!response.success()) { + log.fine("Error on POST processing " + payloadAndHeaders.getPayload()); + if (response.statusCode() == HTTP_UNAUTHORIZED) { + return CompletableFuture.failedStage(new A2AClientException(A2AErrorMessages.AUTHENTICATION_FAILED)); + } else if (response.statusCode() == HTTP_FORBIDDEN) { + return CompletableFuture.failedStage(new A2AClientException(A2AErrorMessages.AUTHORIZATION_FAILED)); + } + + return CompletableFuture.failedFuture(new A2AClientException("Request failed " + response.statusCode())); + } + + return response.body(); + } + }); + /* try { HttpResponse response = builder.send().get(); if (!response.success()) { - throw new IOException("Request failed " + response.statusCode()); + throw ; } return response.body(); @@ -394,6 +488,7 @@ private String sendPostRequest(PayloadAndHeaders payloadAndHeaders) throws IOExc } throw new IOException("Failed to send request", e.getCause()); } + */ } private HttpClient.PostRequestBuilder createPostBuilder(PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException { diff --git a/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportTest.java b/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportTest.java index 25de3294..f5b8ba32 100644 --- a/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportTest.java +++ b/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportTest.java @@ -24,12 +24,7 @@ import static io.a2a.client.transport.jsonrpc.JsonMessages.SEND_MESSAGE_WITH_MIXED_PARTS_TEST_RESPONSE; import static io.a2a.client.transport.jsonrpc.JsonMessages.SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_REQUEST; import static io.a2a.client.transport.jsonrpc.JsonMessages.SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.*; import static org.mockserver.model.HttpRequest.request; import static org.mockserver.model.HttpResponse.response; @@ -37,6 +32,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; import io.a2a.spec.A2AClientException; import io.a2a.spec.AgentCard; @@ -103,6 +101,8 @@ public void testA2AClientSendMessage() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); + CountDownLatch latch = new CountDownLatch(1); + Message message = new Message.Builder() .role(Message.Role.USER) .parts(Collections.singletonList(new TextPart("tell me a joke"))) @@ -118,21 +118,32 @@ public void testA2AClientSendMessage() throws Exception { .configuration(configuration) .build(); - EventKind result = client.sendMessage(params, null); - assertInstanceOf(Task.class, result); - Task task = (Task) result; - assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); - assertNotNull(task.getContextId()); - assertEquals(TaskState.COMPLETED,task.getStatus().state()); - assertEquals(1, task.getArtifacts().size()); - Artifact artifact = task.getArtifacts().get(0); - assertEquals("artifact-1", artifact.artifactId()); - assertEquals("joke", artifact.name()); - assertEquals(1, artifact.parts().size()); - Part part = artifact.parts().get(0); - assertEquals(Part.Kind.TEXT, part.getKind()); - assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) part).getText()); - assertTrue(task.getMetadata().isEmpty()); + client.sendMessage(params, null) + .whenComplete(new BiConsumer() { + @Override + public void accept(EventKind result, Throwable throwable) { + assertNull(throwable); + assertInstanceOf(Task.class, result); + Task task = (Task) result; + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertNotNull(task.getContextId()); + assertEquals(TaskState.COMPLETED,task.getStatus().state()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("joke", artifact.name()); + assertEquals(1, artifact.parts().size()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) part).getText()); + assertTrue(task.getMetadata().isEmpty()); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } @Test @@ -151,6 +162,8 @@ public void testA2AClientSendMessageWithMessageResponse() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); + CountDownLatch latch = new CountDownLatch(1); + Message message = new Message.Builder() .role(Message.Role.USER) .parts(Collections.singletonList(new TextPart("tell me a joke"))) @@ -166,14 +179,26 @@ public void testA2AClientSendMessageWithMessageResponse() throws Exception { .configuration(configuration) .build(); - EventKind result = client.sendMessage(params, null); - assertInstanceOf(Message.class, result); - Message agentMessage = (Message) result; - assertEquals(Message.Role.AGENT, agentMessage.getRole()); - Part part = agentMessage.getParts().get(0); - assertEquals(Part.Kind.TEXT, part.getKind()); - assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) part).getText()); - assertEquals("msg-456", agentMessage.getMessageId()); + client.sendMessage(params, null) + .whenComplete(new BiConsumer() { + @Override + public void accept(EventKind result, Throwable throwable) { + assertNull(throwable); + + assertInstanceOf(Message.class, result); + Message agentMessage = (Message) result; + assertEquals(Message.Role.AGENT, agentMessage.getRole()); + Part part = agentMessage.getParts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) part).getText()); + assertEquals("msg-456", agentMessage.getMessageId()); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } @@ -193,6 +218,8 @@ public void testA2AClientSendMessageWithError() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); + CountDownLatch latch = new CountDownLatch(1); + Message message = new Message.Builder() .role(Message.Role.USER) .parts(Collections.singletonList(new TextPart("tell me a joke"))) @@ -208,12 +235,21 @@ public void testA2AClientSendMessageWithError() throws Exception { .configuration(configuration) .build(); - try { - client.sendMessage(params, null); - fail(); // should not reach here - } catch (A2AClientException e) { - assertTrue(e.getMessage().contains("Invalid parameters: Hello world")); - } + client.sendMessage(params, null) + .whenComplete(new BiConsumer() { + @Override + public void accept(EventKind eventKind, Throwable throwable) { + assertNull(eventKind); + + assertNotNull(throwable); + assertTrue(throwable.getMessage().contains("Invalid parameters: Hello world")); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } @Test @@ -232,41 +268,55 @@ public void testA2AClientGetTask() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); - Task task = client.getTask(new TaskQueryParams("de38c76d-d54c-436c-8b9f-4c2703648d64", - 10), null); - assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); - assertEquals("c295ea44-7543-4f78-b524-7a38915ad6e4", task.getContextId()); - assertEquals(TaskState.COMPLETED, task.getStatus().state()); - assertEquals(1, task.getArtifacts().size()); - Artifact artifact = task.getArtifacts().get(0); - assertEquals(1, artifact.parts().size()); - assertEquals("artifact-1", artifact.artifactId()); - Part part = artifact.parts().get(0); - assertEquals(Part.Kind.TEXT, part.getKind()); - assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) part).getText()); - assertTrue(task.getMetadata().isEmpty()); - List history = task.getHistory(); - assertNotNull(history); - assertEquals(1, history.size()); - Message message = history.get(0); - assertEquals(Message.Role.USER, message.getRole()); - List> parts = message.getParts(); - assertNotNull(parts); - assertEquals(3, parts.size()); - part = parts.get(0); - assertEquals(Part.Kind.TEXT, part.getKind()); - assertEquals("tell me a joke", ((TextPart)part).getText()); - part = parts.get(1); - assertEquals(Part.Kind.FILE, part.getKind()); - FileContent filePart = ((FilePart) part).getFile(); - assertEquals("file:///path/to/file.txt", ((FileWithUri) filePart).uri()); - assertEquals("text/plain", filePart.mimeType()); - part = parts.get(2); - assertEquals(Part.Kind.FILE, part.getKind()); - filePart = ((FilePart) part).getFile(); - assertEquals("aGVsbG8=", ((FileWithBytes) filePart).bytes()); - assertEquals("hello.txt", filePart.name()); - assertTrue(task.getMetadata().isEmpty()); + CountDownLatch latch = new CountDownLatch(1); + + client.getTask(new TaskQueryParams("de38c76d-d54c-436c-8b9f-4c2703648d64", + 10), null) + .whenComplete(new BiConsumer() { + @Override + public void accept(Task task, Throwable throwable) { + assertNull(throwable); + + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertEquals("c295ea44-7543-4f78-b524-7a38915ad6e4", task.getContextId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals(1, artifact.parts().size()); + assertEquals("artifact-1", artifact.artifactId()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) part).getText()); + assertTrue(task.getMetadata().isEmpty()); + List history = task.getHistory(); + assertNotNull(history); + assertEquals(1, history.size()); + Message message = history.get(0); + assertEquals(Message.Role.USER, message.getRole()); + List> parts = message.getParts(); + assertNotNull(parts); + assertEquals(3, parts.size()); + part = parts.get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("tell me a joke", ((TextPart)part).getText()); + part = parts.get(1); + assertEquals(Part.Kind.FILE, part.getKind()); + FileContent filePart = ((FilePart) part).getFile(); + assertEquals("file:///path/to/file.txt", ((FileWithUri) filePart).uri()); + assertEquals("text/plain", filePart.mimeType()); + part = parts.get(2); + assertEquals(Part.Kind.FILE, part.getKind()); + filePart = ((FilePart) part).getFile(); + assertEquals("aGVsbG8=", ((FileWithBytes) filePart).bytes()); + assertEquals("hello.txt", filePart.name()); + assertTrue(task.getMetadata().isEmpty()); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } @Test @@ -285,12 +335,26 @@ public void testA2AClientCancelTask() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); - Task task = client.cancelTask(new TaskIdParams("de38c76d-d54c-436c-8b9f-4c2703648d64", - new HashMap<>()), null); - assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); - assertEquals("c295ea44-7543-4f78-b524-7a38915ad6e4", task.getContextId()); - assertEquals(TaskState.CANCELED, task.getStatus().state()); - assertTrue(task.getMetadata().isEmpty()); + CountDownLatch latch = new CountDownLatch(1); + + client.cancelTask(new TaskIdParams("de38c76d-d54c-436c-8b9f-4c2703648d64", + new HashMap<>()), null) + .whenComplete(new BiConsumer() { + @Override + public void accept(Task task, Throwable throwable) { + assertNull(throwable); + + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertEquals("c295ea44-7543-4f78-b524-7a38915ad6e4", task.getContextId()); + assertEquals(TaskState.CANCELED, task.getStatus().state()); + assertTrue(task.getMetadata().isEmpty()); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } @Test @@ -309,15 +373,31 @@ public void testA2AClientGetTaskPushNotificationConfig() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); - TaskPushNotificationConfig taskPushNotificationConfig = client.getTaskPushNotificationConfiguration( + CountDownLatch latch = new CountDownLatch(1); + + client.getTaskPushNotificationConfiguration( new GetTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", null, - new HashMap<>()), null); - PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); - assertNotNull(pushNotificationConfig); - assertEquals("https://example.com/callback", pushNotificationConfig.url()); - PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); - assertTrue(authenticationInfo.schemes().size() == 1); - assertEquals("jwt", authenticationInfo.schemes().get(0)); + new HashMap<>()), null) + .whenComplete(new BiConsumer() { + @Override + public void accept(TaskPushNotificationConfig taskPushNotificationConfig, Throwable throwable) { + assertNull(throwable); + + PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://example.com/callback", pushNotificationConfig.url()); + PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); + assertTrue(authenticationInfo.schemes().size() == 1); + assertEquals("jwt", authenticationInfo.schemes().get(0)); + + + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } @Test @@ -336,19 +416,33 @@ public void testA2AClientSetTaskPushNotificationConfig() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); - TaskPushNotificationConfig taskPushNotificationConfig = client.setTaskPushNotificationConfiguration( + CountDownLatch latch = new CountDownLatch(1); + + client.setTaskPushNotificationConfiguration( new TaskPushNotificationConfig("de38c76d-d54c-436c-8b9f-4c2703648d64", new PushNotificationConfig.Builder() .url("https://example.com/callback") .authenticationInfo(new PushNotificationAuthenticationInfo(Collections.singletonList("jwt"), null)) - .build()), null); - PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); - assertNotNull(pushNotificationConfig); - assertEquals("https://example.com/callback", pushNotificationConfig.url()); - PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); - assertEquals(1, authenticationInfo.schemes().size()); - assertEquals("jwt", authenticationInfo.schemes().get(0)); + .build()), null) + .whenComplete(new BiConsumer() { + @Override + public void accept(TaskPushNotificationConfig taskPushNotificationConfig, Throwable throwable) { + assertNull(throwable); + + PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://example.com/callback", pushNotificationConfig.url()); + PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); + assertEquals(1, authenticationInfo.schemes().size()); + assertEquals("jwt", authenticationInfo.schemes().get(0)); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } @@ -366,68 +460,82 @@ public void testA2AClientGetAgentCard() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); - AgentCard agentCard = client.getAgentCard(null); - assertEquals("GeoSpatial Route Planner Agent", agentCard.name()); - assertEquals("Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", agentCard.description()); - assertEquals("https://georoute-agent.example.com/a2a/v1", agentCard.url()); - assertEquals("Example Geo Services Inc.", agentCard.provider().organization()); - assertEquals("https://www.examplegeoservices.com", agentCard.provider().url()); - assertEquals("1.2.0", agentCard.version()); - assertEquals("https://docs.examplegeoservices.com/georoute-agent/api", agentCard.documentationUrl()); - assertTrue(agentCard.capabilities().streaming()); - assertTrue(agentCard.capabilities().pushNotifications()); - assertFalse(agentCard.capabilities().stateTransitionHistory()); - Map securitySchemes = agentCard.securitySchemes(); - assertNotNull(securitySchemes); - OpenIdConnectSecurityScheme google = (OpenIdConnectSecurityScheme) securitySchemes.get("google"); - assertEquals("openIdConnect", google.getType()); - assertEquals("https://accounts.google.com/.well-known/openid-configuration", google.getOpenIdConnectUrl()); - List>> security = agentCard.security(); - assertEquals(1, security.size()); - Map> securityMap = security.get(0); - List scopes = securityMap.get("google"); - List expectedScopes = List.of("openid", "profile", "email"); - assertEquals(expectedScopes, scopes); - List defaultInputModes = List.of("application/json", "text/plain"); - assertEquals(defaultInputModes, agentCard.defaultInputModes()); - List defaultOutputModes = List.of("application/json", "image/png"); - assertEquals(defaultOutputModes, agentCard.defaultOutputModes()); - List skills = agentCard.skills(); - assertEquals("route-optimizer-traffic", skills.get(0).id()); - assertEquals("Traffic-Aware Route Optimizer", skills.get(0).name()); - assertEquals("Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", skills.get(0).description()); - List tags = List.of("maps", "routing", "navigation", "directions", "traffic"); - assertEquals(tags, skills.get(0).tags()); - List examples = List.of("Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", - "{\"origin\": {\"lat\": 37.422, \"lng\": -122.084}, \"destination\": {\"lat\": 37.7749, \"lng\": -122.4194}, \"preferences\": [\"avoid_ferries\"]}"); - assertEquals(examples, skills.get(0).examples()); - assertEquals(defaultInputModes, skills.get(0).inputModes()); - List outputModes = List.of("application/json", "application/vnd.geo+json", "text/html"); - assertEquals(outputModes, skills.get(0).outputModes()); - assertEquals("custom-map-generator", skills.get(1).id()); - assertEquals("Personalized Map Generator", skills.get(1).name()); - assertEquals("Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", skills.get(1).description()); - tags = List.of("maps", "customization", "visualization", "cartography"); - assertEquals(tags, skills.get(1).tags()); - examples = List.of("Generate a map of my upcoming road trip with all planned stops highlighted.", - "Show me a map visualizing all coffee shops within a 1-mile radius of my current location."); - assertEquals(examples, skills.get(1).examples()); - List inputModes = List.of("application/json"); - assertEquals(inputModes, skills.get(1).inputModes()); - outputModes = List.of("image/png", "image/jpeg", "application/json", "text/html"); - assertEquals(outputModes, skills.get(1).outputModes()); - assertFalse(agentCard.supportsAuthenticatedExtendedCard()); - assertEquals("https://georoute-agent.example.com/icon.png", agentCard.iconUrl()); - assertEquals("0.2.9", agentCard.protocolVersion()); - assertEquals("JSONRPC", agentCard.preferredTransport()); - List additionalInterfaces = agentCard.additionalInterfaces(); - assertEquals(3, additionalInterfaces.size()); - AgentInterface jsonrpc = new AgentInterface(TransportProtocol.JSONRPC.asString(), "https://georoute-agent.example.com/a2a/v1"); - AgentInterface grpc = new AgentInterface(TransportProtocol.GRPC.asString(), "https://georoute-agent.example.com/a2a/grpc"); - AgentInterface httpJson = new AgentInterface(TransportProtocol.HTTP_JSON.asString(), "https://georoute-agent.example.com/a2a/json"); - assertEquals(jsonrpc, additionalInterfaces.get(0)); - assertEquals(grpc, additionalInterfaces.get(1)); - assertEquals(httpJson, additionalInterfaces.get(2)); + CountDownLatch latch = new CountDownLatch(1); + + client.getAgentCard(null) + .whenComplete(new BiConsumer() { + @Override + public void accept(AgentCard agentCard, Throwable throwable) { + assertNull(throwable); + + assertEquals("GeoSpatial Route Planner Agent", agentCard.name()); + assertEquals("Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", agentCard.description()); + assertEquals("https://georoute-agent.example.com/a2a/v1", agentCard.url()); + assertEquals("Example Geo Services Inc.", agentCard.provider().organization()); + assertEquals("https://www.examplegeoservices.com", agentCard.provider().url()); + assertEquals("1.2.0", agentCard.version()); + assertEquals("https://docs.examplegeoservices.com/georoute-agent/api", agentCard.documentationUrl()); + assertTrue(agentCard.capabilities().streaming()); + assertTrue(agentCard.capabilities().pushNotifications()); + assertFalse(agentCard.capabilities().stateTransitionHistory()); + Map securitySchemes = agentCard.securitySchemes(); + assertNotNull(securitySchemes); + OpenIdConnectSecurityScheme google = (OpenIdConnectSecurityScheme) securitySchemes.get("google"); + assertEquals("openIdConnect", google.getType()); + assertEquals("https://accounts.google.com/.well-known/openid-configuration", google.getOpenIdConnectUrl()); + List>> security = agentCard.security(); + assertEquals(1, security.size()); + Map> securityMap = security.get(0); + List scopes = securityMap.get("google"); + List expectedScopes = List.of("openid", "profile", "email"); + assertEquals(expectedScopes, scopes); + List defaultInputModes = List.of("application/json", "text/plain"); + assertEquals(defaultInputModes, agentCard.defaultInputModes()); + List defaultOutputModes = List.of("application/json", "image/png"); + assertEquals(defaultOutputModes, agentCard.defaultOutputModes()); + List skills = agentCard.skills(); + assertEquals("route-optimizer-traffic", skills.get(0).id()); + assertEquals("Traffic-Aware Route Optimizer", skills.get(0).name()); + assertEquals("Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", skills.get(0).description()); + List tags = List.of("maps", "routing", "navigation", "directions", "traffic"); + assertEquals(tags, skills.get(0).tags()); + List examples = List.of("Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\"origin\": {\"lat\": 37.422, \"lng\": -122.084}, \"destination\": {\"lat\": 37.7749, \"lng\": -122.4194}, \"preferences\": [\"avoid_ferries\"]}"); + assertEquals(examples, skills.get(0).examples()); + assertEquals(defaultInputModes, skills.get(0).inputModes()); + List outputModes = List.of("application/json", "application/vnd.geo+json", "text/html"); + assertEquals(outputModes, skills.get(0).outputModes()); + assertEquals("custom-map-generator", skills.get(1).id()); + assertEquals("Personalized Map Generator", skills.get(1).name()); + assertEquals("Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", skills.get(1).description()); + tags = List.of("maps", "customization", "visualization", "cartography"); + assertEquals(tags, skills.get(1).tags()); + examples = List.of("Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location."); + assertEquals(examples, skills.get(1).examples()); + List inputModes = List.of("application/json"); + assertEquals(inputModes, skills.get(1).inputModes()); + outputModes = List.of("image/png", "image/jpeg", "application/json", "text/html"); + assertEquals(outputModes, skills.get(1).outputModes()); + assertFalse(agentCard.supportsAuthenticatedExtendedCard()); + assertEquals("https://georoute-agent.example.com/icon.png", agentCard.iconUrl()); + assertEquals("0.2.9", agentCard.protocolVersion()); + assertEquals("JSONRPC", agentCard.preferredTransport()); + List additionalInterfaces = agentCard.additionalInterfaces(); + assertEquals(3, additionalInterfaces.size()); + AgentInterface jsonrpc = new AgentInterface(TransportProtocol.JSONRPC.asString(), "https://georoute-agent.example.com/a2a/v1"); + AgentInterface grpc = new AgentInterface(TransportProtocol.GRPC.asString(), "https://georoute-agent.example.com/a2a/grpc"); + AgentInterface httpJson = new AgentInterface(TransportProtocol.HTTP_JSON.asString(), "https://georoute-agent.example.com/a2a/json"); + assertEquals(jsonrpc, additionalInterfaces.get(0)); + assertEquals(grpc, additionalInterfaces.get(1)); + assertEquals(httpJson, additionalInterfaces.get(2)); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } @Test @@ -455,63 +563,77 @@ public void testA2AClientGetAuthenticatedExtendedAgentCard() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); - AgentCard agentCard = client.getAgentCard(null); - assertEquals("GeoSpatial Route Planner Agent Extended", agentCard.name()); - assertEquals("Extended description", agentCard.description()); - assertEquals("https://georoute-agent.example.com/a2a/v1", agentCard.url()); - assertEquals("Example Geo Services Inc.", agentCard.provider().organization()); - assertEquals("https://www.examplegeoservices.com", agentCard.provider().url()); - assertEquals("1.2.0", agentCard.version()); - assertEquals("https://docs.examplegeoservices.com/georoute-agent/api", agentCard.documentationUrl()); - assertTrue(agentCard.capabilities().streaming()); - assertTrue(agentCard.capabilities().pushNotifications()); - assertFalse(agentCard.capabilities().stateTransitionHistory()); - Map securitySchemes = agentCard.securitySchemes(); - assertNotNull(securitySchemes); - OpenIdConnectSecurityScheme google = (OpenIdConnectSecurityScheme) securitySchemes.get("google"); - assertEquals("openIdConnect", google.getType()); - assertEquals("https://accounts.google.com/.well-known/openid-configuration", google.getOpenIdConnectUrl()); - List>> security = agentCard.security(); - assertEquals(1, security.size()); - Map> securityMap = security.get(0); - List scopes = securityMap.get("google"); - List expectedScopes = List.of("openid", "profile", "email"); - assertEquals(expectedScopes, scopes); - List defaultInputModes = List.of("application/json", "text/plain"); - assertEquals(defaultInputModes, agentCard.defaultInputModes()); - List defaultOutputModes = List.of("application/json", "image/png"); - assertEquals(defaultOutputModes, agentCard.defaultOutputModes()); - List skills = agentCard.skills(); - assertEquals("route-optimizer-traffic", skills.get(0).id()); - assertEquals("Traffic-Aware Route Optimizer", skills.get(0).name()); - assertEquals("Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", skills.get(0).description()); - List tags = List.of("maps", "routing", "navigation", "directions", "traffic"); - assertEquals(tags, skills.get(0).tags()); - List examples = List.of("Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", - "{\"origin\": {\"lat\": 37.422, \"lng\": -122.084}, \"destination\": {\"lat\": 37.7749, \"lng\": -122.4194}, \"preferences\": [\"avoid_ferries\"]}"); - assertEquals(examples, skills.get(0).examples()); - assertEquals(defaultInputModes, skills.get(0).inputModes()); - List outputModes = List.of("application/json", "application/vnd.geo+json", "text/html"); - assertEquals(outputModes, skills.get(0).outputModes()); - assertEquals("custom-map-generator", skills.get(1).id()); - assertEquals("Personalized Map Generator", skills.get(1).name()); - assertEquals("Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", skills.get(1).description()); - tags = List.of("maps", "customization", "visualization", "cartography"); - assertEquals(tags, skills.get(1).tags()); - examples = List.of("Generate a map of my upcoming road trip with all planned stops highlighted.", - "Show me a map visualizing all coffee shops within a 1-mile radius of my current location."); - assertEquals(examples, skills.get(1).examples()); - List inputModes = List.of("application/json"); - assertEquals(inputModes, skills.get(1).inputModes()); - outputModes = List.of("image/png", "image/jpeg", "application/json", "text/html"); - assertEquals(outputModes, skills.get(1).outputModes()); - assertEquals("skill-extended", skills.get(2).id()); - assertEquals("Extended Skill", skills.get(2).name()); - assertEquals("This is an extended skill.", skills.get(2).description()); - assertEquals(List.of("extended"), skills.get(2).tags()); - assertTrue(agentCard.supportsAuthenticatedExtendedCard()); - assertEquals("https://georoute-agent.example.com/icon.png", agentCard.iconUrl()); - assertEquals("0.2.5", agentCard.protocolVersion()); + CountDownLatch latch = new CountDownLatch(1); + + client.getAgentCard(null) + .whenComplete(new BiConsumer() { + @Override + public void accept(AgentCard agentCard, Throwable throwable) { + assertNull(throwable); + + assertEquals("GeoSpatial Route Planner Agent Extended", agentCard.name()); + assertEquals("Extended description", agentCard.description()); + assertEquals("https://georoute-agent.example.com/a2a/v1", agentCard.url()); + assertEquals("Example Geo Services Inc.", agentCard.provider().organization()); + assertEquals("https://www.examplegeoservices.com", agentCard.provider().url()); + assertEquals("1.2.0", agentCard.version()); + assertEquals("https://docs.examplegeoservices.com/georoute-agent/api", agentCard.documentationUrl()); + assertTrue(agentCard.capabilities().streaming()); + assertTrue(agentCard.capabilities().pushNotifications()); + assertFalse(agentCard.capabilities().stateTransitionHistory()); + Map securitySchemes = agentCard.securitySchemes(); + assertNotNull(securitySchemes); + OpenIdConnectSecurityScheme google = (OpenIdConnectSecurityScheme) securitySchemes.get("google"); + assertEquals("openIdConnect", google.getType()); + assertEquals("https://accounts.google.com/.well-known/openid-configuration", google.getOpenIdConnectUrl()); + List>> security = agentCard.security(); + assertEquals(1, security.size()); + Map> securityMap = security.get(0); + List scopes = securityMap.get("google"); + List expectedScopes = List.of("openid", "profile", "email"); + assertEquals(expectedScopes, scopes); + List defaultInputModes = List.of("application/json", "text/plain"); + assertEquals(defaultInputModes, agentCard.defaultInputModes()); + List defaultOutputModes = List.of("application/json", "image/png"); + assertEquals(defaultOutputModes, agentCard.defaultOutputModes()); + List skills = agentCard.skills(); + assertEquals("route-optimizer-traffic", skills.get(0).id()); + assertEquals("Traffic-Aware Route Optimizer", skills.get(0).name()); + assertEquals("Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", skills.get(0).description()); + List tags = List.of("maps", "routing", "navigation", "directions", "traffic"); + assertEquals(tags, skills.get(0).tags()); + List examples = List.of("Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\"origin\": {\"lat\": 37.422, \"lng\": -122.084}, \"destination\": {\"lat\": 37.7749, \"lng\": -122.4194}, \"preferences\": [\"avoid_ferries\"]}"); + assertEquals(examples, skills.get(0).examples()); + assertEquals(defaultInputModes, skills.get(0).inputModes()); + List outputModes = List.of("application/json", "application/vnd.geo+json", "text/html"); + assertEquals(outputModes, skills.get(0).outputModes()); + assertEquals("custom-map-generator", skills.get(1).id()); + assertEquals("Personalized Map Generator", skills.get(1).name()); + assertEquals("Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", skills.get(1).description()); + tags = List.of("maps", "customization", "visualization", "cartography"); + assertEquals(tags, skills.get(1).tags()); + examples = List.of("Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location."); + assertEquals(examples, skills.get(1).examples()); + List inputModes = List.of("application/json"); + assertEquals(inputModes, skills.get(1).inputModes()); + outputModes = List.of("image/png", "image/jpeg", "application/json", "text/html"); + assertEquals(outputModes, skills.get(1).outputModes()); + assertEquals("skill-extended", skills.get(2).id()); + assertEquals("Extended Skill", skills.get(2).name()); + assertEquals("This is an extended skill.", skills.get(2).description()); + assertEquals(List.of("extended"), skills.get(2).tags()); + assertTrue(agentCard.supportsAuthenticatedExtendedCard()); + assertEquals("https://georoute-agent.example.com/icon.png", agentCard.iconUrl()); + assertEquals("0.2.5", agentCard.protocolVersion()); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } @Test @@ -530,6 +652,8 @@ public void testA2AClientSendMessageWithFilePart() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); + CountDownLatch latch = new CountDownLatch(1); + Message message = new Message.Builder() .role(Message.Role.USER) .parts(List.of( @@ -548,21 +672,33 @@ public void testA2AClientSendMessageWithFilePart() throws Exception { .configuration(configuration) .build(); - EventKind result = client.sendMessage(params, null); - assertInstanceOf(Task.class, result); - Task task = (Task) result; - assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); - assertNotNull(task.getContextId()); - assertEquals(TaskState.COMPLETED, task.getStatus().state()); - assertEquals(1, task.getArtifacts().size()); - Artifact artifact = task.getArtifacts().get(0); - assertEquals("artifact-1", artifact.artifactId()); - assertEquals("image-analysis", artifact.name()); - assertEquals(1, artifact.parts().size()); - Part part = artifact.parts().get(0); - assertEquals(Part.Kind.TEXT, part.getKind()); - assertEquals("This is an image of a cat sitting on a windowsill.", ((TextPart) part).getText()); - assertTrue(task.getMetadata().isEmpty()); + client.sendMessage(params, null) + .whenComplete(new BiConsumer() { + @Override + public void accept(EventKind result, Throwable throwable) { + assertNull(throwable); + + assertInstanceOf(Task.class, result); + Task task = (Task) result; + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertNotNull(task.getContextId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("image-analysis", artifact.name()); + assertEquals(1, artifact.parts().size()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("This is an image of a cat sitting on a windowsill.", ((TextPart) part).getText()); + assertTrue(task.getMetadata().isEmpty()); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } @Test @@ -581,6 +717,7 @@ public void testA2AClientSendMessageWithDataPart() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); + CountDownLatch latch = new CountDownLatch(1); Map data = new HashMap<>(); data.put("temperature", 25.5); @@ -606,21 +743,33 @@ public void testA2AClientSendMessageWithDataPart() throws Exception { .configuration(configuration) .build(); - EventKind result = client.sendMessage(params, null); - assertInstanceOf(Task.class, result); - Task task = (Task) result; - assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); - assertNotNull(task.getContextId()); - assertEquals(TaskState.COMPLETED, task.getStatus().state()); - assertEquals(1, task.getArtifacts().size()); - Artifact artifact = task.getArtifacts().get(0); - assertEquals("artifact-1", artifact.artifactId()); - assertEquals("data-analysis", artifact.name()); - assertEquals(1, artifact.parts().size()); - Part part = artifact.parts().get(0); - assertEquals(Part.Kind.TEXT, part.getKind()); - assertEquals("Processed weather data: Temperature is 25.5°C, humidity is 60.2% in San Francisco.", ((TextPart) part).getText()); - assertTrue(task.getMetadata().isEmpty()); + client.sendMessage(params, null) + .whenComplete(new BiConsumer() { + @Override + public void accept(EventKind result, Throwable throwable) { + assertNull(throwable); + + assertInstanceOf(Task.class, result); + Task task = (Task) result; + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertNotNull(task.getContextId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("data-analysis", artifact.name()); + assertEquals(1, artifact.parts().size()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Processed weather data: Temperature is 25.5°C, humidity is 60.2% in San Francisco.", ((TextPart) part).getText()); + assertTrue(task.getMetadata().isEmpty()); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } @Test @@ -639,6 +788,7 @@ public void testA2AClientSendMessageWithMixedParts() throws Exception { ); JSONRPCTransport client = new JSONRPCTransport("http://localhost:4001"); + CountDownLatch latch = new CountDownLatch(1); Map data = new HashMap<>(); data.put("chartType", "bar"); @@ -664,20 +814,32 @@ public void testA2AClientSendMessageWithMixedParts() throws Exception { .configuration(configuration) .build(); - EventKind result = client.sendMessage(params, null); - assertInstanceOf(Task.class, result); - Task task = (Task) result; - assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); - assertNotNull(task.getContextId()); - assertEquals(TaskState.COMPLETED, task.getStatus().state()); - assertEquals(1, task.getArtifacts().size()); - Artifact artifact = task.getArtifacts().get(0); - assertEquals("artifact-1", artifact.artifactId()); - assertEquals("mixed-analysis", artifact.name()); - assertEquals(1, artifact.parts().size()); - Part part = artifact.parts().get(0); - assertEquals(Part.Kind.TEXT, part.getKind()); - assertEquals("Analyzed chart image and data: Bar chart showing quarterly data with values [10, 20, 30, 40].", ((TextPart) part).getText()); - assertTrue(task.getMetadata().isEmpty()); + client.sendMessage(params, null) + .whenComplete(new BiConsumer() { + @Override + public void accept(EventKind result, Throwable throwable) { + assertNull(throwable); + + assertInstanceOf(Task.class, result); + Task task = (Task) result; + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertNotNull(task.getContextId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("mixed-analysis", artifact.name()); + assertEquals(1, artifact.parts().size()); + Part part = artifact.parts().get(0); + assertEquals(Part.Kind.TEXT, part.getKind()); + assertEquals("Analyzed chart image and data: Bar chart showing quarterly data with values [10, 20, 30, 40].", ((TextPart) part).getText()); + assertTrue(task.getMetadata().isEmpty()); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } } \ No newline at end of file diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java index 85bf962b..bf4a7a53 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java @@ -18,6 +18,8 @@ import io.a2a.spec.TaskNotCancelableError; import io.a2a.spec.TaskNotFoundError; import io.a2a.spec.UnsupportedOperationError; + +import java.util.concurrent.CompletableFuture; import java.util.logging.Level; import java.util.logging.Logger; @@ -28,8 +30,8 @@ public class RestErrorMapper { private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().registerModule(new JavaTimeModule()); - public static A2AClientException mapRestError(HttpResponse response) { - return RestErrorMapper.mapRestError(response.body(), response.statusCode()); + public static CompletableFuture mapRestError(HttpResponse response) { + return response.body().thenCompose(responseBody -> CompletableFuture.failedFuture(RestErrorMapper.mapRestError(responseBody, response.statusCode()))); } public static A2AClientException mapRestError(String body, int code) { diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java index 912c0082..6258f433 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java @@ -4,7 +4,6 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.google.protobuf.InvalidProtocolBufferException; -import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.util.JsonFormat; import io.a2a.client.http.A2ACardResolver; import io.a2a.client.http.HttpClient; @@ -19,29 +18,15 @@ import io.a2a.grpc.GetTaskPushNotificationConfigRequest; import io.a2a.grpc.GetTaskRequest; import io.a2a.grpc.ListTaskPushNotificationConfigRequest; -import io.a2a.spec.TaskPushNotificationConfig; -import io.a2a.spec.A2AClientException; -import io.a2a.spec.AgentCard; -import io.a2a.spec.DeleteTaskPushNotificationConfigParams; -import io.a2a.spec.EventKind; -import io.a2a.spec.GetTaskPushNotificationConfigParams; -import io.a2a.spec.ListTaskPushNotificationConfigParams; -import io.a2a.spec.MessageSendParams; -import io.a2a.spec.StreamingEventKind; -import io.a2a.spec.Task; -import io.a2a.spec.TaskIdParams; -import io.a2a.spec.TaskQueryParams; +import io.a2a.spec.*; import io.a2a.grpc.utils.ProtoUtils; -import io.a2a.spec.A2AClientError; -import io.a2a.spec.SendStreamingMessageRequest; -import io.a2a.spec.SetTaskPushNotificationConfigRequest; import io.a2a.util.Utils; -import java.io.IOException; import java.net.URI; import java.util.Collections; import java.util.List; -import java.util.concurrent.ExecutionException; +import java.util.concurrent.CompletionStage; import java.util.function.BiConsumer; +import java.util.function.Function; import java.util.logging.Logger; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -77,26 +62,31 @@ public RestTransport(@Nullable HttpClient httpClient, @Nullable AgentCard agentC } @Override - public EventKind sendMessage(MessageSendParams messageSendParams, @Nullable ClientCallContext context) throws A2AClientException { + public CompletableFuture sendMessage(MessageSendParams messageSendParams, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("messageSendParams", messageSendParams); io.a2a.grpc.SendMessageRequest.Builder builder = io.a2a.grpc.SendMessageRequest.newBuilder(ProtoUtils.ToProto.sendMessageRequest(messageSendParams)); PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.SendMessageRequest.METHOD, builder, agentCard, context); - try { - String httpResponseBody = sendPostRequest("/v1/message:send", payloadAndHeaders); - io.a2a.grpc.SendMessageResponse.Builder responseBuilder = io.a2a.grpc.SendMessageResponse.newBuilder(); - JsonFormat.parser().merge(httpResponseBody, responseBuilder); - if (responseBuilder.hasMsg()) { - return ProtoUtils.FromProto.message(responseBuilder.getMsg()); - } - if (responseBuilder.hasTask()) { - return ProtoUtils.FromProto.task(responseBuilder.getTask()); - } - throw new A2AClientException("Failed to send message, wrong response:" + httpResponseBody); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException | ExecutionException e) { - throw new A2AClientException("Failed to send message: " + e, e); - } + return sendPostRequest("/v1/message:send", payloadAndHeaders) + .thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + io.a2a.grpc.SendMessageResponse.Builder responseBuilder = io.a2a.grpc.SendMessageResponse.newBuilder(); + try { + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + } catch (InvalidProtocolBufferException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to send message: " + e, e)); + } + + if (responseBuilder.hasMsg()) { + return CompletableFuture.completedFuture(ProtoUtils.FromProto.message(responseBuilder.getMsg())); + } + if (responseBuilder.hasTask()) { + return CompletableFuture.completedFuture(ProtoUtils.FromProto.task(responseBuilder.getTask())); + } + + return CompletableFuture.failedFuture(new A2AClientException("Failed to send message, wrong response:" + httpResponseBody)); + } + }); } @Override @@ -109,7 +99,7 @@ public void sendMessageStreaming(MessageSendParams messageSendParams, Consumer> ref = new AtomicReference<>(); RestSSEEventListener sseEventListener = new RestSSEEventListener(eventConsumer, errorConsumer); - try { + // try { HttpClient.PostRequestBuilder postBuilder = createPostBuilder("/v1/message:stream", payloadAndHeaders).asSSE(); ref.set(postBuilder.send().whenComplete(new BiConsumer() { @Override @@ -123,68 +113,84 @@ public void accept(HttpResponse httpResponse, Throwable throwable) { } } })); + /* } catch (IOException e) { throw new A2AClientException("Failed to send streaming message request: " + e, e); } + + */ } @Override - public Task getTask(TaskQueryParams taskQueryParams, @Nullable ClientCallContext context) throws A2AClientException { + public CompletableFuture getTask(TaskQueryParams taskQueryParams, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("taskQueryParams", taskQueryParams); GetTaskRequest.Builder builder = GetTaskRequest.newBuilder(); builder.setName("tasks/" + taskQueryParams.id()); PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, builder, agentCard, context); - try { - String path; - if (taskQueryParams.historyLength() != null) { - path = String.format("/v1/tasks/%1s?historyLength=%2d", taskQueryParams.id(), taskQueryParams.historyLength()); - } else { - path = String.format("/v1/tasks/%1s", taskQueryParams.id()); - } - HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); - if (payloadAndHeaders.getHeaders() != null) { - for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { - getBuilder.addHeader(entry.getKey(), entry.getValue()); - } - } - CompletableFuture responseFut = getBuilder.send(); - HttpResponse response = responseFut.get(); - if (!response.success()) { - throw RestErrorMapper.mapRestError(response); + + String path; + if (taskQueryParams.historyLength() != null) { + path = String.format("/v1/tasks/%1s?historyLength=%2d", taskQueryParams.id(), taskQueryParams.historyLength()); + } else { + path = String.format("/v1/tasks/%1s", taskQueryParams.id()); + } + HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); + if (payloadAndHeaders.getHeaders() != null) { + for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { + getBuilder.addHeader(entry.getKey(), entry.getValue()); } - String httpResponseBody = response.body(); - io.a2a.grpc.Task.Builder responseBuilder = io.a2a.grpc.Task.newBuilder(); - JsonFormat.parser().merge(httpResponseBody, responseBuilder); - return ProtoUtils.FromProto.task(responseBuilder); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException | ExecutionException e) { - throw new A2AClientException("Failed to get task: " + e, e); } + + return getBuilder.send() + .thenCompose(new Function>() { + @Override + public CompletionStage apply(HttpResponse response) { + if (!response.success()) { + return RestErrorMapper.mapRestError(response); + } + + return response.body(); + } + }).thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + io.a2a.grpc.Task.Builder responseBuilder = io.a2a.grpc.Task.newBuilder(); + try { + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + return CompletableFuture.completedFuture(ProtoUtils.FromProto.task(responseBuilder)); + } catch (InvalidProtocolBufferException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to get task: " + e, e)); + } + } + }); } @Override - public Task cancelTask(TaskIdParams taskIdParams, @Nullable ClientCallContext context) throws A2AClientException { + public CompletableFuture cancelTask(TaskIdParams taskIdParams, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("taskIdParams", taskIdParams); CancelTaskRequest.Builder builder = CancelTaskRequest.newBuilder(); builder.setName("tasks/" + taskIdParams.id()); PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.CancelTaskRequest.METHOD, builder, agentCard, context); - try { - String httpResponseBody = sendPostRequest(String.format("/v1/tasks/%1s:cancel", taskIdParams.id()), payloadAndHeaders); - io.a2a.grpc.Task.Builder responseBuilder = io.a2a.grpc.Task.newBuilder(); - JsonFormat.parser().merge(httpResponseBody, responseBuilder); - return ProtoUtils.FromProto.task(responseBuilder); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException | ExecutionException e) { - throw new A2AClientException("Failed to cancel task: " + e, e); - } + + return sendPostRequest(String.format("/v1/tasks/%1s:cancel", taskIdParams.id()), payloadAndHeaders) + .thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + io.a2a.grpc.Task.Builder responseBuilder = io.a2a.grpc.Task.newBuilder(); + try { + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + return CompletableFuture.completedFuture(ProtoUtils.FromProto.task(responseBuilder)); + } catch (InvalidProtocolBufferException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to cancel task: " + e, e)); + } + } + }); } @Override - public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, @Nullable ClientCallContext context) throws A2AClientException { + public CompletableFuture setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); CreateTaskPushNotificationConfigRequest.Builder builder = CreateTaskPushNotificationConfigRequest.newBuilder(); builder.setConfig(ProtoUtils.ToProto.taskPushNotificationConfig(request)) @@ -193,97 +199,138 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN builder.setConfigId(request.pushNotificationConfig().id()); } PayloadAndHeaders payloadAndHeaders = applyInterceptors(SetTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); - try { - String httpResponseBody = sendPostRequest(String.format("/v1/tasks/%1s/pushNotificationConfigs", request.taskId()), payloadAndHeaders); - io.a2a.grpc.TaskPushNotificationConfig.Builder responseBuilder = io.a2a.grpc.TaskPushNotificationConfig.newBuilder(); - JsonFormat.parser().merge(httpResponseBody, responseBuilder); - return ProtoUtils.FromProto.taskPushNotificationConfig(responseBuilder); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException | ExecutionException e) { - throw new A2AClientException("Failed to set task push notification config: " + e, e); - } + + return sendPostRequest(String.format("/v1/tasks/%1s/pushNotificationConfigs", request.taskId()), payloadAndHeaders) + .thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + io.a2a.grpc.TaskPushNotificationConfig.Builder responseBuilder = io.a2a.grpc.TaskPushNotificationConfig.newBuilder(); + try { + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + return CompletableFuture.completedFuture(ProtoUtils.FromProto.taskPushNotificationConfig(responseBuilder)); + } catch (InvalidProtocolBufferException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to set task push notification config: " + e, e)); + } + } + }); } @Override - public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { + public CompletableFuture getTaskPushNotificationConfiguration(GetTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); GetTaskPushNotificationConfigRequest.Builder builder = GetTaskPushNotificationConfigRequest.newBuilder(); builder.setName(String.format("/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId())); PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); - try { - String path = String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); - HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); - if (payloadAndHeaders.getHeaders() != null) { - for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { - getBuilder.addHeader(entry.getKey(), entry.getValue()); - } - } - CompletableFuture responseFut = getBuilder.send(); - HttpResponse response = responseFut.get(); - - if (!response.success()) { - throw RestErrorMapper.mapRestError(response); + String path = String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); + HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); + if (payloadAndHeaders.getHeaders() != null) { + for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { + getBuilder.addHeader(entry.getKey(), entry.getValue()); } - String httpResponseBody = response.body(); - io.a2a.grpc.TaskPushNotificationConfig.Builder responseBuilder = io.a2a.grpc.TaskPushNotificationConfig.newBuilder(); - JsonFormat.parser().merge(httpResponseBody, responseBuilder); - return ProtoUtils.FromProto.taskPushNotificationConfig(responseBuilder); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException | ExecutionException e) { - throw new A2AClientException("Failed to get push notifications: " + e, e); } + return getBuilder.send() + .thenCompose(new Function>() { + @Override + public CompletionStage apply(HttpResponse response) { + if (!response.success()) { + if (!response.success()) { + return RestErrorMapper.mapRestError(response); + } + } + + return response.body(); + } + }).thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + io.a2a.grpc.TaskPushNotificationConfig.Builder responseBuilder = io.a2a.grpc.TaskPushNotificationConfig.newBuilder(); + try { + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + return CompletableFuture.completedFuture(ProtoUtils.FromProto.taskPushNotificationConfig(responseBuilder)); + } catch (InvalidProtocolBufferException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to get push notifications: " + e, e)); + } + } + }); } @Override - public List listTaskPushNotificationConfigurations(ListTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { + public CompletableFuture> listTaskPushNotificationConfigurations(ListTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); ListTaskPushNotificationConfigRequest.Builder builder = ListTaskPushNotificationConfigRequest.newBuilder(); builder.setParent(String.format("/tasks/%1s/pushNotificationConfigs", request.id())); PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.ListTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); - try { - String path = String.format("/v1/tasks/%1s/pushNotificationConfigs", request.id()); - HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); - if (payloadAndHeaders.getHeaders() != null) { - for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { - getBuilder.addHeader(entry.getKey(), entry.getValue()); - } - } - CompletableFuture responseFut = getBuilder.send(); - HttpResponse response = responseFut.get(); - if (!response.success()) { - throw RestErrorMapper.mapRestError(response); + String path = String.format("/v1/tasks/%1s/pushNotificationConfigs", request.id()); + HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); + if (payloadAndHeaders.getHeaders() != null) { + for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { + getBuilder.addHeader(entry.getKey(), entry.getValue()); } - String httpResponseBody = response.body(); - io.a2a.grpc.ListTaskPushNotificationConfigResponse.Builder responseBuilder = io.a2a.grpc.ListTaskPushNotificationConfigResponse.newBuilder(); - JsonFormat.parser().merge(httpResponseBody, responseBuilder); - return ProtoUtils.FromProto.listTaskPushNotificationConfigParams(responseBuilder); - } catch (A2AClientException e) { - throw e; - } catch (IOException | InterruptedException | ExecutionException e) { - throw new A2AClientException("Failed to list push notifications: " + e, e); } + + return getBuilder.send() + .thenCompose(new Function>() { + @Override + public CompletionStage apply(HttpResponse response) { + if (!response.success()) { + if (!response.success()) { + return RestErrorMapper.mapRestError(response); + } + } + + return response.body(); + } + }).thenCompose(new Function>>() { + @Override + public CompletionStage> apply(String httpResponseBody) { + io.a2a.grpc.ListTaskPushNotificationConfigResponse.Builder responseBuilder = io.a2a.grpc.ListTaskPushNotificationConfigResponse.newBuilder(); + try { + JsonFormat.parser().merge(httpResponseBody, responseBuilder); + return CompletableFuture.completedFuture(ProtoUtils.FromProto.listTaskPushNotificationConfigParams(responseBuilder)); + } catch (InvalidProtocolBufferException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to list push notifications: " + e, e)); + } + } + }); } @Override - public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { + public CompletableFuture deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); io.a2a.grpc.DeleteTaskPushNotificationConfigRequestOrBuilder builder = io.a2a.grpc.DeleteTaskPushNotificationConfigRequest.newBuilder(); PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.DeleteTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); - try { - String path = String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); - HttpClient.DeleteRequestBuilder deleteBuilder = httpClient.delete(agentPath + path); - if (payloadAndHeaders.getHeaders() != null) { - for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { - deleteBuilder.addHeader(entry.getKey(), entry.getValue()); - } + + String path = String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); + HttpClient.DeleteRequestBuilder deleteBuilder = httpClient.delete(agentPath + path); + if (payloadAndHeaders.getHeaders() != null) { + for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { + deleteBuilder.addHeader(entry.getKey(), entry.getValue()); } + } + + return deleteBuilder + .send() + .thenCompose(new Function>() { + @Override + public CompletionStage apply(HttpResponse response) { + if (!response.success()) { + if (!response.success()) { + return RestErrorMapper.mapRestError(response); + } + } + + return response.body(); + } + }) + .thenApply(s -> null); + /* + try { + CompletableFuture responseFut = deleteBuilder.send(); HttpResponse response = responseFut.get(); @@ -295,6 +342,7 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to delete push notification config: " + e, e); } + */ } @Override @@ -307,7 +355,7 @@ public void resubscribe(TaskIdParams request, Consumer event agentCard, context); AtomicReference> ref = new AtomicReference<>(); RestSSEEventListener sseEventListener = new RestSSEEventListener(eventConsumer, errorConsumer); - try { + // try { String path = String.format("/v1/tasks/%1s:subscribe", request.id()); HttpClient.PostRequestBuilder postBuilder = createPostBuilder(path, payloadAndHeaders).asSSE(); ref.set(postBuilder.send().whenComplete(new BiConsumer() { @@ -322,22 +370,28 @@ public void accept(HttpResponse httpResponse, Throwable throwable) { } } })); + /* } catch (IOException e) { throw new A2AClientException("Failed to send streaming message request: " + e, e); } + */ } @Override - public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2AClientException { + public CompletableFuture getAgentCard(@Nullable ClientCallContext context) throws A2AClientException { A2ACardResolver resolver; - try { + if (agentCard == null) { - resolver = new A2ACardResolver(httpClient, agentPath, getHttpHeaders(context)); - agentCard = resolver.getAgentCard(); - needsExtendedCard = agentCard.supportsAuthenticatedExtendedCard(); + try { + resolver = new A2ACardResolver(httpClient, agentPath, getHttpHeaders(context)); + agentCard = resolver.getAgentCard(); + needsExtendedCard = agentCard.supportsAuthenticatedExtendedCard(); + } catch (A2AClientError e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to get agent card: " + e, e)); + } } if (!needsExtendedCard) { - return agentCard; + return CompletableFuture.completedFuture(agentCard); } PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, null, agentCard, context); @@ -348,21 +402,32 @@ public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2ACli getBuilder.addHeader(entry.getKey(), entry.getValue()); } } - CompletableFuture responseFut = getBuilder.send(); - HttpResponse response = responseFut.get(); + return getBuilder.send() + .thenCompose(new Function>() { + @Override + public CompletionStage apply(HttpResponse response) { + if (!response.success()) { + return RestErrorMapper.mapRestError(response); + } - if (!response.success()) { - throw RestErrorMapper.mapRestError(response); - } - String httpResponseBody = response.body(); - agentCard = Utils.OBJECT_MAPPER.readValue(httpResponseBody, AgentCard.class); - needsExtendedCard = false; - return agentCard; - } catch (IOException | InterruptedException | ExecutionException e) { - throw new A2AClientException("Failed to get authenticated extended agent card: " + e, e); - } catch (A2AClientError e) { - throw new A2AClientException("Failed to get agent card: " + e, e); - } + return response.body(); + } + }).thenCompose(new Function>() { + @Override + public CompletionStage apply(String httpResponseBody) { + try { + return CompletableFuture.completedFuture(Utils.OBJECT_MAPPER.readValue(httpResponseBody, AgentCard.class)); + } catch (JsonProcessingException e) { + return CompletableFuture.failedFuture(new A2AClientException("Failed to get authenticated extended agent card: " + e, e)); + } + } + }).whenComplete(new BiConsumer() { + @Override + public void accept(AgentCard agentCard, Throwable throwable) { + RestTransport.this.agentCard = agentCard; + needsExtendedCard = false; + } + }); } @Override @@ -370,23 +435,27 @@ public void close() { // no-op } - private String sendPostRequest(String path, PayloadAndHeaders payloadAndHeaders) throws IOException, InterruptedException, ExecutionException { - HttpClient.PostRequestBuilder builder = createPostBuilder(path, payloadAndHeaders); - CompletableFuture responseFut = builder.send(); + private CompletableFuture sendPostRequest(String path, PayloadAndHeaders payloadAndHeaders) { + return createPostBuilder(path, payloadAndHeaders) + .send() + .thenCompose(new Function>() { + @Override + public CompletionStage apply(HttpResponse response) { + if (!response.success()) { + log.fine("Error on POST processing " + convertToJsonString(payloadAndHeaders.getPayload())); + return RestErrorMapper.mapRestError(response); + } - HttpResponse response = responseFut.get(); - if (!response.success()) { - log.fine("Error on POST processing " + JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); - throw RestErrorMapper.mapRestError(response); - } - return response.body(); + return response.body(); + } + }); } - private HttpClient.PostRequestBuilder createPostBuilder(String path, PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException, InvalidProtocolBufferException { - log.fine(JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); + private HttpClient.PostRequestBuilder createPostBuilder(String path, PayloadAndHeaders payloadAndHeaders) { + log.fine(convertToJsonString(payloadAndHeaders.getPayload())); HttpClient.PostRequestBuilder postBuilder = httpClient.post(agentPath + path) .addHeader("Content-Type", "application/json") - .body(JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); + .body(convertToJsonString(payloadAndHeaders.getPayload())); if (payloadAndHeaders.getHeaders() != null) { for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { @@ -399,4 +468,16 @@ private HttpClient.PostRequestBuilder createPostBuilder(String path, PayloadAndH private Map getHttpHeaders(@Nullable ClientCallContext context) { return context != null ? context.getHeaders() : Collections.emptyMap(); } + + private @Nullable String convertToJsonString(@Nullable Object obj) { + if (obj != null) { + try { + return JsonFormat.printer().print((com.google.protobuf.MessageOrBuilder) obj); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + return null; + } } diff --git a/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java b/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java index ae938cb4..75388341 100644 --- a/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java +++ b/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java @@ -50,6 +50,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.logging.Logger; import org.junit.jupiter.api.AfterEach; @@ -105,27 +106,42 @@ public void testSendMessage() throws Exception { ClientCallContext context = null; RestTransport instance = new RestTransport(AGENT_URL); - EventKind result = instance.sendMessage(messageSendParams, context); - assertEquals("task", result.getKind()); - Task task = (Task) result; - assertEquals("9b511af4-b27c-47fa-aecf-2a93c08a44f8", task.getId()); - assertEquals("context-1234", task.getContextId()); - assertEquals(TaskState.SUBMITTED, task.getStatus().state()); - assertNull(task.getStatus().message()); - assertNull(task.getMetadata()); - assertEquals(true, task.getArtifacts().isEmpty()); - assertEquals(1, task.getHistory().size()); - Message history = task.getHistory().get(0); - assertEquals("message", history.getKind()); - assertEquals(Message.Role.USER, history.getRole()); - assertEquals("context-1234", history.getContextId()); - assertEquals("message-1234", history.getMessageId()); - assertEquals("9b511af4-b27c-47fa-aecf-2a93c08a44f8", history.getTaskId()); - assertEquals(1, history.getParts().size()); - assertEquals(Kind.TEXT, history.getParts().get(0).getKind()); - assertEquals("tell me a joke", ((TextPart) history.getParts().get(0)).getText()); - assertNull(history.getMetadata()); - assertNull(history.getReferenceTaskIds()); + + CountDownLatch latch = new CountDownLatch(1); + + instance.sendMessage(messageSendParams, context) + .whenComplete(new BiConsumer() { + @Override + public void accept(EventKind result, Throwable throwable) { + assertNull(throwable); + + assertEquals("task", result.getKind()); + Task task = (Task) result; + assertEquals("9b511af4-b27c-47fa-aecf-2a93c08a44f8", task.getId()); + assertEquals("context-1234", task.getContextId()); + assertEquals(TaskState.SUBMITTED, task.getStatus().state()); + assertNull(task.getStatus().message()); + assertNull(task.getMetadata()); + assertEquals(true, task.getArtifacts().isEmpty()); + assertEquals(1, task.getHistory().size()); + Message history = task.getHistory().get(0); + assertEquals("message", history.getKind()); + assertEquals(Message.Role.USER, history.getRole()); + assertEquals("context-1234", history.getContextId()); + assertEquals("message-1234", history.getMessageId()); + assertEquals("9b511af4-b27c-47fa-aecf-2a93c08a44f8", history.getTaskId()); + assertEquals(1, history.getParts().size()); + assertEquals(Kind.TEXT, history.getParts().get(0).getKind()); + assertEquals("tell me a joke", ((TextPart) history.getParts().get(0)).getText()); + assertNull(history.getMetadata()); + assertNull(history.getReferenceTaskIds()); + + latch.countDown(); + } + }); + + boolean callCompleted = latch.await(5, TimeUnit.SECONDS); + assertTrue(callCompleted); } /** @@ -146,12 +162,27 @@ public void testCancelTask() throws Exception { ); ClientCallContext context = null; RestTransport instance = new RestTransport(AGENT_URL); - Task task = instance.cancelTask(new TaskIdParams("de38c76d-d54c-436c-8b9f-4c2703648d64", - new HashMap<>()), context); - assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); - assertEquals(TaskState.CANCELED, task.getStatus().state()); - assertNull(task.getStatus().message()); - assertNull(task.getMetadata()); + + CountDownLatch latch = new CountDownLatch(1); + + instance.cancelTask(new TaskIdParams("de38c76d-d54c-436c-8b9f-4c2703648d64", + new HashMap<>()), context) + .whenComplete(new BiConsumer() { + @Override + public void accept(Task task, Throwable throwable) { + assertNull(throwable); + + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertEquals(TaskState.CANCELED, task.getStatus().state()); + assertNull(task.getStatus().message()); + assertNull(task.getMetadata()); + + latch.countDown(); + } + }); + + boolean callCompleted = latch.await(5, TimeUnit.SECONDS); + assertTrue(callCompleted); } /** @@ -172,37 +203,52 @@ public void testGetTask() throws Exception { ClientCallContext context = null; TaskQueryParams request = new TaskQueryParams("de38c76d-d54c-436c-8b9f-4c2703648d64", 10); RestTransport instance = new RestTransport(AGENT_URL); - Task task = instance.getTask(request, context); - assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); - assertEquals(TaskState.COMPLETED, task.getStatus().state()); - assertNull(task.getStatus().message()); - assertNull(task.getMetadata()); - assertEquals(false, task.getArtifacts().isEmpty()); - assertEquals(1, task.getArtifacts().size()); - Artifact artifact = task.getArtifacts().get(0); - assertEquals("artifact-1", artifact.artifactId()); - assertEquals("", artifact.name()); - assertEquals(false, artifact.parts().isEmpty()); - assertEquals(Kind.TEXT, artifact.parts().get(0).getKind()); - assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) artifact.parts().get(0)).getText()); - assertEquals(1, task.getHistory().size()); - Message history = task.getHistory().get(0); - assertEquals("message", history.getKind()); - assertEquals(Message.Role.USER, history.getRole()); - assertEquals("message-123", history.getMessageId()); - assertEquals(3, history.getParts().size()); - assertEquals(Kind.TEXT, history.getParts().get(0).getKind()); - assertEquals("tell me a joke", ((TextPart) history.getParts().get(0)).getText()); - assertEquals(Kind.FILE, history.getParts().get(1).getKind()); - FilePart part = (FilePart) history.getParts().get(1); - assertEquals("text/plain", part.getFile().mimeType()); - assertEquals("file:///path/to/file.txt", ((FileWithUri) part.getFile()).uri()); - part = (FilePart) history.getParts().get(2); - assertEquals(Kind.FILE, part.getKind()); - assertEquals("text/plain", part.getFile().mimeType()); - assertEquals("hello", ((FileWithBytes) part.getFile()).bytes()); - assertNull(history.getMetadata()); - assertNull(history.getReferenceTaskIds()); + CountDownLatch latch = new CountDownLatch(1); + + instance.getTask(request, context) + .whenComplete(new BiConsumer() { + @Override + public void accept(Task task, Throwable throwable) { + assertNull(throwable); + + assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); + assertEquals(TaskState.COMPLETED, task.getStatus().state()); + assertNull(task.getStatus().message()); + assertNull(task.getMetadata()); + assertEquals(false, task.getArtifacts().isEmpty()); + assertEquals(1, task.getArtifacts().size()); + Artifact artifact = task.getArtifacts().get(0); + assertEquals("artifact-1", artifact.artifactId()); + assertEquals("", artifact.name()); + assertEquals(false, artifact.parts().isEmpty()); + assertEquals(Kind.TEXT, artifact.parts().get(0).getKind()); + assertEquals("Why did the chicken cross the road? To get to the other side!", ((TextPart) artifact.parts().get(0)).getText()); + assertEquals(1, task.getHistory().size()); + Message history = task.getHistory().get(0); + assertEquals("message", history.getKind()); + assertEquals(Message.Role.USER, history.getRole()); + assertEquals("message-123", history.getMessageId()); + assertEquals(3, history.getParts().size()); + assertEquals(Kind.TEXT, history.getParts().get(0).getKind()); + assertEquals("tell me a joke", ((TextPart) history.getParts().get(0)).getText()); + assertEquals(Kind.FILE, history.getParts().get(1).getKind()); + FilePart part = (FilePart) history.getParts().get(1); + assertEquals("text/plain", part.getFile().mimeType()); + assertEquals("file:///path/to/file.txt", ((FileWithUri) part.getFile()).uri()); + part = (FilePart) history.getParts().get(2); + assertEquals(Kind.FILE, part.getKind()); + assertEquals("text/plain", part.getFile().mimeType()); + assertEquals("hello", ((FileWithBytes) part.getFile()).bytes()); + assertNull(history.getMetadata()); + assertNull(history.getReferenceTaskIds()); + + latch.countDown(); + } + }); + + boolean callCompleted = latch.await(5, TimeUnit.SECONDS); + assertTrue(callCompleted); + } /** @@ -274,6 +320,9 @@ public void testSetTaskPushNotificationConfiguration() throws Exception { .withBody(SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) ); RestTransport client = new RestTransport(AGENT_URL); + + CountDownLatch latch = new CountDownLatch(1); + TaskPushNotificationConfig pushedConfig = new TaskPushNotificationConfig( "de38c76d-d54c-436c-8b9f-4c2703648d64", new PushNotificationConfig.Builder() @@ -281,13 +330,26 @@ public void testSetTaskPushNotificationConfiguration() throws Exception { .authenticationInfo( new PushNotificationAuthenticationInfo(Collections.singletonList("jwt"), null)) .build()); - TaskPushNotificationConfig taskPushNotificationConfig = client.setTaskPushNotificationConfiguration(pushedConfig, null); - PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); - assertNotNull(pushNotificationConfig); - assertEquals("https://example.com/callback", pushNotificationConfig.url()); - PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); - assertEquals(1, authenticationInfo.schemes().size()); - assertEquals("jwt", authenticationInfo.schemes().get(0)); + + client.setTaskPushNotificationConfiguration(pushedConfig, null) + .whenComplete(new BiConsumer() { + @Override + public void accept(TaskPushNotificationConfig taskPushNotificationConfig, Throwable throwable) { + assertNull(throwable); + + PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://example.com/callback", pushNotificationConfig.url()); + PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); + assertEquals(1, authenticationInfo.schemes().size()); + assertEquals("jwt", authenticationInfo.schemes().get(0)); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } /** @@ -307,15 +369,30 @@ public void testGetTaskPushNotificationConfiguration() throws Exception { ); RestTransport client = new RestTransport(AGENT_URL); - TaskPushNotificationConfig taskPushNotificationConfig = client.getTaskPushNotificationConfiguration( + + CountDownLatch latch = new CountDownLatch(1); + + client.getTaskPushNotificationConfiguration( new GetTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", "10", - new HashMap<>()), null); - PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); - assertNotNull(pushNotificationConfig); - assertEquals("https://example.com/callback", pushNotificationConfig.url()); - PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); - assertTrue(authenticationInfo.schemes().size() == 1); - assertEquals("jwt", authenticationInfo.schemes().get(0)); + new HashMap<>()), null) + .whenComplete(new BiConsumer() { + @Override + public void accept(TaskPushNotificationConfig taskPushNotificationConfig, Throwable throwable) { + assertNull(throwable); + + PushNotificationConfig pushNotificationConfig = taskPushNotificationConfig.pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://example.com/callback", pushNotificationConfig.url()); + PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); + assertTrue(authenticationInfo.schemes().size() == 1); + assertEquals("jwt", authenticationInfo.schemes().get(0)); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } /** @@ -335,23 +412,37 @@ public void testListTaskPushNotificationConfigurations() throws Exception { ); RestTransport client = new RestTransport(AGENT_URL); - List taskPushNotificationConfigs = client.listTaskPushNotificationConfigurations( - new ListTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", new HashMap<>()), null); - assertEquals(2, taskPushNotificationConfigs.size()); - PushNotificationConfig pushNotificationConfig = taskPushNotificationConfigs.get(0).pushNotificationConfig(); - assertNotNull(pushNotificationConfig); - assertEquals("https://example.com/callback", pushNotificationConfig.url()); - assertEquals("10", pushNotificationConfig.id()); - PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); - assertTrue(authenticationInfo.schemes().size() == 1); - assertEquals("jwt", authenticationInfo.schemes().get(0)); - assertEquals("", authenticationInfo.credentials()); - pushNotificationConfig = taskPushNotificationConfigs.get(1).pushNotificationConfig(); - assertNotNull(pushNotificationConfig); - assertEquals("https://test.com/callback", pushNotificationConfig.url()); - assertEquals("5", pushNotificationConfig.id()); - authenticationInfo = pushNotificationConfig.authentication(); - assertNull(authenticationInfo); + CountDownLatch latch = new CountDownLatch(1); + + client.listTaskPushNotificationConfigurations( + new ListTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", new HashMap<>()), null) + .whenComplete(new BiConsumer, Throwable>() { + @Override + public void accept(List taskPushNotificationConfigs, Throwable throwable) { + assertNull(throwable); + + assertEquals(2, taskPushNotificationConfigs.size()); + PushNotificationConfig pushNotificationConfig = taskPushNotificationConfigs.get(0).pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://example.com/callback", pushNotificationConfig.url()); + assertEquals("10", pushNotificationConfig.id()); + PushNotificationAuthenticationInfo authenticationInfo = pushNotificationConfig.authentication(); + assertTrue(authenticationInfo.schemes().size() == 1); + assertEquals("jwt", authenticationInfo.schemes().get(0)); + assertEquals("", authenticationInfo.credentials()); + pushNotificationConfig = taskPushNotificationConfigs.get(1).pushNotificationConfig(); + assertNotNull(pushNotificationConfig); + assertEquals("https://test.com/callback", pushNotificationConfig.url()); + assertEquals("5", pushNotificationConfig.id()); + authenticationInfo = pushNotificationConfig.authentication(); + assertNull(authenticationInfo); + + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed); } /** diff --git a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/ClientTransport.java b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/ClientTransport.java index 56a4067c..3456879c 100644 --- a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/ClientTransport.java +++ b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/ClientTransport.java @@ -1,6 +1,7 @@ package io.a2a.client.transport.spi; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; import io.a2a.client.transport.spi.interceptors.ClientCallContext; @@ -31,7 +32,7 @@ public interface ClientTransport { * @return the response, either a Task or Message * @throws A2AClientException if sending the message fails for any reason */ - EventKind sendMessage(MessageSendParams request, @Nullable ClientCallContext context) + CompletableFuture sendMessage(MessageSendParams request, @Nullable ClientCallContext context) throws A2AClientException; /** @@ -54,7 +55,7 @@ void sendMessageStreaming(MessageSendParams request, Consumer getTask(TaskQueryParams request, @Nullable ClientCallContext context) throws A2AClientException; /** * Request the agent to cancel a specific task. @@ -64,7 +65,7 @@ void sendMessageStreaming(MessageSendParams request, Consumer cancelTask(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException; /** * Set or update the push notification configuration for a specific task. @@ -74,7 +75,7 @@ void sendMessageStreaming(MessageSendParams request, Consumer setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, @Nullable ClientCallContext context) throws A2AClientException; /** @@ -85,7 +86,7 @@ TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushNotifica * @return the task push notification config * @throws A2AClientException if getting the task push notification config fails for any reason */ - TaskPushNotificationConfig getTaskPushNotificationConfiguration( + CompletableFuture getTaskPushNotificationConfiguration( GetTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException; @@ -97,7 +98,7 @@ TaskPushNotificationConfig getTaskPushNotificationConfiguration( * @return the list of task push notification configs * @throws A2AClientException if getting the task push notification configs fails for any reason */ - List listTaskPushNotificationConfigurations( + CompletableFuture> listTaskPushNotificationConfigurations( ListTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException; @@ -108,7 +109,7 @@ List listTaskPushNotificationConfigurations( * @param context optional client call context for the request (may be {@code null}) * @throws A2AClientException if deleting the task push notification configs fails for any reason */ - void deleteTaskPushNotificationConfigurations( + CompletableFuture deleteTaskPushNotificationConfigurations( DeleteTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException; @@ -131,7 +132,7 @@ void resubscribe(TaskIdParams request, Consumer eventConsume * @return the AgentCard * @throws A2AClientException if retrieving the agent card fails for any reason */ - AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2AClientException; + CompletableFuture getAgentCard(@Nullable ClientCallContext context) throws A2AClientException; /** * Close the transport and release any associated resources. diff --git a/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClient.java b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClient.java index 62284dc6..d9295d1d 100644 --- a/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClient.java +++ b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClient.java @@ -8,6 +8,7 @@ import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderValues; import io.vertx.core.*; +import io.vertx.core.buffer.Buffer; import io.vertx.core.http.*; import java.io.IOException; @@ -16,7 +17,6 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; import java.util.function.Consumer; import java.util.function.Function; @@ -189,15 +189,8 @@ public int statusCode() { } @Override - public String body() { - try { - return response.body().toCompletionStage().toCompletableFuture().get().toString(); - - } catch (InterruptedException e) { - throw new RuntimeException(e); - } catch (ExecutionException e) { - throw new RuntimeException(e); - } + public CompletableFuture body() { + return response.body().map(Buffer::toString).toCompletionStage().toCompletableFuture(); } @Override diff --git a/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java b/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java index d938bb93..7685ce65 100644 --- a/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java +++ b/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java @@ -114,7 +114,7 @@ public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { if (!response.success()) { throw new A2AClientError("Failed to obtain agent card: " + response.statusCode()); } - body = response.body(); + body = response.body().get(); } catch (InterruptedException | ExecutionException e) { throw new A2AClientError("Failed to obtain agent card", e); } diff --git a/http-client/src/main/java/io/a2a/client/http/HttpClient.java b/http-client/src/main/java/io/a2a/client/http/HttpClient.java index 1cb14fde..fe16d7b0 100644 --- a/http-client/src/main/java/io/a2a/client/http/HttpClient.java +++ b/http-client/src/main/java/io/a2a/client/http/HttpClient.java @@ -1,5 +1,7 @@ package io.a2a.client.http; +import org.jspecify.annotations.Nullable; + import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -28,7 +30,7 @@ interface GetRequestBuilder extends RequestBuilder { } interface PostRequestBuilder extends RequestBuilder { - PostRequestBuilder body(String body); + PostRequestBuilder body(@Nullable String body); default PostRequestBuilder asSSE() { return addHeader("Accept", "text/event-stream"); diff --git a/http-client/src/main/java/io/a2a/client/http/HttpResponse.java b/http-client/src/main/java/io/a2a/client/http/HttpResponse.java index 3e2f35f6..39a009ae 100644 --- a/http-client/src/main/java/io/a2a/client/http/HttpResponse.java +++ b/http-client/src/main/java/io/a2a/client/http/HttpResponse.java @@ -2,6 +2,7 @@ import io.a2a.client.http.sse.Event; +import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; public interface HttpResponse { @@ -11,7 +12,7 @@ default boolean success() { return statusCode() >= 200 && statusCode() < 300; } - String body(); + CompletableFuture body(); void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer); } diff --git a/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClient.java b/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClient.java index 83e31208..d786ad8b 100644 --- a/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClient.java +++ b/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClient.java @@ -127,7 +127,7 @@ public CompletableFuture send() { HttpRequest request = super.createRequestBuilder().GET().build(); return httpClient .sendAsync(request, BodyHandlers.ofString(StandardCharsets.UTF_8)) - .thenCompose(RESPONSE_MAPPER); + .thenApply((Function, HttpResponse>) JdkHttpResponse::new); } } @@ -142,7 +142,7 @@ public CompletableFuture send() { HttpRequest request = super.createRequestBuilder().DELETE().build(); return httpClient .sendAsync(request, BodyHandlers.ofString(StandardCharsets.UTF_8)) - .thenCompose(RESPONSE_MAPPER); + .thenApply((Function, HttpResponse>) JdkHttpResponse::new); } } @@ -174,7 +174,8 @@ public CompletableFuture send() { bodyHandler = BodyHandlers.ofString(StandardCharsets.UTF_8); } - return httpClient.sendAsync(request, bodyHandler).thenCompose(RESPONSE_MAPPER); + return httpClient.sendAsync(request, bodyHandler) + .thenApply((Function, HttpResponse>) JdkHttpResponse::new); } } @@ -200,12 +201,12 @@ static boolean success(java.net.http.HttpResponse response) { } @Override - public String body() { + public CompletableFuture body() { if (response.body() instanceof String) { - return (String) response.body(); + return CompletableFuture.completedFuture((String) response.body()); } - throw new IllegalStateException(); + return CompletableFuture.failedFuture(new IllegalStateException()); } @Override diff --git a/tests/client-common/src/test/java/io/a2a/client/http/common/AbstractHttpClientTest.java b/tests/client-common/src/test/java/io/a2a/client/http/common/AbstractHttpClientTest.java index 9b21d4f1..dbb34884 100644 --- a/tests/client-common/src/test/java/io/a2a/client/http/common/AbstractHttpClientTest.java +++ b/tests/client-common/src/test/java/io/a2a/client/http/common/AbstractHttpClientTest.java @@ -48,7 +48,6 @@ private String getServerUrl() { * This test is disabled until we can make the http-client layer fully async */ @Test - @Disabled public void testGetWithBodyResponse() throws Exception { givenThat(get(urlPathEqualTo(AGENT_CARD_PATH)) .willReturn(okForContentType("application/json", JsonMessages.AGENT_CARD))); @@ -58,12 +57,12 @@ public void testGetWithBodyResponse() throws Exception { .create(getServerUrl()) .get(AGENT_CARD_PATH) .send() - .thenAccept(new Consumer() { + .thenCompose(HttpResponse::body) + .thenAccept(new Consumer() { @Override - public void accept(HttpResponse httpResponse) { - String body = httpResponse.body(); + public void accept(String responseBody) { + Assertions.assertEquals(JsonMessages.AGENT_CARD, responseBody); - Assertions.assertEquals(JsonMessages.AGENT_CARD, body); latch.countDown(); } });