Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from typing import Any, Literal, Protocol

from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langchain_core.messages import AIMessage, RemoveMessage, ToolCall, ToolMessage
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from langgraph.types import interrupt
from typing_extensions import NotRequired, TypedDict
Expand Down Expand Up @@ -269,6 +270,42 @@ def _process_decision(
)
raise ValueError(msg)

def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Before the agent runs, handle dangling tool calls from the most recent AIMessage."""
messages = state["messages"]
if not messages or len(messages) == 0:
return None

patched_messages = []
# Iterate over the messages and add any dangling tool calls
for i, msg in enumerate(messages):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to only check the last AIMessage? I have a helper for finding the last AI message + any tool messages after the fact

patched_messages.append(msg)
if msg.type == "ai" and msg.tool_calls:
for tool_call in msg.tool_calls:
corresponding_tool_msg = next(
(
msg
for msg in messages[i:]
if msg.type == "tool" and msg.tool_call_id == tool_call["id"]
),
None,
)
if corresponding_tool_msg is None:
# We have a dangling tool call which needs a ToolMessage
tool_msg = (
f"Tool call {tool_call['name']} with id {tool_call['id']} was "
"cancelled - another message came in before it could be completed."
)
patched_messages.append(
ToolMessage(
content=tool_msg,
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)

return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), *patched_messages]}

def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""Trigger interrupt flows for relevant tool calls after an AIMessage."""
messages = state["messages"]
Expand Down
142 changes: 142 additions & 0 deletions libs/langchain_v1/tests/unit_tests/agents/middleware/test_hitl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
from langchain_core.messages import (
SystemMessage,
HumanMessage,
AIMessage,
ToolMessage,
ToolCall,
RemoveMessage,
)
from langgraph.graph.message import add_messages


class TestHumanInTheLoopMiddlewareBeforeModel:
"""Test HumanInTheLoopMiddleware before_model behavior."""

def test_first_message(self) -> None:
input_messages = [
SystemMessage(content="You are a helpful assistant.", id="1"),
HumanMessage(content="Hello, how are you?", id="2"),
]
middleware = HumanInTheLoopMiddleware(interrupt_on={})
state_update = middleware.before_agent({"messages": input_messages}, None)
assert state_update is not None
assert len(state_update["messages"]) == 3
assert state_update["messages"][0].type == "remove"
assert state_update["messages"][1].type == "system"
assert state_update["messages"][1].content == "You are a helpful assistant."
assert state_update["messages"][2].type == "human"
assert state_update["messages"][2].content == "Hello, how are you?"
assert state_update["messages"][2].id == "2"

def test_missing_tool_call(self) -> None:
input_messages = [
SystemMessage(content="You are a helpful assistant.", id="1"),
HumanMessage(content="Hello, how are you?", id="2"),
AIMessage(
content="I'm doing well, thank you!",
tool_calls=[
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
],
id="3",
),
HumanMessage(content="What is the weather in Tokyo?", id="4"),
]
middleware = HumanInTheLoopMiddleware(interrupt_on={})
state_update = middleware.before_agent({"messages": input_messages}, None)
assert state_update is not None
assert len(state_update["messages"]) == 6
assert state_update["messages"][0].type == "remove"
assert state_update["messages"][1] == input_messages[0]
assert state_update["messages"][2] == input_messages[1]
assert state_update["messages"][3] == input_messages[2]
assert state_update["messages"][4].type == "tool"
assert state_update["messages"][4].tool_call_id == "123"
assert state_update["messages"][4].name == "get_events_for_days"
assert state_update["messages"][5] == input_messages[3]
updated_messages = add_messages(input_messages, state_update["messages"])
assert len(updated_messages) == 5
assert updated_messages[0] == input_messages[0]
assert updated_messages[1] == input_messages[1]
assert updated_messages[2] == input_messages[2]
assert updated_messages[3].type == "tool"
assert updated_messages[3].tool_call_id == "123"
assert updated_messages[3].name == "get_events_for_days"
assert updated_messages[4] == input_messages[3]

def test_no_missing_tool_calls(self) -> None:
input_messages = [
SystemMessage(content="You are a helpful assistant.", id="1"),
HumanMessage(content="Hello, how are you?", id="2"),
AIMessage(
content="I'm doing well, thank you!",
tool_calls=[
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
],
id="3",
),
ToolMessage(content="I have no events for that date.", tool_call_id="123", id="4"),
HumanMessage(content="What is the weather in Tokyo?", id="5"),
]
middleware = HumanInTheLoopMiddleware(interrupt_on={})
state_update = middleware.before_agent({"messages": input_messages}, None)
assert state_update is not None
assert len(state_update["messages"]) == 6
assert state_update["messages"][0].type == "remove"
assert state_update["messages"][1:] == input_messages
updated_messages = add_messages(input_messages, state_update["messages"])
assert len(updated_messages) == 5
assert updated_messages == input_messages

def test_two_missing_tool_calls(self) -> None:
input_messages = [
SystemMessage(content="You are a helpful assistant.", id="1"),
HumanMessage(content="Hello, how are you?", id="2"),
AIMessage(
content="I'm doing well, thank you!",
tool_calls=[
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
],
id="3",
),
HumanMessage(content="What is the weather in Tokyo?", id="4"),
AIMessage(
content="I'm doing well, thank you!",
tool_calls=[
ToolCall(id="456", name="get_events_for_days", args={"date_str": "2025-01-01"})
],
id="5",
),
HumanMessage(content="What is the weather in Tokyo?", id="6"),
]
middleware = HumanInTheLoopMiddleware(interrupt_on={})
state_update = middleware.before_agent({"messages": input_messages}, None)
assert state_update is not None
assert len(state_update["messages"]) == 9
assert state_update["messages"][0].type == "remove"
assert state_update["messages"][1] == input_messages[0]
assert state_update["messages"][2] == input_messages[1]
assert state_update["messages"][3] == input_messages[2]
assert state_update["messages"][4].type == "tool"
assert state_update["messages"][4].tool_call_id == "123"
assert state_update["messages"][4].name == "get_events_for_days"
assert state_update["messages"][5] == input_messages[3]
assert state_update["messages"][6] == input_messages[4]
assert state_update["messages"][7].type == "tool"
assert state_update["messages"][7].tool_call_id == "456"
assert state_update["messages"][7].name == "get_events_for_days"
assert state_update["messages"][8] == input_messages[5]
updated_messages = add_messages(input_messages, state_update["messages"])
assert len(updated_messages) == 8
assert updated_messages[0] == input_messages[0]
assert updated_messages[1] == input_messages[1]
assert updated_messages[2] == input_messages[2]
assert updated_messages[3].type == "tool"
assert updated_messages[3].tool_call_id == "123"
assert updated_messages[3].name == "get_events_for_days"
assert updated_messages[4] == input_messages[3]
assert updated_messages[5] == input_messages[4]
assert updated_messages[6].type == "tool"
assert updated_messages[6].tool_call_id == "456"
assert updated_messages[6].name == "get_events_for_days"
assert updated_messages[7] == input_messages[5]