Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions livekit-agents/livekit/agents/llm/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@


class MCPServer(ABC):
def __init__(self, *, client_session_timeout_seconds: float) -> None:
def __init__(
self, *, client_session_timeout_seconds: float, tools: list[str] | None = None
) -> None:
self._client: ClientSession | None = None
self._exit_stack: AsyncExitStack = AsyncExitStack()
self._read_timeout = client_session_timeout_seconds

self._cache_dirty = True
self._lk_tools: list[MCPTool] | None = None
self._tools: list[str] = tools or []

@property
def initialized(self) -> bool:
Expand Down Expand Up @@ -72,10 +75,28 @@ async def list_tools(self) -> list[MCPTool]:
if not self._cache_dirty and self._lk_tools is not None:
return self._lk_tools

tools = await self._client.list_tools()
client_tools = await self._client.list_tools()

# If a subset of tool names is configured, validate & filter
if self._tools:
requested = set(self._tools)
available_names = {tool.name for tool in client_tools.tools}

missing = requested - available_names
if missing:
raise ToolError(
f"Specified tool(s) do not exist in MCP Server: {', '.join(sorted(missing))}"
)

client_tools.tools = [
tool for tool in client_tools.tools if tool.name in requested
]

lk_tools = [
self._make_function_tool(tool.name, tool.description, tool.inputSchema, tool.meta)
for tool in tools.tools
self._make_function_tool(
tool.name, tool.description, tool.inputSchema, tool.meta
)
for tool in client_tools.tools
]

self._lk_tools = lk_tools
Expand Down Expand Up @@ -144,7 +165,8 @@ def client_streams(
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback,
]
]: ...
]:
...


class MCPServerHTTP(MCPServer):
Expand All @@ -162,12 +184,15 @@ class MCPServerHTTP(MCPServer):
def __init__(
self,
url: str,
tools: list[str] | None = None,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
client_session_timeout_seconds: float = 5,
) -> None:
super().__init__(client_session_timeout_seconds=client_session_timeout_seconds)
super().__init__(
client_session_timeout_seconds=client_session_timeout_seconds, tools=tools
)
self.url = url
self.headers = headers
self._timeout = timeout
Expand Down Expand Up @@ -223,11 +248,14 @@ def __init__(
self,
command: str,
args: list[str],
tools: list[str] | None = None,
env: dict[str, str] | None = None,
cwd: str | Path | None = None,
client_session_timeout_seconds: float = 5,
) -> None:
super().__init__(client_session_timeout_seconds=client_session_timeout_seconds)
super().__init__(
client_session_timeout_seconds=client_session_timeout_seconds, tools=tools
)
self.command = command
self.args = args
self.env = env
Expand All @@ -242,8 +270,12 @@ def client_streams(
]
]:
return stdio_client( # type: ignore[no-any-return]
StdioServerParameters(command=self.command, args=self.args, env=self.env, cwd=self.cwd)
StdioServerParameters(
command=self.command, args=self.args, env=self.env, cwd=self.cwd
)
)

def __repr__(self) -> str:
return f"MCPServerStdio(command={self.command}, args={self.args}, cwd={self.cwd})"
return (
f"MCPServerStdio(command={self.command}, args={self.args}, cwd={self.cwd})"
)
Loading