diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index a29b23e7ec620..a547e2ce95576 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -2,7 +2,8 @@ from typing import Any, Literal, Protocol -from langchain_core.messages import AIMessage, ToolCall, ToolMessage +from langchain_core.messages import AIMessage, RemoveMessage, ToolCall, ToolMessage +from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.runtime import Runtime from langgraph.types import interrupt from typing_extensions import NotRequired, TypedDict @@ -269,6 +270,42 @@ def _process_decision( ) raise ValueError(msg) + def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002 + """Before the agent runs, handle dangling tool calls from the most recent AIMessage.""" + messages = state["messages"] + if not messages or len(messages) == 0: + return None + + patched_messages = [] + # Iterate over the messages and add any dangling tool calls + for i, msg in enumerate(messages): + patched_messages.append(msg) + if msg.type == "ai" and msg.tool_calls: + for tool_call in msg.tool_calls: + corresponding_tool_msg = next( + ( + msg + for msg in messages[i:] + if msg.type == "tool" and msg.tool_call_id == tool_call["id"] + ), + None, + ) + if corresponding_tool_msg is None: + # We have a dangling tool call which needs a ToolMessage + tool_msg = ( + f"Tool call {tool_call['name']} with id {tool_call['id']} was " + "cancelled - another message came in before it could be completed." + ) + patched_messages.append( + ToolMessage( + content=tool_msg, + name=tool_call["name"], + tool_call_id=tool_call["id"], + ) + ) + + return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), *patched_messages]} + def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: """Trigger interrupt flows for relevant tool calls after an AIMessage.""" messages = state["messages"] diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_hitl.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_hitl.py new file mode 100644 index 0000000000000..1ed99af26252b --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_hitl.py @@ -0,0 +1,142 @@ +from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware +from langchain_core.messages import ( + SystemMessage, + HumanMessage, + AIMessage, + ToolMessage, + ToolCall, + RemoveMessage, +) +from langgraph.graph.message import add_messages + + +class TestHumanInTheLoopMiddlewareBeforeModel: + """Test HumanInTheLoopMiddleware before_model behavior.""" + + def test_first_message(self) -> None: + input_messages = [ + SystemMessage(content="You are a helpful assistant.", id="1"), + HumanMessage(content="Hello, how are you?", id="2"), + ] + middleware = HumanInTheLoopMiddleware(interrupt_on={}) + state_update = middleware.before_agent({"messages": input_messages}, None) + assert state_update is not None + assert len(state_update["messages"]) == 3 + assert state_update["messages"][0].type == "remove" + assert state_update["messages"][1].type == "system" + assert state_update["messages"][1].content == "You are a helpful assistant." + assert state_update["messages"][2].type == "human" + assert state_update["messages"][2].content == "Hello, how are you?" + assert state_update["messages"][2].id == "2" + + def test_missing_tool_call(self) -> None: + input_messages = [ + SystemMessage(content="You are a helpful assistant.", id="1"), + HumanMessage(content="Hello, how are you?", id="2"), + AIMessage( + content="I'm doing well, thank you!", + tool_calls=[ + ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"}) + ], + id="3", + ), + HumanMessage(content="What is the weather in Tokyo?", id="4"), + ] + middleware = HumanInTheLoopMiddleware(interrupt_on={}) + state_update = middleware.before_agent({"messages": input_messages}, None) + assert state_update is not None + assert len(state_update["messages"]) == 6 + assert state_update["messages"][0].type == "remove" + assert state_update["messages"][1] == input_messages[0] + assert state_update["messages"][2] == input_messages[1] + assert state_update["messages"][3] == input_messages[2] + assert state_update["messages"][4].type == "tool" + assert state_update["messages"][4].tool_call_id == "123" + assert state_update["messages"][4].name == "get_events_for_days" + assert state_update["messages"][5] == input_messages[3] + updated_messages = add_messages(input_messages, state_update["messages"]) + assert len(updated_messages) == 5 + assert updated_messages[0] == input_messages[0] + assert updated_messages[1] == input_messages[1] + assert updated_messages[2] == input_messages[2] + assert updated_messages[3].type == "tool" + assert updated_messages[3].tool_call_id == "123" + assert updated_messages[3].name == "get_events_for_days" + assert updated_messages[4] == input_messages[3] + + def test_no_missing_tool_calls(self) -> None: + input_messages = [ + SystemMessage(content="You are a helpful assistant.", id="1"), + HumanMessage(content="Hello, how are you?", id="2"), + AIMessage( + content="I'm doing well, thank you!", + tool_calls=[ + ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"}) + ], + id="3", + ), + ToolMessage(content="I have no events for that date.", tool_call_id="123", id="4"), + HumanMessage(content="What is the weather in Tokyo?", id="5"), + ] + middleware = HumanInTheLoopMiddleware(interrupt_on={}) + state_update = middleware.before_agent({"messages": input_messages}, None) + assert state_update is not None + assert len(state_update["messages"]) == 6 + assert state_update["messages"][0].type == "remove" + assert state_update["messages"][1:] == input_messages + updated_messages = add_messages(input_messages, state_update["messages"]) + assert len(updated_messages) == 5 + assert updated_messages == input_messages + + def test_two_missing_tool_calls(self) -> None: + input_messages = [ + SystemMessage(content="You are a helpful assistant.", id="1"), + HumanMessage(content="Hello, how are you?", id="2"), + AIMessage( + content="I'm doing well, thank you!", + tool_calls=[ + ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"}) + ], + id="3", + ), + HumanMessage(content="What is the weather in Tokyo?", id="4"), + AIMessage( + content="I'm doing well, thank you!", + tool_calls=[ + ToolCall(id="456", name="get_events_for_days", args={"date_str": "2025-01-01"}) + ], + id="5", + ), + HumanMessage(content="What is the weather in Tokyo?", id="6"), + ] + middleware = HumanInTheLoopMiddleware(interrupt_on={}) + state_update = middleware.before_agent({"messages": input_messages}, None) + assert state_update is not None + assert len(state_update["messages"]) == 9 + assert state_update["messages"][0].type == "remove" + assert state_update["messages"][1] == input_messages[0] + assert state_update["messages"][2] == input_messages[1] + assert state_update["messages"][3] == input_messages[2] + assert state_update["messages"][4].type == "tool" + assert state_update["messages"][4].tool_call_id == "123" + assert state_update["messages"][4].name == "get_events_for_days" + assert state_update["messages"][5] == input_messages[3] + assert state_update["messages"][6] == input_messages[4] + assert state_update["messages"][7].type == "tool" + assert state_update["messages"][7].tool_call_id == "456" + assert state_update["messages"][7].name == "get_events_for_days" + assert state_update["messages"][8] == input_messages[5] + updated_messages = add_messages(input_messages, state_update["messages"]) + assert len(updated_messages) == 8 + assert updated_messages[0] == input_messages[0] + assert updated_messages[1] == input_messages[1] + assert updated_messages[2] == input_messages[2] + assert updated_messages[3].type == "tool" + assert updated_messages[3].tool_call_id == "123" + assert updated_messages[3].name == "get_events_for_days" + assert updated_messages[4] == input_messages[3] + assert updated_messages[5] == input_messages[4] + assert updated_messages[6].type == "tool" + assert updated_messages[6].tool_call_id == "456" + assert updated_messages[6].name == "get_events_for_days" + assert updated_messages[7] == input_messages[5]