diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 71fe662628..5de35b1962 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -168,8 +168,7 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener if (functionCalling != null) { functionCalling.configure(params); } - - String memoryType = mlAgent.getMemory().getType(); + String memoryType = mlAgent.getMemory() == null ? null : mlAgent.getMemory().getType(); String memoryId = params.get(MLAgentExecutor.MEMORY_ID); String appType = mlAgent.getAppType(); String title = params.get(MLAgentExecutor.QUESTION); @@ -178,6 +177,11 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener String chatHistoryResponseTemplate = params.get(CHAT_HISTORY_RESPONSE_TEMPLATE); int messageHistoryLimit = getMessageHistoryLimit(params); + if (memoryType == null || memoryType.isBlank()) { + runAgent(mlAgent, params, listener, null, null, functionCalling); + return; + } + ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.wrap(memory -> { // TODO: call runAgent directly if messageHistoryLimit == 0 @@ -317,7 +321,9 @@ private void runReAct( AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); tmpParameters.put(PROMPT, newPrompt.get()); - List traceTensors = createModelTensors(sessionId, parentInteractionId); + List traceTensors = (conversationIndexMemory == null) + ? new ArrayList<>() + : createModelTensors(sessionId, parentInteractionId); int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, DEFAULT_MAX_ITERATIONS)); for (int i = 0; i < maxIterations; i++) { int finalI = i; @@ -513,8 +519,8 @@ private void runReAct( client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener); } - private static List createFinalAnswerTensors(List sessionId, List lastThought) { - List finalModelTensors = sessionId; + private static List createFinalAnswerTensors(List modelTensorsList, List lastThought) { + List finalModelTensors = modelTensorsList; finalModelTensors.add(ModelTensors.builder().mlModelTensors(lastThought).build()); return finalModelTensors; } @@ -719,19 +725,21 @@ private void sendFinalAnswer( public static List createModelTensors(String sessionId, String parentInteractionId) { List cotModelTensors = new ArrayList<>(); - cotModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - List - .of( - ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), - ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build() - ) - ) - .build() - ); + if (sessionId != null && !sessionId.isBlank()) { + cotModelTensors + .add( + ModelTensors + .builder() + .mlModelTensors( + List + .of( + ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), + ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build() + ) + ) + .build() + ); + } return cotModelTensors; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 037e9e1cb4..ddda98bb0a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -880,6 +880,42 @@ public void testToolExecutionWithChatHistoryParameter() { Assert.assertTrue(toolParamsCapture.getValue().containsKey(MLChatAgentRunner.CHAT_HISTORY)); } + @Test + public void testParsingJsonBlockFromResponseNoMemory() { + // Prepare the response with JSON block + String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", " + + "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}"; + String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text"; + + // Mock LLM response to not contain "thought" but contain "response" with JSON block + Map llmResponse = new HashMap<>(); + llmResponse.put("response", responseWithJsonBlock); + doAnswer(getLLMAnswer(llmResponse)) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + // Create an MLAgent and run the MLChatAgentRunner + MLAgent mlAgent = createMLAgentNoMemory(); + Map params = new HashMap<>(); + params.put("verbose", "true"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Capture the response passed to the listener + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); + verify(agentActionListener).onResponse(responseCaptor.capture()); + + // Extract the captured response + Object capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + + ModelTensor modelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0); + + assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); + assertEquals("parsed final answer", modelTensor.getResult()); + } + // Helper methods to create MLAgent and parameters private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); @@ -918,6 +954,23 @@ private MLAgent createMLAgentWithToolsConfig(Map configMap) { .build(); } + private MLAgent createMLAgentNoMemory() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLToolSpec firstToolSpec = MLToolSpec + .builder() + .name(FIRST_TOOL) + .type(FIRST_TOOL) + .parameters(ImmutableMap.of("key1", "value1", "key2", "value2")) + .build(); + return MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .tools(Arrays.asList(firstToolSpec)) + .llm(llmSpec) + .build(); + } + private Map createAgentParamsWithAction(String action, String actionInput) { Map params = new HashMap<>(); params.put("action", action);