Skip to content
28 changes: 27 additions & 1 deletion pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
mcp_step_config: StepConfig | None = None,
model_step_config: StepConfig | None = None,
fastmcp_toolset_step_config: StepConfig | None = None,
):
"""Wrap an agent to enable it with DBOS durable workflows, by automatically offloading model requests, tool calls, and MCP server communication to DBOS steps.

Expand All @@ -56,6 +57,7 @@ def __init__(
event_stream_handler: Optional event stream handler to use instead of the one set on the wrapped agent.
mcp_step_config: The base DBOS step config to use for MCP server steps. If no config is provided, use the default settings of DBOS.
model_step_config: The DBOS step config to use for model request steps. If no config is provided, use the default settings of DBOS.
fastmcp_toolset_step_config: The base DBOS step config to use for FastMCP toolset steps. If no config is provided, use the default settings of DBOS.
"""
super().__init__(wrapped)

Expand All @@ -69,6 +71,7 @@ def __init__(
# Merge the config with the default DBOS config
self._mcp_step_config = mcp_step_config or {}
self._model_step_config = model_step_config or {}
self._fastmcp_toolset_step_config = fastmcp_toolset_step_config or {}

if not isinstance(wrapped.model, Model):
raise UserError(
Expand Down Expand Up @@ -101,6 +104,21 @@ def dbosify_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[Age
step_config=self._mcp_step_config,
)

# Replace FastMCPToolset with DBOSFastMCPToolset
try:
from pydantic_ai.toolsets.fastmcp import FastMCPToolset

from ._fastmcp_toolset import DBOSFastMCPToolset
except ImportError:
pass
else:
if isinstance(toolset, FastMCPToolset):
return DBOSFastMCPToolset(
wrapped=toolset,
step_name_prefix=dbosagent_name,
step_config=self._fastmcp_toolset_step_config,
)

return toolset

dbos_toolsets = [toolset.visit_and_replace(dbosify_toolset) for toolset in wrapped.toolsets]
Expand Down Expand Up @@ -336,6 +354,10 @@ async def main():
Returns:
The result of the run.
"""
if model is not None and not isinstance(model, DBOSModel):
raise UserError(
'Non-DBOS model cannot be set at agent run time inside a DBOS workflow, it must be set at agent creation time.'
)
return await self.dbos_wrapped_run_workflow(
user_prompt,
output_type=output_type,
Expand Down Expand Up @@ -449,6 +471,10 @@ def run_sync(
Returns:
The result of the run.
"""
if model is not None and not isinstance(model, DBOSModel): # pragma: lax no cover
raise UserError(
'Non-DBOS model cannot be set at agent run time inside a DBOS workflow, it must be set at agent creation time.'
)
return self.dbos_wrapped_run_sync_workflow(
user_prompt,
output_type=output_type,
Expand Down Expand Up @@ -838,7 +864,7 @@ async def main():
Returns:
The result of the run.
"""
if model is not None and not isinstance(model, DBOSModel):
if model is not None and not isinstance(model, DBOSModel): # pragma: lax no cover
raise UserError(
'Non-DBOS model cannot be set at agent run time inside a DBOS workflow, it must be set at agent creation time.'
)
Expand Down
100 changes: 100 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_fastmcp_toolset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

from abc import ABC
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from dbos import DBOS
from typing_extensions import Self

from pydantic_ai import AbstractToolset, ToolsetTool, WrapperToolset
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
from pydantic_ai.toolsets.fastmcp import FastMCPToolset

from ._utils import StepConfig

if TYPE_CHECKING:
from pydantic_ai.mcp import ToolResult


class DBOSFastMCPToolset(WrapperToolset[AgentDepsT], ABC):
"""A wrapper for FastMCPToolset that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""

def __init__(
self,
wrapped: FastMCPToolset[AgentDepsT],
*,
step_name_prefix: str,
step_config: StepConfig,
):
super().__init__(wrapped)
self._step_config = step_config or {}
self._step_name_prefix = step_name_prefix
id_suffix = f'__{wrapped.id}' if wrapped.id else ''
self._name = f'{step_name_prefix}__fastmcp_toolset{id_suffix}'

# Wrap get_tools in a DBOS step.
@DBOS.step(
name=f'{self._name}.get_tools',
**self._step_config,
)
async def wrapped_get_tools_step(
ctx: RunContext[AgentDepsT],
) -> dict[str, ToolDefinition]:
# Need to return a serializable dict, so we cannot return ToolsetTool directly.
tools = await super(DBOSFastMCPToolset, self).get_tools(ctx)
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
return {name: tool.tool_def for name, tool in tools.items()}

self._dbos_wrapped_get_tools_step = wrapped_get_tools_step

# Wrap call_tool in a DBOS step.
@DBOS.step(
name=f'{self._name}.call_tool',
**self._step_config,
)
async def wrapped_call_tool_step(
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> ToolResult:
return await super(DBOSFastMCPToolset, self).call_tool(name, tool_args, ctx, tool)

self._dbos_wrapped_call_tool_step = wrapped_call_tool_step

@property
def id(self) -> str | None: # pragma: lax no cover
return self.wrapped.id

async def __aenter__(self) -> Self:
# The wrapped FastMCPToolset enters itself around listing and calling tools
# so we don't need to enter it here (nor could we because we're not inside a DBOS step).
return self

async def __aexit__(self, *args: Any) -> bool | None:
return None

def visit_and_replace(
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
) -> AbstractToolset[AgentDepsT]:
# DBOS-ified toolsets cannot be swapped out after the fact.
return self

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
tool_defs = await self._dbos_wrapped_get_tools_step(ctx)
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> ToolResult:
return await self._dbos_wrapped_call_tool_step(name, tool_args, ctx, tool)

def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
assert isinstance(self.wrapped, FastMCPToolset)
return self.wrapped.tool_for_tool_def(tool_def)
20 changes: 15 additions & 5 deletions pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from typing_extensions import Self

from pydantic_ai import AbstractToolset, ToolsetTool, WrapperToolset
from pydantic_ai.tools import AgentDepsT, RunContext
from pydantic_ai.mcp import MCPServer
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition

from ._utils import StepConfig

if TYPE_CHECKING:
from pydantic_ai.mcp import MCPServer, ToolResult
from pydantic_ai.mcp import ToolResult


class DBOSMCPServer(WrapperToolset[AgentDepsT], ABC):
Expand All @@ -39,8 +40,12 @@ def __init__(
)
async def wrapped_get_tools_step(
ctx: RunContext[AgentDepsT],
) -> dict[str, ToolsetTool[AgentDepsT]]:
return await super(DBOSMCPServer, self).get_tools(ctx)
) -> dict[str, ToolDefinition]:
# Need to return a serializable dict, so we cannot return ToolsetTool directly.
tools = await super(DBOSMCPServer, self).get_tools(ctx)
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
return {name: tool.tool_def for name, tool in tools.items()}

self._dbos_wrapped_get_tools_step = wrapped_get_tools_step

Expand Down Expand Up @@ -78,7 +83,8 @@ def visit_and_replace(
return self

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
return await self._dbos_wrapped_get_tools_step(ctx)
tool_defs = await self._dbos_wrapped_get_tools_step(ctx)
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}

async def call_tool(
self,
Expand All @@ -88,3 +94,7 @@ async def call_tool(
tool: ToolsetTool[AgentDepsT],
) -> ToolResult:
return await self._dbos_wrapped_call_tool_step(name, tool_args, ctx, tool)

def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
assert isinstance(self.wrapped, MCPServer)
return self.wrapped.tool_for_tool_def(tool_def)
Loading