Skip to content

Commit a301125

Browse files
committed
Add typed Tool/ToolError classes per Pankit's feedback
- Add Tool dataclass with name, description, input_schema, output_schema - Add ToolError and ToolErrorType for structured error handling - Update ListToolsObservation to use List[Tool] instead of dicts - Add MCP deps (mcp, fastmcp) to src/pyproject.toml - Add pytest config to root pyproject.toml - Remove test __init__.py files that shadowed src/core
1 parent 7a74252 commit a301125

File tree

8 files changed

+73
-63
lines changed

8 files changed

+73
-63
lines changed

src/core/env_server/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
CallToolObservation,
1616
ListToolsAction,
1717
ListToolsObservation,
18+
Tool,
19+
ToolError,
20+
ToolErrorType,
1821
)
1922
from .types import Action, Observation, State
2023
from .web_interface import create_web_interface_app, WebInterfaceManager
@@ -33,6 +36,9 @@
3336
"Observation",
3437
"ListToolsObservation",
3538
"CallToolObservation",
39+
"Tool",
40+
"ToolError",
41+
"ToolErrorType",
3642
"State",
3743
# Base transforms
3844
"CompositeTransform",

src/core/env_server/mcp_environment.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,13 @@ async def _handle_mcp_action(self, action: Action) -> Observation:
126126
Raises:
127127
ValueError: If MCP client not configured or action type invalid
128128
"""
129-
from .mcp_types import CallToolObservation, ListToolsObservation
129+
from .mcp_types import (
130+
CallToolObservation,
131+
ListToolsObservation,
132+
Tool,
133+
ToolError,
134+
ToolErrorType,
135+
)
130136

131137
if self.mcp_client is None:
132138
raise ValueError("MCP client not configured for this environment")
@@ -137,11 +143,11 @@ async def _handle_mcp_action(self, action: Action) -> Observation:
137143
return ListToolsObservation(
138144
done=False,
139145
tools=[
140-
{
141-
"name": tool.name,
142-
"description": tool.description,
143-
"inputSchema": tool.inputSchema,
144-
}
146+
Tool(
147+
name=tool.name,
148+
description=tool.description or "",
149+
input_schema=tool.inputSchema or {},
150+
)
145151
for tool in tools
146152
],
147153
)
@@ -151,14 +157,18 @@ async def _handle_mcp_action(self, action: Action) -> Observation:
151157
result = await self.mcp_client.call_tool(
152158
action.tool_name, action.parameters
153159
)
154-
# Extract data from CallToolResult (FastMCP wraps results)
155160
result_data = result.data if hasattr(result, "data") else result
156161
return CallToolObservation(
157162
done=False, result=result_data, tool_name=action.tool_name
158163
)
159164
except Exception as e:
160165
return CallToolObservation(
161-
done=False, error=str(e), tool_name=action.tool_name
166+
done=False,
167+
tool_name=action.tool_name,
168+
error=ToolError(
169+
error_type=ToolErrorType.EXECUTION_ERROR,
170+
message=str(e),
171+
),
162172
)
163173

164174
else:

src/core/env_server/mcp_types.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,49 @@
1212
"""
1313

1414
from dataclasses import dataclass, field
15+
from enum import Enum
1516
from typing import Any, Dict, List, Optional
1617

1718
from .types import Action, Observation
1819

1920

21+
class ToolErrorType(Enum):
22+
"""Types of errors that can occur during tool execution."""
23+
24+
INVALID_ARGUMENTS = "invalid_arguments"
25+
TOOL_NOT_FOUND = "tool_not_found"
26+
EXECUTION_ERROR = "execution_error"
27+
TIMEOUT = "timeout"
28+
29+
30+
@dataclass
31+
class ToolError:
32+
"""
33+
Structured error for tool call failures.
34+
35+
Used for transport/validation errors. Tool execution errors that are
36+
part of normal operation should be returned in the result field.
37+
"""
38+
39+
error_type: ToolErrorType
40+
message: str
41+
details: Optional[Dict[str, Any]] = None
42+
43+
44+
@dataclass
45+
class Tool:
46+
"""
47+
Strongly typed representation of an MCP tool.
48+
49+
Follows the MCP specification for tool definitions.
50+
"""
51+
52+
name: str
53+
description: str
54+
input_schema: Dict[str, Any]
55+
output_schema: Optional[Dict[str, Any]] = None
56+
57+
2058
@dataclass(kw_only=True)
2159
class ListToolsAction(Action):
2260
"""
@@ -49,17 +87,19 @@ class ListToolsObservation(Observation):
4987
Contains the list of available tools with their schemas.
5088
"""
5189

52-
tools: List[Dict[str, Any]] = field(default_factory=list)
90+
tools: List[Tool] = field(default_factory=list)
5391

5492

5593
@dataclass(kw_only=True)
5694
class CallToolObservation(Observation):
5795
"""
5896
Observation returned from CallToolAction.
5997
60-
Contains the result of calling a tool, or an error if the call failed.
98+
Contains the result of calling a tool. The error field is for
99+
transport/validation errors only - tool execution errors should
100+
be part of the result.
61101
"""
62102

63-
result: Optional[Any] = None
64-
error: Optional[str] = None
65-
tool_name: Optional[str] = None
103+
tool_name: str
104+
result: Any = None
105+
error: Optional[ToolError] = None

src/core/pyproject.toml

Lines changed: 0 additions & 49 deletions
This file was deleted.

src/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@ dependencies = [
2525
"rich>=13.0.0",
2626
"tomli>=2.0.1",
2727
"tomli-w>=1.0.0",
28+
"mcp>=1.0.0",
29+
"fastmcp>=0.1.0",
2830
]
2931

3032
[project.optional-dependencies]
3133
dev = [
3234
"pytest>=7.0.0",
35+
"pytest-asyncio>=0.21.0",
3336
"black>=23.0.0",
3437
"ruff>=0.1.0",
3538
"mypy>=1.0.0",

tests/core/__init__.py

Whitespace-only changes.

tests/core/mcp/__init__.py

Whitespace-only changes.

tests/core/mcp/test_mcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ async def test_echo_env_mcp_integration():
5858
assert not obs.done
5959
assert hasattr(obs, "tools")
6060
assert len(obs.tools) == 1
61-
assert obs.tools[0]["name"] == "echo_message"
61+
assert obs.tools[0].name == "echo_message"
6262

6363
# Test CallToolAction
6464
call_action = CallToolAction(

0 commit comments

Comments
 (0)