Skip to content

Commit 3f2ecc0

Browse files
authored
[New Feature] 支持 stream agent 流式返回step和message (#345)
* stream agent 流式返回step和message * 使用make format-check进行代码优化 * 使用make format-check检查代码、使用make format优化代码 * 使用`python -m mypy src`检查和优化代码 * 使用`make lint`检查优化代码 * 根据review意见进行修改 * 使用`python -m black --check`检查格式 * 判断typing.TYPE_CHECKING来处理编译器检查
1 parent 621ce45 commit 3f2ecc0

File tree

3 files changed

+351
-15
lines changed

3 files changed

+351
-15
lines changed

erniebot-agent/src/erniebot_agent/agents/agent.py

+126-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import json
33
import logging
44
from typing import (
5+
TYPE_CHECKING,
56
Any,
7+
AsyncIterator,
68
Dict,
79
Final,
810
Iterable,
@@ -20,7 +22,13 @@
2022
from erniebot_agent.agents.callback.default import get_default_callbacks
2123
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
2224
from erniebot_agent.agents.mixins import GradioMixin
23-
from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse
25+
from erniebot_agent.agents.schema import (
26+
DEFAULT_FINISH_STEP,
27+
AgentResponse,
28+
AgentStep,
29+
LLMResponse,
30+
ToolResponse,
31+
)
2432
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
2533
from erniebot_agent.file import (
2634
File,
@@ -131,13 +139,46 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
131139
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
132140
return agent_resp
133141

142+
@final
143+
async def run_stream(
144+
self, prompt: str, files: Optional[Sequence[File]] = None
145+
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
146+
"""Run the agent asynchronously, returning an async iterator of responses.
147+
148+
Args:
149+
prompt: A natural language text describing the task that the agent
150+
should perform.
151+
files: A list of files that the agent can use to perform the task.
152+
Returns:
153+
Iterator of responses from the agent.
154+
"""
155+
if files:
156+
await self._ensure_managed_files(files)
157+
await self._callback_manager.on_run_start(agent=self, prompt=prompt)
158+
try:
159+
async for step, msg in self._run_stream(prompt, files):
160+
yield (step, msg)
161+
except BaseException as e:
162+
await self._callback_manager.on_run_error(agent=self, error=e)
163+
raise e
164+
else:
165+
await self._callback_manager.on_run_end(
166+
agent=self,
167+
response=AgentResponse(
168+
text="Agent run stopped.",
169+
chat_history=self.memory.get_messages(),
170+
steps=[step],
171+
status="STOPPED",
172+
),
173+
)
174+
134175
@final
135176
async def run_llm(
136177
self,
137178
messages: List[Message],
138179
**llm_opts: Any,
139180
) -> LLMResponse:
140-
"""Run the LLM asynchronously.
181+
"""Run the LLM asynchronously, returning final response.
141182
142183
Args:
143184
messages: The input messages.
@@ -156,6 +197,34 @@ async def run_llm(
156197
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
157198
return llm_resp
158199

200+
@final
201+
async def run_llm_stream(
202+
self,
203+
messages: List[Message],
204+
**llm_opts: Any,
205+
) -> AsyncIterator[LLMResponse]:
206+
"""Run the LLM asynchronously, returning an async iterator of responses
207+
208+
Args:
209+
messages: The input messages.
210+
llm_opts: Options to pass to the LLM.
211+
212+
Returns:
213+
Iterator of responses from the LLM.
214+
"""
215+
llm_resp = None
216+
await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages)
217+
try:
218+
# The LLM will return an async iterator.
219+
async for llm_resp in self._run_llm_stream(messages, **(llm_opts or {})):
220+
yield llm_resp
221+
except (Exception, KeyboardInterrupt) as e:
222+
await self._callback_manager.on_llm_error(agent=self, llm=self.llm, error=e)
223+
raise e
224+
else:
225+
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
226+
return
227+
159228
@final
160229
async def run_tool(self, tool_name: str, tool_args: str) -> ToolResponse:
161230
"""Run the specified tool asynchronously.
@@ -221,7 +290,32 @@ def get_file_manager(self) -> FileManager:
221290
async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse:
222291
raise NotImplementedError
223292

293+
@abc.abstractmethod
294+
async def _run_stream(
295+
self, prompt: str, files: Optional[Sequence[File]] = None
296+
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
297+
"""
298+
Abstract asynchronous generator method that should be implemented by subclasses.
299+
This method should yield a sequence of (AgentStep, List[Message]) tuples based on the given
300+
prompt and optionally accompanying files.
301+
"""
302+
if TYPE_CHECKING:
303+
# HACK
304+
# This conditional block is strictly for static type-checking purposes (e.g., mypy)
305+
# and will not be executed.
306+
only_for_mypy_type_check: Tuple[AgentStep, List[Message]] = (DEFAULT_FINISH_STEP, [])
307+
yield only_for_mypy_type_check
308+
224309
async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
310+
"""Run the LLM with the given messages and options.
311+
312+
Args:
313+
messages: The input messages.
314+
opts: Options to pass to the LLM.
315+
316+
Returns:
317+
Response from the LLM.
318+
"""
225319
for reserved_opt in ("stream", "system", "plugins"):
226320
if reserved_opt in opts:
227321
raise TypeError(f"`{reserved_opt}` should not be set.")
@@ -241,6 +335,36 @@ async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
241335
llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts)
242336
return LLMResponse(message=llm_ret)
243337

338+
async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIterator[LLMResponse]:
339+
"""Run the LLM, yielding an async iterator of responses.
340+
341+
Args:
342+
messages: The input messages.
343+
opts: Options to pass to the LLM.
344+
345+
Returns:
346+
Async iterator of responses from the LLM.
347+
"""
348+
for reserved_opt in ("stream", "system", "plugins"):
349+
if reserved_opt in opts:
350+
raise TypeError(f"`{reserved_opt}` should not be set.")
351+
352+
if "functions" not in opts:
353+
functions = self._tool_manager.get_tool_schemas()
354+
else:
355+
functions = opts.pop("functions")
356+
357+
if hasattr(self.llm, "system"):
358+
_logger.warning(
359+
"The `system` message has already been set in the agent;"
360+
"the `system` message configured in ERNIEBot will become ineffective."
361+
)
362+
opts["system"] = self.system.content if self.system is not None else None
363+
opts["plugins"] = self._plugins
364+
llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts)
365+
async for msg in llm_ret:
366+
yield LLMResponse(message=msg)
367+
244368
async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse:
245369
parsed_tool_args = self._parse_tool_args(tool_args)
246370
file_manager = self.get_file_manager()

erniebot-agent/src/erniebot_agent/agents/function_agent.py

+127-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,16 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Final, Iterable, List, Optional, Sequence, Tuple, Union
16+
from typing import (
17+
AsyncIterator,
18+
Final,
19+
Iterable,
20+
List,
21+
Optional,
22+
Sequence,
23+
Tuple,
24+
Union,
25+
)
1726

1827
from erniebot_agent.agents.agent import Agent
1928
from erniebot_agent.agents.callback.callback_manager import CallbackManager
@@ -31,7 +40,12 @@
3140
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
3241
from erniebot_agent.file import File, FileManager
3342
from erniebot_agent.memory import Memory
34-
from erniebot_agent.memory.messages import FunctionMessage, HumanMessage, Message
43+
from erniebot_agent.memory.messages import (
44+
AIMessage,
45+
FunctionMessage,
46+
HumanMessage,
47+
Message,
48+
)
3549
from erniebot_agent.tools.base import BaseTool
3650
from erniebot_agent.tools.tool_manager import ToolManager
3751

@@ -136,7 +150,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
136150
chat_history.append(run_input)
137151

138152
for tool in self._first_tools:
139-
curr_step, new_messages = await self._step(chat_history, selected_tool=tool)
153+
curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool)
140154
if not isinstance(curr_step, EndStep):
141155
chat_history.extend(new_messages)
142156
num_steps_taken += 1
@@ -167,23 +181,122 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
167181
response = self._create_stopped_response(chat_history, steps_taken)
168182
return response
169183

170-
async def _step(
184+
async def _call_first_tools(
171185
self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None
172186
) -> Tuple[AgentStep, List[Message]]:
173-
new_messages: List[Message] = []
174187
input_messages = self.memory.get_messages() + chat_history
175-
if selected_tool is not None:
176-
tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}}
177-
llm_resp = await self.run_llm(
178-
messages=input_messages,
179-
functions=[selected_tool.function_call_schema()], # only regist one tool
180-
tool_choice=tool_choice,
181-
)
182-
else:
188+
if selected_tool is None:
183189
llm_resp = await self.run_llm(messages=input_messages)
190+
return await self._process_step(llm_resp, chat_history)
191+
192+
tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}}
193+
llm_resp = await self.run_llm(
194+
messages=input_messages,
195+
functions=[selected_tool.function_call_schema()], # only regist one tool
196+
tool_choice=tool_choice,
197+
)
198+
return await self._process_step(llm_resp, chat_history)
199+
200+
async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Message]]:
201+
"""Run a step of the agent.
202+
Args:
203+
chat_history: The chat history to provide to the agent.
204+
Returns:
205+
A tuple of an agent step and a list of new messages.
206+
"""
207+
input_messages = self.memory.get_messages() + chat_history
208+
llm_resp = await self.run_llm(messages=input_messages)
209+
return await self._process_step(llm_resp, chat_history)
184210

211+
async def _step_stream(
212+
self, chat_history: List[Message]
213+
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
214+
"""Run a step of the agent in streaming mode.
215+
Args:
216+
chat_history: The chat history to provide to the agent.
217+
Returns:
218+
An async iterator that yields a tuple of an agent step and a list ofnew messages.
219+
"""
220+
input_messages = self.memory.get_messages() + chat_history
221+
async for llm_resp in self.run_llm_stream(messages=input_messages):
222+
yield await self._process_step(llm_resp, chat_history)
223+
224+
async def _run_stream(
225+
self, prompt: str, files: Optional[Sequence[File]] = None
226+
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
227+
"""Run the agent with the given prompt and files in streaming mode.
228+
Args:
229+
prompt: The prompt for the agent to run.
230+
files: A list of files for the agent to use. If `None`, use an empty
231+
list.
232+
Returns:
233+
If `stream` is `False`, an agent response object. If `stream` is
234+
`True`, an async iterator that yields agent steps one by one.
235+
"""
236+
chat_history: List[Message] = []
237+
steps_taken: List[AgentStep] = []
238+
239+
run_input = await HumanMessage.create_with_files(
240+
prompt, files or [], include_file_urls=self.file_needs_url
241+
)
242+
243+
num_steps_taken = 0
244+
chat_history.append(run_input)
245+
246+
for tool in self._first_tools:
247+
curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool)
248+
if not isinstance(curr_step, EndStep):
249+
chat_history.extend(new_messages)
250+
num_steps_taken += 1
251+
steps_taken.append(curr_step)
252+
else:
253+
# If tool choice not work, skip this round
254+
_logger.warning(f"Selected tool [{tool.tool_name}] not work")
255+
256+
is_finished = False
257+
new_messages = []
258+
end_step_msgs = []
259+
while is_finished is False:
260+
# IMPORTANT~! We use following code to get the response from LLM
261+
# When finish_reason is fuction_call, run_llm_stream return all info in one step, but
262+
# When finish_reason is normal chat, run_llm_stream return info in multiple steps.
263+
async for curr_step, new_messages in self._step_stream(chat_history):
264+
if isinstance(curr_step, ToolStep):
265+
steps_taken.append(curr_step)
266+
yield curr_step, new_messages
267+
268+
elif isinstance(curr_step, PluginStep):
269+
steps_taken.append(curr_step)
270+
# 预留 调用了Plugin之后不结束的接口
271+
272+
# 此处为调用了Plugin之后直接结束的Plugin
273+
curr_step = DEFAULT_FINISH_STEP
274+
yield curr_step, new_messages
275+
276+
elif isinstance(curr_step, EndStep):
277+
is_finished = True
278+
end_step_msgs.extend(new_messages)
279+
yield curr_step, new_messages
280+
else:
281+
raise RuntimeError("Invalid step type")
282+
chat_history.extend(new_messages)
283+
284+
self.memory.add_message(run_input)
285+
end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs]))
286+
self.memory.add_message(end_step_msg)
287+
288+
async def _process_step(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]:
289+
"""Process and execute a step of the agent from LLM response.
290+
Args:
291+
llm_resp: The LLM response to convert.
292+
chat_history: The chat history to provide to the agent.
293+
Returns:
294+
A tuple of an agent step and a list of new messages.
295+
"""
296+
new_messages: List[Message] = []
185297
output_message = llm_resp.message # AIMessage
186298
new_messages.append(output_message)
299+
# handle function call
187300
if output_message.function_call is not None:
188301
tool_name = output_message.function_call["name"]
189302
tool_args = output_message.function_call["arguments"]
@@ -198,6 +311,7 @@ async def _step(
198311
),
199312
new_messages,
200313
)
314+
# handle plugin info with input/output files
201315
elif output_message.plugin_info is not None:
202316
file_manager = self.get_file_manager()
203317
return (

0 commit comments

Comments
 (0)