13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
+ import re
17
+ import typing
18
+ from operator import itemgetter
16
19
from unittest import mock
17
20
18
21
import pytest
19
22
from langchain .agents import AgentType
20
23
from langchain .agents import Tool
21
24
from langchain .agents import initialize_agent
22
- from langchain .chat_models import ChatOpenAI # pylint: disable=no-name-in-module
25
+ from langchain .callbacks .manager import AsyncCallbackManagerForToolRun
26
+ from langchain .callbacks .manager import CallbackManagerForToolRun
27
+ from langchain_community .chat_models import ChatOpenAI
28
+ from langchain_core .tools import BaseTool
23
29
24
30
from _utils .llm import execute_node
25
31
from _utils .llm import mk_mock_langchain_tool
@@ -42,12 +48,16 @@ def test_get_input_names(mock_agent_executor: mock.MagicMock):
42
48
"values,arun_return,expected_output,expected_calls" ,
43
49
[({
44
50
'prompt' : "prompt1"
45
- }, list (range (3 )), list (range (3 )), [mock .call (prompt = "prompt1" )]),
51
+ }, list (range (3 )), list (range (3 )), [mock .call (prompt = "prompt1" , metadata = None )]),
46
52
({
47
53
'a' : ['b' , 'c' , 'd' ], 'c' : ['d' , 'e' , 'f' ], 'e' : ['f' , 'g' , 'h' ]
48
54
},
49
55
list (range (3 )), [list (range (3 ))] * 3 ,
50
- [mock .call (a = 'b' , c = 'd' , e = 'f' ), mock .call (a = 'c' , c = 'e' , e = 'g' ), mock .call (a = 'd' , c = 'f' , e = 'h' )])],
56
+ [
57
+ mock .call (a = 'b' , c = 'd' , e = 'f' , metadata = None ),
58
+ mock .call (a = 'c' , c = 'e' , e = 'g' , metadata = None ),
59
+ mock .call (a = 'd' , c = 'f' , e = 'h' , metadata = None )
60
+ ])],
51
61
ids = ["not-lists" , "all-lists" ])
52
62
def test_execute (
53
63
mock_agent_executor : mock .MagicMock ,
@@ -143,3 +153,114 @@ def test_execute_error(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMoc
143
153
144
154
node = LangChainAgentNode (agent_executor = agent )
145
155
assert isinstance (execute_node (node , input = "input1" ), RuntimeError )
156
+
157
+
158
+ class MetadataSaverTool (BaseTool ):
159
+ # The base class defines *args and **kwargs in the signature for _run and _arun requiring the arguments-differ
160
+ # pylint: disable=arguments-differ
161
+ name : str = "MetadataSaverTool"
162
+ description : str = "useful for when you need to know the name of a reptile"
163
+
164
+ saved_metadata : list [dict ] = []
165
+
166
+ def _run (
167
+ self ,
168
+ query : str ,
169
+ run_manager : typing .Optional [CallbackManagerForToolRun ] = None ,
170
+ ) -> str :
171
+ raise NotImplementedError ("This tool only supports async" )
172
+
173
+ async def _arun (
174
+ self ,
175
+ query : str ,
176
+ run_manager : typing .Optional [AsyncCallbackManagerForToolRun ] = None ,
177
+ ) -> str :
178
+ assert query is not None # avoiding unused-argument
179
+ assert run_manager is not None
180
+ self .saved_metadata .append (run_manager .metadata .copy ())
181
+ return "frog"
182
+
183
+
184
+ @pytest .mark .parametrize ("metadata" ,
185
+ [{
186
+ "morpheus" : "unittest"
187
+ }, {
188
+ "morpheus" : ["unittest" ]
189
+ }, {
190
+ "morpheus" : [f"unittest_{ i } " for i in range (3 )]
191
+ }],
192
+ ids = ["single-metadata" , "single-metadata-list" , "multiple-metadata-list" ])
193
+ def test_metadata (mock_chat_completion : tuple [mock .MagicMock , mock .MagicMock ], metadata : dict ):
194
+ if isinstance (metadata ['morpheus' ], list ):
195
+ num_meta = len (metadata ['morpheus' ])
196
+ input_data = [f"input_{ i } " for i in range (num_meta )]
197
+ expected_result = [f"{ input_val } : Yes!" for input_val in input_data ]
198
+ expected_saved_metadata = [{"morpheus" : meta } for meta in metadata ['morpheus' ]]
199
+ response_per_input_counter = {input_val : 0 for input_val in input_data }
200
+ else :
201
+ num_meta = 1
202
+ input_data = "input_0"
203
+ expected_result = "input_0: Yes!"
204
+ expected_saved_metadata = [metadata .copy ()]
205
+ response_per_input_counter = {input_data : 0 }
206
+
207
+ check_tool_response = 'I should check Tool1\n Action: MetadataSaverTool\n Action Input: "name a reptile"'
208
+ final_response = 'Observation: Answer: Yes!\n I now know the final answer.\n Final Answer: {}: Yes!'
209
+
210
+ # Tests the execute method of the LangChainAgentNode with a a mocked tools and chat completion
211
+ (_ , mock_async_client ) = mock_chat_completion
212
+
213
+ # Regex to find the actual prompt from the input which includes the REACT and tool description boilerplate
214
+ input_re = re .compile (r'^Question: (input_\d+)$' , re .MULTILINE )
215
+
216
+ def mock_llm_chat (* _ , messages , ** __ ):
217
+ """
218
+ This method avoids a race condition when running in aysnc mode over multiple inputs. Ensuring that the final
219
+ response is only given for an input after the initial check tool response.
220
+ """
221
+
222
+ query = None
223
+ for msg in messages :
224
+ if msg ['role' ] == 'user' :
225
+ query = msg ['content' ]
226
+
227
+ assert query is not None
228
+
229
+ match = input_re .search (query )
230
+ assert match is not None
231
+
232
+ input_key = match .group (1 )
233
+
234
+ call_count = response_per_input_counter [input_key ]
235
+
236
+ if call_count == 0 :
237
+ response = check_tool_response
238
+ else :
239
+ response = final_response .format (input_key )
240
+
241
+ response_per_input_counter [input_key ] += 1
242
+
243
+ return mk_mock_openai_response ([response ])
244
+
245
+ mock_async_client .chat .completions .create .side_effect = mock_llm_chat
246
+
247
+ llm_chat = ChatOpenAI (model = "fake-model" , openai_api_key = "fake-key" )
248
+
249
+ metadata_saver_tool = MetadataSaverTool ()
250
+
251
+ tools = [metadata_saver_tool ]
252
+
253
+ agent = initialize_agent (tools ,
254
+ llm_chat ,
255
+ agent = AgentType .ZERO_SHOT_REACT_DESCRIPTION ,
256
+ verbose = True ,
257
+ handle_parsing_errors = True ,
258
+ early_stopping_method = "generate" ,
259
+ return_intermediate_steps = False )
260
+
261
+ node = LangChainAgentNode (agent_executor = agent )
262
+
263
+ assert execute_node (node , input = input_data , metadata = metadata ) == expected_result
264
+
265
+ # Since we are running in async mode, we will need to sort saved metadata
266
+ assert sorted (metadata_saver_tool .saved_metadata , key = itemgetter ('morpheus' )) == expected_saved_metadata
0 commit comments