Skip to content

Commit e7a045d

Browse files
committed
extending memory refactoring to PER agent
Signed-off-by: Dhrubo Saha <[email protected]>
1 parent 94d97e7 commit e7a045d

File tree

11 files changed

+91
-55
lines changed

11 files changed

+91
-55
lines changed

common/src/main/java/org/opensearch/ml/common/MLAgentType.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public static MLAgentType from(String value) {
2020
try {
2121
return MLAgentType.valueOf(value.toUpperCase(Locale.ROOT));
2222
} catch (Exception e) {
23-
throw new IllegalArgumentException("Wrong Agent type");
23+
throw new IllegalArgumentException(value + " is not a valid Agent Type");
2424
}
2525
}
2626
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common;
7+
8+
import java.util.Locale;
9+
10+
public enum MLMemoryType {
11+
CONVERSATION_INDEX,
12+
AGENTIC_MEMORY;
13+
14+
public static MLMemoryType from(String value) {
15+
if (value != null) {
16+
try {
17+
return MLMemoryType.valueOf(value.toUpperCase(Locale.ROOT));
18+
} catch (Exception e) {
19+
throw new IllegalArgumentException("Wrong Memory type");
20+
}
21+
}
22+
return null;
23+
}
24+
}

common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ private void validate() {
113113
String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH)
114114
);
115115
}
116-
validateMLAgentType(type);
116+
MLAgentType.from(type);
117117
if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null) {
118118
throw new IllegalArgumentException("We need model information for the conversational agent type");
119119
}

