Skip to content
25 changes: 24 additions & 1 deletion pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ def dbosify_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[Age
step_config=self._mcp_step_config,
)

# Replace FastMCPToolset with DBOSFastMCPToolset
try:
from pydantic_ai.toolsets.fastmcp import FastMCPToolset

from ._fastmcp_toolset import DBOSFastMCPToolset
except ImportError:
pass
else:
if isinstance(toolset, FastMCPToolset):
return DBOSFastMCPToolset(
wrapped=toolset,
step_name_prefix=dbosagent_name,
step_config=self._mcp_step_config,
)

return toolset

dbos_toolsets = [toolset.visit_and_replace(dbosify_toolset) for toolset in wrapped.toolsets]
Expand Down Expand Up @@ -336,6 +351,10 @@ async def main():
Returns:
The result of the run.
"""
if model is not None and not isinstance(model, DBOSModel):
raise UserError(
'Non-DBOS model cannot be set at agent run time inside a DBOS workflow, it must be set at agent creation time.'
)
return await self.dbos_wrapped_run_workflow(
user_prompt,
output_type=output_type,
Expand Down Expand Up @@ -449,6 +468,10 @@ def run_sync(
Returns:
The result of the run.
"""
if model is not None and not isinstance(model, DBOSModel): # pragma: lax no cover
raise UserError(
'Non-DBOS model cannot be set at agent run time inside a DBOS workflow, it must be set at agent creation time.'
)
return self.dbos_wrapped_run_sync_workflow(
user_prompt,
output_type=output_type,
Expand Down Expand Up @@ -838,7 +861,7 @@ async def main():
Returns:
The result of the run.
"""
if model is not None and not isinstance(model, DBOSModel):
if model is not None and not isinstance(model, DBOSModel): # pragma: lax no cover
raise UserError(
'Non-DBOS model cannot be set at agent run time inside a DBOS workflow, it must be set at agent creation time.'
)
Expand Down
29 changes: 29 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_fastmcp_toolset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

from pydantic_ai import ToolsetTool
from pydantic_ai.tools import AgentDepsT, ToolDefinition
from pydantic_ai.toolsets.fastmcp import FastMCPToolset

from ._mcp import DBOSMCPToolset
from ._utils import StepConfig


class DBOSFastMCPToolset(DBOSMCPToolset[AgentDepsT]):
"""A wrapper for FastMCPToolset that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""

def __init__(
self,
wrapped: FastMCPToolset[AgentDepsT],
*,
step_name_prefix: str,
step_config: StepConfig,
):
super().__init__(
wrapped,
step_name_prefix=step_name_prefix,
step_config=step_config,
)

def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
assert isinstance(self.wrapped, FastMCPToolset)
return self.wrapped.tool_for_tool_def(tool_def)
99 changes: 99 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from dbos import DBOS
from typing_extensions import Self

from pydantic_ai import AbstractToolset, ToolsetTool, WrapperToolset
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition

from ._utils import StepConfig

if TYPE_CHECKING:
from pydantic_ai.mcp import ToolResult


class DBOSMCPToolset(WrapperToolset[AgentDepsT], ABC):
"""A wrapper for MCP toolset that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""

def __init__(
self,
wrapped: AbstractToolset[AgentDepsT],
*,
step_name_prefix: str,
step_config: StepConfig,
):
super().__init__(wrapped)
self._step_config = step_config or {}
self._step_name_prefix = step_name_prefix
id_suffix = f'__{wrapped.id}' if wrapped.id else ''
self._name = f'{step_name_prefix}__mcp_server{id_suffix}'

# Wrap get_tools in a DBOS step.
@DBOS.step(
name=f'{self._name}.get_tools',
**self._step_config,
)
async def wrapped_get_tools_step(
ctx: RunContext[AgentDepsT],
) -> dict[str, ToolDefinition]:
# Need to return a serializable dict, so we cannot return ToolsetTool directly.
tools = await super(DBOSMCPToolset, self).get_tools(ctx)
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
return {name: tool.tool_def for name, tool in tools.items()}

self._dbos_wrapped_get_tools_step = wrapped_get_tools_step

# Wrap call_tool in a DBOS step.
@DBOS.step(
name=f'{self._name}.call_tool',
**self._step_config,
)
async def wrapped_call_tool_step(
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> ToolResult:
return await super(DBOSMCPToolset, self).call_tool(name, tool_args, ctx, tool)

self._dbos_wrapped_call_tool_step = wrapped_call_tool_step

@abstractmethod
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
raise NotImplementedError

@property
def id(self) -> str | None:
return self.wrapped.id

async def __aenter__(self) -> Self:
# The wrapped MCP toolset enters itself around listing and calling tools
# so we don't need to enter it here (nor could we because we're not inside a DBOS step).
return self

async def __aexit__(self, *args: Any) -> bool | None:
return None

def visit_and_replace(
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
) -> AbstractToolset[AgentDepsT]:
# DBOS-ified toolsets cannot be swapped out after the fact.
return self

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
tool_defs = await self._dbos_wrapped_get_tools_step(ctx)
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> ToolResult:
return await self._dbos_wrapped_call_tool_step(name, tool_args, ctx, tool)
85 changes: 12 additions & 73 deletions pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_mcp_server.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
from __future__ import annotations

from abc import ABC
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from dbos import DBOS
from typing_extensions import Self

from pydantic_ai import AbstractToolset, ToolsetTool, WrapperToolset
from pydantic_ai.tools import AgentDepsT, RunContext
from pydantic_ai import ToolsetTool
from pydantic_ai.mcp import MCPServer
from pydantic_ai.tools import AgentDepsT, ToolDefinition

from ._mcp import DBOSMCPToolset
from ._utils import StepConfig

if TYPE_CHECKING:
from pydantic_ai.mcp import MCPServer, ToolResult


class DBOSMCPServer(WrapperToolset[AgentDepsT], ABC):
class DBOSMCPServer(DBOSMCPToolset[AgentDepsT]):
"""A wrapper for MCPServer that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""

def __init__(
Expand All @@ -26,65 +18,12 @@ def __init__(
step_name_prefix: str,
step_config: StepConfig,
):
super().__init__(wrapped)
self._step_config = step_config or {}
self._step_name_prefix = step_name_prefix
id_suffix = f'__{wrapped.id}' if wrapped.id else ''
self._name = f'{step_name_prefix}__mcp_server{id_suffix}'

# Wrap get_tools in a DBOS step.
@DBOS.step(
name=f'{self._name}.get_tools',
**self._step_config,
)
async def wrapped_get_tools_step(
ctx: RunContext[AgentDepsT],
) -> dict[str, ToolsetTool[AgentDepsT]]:
return await super(DBOSMCPServer, self).get_tools(ctx)

self._dbos_wrapped_get_tools_step = wrapped_get_tools_step

# Wrap call_tool in a DBOS step.
@DBOS.step(
name=f'{self._name}.call_tool',
**self._step_config,
super().__init__(
wrapped,
step_name_prefix=step_name_prefix,
step_config=step_config,
)
async def wrapped_call_tool_step(
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> ToolResult:
return await super(DBOSMCPServer, self).call_tool(name, tool_args, ctx, tool)

self._dbos_wrapped_call_tool_step = wrapped_call_tool_step

@property
def id(self) -> str | None:
return self.wrapped.id

async def __aenter__(self) -> Self:
# The wrapped MCPServer enters itself around listing and calling tools
# so we don't need to enter it here (nor could we because we're not inside a DBOS step).
return self

async def __aexit__(self, *args: Any) -> bool | None:
return None

def visit_and_replace(
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
) -> AbstractToolset[AgentDepsT]:
# DBOS-ified toolsets cannot be swapped out after the fact.
return self

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
return await self._dbos_wrapped_get_tools_step(ctx)

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> ToolResult:
return await self._dbos_wrapped_call_tool_step(name, tool_args, ctx, tool)
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
assert isinstance(self.wrapped, MCPServer)
return self.wrapped.tool_for_tool_def(tool_def)
Loading