Skip to content

Commit a58facf

Browse files
ylwu-amzngithub-actions[bot]
authored andcommitted
enhance parsing model response function for more edge cases (#2122)
* enhance parsing model response function for more edge cases Signed-off-by: Yaliang Wu <[email protected]> * add more unit test Signed-off-by: Yaliang Wu <[email protected]> * fine tune code; fix some bug Signed-off-by: Yaliang Wu <[email protected]> * add more unit test Signed-off-by: Yaliang Wu <[email protected]> * fix tool name bug Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]> (cherry picked from commit 311b971)
1 parent 70f1b8d commit a58facf

File tree

5 files changed

+473
-120
lines changed

5 files changed

+473
-120
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

+169-11
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,20 @@
55

66
package org.opensearch.ml.engine.algorithms.agent;
77

8+
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
89
import static org.opensearch.ml.common.utils.StringUtils.gson;
10+
import static org.opensearch.ml.common.utils.StringUtils.isJson;
11+
import static org.opensearch.ml.common.utils.StringUtils.toJson;
912
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MESSAGE_HISTORY_LIMIT;
13+
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION;
14+
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT;
1015
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY;
1116
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;
1217
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;
18+
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.FINAL_ANSWER;
1319
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
20+
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT;
21+
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT_RESPONSE;
1422
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
1523
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
1624
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;
@@ -19,10 +27,13 @@
1927
import java.security.PrivilegedActionException;
2028
import java.security.PrivilegedExceptionAction;
2129
import java.util.ArrayList;
30+
import java.util.Collection;
2231
import java.util.HashMap;
2332
import java.util.List;
33+
import java.util.Locale;
2434
import java.util.Map;
2535
import java.util.Optional;
36+
import java.util.Set;
2637
import java.util.regex.Matcher;
2738
import java.util.regex.Pattern;
2839

@@ -33,7 +44,11 @@
3344
import org.opensearch.ml.common.output.model.ModelTensor;
3445
import org.opensearch.ml.common.output.model.ModelTensorOutput;
3546
import org.opensearch.ml.common.spi.tools.Tool;
47+
import org.opensearch.ml.common.utils.StringUtils;
3648

