-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmcpx_pydantic_ai.py
128 lines (105 loc) · 3.74 KB
/
mcpx_pydantic_ai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import mcp_run
import pydantic_ai
import pydantic
from pydantic import BaseModel, Field
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
from mcp_run.client import _convert_type
from typing import TypedDict, List, Set, AsyncIterator, Any
import traceback
__all__ = ["BaseModel", "Field", "Agent", "mcp_run", "pydantic_ai", "pydantic"]
def openai_compatible_model(url: str, model: str, api_key: str | None = None):
"""
Returns an OpenAI compatible model from the provided `url`, `model` name and optional `api_key`
"""
provider = OpenAIProvider(base_url=url, api_key=api_key)
return OpenAIModel(model, provider=provider)
class Agent(pydantic_ai.Agent):
"""
A Pydantic Agent using tools from mcp.run
"""
client: mcp_run.Client
ignore_tools: Set[str]
_original_tools: list
_registered_tools: List[str]
def __init__(
self,
*args,
client: mcp_run.Client | None = None,
ignore_tools: List[str] | None = None,
**kw,
):
self.client = client or mcp_run.Client()
self._original_tools = kw.get("tools", [])
self._registered_tools = []
self.ignore_tools = set(ignore_tools or [])
super().__init__(*args, **kw)
self._update_tools()
for t in self._original_tools:
self._registered_tools.append(t.name)
def set_profile(self, profile: str):
self.client.set_profile(profile)
self._update_tools()
def register_tool(self, tool: mcp_run.Tool, f=None):
if tool.name in self.ignore_tools:
return
def wrap(tool, inner):
if inner is not None:
props = tool.input_schema["properties"]
t = {k: _convert_type(v["type"]) for k, v in props.items()}
InputType = TypedDict("Input", t)
def f(input: InputType):
try:
return inner(input)
except Exception as exc:
return f"ERROR call to tool {tool.name} failed: {traceback.format_exception(exc)}"
return f
else:
return self.client._make_pydantic_function(tool)
self._register_tool(
pydantic_ai.Tool(
wrap(tool, f),
name=tool.name,
description=tool.description,
)
)
if f is not None:
self._registered_tools.append(tool.name)
def reset_tools(self):
for k in list(self._function_tools.keys()):
if k not in self._registered_tools:
del self._function_tools[k]
def _update_tools(self):
self.reset_tools()
for tool in self.client.tools.values():
self.register_tool(tool)
async def run(self, *args, update_tools: bool = True, **kw):
if update_tools:
self._update_tools()
return await super().run(*args, **kw)
def run_sync(self, *args, update_tools: bool = True, **kw):
if update_tools:
self._update_tools()
return super().run_sync(*args, **kw)
async def run_async(self, *args, update_tools: bool = True, **kw):
if update_tools:
self._update_tools()
return await super().run_async(*args, **kw)
def run_stream(
self,
*args,
update_tools: bool = True,
**kw,
) -> AsyncIterator[Any]:
if update_tools:
self._update_tools()
return super().run_stream(*args, **kw)
def iter(
self,
*args,
update_tools: bool = True,
**kw,
) -> AsyncIterator[Any]:
if update_tools:
self._update_tools()
return super().iter(*args, **kw)