Skip to content

Commit

Permalink
Agents: Small fixes in streaming to gradio + add tests (huggingface#3…
Browse files Browse the repository at this point in the history
…4549)

* Better support transformers.agents in gradio: small fixes and additional tests
  • Loading branch information
aymeric-roucher authored Nov 11, 2024
1 parent 6de2a4d commit 33eef99
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 41 deletions.
3 changes: 1 addition & 2 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,11 +1141,10 @@ def step(self):
)
self.logger.warning("Print outputs:")
self.logger.log(32, self.state["print_outputs"])
observation = "Print outputs:\n" + self.state["print_outputs"]
if result is not None:
self.logger.warning("Last output from code snippet:")
self.logger.log(32, str(result))
observation = "Print outputs:\n" + self.state["print_outputs"]
if result is not None:
observation += "Last output from code snippet:\n" + str(result)[:100000]
current_step_logs["observation"] = observation
except Exception as e:
Expand Down
42 changes: 30 additions & 12 deletions src/transformers/agents/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@
from .agents import ReactAgent


def pull_message(step_log: dict):
def pull_message(step_log: dict, test_mode: bool = True):
try:
from gradio import ChatMessage
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
if test_mode:

class ChatMessage:
def __init__(self, role, content, metadata=None):
self.role = role
self.content = content
self.metadata = metadata
else:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")

if step_log.get("rationale"):
yield ChatMessage(role="assistant", content=step_log["rationale"])
Expand All @@ -46,30 +54,40 @@ def pull_message(step_log: dict):
)


def stream_to_gradio(agent: ReactAgent, task: str, **kwargs):
def stream_to_gradio(agent: ReactAgent, task: str, test_mode: bool = False, **kwargs):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""

try:
from gradio import ChatMessage
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
if test_mode:

class ChatMessage:
def __init__(self, role, content, metadata=None):
self.role = role
self.content = content
self.metadata = metadata
else:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")

for step_log in agent.run(task, stream=True, **kwargs):
if isinstance(step_log, dict):
for message in pull_message(step_log):
for message in pull_message(step_log, test_mode=test_mode):
yield message

if isinstance(step_log, AgentText):
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```")
elif isinstance(step_log, AgentImage):
final_answer = step_log # Last log is the run's final_answer

if isinstance(final_answer, AgentText):
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```")
elif isinstance(final_answer, AgentImage):
yield ChatMessage(
role="assistant",
content={"path": step_log.to_string(), "mime_type": "image/png"},
content={"path": final_answer.to_string(), "mime_type": "image/png"},
)
elif isinstance(step_log, AgentAudio):
elif isinstance(final_answer, AgentAudio):
yield ChatMessage(
role="assistant",
content={"path": step_log.to_string(), "mime_type": "audio/wav"},
content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
)
else:
yield ChatMessage(role="assistant", content=str(step_log))
yield ChatMessage(role="assistant", content=str(final_answer))
36 changes: 15 additions & 21 deletions src/transformers/agents/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,13 @@ def evaluate_ast(
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")


def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str:
if len(print_outputs) < max_len_outputs:
return print_outputs
else:
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"


def evaluate_python_code(
code: str,
static_tools: Optional[Dict[str, Callable]] = None,
Expand Down Expand Up @@ -890,25 +897,12 @@ def evaluate_python_code(
PRINT_OUTPUTS = ""
global OPERATIONS_COUNT
OPERATIONS_COUNT = 0
for node in expression.body:
try:
try:
for node in expression.body:
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
except InterpreterError as e:
msg = ""
if len(PRINT_OUTPUTS) > 0:
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n"
else:
msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n"
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg)
finally:
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
state["print_outputs"] = PRINT_OUTPUTS
else:
state["print_outputs"] = (
PRINT_OUTPUTS[:MAX_LEN_OUTPUT]
+ f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._"
)

return result
state["print_outputs"] = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
return result
except InterpreterError as e:
msg = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg)
16 changes: 10 additions & 6 deletions src/transformers/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import base64
import importlib
import inspect
Expand Down Expand Up @@ -141,15 +142,19 @@ def validate_arguments(self, do_validate_forward: bool = True):
required_attributes = {
"description": str,
"name": str,
"inputs": Dict,
"inputs": dict,
"output_type": str,
}
authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"]

for attr, expected_type in required_attributes.items():
attr_value = getattr(self, attr, None)
if attr_value is None:
raise TypeError(f"You must set an attribute {attr}.")
if not isinstance(attr_value, expected_type):
raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
raise TypeError(
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
)
for input_name, input_content in self.inputs.items():
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
assert (
Expand Down Expand Up @@ -248,7 +253,6 @@ def save(self, output_dir):
def from_hub(
cls,
repo_id: str,
model_repo_id: Optional[str] = None,
token: Optional[str] = None,
**kwargs,
):
Expand All @@ -266,9 +270,6 @@ def from_hub(
Args:
repo_id (`str`):
The name of the repo on the Hub where your tool is defined.
model_repo_id (`str`, *optional*):
If your tool uses a model and you want to use a different model than the default, you can pass a second
repo ID or an endpoint url to this argument.
token (`str`, *optional*):
The token to identify you on hf.co. If unset, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
Expand Down Expand Up @@ -354,6 +355,9 @@ def from_hub(
if tool_class.output_type != custom_tool["output_type"]:
tool_class.output_type = custom_tool["output_type"]

if not isinstance(tool_class.inputs, dict):
tool_class.inputs = ast.literal_eval(tool_class.inputs)

return tool_class(**kwargs)

def push_to_hub(
Expand Down
82 changes: 82 additions & 0 deletions tests/agents/test_monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from transformers.agents.agent_types import AgentImage
from transformers.agents.agents import AgentError, ReactCodeAgent, ReactJsonAgent
from transformers.agents.monitoring import stream_to_gradio


class MonitoringTester(unittest.TestCase):
def test_streaming_agent_text_output(self):
def dummy_llm_engine(prompt, **kwargs):
return """
Code:
````
final_answer('This is the final answer.')
```"""

agent = ReactCodeAgent(
tools=[],
llm_engine=dummy_llm_engine,
max_iterations=1,
)

# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))

self.assertEqual(len(outputs), 3)
final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant")
self.assertIn("This is the final answer.", final_message.content)

def test_streaming_agent_image_output(self):
def dummy_llm_engine(prompt, **kwargs):
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'

agent = ReactJsonAgent(
tools=[],
llm_engine=dummy_llm_engine,
max_iterations=1,
)

# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True))

self.assertEqual(len(outputs), 2)
final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant")
self.assertIsInstance(final_message.content, dict)
self.assertEqual(final_message.content["path"], "path.png")
self.assertEqual(final_message.content["mime_type"], "image/png")

def test_streaming_with_agent_error(self):
def dummy_llm_engine(prompt, **kwargs):
raise AgentError("Simulated agent error")

agent = ReactCodeAgent(
tools=[],
llm_engine=dummy_llm_engine,
max_iterations=1,
)

# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))

self.assertEqual(len(outputs), 3)
final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant")
self.assertIn("Simulated agent error", final_message.content)

0 comments on commit 33eef99

Please sign in to comment.