Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 136 additions & 45 deletions python/packages/core/agent_framework/_workflows/_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

import logging
from dataclasses import dataclass
from typing import Any
from typing import Any, cast

from agent_framework import FunctionApprovalRequestContent, FunctionApprovalResponseContent

from .._agents import AgentProtocol, ChatAgent
from .._threads import AgentThread
from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
from ._conversation_state import encode_chat_messages
from ._events import (
AgentRunEvent,
AgentRunUpdateEvent, # type: ignore[reportPrivateUsage]
)
from ._executor import Executor, handler
from ._message_utils import normalize_messages_input
from ._request_info_mixin import response_handler
from ._workflow_context import WorkflowContext

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -83,6 +87,8 @@ def __init__(
super().__init__(exec_id)
self._agent = agent
self._agent_thread = agent_thread or self._agent.get_new_thread()
self._pending_agent_requests: dict[str, FunctionApprovalRequestContent] = {}
self._pending_responses_to_agent: list[FunctionApprovalResponseContent] = []
self._output_response = output_response
self._cache: list[ChatMessage] = []

Expand All @@ -93,50 +99,6 @@ def workflow_output_types(self) -> list[type[Any]]:
return [AgentRunResponse]
return []

async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]) -> None:
"""Execute the underlying agent, emit events, and enqueue response.

Checks ctx.is_streaming() to determine whether to emit incremental AgentRunUpdateEvent
events (streaming mode) or a single AgentRunEvent (non-streaming mode).
"""
if ctx.is_streaming():
# Streaming mode: emit incremental updates
updates: list[AgentRunResponseUpdate] = []
async for update in self._agent.run_stream(
self._cache,
thread=self._agent_thread,
):
updates.append(update)
await ctx.add_event(AgentRunUpdateEvent(self.id, update))

if isinstance(self._agent, ChatAgent):
response_format = self._agent.chat_options.response_format
response = AgentRunResponse.from_agent_run_response_updates(
updates,
output_format_type=response_format,
)
else:
response = AgentRunResponse.from_agent_run_response_updates(updates)
else:
# Non-streaming mode: use run() and emit single event
response = await self._agent.run(
self._cache,
thread=self._agent_thread,
)
await ctx.add_event(AgentRunEvent(self.id, response))

if self._output_response:
await ctx.yield_output(response)

# Always construct a full conversation snapshot from inputs (cache)
# plus agent outputs (agent_run_response.messages). Do not mutate
# response.messages so AgentRunEvent remains faithful to the raw output.
full_conversation: list[ChatMessage] = list(self._cache) + list(response.messages)

agent_response = AgentExecutorResponse(self.id, response, full_conversation=full_conversation)
await ctx.send_message(agent_response)
self._cache.clear()

@handler
async def run(
self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]
Expand Down Expand Up @@ -192,6 +154,31 @@ async def from_messages(
self._cache = normalize_messages_input(messages)
await self._run_agent_and_emit(ctx)

@response_handler
async def handle_user_input_response(
self,
original_request: FunctionApprovalRequestContent,
response: FunctionApprovalResponseContent,
ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse],
) -> None:
"""Handle user input responses for function approvals during agent execution.

This will hold the executor's execution until all pending user input requests are resolved.

Args:
original_request: The original function approval request sent by the agent.
response: The user's response to the function approval request.
ctx: The workflow context for emitting events and outputs.
"""
self._pending_responses_to_agent.append(response)
self._pending_agent_requests.pop(original_request.id, None)

if not self._pending_agent_requests:
# All pending requests have been resolved; resume agent execution
self._cache = normalize_messages_input(ChatMessage(role="user", contents=self._pending_responses_to_agent))
self._pending_responses_to_agent.clear()
await self._run_agent_and_emit(ctx)

