Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
ListTasksParams,
ListTasksResult,
Message,
MessageSendConfiguration,
MessageSendParams,
Expand Down Expand Up @@ -133,6 +135,15 @@ async def get_task(
"""
return await self._transport.get_task(request, context=context)

async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""
return await self._transport.list_tasks(request, context=context)

async def cancel_task(
self,
request: TaskIdParams,
Expand Down
11 changes: 11 additions & 0 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
ListTasksParams,
ListTasksResult,
Message,
PushNotificationConfig,
Task,
Expand Down Expand Up @@ -119,52 +121,61 @@
pairs, or a `Message`. Client will also send these values to any
configured `Consumer`s in the client.
"""
return
yield

@abstractmethod
async def get_task(
self,
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""

@abstractmethod
async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""

@abstractmethod
async def cancel_task(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""

@abstractmethod
async def set_task_callback(
self,
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""

@abstractmethod
async def get_task_callback(
self,
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""

@abstractmethod
async def resubscribe(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
) -> AsyncIterator[ClientEvent]:

Check notice on line 178 in src/a2a/client/client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/base.py (43-97)
"""Resubscribes to a task's event stream."""
return
yield
Expand Down
11 changes: 11 additions & 0 deletions src/a2a/client/transports/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
ListTasksParams,
ListTasksResult,
Message,
MessageSendParams,
Task,
Expand Down Expand Up @@ -38,52 +40,61 @@
Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
]:
"""Sends a streaming message request to the agent and yields responses as they arrive."""
return
yield

@abstractmethod
async def get_task(
self,
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""

@abstractmethod
async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""

@abstractmethod
async def cancel_task(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""

@abstractmethod
async def set_task_callback(
self,
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""

@abstractmethod
async def get_task_callback(
self,
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""

@abstractmethod
async def resubscribe(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
) -> AsyncGenerator[

Check notice on line 97 in src/a2a/client/transports/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/client.py (124-178)
Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
]:
"""Reconnects to get task updates."""
Expand Down
17 changes: 17 additions & 0 deletions src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from collections.abc import AsyncGenerator

from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE


try:
import grpc
Expand All @@ -20,6 +22,8 @@
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
ListTasksParams,
ListTasksResult,
Message,
MessageSendParams,
Task,
Expand Down Expand Up @@ -145,6 +149,19 @@ async def get_task(
)
return proto_utils.FromProto.task(task)

async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""
response = await self.stub.ListTasks(
proto_utils.ToProto.list_tasks_request(request)
)
page_size = request.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE
return proto_utils.FromProto.list_tasks_result(response, page_size)

async def cancel_task(
self,
request: TaskIdParams,
Expand Down
24 changes: 24 additions & 0 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
GetTaskRequest,
GetTaskResponse,
JSONRPCErrorResponse,
ListTasksParams,
ListTasksRequest,
ListTasksResponse,
ListTasksResult,
Message,
MessageSendParams,
SendMessageRequest,
Expand Down Expand Up @@ -165,191 +169,211 @@
async for sse in event_source.aiter_sse():
response = SendStreamingMessageResponse.model_validate(
json.loads(sse.data)
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_request(
self,
rpc_request_payload: dict[str, Any],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
try:
response = await self.httpx_client.post(
self.url, json=rpc_request_payload, **(http_kwargs or {})
)
response.raise_for_status()
return response.json()
except httpx.ReadTimeout as e:
raise A2AClientTimeoutError('Client Request timed out') from e
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_task(
self,
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""
rpc_request = GetTaskRequest(params=request, id=str(uuid4()))
payload, modified_kwargs = await self._apply_interceptors(
'tasks/get',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
context,
)
response_data = await self._send_request(payload, modified_kwargs)
response = GetTaskResponse.model_validate(response_data)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
return response.root.result

async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""
rpc_request = ListTasksRequest(params=request, id=str(uuid4()))
payload, modified_kwargs = await self._apply_interceptors(
'tasks/list',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
context,
)
response_data = await self._send_request(payload, modified_kwargs)
response = ListTasksResponse.model_validate(response_data)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
return response.root.result

async def cancel_task(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""
rpc_request = CancelTaskRequest(params=request, id=str(uuid4()))
payload, modified_kwargs = await self._apply_interceptors(
'tasks/cancel',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
context,
)
response_data = await self._send_request(payload, modified_kwargs)
response = CancelTaskResponse.model_validate(response_data)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
return response.root.result

async def set_task_callback(
self,
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""
rpc_request = SetTaskPushNotificationConfigRequest(
params=request, id=str(uuid4())
)
payload, modified_kwargs = await self._apply_interceptors(
'tasks/pushNotificationConfig/set',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
context,
)
response_data = await self._send_request(payload, modified_kwargs)
response = SetTaskPushNotificationConfigResponse.model_validate(
response_data
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
return response.root.result

async def get_task_callback(
self,
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""
rpc_request = GetTaskPushNotificationConfigRequest(
params=request, id=str(uuid4())
)
payload, modified_kwargs = await self._apply_interceptors(
'tasks/pushNotificationConfig/get',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
context,
)
response_data = await self._send_request(payload, modified_kwargs)
response = GetTaskPushNotificationConfigResponse.model_validate(
response_data
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
return response.root.result

async def resubscribe(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
) -> AsyncGenerator[
Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
]:
"""Reconnects to get task updates."""
rpc_request = TaskResubscriptionRequest(params=request, id=str(uuid4()))
payload, modified_kwargs = await self._apply_interceptors(
'tasks/resubscribe',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
context,
)

modified_kwargs.setdefault('timeout', None)

async with aconnect_sse(
self.httpx_client,
'POST',
self.url,
json=payload,
**modified_kwargs,
) as event_source:
try:
async for sse in event_source.aiter_sse():
response = SendStreamingMessageResponse.model_validate_json(
sse.data
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_card(
self,
*,
context: ClientCallContext | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
card = self.agent_card
if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(
http_kwargs=self._get_http_args(context)
)
self._needs_extended_card = (
card.supports_authenticated_extended_card
)

Check notice on line 376 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (347-382)
self.agent_card = card

if not self._needs_extended_card:
Expand Down
44 changes: 44 additions & 0 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from google.protobuf.json_format import MessageToDict, Parse, ParseDict
from httpx_sse import SSEError, aconnect_sse
from pydantic import BaseModel

from a2a.client.card_resolver import A2ACardResolver
from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError
Expand All @@ -17,6 +18,8 @@
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
ListTasksParams,
ListTasksResult,
Message,
MessageSendParams,
Task,
Expand All @@ -27,6 +30,7 @@
TaskStatusUpdateEvent,
)
from a2a.utils import proto_utils
from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE
from a2a.utils.telemetry import SpanKind, trace_class


Expand Down Expand Up @@ -222,6 +226,28 @@ async def get_task(
ParseDict(response_data, task)
return proto_utils.FromProto.task(task)

async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""
_, modified_kwargs = await self._apply_interceptors(
request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
context,
)
response_data = await self._send_get_request(
'/v1/tasks',
_model_to_query_params(request),
modified_kwargs,
)
response = a2a_pb2.ListTasksResponse()
ParseDict(response_data, response)
page_size = request.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE
return proto_utils.FromProto.list_tasks_result(response, page_size)

async def cancel_task(
self,
request: TaskIdParams,
Expand Down Expand Up @@ -363,3 +389,21 @@ async def get_card(
async def close(self) -> None:
"""Closes the httpx client."""
await self.httpx_client.aclose()


def _model_to_query_params(instance: BaseModel) -> dict[str, str]:
data = instance.model_dump(mode='json', exclude_none=True)
return _json_to_query_params(data)


def _json_to_query_params(data: dict[str, Any]) -> dict[str, str]:
query_dict = {}
for key, value in data.items():
if isinstance(value, list):
query_dict[key] = ','.join(map(str, value))
elif isinstance(value, bool):
query_dict[key] = str(value).lower()
else:
query_dict[key] = str(value)

return query_dict
Loading
Loading