Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
ListTasksParams,
ListTasksResult,
Message,
MessageSendConfiguration,
MessageSendParams,
Expand Down Expand Up @@ -133,6 +135,15 @@ async def get_task(
"""
return await self._transport.get_task(request, context=context)

async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""
return await self._transport.list_tasks(request, context=context)

async def cancel_task(
self,
request: TaskIdParams,
Expand Down
11 changes: 11 additions & 0 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
ListTasksParams,
ListTasksResult,
Message,
PushNotificationConfig,
Task,
Expand Down Expand Up @@ -131,6 +133,15 @@ async def get_task(
) -> Task:
"""Retrieves the current state and history of a specific task."""

@abstractmethod
async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""

@abstractmethod
async def cancel_task(
self,
Expand Down
11 changes: 11 additions & 0 deletions src/a2a/client/transports/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
ListTasksParams,
ListTasksResult,
Message,
MessageSendParams,
Task,
Expand Down Expand Up @@ -50,6 +52,15 @@ async def get_task(
) -> Task:
"""Retrieves the current state and history of a specific task."""

@abstractmethod
async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""

@abstractmethod
async def cancel_task(
self,
Expand Down
17 changes: 17 additions & 0 deletions src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from collections.abc import AsyncGenerator

from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE


try:
import grpc
Expand All @@ -20,6 +22,8 @@
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
ListTasksParams,
ListTasksResult,
Message,
MessageSendParams,
Task,
Expand Down Expand Up @@ -145,6 +149,19 @@ async def get_task(
)
return proto_utils.FromProto.task(task)

async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""
response = await self.stub.ListTasks(
proto_utils.ToProto.list_tasks_request(request)
)
page_size = request.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE
return proto_utils.FromProto.list_tasks_result(response, page_size)

async def cancel_task(
self,
request: TaskIdParams,
Expand Down
24 changes: 24 additions & 0 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
GetTaskRequest,
GetTaskResponse,
JSONRPCErrorResponse,
ListTasksParams,
ListTasksRequest,
ListTasksResponse,
ListTasksResult,
Message,
MessageSendParams,
SendMessageRequest,
Expand Down Expand Up @@ -222,6 +226,26 @@ async def get_task(
raise A2AClientJSONRPCError(response.root)
return response.root.result

async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""
rpc_request = ListTasksRequest(params=request, id=str(uuid4()))
payload, modified_kwargs = await self._apply_interceptors(
'tasks/list',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
context,
)
response_data = await self._send_request(payload, modified_kwargs)
response = ListTasksResponse.model_validate(response_data)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
return response.root.result

async def cancel_task(
self,
request: TaskIdParams,
Expand Down
44 changes: 44 additions & 0 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from google.protobuf.json_format import MessageToDict, Parse, ParseDict
from httpx_sse import SSEError, aconnect_sse
from pydantic import BaseModel

from a2a.client.card_resolver import A2ACardResolver
from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError
Expand All @@ -17,6 +18,8 @@
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
ListTasksParams,
ListTasksResult,
Message,
MessageSendParams,
Task,
Expand All @@ -27,6 +30,7 @@
TaskStatusUpdateEvent,
)
from a2a.utils import proto_utils
from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE
from a2a.utils.telemetry import SpanKind, trace_class


Expand Down Expand Up @@ -222,6 +226,28 @@ async def get_task(
ParseDict(response_data, task)
return proto_utils.FromProto.task(task)

async def list_tasks(
self,
request: ListTasksParams,
*,
context: ClientCallContext | None = None,
) -> ListTasksResult:
"""Retrieves tasks for an agent."""
_, modified_kwargs = await self._apply_interceptors(
request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
context,
)
response_data = await self._send_get_request(
'/v1/tasks',
_model_to_query_params(request),
modified_kwargs,
)
response = a2a_pb2.ListTasksResponse()
ParseDict(response_data, response)
page_size = request.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE
return proto_utils.FromProto.list_tasks_result(response, page_size)

async def cancel_task(
self,
request: TaskIdParams,
Expand Down Expand Up @@ -363,3 +389,21 @@ async def get_card(
async def close(self) -> None:
"""Closes the httpx client."""
await self.httpx_client.aclose()


def _model_to_query_params(instance: BaseModel) -> dict[str, str]:
data = instance.model_dump(mode='json', exclude_none=True)
return _json_to_query_params(data)


def _json_to_query_params(data: dict[str, Any]) -> dict[str, str]:
query_dict = {}
for key, value in data.items():
if isinstance(value, list):
query_dict[key] = ','.join(map(str, value))
elif isinstance(value, bool):
query_dict[key] = str(value).lower()
else:
query_dict[key] = str(value)

return query_dict
Loading