1616import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS ;
1717import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .LLM_RESPONSE_FILTER ;
1818import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .cleanUpResource ;
19+ import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .createMemoryParams ;
1920import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .createTools ;
2021import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .getCurrentDateTime ;
2122import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .getMcpToolSpecs ;
3233import static org .opensearch .ml .engine .algorithms .agent .PromptTemplate .FINAL_RESULT_RESPONSE_INSTRUCTIONS ;
3334import static org .opensearch .ml .engine .algorithms .agent .PromptTemplate .PLANNER_RESPONSIBILITY ;
3435import static org .opensearch .ml .engine .algorithms .agent .PromptTemplate .PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT ;
35- import static org .opensearch .ml .engine .memory .ConversationIndexMemory .APP_TYPE ;
36- import static org .opensearch .ml .engine .memory .ConversationIndexMemory .MEMORY_ID ;
37- import static org .opensearch .ml .engine .memory .ConversationIndexMemory .MEMORY_NAME ;
3836
3937import java .util .ArrayList ;
4038import java .util .HashMap ;
5149import org .opensearch .core .action .ActionListener ;
5250import org .opensearch .core .xcontent .NamedXContentRegistry ;
5351import org .opensearch .ml .common .FunctionName ;
52+ import org .opensearch .ml .common .MLMemoryType ;
5453import org .opensearch .ml .common .MLTaskState ;
5554import org .opensearch .ml .common .agent .LLMSpec ;
5655import org .opensearch .ml .common .agent .MLAgent ;
@@ -285,42 +284,42 @@ public void run(MLAgent mlAgent, Map<String, String> apiParams, ActionListener<O
285284 usePlannerPromptTemplate (allParams );
286285
287286 String memoryId = allParams .get (MEMORY_ID_FIELD );
288- String memoryType = mlAgent .getMemory ().getType ();
287+ String memoryType = MLMemoryType . from ( mlAgent .getMemory ().getType ()). name ();
289288 String appType = mlAgent .getAppType ();
290289 int messageHistoryLimit = Integer .parseInt (allParams .getOrDefault (PLANNER_MESSAGE_HISTORY_LIMIT , DEFAULT_MESSAGE_HISTORY_LIMIT ));
291290
292291 // todo: use chat history instead of completed steps
293- ConversationIndexMemory .Factory conversationIndexMemoryFactory = (ConversationIndexMemory .Factory ) memoryFactoryMap .get (memoryType );
294- conversationIndexMemoryFactory
295- .create (
296- Map .of (MEMORY_ID , memoryId , MEMORY_NAME , apiParams .get (USER_PROMPT_FIELD ), APP_TYPE , appType ),
297- ActionListener .<ConversationIndexMemory >wrap (memory -> {
298- memory .getMessages (messageHistoryLimit , ActionListener .<List <Interaction >>wrap (interactions -> {
299- List <String > completedSteps = new ArrayList <>();
300- for (Interaction interaction : interactions ) {
301- String question = interaction .getInput ();
302- String response = interaction .getResponse ();
303-
304- if (Strings .isNullOrEmpty (response )) {
305- continue ;
306- }
307-
308- completedSteps .add (question );
309- completedSteps .add (response );
310- }
292+ // ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory)
293+ // memoryFactoryMap.get(memoryType);
294+
295+ Memory .Factory <Memory <Interaction , ?, ?>> memoryFactory = memoryFactoryMap .get (memoryType );
296+ Map <String , Object > memoryParams = createMemoryParams (apiParams .get (USER_PROMPT_FIELD ), memoryId , appType , mlAgent );
297+ memoryFactory .create (memoryParams , ActionListener .wrap (memory -> {
298+ memory .getMessages (messageHistoryLimit , ActionListener .<List <Interaction >>wrap (interactions -> {
299+ List <String > completedSteps = new ArrayList <>();
300+ for (Interaction interaction : interactions ) {
301+ String question = interaction .getInput ();
302+ String response = interaction .getResponse ();
303+
304+ if (Strings .isNullOrEmpty (response )) {
305+ continue ;
306+ }
311307
312- if (!completedSteps .isEmpty ()) {
313- addSteps (completedSteps , allParams , COMPLETED_STEPS_FIELD );
314- usePlannerWithHistoryPromptTemplate (allParams );
315- }
308+ completedSteps .add (question );
309+ completedSteps .add (response );
310+ }
316311
317- setToolsAndRunAgent (mlAgent , allParams , completedSteps , memory , memory .getConversationId (), listener );
318- }, e -> {
319- log .error ("Failed to get chat history" , e );
320- listener .onFailure (e );
321- }));
322- }, listener ::onFailure )
323- );
312+ if (!completedSteps .isEmpty ()) {
313+ addSteps (completedSteps , allParams , COMPLETED_STEPS_FIELD );
314+ usePlannerWithHistoryPromptTemplate (allParams );
315+ }
316+
317+ setToolsAndRunAgent (mlAgent , allParams , completedSteps , memory , memory .getId (), listener );
318+ }, e -> {
319+ log .error ("Failed to get chat history" , e );
320+ listener .onFailure (e );
321+ }));
322+ }, listener ::onFailure ));
324323 }
325324
326325 private void setToolsAndRunAgent (
@@ -412,7 +411,7 @@ private void executePlanningLoop(
412411 if (parseLLMOutput .get (RESULT_FIELD ) != null ) {
413412 String finalResult = (String ) parseLLMOutput .get (RESULT_FIELD );
414413 saveAndReturnFinalResult (
415- ( ConversationIndexMemory ) memory ,
414+ memory ,
416415 parentInteractionId ,
417416 allParams .get (EXECUTOR_AGENT_MEMORY_ID_FIELD ),
418417 allParams .get (EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD ),
@@ -512,7 +511,7 @@ private void executePlanningLoop(
512511 completedSteps .add (String .format ("\n Step %d Result: %s\n " , stepsExecuted + 1 , results .get (STEP_RESULT_FIELD )));
513512
514513 saveTraceData (
515- ( ConversationIndexMemory ) memory ,
514+ memory ,
516515 memory .getType (),
517516 stepToExecute ,
518517 results .get (STEP_RESULT_FIELD ),
@@ -636,7 +635,7 @@ void addSteps(List<String> steps, Map<String, String> allParams, String field) {
636635
637636 @ VisibleForTesting
638637 void saveAndReturnFinalResult (
639- ConversationIndexMemory memory ,
638+ Memory memory ,
640639 String parentInteractionId ,
641640 String reactAgentMemoryId ,
642641 String reactParentInteractionId ,
@@ -651,9 +650,9 @@ void saveAndReturnFinalResult(
651650 updateContent .put (INTERACTIONS_INPUT_FIELD , input );
652651 }
653652
654- memory .getMemoryManager (). updateInteraction (parentInteractionId , updateContent , ActionListener .wrap (res -> {
653+ memory .update (parentInteractionId , updateContent , ActionListener .wrap (res -> {
655654 List <ModelTensors > finalModelTensors = createModelTensors (
656- memory .getConversationId (),
655+ memory .getId (),
657656 parentInteractionId ,
658657 reactAgentMemoryId ,
659658 reactParentInteractionId
0 commit comments