|
16 | 16 | import logging
|
17 | 17 | from collections import deque
|
18 | 18 | from dataclasses import dataclass, replace
|
19 |
| -from typing import Deque, Final, Iterable, List, Optional, Sequence, Tuple, Union |
| 19 | +from typing import ( |
| 20 | + Callable, |
| 21 | + Deque, |
| 22 | + Final, |
| 23 | + Iterable, |
| 24 | + List, |
| 25 | + Optional, |
| 26 | + Sequence, |
| 27 | + Tuple, |
| 28 | + Union, |
| 29 | +) |
20 | 30 |
|
21 | 31 | from erniebot_agent.agents.agent import Agent
|
22 | 32 | from erniebot_agent.agents.callback.callback_manager import CallbackManager
|
@@ -82,6 +92,7 @@ def __init__(
|
82 | 92 | plugins: Optional[List[str]] = None,
|
83 | 93 | max_steps: Optional[int] = None,
|
84 | 94 | first_tools: Optional[Sequence[BaseTool]] = [],
|
| 95 | + first_tools_rejected_callback: Optional[Callable[[BaseTool, List[Message], AgentStep], None]] = None, |
85 | 96 | ) -> None:
|
86 | 97 | """Initialize a function agent.
|
87 | 98 |
|
@@ -133,6 +144,7 @@ def __init__(
|
133 | 144 | raise RuntimeError("The tool in `first_tools` must be in the tools list.")
|
134 | 145 | else:
|
135 | 146 | self._first_tools = []
|
| 147 | + self._first_tools_rejected_callback = first_tools_rejected_callback |
136 | 148 | self._snapshots: Deque[FunctionAgentRunSnapshot] = deque(maxlen=5)
|
137 | 149 | self._snapshot_for_curr_run: Optional[FunctionAgentRunSnapshot] = None
|
138 | 150 |
|
@@ -170,6 +182,8 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
|
170 | 182 | else:
|
171 | 183 | # If tool choice not work, skip this round
|
172 | 184 | _logger.warning(f"Selected tool [{tool.tool_name}] not work")
|
| 185 | + if self._first_tools_rejected_callback is not None: |
| 186 | + self._first_tools_rejected_callback(tool, new_messages, curr_step) |
173 | 187 |
|
174 | 188 | while num_steps_taken < self.max_steps:
|
175 | 189 | curr_step, new_messages = await self._step(chat_history)
|
|
0 commit comments