Skip to content

Commit 47e3889

Browse files
nitpicker55555fengju0213Wendong-Fan
authored
fix: message integration (#3269)
Co-authored-by: Tao Sun <[email protected]> Co-authored-by: Sun Tao <[email protected]> Co-authored-by: Wendong-Fan <[email protected]>
1 parent 5e34848 commit 47e3889

File tree

3 files changed

+46
-33
lines changed

3 files changed

+46
-33
lines changed

camel/agents/chat_agent.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4519,23 +4519,29 @@ def _clone_tools(
45194519
# Toolkit doesn't support cloning, use original
45204520
cloned_toolkits[toolkit_id] = toolkit_instance
45214521

4522-
if getattr(
4523-
tool.func, "__message_integration_enhanced__", False
4524-
):
4525-
cloned_tools.append(
4526-
FunctionTool(
4527-
func=tool.func,
4528-
openai_tool_schema=tool.get_openai_tool_schema(),
4529-
)
4530-
)
4531-
continue
4532-
45334522
# Get the method from the cloned (or original) toolkit
45344523
toolkit = cloned_toolkits[toolkit_id]
45354524
method_name = tool.func.__name__
45364525

4526+
# Check if toolkit was actually cloned or just reused
4527+
toolkit_was_cloned = toolkit is not toolkit_instance
4528+
45374529
if hasattr(toolkit, method_name):
45384530
new_method = getattr(toolkit, method_name)
4531+
4532+
# If toolkit wasn't cloned (stateless), preserve the
4533+
# original function to maintain any enhancements/wrappers
4534+
if not toolkit_was_cloned:
4535+
# Toolkit is stateless, safe to reuse original function
4536+
cloned_tools.append(
4537+
FunctionTool(
4538+
func=tool.func,
4539+
openai_tool_schema=tool.get_openai_tool_schema(),
4540+
)
4541+
)
4542+
continue
4543+
4544+
# Toolkit was cloned, use the new method
45394545
# Wrap cloned method into a new FunctionTool,
45404546
# preserving schema
45414547
try:

camel/toolkits/message_integration.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,10 @@ def get_message_tool(self) -> FunctionTool:
148148
"""
149149
return FunctionTool(self.send_message_to_user)
150150

151-
def register_toolkits(
152-
self, toolkit: BaseToolkit, tool_names: Optional[List[str]] = None
153-
) -> BaseToolkit:
154-
r"""Add messaging capabilities to toolkit methods.
151+
def register_toolkits(self, toolkit: BaseToolkit) -> BaseToolkit:
152+
r"""Add messaging capabilities to all toolkit methods.
155153
156-
This method modifies a toolkit so that specified tools can send
154+
This method modifies a toolkit so that all its tools can send
157155
status messages to users while executing their primary function.
158156
The tools will accept optional messaging parameters:
159157
- message_title: Title of the status message
@@ -162,20 +160,18 @@ def register_toolkits(
162160
163161
Args:
164162
toolkit: The toolkit to add messaging capabilities to
165-
tool_names: List of specific tool names to modify.
166-
If None, messaging is added to all tools.
167163
168164
Returns:
169-
The toolkit with messaging capabilities added
165+
The same toolkit instance with messaging capabilities added to
166+
all methods.
170167
"""
171168
original_tools = toolkit.get_tools()
172169
enhanced_methods = {}
173170
for tool in original_tools:
174171
method_name = tool.func.__name__
175-
if tool_names is None or method_name in tool_names:
176-
enhanced_func = self._add_messaging_to_tool(tool.func)
177-
enhanced_methods[method_name] = enhanced_func
178-
setattr(toolkit, method_name, enhanced_func)
172+
enhanced_func = self._add_messaging_to_tool(tool.func)
173+
enhanced_methods[method_name] = enhanced_func
174+
setattr(toolkit, method_name, enhanced_func)
179175
original_get_tools_method = toolkit.get_tools
180176

181177
def enhanced_get_tools() -> List[FunctionTool]:
@@ -201,7 +197,7 @@ def enhanced_get_tools() -> List[FunctionTool]:
201197
def enhanced_clone_for_new_session(new_session_id=None):
202198
cloned_toolkit = original_clone_method(new_session_id)
203199
return message_integration_instance.register_toolkits(
204-
cloned_toolkit, tool_names
200+
cloned_toolkit
205201
)
206202

207203
toolkit.clone_for_new_session = enhanced_clone_for_new_session
@@ -300,6 +296,12 @@ def _add_messaging_to_tool(self, func: Callable) -> Callable:
300296
This internal method modifies the function signature and docstring
301297
to include optional messaging parameters that trigger status updates.
302298
"""
299+
if getattr(func, "__message_integration_enhanced__", False):
300+
logger.debug(
301+
f"Function {func.__name__} already enhanced, skipping"
302+
)
303+
return func
304+
303305
# Get the original signature
304306
original_sig = inspect.signature(func)
305307

test/toolkits/test_message_integration.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,16 @@ def test_register_all_tools(self):
9090
self.assertIn('message_attachment', params)
9191

9292
def test_register_specific_tools(self):
93-
r"""Test adding messaging to specific tools only."""
94-
enhanced_toolkit = self.message_integration.register_toolkits(
95-
self.toolkit, tool_names=['search_web']
93+
r"""Test adding messaging to specific tools only using
94+
register_functions."""
95+
# Use register_functions to enhance only one method
96+
enhanced_tools = self.message_integration.register_functions(
97+
[self.toolkit.search_web]
9698
)
9799

98-
tools = enhanced_toolkit.get_tools()
99-
search_tool = next(t for t in tools if t.func.__name__ == 'search_web')
100-
analyze_tool = next(
101-
t for t in tools if t.func.__name__ == 'analyze_data'
102-
)
100+
# Should get one enhanced tool
101+
self.assertEqual(len(enhanced_tools), 1)
102+
search_tool = enhanced_tools[0]
103103

104104
# Check search_web has message parameters
105105
search_schema = search_tool.get_openai_tool_schema()
@@ -108,7 +108,12 @@ def test_register_specific_tools(self):
108108
search_schema['function']['parameters']['properties'],
109109
)
110110

111-
# Check analyze_data doesn't have message parameters
111+
# Check the original toolkit's analyze_data doesn't have message
112+
# parameters
113+
original_tools = self.toolkit.get_tools()
114+
analyze_tool = next(
115+
t for t in original_tools if t.func.__name__ == 'analyze_data'
116+
)
112117
analyze_schema = analyze_tool.get_openai_tool_schema()
113118
self.assertNotIn(
114119
'message_title',

0 commit comments

Comments
 (0)