From 4ef29f262028b311c156cef79127dafcbd16fe6d Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Fri, 25 Apr 2025 16:31:05 -0700 Subject: [PATCH 1/7] feat: adding UTs for plan execute reflect agent Signed-off-by: Pavan Yekbote --- ...LPlanExecuteAndReflectAgentRunnerTest.java | 368 ++++++++++++++++++ 1 file changed, 368 insertions(+) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java new file mode 100644 index 0000000000..05b45deb43 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -0,0 +1,368 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.StepListener; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.MLAgentType; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.memory.Memory; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.memory.ConversationIndexMemory; +import org.opensearch.ml.engine.memory.MLMemoryManager; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.transport.client.Client; + +public class MLPlanExecuteAndReflectAgentRunnerTest { + public static final String FIRST_TOOL = "firstTool"; + public static final String SECOND_TOOL = "secondTool"; + + @Mock + private Client client; + private Settings settings; + @Mock + private ClusterService clusterService; + @Mock + private NamedXContentRegistry xContentRegistry; + private Map toolFactories; + @Mock + private Map memoryMap; + private MLPlanExecuteAndReflectAgentRunner mlPlanExecuteAndReflectAgentRunner; + @Mock + private Tool.Factory firstToolFactory; + @Mock + private Tool.Factory secondToolFactory; + @Mock + private Tool firstTool; + @Mock + private Tool secondTool; + @Mock + private ActionListener agentActionListener; + @Mock + private ConversationIndexMemory conversationIndexMemory; + @Mock + private MLMemoryManager mlMemoryManager; + @Mock + private CreateInteractionResponse createInteractionResponse; + @Mock + private ConversationIndexMemory.Factory memoryFactory; + @Mock + private SdkClient sdkClient; + @Mock + private Encryptor encryptor; + @Mock + private UpdateResponse updateResponse; + @Mock + private MLExecuteTaskResponse mlExecuteTaskResponse; + @Mock + private MLTaskResponse mlTaskResponse; + + @Captor + private ArgumentCaptor objectCaptor; + @Captor + private ArgumentCaptor> memoryFactoryCapture; + @Captor + private ArgumentCaptor>> memoryInteractionCapture; + @Captor + private ArgumentCaptor> toolParamsCapture; + + private MLMemorySpec mlMemorySpec; + + @Before + @SuppressWarnings("unchecked") + public void setup() { + MockitoAnnotations.openMocks(this); + settings = Settings.builder().build(); + toolFactories = ImmutableMap.of(FIRST_TOOL, firstToolFactory, SECOND_TOOL, secondToolFactory); + + // Setup memory + mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10); + when(memoryMap.get(anyString())).thenReturn(memoryFactory); + when(conversationIndexMemory.getConversationId()).thenReturn("conversation_id"); + when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); + when(createInteractionResponse.getId()).thenReturn("create_interaction_id"); + when(updateResponse.getId()).thenReturn("update_interaction_id"); + + // Setup memory factory + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(conversationIndexMemory); + return null; + }).when(memoryFactory).create(any(), any(), any(), memoryFactoryCapture.capture()); + + // Setup conversation index memory + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + listener.onResponse(generateInteractions(2)); + return null; + }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), anyInt()); + + // Setup memory manager + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onResponse(createInteractionResponse); + return null; + }).when(conversationIndexMemory).save(any(), any(), any(), any(), any()); + + mlPlanExecuteAndReflectAgentRunner = new MLPlanExecuteAndReflectAgentRunner( + client, + settings, + clusterService, + xContentRegistry, + toolFactories, + memoryMap, + sdkClient, + encryptor + ); + + // Setup tools + when(firstToolFactory.create(any())).thenReturn(firstTool); + when(secondToolFactory.create(any())).thenReturn(secondTool); + when(firstTool.getName()).thenReturn(FIRST_TOOL); + when(firstTool.getDescription()).thenReturn("First tool description"); + when(secondTool.getName()).thenReturn(SECOND_TOOL); + when(secondTool.getDescription()).thenReturn("Second tool description"); + when(firstTool.validate(any())).thenReturn(true); + when(secondTool.validate(any())).thenReturn(true); + } + + @Test + public void testBasicExecution() { + // Create MLAgent with tools and parameters + Map agentParams = new HashMap<>(); + agentParams.put("system_prompt", "You are a helpful assistant"); + agentParams.put("max_steps", "10"); + + MLAgent mlAgent = createMLAgentWithTools(agentParams); + + // Setup LLM response for planning phase + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder() + .dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\"], \"result\":\"final result\"}")) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + // Setup tool execution response + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + ModelTensor modelTensor = ModelTensor.builder() + .dataAsMap(ImmutableMap.of("response", "tool execution result")) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlExecuteTaskResponse); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(MLExecuteTaskRequest.class), any()); + + // Setup memory manager update response + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + // Run the agent + Map params = new HashMap<>(); + params.put("question", "test question"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify the response + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + assertNotNull(modelTensorOutput); + } + + @Test + public void testExecutionWithHistory() { + // Create MLAgent with tools and parameters + Map agentParams = new HashMap<>(); + agentParams.put("system_prompt", "You are a helpful assistant"); + agentParams.put("max_steps", "10"); + + MLAgent mlAgent = createMLAgentWithTools(agentParams); + + // Setup LLM response for planning phase + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder() + .dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\"], \"result\":\"final result\"}")) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + // Setup tool execution response + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + ModelTensor modelTensor = ModelTensor.builder() + .dataAsMap(ImmutableMap.of("response", "tool execution result")) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlExecuteTaskResponse); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(MLExecuteTaskRequest.class), any()); + + // Setup memory manager update response + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + // Run the agent with history + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("memory_id", "test_memory_id"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify the response + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + assertNotNull(modelTensorOutput); + } + + @Test + public void testExecutionWithMaxSteps() { + // Create MLAgent with tools and parameters + Map agentParams = new HashMap<>(); + agentParams.put("system_prompt", "You are a helpful assistant"); + agentParams.put("max_steps", "10"); + + MLAgent mlAgent = createMLAgentWithTools(agentParams); + + // Setup LLM response for planning phase + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder() + .dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\", \"step2\", \"step3\"], \"result\":\"\"}")) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + // Setup tool execution response + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder() + .dataAsMap(ImmutableMap.of("response", "tool execution result")) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlExecuteTaskResponse); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(MLExecuteTaskRequest.class), any()); + + // Setup memory manager update response + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + // Run the agent with max steps + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("max_steps", "2"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify the response + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + assertNotNull(modelTensorOutput); + } + + // Helper methods + private MLAgent createMLAgentWithTools(Map parameters) { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLToolSpec firstToolSpec = MLToolSpec + .builder() + .name(FIRST_TOOL) + .type(FIRST_TOOL) + .parameters(ImmutableMap.of("key1", "value1", "key2", "value2")) + .build(); + return MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .tools(Arrays.asList(firstToolSpec)) + .memory(mlMemorySpec) + .llm(llmSpec) + .parameters(parameters) + .build(); + } + + private List generateInteractions(int size) { + return Arrays.asList( + Interaction.builder().id("interaction-1").input("input-1").response("response-1").build(), + Interaction.builder().id("interaction-2").input("input-2").response("response-2").build() + ); + } +} \ No newline at end of file From 305a7cd870187f547235775373084a372de1024f Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Tue, 29 Apr 2025 14:29:26 -0700 Subject: [PATCH 2/7] wip Signed-off-by: Pavan Yekbote --- ...LPlanExecuteAndReflectAgentRunnerTest.java | 107 +++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 05b45deb43..614ac443ec 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -365,4 +365,109 @@ private List generateInteractions(int size) { Interaction.builder().id("interaction-2").input("input-2").response("response-2").build() ); } -} \ No newline at end of file + + /** + * Test the run method with an unsupported LLM interface. + * This test verifies that the method throws an MLException when an unsupported LLM interface is provided. + */ + @Test(expected = org.opensearch.ml.common.exception.MLException.class) + public void testRunWithUnsupportedLLMInterface() { + Map parameters = new HashMap<>(); + parameters.put("llm_interface", "unsupported_interface"); + MLAgent mlAgent = createMLAgentWithTools(parameters); + + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener); + } + + /** + * Testcase 1 for @Override public void run(MLAgent mlAgent, Map apiParams, ActionListener listener) + * Path constraints: (Strings.isNullOrEmpty(response)), (!completedSteps.isEmpty()) + */ + @Test + public void test_run_1() { + // Create MLAgent with necessary parameters + Map parameters = new HashMap<>(); + parameters.put("memory_id", "test_memory_id"); + parameters.put("question", "test_question"); + MLAgent mlAgent = createMLAgentWithTools(parameters); + + // Mock the behavior to satisfy the path constraints + when(conversationIndexMemory.getMessages(any(), anyInt())).thenAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + List interactions = Arrays.asList( + Interaction.builder().id("interaction-1").input("input-1").response("").build(), + Interaction.builder().id("interaction-2").input("input-2").response("response-2").build() + ); + listener.onResponse(interactions); + return null; + }); + + // Execute the method under test + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, parameters, agentActionListener); + + // Verify that the memory factory was called with the correct parameters + verify(memoryFactory).create(eq("test_question"), eq("test_memory_id"), any(), any()); + + // Verify that the conversation index memory's getMessages method was called + verify(conversationIndexMemory).getMessages(any(), eq(10)); + + // Additional verifications can be added here based on the expected behavior + // For example, you might want to verify that certain methods were called on the client + // or that the agentActionListener was invoked with the expected result + } + + /** + * Testcase 2 for @Override public void run(MLAgent mlAgent, Map apiParams, ActionListener listener) + * Path constraints: !((Strings.isNullOrEmpty(response))), (!completedSteps.isEmpty()) + */ + @Test + public void test_run_2() { + // Setup + MLAgent mlAgent = createMLAgentWithTools(new HashMap<>()); + Map apiParams = new HashMap<>(); + apiParams.put("question", "Test question"); + + // Mock behavior to satisfy path constraints + when(conversationIndexMemory.getMessages(any(), anyInt())).thenAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + listener.onResponse(Arrays.asList( + Interaction.builder().id("interaction-1").input("input-1").response("response-1").build(), + Interaction.builder().id("interaction-2").input("input-2").response("response-2").build() + )); + return null; + }); + + // Execute + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, apiParams, agentActionListener); + + // Verify + verify(conversationIndexMemory).getMessages(any(), eq(10)); + // Add more verifications as needed to ensure the correct path is taken + } + + /** + * Testcase 3 for @Override public void run(MLAgent mlAgent, Map apiParams, ActionListener listener) + * Path constraints: (Strings.isNullOrEmpty(response)), !((!completedSteps.isEmpty())) + */ + @Test + public void test_run_3() { + // Setup + MLAgent mlAgent = createMLAgentWithTools(new HashMap<>()); + Map apiParams = new HashMap<>(); + apiParams.put("QUESTION_FIELD", "Test question"); + + // Mock conversation index memory to return empty interactions + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + listener.onResponse(Arrays.asList()); + return null; + }).when(conversationIndexMemory).getMessages(any(), anyInt()); + + // Execute + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, apiParams, agentActionListener); + + // Verify + verify(memoryFactory).create(any(), any(), any(), any()); + verify(conversationIndexMemory).getMessages(any(), eq(10)); + } +} From 7cf85faef84dd2903f2361b0e72ed8bd4996890e Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Wed, 4 Jun 2025 13:52:46 -0700 Subject: [PATCH 3/7] feat: add more test cases for per agent Signed-off-by: Pavan Yekbote --- .../MLPlanExecuteAndReflectAgentRunner.java | 34 +- ...LPlanExecuteAndReflectAgentRunnerTest.java | 439 ++++++++++++------ 2 files changed, 324 insertions(+), 149 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index ef12c385c7..c43bbfb08f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -66,6 +66,7 @@ import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.client.Client; +import com.google.common.annotations.VisibleForTesting; import com.jayway.jsonpath.JsonPath; import joptsimple.internal.Strings; @@ -154,7 +155,8 @@ public MLPlanExecuteAndReflectAgentRunner( this.plannerWithHistoryPromptTemplate = DEFAULT_PLANNER_WITH_HISTORY_PROMPT_TEMPLATE; } - private void setupPromptParameters(Map params) { + @VisibleForTesting + void setupPromptParameters(Map params) { // populated depending on whether LLM is asked to plan or re-evaluate // removed here, so that error is thrown in case this field is not populated params.remove(PROMPT_FIELD); @@ -203,22 +205,26 @@ private void setupPromptParameters(Map params) { } } - private void usePlannerPromptTemplate(Map params) { + @VisibleForTesting + void usePlannerPromptTemplate(Map params) { params.put(PROMPT_TEMPLATE_FIELD, this.plannerPromptTemplate); populatePrompt(params); } - private void useReflectPromptTemplate(Map params) { + @VisibleForTesting + void useReflectPromptTemplate(Map params) { params.put(PROMPT_TEMPLATE_FIELD, this.reflectPromptTemplate); populatePrompt(params); } - private void usePlannerWithHistoryPromptTemplate(Map params) { + @VisibleForTesting + void usePlannerWithHistoryPromptTemplate(Map params) { params.put(PROMPT_TEMPLATE_FIELD, this.plannerWithHistoryPromptTemplate); populatePrompt(params); } - private void populatePrompt(Map allParams) { + @VisibleForTesting + void populatePrompt(Map allParams) { String promptTemplate = allParams.get(PROMPT_TEMPLATE_FIELD); StringSubstitutor promptSubstitutor = new StringSubstitutor(allParams, "${parameters.", "}"); String prompt = promptSubstitutor.replace(promptTemplate); @@ -475,7 +481,8 @@ private void executePlanningLoop( client.execute(MLPredictionTaskAction.INSTANCE, request, planListener); } - private Map parseLLMOutput(Map allParams, ModelTensorOutput modelTensorOutput) { + @VisibleForTesting + Map parseLLMOutput(Map allParams, ModelTensorOutput modelTensorOutput) { Map modelOutput = new HashMap<>(); Map dataAsMap = modelTensorOutput.getMlModelOutputs().getFirst().getMlModelTensors().getFirst().getDataAsMap(); String llmResponse; @@ -513,7 +520,8 @@ private Map parseLLMOutput(Map allParams, ModelT return modelOutput; } - private String extractJsonFromMarkdown(String response) { + @VisibleForTesting + String extractJsonFromMarkdown(String response) { response = response.trim(); if (response.contains("```json")) { response = response.substring(response.indexOf("```json") + "```json".length()); @@ -530,7 +538,8 @@ private String extractJsonFromMarkdown(String response) { return response; } - private void addToolsToPrompt(Map tools, Map allParams) { + @VisibleForTesting + void addToolsToPrompt(Map tools, Map allParams) { StringBuilder toolsPrompt = new StringBuilder("In this environment, you have access to the below tools: \n"); for (Map.Entry entry : tools.entrySet()) { String toolName = entry.getKey(); @@ -543,11 +552,13 @@ private void addToolsToPrompt(Map tools, Map allPa cleanUpResource(tools); } - private void addSteps(List steps, Map allParams, String field) { + @VisibleForTesting + void addSteps(List steps, Map allParams, String field) { allParams.put(field, String.join(", ", steps)); } - private void saveAndReturnFinalResult( + @VisibleForTesting + void saveAndReturnFinalResult( ConversationIndexMemory memory, String parentInteractionId, String reactAgentMemoryId, @@ -586,7 +597,8 @@ private void saveAndReturnFinalResult( })); } - private static List createModelTensors( + @VisibleForTesting + static List createModelTensors( String sessionId, String parentInteractionId, String reactAgentMemoryId, diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 614ac443ec..c3c76f4972 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -7,17 +7,20 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -28,7 +31,6 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.action.StepListener; import org.opensearch.action.update.UpdateResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -40,6 +42,7 @@ import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -55,10 +58,11 @@ import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.MLMemoryManager; import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; -import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.client.Client; +import com.google.common.collect.ImmutableMap; + public class MLPlanExecuteAndReflectAgentRunnerTest { public static final String FIRST_TOOL = "firstTool"; public static final String SECOND_TOOL = "secondTool"; @@ -99,18 +103,18 @@ public class MLPlanExecuteAndReflectAgentRunnerTest { @Mock private UpdateResponse updateResponse; @Mock - private MLExecuteTaskResponse mlExecuteTaskResponse; - @Mock private MLTaskResponse mlTaskResponse; + @Mock + private MLExecuteTaskResponse mlExecuteTaskResponse; @Captor private ArgumentCaptor objectCaptor; + @Captor private ArgumentCaptor> memoryFactoryCapture; + @Captor private ArgumentCaptor>> memoryInteractionCapture; - @Captor - private ArgumentCaptor> toolParamsCapture; private MLMemorySpec mlMemorySpec; @@ -121,15 +125,15 @@ public void setup() { settings = Settings.builder().build(); toolFactories = ImmutableMap.of(FIRST_TOOL, firstToolFactory, SECOND_TOOL, secondToolFactory); - // Setup memory + // memory mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10); when(memoryMap.get(anyString())).thenReturn(memoryFactory); - when(conversationIndexMemory.getConversationId()).thenReturn("conversation_id"); + when(conversationIndexMemory.getConversationId()).thenReturn("test_memory_id"); when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); when(createInteractionResponse.getId()).thenReturn("create_interaction_id"); when(updateResponse.getId()).thenReturn("update_interaction_id"); - // Setup memory factory + // memory factory doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(conversationIndexMemory); @@ -139,7 +143,7 @@ public void setup() { // Setup conversation index memory doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(0); - listener.onResponse(generateInteractions(2)); + listener.onResponse(generateInteractions()); return null; }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), anyInt()); @@ -174,17 +178,13 @@ public void setup() { @Test public void testBasicExecution() { - // Create MLAgent with tools and parameters - Map agentParams = new HashMap<>(); - agentParams.put("system_prompt", "You are a helpful assistant"); - agentParams.put("max_steps", "10"); - - MLAgent mlAgent = createMLAgentWithTools(agentParams); + MLAgent mlAgent = createMLAgentWithTools(); // Setup LLM response for planning phase doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); - ModelTensor modelTensor = ModelTensor.builder() + ModelTensor modelTensor = ModelTensor + .builder() .dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\"], \"result\":\"final result\"}")) .build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); @@ -197,9 +197,7 @@ public void testBasicExecution() { // Setup tool execution response doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - ModelTensor modelTensor = ModelTensor.builder() - .dataAsMap(ImmutableMap.of("response", "tool execution result")) - .build(); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "tool execution result")).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); @@ -217,6 +215,7 @@ public void testBasicExecution() { // Run the agent Map params = new HashMap<>(); params.put("question", "test question"); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_test_id"); mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); // Verify the response @@ -224,22 +223,40 @@ public void testBasicExecution() { Object response = objectCaptor.getValue(); assertTrue(response instanceof ModelTensorOutput); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; - assertNotNull(modelTensorOutput); + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + ModelTensors firstModelTensors = mlModelOutputs.get(0); + List firstModelTensorList = firstModelTensors.getMlModelTensors(); + assertEquals(2, firstModelTensorList.size()); + + ModelTensor memoryIdTensor = firstModelTensorList.get(0); + assertEquals("memory_id", memoryIdTensor.getName()); + assertEquals("test_memory_id", memoryIdTensor.getResult()); + + ModelTensor parentInteractionModelTensor = firstModelTensorList.get(1); + assertEquals("parent_interaction_id", parentInteractionModelTensor.getName()); + assertEquals("test_parent_interaction_id", parentInteractionModelTensor.getResult()); + + ModelTensors secondModelTensors = mlModelOutputs.get(1); + List secondModelTensorList = secondModelTensors.getMlModelTensors(); + assertEquals(1, secondModelTensorList.size()); + + ModelTensor responseTensor = secondModelTensorList.get(0); + assertEquals("response", responseTensor.getName()); + assertEquals("final result", responseTensor.getDataAsMap().get("response")); } @Test public void testExecutionWithHistory() { - // Create MLAgent with tools and parameters - Map agentParams = new HashMap<>(); - agentParams.put("system_prompt", "You are a helpful assistant"); - agentParams.put("max_steps", "10"); - - MLAgent mlAgent = createMLAgentWithTools(agentParams); + MLAgent mlAgent = createMLAgentWithTools(); // Setup LLM response for planning phase doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); - ModelTensor modelTensor = ModelTensor.builder() + ModelTensor modelTensor = ModelTensor + .builder() .dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\"], \"result\":\"final result\"}")) .build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); @@ -252,9 +269,7 @@ public void testExecutionWithHistory() { // Setup tool execution response doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - ModelTensor modelTensor = ModelTensor.builder() - .dataAsMap(ImmutableMap.of("response", "tool execution result")) - .build(); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "tool execution result")).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); @@ -273,6 +288,7 @@ public void testExecutionWithHistory() { Map params = new HashMap<>(); params.put("question", "test question"); params.put("memory_id", "test_memory_id"); + params.put("parent_interaction_id", "test_parent_interaction_id"); mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); // Verify the response @@ -280,22 +296,39 @@ public void testExecutionWithHistory() { Object response = objectCaptor.getValue(); assertTrue(response instanceof ModelTensorOutput); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; - assertNotNull(modelTensorOutput); + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + ModelTensors firstModelTensors = mlModelOutputs.get(0); + List firstModelTensorList = firstModelTensors.getMlModelTensors(); + assertEquals(2, firstModelTensorList.size()); + + ModelTensor memoryIdTensor = firstModelTensorList.get(0); + assertEquals("memory_id", memoryIdTensor.getName()); + assertEquals("test_memory_id", memoryIdTensor.getResult()); + + ModelTensor parentInteractionModelTensor = firstModelTensorList.get(1); + assertEquals("parent_interaction_id", parentInteractionModelTensor.getName()); + assertEquals("test_parent_interaction_id", parentInteractionModelTensor.getResult()); + + ModelTensors secondModelTensors = mlModelOutputs.get(1); + List secondModelTensorList = secondModelTensors.getMlModelTensors(); + assertEquals(1, secondModelTensorList.size()); + + ModelTensor responseTensor = secondModelTensorList.get(0); + assertEquals("response", responseTensor.getName()); + assertEquals("final result", responseTensor.getDataAsMap().get("response")); } @Test public void testExecutionWithMaxSteps() { - // Create MLAgent with tools and parameters - Map agentParams = new HashMap<>(); - agentParams.put("system_prompt", "You are a helpful assistant"); - agentParams.put("max_steps", "10"); - - MLAgent mlAgent = createMLAgentWithTools(agentParams); + MLAgent mlAgent = createMLAgentWithTools(); - // Setup LLM response for planning phase doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); - ModelTensor modelTensor = ModelTensor.builder() + ModelTensor modelTensor = ModelTensor + .builder() .dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\", \"step2\", \"step3\"], \"result\":\"\"}")) .build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); @@ -305,12 +338,9 @@ public void testExecutionWithMaxSteps() { return null; }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); - // Setup tool execution response doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); - ModelTensor modelTensor = ModelTensor.builder() - .dataAsMap(ImmutableMap.of("response", "tool execution result")) - .build(); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "tool execution result")).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); @@ -318,29 +348,53 @@ public void testExecutionWithMaxSteps() { return null; }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(MLExecuteTaskRequest.class), any()); - // Setup memory manager update response doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); listener.onResponse(updateResponse); return null; }).when(mlMemoryManager).updateInteraction(any(), any(), any()); - // Run the agent with max steps Map params = new HashMap<>(); params.put("question", "test question"); params.put("max_steps", "2"); + params.put("parent_interaction_id", "test_parent_interaction_id"); mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); - // Verify the response verify(agentActionListener).onResponse(objectCaptor.capture()); Object response = objectCaptor.getValue(); assertTrue(response instanceof ModelTensorOutput); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; - assertNotNull(modelTensorOutput); + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + ModelTensors firstModelTensors = mlModelOutputs.get(0); + List firstModelTensorList = firstModelTensors.getMlModelTensors(); + assertEquals(2, firstModelTensorList.size()); + + ModelTensor memoryIdTensor = firstModelTensorList.get(0); + assertEquals("memory_id", memoryIdTensor.getName()); + assertEquals("test_memory_id", memoryIdTensor.getResult()); + + ModelTensor parentInteractionModelTensor = firstModelTensorList.get(1); + assertEquals("parent_interaction_id", parentInteractionModelTensor.getName()); + assertEquals("test_parent_interaction_id", parentInteractionModelTensor.getResult()); + + ModelTensors secondModelTensors = mlModelOutputs.get(1); + List secondModelTensorList = secondModelTensors.getMlModelTensors(); + assertEquals(1, secondModelTensorList.size()); + + ModelTensor responseTensor = secondModelTensorList.get(0); + assertEquals("response", responseTensor.getName()); + assertEquals( + "Max Steps Limit Reached. Use memory_id with same task to restart. \n" + + " Last executed step: step1, \n" + + " Last executed step result: tool execution result", + responseTensor.getDataAsMap().get("response") + ); } - // Helper methods - private MLAgent createMLAgentWithTools(Map parameters) { + private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); MLToolSpec firstToolSpec = MLToolSpec .builder() @@ -348,6 +402,7 @@ private MLAgent createMLAgentWithTools(Map parameters) { .type(FIRST_TOOL) .parameters(ImmutableMap.of("key1", "value1", "key2", "value2")) .build(); + return MLAgent .builder() .name("TestAgent") @@ -355,119 +410,227 @@ private MLAgent createMLAgentWithTools(Map parameters) { .tools(Arrays.asList(firstToolSpec)) .memory(mlMemorySpec) .llm(llmSpec) - .parameters(parameters) + .parameters(Collections.emptyMap()) .build(); } - private List generateInteractions(int size) { - return Arrays.asList( - Interaction.builder().id("interaction-1").input("input-1").response("response-1").build(), - Interaction.builder().id("interaction-2").input("input-2").response("response-2").build() + private List generateInteractions() { + return Arrays + .asList( + Interaction.builder().id("interaction-1").input("input-1").response("response-1").build(), + Interaction.builder().id("interaction-2").input("input-2").response("response-2").build() + ); + } + + @Test + public void testSetupPromptParameters() { + Map testParams = new HashMap<>(); + testParams.put(MLPlanExecuteAndReflectAgentRunner.QUESTION_FIELD, "test question"); + testParams.put(MLPlanExecuteAndReflectAgentRunner.SYSTEM_PROMPT_FIELD, "custom system prompt"); + + mlPlanExecuteAndReflectAgentRunner.setupPromptParameters(testParams); + + assertEquals("test question", testParams.get(MLPlanExecuteAndReflectAgentRunner.USER_PROMPT_FIELD)); + assertTrue(testParams.get(MLPlanExecuteAndReflectAgentRunner.SYSTEM_PROMPT_FIELD).contains("custom system prompt")); + assertTrue(testParams.get(MLPlanExecuteAndReflectAgentRunner.SYSTEM_PROMPT_FIELD).contains("Always respond in JSON format")); + assertNotNull(testParams.get(MLPlanExecuteAndReflectAgentRunner.PLANNER_PROMPT_FIELD)); + assertNotNull(testParams.get(MLPlanExecuteAndReflectAgentRunner.REFLECT_PROMPT_FIELD)); + assertEquals( + PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT, + testParams.get(MLPlanExecuteAndReflectAgentRunner.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT_FIELD) ); } - /** - * Test the run method with an unsupported LLM interface. - * This test verifies that the method throws an MLException when an unsupported LLM interface is provided. - */ - @Test(expected = org.opensearch.ml.common.exception.MLException.class) - public void testRunWithUnsupportedLLMInterface() { - Map parameters = new HashMap<>(); - parameters.put("llm_interface", "unsupported_interface"); - MLAgent mlAgent = createMLAgentWithTools(parameters); - - mlPlanExecuteAndReflectAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener); + @Test + public void testUsePlannerPromptTemplate() { + Map testParams = new HashMap<>(); + mlPlanExecuteAndReflectAgentRunner.usePlannerPromptTemplate(testParams); + assertNotNull(testParams.get(MLPlanExecuteAndReflectAgentRunner.PROMPT_TEMPLATE_FIELD)); + assertNotNull(testParams.get(MLPlanExecuteAndReflectAgentRunner.PROMPT_FIELD)); } - /** - * Testcase 1 for @Override public void run(MLAgent mlAgent, Map apiParams, ActionListener listener) - * Path constraints: (Strings.isNullOrEmpty(response)), (!completedSteps.isEmpty()) - */ @Test - public void test_run_1() { - // Create MLAgent with necessary parameters - Map parameters = new HashMap<>(); - parameters.put("memory_id", "test_memory_id"); - parameters.put("question", "test_question"); - MLAgent mlAgent = createMLAgentWithTools(parameters); - - // Mock the behavior to satisfy the path constraints - when(conversationIndexMemory.getMessages(any(), anyInt())).thenAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(0); - List interactions = Arrays.asList( - Interaction.builder().id("interaction-1").input("input-1").response("").build(), - Interaction.builder().id("interaction-2").input("input-2").response("response-2").build() - ); - listener.onResponse(interactions); - return null; - }); + public void testUseReflectPromptTemplate() { + Map testParams = new HashMap<>(); + mlPlanExecuteAndReflectAgentRunner.useReflectPromptTemplate(testParams); + assertNotNull(testParams.get(MLPlanExecuteAndReflectAgentRunner.PROMPT_TEMPLATE_FIELD)); + assertNotNull(testParams.get(MLPlanExecuteAndReflectAgentRunner.PROMPT_FIELD)); + } - // Execute the method under test - mlPlanExecuteAndReflectAgentRunner.run(mlAgent, parameters, agentActionListener); + @Test + public void testUsePlannerWithHistoryPromptTemplate() { + Map testParams = new HashMap<>(); + mlPlanExecuteAndReflectAgentRunner.usePlannerWithHistoryPromptTemplate(testParams); + assertNotNull(testParams.get(MLPlanExecuteAndReflectAgentRunner.PROMPT_TEMPLATE_FIELD)); + assertNotNull(testParams.get(MLPlanExecuteAndReflectAgentRunner.PROMPT_FIELD)); + } - // Verify that the memory factory was called with the correct parameters - verify(memoryFactory).create(eq("test_question"), eq("test_memory_id"), any(), any()); + @Test + public void testPopulatePrompt() { + Map testParams = new HashMap<>(); + testParams.put(MLPlanExecuteAndReflectAgentRunner.PROMPT_TEMPLATE_FIELD, "Hello ${parameters.name}!"); + testParams.put("name", "World"); - // Verify that the conversation index memory's getMessages method was called - verify(conversationIndexMemory).getMessages(any(), eq(10)); + mlPlanExecuteAndReflectAgentRunner.populatePrompt(testParams); - // Additional verifications can be added here based on the expected behavior - // For example, you might want to verify that certain methods were called on the client - // or that the agentActionListener was invoked with the expected result + assertEquals("Hello World!", testParams.get(MLPlanExecuteAndReflectAgentRunner.PROMPT_FIELD)); } - /** - * Testcase 2 for @Override public void run(MLAgent mlAgent, Map apiParams, ActionListener listener) - * Path constraints: !((Strings.isNullOrEmpty(response))), (!completedSteps.isEmpty()) - */ @Test - public void test_run_2() { - // Setup - MLAgent mlAgent = createMLAgentWithTools(new HashMap<>()); - Map apiParams = new HashMap<>(); - apiParams.put("question", "Test question"); - - // Mock behavior to satisfy path constraints - when(conversationIndexMemory.getMessages(any(), anyInt())).thenAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(0); - listener.onResponse(Arrays.asList( - Interaction.builder().id("interaction-1").input("input-1").response("response-1").build(), - Interaction.builder().id("interaction-2").input("input-2").response("response-2").build() - )); + public void testParseLLMOutput() { + Map allParams = new HashMap<>(); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap( + Map.of(MLPlanExecuteAndReflectAgentRunner.RESPONSE_FIELD, "{\"steps\":[\"step1\",\"step2\"],\"result\":\"final result\"}") + ) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + Map result = mlPlanExecuteAndReflectAgentRunner.parseLLMOutput(allParams, modelTensorOutput); + + assertEquals("step1, step2", result.get(MLPlanExecuteAndReflectAgentRunner.STEPS_FIELD)); + assertEquals("final result", result.get(MLPlanExecuteAndReflectAgentRunner.RESULT_FIELD)); + + modelTensor = ModelTensor.builder().dataAsMap(Map.of(MLPlanExecuteAndReflectAgentRunner.RESPONSE_FIELD, "random response")).build(); + modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + final ModelTensorOutput modelTensorOutput2 = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + assertThrows(IllegalStateException.class, () -> mlPlanExecuteAndReflectAgentRunner.parseLLMOutput(allParams, modelTensorOutput2)); + + modelTensor = ModelTensor + .builder() + .dataAsMap(Map.of(MLPlanExecuteAndReflectAgentRunner.RESPONSE_FIELD, "{ \"random\": \"random response\"}")) + .build(); + modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + final ModelTensorOutput modelTensorOutput3 = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + assertThrows( + IllegalArgumentException.class, + () -> mlPlanExecuteAndReflectAgentRunner.parseLLMOutput(allParams, modelTensorOutput3) + ); + } + + @Test + public void testExtractJsonFromMarkdown() { + String markdown = "```json\n{\"key\":\"value\"}\n```"; + String result = mlPlanExecuteAndReflectAgentRunner.extractJsonFromMarkdown(markdown); + assertEquals("{\"key\":\"value\"}", result); + } + + @Test + public void testAddToolsToPrompt() { + Map testParams = new HashMap<>(); + Map tools = new HashMap<>(); + Tool tool1 = mock(Tool.class); + when(tool1.getName()).thenReturn("tool1"); + when(tool1.getDescription()).thenReturn("description1"); + tools.put("tool1", tool1); + + mlPlanExecuteAndReflectAgentRunner.addToolsToPrompt(tools, testParams); + + assertEquals( + "In this environment, you have access to the below tools: \n- tool1: description1\n\n", + testParams.get(MLPlanExecuteAndReflectAgentRunner.DEFAULT_PROMPT_TOOLS_FIELD) + ); + } + + @Test + public void testAddSteps() { + Map testParams = new HashMap<>(); + List steps = Arrays.asList("step1", "step2"); + String field = "test_field"; + + mlPlanExecuteAndReflectAgentRunner.addSteps(steps, testParams, field); + + assertEquals("step1, step2", testParams.get(field)); + } + + @Test + public void testCreateModelTensors() { + String sessionId = "test_session"; + String parentInteractionId = "test_parent"; + String executorMemoryId = "test_executor_mem_id"; + String executorParentId = "test_executor_parent_id"; + + List result = MLPlanExecuteAndReflectAgentRunner.createModelTensors(sessionId, parentInteractionId, executorMemoryId, executorParentId); + + assertNotNull(result); + assertEquals(1, result.size()); + ModelTensors tensors = result.get(0); + assertEquals(4, tensors.getMlModelTensors().size()); + assertEquals(sessionId, tensors.getMlModelTensors().get(0).getResult()); + assertEquals(parentInteractionId, tensors.getMlModelTensors().get(1).getResult()); + assertEquals(executorMemoryId, tensors.getMlModelTensors().get(2).getResult()); + assertEquals(executorParentId, tensors.getMlModelTensors().get(3).getResult()); + } + + @Test + public void testSaveAndReturnFinalResult() { + String parentInteractionId = "test_parent_id"; + String finalResult = "test final result"; + String input = "test input"; + String conversationId = "test_conversation_id"; + String executorMemoryId = "test_executor_mem_id"; + String executorParentId = "test_executor_parent_id"; + + when(conversationIndexMemory.getConversationId()).thenReturn(conversationId); + when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); return null; - }); + }).when(mlMemoryManager).updateInteraction(eq(parentInteractionId), any(), any()); - // Execute - mlPlanExecuteAndReflectAgentRunner.run(mlAgent, apiParams, agentActionListener); + mlPlanExecuteAndReflectAgentRunner + .saveAndReturnFinalResult(conversationIndexMemory, parentInteractionId, executorMemoryId, executorParentId, finalResult, input, agentActionListener); - // Verify - verify(conversationIndexMemory).getMessages(any(), eq(10)); - // Add more verifications as needed to ensure the correct path is taken + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + ModelTensors firstModelTensors = mlModelOutputs.get(0); + List firstModelTensorList = firstModelTensors.getMlModelTensors(); + assertEquals(4, firstModelTensorList.size()); + assertEquals(conversationId, firstModelTensorList.get(0).getResult()); + assertEquals(parentInteractionId, firstModelTensorList.get(1).getResult()); + assertEquals(executorMemoryId, firstModelTensorList.get(2).getResult()); + assertEquals(executorParentId, firstModelTensorList.get(3).getResult()); + + ModelTensors secondModelTensors = mlModelOutputs.get(1); + List secondModelTensorList = secondModelTensors.getMlModelTensors(); + assertEquals(1, secondModelTensorList.size()); + assertEquals(finalResult, secondModelTensorList.get(0).getDataAsMap().get("response")); } - /** - * Testcase 3 for @Override public void run(MLAgent mlAgent, Map apiParams, ActionListener listener) - * Path constraints: (Strings.isNullOrEmpty(response)), !((!completedSteps.isEmpty())) - */ @Test - public void test_run_3() { - // Setup - MLAgent mlAgent = createMLAgentWithTools(new HashMap<>()); - Map apiParams = new HashMap<>(); - apiParams.put("QUESTION_FIELD", "Test question"); + public void testSaveAndReturnFinalResultWithError() { + String parentInteractionId = "test_parent_id"; + String finalResult = "test final result"; + String input = "test input"; + String conversationId = "test_conversation_id"; + String executorMemoryId = "test_executor_mem_id"; + String executorParentId = "test_executor_parent_id"; + Exception expectedException = new MLException("Test error"); + + when(conversationIndexMemory.getConversationId()).thenReturn(conversationId); + when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); - // Mock conversation index memory to return empty interactions doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(0); - listener.onResponse(Arrays.asList()); + ActionListener listener = invocation.getArgument(2); + listener.onFailure(expectedException); return null; - }).when(conversationIndexMemory).getMessages(any(), anyInt()); + }).when(mlMemoryManager).updateInteraction(eq(parentInteractionId), any(), any()); - // Execute - mlPlanExecuteAndReflectAgentRunner.run(mlAgent, apiParams, agentActionListener); + mlPlanExecuteAndReflectAgentRunner + .saveAndReturnFinalResult(conversationIndexMemory, executorMemoryId, executorParentId, parentInteractionId, finalResult, input, agentActionListener); - // Verify - verify(memoryFactory).create(any(), any(), any(), any()); - verify(conversationIndexMemory).getMessages(any(), eq(10)); + verify(agentActionListener).onFailure(expectedException); } } From ae6269a6236816e76df5628765d84788e4b15df2 Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Wed, 4 Jun 2025 14:24:25 -0700 Subject: [PATCH 4/7] fix: saveAndReturnFinalResult testcase Signed-off-by: Pavan Yekbote --- ...LPlanExecuteAndReflectAgentRunnerTest.java | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index c3c76f4972..d27cc025fb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -608,29 +608,4 @@ public void testSaveAndReturnFinalResult() { assertEquals(1, secondModelTensorList.size()); assertEquals(finalResult, secondModelTensorList.get(0).getDataAsMap().get("response")); } - - @Test - public void testSaveAndReturnFinalResultWithError() { - String parentInteractionId = "test_parent_id"; - String finalResult = "test final result"; - String input = "test input"; - String conversationId = "test_conversation_id"; - String executorMemoryId = "test_executor_mem_id"; - String executorParentId = "test_executor_parent_id"; - Exception expectedException = new MLException("Test error"); - - when(conversationIndexMemory.getConversationId()).thenReturn(conversationId); - when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(expectedException); - return null; - }).when(mlMemoryManager).updateInteraction(eq(parentInteractionId), any(), any()); - - mlPlanExecuteAndReflectAgentRunner - .saveAndReturnFinalResult(conversationIndexMemory, executorMemoryId, executorParentId, parentInteractionId, finalResult, input, agentActionListener); - - verify(agentActionListener).onFailure(expectedException); - } } From 43c9c2f06cdebda30d99d208bae4307ce2b9036a Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Wed, 4 Jun 2025 17:49:45 -0700 Subject: [PATCH 5/7] feat: adding test cases for agentutils, connectorutils, transportregisteragent Signed-off-by: Pavan Yekbote --- .../algorithms/agent/AgentUtilsTest.java | 243 ++++++++++++++++++ ...LPlanExecuteAndReflectAgentRunnerTest.java | 14 +- .../algorithms/remote/ConnectorUtilsTest.java | 49 ++++ .../RegisterAgentTransportActionTests.java | 88 +++++++ 4 files changed, 391 insertions(+), 3 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index 0a39383e17..8cc04cbc79 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -9,24 +9,29 @@ import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD; import static org.opensearch.ml.common.CommonValue.MCP_CONNECTOR_ID_FIELD; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_PATH; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_TOOL_USE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_GEN_INPUT; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_EXCLUDE_PATH; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOLS; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_PATH; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_INPUT; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_NAME; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID_PATH; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_FILTERS_FIELD; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_TEMPLATE; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY; @@ -81,6 +86,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.Client; +import com.google.gson.JsonSyntaxException; + public class AgentUtilsTest extends MLStaticMockBase { @Mock @@ -1197,6 +1204,48 @@ public void testParseLLMOutputWithDeepseekFormat() { Assert.assertTrue(output3.get(FINAL_ANSWER).contains("This is a test response")); } + @Test + public void testAddToolsToFunctionCalling() { + Map tools = new HashMap<>(); + tools.put("Tool1", tool1); + tools.put("Tool2", tool2); + + when(tool1.getName()).thenReturn("Tool1"); + when(tool1.getDescription()).thenReturn("Description of Tool1"); + when(tool1.getAttributes()).thenReturn(Map.of("param1", "value1")); + + when(tool2.getName()).thenReturn("Tool2"); + when(tool2.getDescription()).thenReturn("Description of Tool2"); + when(tool2.getAttributes()).thenReturn(Map.of("param2", "value2")); + + Map parameters = new HashMap<>(); + String toolTemplate = "{\"name\": \"${tool.name}\", \"description\": \"${tool.description}\"}"; + parameters.put(TOOL_TEMPLATE, toolTemplate); + + List inputTools = Arrays.asList("Tool1", "Tool2"); + String prompt = "test prompt"; + + String expectedTool1 = "{\"name\": \"Tool1\", \"description\": \"Description of Tool1\"}"; + String expectedTool2 = "{\"name\": \"Tool2\", \"description\": \"Description of Tool2\"}"; + String expectedTools = expectedTool1 + ", " + expectedTool2; + + AgentUtils.addToolsToFunctionCalling(tools, parameters, inputTools, prompt); + + assertEquals(expectedTools, parameters.get(TOOLS)); + } + + @Test + public void testAddToolsToFunctionCalling_ToolNotRegistered() { + Map tools = new HashMap<>(); + tools.put("Tool1", tool1); + Map parameters = new HashMap<>(); + parameters.put(TOOL_TEMPLATE, "template"); + List inputTools = Arrays.asList("Tool1", "UnregisteredTool"); + String prompt = "test prompt"; + + assertThrows(IllegalArgumentException.class, () -> AgentUtils.addToolsToFunctionCalling(tools, parameters, inputTools, prompt)); + } + private static MLToolSpec buildTool(String name) { return MLToolSpec.builder().type(McpSseTool.TYPE).name(name).description("mock").build(); } @@ -1362,4 +1411,198 @@ private void verifyConstructToolParams(String question, String actionInput, Cons Map toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); verify.accept(toolParams); } + + @Test + public void testParseLLMOutput_WithExcludePath() { + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_EXCLUDE_PATH, "[\"$.exclude_field\"]"); + + Map dataAsMap = new HashMap<>(); + dataAsMap.put("exclude_field", "should be excluded"); + dataAsMap.put("keep_field", "should be kept"); + + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs( + List + .of( + ModelTensors + .builder() + .mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(dataAsMap).build())) + .build() + ) + ) + .build(); + + Map output = AgentUtils.parseLLMOutput(parameters, modelTensorOutput, null, Set.of(), new ArrayList<>()); + + Assert.assertTrue(output.containsKey(THOUGHT_RESPONSE)); + Assert.assertFalse(output.get(THOUGHT_RESPONSE).contains("exclude_field")); + Assert.assertTrue(output.get(THOUGHT_RESPONSE).contains("keep_field")); + } + + @Test + public void testParseLLMOutput_EmptyDataAsMap() { + Map dataAsMap = new HashMap<>(); + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs( + List + .of( + ModelTensors + .builder() + .mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(dataAsMap).build())) + .build() + ) + ) + .build(); + + Map output = AgentUtils.parseLLMOutput(new HashMap<>(), modelTensorOutput, null, Set.of(), new ArrayList<>()); + + Assert.assertTrue(output.containsKey(THOUGHT_RESPONSE)); + Assert.assertEquals("{}", output.get(THOUGHT_RESPONSE)); + } + + @Test + public void testParseLLMOutput_ToolUse() { + Map parameters = new HashMap<>(); + parameters.put(TOOL_CALLS_PATH, "$.tool_calls"); + parameters.put(TOOL_CALLS_TOOL_NAME, "name"); + parameters.put(TOOL_CALLS_TOOL_INPUT, "input"); + parameters.put(TOOL_CALL_ID_PATH, "id"); + parameters.put(LLM_RESPONSE_FILTER, "$.response"); + parameters.put(LLM_FINISH_REASON_PATH, "$.finish_reason"); + parameters.put(LLM_FINISH_REASON_TOOL_USE, "tool_use"); + + Map dataAsMap = new HashMap<>(); + dataAsMap.put("tool_calls", List.of(Map.of("name", "test_tool", "input", "test_input", "id", "test_id"))); + dataAsMap.put("response", "test response"); + dataAsMap.put("finish_reason", "tool_use"); + + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs( + List + .of( + ModelTensors + .builder() + .mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(dataAsMap).build())) + .build() + ) + ) + .build(); + + Map output = AgentUtils.parseLLMOutput(parameters, modelTensorOutput, null, Set.of("test_tool"), new ArrayList<>()); + + Assert.assertEquals("test_tool", output.get(ACTION)); + Assert.assertEquals("test_input", output.get(ACTION_INPUT)); + Assert.assertEquals("test_id", output.get(TOOL_CALL_ID)); + } + + @Test + public void testRemoveJsonPath_WithStringPaths() { + Map json = new HashMap<>(); + json.put("field1", "value1"); + json.put("field2", "value2"); + json.put("nested", Map.of("field3", "value3")); + String excludePaths = "[\"$.field1\", \"$.nested.field3\"]"; + Map result = AgentUtils.removeJsonPath(json, excludePaths, false); + Assert.assertFalse(result.containsKey("field1")); + Assert.assertTrue(result.containsKey("field2")); + Assert.assertTrue(result.containsKey("nested")); + Assert.assertFalse(((Map) result.get("nested")).containsKey("field3")); + } + + @Test + public void testRemoveJsonPath_WithListPaths() { + Map json = new HashMap<>(); + json.put("field1", "value1"); + json.put("field2", "value2"); + json.put("nested", Map.of("field3", "value3")); + List excludePaths = java.util.Arrays.asList("$.field1", "$.nested.field3"); + Map result = AgentUtils.removeJsonPath(json, excludePaths, false); + Assert.assertFalse(result.containsKey("field1")); + Assert.assertTrue(result.containsKey("field2")); + Assert.assertTrue(result.containsKey("nested")); + Assert.assertFalse(((Map) result.get("nested")).containsKey("field3")); + } + + @Test + public void testRemoveJsonPath_InPlace() { + Map json = new HashMap<>(); + json.put("field1", "value1"); + json.put("field2", "value2"); + json.put("nested", new HashMap<>(Map.of("field3", "value3"))); + List excludePaths = java.util.Arrays.asList("$.field1", "$.nested.field3"); + Map result = AgentUtils.removeJsonPath(json, excludePaths, true); + Assert.assertFalse(json.containsKey("field1")); + Assert.assertTrue(json.containsKey("field2")); + Assert.assertTrue(json.containsKey("nested")); + Assert.assertFalse(((Map) json.get("nested")).containsKey("field3")); + Assert.assertSame(json, result); + } + + @Test + public void testRemoveJsonPath_WithInvalidJsonPaths() { + Map json = new HashMap<>(); + json.put("field1", "value1"); + String invalidJsonPaths = "invalid json"; + Assert.assertThrows(JsonSyntaxException.class, () -> AgentUtils.removeJsonPath(json, invalidJsonPaths, false)); + } + + @Test + public void testSubstitute() { + String template = "Hello ${parameters.name}! Welcome to ${parameters.place}."; + Map params = new HashMap<>(); + params.put("name", "AI"); + params.put("place", "OpenSearch"); + String prefix = "${parameters."; + + String result = AgentUtils.substitute(template, params, prefix); + + Assert.assertEquals("Hello AI! Welcome to OpenSearch.", result); + } + + @Test + public void testCreateTool_Success() { + Map toolFactories = new HashMap<>(); + Tool.Factory factory = mock(Tool.Factory.class); + Tool mockTool = mock(Tool.class); + when(factory.create(any())).thenReturn(mockTool); + toolFactories.put("test_tool", factory); + + MLToolSpec toolSpec = MLToolSpec + .builder() + .type("test_tool") + .name("TestTool") + .description("Original description") + .parameters(Map.of("param1", "value1")) + .runtimeResources(Map.of("resource1", "value2")) + .build(); + + Map params = new HashMap<>(); + params.put("TestTool.param2", "value3"); + params.put("TestTool.description", "Custom description"); + + AgentUtils.createTool(toolFactories, params, toolSpec, "test_tenant"); + + verify(factory).create(argThat(toolParamsMap -> { + Map toolParams = (Map) toolParamsMap; + return toolParams.get("param1").equals("value1") + && toolParams.get("param2").equals("value3") + && toolParams.get("resource1").equals("value2") + && toolParams.get(TENANT_ID_FIELD).equals("test_tenant"); + })); + + verify(mockTool).setName("TestTool"); + verify(mockTool).setDescription("Custom description"); + } + + @Test + public void testCreateTool_ToolNotFound() { + Map toolFactories = new HashMap<>(); + MLToolSpec toolSpec = MLToolSpec.builder().type("non_existent_tool").name("TestTool").build(); + + assertThrows(IllegalArgumentException.class, () -> AgentUtils.createTool(toolFactories, new HashMap<>(), toolSpec, "test_tenant")); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index d27cc025fb..4c00bc8b0b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -42,7 +42,6 @@ import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.Interaction; -import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -554,7 +553,8 @@ public void testCreateModelTensors() { String executorMemoryId = "test_executor_mem_id"; String executorParentId = "test_executor_parent_id"; - List result = MLPlanExecuteAndReflectAgentRunner.createModelTensors(sessionId, parentInteractionId, executorMemoryId, executorParentId); + List result = MLPlanExecuteAndReflectAgentRunner + .createModelTensors(sessionId, parentInteractionId, executorMemoryId, executorParentId); assertNotNull(result); assertEquals(1, result.size()); @@ -585,7 +585,15 @@ public void testSaveAndReturnFinalResult() { }).when(mlMemoryManager).updateInteraction(eq(parentInteractionId), any(), any()); mlPlanExecuteAndReflectAgentRunner - .saveAndReturnFinalResult(conversationIndexMemory, parentInteractionId, executorMemoryId, executorParentId, finalResult, input, agentActionListener); + .saveAndReturnFinalResult( + conversationIndexMemory, + parentInteractionId, + executorMemoryId, + executorParentId, + finalResult, + input, + agentActionListener + ); verify(agentActionListener).onResponse(objectCaptor.capture()); Object response = objectCaptor.getValue(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 335dc95245..6b80c19300 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -386,4 +386,53 @@ public void testGetTask_createCancelBatchActionForBedrock() { ); assertNull(result.getRequestBody()); } + + @Test + public void testEscapeRemoteInferenceInputData_WithSpecialCharacters() { + Map params = new HashMap<>(); + params.put("key1", "hello \"world\" \n \t"); + params.put("key2", "test value"); + + RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build(); + + ConnectorUtils.escapeRemoteInferenceInputData(inputData); + + assertEquals("hello \\\"world\\\" \\n \\t", inputData.getParameters().get("key1")); + assertEquals("test value", inputData.getParameters().get("key2")); + } + + @Test + public void testEscapeRemoteInferenceInputData_WithJsonValues() { + Map params = new HashMap<>(); + params.put("key1", "{\"name\": \"test\", \"value\": 123}"); + params.put("key2", "[\"item1\", \"item2\"]"); + + RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build(); + + ConnectorUtils.escapeRemoteInferenceInputData(inputData); + + assertEquals("{\"name\": \"test\", \"value\": 123}", inputData.getParameters().get("key1")); + assertEquals("[\"item1\", \"item2\"]", inputData.getParameters().get("key2")); + } + + @Test + public void testEscapeRemoteInferenceInputData_WithNoEscapeParams() { + Map params = new HashMap<>(); + String inputKey1 = "hello \"world\""; + String inputKey3 = "special \"chars\""; + params.put("key1", inputKey1); + params.put("key2", "test value"); + params.put("key3", inputKey3); + params.put("NO_ESCAPE_PARAMS", "key1,key3"); + + RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build(); + + ConnectorUtils.escapeRemoteInferenceInputData(inputData); + + String expectedKey1 = "hello \\\"world\\\""; + String expectedKey3 = "special \\\"chars\\\""; + assertEquals(expectedKey1, inputData.getParameters().get("key1")); + assertEquals("test value", inputData.getParameters().get("key2")); + assertEquals(expectedKey3, inputData.getParameters().get("key3")); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java index 84264e960e..fe1320c6d9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java @@ -8,14 +8,17 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED; +import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.EXECUTOR_AGENT_ID_FIELD; import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.Map; import java.util.Set; import org.junit.Before; @@ -278,4 +281,89 @@ public void test_execute_registerAgent_Othertype() { assertNotNull(argumentCaptor.getValue()); } + @Test + public void test_execute_registerAgent_PlanExecuteAndReflect_WithoutExecutorAgentId() { + MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); + Map parameters = new HashMap<>(); + parameters.put("tools", "[]"); + parameters.put("memory", "{}"); + + LLMSpec llmSpec = new LLMSpec("test-model-id", new HashMap<>()); + + MLAgent mlAgent = MLAgent + .builder() + .name("test_agent") + .type(MLAgentType.PLAN_EXECUTE_AND_REFLECT.name()) + .description("Test agent for plan-execute-and-reflect") + .parameters(parameters) + .llm(llmSpec) + .build(); + when(request.getMlAgent()).thenReturn(mlAgent); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLAgentIndex(any()); + + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + + transportRegisterAgentAction.doExecute(task, request, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + + MLRegisterAgentResponse response = argumentCaptor.getValue(); + assertNotNull(response); + assertEquals("AGENT_ID", response.getAgentId()); + verify(client, times(2)).index(any(), any()); + } + + @Test + public void test_execute_registerAgent_PlanExecuteAndReflect_WithExecutorAgentId() { + MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); + Map parameters = new HashMap<>(); + parameters.put("tools", "[]"); + parameters.put("memory", "{}"); + parameters.put(EXECUTOR_AGENT_ID_FIELD, "existing-executor-id"); + + LLMSpec llmSpec = new LLMSpec("test-model-id", new HashMap<>()); + + MLAgent mlAgent = MLAgent + .builder() + .name("test_agent") + .type(MLAgentType.PLAN_EXECUTE_AND_REFLECT.name()) + .description("Test agent for plan-execute-and-reflect") + .parameters(parameters) + .llm(llmSpec) + .build(); + when(request.getMlAgent()).thenReturn(mlAgent); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLAgentIndex(any()); + + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + + transportRegisterAgentAction.doExecute(task, request, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + + MLRegisterAgentResponse response = argumentCaptor.getValue(); + assertNotNull(response); + assertEquals("AGENT_ID", response.getAgentId()); + + verify(client, times(1)).index(any(), any()); + } } From f30417fe6959e484a2e2878a6f26dff058f0666f Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Wed, 4 Jun 2025 18:24:07 -0700 Subject: [PATCH 6/7] chore: remove max steps test post rebase Signed-off-by: Pavan Yekbote --- ...LPlanExecuteAndReflectAgentRunnerTest.java | 80 +------------------ 1 file changed, 3 insertions(+), 77 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 4c00bc8b0b..6b8accaedc 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -214,7 +214,7 @@ public void testBasicExecution() { // Run the agent Map params = new HashMap<>(); params.put("question", "test question"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_test_id"); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test_parent_interaction_id"); mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); // Verify the response @@ -228,7 +228,7 @@ public void testBasicExecution() { ModelTensors firstModelTensors = mlModelOutputs.get(0); List firstModelTensorList = firstModelTensors.getMlModelTensors(); - assertEquals(2, firstModelTensorList.size()); + assertEquals(4, firstModelTensorList.size()); ModelTensor memoryIdTensor = firstModelTensorList.get(0); assertEquals("memory_id", memoryIdTensor.getName()); @@ -301,7 +301,7 @@ public void testExecutionWithHistory() { ModelTensors firstModelTensors = mlModelOutputs.get(0); List firstModelTensorList = firstModelTensors.getMlModelTensors(); - assertEquals(2, firstModelTensorList.size()); + assertEquals(4, firstModelTensorList.size()); ModelTensor memoryIdTensor = firstModelTensorList.get(0); assertEquals("memory_id", memoryIdTensor.getName()); @@ -320,79 +320,6 @@ public void testExecutionWithHistory() { assertEquals("final result", responseTensor.getDataAsMap().get("response")); } - @Test - public void testExecutionWithMaxSteps() { - MLAgent mlAgent = createMLAgentWithTools(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - ModelTensor modelTensor = ModelTensor - .builder() - .dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\", \"step2\", \"step3\"], \"result\":\"\"}")) - .build(); - ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); - ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); - when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); - listener.onResponse(mlTaskResponse); - return null; - }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "tool execution result")).build(); - ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); - ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); - when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); - listener.onResponse(mlExecuteTaskResponse); - return null; - }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(MLExecuteTaskRequest.class), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(updateResponse); - return null; - }).when(mlMemoryManager).updateInteraction(any(), any(), any()); - - Map params = new HashMap<>(); - params.put("question", "test question"); - params.put("max_steps", "2"); - params.put("parent_interaction_id", "test_parent_interaction_id"); - mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); - - verify(agentActionListener).onResponse(objectCaptor.capture()); - Object response = objectCaptor.getValue(); - assertTrue(response instanceof ModelTensorOutput); - ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; - - List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); - assertEquals(2, mlModelOutputs.size()); - - ModelTensors firstModelTensors = mlModelOutputs.get(0); - List firstModelTensorList = firstModelTensors.getMlModelTensors(); - assertEquals(2, firstModelTensorList.size()); - - ModelTensor memoryIdTensor = firstModelTensorList.get(0); - assertEquals("memory_id", memoryIdTensor.getName()); - assertEquals("test_memory_id", memoryIdTensor.getResult()); - - ModelTensor parentInteractionModelTensor = firstModelTensorList.get(1); - assertEquals("parent_interaction_id", parentInteractionModelTensor.getName()); - assertEquals("test_parent_interaction_id", parentInteractionModelTensor.getResult()); - - ModelTensors secondModelTensors = mlModelOutputs.get(1); - List secondModelTensorList = secondModelTensors.getMlModelTensors(); - assertEquals(1, secondModelTensorList.size()); - - ModelTensor responseTensor = secondModelTensorList.get(0); - assertEquals("response", responseTensor.getName()); - assertEquals( - "Max Steps Limit Reached. Use memory_id with same task to restart. \n" - + " Last executed step: step1, \n" - + " Last executed step result: tool execution result", - responseTensor.getDataAsMap().get("response") - ); - } - private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); MLToolSpec firstToolSpec = MLToolSpec @@ -431,7 +358,6 @@ public void testSetupPromptParameters() { assertEquals("test question", testParams.get(MLPlanExecuteAndReflectAgentRunner.USER_PROMPT_FIELD)); assertTrue(testParams.get(MLPlanExecuteAndReflectAgentRunner.SYSTEM_PROMPT_FIELD).contains("custom system prompt")); - assertTrue(testParams.get(MLPlanExecuteAndReflectAgentRunner.SYSTEM_PROMPT_FIELD).contains("Always respond in JSON format")); assertNotNull(testParams.get(MLPlanExecuteAndReflectAgentRunner.PLANNER_PROMPT_FIELD)); assertNotNull(testParams.get(MLPlanExecuteAndReflectAgentRunner.REFLECT_PROMPT_FIELD)); assertEquals( From fa8bb9c9076a872c942b7c57c64c5dcd387f24fa Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Wed, 4 Jun 2025 18:27:18 -0700 Subject: [PATCH 7/7] chore: add todo for max_steps reached Signed-off-by: Pavan Yekbote --- .../agent/MLPlanExecuteAndReflectAgentRunnerTest.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 6b8accaedc..e608283ed3 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -320,6 +320,8 @@ public void testExecutionWithHistory() { assertEquals("final result", responseTensor.getDataAsMap().get("response")); } + // ToDo: add test case for when max steps is reached + private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); MLToolSpec firstToolSpec = MLToolSpec