49+
import lombok.extern.log4j.Log4j2;
50+
51+
@Log4j2
3752
public class AgentUtils {
3853

3954
public static final String SELECTED_TOOLS = "selected_tools";
@@ -167,23 +182,166 @@ public static String extractModelResponseJson(String text) {
167182
return extractModelResponseJson(text, null);
168183
}
169184

170-
public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {
171-
Pattern jsonBlockPattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
172-
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
173-
174-
if (jsonBlockMatcher.find()) {
175-
return jsonBlockMatcher.group(1);
185+
public static Map<String, String> parseLLMOutput(
186+
ModelTensorOutput tmpModelTensorOutput,
187+
List<String> llmResponsePatterns,
188+
Set<String> inputTools
189+
) {
190+
Map<String, String> modelOutput = new HashMap<>();
191+
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
192+
if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) {
193+
String llmReasoningResponse = (String) dataAsMap.get("response");
194+
String thoughtResponse = null;
195+
try {
196+
thoughtResponse = extractModelResponseJson(llmReasoningResponse, llmResponsePatterns);
197+
modelOutput.put(THOUGHT_RESPONSE, thoughtResponse);
198+
} catch (IllegalArgumentException e) {
199+
modelOutput.put(THOUGHT_RESPONSE, llmReasoningResponse);
200+
thoughtResponse = llmReasoningResponse;
201+
}
202+
parseThoughtResponse(modelOutput, thoughtResponse);
176203
} else {
177-
String matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS);
178-
if (matchedPart == null && llmResponsePatterns != null) {
179-
// If no match is found, try additional patterns if provided
180-
matchedPart = findMatchedPart(text, llmResponsePatterns);
204+
extractParams(modelOutput, dataAsMap, THOUGHT);
205+
extractParams(modelOutput, dataAsMap, ACTION);
206+
extractParams(modelOutput, dataAsMap, ACTION_INPUT);
207+
extractParams(modelOutput, dataAsMap, FINAL_ANSWER);
208+
try {
209+
modelOutput.put(THOUGHT_RESPONSE, StringUtils.toJson(dataAsMap));
210+
} catch (Exception e) {
211+
log.warn("Failed to parse model response", e);
212+
}
213+
}
214+
String action = modelOutput.get(ACTION);
215+
if (action != null) {
216+
String matchedTool = getMatchedTool(inputTools, action);
217+
if (matchedTool != null) {
218+
modelOutput.put(ACTION, matchedTool);
219+
} else {
220+
modelOutput.remove(ACTION);
221+
}
222+
}
223+
if (!modelOutput.containsKey(ACTION) && !modelOutput.containsKey(FINAL_ANSWER)) {
224+
modelOutput.put(FINAL_ANSWER, modelOutput.get(THOUGHT_RESPONSE));
225+
}
226+
return modelOutput;
227+
}
228+
229+
public static String getMatchedTool(Collection<String> tools, String action) {
230+
for (String tool : tools) {
231+
if (action.toLowerCase(Locale.ROOT).contains(tool.toLowerCase(Locale.ROOT))) {
232+
return tool;
181233
}
234+
}
235+
return null;
236+
}
237+
238+
public static void extractParams(Map<String, String> modelOutput, Map<String, ?> dataAsMap, String paramName) {
239+
if (dataAsMap.containsKey(paramName)) {
240+
modelOutput.put(paramName, toJson(dataAsMap.get(paramName)));
241+
}
242+
}
243+
244+
public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {
245+
if (text.contains("```json")) {
246+
text = text.substring(text.indexOf("```json") + "```json".length());
247+
if (text.contains("```")) {
248+
text = text.substring(0, text.lastIndexOf("```"));
249+
}
250+
}
251+
text = text.trim();
252+
if (isJson(text)) {
253+
return text;
254+
}
255+
String matchedPart = null;
256+
if (llmResponsePatterns != null) {
257+
matchedPart = findMatchedPart(text, llmResponsePatterns);
182258
if (matchedPart != null) {
183259
return matchedPart;
184260
}
185-
throw new IllegalArgumentException("Model output is invalid");
186261
}
262+
matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS);
263+
if (matchedPart != null) {
264+
return matchedPart;
265+
}
266+
throw new IllegalArgumentException("Model output is invalid");
267+
}
268+
269+
public static void parseThoughtResponse(Map<String, String> modelOutput, String thoughtResponse) {
270+
if (thoughtResponse != null) {
271+
if (isJson(thoughtResponse)) {
272+
modelOutput.putAll(getParameterMap(gson.fromJson(thoughtResponse, Map.class)));
273+
} else {// sometimes LLM return invalid json response
274+
String thought = extractThought(thoughtResponse);
275+
String action = extractAction(thoughtResponse);
276+
String actionInput = extractActionInput(thoughtResponse);
277+
String finalAnswer = extractFinalAnswer(thoughtResponse);
278+
if (thought != null) {
279+
modelOutput.put(THOUGHT, thought);
280+
}
281+
if (action != null) {
282+
modelOutput.put(ACTION, action);
283+
}
284+
if (actionInput != null) {
285+
modelOutput.put(ACTION_INPUT, actionInput);
286+
}
287+
if (finalAnswer != null) {
288+
modelOutput.put(FINAL_ANSWER, finalAnswer);
289+
}
290+
}
291+
}
292+
}
293+
294+
public static String extractFinalAnswer(String text) {
295+
String result = null;
296+
if (text.contains("\"final_answer\"")) {
297+
String pattern = "\"final_answer\"\\s*:\\s*\"(.*?)$";
298+
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);
299+
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
300+
if (jsonBlockMatcher.find()) {
301+
result = jsonBlockMatcher.group(1);
302+
}
303+
}
304+
return result;
305+
}
306+
307+
public static String extractThought(String text) {
308+
String result = null;
309+
if (text.contains("\"thought\"")) {
310+
String pattern = "\"thought\"\\s*:\\s*\"(.*?)\"\\s*,\\s*[\"final_answer\"|\"action\"]";
311+
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);
312+
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
313+
if (jsonBlockMatcher.find()) {
314+
result = jsonBlockMatcher.group(1);
315+
}
316+
}
317+
return result;
318+
}
319+
320+
public static String extractAction(String text) {
321+
String result = null;
322+
if (text.contains("\"action\"")) {
323+
String pattern = "\"action\"\\s*:\\s*\"(.*?)(?:\"action_input\"|$)";
324+
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);
325+
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
326+
if (jsonBlockMatcher.find()) {
327+
result = jsonBlockMatcher.group(1);
328+
}
329+
}
330+
return result;
331+
}
332+
333+
public static String extractActionInput(String text) {
334+
String result = null;
335+
if (text.contains("\"action_input\"")) {
336+
String pattern = "\"action_input\"\\s*:\\s*\"((?:[^\\\"]|\\\")*)\"";
337+
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL); // Add Pattern.DOTALL to match across newlines
338+
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
339+
if (jsonBlockMatcher.find()) {
340+
result = jsonBlockMatcher.group(1);
341+
result = result.replace("\\\"", "\"");
342+
}
343+
}
344+
return result;
187345
}
188346

189347
public static String findMatchedPart(String text, List<String> patternList) {

0 commit comments

Comments
 (0)