Skip to content

Commit bb51e61

Browse files
authored
Allow passing metadata to LangChainAgentNode._run_single (nv-morpheus#1710)
* Allows passing arbitrary `metadata` in to the agent. * Update a few imports to lower the number of deprecation warnings Closes nv-morpheus#1706 ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md). - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - David Gardner (https://github.com/dagardner-nv) Approvers: - Michael Demoret (https://github.com/mdemoret-nv) URL: nv-morpheus#1710
1 parent 580be43 commit bb51e61

File tree

4 files changed

+153
-14
lines changed

4 files changed

+153
-14
lines changed

morpheus/llm/nodes/langchain_agent_node.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,35 @@ def __init__(self, agent_executor: "AgentExecutor"):
4545
def get_input_names(self):
4646
return self._input_names
4747

48-
async def _run_single(self, **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]:
48+
@staticmethod
49+
def _is_all_lists(data: dict[str, typing.Any]) -> bool:
50+
return all(isinstance(v, list) for v in data.values())
4951

50-
all_lists = all(isinstance(v, list) for v in kwargs.values())
52+
@staticmethod
53+
def _transform_dict_of_lists(data: dict[str, typing.Any]) -> list[dict[str, typing.Any]]:
54+
return [dict(zip(data, t)) for t in zip(*data.values())]
55+
56+
async def _run_single(self, metadata: dict[str, typing.Any] = None, **kwargs) -> dict[str, typing.Any]:
57+
58+
all_lists = self._is_all_lists(kwargs)
5159

5260
# Check if all values are a list
5361
if all_lists:
5462

5563
# Transform from dict[str, list[Any]] to list[dict[str, Any]]
56-
input_list = [dict(zip(kwargs, t)) for t in zip(*kwargs.values())]
64+
input_list = self._transform_dict_of_lists(kwargs)
65+
66+
# If all metadata values are lists of the same length and the same length as the input list
67+
# then transform them the same way as the input list
68+
if (metadata is not None and self._is_all_lists(metadata)
69+
and all(len(v) == len(input_list) for v in metadata.values())):
70+
metadata_list = self._transform_dict_of_lists(metadata)
71+
72+
else:
73+
metadata_list = [metadata] * len(input_list)
5774

5875
# Run multiple again
59-
results_async = [self._run_single(**x) for x in input_list]
76+
results_async = [self._run_single(metadata=metadata_list[i], **x) for (i, x) in enumerate(input_list)]
6077

6178
results = await asyncio.gather(*results_async, return_exceptions=True)
6279

@@ -67,7 +84,7 @@ async def _run_single(self, **kwargs: dict[str, typing.Any]) -> dict[str, typing
6784

6885
# We are not dealing with a list, so run single
6986
try:
70-
return await self._agent_executor.arun(**kwargs)
87+
return await self._agent_executor.arun(metadata=metadata, **kwargs)
7188
except Exception as e:
7289
logger.exception("Error running agent: %s", e)
7390
return e

tests/llm/nodes/test_langchain_agent_node.py

+124-3
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import re
17+
import typing
18+
from operator import itemgetter
1619
from unittest import mock
1720

1821
import pytest
1922
from langchain.agents import AgentType
2023
from langchain.agents import Tool
2124
from langchain.agents import initialize_agent
22-
from langchain.chat_models import ChatOpenAI # pylint: disable=no-name-in-module
25+
from langchain.callbacks.manager import AsyncCallbackManagerForToolRun
26+
from langchain.callbacks.manager import CallbackManagerForToolRun
27+
from langchain_community.chat_models import ChatOpenAI
28+
from langchain_core.tools import BaseTool
2329

2430
from _utils.llm import execute_node
2531
from _utils.llm import mk_mock_langchain_tool
@@ -42,12 +48,16 @@ def test_get_input_names(mock_agent_executor: mock.MagicMock):
4248
"values,arun_return,expected_output,expected_calls",
4349
[({
4450
'prompt': "prompt1"
45-
}, list(range(3)), list(range(3)), [mock.call(prompt="prompt1")]),
51+
}, list(range(3)), list(range(3)), [mock.call(prompt="prompt1", metadata=None)]),
4652
({
4753
'a': ['b', 'c', 'd'], 'c': ['d', 'e', 'f'], 'e': ['f', 'g', 'h']
4854
},
4955
list(range(3)), [list(range(3))] * 3,
50-
[mock.call(a='b', c='d', e='f'), mock.call(a='c', c='e', e='g'), mock.call(a='d', c='f', e='h')])],
56+
[
57+
mock.call(a='b', c='d', e='f', metadata=None),
58+
mock.call(a='c', c='e', e='g', metadata=None),
59+
mock.call(a='d', c='f', e='h', metadata=None)
60+
])],
5161
ids=["not-lists", "all-lists"])
5262
def test_execute(
5363
mock_agent_executor: mock.MagicMock,
@@ -143,3 +153,114 @@ def test_execute_error(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMoc
143153

144154
node = LangChainAgentNode(agent_executor=agent)
145155
assert isinstance(execute_node(node, input="input1"), RuntimeError)
156+
157+
158+
class MetadataSaverTool(BaseTool):
159+
# The base class defines *args and **kwargs in the signature for _run and _arun requiring the arguments-differ
160+
# pylint: disable=arguments-differ
161+
name: str = "MetadataSaverTool"
162+
description: str = "useful for when you need to know the name of a reptile"
163+
164+
saved_metadata: list[dict] = []
165+
166+
def _run(
167+
self,
168+
query: str,
169+
run_manager: typing.Optional[CallbackManagerForToolRun] = None,
170+
) -> str:
171+
raise NotImplementedError("This tool only supports async")
172+
173+
async def _arun(
174+
self,
175+
query: str,
176+
run_manager: typing.Optional[AsyncCallbackManagerForToolRun] = None,
177+
) -> str:
178+
assert query is not None # avoiding unused-argument
179+
assert run_manager is not None
180+
self.saved_metadata.append(run_manager.metadata.copy())
181+
return "frog"
182+
183+
184+
@pytest.mark.parametrize("metadata",
185+
[{
186+
"morpheus": "unittest"
187+
}, {
188+
"morpheus": ["unittest"]
189+
}, {
190+
"morpheus": [f"unittest_{i}" for i in range(3)]
191+
}],
192+
ids=["single-metadata", "single-metadata-list", "multiple-metadata-list"])
193+
def test_metadata(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], metadata: dict):
194+
if isinstance(metadata['morpheus'], list):
195+
num_meta = len(metadata['morpheus'])
196+
input_data = [f"input_{i}" for i in range(num_meta)]
197+
expected_result = [f"{input_val}: Yes!" for input_val in input_data]
198+
expected_saved_metadata = [{"morpheus": meta} for meta in metadata['morpheus']]
199+
response_per_input_counter = {input_val: 0 for input_val in input_data}
200+
else:
201+
num_meta = 1
202+
input_data = "input_0"
203+
expected_result = "input_0: Yes!"
204+
expected_saved_metadata = [metadata.copy()]
205+
response_per_input_counter = {input_data: 0}
206+
207+
check_tool_response = 'I should check Tool1\nAction: MetadataSaverTool\nAction Input: "name a reptile"'
208+
final_response = 'Observation: Answer: Yes!\nI now know the final answer.\nFinal Answer: {}: Yes!'
209+
210+
# Tests the execute method of the LangChainAgentNode with a a mocked tools and chat completion
211+
(_, mock_async_client) = mock_chat_completion
212+
213+
# Regex to find the actual prompt from the input which includes the REACT and tool description boilerplate
214+
input_re = re.compile(r'^Question: (input_\d+)$', re.MULTILINE)
215+
216+
def mock_llm_chat(*_, messages, **__):
217+
"""
218+
This method avoids a race condition when running in aysnc mode over multiple inputs. Ensuring that the final
219+
response is only given for an input after the initial check tool response.
220+
"""
221+
222+
query = None
223+
for msg in messages:
224+
if msg['role'] == 'user':
225+
query = msg['content']
226+
227+
assert query is not None
228+
229+
match = input_re.search(query)
230+
assert match is not None
231+
232+
input_key = match.group(1)
233+
234+
call_count = response_per_input_counter[input_key]
235+
236+
if call_count == 0:
237+
response = check_tool_response
238+
else:
239+
response = final_response.format(input_key)
240+
241+
response_per_input_counter[input_key] += 1
242+
243+
return mk_mock_openai_response([response])
244+
245+
mock_async_client.chat.completions.create.side_effect = mock_llm_chat
246+
247+
llm_chat = ChatOpenAI(model="fake-model", openai_api_key="fake-key")
248+
249+
metadata_saver_tool = MetadataSaverTool()
250+
251+
tools = [metadata_saver_tool]
252+
253+
agent = initialize_agent(tools,
254+
llm_chat,
255+
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
256+
verbose=True,
257+
handle_parsing_errors=True,
258+
early_stopping_method="generate",
259+
return_intermediate_steps=False)
260+
261+
node = LangChainAgentNode(agent_executor=agent)
262+
263+
assert execute_node(node, input=input_data, metadata=metadata) == expected_result
264+
265+
# Since we are running in async mode, we will need to sort saved metadata
266+
assert sorted(metadata_saver_tool.saved_metadata, key=itemgetter('morpheus')) == expected_saved_metadata

tests/llm/nodes/test_langchain_agent_node_pipe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_pipeline(config: Config, dataset_cudf: DatasetManager, mock_agent_execu
4545

4646
mock_agent_executor.arun.return_value = 'frogs'
4747
expected_df['response'] = 'frogs'
48-
expected_calls = [mock.call(prompt=x) for x in expected_df['v3'].values_host]
48+
expected_calls = [mock.call(prompt=x, metadata=None) for x in expected_df['v3'].values_host]
4949

5050
task_payload = {"task_type": "llm_engine", "task_dict": {"input_keys": ['v3']}}
5151

tests/llm/test_agents_simple_pipe.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
import re
1818
from unittest import mock
1919

20-
import langchain
2120
import pytest
2221
from langchain.agents import AgentType
2322
from langchain.agents import initialize_agent
2423
from langchain.agents import load_tools
2524
from langchain.agents.tools import Tool
26-
from langchain.utilities import serpapi
25+
from langchain_community.llms import OpenAI # pylint: disable=no-name-in-module
26+
from langchain_community.utilities import serpapi
2727

2828
import cudf
2929

@@ -50,7 +50,7 @@ def questions_fixture():
5050

5151
def _build_agent_executor(model_name: str):
5252

53-
llm = langchain.OpenAI(model=model_name, temperature=0, cache=False)
53+
llm = OpenAI(model=model_name, temperature=0, cache=False)
5454

5555
# Explicitly construct the serpapi tool, loading it via load_tools makes it too difficult to mock
5656
tools = [
@@ -125,8 +125,9 @@ def test_agents_simple_pipe_integration_openai(config: Config, questions: list[s
125125

126126

127127
@pytest.mark.usefixtures("openai", "restore_environ")
128-
@mock.patch("langchain.utilities.serpapi.SerpAPIWrapper.aresults")
129-
@mock.patch("langchain.OpenAI._agenerate", autospec=True) # autospec is needed as langchain will inspect the function
128+
@mock.patch("langchain_community.utilities.serpapi.SerpAPIWrapper.aresults")
129+
@mock.patch("langchain_community.llms.OpenAI._agenerate",
130+
autospec=True) # autospec is needed as langchain will inspect the function
130131
def test_agents_simple_pipe(mock_openai_agenerate: mock.AsyncMock,
131132
mock_serpapi_aresults: mock.AsyncMock,
132133
config: Config,

0 commit comments

Comments
 (0)