From 71dcf2c3d34d248a7d6eeedd3a4cb9e7cfd0a21b Mon Sep 17 00:00:00 2001 From: maxpetrusenkoagent Date: Sun, 21 Jun 2026 01:17:01 -0400 Subject: [PATCH] fix: reject spoofed grpc runtime sender types Signed-off-by: maxpetrusenkoagent --- .../runtimes/grpc/_worker_runtime.py | 14 +- .../grpc/_worker_runtime_host_servicer.py | 59 +++++- .../autogen-ext/tests/test_worker_runtime.py | 177 +++++++++++++++++- 3 files changed, 237 insertions(+), 13 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index 6a3963586e18..9e5b1611b65f 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -434,22 +434,22 @@ async def publish_message( message, type_name=message_type, data_content_type=self._payload_serialization_format ) - sender_id = sender or AgentId("unknown", "unknown") attributes = { _constants.DATA_CONTENT_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( ce_string=self._payload_serialization_format ), _constants.DATA_SCHEMA_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=message_type), - _constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( - ce_string=sender_id.type - ), - _constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( - ce_string=sender_id.key - ), _constants.MESSAGE_KIND_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( ce_string=_constants.MESSAGE_KIND_VALUE_PUBLISH ), } + if sender is not None: + attributes[_constants.AGENT_SENDER_TYPE_ATTR] = cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string=sender.type + ) + attributes[_constants.AGENT_SENDER_KEY_ATTR] = cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string=sender.key + ) # If sending JSON we fill text_data with the serialized message # If sending Protobuf we fill proto_data with the serialized message diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py index 1c0b57a440ed..1edb1bb01961 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py @@ -10,6 +10,7 @@ from autogen_core._agent_id import AgentId from autogen_core._runtime_impl_helpers import SubscriptionManager +from . import _constants from ._constants import GRPC_IMPORT_ERROR_STR from ._utils import subscription_from_proto, subscription_to_proto @@ -199,13 +200,27 @@ async def _receive_message(self, client_id: ClientConnectionId, message: agent_w task.add_done_callback(self._raise_on_exception) task.add_done_callback(self._background_tasks.discard) case "cloudEvent": - task = asyncio.create_task(self._process_event(message.cloudEvent)) + task = asyncio.create_task(self._process_event(message.cloudEvent, client_id)) self._background_tasks.add(task) task.add_done_callback(self._raise_on_exception) task.add_done_callback(self._background_tasks.discard) case None: logger.warning("Received empty message") + async def _client_owns_agent_type(self, agent_type: str, client_id: ClientConnectionId) -> bool: + async with self._agent_type_to_client_id_lock: + registered_client_id = self._agent_type_to_client_id.get(agent_type) + return registered_client_id == client_id + + async def _send_rpc_error(self, client_id: ClientConnectionId, request_id: str, error: str) -> None: + send_queue = self._data_connections.get(client_id) + if send_queue is None: + logger.error("Client %s not found, failed to send RPC error response.", client_id) + return + await send_queue.send( + agent_worker_pb2.Message(response=agent_worker_pb2.RpcResponse(request_id=request_id, error=error)) + ) + async def _receive_control_message( self, client_id: ClientConnectionId, message: agent_worker_pb2.ControlMessage ) -> None: @@ -230,6 +245,19 @@ async def _receive_control_message( await target_send_queue.send(message) async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None: + if request.HasField("source") and not await self._client_owns_agent_type(request.source.type, client_id): + logger.warning( + "Client %s attempted to send an RPC request as agent type %s owned by another client.", + client_id, + request.source.type, + ) + await self._send_rpc_error( + client_id, + request.request_id, + f"Client {client_id} is not authorized to send as agent type {request.source.type}.", + ) + return + # Deliver the message to a client given the target agent type. async with self._agent_type_to_client_id_lock: target_client_id = self._agent_type_to_client_id.get(request.target.type) @@ -268,16 +296,37 @@ async def _process_response(self, response: agent_worker_pb2.RpcResponse, client future = self._pending_responses[client_id].pop(response.request_id) future.set_result(response) - async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: + async def _process_event(self, event: cloudevent_pb2.CloudEvent, client_id: ClientConnectionId) -> None: + event_attributes = event.attributes + sender_type_attribute = event_attributes.get(_constants.AGENT_SENDER_TYPE_ATTR) + sender_key_attribute = event_attributes.get(_constants.AGENT_SENDER_KEY_ATTR) + is_legacy_anonymous_sender = ( + sender_type_attribute is not None + and sender_key_attribute is not None + and sender_type_attribute.ce_string == "unknown" + and sender_key_attribute.ce_string == "unknown" + ) + if ( + sender_type_attribute is not None + and not is_legacy_anonymous_sender + and not await self._client_owns_agent_type(sender_type_attribute.ce_string, client_id) + ): + logger.warning( + "Client %s attempted to publish an event as agent type %s owned by another client.", + client_id, + sender_type_attribute.ce_string, + ) + return + topic_id = TopicId(type=event.type, source=event.source) recipients = await self._subscription_manager.get_subscribed_recipients(topic_id) # Get the client ids of the recipients. async with self._agent_type_to_client_id_lock: client_ids: Set[ClientConnectionId] = set() for recipient in recipients: - client_id = self._agent_type_to_client_id.get(recipient.type) - if client_id is not None: - client_ids.add(client_id) + target_client_id = self._agent_type_to_client_id.get(recipient.type) + if target_client_id is not None: + client_ids.add(target_client_id) else: logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.") # Deliver the event to clients. diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index ec57f187e821..3c7a28376e38 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -1,3 +1,4 @@ +# pyright: reportPrivateUsage=false import asyncio import logging import os @@ -20,7 +21,9 @@ try_get_known_serializers_for_type, type_subscription, ) -from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime, GrpcWorkerAgentRuntimeHost +from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime, GrpcWorkerAgentRuntimeHost, _constants +from autogen_ext.runtimes.grpc._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer +from autogen_ext.runtimes.grpc.protos import agent_worker_pb2, cloudevent_pb2 from autogen_test_utils import ( CascadingAgent, CascadingMessageType, @@ -34,6 +37,178 @@ from .protos.serialization_test_pb2 import ProtoMessage +class _RecordingConnection: + def __init__(self) -> None: + self.messages: List[agent_worker_pb2.Message] = [] + + async def send(self, message: agent_worker_pb2.Message) -> None: + self.messages.append(message) + + +def _test_payload() -> agent_worker_pb2.Payload: + return agent_worker_pb2.Payload(data_type="test", data=b"{}", data_content_type="application/json") + + +@pytest.mark.asyncio +async def test_host_rejects_rpc_request_with_spoofed_source_agent_type() -> None: + servicer = GrpcWorkerAgentRuntimeHostServicer() + servicer._agent_type_to_client_id.update({"source_agent": "source-client", "target_agent": "target-client"}) + source_connection = _RecordingConnection() + target_connection = _RecordingConnection() + servicer._data_connections["spoofing-client"] = source_connection # type: ignore[assignment] + servicer._data_connections["target-client"] = target_connection # type: ignore[assignment] + + await servicer._process_request( + agent_worker_pb2.RpcRequest( + request_id="request-1", + target=agent_worker_pb2.AgentId(type="target_agent", key="default"), + source=agent_worker_pb2.AgentId(type="source_agent", key="default"), + payload=_test_payload(), + ), + client_id="spoofing-client", + ) + + assert target_connection.messages == [] + assert len(source_connection.messages) == 1 + assert "not authorized to send as agent type source_agent" in source_connection.messages[0].response.error + + +@pytest.mark.asyncio +async def test_host_rejects_rpc_request_with_unregistered_source_agent_type() -> None: + servicer = GrpcWorkerAgentRuntimeHostServicer() + servicer._agent_type_to_client_id.update({"target_agent": "target-client"}) + source_connection = _RecordingConnection() + target_connection = _RecordingConnection() + servicer._data_connections["source-client"] = source_connection # type: ignore[assignment] + servicer._data_connections["target-client"] = target_connection # type: ignore[assignment] + + await servicer._process_request( + agent_worker_pb2.RpcRequest( + request_id="request-1", + target=agent_worker_pb2.AgentId(type="target_agent", key="default"), + source=agent_worker_pb2.AgentId(type="unregistered_agent", key="default"), + payload=_test_payload(), + ), + client_id="source-client", + ) + + assert target_connection.messages == [] + assert len(source_connection.messages) == 1 + assert "not authorized to send as agent type unregistered_agent" in source_connection.messages[0].response.error + + +@pytest.mark.asyncio +async def test_host_rejects_publish_event_with_spoofed_sender_agent_type() -> None: + servicer = GrpcWorkerAgentRuntimeHostServicer() + servicer._agent_type_to_client_id.update({"source_agent": "source-client", "target_agent": "target-client"}) + target_connection = _RecordingConnection() + servicer._data_connections["target-client"] = target_connection # type: ignore[assignment] + await servicer._subscription_manager.add_subscription(TypeSubscription("default", "target_agent")) + + await servicer._process_event( + cloudevent_pb2.CloudEvent( + id="event-1", + spec_version="1.0", + type="default", + source="default", + attributes={ + _constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string="source_agent" + ), + _constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string="default" + ), + }, + binary_data=b"{}", + ), + client_id="spoofing-client", + ) + + assert target_connection.messages == [] + + +@pytest.mark.asyncio +async def test_host_rejects_publish_event_with_unregistered_sender_agent_type() -> None: + servicer = GrpcWorkerAgentRuntimeHostServicer() + servicer._agent_type_to_client_id.update({"target_agent": "target-client"}) + target_connection = _RecordingConnection() + servicer._data_connections["target-client"] = target_connection # type: ignore[assignment] + await servicer._subscription_manager.add_subscription(TypeSubscription("default", "target_agent")) + + await servicer._process_event( + cloudevent_pb2.CloudEvent( + id="event-1", + spec_version="1.0", + type="default", + source="default", + attributes={ + _constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string="unregistered_agent" + ), + _constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string="default" + ), + }, + binary_data=b"{}", + ), + client_id="source-client", + ) + + assert target_connection.messages == [] + + +@pytest.mark.asyncio +async def test_host_allows_anonymous_publish_when_unknown_agent_type_is_registered_elsewhere() -> None: + servicer = GrpcWorkerAgentRuntimeHostServicer() + servicer._agent_type_to_client_id.update({"unknown": "unknown-client", "target_agent": "target-client"}) + target_connection = _RecordingConnection() + servicer._data_connections["target-client"] = target_connection # type: ignore[assignment] + await servicer._subscription_manager.add_subscription(TypeSubscription("default", "target_agent")) + + await servicer._process_event( + cloudevent_pb2.CloudEvent( + id="event-1", + spec_version="1.0", + type="default", + source="default", + binary_data=b"{}", + ), + client_id="source-client", + ) + + assert len(target_connection.messages) == 1 + + +@pytest.mark.asyncio +async def test_host_allows_legacy_unknown_sender_publish_for_rolling_upgrade() -> None: + servicer = GrpcWorkerAgentRuntimeHostServicer() + servicer._agent_type_to_client_id.update({"unknown": "unknown-client", "target_agent": "target-client"}) + target_connection = _RecordingConnection() + servicer._data_connections["target-client"] = target_connection # type: ignore[assignment] + await servicer._subscription_manager.add_subscription(TypeSubscription("default", "target_agent")) + + await servicer._process_event( + cloudevent_pb2.CloudEvent( + id="event-1", + spec_version="1.0", + type="default", + source="default", + attributes={ + _constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string="unknown" + ), + _constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string="unknown" + ), + }, + binary_data=b"{}", + ), + client_id="source-client", + ) + + assert len(target_connection.messages) == 1 + + @pytest.mark.grpc @pytest.mark.asyncio async def test_agent_types_must_be_unique_single_worker() -> None: