Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.DEFAULT_NO_ESCAPE_PARAMS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.FINAL_ANSWER;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
Expand Down Expand Up @@ -95,6 +96,7 @@ public class AgentUtils {
public static final String LLM_FINISH_REASON_PATH = "llm_finish_reason_path";
public static final String LLM_FINISH_REASON_TOOL_USE = "llm_finish_reason_tool_use";
public static final String LLM_RESPONSE_EXCLUDE_PATH = "llm_response_exclude_path";
public static final String LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE = "bedrock/converse/claude";

public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> examplesMap = new HashMap<>();
Expand Down Expand Up @@ -371,7 +373,7 @@ private static String postFilterFinalAnswer(Map<String, String> parameters, Map<
return StringUtils.toJson(llmResponse);
}

private static Map<String, ?> removeJsonPath(Map<String, ?> json, String excludePaths, boolean inPlace) {
public static Map<String, ?> removeJsonPath(Map<String, ?> json, String excludePaths, boolean inPlace) {
Type listType = new TypeToken<List<String>>(){}.getType();
List<String> excludedPath = gson.fromJson(excludePaths, listType);
return removeJsonPath(json, excludedPath, inPlace);
Expand Down Expand Up @@ -784,4 +786,80 @@ public static Map<String, String> constructToolParams(
}
return toolParams;
}

public static void constructLLMInterfaceParams(String llmInterface, Map<String, String> params) {
if (org.apache.commons.lang3.StringUtils.isBlank(llmInterface)) {
log.debug("no llm interface");
return;
}

if ("openai/v1/chat/completions".equalsIgnoreCase(llmInterface)) {
if (!params.containsKey(NO_ESCAPE_PARAMS)) {
params.put(NO_ESCAPE_PARAMS, DEFAULT_NO_ESCAPE_PARAMS);
}
params.put(LLM_RESPONSE_FILTER, "$.choices[0].message.content");

params.put("tool_template", "{\"type\": \"function\", \"function\": { \"name\": \"${tool.name}\", \"description\": \"${tool.description}\", \"parameters\": ${tool.attributes.input_schema}, \"strict\": ${tool.attributes.strict:-false} } }");
params.put("tool_calls_path", "$.choices[0].message.tool_calls");
params.put("tool_calls.tool_name", "function.name");
params.put("tool_calls.tool_input", "function.arguments");
params.put("tool_calls.id_path", "id");

params.put("tool_choice", "auto");
params.put("parallel_tool_calls", "false");

params.put("interaction_template.assistant_tool_calls_path", "$.choices[0].message");
params.put("interaction_template.tool_response", "{ \"role\": \"tool\", \"tool_call_id\": \"${_interactions.tool_call_id}\", \"content\": \"${_interactions.tool_response}\" }");

params.put("chat_history_template.user_question", "{\"role\": \"user\",\"content\": \"${_chat_history.message.question}\"}");
params.put("chat_history_template.ai_response", "{\"role\": \"assistant\",\"content\": \"${_chat_history.message.response}\"}");

params.put("llm_finish_reason_path", "$.choices[0].finish_reason");
params.put("llm_finish_reason_tool_use", "tool_calls");
params.put("llm_response_filter", "$.choices[0].message.content");
} else if ("bedrock/converse/claude".equalsIgnoreCase(llmInterface)) {
if (!params.containsKey(NO_ESCAPE_PARAMS)) {
params.put(NO_ESCAPE_PARAMS, DEFAULT_NO_ESCAPE_PARAMS + ",tool_configs");
}
params.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text");

params.put("tool_template", "{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}");
params.put("tool_calls_path", "$.output.message.content[*].toolUse");
params.put("tool_calls.tool_name", "name");
params.put("tool_calls.tool_input", "input");
params.put("tool_calls.id_path", "toolUseId");
params.put("tool_configs", ", \"toolConfig\": {\"tools\": [${parameters._tools:-}]}");

params.put("interaction_template.assistant_tool_calls_path", "$.output.message");
params.put("interaction_template.tool_response", "{\"role\":\"user\",\"content\":[{\"toolResult\":{\"toolUseId\":\"${_interactions.tool_call_id}\",\"content\":[{\"text\":\"${_interactions.tool_response}\"}]}}]}");

params.put("chat_history_template.user_question", "{\"role\":\"user\",\"content\":[{\"text\":\"${_chat_history.message.question}\"}]}");
params.put("chat_history_template.ai_response", "{\"role\":\"assistant\",\"content\":[{\"text\":\"${_chat_history.message.response}\"}]}");

params.put("llm_finish_reason_path", "$.stopReason");
params.put("llm_finish_reason_tool_use", "tool_use");
} else if ("bedrock/converse/deepseek_r1".equalsIgnoreCase(llmInterface)) {
if (!params.containsKey(NO_ESCAPE_PARAMS)) {
params.put(NO_ESCAPE_PARAMS, "_chat_history,_interactions");
}
params.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text");
params.put("llm_final_response_post_filter", "$.message.content[0].text");

params.put("tool_template", "{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}");
params.put("tool_calls_path", "_llm_response.tool_calls");
params.put("tool_calls.tool_name", "tool_name");
params.put("tool_calls.tool_input", "input");
params.put("tool_calls.id_path", "id");

params.put("interaction_template.assistant_tool_calls_path", "$.output.message");
params.put("interaction_template.assistant_tool_calls_exclude_path", "[ \"$.output.message.content[?(@.reasoningContent)]\" ]");
params.put("interaction_template.tool_response", "{\"role\":\"user\",\"content\":[ {\"text\":\"{\\\"tool_call_id\\\":\\\"${_interactions.tool_call_id}\\\",\\\"tool_result\\\": \\\"${_interactions.tool_response}\\\"\"} ]}");

params.put("chat_history_template.user_question", "{\"role\":\"user\",\"content\":[{\"text\":\"${_chat_history.message.question}\"}]}");
params.put("chat_history_template.ai_response", "{\"role\":\"assistant\",\"content\":[{\"text\":\"${_chat_history.message.response}\"}]}");

params.put("llm_finish_reason_path", "_llm_response.stop_reason");
params.put("llm_finish_reason_tool_use", "tool_use");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.RESPONSE_FORMAT_INSTRUCTION;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESPONSE;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.VERBOSE;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.constructLLMInterfaceParams;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.constructToolParams;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit;
Expand Down Expand Up @@ -147,74 +148,7 @@ public void run(MLAgent mlAgent, Map<String, String> inputParams, ActionListener
}

String llmInterface = params.get(LLM_INTERFACE);
if ("openai/v1/chat/completions".equalsIgnoreCase(llmInterface)) {
if (!params.containsKey(NO_ESCAPE_PARAMS)) {
params.put(NO_ESCAPE_PARAMS, DEFAULT_NO_ESCAPE_PARAMS);
}
params.put(LLM_RESPONSE_FILTER, "$.choices[0].message.content");

params.put("tool_template", "{\"type\": \"function\", \"function\": { \"name\": \"${tool.name}\", \"description\": \"${tool.description}\", \"parameters\": ${tool.attributes.input_schema}, \"strict\": ${tool.attributes.strict:-false} } }");
params.put("tool_calls_path", "$.choices[0].message.tool_calls");
params.put("tool_calls.tool_name", "function.name");
params.put("tool_calls.tool_input", "function.arguments");
params.put("tool_calls.id_path", "id");

params.put("tool_choice", "auto");
params.put("parallel_tool_calls", "false");

params.put("interaction_template.assistant_tool_calls_path", "$.choices[0].message");
params.put("interaction_template.tool_response", "{ \"role\": \"tool\", \"tool_call_id\": \"${_interactions.tool_call_id}\", \"content\": \"${_interactions.tool_response}\" }");

params.put("chat_history_template.user_question", "{\"role\": \"user\",\"content\": \"${_chat_history.message.question}\"}");
params.put("chat_history_template.ai_response", "{\"role\": \"assistant\",\"content\": \"${_chat_history.message.response}\"}");

params.put("llm_finish_reason_path", "$.choices[0].finish_reason");
params.put("llm_finish_reason_tool_use", "tool_calls");
params.put("llm_response_filter", "$.choices[0].message.content");
} else if ("bedrock/converse/claude".equalsIgnoreCase(llmInterface)) {
if (!params.containsKey(NO_ESCAPE_PARAMS)) {
params.put(NO_ESCAPE_PARAMS, DEFAULT_NO_ESCAPE_PARAMS + ",tool_configs");
}
params.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text");

params.put("tool_template", "{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}");
params.put("tool_calls_path", "$.output.message.content[*].toolUse");
params.put("tool_calls.tool_name", "name");
params.put("tool_calls.tool_input", "input");
params.put("tool_calls.id_path", "toolUseId");
params.put("tool_configs", ", \"toolConfig\": {\"tools\": [${parameters._tools:-}]}");

params.put("interaction_template.assistant_tool_calls_path", "$.output.message");
params.put("interaction_template.tool_response", "{\"role\":\"user\",\"content\":[{\"toolResult\":{\"toolUseId\":\"${_interactions.tool_call_id}\",\"content\":[{\"text\":\"${_interactions.tool_response}\"}]}}]}");

params.put("chat_history_template.user_question", "{\"role\":\"user\",\"content\":[{\"text\":\"${_chat_history.message.question}\"}]}");
params.put("chat_history_template.ai_response", "{\"role\":\"assistant\",\"content\":[{\"text\":\"${_chat_history.message.response}\"}]}");

params.put("llm_finish_reason_path", "$.stopReason");
params.put("llm_finish_reason_tool_use", "tool_use");
} else if ("bedrock/converse/deepseek_r1".equalsIgnoreCase(llmInterface)) {
if (!params.containsKey(NO_ESCAPE_PARAMS)) {
params.put(NO_ESCAPE_PARAMS, "_chat_history,_interactions");
}
params.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text");
params.put("llm_final_response_post_filter", "$.message.content[0].text");

params.put("tool_template", "{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}");
params.put("tool_calls_path", "_llm_response.tool_calls");
params.put("tool_calls.tool_name", "tool_name");
params.put("tool_calls.tool_input", "input");
params.put("tool_calls.id_path", "id");

params.put("interaction_template.assistant_tool_calls_path", "$.output.message");
params.put("interaction_template.assistant_tool_calls_exclude_path", "[ \"$.output.message.content[?(@.reasoningContent)]\" ]");
params.put("interaction_template.tool_response", "{\"role\":\"user\",\"content\":[ {\"text\":\"{\\\"tool_call_id\\\":\\\"${_interactions.tool_call_id}\\\",\\\"tool_result\\\": \\\"${_interactions.tool_response}\\\"\"} ]}");

params.put("chat_history_template.user_question", "{\"role\":\"user\",\"content\":[{\"text\":\"${_chat_history.message.question}\"}]}");
params.put("chat_history_template.ai_response", "{\"role\":\"assistant\",\"content\":[{\"text\":\"${_chat_history.message.response}\"}]}");

params.put("llm_finish_reason_path", "_llm_response.stop_reason");
params.put("llm_finish_reason_tool_use", "tool_use");
}
constructLLMInterfaceParams(llmInterface, params);
String memoryType = mlAgent.getMemory().getType();
String memoryId = params.get(MLAgentExecutor.MEMORY_ID);
String appType = mlAgent.getAppType();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package org.opensearch.ml.engine.function_calling;

import com.jayway.jsonpath.JsonPath;
import lombok.Data;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.utils.StringUtils;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_EXCLUDE_PATH;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.removeJsonPath;

public class BedrockConverseFunctionCalling implements FunctionCalling {
private static final String FINISH_REASON_PATH = "$.stopReason";
private static final String FINISH_REASON = "tool_use";
private static final String CALL_PATH = "$.output.message.content[*].toolUse";
private static final String NAME = "name";
private static final String INPUT = "input";
private static final String ID_PATH = "toolUseId";
private static final String TOOL_ERROR = "tool_error";
private static final String TOOL_RESULT = "tool_result";

@Override
public void configure(Map<String, String> params) {
params.put("tool_template", "{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}");
params.put("tool_configs", ", \"toolConfig\": {\"tools\": [${parameters._tools:-}]}");
}

@Override
public List<Map<String, String>> handle(ModelTensorOutput tmpModelTensorOutput, Map<String, String> parameters) {
List<Map<String, String>> output = new ArrayList<>();
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
String llmResponseExcludePath = parameters.get(LLM_RESPONSE_EXCLUDE_PATH);
if (llmResponseExcludePath != null) {
dataAsMap = removeJsonPath(dataAsMap, llmResponseExcludePath, true);
}
Object response = JsonPath.read(dataAsMap, parameters.get(LLM_RESPONSE_FILTER));
String llmFinishReason = JsonPath.read(dataAsMap, FINISH_REASON_PATH);
if (!llmFinishReason.contentEquals(FINISH_REASON)) {
return output;
}
List toolCalls = JsonPath.read(dataAsMap, CALL_PATH);
if (CollectionUtils.isEmpty(toolCalls)) {
return output;
}
for (Object call : toolCalls) {
String toolName = JsonPath.read(call, parameters.get(NAME));
String toolInput = StringUtils.toJson(JsonPath.read(call, parameters.get(INPUT)));
String toolCallId = JsonPath.read(call, parameters.get(ID_PATH));
output.add(Map.of("tool_name", toolName, "tool_input", toolInput, "tool_call_id", toolCallId));
}
return output;
}

@Override
public LLMMessage supply(List<Map<String, String>> toolResults) {
LLMMessage toolMessage = new LLMMessage();
for (Map toolResult : toolResults) {
String toolUseId = (String) toolResult.get(ID_PATH);
if (toolUseId == null) {
continue;
}
ToolResult result = new ToolResult();
result.setToolUseId(toolUseId);
result.getContent().add(Map.of("text", toolResult.get(TOOL_RESULT)));
if (toolResult.containsKey(TOOL_ERROR)) {
result.setStatus("error");
}
toolMessage.getContent().add(result);
}

return null;
}

@Data
public static class ToolResult {
private String toolUseId;
private List<Map<String, Object>> content = new ArrayList<>();
private String status;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.opensearch.ml.engine.function_calling;

import org.opensearch.ml.common.output.model.ModelTensorOutput;

import java.util.List;
import java.util.Map;

/**
* A general LLM function calling interface.
*/
public interface FunctionCalling {

/**
* Configure all parameters related to function calling.
* @param params the parameters used to configure a request to LLM
*/
void configure(Map<String, String> params);

/**
* Handle the response from LLM to get the function calling context.
* @param modelTensorOutput the response from LLM
* @param parameters some parameters
* @return a list of tools with something like name, input, etc.
*/
List<Map<String, String>> handle(ModelTensorOutput modelTensorOutput, Map<String, String> parameters);

/**
* According to results of tools to render a LLMMessage provided to LLM
* @param toolResults results from tools
* @return a LLMMessage containing tool results.
*/
LLMMessage supply(List<Map<String, String>> toolResults);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.opensearch.ml.engine.function_calling;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.ml.common.exception.MLException;

import java.util.Locale;

import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE;

public class FunctionCallingFactory {
public FunctionCalling create(String llmInterface) {
if (StringUtils.isBlank(llmInterface)) {
return null;
}

switch (llmInterface.trim().toLowerCase(Locale.ROOT)) {
case LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE:
return new BedrockConverseFunctionCalling();
default:
throw new MLException(String.format("Unsupported llm interface: {}.", llmInterface));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.opensearch.ml.engine.function_calling;

import lombok.Data;

import java.util.ArrayList;
import java.util.List;

@Data
public class LLMMessage {

private String role;
private List<Object> content = new ArrayList<>();

LLMMessage() {
this("user");
}

LLMMessage(String role) {
this(role, null);
}

LLMMessage(String role, List<Object> content) {
this.role = role;
if (content != null) {
this.content = content;
}
}
}
Loading