Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Feature] 支持 stream agent 流式返回step和message #345

Merged
merged 8 commits into from
May 7, 2024
128 changes: 126 additions & 2 deletions erniebot-agent/src/erniebot_agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import json
import logging
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Final,
Iterable,
Expand All @@ -20,7 +22,13 @@
from erniebot_agent.agents.callback.default import get_default_callbacks
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
from erniebot_agent.agents.mixins import GradioMixin
from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse
from erniebot_agent.agents.schema import (
DEFAULT_FINISH_STEP,
AgentResponse,
AgentStep,
LLMResponse,
ToolResponse,
)
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.file import (
File,
Expand Down Expand Up @@ -131,13 +139,46 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
return agent_resp

@final
async def run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run the agent asynchronously, returning an async iterator of responses.

Args:
prompt: A natural language text describing the task that the agent
should perform.
files: A list of files that the agent can use to perform the task.
Returns:
Iterator of responses from the agent.
"""
if files:
await self._ensure_managed_files(files)
await self._callback_manager.on_run_start(agent=self, prompt=prompt)
try:
async for step, msg in self._run_stream(prompt, files):
yield (step, msg)
except BaseException as e:
await self._callback_manager.on_run_error(agent=self, error=e)
raise e
else:
await self._callback_manager.on_run_end(
agent=self,
response=AgentResponse(
text="Agent run stopped.",
chat_history=self.memory.get_messages(),
steps=[step],
status="STOPPED",
),
)

@final
async def run_llm(
self,
messages: List[Message],
**llm_opts: Any,
) -> LLMResponse:
"""Run the LLM asynchronously.
"""Run the LLM asynchronously, returning final response.

Args:
messages: The input messages.
Expand All @@ -156,6 +197,34 @@ async def run_llm(
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
return llm_resp

@final
async def run_llm_stream(
self,
messages: List[Message],
**llm_opts: Any,
) -> AsyncIterator[LLMResponse]:
"""Run the LLM asynchronously, returning an async iterator of responses

Args:
messages: The input messages.
llm_opts: Options to pass to the LLM.

Returns:
Iterator of responses from the LLM.
"""
llm_resp = None
await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages)
try:
# The LLM will return an async iterator.
async for llm_resp in self._run_llm_stream(messages, **(llm_opts or {})):
yield llm_resp
except (Exception, KeyboardInterrupt) as e:
await self._callback_manager.on_llm_error(agent=self, llm=self.llm, error=e)
raise e
else:
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
return

@final
async def run_tool(self, tool_name: str, tool_args: str) -> ToolResponse:
"""Run the specified tool asynchronously.
Expand Down Expand Up @@ -221,7 +290,32 @@ def get_file_manager(self) -> FileManager:
async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse:
raise NotImplementedError

@abc.abstractmethod
async def _run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""
Abstract asynchronous generator method that should be implemented by subclasses.
This method should yield a sequence of (AgentStep, List[Message]) tuples based on the given
prompt and optionally accompanying files.
"""
if TYPE_CHECKING:
# HACK
# This conditional block is strictly for static type-checking purposes (e.g., mypy)
# and will not be executed.
only_for_mypy_type_check: Tuple[AgentStep, List[Message]] = (DEFAULT_FINISH_STEP, [])
yield only_for_mypy_type_check

async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
"""Run the LLM with the given messages and options.

Args:
messages: The input messages.
opts: Options to pass to the LLM.

Returns:
Response from the LLM.
"""
for reserved_opt in ("stream", "system", "plugins"):
if reserved_opt in opts:
raise TypeError(f"`{reserved_opt}` should not be set.")
Expand All @@ -241,6 +335,36 @@ async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts)
return LLMResponse(message=llm_ret)

async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIterator[LLMResponse]:
"""Run the LLM, yielding an async iterator of responses.

Args:
messages: The input messages.
opts: Options to pass to the LLM.

Returns:
Async iterator of responses from the LLM.
"""
for reserved_opt in ("stream", "system", "plugins"):
if reserved_opt in opts:
raise TypeError(f"`{reserved_opt}` should not be set.")

if "functions" not in opts:
functions = self._tool_manager.get_tool_schemas()
else:
functions = opts.pop("functions")

if hasattr(self.llm, "system"):
_logger.warning(
"The `system` message has already been set in the agent;"
"the `system` message configured in ERNIEBot will become ineffective."
)
opts["system"] = self.system.content if self.system is not None else None
opts["plugins"] = self._plugins
llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts)
async for msg in llm_ret:
yield LLMResponse(message=msg)

async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse:
parsed_tool_args = self._parse_tool_args(tool_args)
file_manager = self.get_file_manager()
Expand Down
140 changes: 127 additions & 13 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
# limitations under the License.

import logging
from typing import Final, Iterable, List, Optional, Sequence, Tuple, Union
from typing import (
AsyncIterator,
Final,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)

from erniebot_agent.agents.agent import Agent
from erniebot_agent.agents.callback.callback_manager import CallbackManager
Expand All @@ -31,7 +40,12 @@
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.file import File, FileManager
from erniebot_agent.memory import Memory
from erniebot_agent.memory.messages import FunctionMessage, HumanMessage, Message
from erniebot_agent.memory.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
Message,
)
from erniebot_agent.tools.base import BaseTool
from erniebot_agent.tools.tool_manager import ToolManager

Expand Down Expand Up @@ -136,7 +150,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
chat_history.append(run_input)

for tool in self._first_tools:
curr_step, new_messages = await self._step(chat_history, selected_tool=tool)
curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool)
if not isinstance(curr_step, EndStep):
chat_history.extend(new_messages)
num_steps_taken += 1
Expand Down Expand Up @@ -167,23 +181,122 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
response = self._create_stopped_response(chat_history, steps_taken)
return response

async def _step(
async def _call_first_tools(
self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None
) -> Tuple[AgentStep, List[Message]]:
new_messages: List[Message] = []
input_messages = self.memory.get_messages() + chat_history
if selected_tool is not None:
tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}}
llm_resp = await self.run_llm(
messages=input_messages,
functions=[selected_tool.function_call_schema()], # only regist one tool
tool_choice=tool_choice,
)
else:
if selected_tool is None:
llm_resp = await self.run_llm(messages=input_messages)
return await self._process_step(llm_resp, chat_history)

tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}}
llm_resp = await self.run_llm(
messages=input_messages,
functions=[selected_tool.function_call_schema()], # only regist one tool
tool_choice=tool_choice,
)
return await self._process_step(llm_resp, chat_history)

async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Message]]:
"""Run a step of the agent.
Args:
chat_history: The chat history to provide to the agent.
Returns:
A tuple of an agent step and a list of new messages.
"""
input_messages = self.memory.get_messages() + chat_history
llm_resp = await self.run_llm(messages=input_messages)
return await self._process_step(llm_resp, chat_history)

async def _step_stream(
self, chat_history: List[Message]
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run a step of the agent in streaming mode.
Args:
chat_history: The chat history to provide to the agent.
Returns:
An async iterator that yields a tuple of an agent step and a list ofnew messages.
"""
input_messages = self.memory.get_messages() + chat_history
async for llm_resp in self.run_llm_stream(messages=input_messages):
yield await self._process_step(llm_resp, chat_history)

async def _run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run the agent with the given prompt and files in streaming mode.
Args:
prompt: The prompt for the agent to run.
files: A list of files for the agent to use. If `None`, use an empty
list.
Returns:
If `stream` is `False`, an agent response object. If `stream` is
`True`, an async iterator that yields agent steps one by one.
"""
chat_history: List[Message] = []
steps_taken: List[AgentStep] = []

run_input = await HumanMessage.create_with_files(
prompt, files or [], include_file_urls=self.file_needs_url
)

num_steps_taken = 0
chat_history.append(run_input)

for tool in self._first_tools:
curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool)
if not isinstance(curr_step, EndStep):
chat_history.extend(new_messages)
num_steps_taken += 1
steps_taken.append(curr_step)
else:
# If tool choice not work, skip this round
_logger.warning(f"Selected tool [{tool.tool_name}] not work")

is_finished = False
new_messages = []
end_step_msgs = []
while is_finished is False:
# IMPORTANT~! We use following code to get the response from LLM
# When finish_reason is fuction_call, run_llm_stream return all info in one step, but
# When finish_reason is normal chat, run_llm_stream return info in multiple steps.
async for curr_step, new_messages in self._step_stream(chat_history):
if isinstance(curr_step, ToolStep):
steps_taken.append(curr_step)
yield curr_step, new_messages

elif isinstance(curr_step, PluginStep):
steps_taken.append(curr_step)
# 预留 调用了Plugin之后不结束的接口

# 此处为调用了Plugin之后直接结束的Plugin
curr_step = DEFAULT_FINISH_STEP
yield curr_step, new_messages

elif isinstance(curr_step, EndStep):
is_finished = True
end_step_msgs.extend(new_messages)
yield curr_step, new_messages
else:
raise RuntimeError("Invalid step type")
chat_history.extend(new_messages)

self.memory.add_message(run_input)
end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs]))
self.memory.add_message(end_step_msg)

async def _process_step(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]:
"""Process and execute a step of the agent from LLM response.
Args:
llm_resp: The LLM response to convert.
chat_history: The chat history to provide to the agent.
Returns:
A tuple of an agent step and a list of new messages.
"""
new_messages: List[Message] = []
output_message = llm_resp.message # AIMessage
new_messages.append(output_message)
# handle function call
if output_message.function_call is not None:
tool_name = output_message.function_call["name"]
tool_args = output_message.function_call["arguments"]
Expand All @@ -198,6 +311,7 @@ async def _step(
),
new_messages,
)
# handle plugin info with input/output files
elif output_message.plugin_info is not None:
file_manager = self.get_file_manager()
return (
Expand Down
Loading
Loading