Skip to content

Commit c649905

Browse files
authored
openai v1 chat completions function calling (#3681)
Signed-off-by: Jing Zhang <[email protected]>
1 parent 8fe7efe commit c649905

11 files changed

+273
-100
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

+110-54
Large diffs are not rendered by default.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

+30-13
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
import static org.opensearch.ml.common.utils.StringUtils.processTextDoc;
1212
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE;
1313
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.INTERACTIONS_PREFIX;
14-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
15-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.NO_ESCAPE_PARAMS;
1614
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX;
1715
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
1816
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
@@ -193,14 +191,14 @@ public void run(MLAgent mlAgent, Map<String, String> inputParams, ActionListener
193191
List<String> chatHistory = new ArrayList<>();
194192
for (Message message : messageList) {
195193
Map<String, String> messageParams = new HashMap<>();
196-
messageParams.put("question", processTextDoc(((ConversationIndexMessage)message).getQuestion()));
194+
messageParams.put("question", processTextDoc(((ConversationIndexMessage) message).getQuestion()));
197195

198196
StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}");
199197
String chatQuestionMessage = substitutor.replace(chatHistoryQuestionTemplate);
200198
chatHistory.add(chatQuestionMessage);
201199

202200
messageParams.clear();
203-
messageParams.put("response", processTextDoc(((ConversationIndexMessage)message).getResponse()));
201+
messageParams.put("response", processTextDoc(((ConversationIndexMessage) message).getResponse()));
204202
substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}");
205203
String chatResponseMessage = substitutor.replace(chatHistoryResponseTemplate);
206204
chatHistory.add(chatResponseMessage);
@@ -283,7 +281,13 @@ private void runReAct(
283281
MLTaskResponse llmResponse = (MLTaskResponse) output;
284282
ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput();
285283
List<String> llmResponsePatterns = gson.fromJson(tmpParameters.get("llm_response_pattern"), List.class);
286-
Map<String, String> modelOutput = parseLLMOutput(parameters, tmpModelTensorOutput, llmResponsePatterns, tools.keySet(), interactions);
284+
Map<String, String> modelOutput = parseLLMOutput(
285+
parameters,
286+
tmpModelTensorOutput,
287+
llmResponsePatterns,
288+
tools.keySet(),
289+
interactions
290+
);
287291

288292
String thought = String.valueOf(modelOutput.get(THOUGHT));
289293
String toolCallId = String.valueOf(modelOutput.get("tool_call_id"));
@@ -354,7 +358,8 @@ private void runReAct(
354358
action,
355359
actionInput,
356360
toolParams,
357-
interactions, toolCallId
361+
interactions,
362+
toolCallId
358363
);
359364
} else {
360365
String res = String.format(Locale.ROOT, "Failed to run the tool %s which is unsupported.", action);
@@ -460,8 +465,8 @@ private void runReAct(
460465

461466
private void cleanUpResource(Map<String, Tool> tools) {
462467
for (String key : tools.keySet()) {
463-
if (tools.get(key) instanceof McpSseTool) {//TODO: make this more general, avoid checking specific tool type
464-
((McpSseTool)tools.get(key)).getMcpSyncClient().closeGracefully();
468+
if (tools.get(key) instanceof McpSseTool) {// TODO: make this more general, avoid checking specific tool type
469+
((McpSseTool) tools.get(key)).getMcpSyncClient().closeGracefully();
465470
}
466471
}
467472
}
@@ -533,12 +538,24 @@ private static void runTool(
533538
try {
534539
String finalAction = action;
535540
ActionListener<Object> toolListener = ActionListener.wrap(r -> {
536-
interactions.add(substitute(tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE),
537-
Map.of("tool_call_id", toolCallId, "tool_response", processTextDoc(StringUtils.toJson(r))), INTERACTIONS_PREFIX));
541+
interactions
542+
.add(
543+
substitute(
544+
tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE),
545+
Map.of("tool_call_id", toolCallId, "tool_response", processTextDoc(StringUtils.toJson(r))),
546+
INTERACTIONS_PREFIX
547+
)
548+
);
538549
nextStepListener.onResponse(r);
539550
}, e -> {
540-
interactions.add(substitute(tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE),
541-
Map.of("tool_call_id", toolCallId, "tool_response", "Tool " + action + " failed: " + e.getMessage()), INTERACTIONS_PREFIX));
551+
interactions
552+
.add(
553+
substitute(
554+
tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE),
555+
Map.of("tool_call_id", toolCallId, "tool_response", "Tool " + action + " failed: " + e.getMessage()),
556+
INTERACTIONS_PREFIX
557+
)
558+
);
542559
nextStepListener
543560
.onResponse(
544561
String.format(Locale.ROOT, "Failed to run the tool %s with the error message %s.", finalAction, e.getMessage())
@@ -560,7 +577,7 @@ private static void runTool(
560577
nextStepListener
561578
.onResponse(String.format(Locale.ROOT, "Failed to run the tool %s with the error message %s.", action, e.getMessage()));
562579
}
563-
} else { //TODO: add failure to interaction to let LLM regenerate ?
580+
} else { // TODO: add failure to interaction to let LLM regenerate ?
564581
String res = String.format(Locale.ROOT, "Failed to run the tool %s due to wrong input %s.", action, actionInput);
565582
nextStepListener.onResponse(res);
566583
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet in
182182
} else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) {
183183
// no need to escape if it's already valid json
184184
newParameters.put(key, value);
185-
} else if (!noEscapParamSet.contains(key)){
185+
} else if (!noEscapParamSet.contains(key)) {
186186
newParameters.put(key, escapeJson(value));
187187
} else {
188188
newParameters.put(key, value);

ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/BedrockConverseFunctionCalling.java

+14-10
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
package org.opensearch.ml.engine.function_calling;
22

3-
import com.jayway.jsonpath.JsonPath;
4-
import lombok.Data;
5-
import org.opensearch.core.common.util.CollectionUtils;
6-
import org.opensearch.ml.common.output.model.ModelTensorOutput;
7-
import org.opensearch.ml.common.utils.StringUtils;
3+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_EXCLUDE_PATH;
4+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.removeJsonPath;
85

96
import java.util.ArrayList;
107
import java.util.List;
118
import java.util.Map;
129

13-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_EXCLUDE_PATH;
14-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
15-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.removeJsonPath;
10+
import org.opensearch.core.common.util.CollectionUtils;
11+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
12+
import org.opensearch.ml.common.utils.StringUtils;
13+
14+
import com.jayway.jsonpath.JsonPath;
15+
16+
import lombok.Data;
1617

1718
public class BedrockConverseFunctionCalling implements FunctionCalling {
1819
private static final String FINISH_REASON_PATH = "$.stopReason";
@@ -26,7 +27,11 @@ public class BedrockConverseFunctionCalling implements FunctionCalling {
2627

2728
@Override
2829
public void configure(Map<String, String> params) {
29-
params.put("tool_template", "{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}");
30+
params
31+
.put(
32+
"tool_template",
33+
"{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}"
34+
);
3035
params.put("tool_configs", ", \"toolConfig\": {\"tools\": [${parameters._tools:-}]}");
3136
}
3237

@@ -38,7 +43,6 @@ public List<Map<String, String>> handle(ModelTensorOutput tmpModelTensorOutput,
3843
if (llmResponseExcludePath != null) {
3944
dataAsMap = removeJsonPath(dataAsMap, llmResponseExcludePath, true);
4045
}
41-
Object response = JsonPath.read(dataAsMap, parameters.get(LLM_RESPONSE_FILTER));
4246
String llmFinishReason = JsonPath.read(dataAsMap, FINISH_REASON_PATH);
4347
if (!llmFinishReason.contentEquals(FINISH_REASON)) {
4448
return output;

ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/FunctionCalling.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package org.opensearch.ml.engine.function_calling;
22

3-
import org.opensearch.ml.common.output.model.ModelTensorOutput;
4-
53
import java.util.List;
64
import java.util.Map;
75

6+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
7+
88
/**
99
* A general LLM function calling interface.
1010
*/

ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/FunctionCallingFactory.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package org.opensearch.ml.engine.function_calling;
22

3-
import org.apache.commons.lang3.StringUtils;
4-
import org.opensearch.ml.common.exception.MLException;
3+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE;
4+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS;
55

66
import java.util.Locale;
77

8-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE;
8+
import org.apache.commons.lang3.StringUtils;
9+
import org.opensearch.ml.common.exception.MLException;
910

1011
public class FunctionCallingFactory {
1112
public FunctionCalling create(String llmInterface) {
@@ -16,6 +17,8 @@ public FunctionCalling create(String llmInterface) {
1617
switch (llmInterface.trim().toLowerCase(Locale.ROOT)) {
1718
case LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE:
1819
return new BedrockConverseFunctionCalling();
20+
case LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS:
21+
return new OpenaiV1ChatCompletionsFunctionCalling();
1922
default:
2023
throw new MLException(String.format("Unsupported llm interface: {}.", llmInterface));
2124
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/LLMMessage.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package org.opensearch.ml.engine.function_calling;
22

3-
import lombok.Data;
4-
53
import java.util.ArrayList;
64
import java.util.List;
75

6+
import lombok.Data;
7+
88
@Data
99
public class LLMMessage {
1010

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package org.opensearch.ml.engine.function_calling;
2+
3+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_EXCLUDE_PATH;
4+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.removeJsonPath;
5+
6+
import java.util.ArrayList;
7+
import java.util.List;
8+
import java.util.Map;
9+
10+
import org.opensearch.core.common.util.CollectionUtils;
11+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
12+
import org.opensearch.ml.common.utils.StringUtils;
13+
14+
import com.jayway.jsonpath.JsonPath;
15+
16+
import lombok.Data;
17+
18+
public class OpenaiV1ChatCompletionsFunctionCalling implements FunctionCalling {
19+
private static final String FINISH_REASON_PATH = "$.choices[0].finish_reason";
20+
private static final String FINISH_REASON = "tool_calls";
21+
private static final String CALL_PATH = "$.choices[0].message.tool_calls";
22+
private static final String NAME = "function.name";
23+
private static final String INPUT = "function.arguments";
24+
private static final String ID_PATH = "id";
25+
private static final String TOOL_ERROR = "tool_error";
26+
private static final String TOOL_RESULT = "tool_result";
27+
28+
@Override
29+
public void configure(Map<String, String> params) {
30+
params
31+
.put(
32+
"tool_template",
33+
"{\"type\": \"function\", \"function\": { \"name\": \"${tool.name}\", \"description\": \"${tool.description}\", \"parameters\": ${tool.attributes.input_schema}, \"strict\": ${tool.attributes.strict:-false} } }"
34+
);
35+
}
36+
37+
@Override
38+
public List<Map<String, String>> handle(ModelTensorOutput tmpModelTensorOutput, Map<String, String> parameters) {
39+
List<Map<String, String>> output = new ArrayList<>();
40+
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
41+
String llmResponseExcludePath = parameters.get(LLM_RESPONSE_EXCLUDE_PATH);
42+
if (llmResponseExcludePath != null) {
43+
dataAsMap = removeJsonPath(dataAsMap, llmResponseExcludePath, true);
44+
}
45+
String llmFinishReason = JsonPath.read(dataAsMap, FINISH_REASON_PATH);
46+
if (!llmFinishReason.contentEquals(FINISH_REASON)) {
47+
return output;
48+
}
49+
List toolCalls = JsonPath.read(dataAsMap, CALL_PATH);
50+
if (CollectionUtils.isEmpty(toolCalls)) {
51+
return output;
52+
}
53+
for (Object call : toolCalls) {
54+
String toolName = JsonPath.read(call, parameters.get(NAME));
55+
String toolInput = StringUtils.toJson(JsonPath.read(call, parameters.get(INPUT)));
56+
String toolCallId = JsonPath.read(call, parameters.get(ID_PATH));
57+
output.add(Map.of("tool_name", toolName, "tool_input", toolInput, "tool_call_id", toolCallId));
58+
}
59+
return output;
60+
}
61+
62+
@Override
63+
public LLMMessage supply(List<Map<String, String>> toolResults) {
64+
LLMMessage toolMessage = new LLMMessage();
65+
for (Map toolResult : toolResults) {
66+
String toolUseId = (String) toolResult.get(ID_PATH);
67+
if (toolUseId == null) {
68+
continue;
69+
}
70+
ToolResult result = new ToolResult();
71+
result.setRole("tool");
72+
result.setToolUseId(toolUseId);
73+
result.setContent((String) toolResult.get(TOOL_RESULT));
74+
toolMessage.getContent().add(result);
75+
}
76+
77+
return toolMessage;
78+
}
79+
80+
@Data
81+
public static class ToolResult {
82+
private String role;
83+
private String toolUseId;
84+
private String content;
85+
}
86+
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ public IndexMappingTool(Client client) {
6767
this.client = client;
6868

6969
this.attributes = new HashMap<>();
70-
attributes.put("input_schema", "{\"type\":\"object\",\"properties\":{\"index\":{\"type\":\"string\",\"description\":\"OpenSearch index name\"}},\"required\":[\"index\"],\"additionalProperties\":false}");
70+
attributes
71+
.put(
72+
"input_schema",
73+
"{\"type\":\"object\",\"properties\":{\"index\":{\"type\":\"string\",\"description\":\"OpenSearch index name\"}},\"required\":[\"index\"],\"additionalProperties\":false}"
74+
);
7175
attributes.put("strict", true);
7276

7377
outputParser = new Parser<>() {

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,11 @@ public Object parse(Object o) {
111111
};
112112

113113
this.attributes = new HashMap<>();
114-
attributes.put("input_schema", "{\"type\":\"object\",\"properties\":{\"indices\":{\"type\":\"string\",\"description\":\"OpenSearch index name list, separated by comma. for example: index1, index2\"}},\"additionalProperties\":false}");
114+
attributes
115+
.put(
116+
"input_schema",
117+
"{\"type\":\"object\",\"properties\":{\"indices\":{\"type\":\"string\",\"description\":\"OpenSearch index name list, separated by comma. for example: index1, index2\"}},\"additionalProperties\":false}"
118+
);
115119
attributes.put("strict", false);
116120
}
117121

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/McpSseTool.java

+12-13
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,21 @@
55

66
package org.opensearch.ml.engine.tools;
77

8-
import io.modelcontextprotocol.client.McpSyncClient;
9-
import io.modelcontextprotocol.spec.McpSchema;
10-
import lombok.Getter;
11-
import lombok.Setter;
12-
import lombok.extern.log4j.Log4j2;
8+
import java.util.List;
9+
import java.util.Map;
10+
1311
import org.opensearch.core.action.ActionListener;
1412
import org.opensearch.ml.common.spi.tools.Parser;
1513
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
1614
import org.opensearch.ml.common.spi.tools.WithModelTool;
1715
import org.opensearch.ml.common.utils.StringUtils;
1816
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
1917

20-
import java.util.List;
21-
import java.util.Map;
18+
import io.modelcontextprotocol.client.McpSyncClient;
19+
import io.modelcontextprotocol.spec.McpSchema;
20+
import lombok.Getter;
21+
import lombok.Setter;
22+
import lombok.extern.log4j.Log4j2;
2223

2324
/**
2425
* This tool supports running any ml-commons model.
@@ -40,7 +41,7 @@ public class McpSseTool implements WithModelTool {
4041
@Setter
4142
private String description = DEFAULT_DESCRIPTION;
4243
@Getter
43-
private McpSyncClient mcpSyncClient; //TODO:// close client when agent run finish
44+
private McpSyncClient mcpSyncClient; // TODO:// close client when agent run finish
4445
@Setter
4546
private Parser inputParser;
4647
@Setter
@@ -57,11 +58,9 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
5758
try {
5859
String input = parameters.get("input");
5960
Map<String, Object> inputArgs = StringUtils.fromJson(input, "input");
60-
McpSchema.CallToolResult result = mcpSyncClient.callTool(
61-
new McpSchema.CallToolRequest(this.name, inputArgs)
62-
);
61+
McpSchema.CallToolResult result = mcpSyncClient.callTool(new McpSchema.CallToolRequest(this.name, inputArgs));
6362
String resultJson = StringUtils.toJson(result.content());
64-
listener.onResponse((T)resultJson);
63+
listener.onResponse((T) resultJson);
6564
} catch (Exception e) {
6665
log.error("Failed to call MCP tool: {}", this.getName(), e);
6766
listener.onFailure(e);
@@ -116,7 +115,7 @@ public void init() {}
116115

117116
@Override
118117
public McpSseTool create(Map<String, Object> map) {
119-
return new McpSseTool((McpSyncClient)map.get("mcp_client"));
118+
return new McpSseTool((McpSyncClient) map.get("mcp_client"));
120119
}
121120

122121
@Override

0 commit comments

Comments
 (0)