Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make memory optional in conversational agent #3626

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.StepListener;
Expand Down Expand Up @@ -121,12 +122,17 @@ public MLChatAgentRunner(

@Override
public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener) {
String memoryType = mlAgent.getMemory().getType();
String memoryType = mlAgent.getMemory() == null ? null : mlAgent.getMemory().getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the current behavior throw NPE here? wondering if there is some validation in the REST/Transport layer to check for memory, didn't see any tho

String memoryId = params.get(MLAgentExecutor.MEMORY_ID);
String appType = mlAgent.getAppType();
String title = params.get(MLAgentExecutor.QUESTION);
int messageHistoryLimit = getMessageHistoryLimit(params);

if (StringUtils.isEmpty(memoryType)) {
runAgent(mlAgent, params, listener, null, null);
return;
}

ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.<ConversationIndexMemory>wrap(memory -> {
// TODO: call runAgent directly if messageHistoryLimit == 0
Expand All @@ -151,8 +157,8 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
);
}

StringBuilder chatHistoryBuilder = new StringBuilder();
if (!messageList.isEmpty()) {
StringBuilder chatHistoryBuilder = new StringBuilder();
String chatHistoryPrefix = params.getOrDefault(PROMPT_CHAT_HISTORY_PREFIX, CHAT_HISTORY_PREFIX);
chatHistoryBuilder.append(chatHistoryPrefix);
for (Message message : messageList) {
Expand Down Expand Up @@ -220,7 +226,9 @@ private void runReAct(
AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt));
tmpParameters.put(PROMPT, newPrompt.get());

List<ModelTensors> traceTensors = createModelTensors(sessionId, parentInteractionId);
List<ModelTensors> traceTensors = (conversationIndexMemory == null)
? new ArrayList<>()
: createModelTensors(sessionId, parentInteractionId);
Comment on lines +229 to +231
Copy link
Contributor

@pyek-bot pyek-bot Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about passing conversationIndexMemory to this method and returning empty list rather than using ternary operator here?

*ignore, seems like it will affect a lot of other methods using this

int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, "3")) * 2;
for (int i = 0; i < maxIterations; i++) {
int finalI = i;
Expand Down Expand Up @@ -401,8 +409,8 @@ private void runReAct(
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
}

private static List<ModelTensors> createFinalAnswerTensors(List<ModelTensors> sessionId, List<ModelTensor> lastThought) {
List<ModelTensors> finalModelTensors = sessionId;
private static List<ModelTensors> createFinalAnswerTensors(List<ModelTensors> modelTensorsList, List<ModelTensor> lastThought) {
List<ModelTensors> finalModelTensors = modelTensorsList;
finalModelTensors.add(ModelTensors.builder().mlModelTensors(lastThought).build());
return finalModelTensors;
}
Expand Down Expand Up @@ -572,19 +580,21 @@ private void sendFinalAnswer(
private static List<ModelTensors> createModelTensors(String sessionId, String parentInteractionId) {
List<ModelTensors> cotModelTensors = new ArrayList<>();

cotModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
List
.of(
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build()
)
)
.build()
);
if (!StringUtils.isEmpty(sessionId)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's inverse this? if (StringUtils.isEmpty()) { return emptyList; }

cotModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
List
.of(
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build()
)
)
.build()
);
}
return cotModelTensors;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,42 @@ public void testToolExecutionWithChatHistoryParameter() {
Assert.assertTrue(toolParamsCapture.getValue().containsKey(MLChatAgentRunner.CHAT_HISTORY));
}

@Test
public void testParsingJsonBlockFromResponseNoMemory() {
// Prepare the response with JSON block
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
+ "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}";
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";

// Mock LLM response to not contain "thought" but contain "response" with JSON block
Map<String, String> llmResponse = new HashMap<>();
llmResponse.put("response", responseWithJsonBlock);
doAnswer(getLLMAnswer(llmResponse))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

// Create an MLAgent and run the MLChatAgentRunner
MLAgent mlAgent = createMLAgentNoMemory();
Map<String, String> params = new HashMap<>();
params.put("verbose", "true");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Capture the response passed to the listener
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

// Extract the captured response
Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;

ModelTensor modelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0);

assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
assertEquals("parsed final answer", modelTensor.getResult());
}

// Helper methods to create MLAgent and parameters
private MLAgent createMLAgentWithTools() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down Expand Up @@ -917,6 +953,23 @@ private MLAgent createMLAgentWithToolsConfig(Map<String, String> configMap) {
.build();
}

private MLAgent createMLAgentNoMemory() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
MLToolSpec firstToolSpec = MLToolSpec
.builder()
.name(FIRST_TOOL)
.type(FIRST_TOOL)
.parameters(ImmutableMap.of("key1", "value1", "key2", "value2"))
.build();
return MLAgent
.builder()
.name("TestAgent")
.type(MLAgentType.CONVERSATIONAL.name())
.tools(Arrays.asList(firstToolSpec))
.llm(llmSpec)
.build();
}

private Map<String, String> createAgentParamsWithAction(String action, String actionInput) {
Map<String, String> params = new HashMap<>();
params.put("action", action);
Expand Down
Loading