13
13
# limitations under the License.
14
14
15
15
import logging
16
- from typing import Final , Iterable , List , Optional , Sequence , Tuple , Union
16
+ from typing import (
17
+ AsyncIterator ,
18
+ Final ,
19
+ Iterable ,
20
+ List ,
21
+ Optional ,
22
+ Sequence ,
23
+ Tuple ,
24
+ Union ,
25
+ )
17
26
18
27
from erniebot_agent .agents .agent import Agent
19
28
from erniebot_agent .agents .callback .callback_manager import CallbackManager
31
40
from erniebot_agent .chat_models .erniebot import BaseERNIEBot
32
41
from erniebot_agent .file import File , FileManager
33
42
from erniebot_agent .memory import Memory
34
- from erniebot_agent .memory .messages import FunctionMessage , HumanMessage , Message
43
+ from erniebot_agent .memory .messages import (
44
+ AIMessage ,
45
+ FunctionMessage ,
46
+ HumanMessage ,
47
+ Message ,
48
+ )
35
49
from erniebot_agent .tools .base import BaseTool
36
50
from erniebot_agent .tools .tool_manager import ToolManager
37
51
@@ -136,7 +150,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
136
150
chat_history .append (run_input )
137
151
138
152
for tool in self ._first_tools :
139
- curr_step , new_messages = await self ._step (chat_history , selected_tool = tool )
153
+ curr_step , new_messages = await self ._call_first_tools (chat_history , selected_tool = tool )
140
154
if not isinstance (curr_step , EndStep ):
141
155
chat_history .extend (new_messages )
142
156
num_steps_taken += 1
@@ -167,23 +181,122 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
167
181
response = self ._create_stopped_response (chat_history , steps_taken )
168
182
return response
169
183
170
- async def _step (
184
+ async def _call_first_tools (
171
185
self , chat_history : List [Message ], selected_tool : Optional [BaseTool ] = None
172
186
) -> Tuple [AgentStep , List [Message ]]:
173
- new_messages : List [Message ] = []
174
187
input_messages = self .memory .get_messages () + chat_history
175
- if selected_tool is not None :
176
- tool_choice = {"type" : "function" , "function" : {"name" : selected_tool .tool_name }}
177
- llm_resp = await self .run_llm (
178
- messages = input_messages ,
179
- functions = [selected_tool .function_call_schema ()], # only regist one tool
180
- tool_choice = tool_choice ,
181
- )
182
- else :
188
+ if selected_tool is None :
183
189
llm_resp = await self .run_llm (messages = input_messages )
190
+ return await self ._process_step (llm_resp , chat_history )
191
+
192
+ tool_choice = {"type" : "function" , "function" : {"name" : selected_tool .tool_name }}
193
+ llm_resp = await self .run_llm (
194
+ messages = input_messages ,
195
+ functions = [selected_tool .function_call_schema ()], # only regist one tool
196
+ tool_choice = tool_choice ,
197
+ )
198
+ return await self ._process_step (llm_resp , chat_history )
199
+
200
+ async def _step (self , chat_history : List [Message ]) -> Tuple [AgentStep , List [Message ]]:
201
+ """Run a step of the agent.
202
+ Args:
203
+ chat_history: The chat history to provide to the agent.
204
+ Returns:
205
+ A tuple of an agent step and a list of new messages.
206
+ """
207
+ input_messages = self .memory .get_messages () + chat_history
208
+ llm_resp = await self .run_llm (messages = input_messages )
209
+ return await self ._process_step (llm_resp , chat_history )
184
210
211
+ async def _step_stream (
212
+ self , chat_history : List [Message ]
213
+ ) -> AsyncIterator [Tuple [AgentStep , List [Message ]]]:
214
+ """Run a step of the agent in streaming mode.
215
+ Args:
216
+ chat_history: The chat history to provide to the agent.
217
+ Returns:
218
+ An async iterator that yields a tuple of an agent step and a list ofnew messages.
219
+ """
220
+ input_messages = self .memory .get_messages () + chat_history
221
+ async for llm_resp in self .run_llm_stream (messages = input_messages ):
222
+ yield await self ._process_step (llm_resp , chat_history )
223
+
224
+ async def _run_stream (
225
+ self , prompt : str , files : Optional [Sequence [File ]] = None
226
+ ) -> AsyncIterator [Tuple [AgentStep , List [Message ]]]:
227
+ """Run the agent with the given prompt and files in streaming mode.
228
+ Args:
229
+ prompt: The prompt for the agent to run.
230
+ files: A list of files for the agent to use. If `None`, use an empty
231
+ list.
232
+ Returns:
233
+ If `stream` is `False`, an agent response object. If `stream` is
234
+ `True`, an async iterator that yields agent steps one by one.
235
+ """
236
+ chat_history : List [Message ] = []
237
+ steps_taken : List [AgentStep ] = []
238
+
239
+ run_input = await HumanMessage .create_with_files (
240
+ prompt , files or [], include_file_urls = self .file_needs_url
241
+ )
242
+
243
+ num_steps_taken = 0
244
+ chat_history .append (run_input )
245
+
246
+ for tool in self ._first_tools :
247
+ curr_step , new_messages = await self ._call_first_tools (chat_history , selected_tool = tool )
248
+ if not isinstance (curr_step , EndStep ):
249
+ chat_history .extend (new_messages )
250
+ num_steps_taken += 1
251
+ steps_taken .append (curr_step )
252
+ else :
253
+ # If tool choice not work, skip this round
254
+ _logger .warning (f"Selected tool [{ tool .tool_name } ] not work" )
255
+
256
+ is_finished = False
257
+ new_messages = []
258
+ end_step_msgs = []
259
+ while is_finished is False :
260
+ # IMPORTANT~! We use following code to get the response from LLM
261
+ # When finish_reason is fuction_call, run_llm_stream return all info in one step, but
262
+ # When finish_reason is normal chat, run_llm_stream return info in multiple steps.
263
+ async for curr_step , new_messages in self ._step_stream (chat_history ):
264
+ if isinstance (curr_step , ToolStep ):
265
+ steps_taken .append (curr_step )
266
+ yield curr_step , new_messages
267
+
268
+ elif isinstance (curr_step , PluginStep ):
269
+ steps_taken .append (curr_step )
270
+ # 预留 调用了Plugin之后不结束的接口
271
+
272
+ # 此处为调用了Plugin之后直接结束的Plugin
273
+ curr_step = DEFAULT_FINISH_STEP
274
+ yield curr_step , new_messages
275
+
276
+ elif isinstance (curr_step , EndStep ):
277
+ is_finished = True
278
+ end_step_msgs .extend (new_messages )
279
+ yield curr_step , new_messages
280
+ else :
281
+ raise RuntimeError ("Invalid step type" )
282
+ chat_history .extend (new_messages )
283
+
284
+ self .memory .add_message (run_input )
285
+ end_step_msg = AIMessage (content = "" .join ([item .content for item in end_step_msgs ]))
286
+ self .memory .add_message (end_step_msg )
287
+
288
+ async def _process_step (self , llm_resp , chat_history ) -> Tuple [AgentStep , List [Message ]]:
289
+ """Process and execute a step of the agent from LLM response.
290
+ Args:
291
+ llm_resp: The LLM response to convert.
292
+ chat_history: The chat history to provide to the agent.
293
+ Returns:
294
+ A tuple of an agent step and a list of new messages.
295
+ """
296
+ new_messages : List [Message ] = []
185
297
output_message = llm_resp .message # AIMessage
186
298
new_messages .append (output_message )
299
+ # handle function call
187
300
if output_message .function_call is not None :
188
301
tool_name = output_message .function_call ["name" ]
189
302
tool_args = output_message .function_call ["arguments" ]
@@ -198,6 +311,7 @@ async def _step(
198
311
),
199
312
new_messages ,
200
313
)
314
+ # handle plugin info with input/output files
201
315
elif output_message .plugin_info is not None :
202
316
file_manager = self .get_file_manager ()
203
317
return (
0 commit comments