Skip to content

Commit 235fb9c

Browse files
authored
parse tool input to map (#2131)
Signed-off-by: Yaliang Wu <[email protected]>
1 parent 896caac commit 235fb9c

File tree

3 files changed

+65
-22
lines changed

3 files changed

+65
-22
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import java.util.Map;
3535
import java.util.Optional;
3636
import java.util.Set;
37+
import java.util.concurrent.atomic.AtomicReference;
3738
import java.util.regex.Matcher;
3839
import java.util.regex.Pattern;
3940

@@ -452,4 +453,30 @@ public static List<String> getToolNames(Map<String, Tool> tools) {
452453
}
453454
return inputTools;
454455
}
456+
457+
public static Map<String, String> constructToolParams(
458+
Map<String, Tool> tools,
459+
Map<String, MLToolSpec> toolSpecMap,
460+
String question,
461+
AtomicReference<String> lastActionInput,
462+
String action,
463+
String actionInput
464+
) {
465+
Map<String, String> toolParams = new HashMap<>();
466+
Map<String, String> toolSpecParams = toolSpecMap.get(action).getParameters();
467+
if (toolSpecParams != null) {
468+
toolParams.putAll(toolSpecParams);
469+
}
470+
if (tools.get(action).useOriginalInput()) {
471+
toolParams.put("input", question);
472+
lastActionInput.set(question);
473+
} else {
474+
toolParams.put("input", actionInput);
475+
if (isJson(actionInput)) {
476+
Map<String, String> params = getParameterMap(gson.fromJson(actionInput, Map.class));
477+
toolParams.putAll(params);
478+
}
479+
}
480+
return toolParams;
481+
}
455482
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.RESPONSE_FORMAT_INSTRUCTION;
1616
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESPONSE;
1717
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.VERBOSE;
18+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.constructToolParams;
1819
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools;
1920
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit;
2021
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
@@ -484,28 +485,6 @@ private static void runTool(
484485
}
485486
}
486487

487-
private static Map<String, String> constructToolParams(
488-
Map<String, Tool> tools,
489-
Map<String, MLToolSpec> toolSpecMap,
490-
String question,
491-
AtomicReference<String> lastActionInput,
492-
String action,
493-
String actionInput
494-
) {
495-
Map<String, String> toolParams = new HashMap<>();
496-
Map<String, String> toolSpecParams = toolSpecMap.get(action).getParameters();
497-
if (toolSpecParams != null) {
498-
toolParams.putAll(toolSpecParams);
499-
}
500-
if (tools.get(action).useOriginalInput()) {
501-
toolParams.put("input", question);
502-
lastActionInput.set(question);
503-
} else {
504-
toolParams.put("input", actionInput);
505-
}
506-
return toolParams;
507-
}
508-
509488
private static void saveTraceData(
510489
ConversationIndexMemory conversationIndexMemory,
511490
String memory,

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@
2525
import java.util.List;
2626
import java.util.Map;
2727
import java.util.Set;
28+
import java.util.concurrent.atomic.AtomicReference;
29+
import java.util.function.Consumer;
2830

2931
import org.junit.Assert;
3032
import org.junit.Before;
3133
import org.junit.Test;
3234
import org.mockito.Mock;
3335
import org.mockito.MockitoAnnotations;
36+
import org.opensearch.ml.common.agent.MLToolSpec;
3437
import org.opensearch.ml.common.output.model.ModelTensor;
3538
import org.opensearch.ml.common.output.model.ModelTensorOutput;
3639
import org.opensearch.ml.common.output.model.ModelTensors;
@@ -569,4 +572,38 @@ public void testExtractThought_InvalidResult() {
569572
Assert.assertEquals("Let me search our index to find population projections", result);
570573
}
571574

575+
@Test
576+
public void testConstructToolParams() {
577+
String question = "dummy question";
578+
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
579+
verifyConstructToolParams(question, actionInput, (toolParams) -> {
580+
Assert.assertEquals(4, toolParams.size());
581+
Assert.assertEquals(actionInput, toolParams.get("input"));
582+
Assert.assertEquals("abc", toolParams.get("detectorName"));
583+
Assert.assertEquals("sample-data", toolParams.get("indices"));
584+
Assert.assertEquals("value1", toolParams.get("key1"));
585+
});
586+
}
587+
588+
@Test
589+
public void testConstructToolParams_UseOriginalInput() {
590+
String question = "dummy question";
591+
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
592+
when(tool1.useOriginalInput()).thenReturn(true);
593+
verifyConstructToolParams(question, actionInput, (toolParams) -> {
594+
Assert.assertEquals(2, toolParams.size());
595+
Assert.assertEquals(question, toolParams.get("input"));
596+
Assert.assertEquals("value1", toolParams.get("key1"));
597+
});
598+
}
599+
600+
private void verifyConstructToolParams(String question, String actionInput, Consumer<Map<String, String>> verify) {
601+
Map<String, Tool> tools = Map.of("tool1", tool1);
602+
Map<String, MLToolSpec> toolSpecMap = Map
603+
.of("tool1", MLToolSpec.builder().type("tool1").parameters(Map.of("key1", "value1")).build());
604+
AtomicReference<String> lastActionInput = new AtomicReference<>();
605+
String action = "tool1";
606+
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
607+
verify.accept(toolParams);
608+
}
572609
}

0 commit comments

Comments
 (0)