diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 301304b556..513d177558 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -34,6 +34,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -452,4 +453,30 @@ public static List getToolNames(Map tools) { } return inputTools; } + + public static Map constructToolParams( + Map tools, + Map toolSpecMap, + String question, + AtomicReference lastActionInput, + String action, + String actionInput + ) { + Map toolParams = new HashMap<>(); + Map toolSpecParams = toolSpecMap.get(action).getParameters(); + if (toolSpecParams != null) { + toolParams.putAll(toolSpecParams); + } + if (tools.get(action).useOriginalInput()) { + toolParams.put("input", question); + lastActionInput.set(question); + } else { + toolParams.put("input", actionInput); + if (isJson(actionInput)) { + Map params = getParameterMap(gson.fromJson(actionInput, Map.class)); + toolParams.putAll(params); + } + } + return toolParams; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index bd64c36828..4740565dd1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -15,6 +15,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.RESPONSE_FORMAT_INSTRUCTION; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESPONSE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.VERBOSE; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.constructToolParams; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; @@ -484,28 +485,6 @@ private static void runTool( } } - private static Map constructToolParams( - Map tools, - Map toolSpecMap, - String question, - AtomicReference lastActionInput, - String action, - String actionInput - ) { - Map toolParams = new HashMap<>(); - Map toolSpecParams = toolSpecMap.get(action).getParameters(); - if (toolSpecParams != null) { - toolParams.putAll(toolSpecParams); - } - if (tools.get(action).useOriginalInput()) { - toolParams.put("input", question); - lastActionInput.set(question); - } else { - toolParams.put("input", actionInput); - } - return toolParams; - } - private static void saveTraceData( ConversationIndexMemory conversationIndexMemory, String memory, diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index 0a0af3f60c..0fea6062e4 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -25,12 +25,15 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -569,4 +572,38 @@ public void testExtractThought_InvalidResult() { Assert.assertEquals("Let me search our index to find population projections", result); } + @Test + public void testConstructToolParams() { + String question = "dummy question"; + String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }"; + verifyConstructToolParams(question, actionInput, (toolParams) -> { + Assert.assertEquals(4, toolParams.size()); + Assert.assertEquals(actionInput, toolParams.get("input")); + Assert.assertEquals("abc", toolParams.get("detectorName")); + Assert.assertEquals("sample-data", toolParams.get("indices")); + Assert.assertEquals("value1", toolParams.get("key1")); + }); + } + + @Test + public void testConstructToolParams_UseOriginalInput() { + String question = "dummy question"; + String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }"; + when(tool1.useOriginalInput()).thenReturn(true); + verifyConstructToolParams(question, actionInput, (toolParams) -> { + Assert.assertEquals(2, toolParams.size()); + Assert.assertEquals(question, toolParams.get("input")); + Assert.assertEquals("value1", toolParams.get("key1")); + }); + } + + private void verifyConstructToolParams(String question, String actionInput, Consumer> verify) { + Map tools = Map.of("tool1", tool1); + Map toolSpecMap = Map + .of("tool1", MLToolSpec.builder().type("tool1").parameters(Map.of("key1", "value1")).build()); + AtomicReference lastActionInput = new AtomicReference<>(); + String action = "tool1"; + Map toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); + verify.accept(toolParams); + } }