diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 124447575..fcfb0d72f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,25 +1,21 @@ repos: -- repo: https://github.com/psf/black.git +- repo: https://githubfast.com/psf/black.git rev: 23.3.0 hooks: - id: black files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ -- repo: https://github.com/pycqa/isort - rev: 5.11.5 - hooks: - - id: isort -- repo: https://github.com/PyCQA/flake8 +- repo: https://githubfast.com/PyCQA/flake8 rev: 5.0.4 hooks: - id: flake8 args: [--config=.flake8] -- repo: https://github.com/pre-commit/mirrors-mypy +- repo: https://githubfast.com/pre-commit/mirrors-mypy rev: v1.6.1 hooks: - id: mypy exclude: ^(setup\.py|.*tests.*)$ additional_dependencies: ['types-requests', 'types-PyYAML'] -- repo: https://github.com/pre-commit/pre-commit-hooks +- repo: https://githubfast.com/pre-commit/pre-commit-hooks rev: a11d9314b22d8f8c7556443875b731ef05965464 hooks: - id: check-merge-conflict @@ -28,7 +24,7 @@ repos: - id: trailing-whitespace - id: detect-private-key - id: check-added-large-files -- repo: https://github.com/Lucas-C/pre-commit-hooks +- repo: https://githubfast.com/Lucas-C/pre-commit-hooks rev: v1.0.1 hooks: - id: forbid-crlf diff --git a/erniebot-agent/erniebot_agent/agents/agentchat.py b/erniebot-agent/erniebot_agent/agents/agentchat.py new file mode 100644 index 000000000..d7c1aceef --- /dev/null +++ b/erniebot-agent/erniebot_agent/agents/agentchat.py @@ -0,0 +1,84 @@ +from typing import Dict, List, Optional, Union + + +class AgentChat: + """(In preview) An abstract class for AI agent. + + An agent can communicate with other agents and perform actions. + Different agents can differ in what actions they perform in the `receive` method. + """ + + def __init__(self, name: str): + """ + Args: + name (str): name of the agent. + """ + # a dictionary of conversations, default value is list + self._name = name + + @property + def name(self): + """Get the name of the agent.""" + return self._name + + def send(self, message: Union[Dict, str], recipient: "AgentChat", request_reply: Optional[bool] = None): + """(Abstract method) Send a message to another agent.""" + + async def a_send( + self, message: Union[Dict, str], recipient: "AgentChat", request_reply: Optional[bool] = None + ): + """(Abstract async method) Send a message to another agent.""" + + def receive( + self, + message: Union[Dict, str], + sender: "AgentChat", + request_reply: Optional[bool] = None, + silent: Optional[bool] = None, + ): + """(Abstract method) Receive a message from another agent.""" + + async def a_receive( + self, + message: Union[Dict, str], + sender: "AgentChat", + request_reply: Optional[bool] = None, + silent: Optional[bool] = None, + ): + """(Abstract async method) Receive a message from another agent.""" + + def reset(self): + """(Abstract method) Reset the agent.""" + + def can_execute_function(self, name: str): + """Determine whether the function can be executed""" + + def generate_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional["AgentChat"] = None, + **kwargs, + ) -> Union[str, Dict, None]: + """(Abstract method) Generate a reply based on the received messages. + + Args: + messages (list[dict]): a list of messages received. + sender: sender of an Agent instance. + Returns: + str or dict or None: the generated reply. If None, no reply is generated. + """ + + async def a_generate_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional["AgentChat"] = None, + **kwargs, + ) -> Union[str, Dict, None]: + """(Abstract async method) Generate a reply based on the received messages. + + Args: + messages (list[dict]): a list of messages received. + sender: sender of an Agent instance. + Returns: + str or dict or None: the generated reply. If None, no reply is generated. + """ diff --git a/erniebot-agent/erniebot_agent/agents/assistant_agent.py b/erniebot-agent/erniebot_agent/agents/assistant_agent.py new file mode 100644 index 000000000..ef2e359c1 --- /dev/null +++ b/erniebot-agent/erniebot_agent/agents/assistant_agent.py @@ -0,0 +1,31 @@ +from .conversable_agent import ConversableAgent +from typing import Callable, Dict, Optional + + +class AssistantAgent(ConversableAgent): + DEFAULT_SYSTEM_MESSAGE = """您是一位有用的人工智能助手。请利用您的语言技能解决任务。 + 如果需要,请逐步解决任务。如果未提供计划,请先解释您的计划。 + 当你找到答案时,请仔细验证答案。 + 如果可能,请在您的回复中包含可验证的证据。 + 当一切完成后,最后回复'终止' + """ + + def __init__( + self, + name: str, + llm_config: Dict, + system_message: str = DEFAULT_SYSTEM_MESSAGE, + is_termination_msg: Optional[Callable[[Dict], bool]] = None, + max_consecutive_auto_reply: Optional[int] = None, + human_input_mode: Optional[str] = "NEVER", + **kwargs, + ): + super().__init__( + name, + llm_config, + system_message, + is_termination_msg, + max_consecutive_auto_reply, + human_input_mode, + **kwargs, + ) diff --git a/erniebot-agent/erniebot_agent/agents/conversable_agent.py b/erniebot-agent/erniebot_agent/agents/conversable_agent.py new file mode 100644 index 000000000..a16b22b09 --- /dev/null +++ b/erniebot-agent/erniebot_agent/agents/conversable_agent.py @@ -0,0 +1,730 @@ +import asyncio +from collections import defaultdict +import copy +import json +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from .agentchat import AgentChat +import erniebot +import time + +try: + from termcolor import colored +except ImportError: + + def colored(x, *args, **kwargs): + return x + + +logger = logging.getLogger(__name__) + + +class ConversableAgent(AgentChat): + MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change) + + def __init__( + self, + name: str, + llm_config: Dict, + system_message: str = "你是一个有用的人工智能助手。", + is_termination_msg: Optional[Callable[[Dict], bool]] = None, + max_consecutive_auto_reply: Optional[int] = None, + human_input_mode: Optional[str] = "NEVER", + function_map: Optional[Dict[str, Callable]] = None, + default_auto_reply: Optional[Union[str, Dict, None]] = "", + ): + super().__init__(name) + self._oai_messages: Dict[AgentChat, List] = defaultdict(list) + self._oai_system_message = system_message + self._is_termination_msg = ( + is_termination_msg if is_termination_msg is not None else (lambda x: x.get("content") == "终止") + ) + + erniebot.api_type = llm_config["api_type"] # type: ignore + erniebot.access_token = llm_config["access_token"] # type: ignore + self.model = llm_config.get("model", "ernie-bot-4") + self.human_input_mode = human_input_mode + self._max_consecutive_auto_reply = ( + max_consecutive_auto_reply + if max_consecutive_auto_reply is not None + else self.MAX_CONSECUTIVE_AUTO_REPLY + ) + self._consecutive_auto_reply_counter: Dict[AgentChat, int] = defaultdict(int) + self._max_consecutive_auto_reply_dict: Dict[AgentChat, int] = defaultdict( + self.max_consecutive_auto_reply + ) + self._function_map = {} if function_map is None else function_map + self._default_auto_reply = default_auto_reply + self._reply_func_list: List = [] + self.reply_at_receive: Dict[AgentChat, bool] = defaultdict(bool) + self.register_reply([AgentChat, None], ConversableAgent.generate_oai_reply) + self.register_reply([AgentChat, None], ConversableAgent.generate_function_call_reply) + self.register_reply([AgentChat, None], ConversableAgent.generate_async_function_call_reply) + self.register_reply([AgentChat, None], ConversableAgent.check_termination_and_human_reply) + + def register_reply( + self, + trigger: Union[Type[AgentChat], str, AgentChat, Callable[[AgentChat], bool], List], + reply_func: Callable, + position: int = 0, + config: Optional[Any] = None, + reset_config: Optional[Callable] = None, + ): + if not isinstance(trigger, (type, str, AgentChat, Callable, list)): # type: ignore + raise ValueError("trigger must be a class, a string, an agent, a callable or a list.") + self._reply_func_list.insert( + position, + { + "trigger": trigger, + "reply_func": reply_func, + "config": copy.copy(config), + "init_config": config, + "reset_config": reset_config, + }, + ) + + @property + def system_message(self): + return self._oai_system_message + + def update_system_message(self, system_message: str): + self._oai_system_message = system_message + + def update_max_consecutive_auto_reply(self, value: int, sender: Optional[AgentChat] = None): + if sender is None: + self._max_consecutive_auto_reply = value + for k in self._max_consecutive_auto_reply_dict: + self._max_consecutive_auto_reply_dict[k] = value + else: + self._max_consecutive_auto_reply_dict[sender] = value + + def max_consecutive_auto_reply(self, sender: Optional[AgentChat] = None) -> int: + return ( + self._max_consecutive_auto_reply + if sender is None + else self._max_consecutive_auto_reply_dict[sender] + ) + + @property + def chat_messages(self) -> Dict[AgentChat, List[Dict]]: + return self._oai_messages + + def last_message(self, agent: Optional[AgentChat] = None) -> Optional[Dict]: + if agent is None: + n_conversations = len(self._oai_messages) + if n_conversations == 0: + return None + if n_conversations == 1: + for conversation in self._oai_messages.values(): + return conversation[-1] + raise ValueError( + "More than one conversation is found. Please specify the sender to get the last message." + ) + if agent not in self._oai_messages.keys(): + raise KeyError( + f"The agent '{agent.name}' is not present in any conversation. " + + "No history available for this agent." + ) + return self._oai_messages[agent][-1] + + @staticmethod + def _message_to_dict(message: Union[Dict, str]) -> Dict[str, Any]: + """Convert a message to a dictionary. + + The message can be a string or a dictionary. The string will + be put in the "content" field of the new dictionary. + """ + if isinstance(message, str): + return {"content": message} + elif isinstance(message, dict): + return message + else: + return dict(message) + + def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: AgentChat) -> bool: + message = self._message_to_dict(message) + # create oai message to be appended to the oai conversation that can be passed to oai directly. + oai_message = { + k: message[k] for k in ("content", "function_call", "name", "context") if k in message + } + if "content" not in oai_message: + if "function_call" in oai_message: + oai_message["content"] = None + else: + return False + oai_message["role"] = "function" if message.get("role") == "function" else role + if "function_call" in oai_message: + oai_message["role"] = "assistant" + oai_message["function_call"] = dict(oai_message["function_call"]) + if len(self._oai_messages[conversation_id]) == 0: + oai_message["role"] = "user" + elif len(self._oai_messages[conversation_id]) > 0: + if self._oai_messages[conversation_id][-1]["role"] == "user": + oai_message["role"] = "assistant" + elif self._oai_messages[conversation_id][-1]["role"] == "assistant": + oai_message["role"] = "user" + self._oai_messages[conversation_id].append(oai_message) + return True + + def send( + self, + message: Union[Dict, str], + recipient: AgentChat, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ): + valid = self._append_oai_message(message, "assistant", recipient) + if valid: + recipient.receive(message=message, sender=self, request_reply=request_reply, silent=silent) + else: + raise ValueError( + "Message can't be converted into a valid ChatCompletion message." + + "Either content or function_call must be provided." + ) + + async def a_send( + self, + message: Union[Dict, str], + recipient: AgentChat, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ): + valid = self._append_oai_message(message, "assistant", recipient) + if valid: + await recipient.a_receive( + message=message, sender=self, request_reply=request_reply, silent=silent + ) + else: + raise ValueError( + "Message can't be converted into a valid ChatCompletion message. " + + "Either content or function_call must be provided." + ) + + def _print_received_message(self, message: Union[Dict, str], sender: AgentChat): + print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True) + message = self._message_to_dict(message) + if message.get("role") == "function": + func_print = f"***** Response from calling function \"{message['name']}\" *****" + print(colored(func_print, "green"), flush=True) + print(message["content"], flush=True) + print(colored("*" * len(func_print), "green"), flush=True) + else: + content = message.get("content") + if content is not None: + print(content, flush=True) + if "function_call" in message: + function_call = dict(message["function_call"]) + func_print = ( + f"***** Suggested function " + f"Call: {function_call.get('name', '(No function name found)')} *****" + ) + print(colored(func_print, "green"), flush=True) + print( + "Arguments: \n", + function_call.get("arguments", "(No arguments found)"), + flush=True, + sep="", + ) + print(colored("*" * len(func_print), "green"), flush=True) + print("\n", "-" * 80, flush=True, sep="") + + def _process_received_message(self, message, sender: AgentChat, silent): + message = self._message_to_dict(message) + valid = self._append_oai_message(message, "user", sender) + if not valid: + raise ValueError( + "Received message can't be converted into a valid " + + "ChatCompletion message. Either content or function_call " + + "must be provided." + ) + if not silent: + self._print_received_message(message, sender) + + def receive( # type: ignore + self, # type: ignore + sender: AgentChat, # type: ignore + message: Union[Dict, str, None] = None, # type: ignore + request_reply: Optional[bool] = None, # type: ignore + silent: Optional[bool] = False, # type: ignore + ): # type: ignore + self._process_received_message(message=message, sender=sender, silent=silent) + if request_reply is False or request_reply is None and self.reply_at_receive[sender] is False: + return + reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender) + if reply is not None: + self.send(message=reply, recipient=sender, silent=silent) + + async def a_receive( # type: ignore + self, # type: ignore + sender: AgentChat, # type: ignore + message: Union[Dict, str, None] = None, # type: ignore + request_reply: Optional[bool] = None, # type: ignore + silent: Optional[bool] = False, # type: ignore + ): # type: ignore + self._process_received_message(message, sender, silent) + if request_reply is False or request_reply is None and self.reply_at_receive[sender] is False: + return + reply = await self.a_generate_reply(sender=sender) + if reply is not None: + await self.a_send(message=reply, recipient=sender, silent=silent) + + def _prepare_chat(self, recipient, clear_history): + self.reset_consecutive_auto_reply_counter(recipient) + recipient.reset_consecutive_auto_reply_counter(self) + self.reply_at_receive[recipient] = recipient.reply_at_receive[self] = True + if clear_history: + self.clear_history(recipient) + recipient.clear_history(self) + + def initiate_chat( + self, + recipient: "ConversableAgent", + clear_history: Optional[bool] = True, + silent: Optional[bool] = False, + **context, + ): + self._prepare_chat(recipient, clear_history) + self.send(message=self.generate_init_message(**context), recipient=recipient, silent=silent) + + async def a_initiate_chat( + self, + recipient: AgentChat, + clear_history: Optional[bool] = True, + silent: Optional[bool] = False, + **context, + ): + self._prepare_chat(recipient, clear_history) + await self.a_send(message=self.generate_init_message(**context), recipient=recipient, silent=silent) + + def reset(self): + self.clear_history() + self.reset_consecutive_auto_reply_counter() + self.stop_reply_at_receive() + for reply_func_tuple in self._reply_func_list: + if reply_func_tuple["reset_config"] is not None: + reply_func_tuple["reset_config"](reply_func_tuple["config"]) + else: + reply_func_tuple["config"] = copy.copy(reply_func_tuple["init_config"]) + + def stop_reply_at_receive(self, sender: Optional[AgentChat] = None): + if sender is None: + self.reply_at_receive.clear() + else: + self.reply_at_receive[sender] = False + + def reset_consecutive_auto_reply_counter(self, sender: Optional[AgentChat] = None): + if sender is None: + self._consecutive_auto_reply_counter.clear() + else: + self._consecutive_auto_reply_counter[sender] = 0 + + def clear_history(self, agent: Optional[AgentChat] = None): + if agent is None: + self._oai_messages.clear() + else: + self._oai_messages[agent].clear() + + def generate_oai_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[AgentChat] = None, + **kwargs, + ) -> Tuple[bool, Union[str, Dict, None]]: + systems = kwargs.get("systems", "") + if messages is None: + messages = self._oai_messages[sender] # type: ignore + if len(messages) % 2 == 0: + if messages[0]["role"] == "assistant": + return False, None + else: + messages.append({"role": "user", "content": "请你思考任务是否完成,如果完成则输出'终止'即可,否则请输出完成任务的流程"}) + try: + response = erniebot.ChatCompletion.create( # type: ignore + model=self.model, messages=messages, system=self._oai_system_message + systems + ) # type: ignore + return True, response.get_result() + except Exception as e: + print(e) + time.sleep(10) + response = erniebot.ChatCompletion.create( # type: ignore + model=self.model, messages=messages, system=self._oai_system_message + systems + ) # type: ignore + return True, response.get_result() + + def generate_function_call_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[AgentChat] = None, + config: Optional[Any] = None, + ): + """Generate a reply using function call.""" + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] # type: ignore + message = messages[-1] + if "function_call" in message: + _, func_return = self.execute_function(message["function_call"]) + return True, func_return + return False, None + + async def generate_async_function_call_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[AgentChat] = None, + config: Optional[Any] = None, + ): + """Generate a reply using async function call.""" + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] # type: ignore + message = messages[-1] + if "function_call" in message: + func_call = message["function_call"] + func_name = func_call.get("name", "") + func = self._function_map.get(func_name, None) + if func and asyncio.coroutines.iscoroutinefunction(func): + _, func_return = await self.a_execute_function(func_call) + return True, func_return + + return False, None + + def check_termination_and_human_reply( + self, + sender: AgentChat, + messages: Optional[List[Dict]] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, Dict, None]]: + """Check if the conversation should be terminated, and if human reply is provided.""" + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + reply = "" + no_human_input_msg = "" + if self.human_input_mode == "ALWAYS": + reply = self.get_human_input( + f"Provide feedback to {sender.name}. Press enter to skip " + + "and use auto-reply, or type 'exit' to end the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a + # termination message, then we will terminate the conversation + reply = reply if reply or not self._is_termination_msg(message) else "exit" + else: + if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + if self.human_input_mode == "NEVER": + reply = "exit" + else: + terminate = self._is_termination_msg(message) + reply = self.get_human_input( + f"Please give feedback to {sender.name}. " + + "Press enter or type 'exit' to stop the conversation: " + if terminate + else f"Please give feedback to {sender.name}. " + + "Press enter to skip and use auto-reply, " + + "or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message + # is a termination message, then we will terminate the conversation + reply = reply if reply or not terminate else "exit" + elif self._is_termination_msg(message): + if self.human_input_mode == "NEVER": + reply = "exit" + else: + reply = self.get_human_input( + f"Please give feedback to {sender.name}. " + + "Press enter or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is + # a termination message, then we will terminate + # the conversation + reply = reply or "exit" + + # print the no_human_input_msg + if no_human_input_msg: + print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True) + + # stop the conversation + if reply == "exit": + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + return True, None + + # send the human reply + if reply or self._max_consecutive_auto_reply_dict[sender] == 0: + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + return True, reply + + # increment the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] += 1 + if self.human_input_mode != "NEVER": + print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True) + return False, None + + async def a_check_termination_and_human_reply( + self, + sender: AgentChat, + messages: Optional[List[Dict]] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, Dict, None]]: + """(async) Check if the conversation should be terminated, + and if human reply is provided.""" + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + reply = "" + no_human_input_msg = "" + if self.human_input_mode == "ALWAYS": + reply = await self.a_get_human_input( + f"Provide feedback to {sender.name}. " + + "Press enter to skip and use auto-reply, " + + "or type 'exit' to end the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a + # termination message, then we will terminate the conversation + reply = reply if reply or not self._is_termination_msg(message) else "exit" + else: + if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + if self.human_input_mode == "NEVER": + reply = "exit" + else: + terminate = self._is_termination_msg(message) + reply = await self.a_get_human_input( + f"Please give feedback to {sender.name}. " + + "Press enter or type 'exit' to stop the conversation: " + if terminate + else f"Please give feedback to {sender.name}. " + + "Press enter to skip and use auto-reply, or " + + "type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is + # a termination message, then we will terminate the conversation + reply = reply if reply or not terminate else "exit" + elif self._is_termination_msg(message): + if self.human_input_mode == "NEVER": + reply = "exit" + else: + reply = await self.a_get_human_input( + f"Please give feedback to {sender.name}. " + + "Press enter or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message + # is a termination message, then we will + # terminate the conversation + reply = reply or "exit" + + # print the no_human_input_msg + if no_human_input_msg: + print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True) + + # stop the conversation + if reply == "exit": + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + return True, None + + # send the human reply + if reply or self._max_consecutive_auto_reply_dict[sender] == 0: + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + return True, reply + + # increment the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] += 1 + if self.human_input_mode != "NEVER": + print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True) + + return False, None + + def generate_reply( # type: ignore + self, # type: ignore + sender: AgentChat, # type: ignore + messages: Optional[List[Dict]] = None, # type: ignore + exclude: Optional[List[Callable]] = None, # type: ignore + ) -> Union[str, Dict, None]: # type: ignore + if all((messages is None, sender is None)): + error_msg = f"Either {messages=} or {sender=} must be provided." + logger.error(error_msg) + raise AssertionError(error_msg) + + if messages is None: + messages = self._oai_messages[sender] + for reply_func_tuple in self._reply_func_list: + reply_func = reply_func_tuple["reply_func"] + if exclude and reply_func in exclude: + continue + if asyncio.coroutines.iscoroutinefunction(reply_func): + continue + if self._match_trigger(reply_func_tuple["trigger"], sender): + final, reply = reply_func( + self, messages=messages, sender=sender, config=reply_func_tuple["config"] + ) + if final: + return reply + return self._default_auto_reply + + async def a_generate_reply( # type: ignore + self, # type: ignore + sender: AgentChat, # type: ignore + messages: Optional[List[Dict]] = None, # type: ignore + exclude: Optional[List[Callable]] = None, # type: ignore + ) -> Union[str, Dict, None]: # type: ignore + if all((messages is None, sender is None)): + error_msg = f"Either {messages=} or {sender=} must be provided." + logger.error(error_msg) + raise AssertionError(error_msg) + if messages is None: + messages = self._oai_messages[sender] + for reply_func_tuple in self._reply_func_list: + reply_func = reply_func_tuple["reply_func"] + if exclude and reply_func in exclude: + continue + if self._match_trigger(reply_func_tuple["trigger"], sender): + if asyncio.coroutines.iscoroutinefunction(reply_func): + final, reply = await reply_func( + self, messages=messages, sender=sender, config=reply_func_tuple["config"] + ) + else: + final, reply = reply_func( + self, messages=messages, sender=sender, config=reply_func_tuple["config"] + ) + if final: + return reply + return self._default_auto_reply + + def _match_trigger(self, trigger, sender): + """Check if the sender matches the trigger.""" + if trigger is None: + return sender is None + elif isinstance(trigger, str): + return trigger == sender.name + elif isinstance(trigger, type): + return isinstance(sender, trigger) + elif isinstance(trigger, AgentChat): + return trigger == sender + elif isinstance(trigger, Callable): + return trigger(sender) + elif isinstance(trigger, list): + return any(self._match_trigger(t, sender) for t in trigger) + else: + raise ValueError(f"Unsupported trigger type: {type(trigger)}") + + def get_human_input(self, prompt: str) -> str: + reply = input(prompt) + return reply + + async def a_get_human_input(self, prompt: str) -> str: + reply = input(prompt) + return reply + + @staticmethod + def _format_json_str(jstr): + result = [] + inside_quotes = False + last_char = " " + for char in jstr: + if last_char != "\\" and char == '"': + inside_quotes = not inside_quotes + last_char = char + if not inside_quotes and char == "\n": + continue + if inside_quotes and char == "\n": + char = "\\n" + if inside_quotes and char == "\t": + char = "\\t" + result.append(char) + return "".join(result) + + def execute_function(self, func_call): + func_name = func_call.get("name", "") + func = self._function_map.get(func_name, None) + + is_exec_success = False + if func is not None: + # Extract arguments from a json-like string and put it into a dict. + input_string = self._format_json_str(func_call.get("arguments", "{}")) + try: + arguments = json.loads(input_string) + except json.JSONDecodeError as e: + arguments = None + content = f"Error: {e}\n You argument should follow json format." + # Try to execute the function + if arguments is not None: + print( + colored(f"\n>>>>>>>> EXECUTING FUNCTION {func_name}...", "magenta"), + flush=True, + ) + try: + content = func(**arguments) + is_exec_success = True + except Exception as e: + content = f"Error: {e}" + else: + content = f"Error: Function {func_name} not found." + return is_exec_success, { + "name": func_name, + "role": "function", + "content": str(content), + } + + async def a_execute_function(self, func_call): + func_name = func_call.get("name", "") + func = self._function_map.get(func_name, None) + + is_exec_success = False + if func is not None: + # Extract arguments from a json-like string and put it into a dict. + input_string = self._format_json_str(func_call.get("arguments", "{}")) + try: + arguments = json.loads(input_string) + except json.JSONDecodeError as e: + arguments = None + content = f"Error: {e}\n You argument should follow json format." + + # Try to execute the function + if arguments is not None: + print( + colored(f"\n>>>>>>>> EXECUTING ASYNC FUNCTION {func_name}...", "magenta"), + flush=True, + ) + try: + if asyncio.coroutines.iscoroutinefunction(func): + content = await func(**arguments) + else: + # Fallback to sync function if the function is not async + content = func(**arguments) + is_exec_success = True + except Exception as e: + content = f"Error: {e}" + else: + content = f"Error: Function {func_name} not found." + + return is_exec_success, { + "name": func_name, + "role": "function", + "content": str(content), + } + + def generate_init_message(self, **context) -> Union[str, Dict]: + return context["message"] + + def register_function(self, function_map: Dict[str, Callable]): + self._function_map.update(function_map) + + def can_execute_function(self, name: str) -> bool: + return name in self._function_map + + @property + def function_map(self) -> Dict[str, Callable]: + return self._function_map diff --git a/erniebot-agent/erniebot_agent/agents/groupchat.py b/erniebot-agent/erniebot_agent/agents/groupchat.py new file mode 100644 index 000000000..99d07fe07 --- /dev/null +++ b/erniebot-agent/erniebot_agent/agents/groupchat.py @@ -0,0 +1,336 @@ +import logging +import sys +import random +from dataclasses import dataclass +from typing import Dict, List, Optional +import re +from .agentchat import AgentChat +from .conversable_agent import ConversableAgent + +logger = logging.getLogger(__name__) +import copy + + +@dataclass +class GroupChat: + agents: List[ConversableAgent] + messages: List[Dict] + max_round: int = 10 + admin_name: str = "Admin" + func_call_filter: bool = True + speaker_selection_method: str = "auto" + allow_repeat_speaker: bool = True + + _VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"] + + @property + def agent_names(self) -> List[str]: + """Return the names of the agents in the group chat.""" + return [agent.name for agent in self.agents] + + def reset(self): + """Reset the group chat.""" + self.messages.clear() + + def agent_by_name(self, name): + """Returns the agent with a given name.""" + return self.agents[self.agent_names.index(name)] + + def next_agent(self, agent: AgentChat, agents: List[ConversableAgent]): + """Return the next agent in the list.""" + if agents == self.agents: + return agents[(self.agent_names.index(agent.name) + 1) % len(agents)] + else: + offset = self.agent_names.index(agent.name) + 1 + for i in range(len(self.agents)): + if self.agents[(offset + i) % len(self.agents)] in agents: + return self.agents[(offset + i) % len(self.agents)] + + def select_speaker_msg(self, agents: List[ConversableAgent]): + """Return the message for selecting the next speaker.""" + return f"""您正在玩角色扮演游戏。可以使用以下角色: + {self._participant_roles(agents)}. + 阅读下面的对话。 + 然后从 {[agent.name for agent in agents]}选择下一个角色去扮演。只返回角色。""" + + def manual_select_speaker(self, agents: List[ConversableAgent]): + """Manually select the next speaker.""" + + print("请从以下列表中选择下一位agent") + _n_agents = len(agents) + for i in range(_n_agents): + print(f"{i+1}: {agents[i].name}") + try_count = 0 + # Assume the user will enter a valid number within 3 tries, + # otherwise use auto selection to avoid blocking. + while try_count <= 3: + try_count += 1 + if try_count >= 3: + print(f"你已经尝试{try_count} 次了。下一个agent将自动选择。") + break + try: + strs = input("请输入你想要选择agent的序号。当输入q或者不进行输入时则自动选择下一个agent。 ") + if strs == "" or strs == "q": + break + nums = int(strs) + if nums > 0 and nums <= _n_agents: + return agents[nums - 1] + else: + raise ValueError + except ValueError: + print(f"无效输入。请输入1到{_n_agents}的数字。") + return None + + def select_speaker(self, last_speaker: ConversableAgent, selector: ConversableAgent): + """Select the next speaker.""" + if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS: + raise ValueError( + f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. " + f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). " + ) + + agents = self.agents + n_agents = len(agents) + # Warn if GroupChat is underpopulated + if n_agents < 2: + raise ValueError( + f"GroupChat is underpopulated with {n_agents} agents. " + "Please add more agents to the GroupChat or use direct communication instead." + ) + elif ( + n_agents == 2 + and self.speaker_selection_method.lower() != "round_robin" + and self.allow_repeat_speaker + ): + logger.warning( + f"GroupChat is underpopulated with {n_agents} agents. " + + "It is recommended to set speaker_selection_method to " + + "'round_robin' or allow_repeat_speaker to False." + + "Or, use direct communication instead." + ) + + if self.func_call_filter and self.messages and "function_call" in self.messages[-1]: + # find agents with the right function_map which contains the function name + agents = [ + agent + for agent in self.agents + if agent.can_execute_function(self.messages[-1]["function_call"]["name"]) + ] + if len(agents) == 1: + # only one agent can execute the function + return agents[0] + elif not agents: + # find all the agents with function_map + agents = [agent for agent in self.agents if agent.function_map] + if len(agents) == 1: + return agents[0] + elif not agents: + raise ValueError( + f"No agent can execute the function {self.messages[-1]['name']}. " + "Please check the function_map of the agents." + ) + + # remove the last speaker from the list to avoid selecting + # the same speaker if allow_repeat_speaker is False + agents = ( + agents if self.allow_repeat_speaker else [agent for agent in agents if agent != last_speaker] + ) + if self.speaker_selection_method.lower() == "manual": + selected_agent = self.manual_select_speaker(agents) + if selected_agent: + return selected_agent + elif self.speaker_selection_method.lower() == "round_robin": + return self.next_agent(last_speaker, agents) + elif self.speaker_selection_method.lower() == "random": + return random.choice(agents) + selector.update_system_message(self.select_speaker_msg(agents)) + messages = copy.deepcopy(self.messages) + final, name = selector.generate_oai_reply( + messages, + systems=f"读一下上面的对话。然后从{[agent.name for agent in agents]}选择下一个扮演的角色。仅仅返回角色。", + ) + if not final: + # the LLM client is None, thus no reply is generated. Use round robin instead. + return self.next_agent(last_speaker, agents) + + # If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified + + mentions = self._mentioned_agents(name, agents) + if len(mentions) == 1: + name = next(iter(mentions)) + else: + logger.warning( + "GroupChat select_speaker failed to resolve " + + "the next speaker's name. This is because the " + + f"speaker selection OAI call returned:\n{name}" + ) + + # Return the result + try: + return self.agent_by_name(name) + except ValueError: + return self.next_agent(last_speaker, agents) + + def _participant_roles(self, agents): + # Default to all agents registered + if agents: + agents = self.agents + + roles = [] + for agent in agents: + if agent.system_message.strip() == "": + logger.warning( + f"The agent '{agent.name}' has an empty system_message, " + + "and may not work well with GroupChat." + ) + roles.append(f"{agent.name}: {agent.system_message}") + return "\n".join(roles) + + def _mentioned_agents(self, message_content, agents: List[ConversableAgent]): + """ + Finds and counts agent mentions in the string message_content, + taking word boundaries into account. + + Returns: A dictionary mapping agent names to mention counts + (to be included, at least one mention must occur) + """ + mentions = dict() + for agent in agents: + regex = ( + r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)" + ) # Finds agent mentions, taking word boundaries into account + count = len( + re.findall(regex, " " + message_content + " ") + ) # Pad the message to help with matching + if count > 0: + mentions[agent.name] = count + return mentions + + +class GroupChatManager(ConversableAgent): + """(In preview) A chat manager agent that can manage a group chat of multiple agents.""" + + def __init__( + self, + groupchat: GroupChat, + name: str = "chat_manager", + # unlimited consecutive auto reply by default + max_consecutive_auto_reply: Optional[int] = sys.maxsize, + human_input_mode: Optional[str] = "NEVER", + system_message: str = "Group chat manager.", + **kwargs, + ): + super().__init__( + name=name, + max_consecutive_auto_reply=max_consecutive_auto_reply, + human_input_mode=human_input_mode, + system_message=system_message, + **kwargs, + ) + # Order of register_reply is important. + # Allow sync chat if initiated using initiate_chat + self.register_reply( + AgentChat, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset + ) + # Allow async chat if initiated using a_initiate_chat + self.register_reply( + AgentChat, GroupChatManager.a_run_chat, config=groupchat, reset_config=GroupChat.reset + ) + + def run_chat( + self, + sender: ConversableAgent, + config: GroupChat, + messages: Optional[List[Dict]] = None, + ): + """Run a group chat.""" + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + speaker = sender + groupchat = config + for i in range(groupchat.max_round): + # set the name to speaker's name if the role is not function + if message["role"] != "function": + message["name"] = speaker.name + groupchat.messages.append(message) + # broadcast the message to all agents except the speaker + for agent in groupchat.agents: + if agent != speaker: + self.send(message=message, recipient=agent, request_reply=False, silent=True) + if i == groupchat.max_round - 1: + # the last round + break + try: + # select the next speaker + speaker = groupchat.select_speaker(speaker, self) + # let the speaker speak + reply = speaker.generate_reply(sender=self) + except KeyboardInterrupt: + # let the admin agent speak if interrupted + if groupchat.admin_name in groupchat.agent_names: + # admin agent is one of the participants + speaker = groupchat.agent_by_name(groupchat.admin_name) + reply = speaker.generate_reply(sender=self) + else: + # admin agent is not found in the participants + raise + if reply is None: + break + # The speaker sends the message without requesting a reply + speaker.send(message=reply, recipient=self, request_reply=False) + update_message = self.last_message(speaker) + if update_message: + message = update_message + else: + break + return True, None + + async def a_run_chat( + self, + sender: ConversableAgent, + config: GroupChat, + messages: Optional[List[Dict]] = None, + ): + """Run a group chat asynchronously.""" + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + speaker = sender + groupchat = config + for i in range(groupchat.max_round): + # set the name to speaker's name if the role is not function + if message["role"] != "function": + message["name"] = speaker.name + groupchat.messages.append(message) + # broadcast the message to all agents except the speaker + for agent in groupchat.agents: + if agent != speaker: + await self.a_send(message=message, recipient=agent, request_reply=False, silent=True) + if i == groupchat.max_round - 1: + # the last round + break + try: + # select the next speaker + speaker = groupchat.select_speaker(speaker, self) + # let the speaker speak + reply = await speaker.a_generate_reply(sender=self) + except KeyboardInterrupt: + # let the admin agent speak if interrupted + if groupchat.admin_name in groupchat.agent_names: + # admin agent is one of the participants + speaker = groupchat.agent_by_name(groupchat.admin_name) + reply = await speaker.a_generate_reply(sender=self) + else: + # admin agent is not found in the participants + raise + if reply is None: + break + # The speaker sends the message without requesting a reply + await speaker.a_send(message=reply, recipient=self, request_reply=False) + update_message = self.last_message(speaker) + if update_message: + message = update_message + else: + break + return True, None diff --git a/erniebot-agent/erniebot_agent/agents/user_proxy_agent.py b/erniebot-agent/erniebot_agent/agents/user_proxy_agent.py new file mode 100644 index 000000000..e2f8d6745 --- /dev/null +++ b/erniebot-agent/erniebot_agent/agents/user_proxy_agent.py @@ -0,0 +1,26 @@ +from .conversable_agent import ConversableAgent +from typing import Callable, Dict, Optional, Union + + +class UserProxyAgent(ConversableAgent): + def __init__( + self, + name: str, + llm_config: Dict, + is_termination_msg: Optional[Callable[[Dict], bool]] = None, + max_consecutive_auto_reply: Optional[int] = None, + human_input_mode: Optional[str] = "ALWAYS", + function_map: Optional[Dict[str, Callable]] = None, + default_auto_reply: Optional[Union[str, Dict, None]] = "", + system_message: str = "", + ): + super().__init__( + name, + llm_config, + system_message, + is_termination_msg, + max_consecutive_auto_reply, + human_input_mode, + function_map, + default_auto_reply, + ) diff --git a/erniebot-agent/examples/cookbook/how_to_use_user_agent.ipynb b/erniebot-agent/examples/cookbook/how_to_use_user_agent.ipynb new file mode 100644 index 000000000..a30b5691c --- /dev/null +++ b/erniebot-agent/examples/cookbook/how_to_use_user_agent.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": true, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# 自动生成的代理聊天:通过任务执行、人工反馈来解决任务\n", + "我们演示了如何使用 AssistantAgent 和 UserProxyAgent 通过人类反馈来解决问题。这里的 AssistantAgent 是一个基于 LLM 的代理,它可以利用LLM来解决问题。 UserProxyAgent 是一个代理,充当用户来判断AssistantAgent解决问题的方式是否正确。通过正确设置 human_input_mode,UserProxyAgent 还可以提示用户向 AssistantAgent 反馈。例如,当 human_input_mode 设置为“ALWAYS”时,UserProxyAgent 将始终提示用户反馈。当用户提供反馈时,UserProxyAgent 会将反馈直接传递给 AssistantAgent。当没有用户反馈时,UserProxyAgent将执行思考问题是否已经解决。" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## 准备工作\n", + "\n", + "安装`erniebot-agent`\n", + "```\n", + "git clone https://github.com/PaddlePaddle/ERNIE-Bot-SDK.git\n", + "cd erniebot-agent\n", + "```" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### 构建代理" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from erniebot_agent.agents.user_proxy_agent import UserProxyAgent\n", + "from erniebot_agent.agents.assistant_agent import AssistantAgent\n", + "config_list ={'model': 'ernie-bot-4','api_type': 'aistudio',\n", + " 'access_token':''}\n", + "# create an AssistantAgent instance named \"assistant\"\n", + "assistant = AssistantAgent(\n", + " name=\"assistant\",\n", + " llm_config=config_list\n", + ")\n", + "# create a UserProxyAgent instance named \"user_proxy\"\n", + "user_proxy = UserProxyAgent(\n", + " name=\"user_proxy\",\n", + " human_input_mode=\"ALWAYS\",\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"终止\"),\n", + " llm_config=config_list,\n", + " default_auto_reply=\"任务结束,仅仅输入'终止'\"\n", + ")" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "### 执行任务\n", + "我们调用user_proxy代理的initiate_chat()方法来开始对话。当您运行下面的单元格时,系统会在收到assistant代理的消息后提示您提供反馈。如果您不提供任何反馈(直接按 Enter 键),user_proxy代理将尝试代表您判断任务是否完成,或者如果向assistant代理最后发送“TERMINATE”信号则终止的消息" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "problem= \"\"\"\n", + "如何写好一份关于机器学习的综述\n", + "\"\"\"\n", + "# the assistant receives a message from the user, which contains the task description\n", + "user_proxy.initiate_chat(assistant, message=problem)\n" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file