Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -434,22 +434,22 @@ async def publish_message(
message, type_name=message_type, data_content_type=self._payload_serialization_format
)

sender_id = sender or AgentId("unknown", "unknown")
attributes = {
_constants.DATA_CONTENT_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=self._payload_serialization_format
),
_constants.DATA_SCHEMA_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=message_type),
_constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=sender_id.type
),
_constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=sender_id.key
),
_constants.MESSAGE_KIND_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=_constants.MESSAGE_KIND_VALUE_PUBLISH
),
}
if sender is not None:
attributes[_constants.AGENT_SENDER_TYPE_ATTR] = cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=sender.type
)
attributes[_constants.AGENT_SENDER_KEY_ATTR] = cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=sender.key
)

# If sending JSON we fill text_data with the serialized message
# If sending Protobuf we fill proto_data with the serialized message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from autogen_core._agent_id import AgentId
from autogen_core._runtime_impl_helpers import SubscriptionManager

from . import _constants
from ._constants import GRPC_IMPORT_ERROR_STR
from ._utils import subscription_from_proto, subscription_to_proto

Expand Down Expand Up @@ -199,13 +200,27 @@ async def _receive_message(self, client_id: ClientConnectionId, message: agent_w
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "cloudEvent":
task = asyncio.create_task(self._process_event(message.cloudEvent))
task = asyncio.create_task(self._process_event(message.cloudEvent, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("Received empty message")

async def _client_owns_agent_type(self, agent_type: str, client_id: ClientConnectionId) -> bool:
async with self._agent_type_to_client_id_lock:
registered_client_id = self._agent_type_to_client_id.get(agent_type)
return registered_client_id == client_id

async def _send_rpc_error(self, client_id: ClientConnectionId, request_id: str, error: str) -> None:
send_queue = self._data_connections.get(client_id)
if send_queue is None:
logger.error("Client %s not found, failed to send RPC error response.", client_id)
return
await send_queue.send(
agent_worker_pb2.Message(response=agent_worker_pb2.RpcResponse(request_id=request_id, error=error))
)

async def _receive_control_message(
self, client_id: ClientConnectionId, message: agent_worker_pb2.ControlMessage
) -> None:
Expand All @@ -230,6 +245,19 @@ async def _receive_control_message(
await target_send_queue.send(message)

async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
if request.HasField("source") and not await self._client_owns_agent_type(request.source.type, client_id):
logger.warning(
"Client %s attempted to send an RPC request as agent type %s owned by another client.",
client_id,
request.source.type,
)
await self._send_rpc_error(
client_id,
request.request_id,
f"Client {client_id} is not authorized to send as agent type {request.source.type}.",
)
return

# Deliver the message to a client given the target agent type.
async with self._agent_type_to_client_id_lock:
target_client_id = self._agent_type_to_client_id.get(request.target.type)
Expand Down Expand Up @@ -268,16 +296,37 @@ async def _process_response(self, response: agent_worker_pb2.RpcResponse, client
future = self._pending_responses[client_id].pop(response.request_id)
future.set_result(response)

async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
async def _process_event(self, event: cloudevent_pb2.CloudEvent, client_id: ClientConnectionId) -> None:
event_attributes = event.attributes
sender_type_attribute = event_attributes.get(_constants.AGENT_SENDER_TYPE_ATTR)
sender_key_attribute = event_attributes.get(_constants.AGENT_SENDER_KEY_ATTR)
is_legacy_anonymous_sender = (
sender_type_attribute is not None
and sender_key_attribute is not None
and sender_type_attribute.ce_string == "unknown"
and sender_key_attribute.ce_string == "unknown"
)
if (
sender_type_attribute is not None
and not is_legacy_anonymous_sender
and not await self._client_owns_agent_type(sender_type_attribute.ce_string, client_id)
):
logger.warning(
"Client %s attempted to publish an event as agent type %s owned by another client.",
client_id,
sender_type_attribute.ce_string,
)
return

topic_id = TopicId(type=event.type, source=event.source)
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
# Get the client ids of the recipients.
async with self._agent_type_to_client_id_lock:
client_ids: Set[ClientConnectionId] = set()
for recipient in recipients:
client_id = self._agent_type_to_client_id.get(recipient.type)
if client_id is not None:
client_ids.add(client_id)
target_client_id = self._agent_type_to_client_id.get(recipient.type)
if target_client_id is not None:
client_ids.add(target_client_id)
else:
logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")
# Deliver the event to clients.
Expand Down
177 changes: 176 additions & 1 deletion python/packages/autogen-ext/tests/test_worker_runtime.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyright: reportPrivateUsage=false
import asyncio
import logging
import os
Expand All @@ -20,7 +21,9 @@
try_get_known_serializers_for_type,
type_subscription,
)
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime, GrpcWorkerAgentRuntimeHost
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime, GrpcWorkerAgentRuntimeHost, _constants
from autogen_ext.runtimes.grpc._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer
from autogen_ext.runtimes.grpc.protos import agent_worker_pb2, cloudevent_pb2
from autogen_test_utils import (
CascadingAgent,
CascadingMessageType,
Expand All @@ -34,6 +37,178 @@
from .protos.serialization_test_pb2 import ProtoMessage


class _RecordingConnection:
def __init__(self) -> None:
self.messages: List[agent_worker_pb2.Message] = []

async def send(self, message: agent_worker_pb2.Message) -> None:
self.messages.append(message)


def _test_payload() -> agent_worker_pb2.Payload:
return agent_worker_pb2.Payload(data_type="test", data=b"{}", data_content_type="application/json")


@pytest.mark.asyncio
async def test_host_rejects_rpc_request_with_spoofed_source_agent_type() -> None:
servicer = GrpcWorkerAgentRuntimeHostServicer()
servicer._agent_type_to_client_id.update({"source_agent": "source-client", "target_agent": "target-client"})
source_connection = _RecordingConnection()
target_connection = _RecordingConnection()
servicer._data_connections["spoofing-client"] = source_connection # type: ignore[assignment]
servicer._data_connections["target-client"] = target_connection # type: ignore[assignment]

await servicer._process_request(
agent_worker_pb2.RpcRequest(
request_id="request-1",
target=agent_worker_pb2.AgentId(type="target_agent", key="default"),
source=agent_worker_pb2.AgentId(type="source_agent", key="default"),
payload=_test_payload(),
),
client_id="spoofing-client",
)

assert target_connection.messages == []
assert len(source_connection.messages) == 1
assert "not authorized to send as agent type source_agent" in source_connection.messages[0].response.error


@pytest.mark.asyncio
async def test_host_rejects_rpc_request_with_unregistered_source_agent_type() -> None:
servicer = GrpcWorkerAgentRuntimeHostServicer()
servicer._agent_type_to_client_id.update({"target_agent": "target-client"})
source_connection = _RecordingConnection()
target_connection = _RecordingConnection()
servicer._data_connections["source-client"] = source_connection # type: ignore[assignment]
servicer._data_connections["target-client"] = target_connection # type: ignore[assignment]

await servicer._process_request(
agent_worker_pb2.RpcRequest(
request_id="request-1",
target=agent_worker_pb2.AgentId(type="target_agent", key="default"),
source=agent_worker_pb2.AgentId(type="unregistered_agent", key="default"),
payload=_test_payload(),
),
client_id="source-client",
)

assert target_connection.messages == []
assert len(source_connection.messages) == 1
assert "not authorized to send as agent type unregistered_agent" in source_connection.messages[0].response.error


@pytest.mark.asyncio
async def test_host_rejects_publish_event_with_spoofed_sender_agent_type() -> None:
servicer = GrpcWorkerAgentRuntimeHostServicer()
servicer._agent_type_to_client_id.update({"source_agent": "source-client", "target_agent": "target-client"})
target_connection = _RecordingConnection()
servicer._data_connections["target-client"] = target_connection # type: ignore[assignment]
await servicer._subscription_manager.add_subscription(TypeSubscription("default", "target_agent"))

await servicer._process_event(
cloudevent_pb2.CloudEvent(
id="event-1",
spec_version="1.0",
type="default",
source="default",
attributes={
_constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string="source_agent"
),
_constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string="default"
),
},
binary_data=b"{}",
),
client_id="spoofing-client",
)

assert target_connection.messages == []


@pytest.mark.asyncio
async def test_host_rejects_publish_event_with_unregistered_sender_agent_type() -> None:
servicer = GrpcWorkerAgentRuntimeHostServicer()
servicer._agent_type_to_client_id.update({"target_agent": "target-client"})
target_connection = _RecordingConnection()
servicer._data_connections["target-client"] = target_connection # type: ignore[assignment]
await servicer._subscription_manager.add_subscription(TypeSubscription("default", "target_agent"))

await servicer._process_event(
cloudevent_pb2.CloudEvent(
id="event-1",
spec_version="1.0",
type="default",
source="default",
attributes={
_constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string="unregistered_agent"
),
_constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string="default"
),
},
binary_data=b"{}",
),
client_id="source-client",
)

assert target_connection.messages == []


@pytest.mark.asyncio
async def test_host_allows_anonymous_publish_when_unknown_agent_type_is_registered_elsewhere() -> None:
servicer = GrpcWorkerAgentRuntimeHostServicer()
servicer._agent_type_to_client_id.update({"unknown": "unknown-client", "target_agent": "target-client"})
target_connection = _RecordingConnection()
servicer._data_connections["target-client"] = target_connection # type: ignore[assignment]
await servicer._subscription_manager.add_subscription(TypeSubscription("default", "target_agent"))

await servicer._process_event(
cloudevent_pb2.CloudEvent(
id="event-1",
spec_version="1.0",
type="default",
source="default",
binary_data=b"{}",
),
client_id="source-client",
)

assert len(target_connection.messages) == 1


@pytest.mark.asyncio
async def test_host_allows_legacy_unknown_sender_publish_for_rolling_upgrade() -> None:
servicer = GrpcWorkerAgentRuntimeHostServicer()
servicer._agent_type_to_client_id.update({"unknown": "unknown-client", "target_agent": "target-client"})
target_connection = _RecordingConnection()
servicer._data_connections["target-client"] = target_connection # type: ignore[assignment]
await servicer._subscription_manager.add_subscription(TypeSubscription("default", "target_agent"))

await servicer._process_event(
cloudevent_pb2.CloudEvent(
id="event-1",
spec_version="1.0",
type="default",
source="default",
attributes={
_constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string="unknown"
),
_constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string="unknown"
),
},
binary_data=b"{}",
),
client_id="source-client",
)

assert len(target_connection.messages) == 1


@pytest.mark.grpc
@pytest.mark.asyncio
async def test_agent_types_must_be_unique_single_worker() -> None:
Expand Down