-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make memory optional in conversational agent #3626
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
import java.util.concurrent.atomic.AtomicInteger; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.apache.commons.text.StringSubstitutor; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.StepListener; | ||
|
@@ -121,12 +122,17 @@ public MLChatAgentRunner( | |
|
||
@Override | ||
public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener) { | ||
String memoryType = mlAgent.getMemory().getType(); | ||
String memoryType = mlAgent.getMemory() == null ? null : mlAgent.getMemory().getType(); | ||
String memoryId = params.get(MLAgentExecutor.MEMORY_ID); | ||
String appType = mlAgent.getAppType(); | ||
String title = params.get(MLAgentExecutor.QUESTION); | ||
int messageHistoryLimit = getMessageHistoryLimit(params); | ||
|
||
if (StringUtils.isEmpty(memoryType)) { | ||
runAgent(mlAgent, params, listener, null, null); | ||
return; | ||
} | ||
|
||
ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); | ||
conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.<ConversationIndexMemory>wrap(memory -> { | ||
// TODO: call runAgent directly if messageHistoryLimit == 0 | ||
|
@@ -151,8 +157,8 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje | |
); | ||
} | ||
|
||
StringBuilder chatHistoryBuilder = new StringBuilder(); | ||
if (!messageList.isEmpty()) { | ||
StringBuilder chatHistoryBuilder = new StringBuilder(); | ||
String chatHistoryPrefix = params.getOrDefault(PROMPT_CHAT_HISTORY_PREFIX, CHAT_HISTORY_PREFIX); | ||
chatHistoryBuilder.append(chatHistoryPrefix); | ||
for (Message message : messageList) { | ||
|
@@ -220,7 +226,9 @@ private void runReAct( | |
AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); | ||
tmpParameters.put(PROMPT, newPrompt.get()); | ||
|
||
List<ModelTensors> traceTensors = createModelTensors(sessionId, parentInteractionId); | ||
List<ModelTensors> traceTensors = (conversationIndexMemory == null) | ||
? new ArrayList<>() | ||
: createModelTensors(sessionId, parentInteractionId); | ||
Comment on lines
+229
to
+231
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about passing conversationIndexMemory to this method and returning empty list rather than using ternary operator here? *ignore, seems like it will affect a lot of other methods using this |
||
int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, "3")) * 2; | ||
for (int i = 0; i < maxIterations; i++) { | ||
int finalI = i; | ||
|
@@ -401,8 +409,8 @@ private void runReAct( | |
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener); | ||
} | ||
|
||
private static List<ModelTensors> createFinalAnswerTensors(List<ModelTensors> sessionId, List<ModelTensor> lastThought) { | ||
List<ModelTensors> finalModelTensors = sessionId; | ||
private static List<ModelTensors> createFinalAnswerTensors(List<ModelTensors> modelTensorsList, List<ModelTensor> lastThought) { | ||
List<ModelTensors> finalModelTensors = modelTensorsList; | ||
finalModelTensors.add(ModelTensors.builder().mlModelTensors(lastThought).build()); | ||
return finalModelTensors; | ||
} | ||
|
@@ -572,19 +580,21 @@ private void sendFinalAnswer( | |
private static List<ModelTensors> createModelTensors(String sessionId, String parentInteractionId) { | ||
List<ModelTensors> cotModelTensors = new ArrayList<>(); | ||
|
||
cotModelTensors | ||
.add( | ||
ModelTensors | ||
.builder() | ||
.mlModelTensors( | ||
List | ||
.of( | ||
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), | ||
ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build() | ||
) | ||
) | ||
.build() | ||
); | ||
if (!StringUtils.isEmpty(sessionId)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's inverse this? if (StringUtils.isEmpty()) { return emptyList; } |
||
cotModelTensors | ||
.add( | ||
ModelTensors | ||
.builder() | ||
.mlModelTensors( | ||
List | ||
.of( | ||
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), | ||
ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build() | ||
) | ||
) | ||
.build() | ||
); | ||
} | ||
return cotModelTensors; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does the current behavior throw NPE here? wondering if there is some validation in the REST/Transport layer to check for memory, didn't see any tho