Skip to content

Commit a7045d2

Browse files
committed
Add first_tools_rejected_callback
1 parent 87f4a46 commit a7045d2

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

erniebot-agent/src/erniebot_agent/agents/function_agent.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@
1616
import logging
1717
from collections import deque
1818
from dataclasses import dataclass, replace
19-
from typing import Deque, Final, Iterable, List, Optional, Sequence, Tuple, Union
19+
from typing import (
20+
Callable,
21+
Deque,
22+
Final,
23+
Iterable,
24+
List,
25+
Optional,
26+
Sequence,
27+
Tuple,
28+
Union,
29+
)
2030

2131
from erniebot_agent.agents.agent import Agent
2232
from erniebot_agent.agents.callback.callback_manager import CallbackManager
@@ -82,6 +92,7 @@ def __init__(
8292
plugins: Optional[List[str]] = None,
8393
max_steps: Optional[int] = None,
8494
first_tools: Optional[Sequence[BaseTool]] = [],
95+
first_tools_rejected_callback: Optional[Callable[[BaseTool, List[Message], AgentStep], None]] = None,
8596
) -> None:
8697
"""Initialize a function agent.
8798
@@ -133,6 +144,7 @@ def __init__(
133144
raise RuntimeError("The tool in `first_tools` must be in the tools list.")
134145
else:
135146
self._first_tools = []
147+
self._first_tools_rejected_callback = first_tools_rejected_callback
136148
self._snapshots: Deque[FunctionAgentRunSnapshot] = deque(maxlen=5)
137149
self._snapshot_for_curr_run: Optional[FunctionAgentRunSnapshot] = None
138150

@@ -170,6 +182,8 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
170182
else:
171183
# If tool choice not work, skip this round
172184
_logger.warning(f"Selected tool [{tool.tool_name}] not work")
185+
if self._first_tools_rejected_callback is not None:
186+
self._first_tools_rejected_callback(tool, new_messages, curr_step)
173187

174188
while num_steps_taken < self.max_steps:
175189
curr_step, new_messages = await self._step(chat_history)

0 commit comments

Comments
 (0)