Skip to content
Open
5 changes: 5 additions & 0 deletions .github/actions/spelling/allow.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,14 @@ initdb
inmemory
INR
isready
jku
JPY
JSONRPCt
jwk
jwks
JWS
jws
kid
kwarg
langgraph
lifecycles
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"]
postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]
signing = ["PyJWT>=2.0.0"]

sql = ["a2a-sdk[postgresql,mysql,sqlite]"]

Expand All @@ -45,6 +46,7 @@ all = [
"a2a-sdk[encryption]",
"a2a-sdk[grpc]",
"a2a-sdk[telemetry]",
"a2a-sdk[signing]",
]

[project.urls]
Expand Down
8 changes: 6 additions & 2 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Callable
from typing import Any

from a2a.client.client import (
Expand Down Expand Up @@ -261,6 +261,7 @@ async def get_card(
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card.

Expand All @@ -270,12 +271,15 @@ async def get_card(
Args:
context: The client call context.
extensions: List of extensions to be activated.
signature_verifier: A callable used to verify the agent card's signatures.

Returns:
The `AgentCard` for the agent.
"""
card = await self._transport.get_card(
context=context, extensions=extensions
context=context,
extensions=extensions,
signature_verifier=signature_verifier,
)
self._card = card
return card
Expand Down
1 change: 1 addition & 0 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,18 @@
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream."""
return
yield

@abstractmethod
async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""

Check notice on line 190 in src/a2a/client/client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/base.py (97-108)

async def add_event_consumer(self, consumer: Consumer) -> None:
"""Attaches additional consumers to the `Client`."""
Expand Down
3 changes: 2 additions & 1 deletion src/a2a/client/transports/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable

from a2a.client.middleware import ClientCallContext
from a2a.types import (
Expand Down Expand Up @@ -94,17 +94,18 @@
Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
]:
"""Reconnects to get task updates."""
return
yield

@abstractmethod
async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the AgentCard."""

Check notice on line 108 in src/a2a/client/transports/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/client.py (179-190)

@abstractmethod
async def close(self) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable


try:
Expand Down Expand Up @@ -223,6 +223,7 @@ async def get_card(
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
card = self.agent_card
Expand All @@ -236,6 +237,9 @@ async def get_card(
metadata=self._get_grpc_metadata(extensions),
)
card = proto_utils.FromProto.agent_card(card_pb)
if signature_verifier is not None:
signature_verifier(card)

self.agent_card = card
self._needs_extended_card = False
return card
Expand Down
14 changes: 11 additions & 3 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging

from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from typing import Any
from uuid import uuid4

Expand Down Expand Up @@ -363,41 +363,45 @@
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
card = self.agent_card

if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
if signature_verifier is not None:
signature_verifier(card)
self._needs_extended_card = (
card.supports_authenticated_extended_card
)
self.agent_card = card

if not self._needs_extended_card:
return card

request = GetAuthenticatedExtendedCardRequest(id=str(uuid4()))

Check notice on line 404 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (162-396)
payload, modified_kwargs = await self._apply_interceptors(
request.method,
request.model_dump(mode='json', exclude_none=True),
Expand All @@ -413,9 +417,13 @@
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
self.agent_card = response.root.result
card = response.root.result
if signature_verifier is not None:
signature_verifier(card)

self.agent_card = card
self._needs_extended_card = False
return self.agent_card
return card

async def close(self) -> None:
"""Closes the httpx client."""
Expand Down
9 changes: 8 additions & 1 deletion src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging

from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from typing import Any

import httpx
Expand Down Expand Up @@ -159,237 +159,241 @@
yield proto_utils.FromProto.stream_response(event)
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_request(self, request: httpx.Request) -> dict[str, Any]:
try:
response = await self.httpx_client.send(request)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_post_request(
self,
target: str,
rpc_request_payload: dict[str, Any],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
return await self._send_request(
self.httpx_client.build_request(
'POST',
f'{self.url}{target}',
json=rpc_request_payload,
**(http_kwargs or {}),
)
)

async def _send_get_request(
self,
target: str,
query_params: dict[str, str],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
return await self._send_request(
self.httpx_client.build_request(
'GET',
f'{self.url}{target}',
params=query_params,
**(http_kwargs or {}),
)
)

async def get_task(
self,
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
_payload, modified_kwargs = await self._apply_interceptors(
request.model_dump(mode='json', exclude_none=True),
modified_kwargs,
context,
)
response_data = await self._send_get_request(
f'/v1/tasks/{request.id}',
{'historyLength': str(request.history_length)}
if request.history_length is not None
else {},
modified_kwargs,
)
task = a2a_pb2.Task()
ParseDict(response_data, task)
return proto_utils.FromProto.task(task)

async def cancel_task(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""
pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}')
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload,
modified_kwargs,
context,
)
response_data = await self._send_post_request(
f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs
)
task = a2a_pb2.Task()
ParseDict(response_data, task)
return proto_utils.FromProto.task(task)

async def set_task_callback(
self,
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""
pb = a2a_pb2.CreateTaskPushNotificationConfigRequest(
parent=f'tasks/{request.task_id}',
config_id=request.push_notification_config.id,
config=proto_utils.ToProto.task_push_notification_config(request),
)
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload, modified_kwargs, context
)
response_data = await self._send_post_request(
f'/v1/tasks/{request.task_id}/pushNotificationConfigs',
payload,
modified_kwargs,
)
config = a2a_pb2.TaskPushNotificationConfig()
ParseDict(response_data, config)
return proto_utils.FromProto.task_push_notification_config(config)

async def get_task_callback(
self,
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""
pb = a2a_pb2.GetTaskPushNotificationConfigRequest(
name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
)
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload,
modified_kwargs,
context,
)
response_data = await self._send_get_request(
f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
{},
modified_kwargs,
)
config = a2a_pb2.TaskPushNotificationConfig()
ParseDict(response_data, config)
return proto_utils.FromProto.task_push_notification_config(config)

async def resubscribe(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncGenerator[
Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message
]:
"""Reconnects to get task updates."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
modified_kwargs.setdefault('timeout', None)

async with aconnect_sse(
self.httpx_client,
'GET',
f'{self.url}/v1/tasks/{request.id}:subscribe',
**modified_kwargs,
) as event_source:
try:
async for sse in event_source.aiter_sse():
event = a2a_pb2.StreamResponse()
Parse(sse.data, event)
yield proto_utils.FromProto.stream_response(event)
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
card = self.agent_card

if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
if signature_verifier is not None:
signature_verifier(card)
self._needs_extended_card = (
card.supports_authenticated_extended_card
)
self.agent_card = card

if not self._needs_extended_card:
return card

_, modified_kwargs = await self._apply_interceptors(

Check notice on line 396 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (366-404)
{},
modified_kwargs,
context,
Expand All @@ -398,6 +402,9 @@
'/v1/card', {}, modified_kwargs
)
card = AgentCard.model_validate(response_data)
if signature_verifier is not None:
signature_verifier(card)

self.agent_card = card
self._needs_extended_card = False
return card
Expand Down
28 changes: 28 additions & 0 deletions src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,21 @@ def agent_card(
]
if card.additional_interfaces
else None,
signatures=[cls.agent_card_signature(x) for x in card.signatures]
if card.signatures
else None,
)

@classmethod
def agent_card_signature(
cls, signature: types.AgentCardSignature
) -> a2a_pb2.AgentCardSignature:
return a2a_pb2.AgentCardSignature(
protected=signature.protected,
signature=signature.signature,
header=dict_to_struct(signature.header)
if signature.header is not None
else None,
)

@classmethod
Expand Down Expand Up @@ -865,6 +880,19 @@ def agent_card(
]
if card.additional_interfaces
else None,
signatures=[cls.agent_card_signature(x) for x in card.signatures]
if card.signatures
else None,
)

@classmethod
def agent_card_signature(
cls, signature: a2a_pb2.AgentCardSignature
) -> types.AgentCardSignature:
return types.AgentCardSignature(
protected=signature.protected,
signature=signature.signature,
header=json_format.MessageToDict(signature.header),
)

@classmethod
Expand Down
Loading
Loading