diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 3992b9f341..5809101be4 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -36,6 +36,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; import org.opensearch.action.StepListener; @@ -121,12 +122,17 @@ public MLChatAgentRunner( @Override public void run(MLAgent mlAgent, Map params, ActionListener listener) { - String memoryType = mlAgent.getMemory().getType(); + String memoryType = mlAgent.getMemory() == null ? null : mlAgent.getMemory().getType(); String memoryId = params.get(MLAgentExecutor.MEMORY_ID); String appType = mlAgent.getAppType(); String title = params.get(MLAgentExecutor.QUESTION); int messageHistoryLimit = getMessageHistoryLimit(params); + if (StringUtils.isEmpty(memoryType)) { + runAgent(mlAgent, params, listener, null, null); + return; + } + ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.wrap(memory -> { // TODO: call runAgent directly if messageHistoryLimit == 0 @@ -151,8 +157,8 @@ public void run(MLAgent mlAgent, Map params, ActionListener newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); tmpParameters.put(PROMPT, newPrompt.get()); - List traceTensors = createModelTensors(sessionId, parentInteractionId); + List traceTensors = (conversationIndexMemory == null) + ? new ArrayList<>() + : createModelTensors(sessionId, parentInteractionId); int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, "3")) * 2; for (int i = 0; i < maxIterations; i++) { int finalI = i; @@ -401,8 +409,8 @@ private void runReAct( client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener); } - private static List createFinalAnswerTensors(List sessionId, List lastThought) { - List finalModelTensors = sessionId; + private static List createFinalAnswerTensors(List modelTensorsList, List lastThought) { + List finalModelTensors = modelTensorsList; finalModelTensors.add(ModelTensors.builder().mlModelTensors(lastThought).build()); return finalModelTensors; } @@ -572,19 +580,21 @@ private void sendFinalAnswer( private static List createModelTensors(String sessionId, String parentInteractionId) { List cotModelTensors = new ArrayList<>(); - cotModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - List - .of( - ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), - ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build() - ) - ) - .build() - ); + if (!StringUtils.isEmpty(sessionId)) { + cotModelTensors + .add( + ModelTensors + .builder() + .mlModelTensors( + List + .of( + ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), + ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build() + ) + ) + .build() + ); + } return cotModelTensors; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 0fb416f0bf..9f9d881019 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -879,6 +879,42 @@ public void testToolExecutionWithChatHistoryParameter() { Assert.assertTrue(toolParamsCapture.getValue().containsKey(MLChatAgentRunner.CHAT_HISTORY)); } + @Test + public void testParsingJsonBlockFromResponseNoMemory() { + // Prepare the response with JSON block + String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", " + + "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}"; + String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text"; + + // Mock LLM response to not contain "thought" but contain "response" with JSON block + Map llmResponse = new HashMap<>(); + llmResponse.put("response", responseWithJsonBlock); + doAnswer(getLLMAnswer(llmResponse)) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + // Create an MLAgent and run the MLChatAgentRunner + MLAgent mlAgent = createMLAgentNoMemory(); + Map params = new HashMap<>(); + params.put("verbose", "true"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Capture the response passed to the listener + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); + verify(agentActionListener).onResponse(responseCaptor.capture()); + + // Extract the captured response + Object capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + + ModelTensor modelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0); + + assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); + assertEquals("parsed final answer", modelTensor.getResult()); + } + // Helper methods to create MLAgent and parameters private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); @@ -917,6 +953,23 @@ private MLAgent createMLAgentWithToolsConfig(Map configMap) { .build(); } + private MLAgent createMLAgentNoMemory() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLToolSpec firstToolSpec = MLToolSpec + .builder() + .name(FIRST_TOOL) + .type(FIRST_TOOL) + .parameters(ImmutableMap.of("key1", "value1", "key2", "value2")) + .build(); + return MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .tools(Arrays.asList(firstToolSpec)) + .llm(llmSpec) + .build(); + } + private Map createAgentParamsWithAction(String action, String actionInput) { Map params = new HashMap<>(); params.put("action", action);