common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.opensearch.core.xcontent.ToXContentObject;
2727
import org.opensearch.core.xcontent.XContentBuilder;
2828
import org.opensearch.core.xcontent.XContentParser;
29+
import org.opensearch.ml.common.MLMemoryType;
2930
import org.opensearch.ml.common.agent.LLMSpec;
3031
import org.opensearch.ml.common.agent.MLAgent;
3132
import org.opensearch.ml.common.agent.MLMemorySpec;
@@ -383,9 +384,7 @@ private void validate() {
383384
String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH)
384385
);
385386
}
386-
if (memoryType != null && !memoryType.equals("conversation_index")) {
387-
throw new IllegalArgumentException(String.format("Invalid memory type: %s", memoryType));
388-
}
387+
MLMemoryType.from(memoryType);
389388
if (tools != null) {
390389
Set<String> toolNames = new HashSet<>();
391390
for (MLToolSpec toolSpec : tools) {

common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ public void testFromWithMixedCase() {
4444
public void testFromWithInvalidType() {
4545
// This should throw an IllegalArgumentException
4646
exceptionRule.expect(IllegalArgumentException.class);
47-
exceptionRule.expectMessage("Wrong Agent type");
47+
exceptionRule.expectMessage(" is not a valid Agent Type");
4848
MLAgentType.from("INVALID_TYPE");
4949
}
5050

5151
@Test
5252
public void testFromWithEmptyString() {
5353
exceptionRule.expect(IllegalArgumentException.class);
54-
exceptionRule.expectMessage("Wrong Agent type");
54+
exceptionRule.expectMessage(" is not a valid Agent Type");
5555
// This should also throw an IllegalArgumentException
5656
MLAgentType.from("");
5757
}

common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public void testValidationWithInvalidMemoryType() {
9494
IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> {
9595
MLAgentUpdateInput.builder().agentId("test-agent-id").name("test-agent").memoryType("invalid_type").build();
9696
});
97-
assertEquals("Invalid memory type: invalid_type", e.getMessage());
97+
assertEquals("Wrong Memory type", e.getMessage());
9898
}
9999

100100
@Test

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE;
1818
import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly;
1919
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createMemoryParams;
20-
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE;
2120

2221
import java.security.AccessController;
2322
import java.security.PrivilegedActionException;
@@ -48,6 +47,7 @@
4847
import org.opensearch.index.IndexNotFoundException;
4948
import org.opensearch.ml.common.FunctionName;
5049
import org.opensearch.ml.common.MLAgentType;
50+
import org.opensearch.ml.common.MLMemoryType;
5151
import org.opensearch.ml.common.MLTask;
5252
import org.opensearch.ml.common.MLTaskState;
5353
import org.opensearch.ml.common.MLTaskType;
@@ -245,9 +245,10 @@ public void execute(Input input, ActionListener<Output> listener, TransportChann
245245
}
246246
if (memorySpec != null
247247
&& memorySpec.getType() != null
248-
&& memoryFactoryMap.containsKey(memorySpec.getType())
248+
&& memoryFactoryMap.containsKey(MLMemoryType.from(memorySpec.getType()).name())
249249
&& (memoryId == null || parentInteractionId == null)) {
250-
Memory.Factory<Memory<?, ?, ?>> memoryFactory = memoryFactoryMap.get(memorySpec.getType());
250+
Memory.Factory<Memory<?, ?, ?>> memoryFactory = memoryFactoryMap
251+
.get(MLMemoryType.from(memorySpec.getType()).name());
251252

252253
Map<String, Object> memoryParams = createMemoryParams(question, memoryId, appType, mlAgent);
253254
memoryFactory.create(memoryParams, ActionListener.wrap(memory -> {
@@ -299,14 +300,24 @@ public void execute(Input input, ActionListener<Output> listener, TransportChann
299300
} else {
300301
// For existing conversations, create memory instance using factory
301302
if (memorySpec != null && memorySpec.getType() != null) {
303+
String memoryType = MLMemoryType.from(memorySpec.getType()).name();
304+
Memory.Factory<Memory<?, ?, ?>> memoryFactory = memoryFactoryMap.get(memoryType);
305+
302306
ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap
303307
.get(memorySpec.getType());
304-
if (factory != null) {
308+
if (memoryFactory != null) {
305309
// memoryId exists, so create returns an object with existing memory, therefore name can
306310
// be null
307-
factory
311+
Map<String, Object> memoryParams = createMemoryParams(
312+
question,
313+
memoryId,
314+
appType,
315+
mlAgent
316+
);
317+
318+
memoryFactory
308319
.create(
309-
Map.of(MEMORY_ID, memoryId, APP_TYPE, appType),
320+
memoryParams,
310321
ActionListener
311322
.wrap(
312323
createdMemory -> executeAgent(

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.opensearch.core.action.ActionListener;
5959
import org.opensearch.core.common.Strings;
6060
import org.opensearch.core.xcontent.NamedXContentRegistry;
61+
import org.opensearch.ml.common.MLMemoryType;
6162
import org.opensearch.ml.common.agent.LLMSpec;
6263
import org.opensearch.ml.common.agent.MLAgent;
6364
import org.opensearch.ml.common.agent.MLToolSpec;
@@ -177,7 +178,7 @@ public void run(MLAgent mlAgent, Map<String, String> inputParams, ActionListener
177178
functionCalling.configure(params);
178179
}
179180

180-
String memoryType = mlAgent.getMemory().getType();
181+
String memoryType = MLMemoryType.from(mlAgent.getMemory().getType()).name();
181182
String memoryId = params.get(MLAgentExecutor.MEMORY_ID);
182183
String appType = mlAgent.getAppType();
183184
String title = params.get(MLAgentExecutor.QUESTION);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS;
1717
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
1818
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.cleanUpResource;
19+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createMemoryParams;
1920
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools;
2021
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime;
2122
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs;
@@ -32,9 +33,6 @@
3233
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.FINAL_RESULT_RESPONSE_INSTRUCTIONS;
3334
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLANNER_RESPONSIBILITY;
3435
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT;
35-
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE;
36-
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID;
37-
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME;
3836

3937
import java.util.ArrayList;
4038
import java.util.HashMap;
@@ -51,6 +49,7 @@
5149
import org.opensearch.core.action.ActionListener;
5250
import org.opensearch.core.xcontent.NamedXContentRegistry;
5351
import org.opensearch.ml.common.FunctionName;
52+
import org.opensearch.ml.common.MLMemoryType;
5453
import org.opensearch.ml.common.MLTaskState;
5554
import org.opensearch.ml.common.agent.LLMSpec;
5655
import org.opensearch.ml.common.agent.MLAgent;
@@ -285,42 +284,42 @@ public void run(MLAgent mlAgent, Map<String, String> apiParams, ActionListener<O
285284
usePlannerPromptTemplate(allParams);
286285

287286
String memoryId = allParams.get(MEMORY_ID_FIELD);
288-
String memoryType = mlAgent.getMemory().getType();
287+
String memoryType = MLMemoryType.from(mlAgent.getMemory().getType()).name();
289288
String appType = mlAgent.getAppType();
290289
int messageHistoryLimit = Integer.parseInt(allParams.getOrDefault(PLANNER_MESSAGE_HISTORY_LIMIT, DEFAULT_MESSAGE_HISTORY_LIMIT));
291290

292291
// todo: use chat history instead of completed steps
293-
ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
294-
conversationIndexMemoryFactory
295-
.create(
296-
Map.of(MEMORY_ID, memoryId, MEMORY_NAME, apiParams.get(USER_PROMPT_FIELD), APP_TYPE, appType),
297-
ActionListener.<ConversationIndexMemory>wrap(memory -> {
298-
memory.getMessages(messageHistoryLimit, ActionListener.<List<Interaction>>wrap(interactions -> {
299-
List<String> completedSteps = new ArrayList<>();
300-
for (Interaction interaction : interactions) {
301-
String question = interaction.getInput();
302-
String response = interaction.getResponse();
303-
304-
if (Strings.isNullOrEmpty(response)) {
305-
continue;
306-
}
307-
308-
completedSteps.add(question);
309-
completedSteps.add(response);
310-
}
292+
// ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory)
293+
// memoryFactoryMap.get(memoryType);
294+
295+
Memory.Factory<Memory<Interaction, ?, ?>> memoryFactory = memoryFactoryMap.get(memoryType);
296+
Map<String, Object> memoryParams = createMemoryParams(apiParams.get(USER_PROMPT_FIELD), memoryId, appType, mlAgent);
297+
memoryFactory.create(memoryParams, ActionListener.wrap(memory -> {
298+
memory.getMessages(messageHistoryLimit, ActionListener.<List<Interaction>>wrap(interactions -> {
299+
List<String> completedSteps = new ArrayList<>();
300+
for (Interaction interaction : interactions) {
301+
String question = interaction.getInput();
302+
String response = interaction.getResponse();
303+
304+
if (Strings.isNullOrEmpty(response)) {
305+
continue;
306+
}
311307

312-
if (!completedSteps.isEmpty()) {
313-
addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD);
314-
usePlannerWithHistoryPromptTemplate(allParams);
315-
}
308+
completedSteps.add(question);
309+
completedSteps.add(response);
310+
}
316311

317-
setToolsAndRunAgent(mlAgent, allParams, completedSteps, memory, memory.getConversationId(), listener);
318-
}, e -> {
319-
log.error("Failed to get chat history", e);
320-
listener.onFailure(e);
321-
}));
322-
}, listener::onFailure)
323-
);
312+
if (!completedSteps.isEmpty()) {
313+
addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD);
314+
usePlannerWithHistoryPromptTemplate(allParams);
315+
}
316+
317+
setToolsAndRunAgent(mlAgent, allParams, completedSteps, memory, memory.getId(), listener);
318+
}, e -> {
319+
log.error("Failed to get chat history", e);
320+
listener.onFailure(e);
321+
}));
322+
}, listener::onFailure));
324323
}
325324

326325
private void setToolsAndRunAgent(
@@ -412,7 +411,7 @@ private void executePlanningLoop(
412411
if (parseLLMOutput.get(RESULT_FIELD) != null) {
413412
String finalResult = (String) parseLLMOutput.get(RESULT_FIELD);
414413
saveAndReturnFinalResult(
415-
(ConversationIndexMemory) memory,
414+
memory,
416415
parentInteractionId,
417416
allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD),
418417
allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD),
@@ -512,7 +511,7 @@ private void executePlanningLoop(
512511
completedSteps.add(String.format("\nStep %d Result: %s\n", stepsExecuted + 1, results.get(STEP_RESULT_FIELD)));
513512

514513
saveTraceData(
515-
(ConversationIndexMemory) memory,
514+
memory,
516515
memory.getType(),
517516
stepToExecute,
518517
results.get(STEP_RESULT_FIELD),
@@ -636,7 +635,7 @@ void addSteps(List<String> steps, Map<String, String> allParams, String field) {
636635

637636
@VisibleForTesting
638637
void saveAndReturnFinalResult(
639-
ConversationIndexMemory memory,
638+
Memory memory,
640639
String parentInteractionId,
641640
String reactAgentMemoryId,
642641
String reactParentInteractionId,
@@ -651,9 +650,9 @@ void saveAndReturnFinalResult(
651650
updateContent.put(INTERACTIONS_INPUT_FIELD, input);
652651
}
653652

654-
memory.getMemoryManager().updateInteraction(parentInteractionId, updateContent, ActionListener.wrap(res -> {
653+
memory.update(parentInteractionId, updateContent, ActionListener.wrap(res -> {
655654
List<ModelTensors> finalModelTensors = createModelTensors(
656-
memory.getConversationId(),
655+
memory.getId(),
657656
parentInteractionId,
658657
reactAgentMemoryId,
659658
reactParentInteractionId

ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.opensearch.core.common.Strings;
2121
import org.opensearch.index.query.BoolQueryBuilder;
2222
import org.opensearch.index.query.QueryBuilders;
23+
import org.opensearch.ml.common.MLMemoryType;
2324
import org.opensearch.ml.common.conversation.Interaction;
2425
import org.opensearch.ml.common.memory.Memory;
2526
import org.opensearch.ml.common.memory.Message;
@@ -56,7 +57,7 @@
5657
@Getter
5758
public class AgenticConversationMemory implements Memory<Message, CreateInteractionResponse, UpdateResponse> {
5859

59-
public static final String TYPE = "agentic_memory";
60+
public static final String TYPE = MLMemoryType.AGENTIC_MEMORY.name();
6061
private static final String SESSION_ID_FIELD = "session_id";
6162
private static final String CREATED_TIME_FIELD = "created_time";
6263

0 commit comments

Comments
 (0)