async def snapshot_state(self) -> dict[str, Any]:
"""Capture current executor state for checkpointing.

Expand Down Expand Up @@ -226,6 +213,8 @@ async def snapshot_state(self) -> dict[str, Any]:
return {
"cache": encode_chat_messages(self._cache),
"agent_thread": serialized_thread,
"pending_agent_requests": encode_checkpoint_value(self._pending_agent_requests),
"pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent),
}

async def restore_state(self, state: dict[str, Any]) -> None:
Expand Down Expand Up @@ -258,7 +247,109 @@ async def restore_state(self, state: dict[str, Any]) -> None:
else:
self._agent_thread = self._agent.get_new_thread()

pending_requests_payload = state.get("pending_agent_requests")
if pending_requests_payload:
self._pending_agent_requests = decode_checkpoint_value(pending_requests_payload)

pending_responses_payload = state.get("pending_responses_to_agent")
if pending_responses_payload:
self._pending_responses_to_agent = decode_checkpoint_value(pending_responses_payload)

def reset(self) -> None:
"""Reset the internal cache of the executor."""
logger.debug("AgentExecutor %s: Resetting cache", self.id)
self._cache.clear()

async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]) -> None:
"""Execute the underlying agent, emit events, and enqueue response.

Checks ctx.is_streaming() to determine whether to emit incremental AgentRunUpdateEvent
events (streaming mode) or a single AgentRunEvent (non-streaming mode).
"""
if ctx.is_streaming():
# Streaming mode: emit incremental updates
response = await self._run_agent_streaming(cast(WorkflowContext, ctx))
else:
# Non-streaming mode: use run() and emit single event
response = await self._run_agent(cast(WorkflowContext, ctx))

if response is None:
# Agent did not complete (e.g., waiting for user input); do not emit response
logger.info("AgentExecutor %s: Agent did not complete, awaiting user input", self.id)
return

if self._output_response:
await ctx.yield_output(response)

# Always construct a full conversation snapshot from inputs (cache)
# plus agent outputs (agent_run_response.messages). Do not mutate
# response.messages so AgentRunEvent remains faithful to the raw output.
full_conversation: list[ChatMessage] = list(self._cache) + list(response.messages)

agent_response = AgentExecutorResponse(self.id, response, full_conversation=full_conversation)
await ctx.send_message(agent_response)
self._cache.clear()

async def _run_agent(self, ctx: WorkflowContext) -> AgentRunResponse | None:
"""Execute the underlying agent in non-streaming mode.

Args:
ctx: The workflow context for emitting events.

Returns:
The complete AgentRunResponse, or None if waiting for user input.
"""
response = await self._agent.run(
self._cache,
thread=self._agent_thread,
)
await ctx.add_event(AgentRunEvent(self.id, response))

# Handle any user input requests
if response.user_input_requests:
for user_input_request in response.user_input_requests:
self._pending_agent_requests[user_input_request.id] = user_input_request
await ctx.request_info(user_input_request, FunctionApprovalResponseContent)
return None

return response

async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentRunResponse | None:
"""Execute the underlying agent in streaming mode and collect the full response.

Args:
ctx: The workflow context for emitting events.

Returns:
The complete AgentRunResponse, or None if waiting for user input.
"""
updates: list[AgentRunResponseUpdate] = []
user_input_requests: list[FunctionApprovalRequestContent] = []
async for update in self._agent.run_stream(
self._cache,
thread=self._agent_thread,
):
updates.append(update)
await ctx.add_event(AgentRunUpdateEvent(self.id, update))

if update.user_input_requests:
user_input_requests.extend(update.user_input_requests)

# Build the final AgentRunResponse from the collected updates
if isinstance(self._agent, ChatAgent):
response_format = self._agent.chat_options.response_format
response = AgentRunResponse.from_agent_run_response_updates(
updates,
output_format_type=response_format,
)
else:
response = AgentRunResponse.from_agent_run_response_updates(updates)

# Handle any user input requests after the streaming completes
if user_input_requests:
for user_input_request in user_input_requests:
self._pending_agent_requests[user_input_request.id] = user_input_request
await ctx.request_info(user_input_request, FunctionApprovalResponseContent)
return None

return response
4 changes: 4 additions & 0 deletions python/packages/core/tests/workflow/test_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
chat_store_state = thread_state["chat_message_store_state"] # type: ignore[index]
assert "messages" in chat_store_state, "Message store state should include messages"

# Verify checkpoint contains pending requests from agents and responses to be sent
assert "pending_agent_requests" in executor_state
assert "pending_responses_to_agent" in executor_state

# Create a new agent and executor for restoration
# This simulates starting from a fresh state and restoring from checkpoint
restored_agent = _CountingAgent(id="test_agent", name="TestAgent")
Expand Down
Loading