Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Feature] 支持 stream agent 流式返回step和message #345

Merged
merged 8 commits into from
May 7, 2024
Prev Previous commit
Next Next commit
使用python -m mypy src检查和优化代码
xiabo0816 committed Apr 28, 2024
commit 72f97f6fc7dfcefc5285526b1721aa43c98d1daa
30 changes: 23 additions & 7 deletions erniebot-agent/src/erniebot_agent/agents/agent.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
from erniebot_agent.agents.mixins import GradioMixin
from erniebot_agent.agents.schema import (
DEFAULT_FINISH_STEP,
AgentResponse,
AgentStep,
LLMResponse,
@@ -140,7 +141,7 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
@final
async def run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[AgentStep]:
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run the agent asynchronously.

Args:
@@ -154,14 +155,21 @@ async def run_stream(
await self._ensure_managed_files(files)
await self._callback_manager.on_run_start(agent=self, prompt=prompt)
try:
async for step in self._run_stream(prompt, files):
yield step
async for step, msg in self._run_stream(prompt, files):
yield (step, msg)
except BaseException as e:
await self._callback_manager.on_run_error(agent=self, error=e)
raise e
else:
await self._callback_manager.on_run_end(agent=self, response=None)
return
await self._callback_manager.on_run_end(
agent=self,
response=AgentResponse(
text="Agent run stopped early.",
chat_history=self.memory.get_messages(),
steps=[step],
status="STOPPED",
),
)

@final
async def run_llm(
@@ -284,8 +292,16 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
@abc.abstractmethod
async def _run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[AgentStep]:
raise NotImplementedError
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""
Abstract asynchronous generator method that should be implemented by subclasses.
This method should yield a sequence of (AgentStep, List[Message]) tuples based on the given
prompt and optionally accompanying files.
"""
if False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是为什么呀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

特别不好意思,这个是对mypy的妥协:

这里直接pass或者raise都会mypy编译器报错,原因应该是返回AsyncIterator时,如果内部没有yield就认为返回值是Coroutine,从而导致和子类重载函数返回值类型(FunctionAgent._run_stream)不同而报错;由于我一直没能找到通过mypy检测的抽象方法@abc.abstractmethod async def _run_stream的返回值数据类型,就很丑陋的if False:了,哭哭

Copy link
Member

@Bobholamovic Bobholamovic Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

噢噢,好吧,那我觉得可以先保留这个(不过或许if False可以用if typing.TYPE_CHECKING代替,顺带可以在注释里加个HACK标记)~

# This conditional block is strictly for static type-checking purposes (e.g., mypy) and will not be executed.
only_for_mypy_type_check: Tuple[AgentStep, List[Message]] = (DEFAULT_FINISH_STEP, [])
yield only_for_mypy_type_check

async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
for reserved_opt in ("stream", "system", "plugins"):
9 changes: 5 additions & 4 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
@@ -182,7 +182,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
return response

async def _first_tool_step(
self, chat_history: List[Message], selected_tool: BaseTool = None
self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None
) -> Tuple[AgentStep, List[Message]]:
input_messages = self.memory.get_messages() + chat_history
if selected_tool is None:
@@ -254,7 +254,6 @@ async def _run_stream(
_logger.warning(f"Selected tool [{tool.tool_name}] not work")

is_finished = False
curr_step = None
new_messages = []
end_step_msgs = []
while is_finished is False:
@@ -274,17 +273,19 @@ async def _run_stream(
curr_step = DEFAULT_FINISH_STEP
yield curr_step, new_messages

if isinstance(curr_step, EndStep):
elif isinstance(curr_step, EndStep):
is_finished = True
end_step_msgs.extend(new_messages)
yield curr_step, new_messages
else:
raise RuntimeError("Invalid step type")
chat_history.extend(new_messages)

self.memory.add_message(run_input)
end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs]))
self.memory.add_message(end_step_msg)

async def _schema_format(self, llm_resp, chat_history):
async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]:
"""Convert the LLM response to the agent response schema.
Args:
llm_resp: The LLM response to convert.