Skip to content

Commit b9b5687

Browse files
[FEATURE] Agent Execute Stream (opensearch-project#4212)
* Initial commit for agent streaming Signed-off-by: Nathalie Jonathan <[email protected]> * Create streaming handler factory, address comments Signed-off-by: Nathalie Jonathan <[email protected]> * Address more comments Signed-off-by: Nathalie Jonathan <[email protected]> * Fix failing tests Signed-off-by: Nathalie Jonathan <[email protected]> * Address comments, add some tests Signed-off-by: Nathalie Jonathan <[email protected]> * clean up agent runner Signed-off-by: Nathalie Jonathan <[email protected]> * Address comments Signed-off-by: Nathalie Jonathan <[email protected]> * Address comments, add more tests Signed-off-by: Nathalie Jonathan <[email protected]> * Fix test after rebase Signed-off-by: Nathalie Jonathan <[email protected]> * Fix failing tests Signed-off-by: Nathalie Jonathan <[email protected]> --------- Signed-off-by: Nathalie Jonathan <[email protected]>
1 parent 080f974 commit b9b5687

File tree

50 files changed

+2590
-475
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2590
-475
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.regex.Matcher;
2727
import java.util.regex.Pattern;
2828

29+
import org.apache.commons.lang3.StringUtils;
2930
import org.apache.commons.text.StringEscapeUtils;
3031
import org.apache.commons.text.StringSubstitutor;
3132
import org.opensearch.Version;
@@ -56,6 +57,8 @@ public class HttpConnector extends AbstractConnector {
5657
public static final String PARAMETERS_FIELD = "parameters";
5758
public static final String SERVICE_NAME_FIELD = "service_name";
5859
public static final String REGION_FIELD = "region";
60+
// TODO: move the AgentUtils class from algorithm module to common module
61+
public static final String LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS = "openai/v1/chat/completions";
5962

6063
// TODO: add RequestConfig like request time out,
6164

@@ -377,14 +380,14 @@ private boolean neededStreamParameterInPayload(Map<String, String> parameters) {
377380
}
378381

379382
String llmInterface = parameters.get("_llm_interface");
380-
if (llmInterface.isBlank()) {
383+
if (StringUtils.isBlank(llmInterface)) {
381384
return false;
382385
}
383386

384387
llmInterface = llmInterface.trim().toLowerCase(Locale.ROOT);
385388
llmInterface = StringEscapeUtils.unescapeJava(llmInterface);
386389
switch (llmInterface) {
387-
case "openai/v1/chat/completions":
390+
case LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS:
388391
return true;
389392
default:
390393
return false;
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.execute;
7+
8+
import org.opensearch.action.ActionType;
9+
10+
public class MLExecuteStreamTaskAction extends ActionType<MLExecuteTaskResponse> {
11+
public static final MLExecuteStreamTaskAction INSTANCE = new MLExecuteStreamTaskAction();
12+
public static final String NAME = "cluster:admin/opensearch/ml/execute/stream";
13+
14+
private MLExecuteStreamTaskAction() {
15+
super(NAME, MLExecuteTaskResponse::new);
16+
}
17+
}

common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,27 @@
2222
import org.opensearch.ml.common.MLCommonsClassLoader;
2323
import org.opensearch.ml.common.input.Input;
2424
import org.opensearch.ml.common.transport.MLTaskRequest;
25+
import org.opensearch.transport.TransportChannel;
2526

2627
import lombok.AccessLevel;
2728
import lombok.Builder;
2829
import lombok.Getter;
2930
import lombok.NonNull;
31+
import lombok.Setter;
3032
import lombok.ToString;
3133
import lombok.experimental.FieldDefaults;
34+
import lombok.experimental.NonFinal;
3235

3336
@Getter
3437
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
3538
@ToString
3639
public class MLExecuteTaskRequest extends MLTaskRequest {
3740

41+
@Getter
42+
@Setter
43+
@NonFinal
44+
private transient TransportChannel streamingChannel;
45+
3846
FunctionName functionName;
3947
Input input;
4048

ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,24 @@
99
import org.opensearch.ml.common.exception.ExecuteException;
1010
import org.opensearch.ml.common.input.Input;
1111
import org.opensearch.ml.common.output.Output;
12+
import org.opensearch.transport.TransportChannel;
1213

1314
public interface Executable {
1415

1516
/**
16-
* Execute algorithm with given input data.
17+
* Execute algorithm with given input data (non-streaming).
1718
* @param input input data
19+
* @param listener action listener
1820
*/
19-
void execute(Input input, ActionListener<Output> listener) throws ExecuteException;
21+
default void execute(Input input, ActionListener<Output> listener) throws ExecuteException {
22+
execute(input, listener, null);
23+
}
24+
25+
/**
26+
* Execute algorithm with given input data (streaming).
27+
* @param input input data
28+
* @param listener action listener
29+
* @param channel transport channel
30+
*/
31+
default void execute(Input input, ActionListener<Output> listener, TransportChannel channel) {}
2032
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.opensearch.ml.common.output.MLOutput;
2727
import org.opensearch.ml.common.output.Output;
2828
import org.opensearch.ml.engine.encryptor.Encryptor;
29+
import org.opensearch.transport.TransportChannel;
2930

3031
import lombok.Getter;
3132
import lombok.extern.log4j.Log4j2;
@@ -186,7 +187,7 @@ public MLOutput trainAndPredict(Input input) {
186187
return trainAndPredictable.trainAndPredict(mlInput);
187188
}
188189

189-
public void execute(Input input, ActionListener<Output> listener) throws Exception {
190+
public void execute(Input input, ActionListener<Output> listener, TransportChannel channel) throws Exception {
190191
validateInput(input);
191192
if (input.getFunctionName() == FunctionName.METRICS_CORRELATION) {
192193
MLExecutable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
@@ -199,6 +200,10 @@ public void execute(Input input, ActionListener<Output> listener) throws Excepti
199200
if (executable == null) {
200201
throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
201202
}
203+
if (channel != null) {
204+
executable.execute(input, listener, channel);
205+
return;
206+
}
202207
executable.execute(input, listener);
203208
}
204209
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
import org.opensearch.remote.metadata.client.SdkClient;
8181
import org.opensearch.remote.metadata.common.SdkClientUtils;
8282
import org.opensearch.search.fetch.subphase.FetchSourceContext;
83+
import org.opensearch.transport.TransportChannel;
8384
import org.opensearch.transport.client.Client;
8485

8586
import com.google.common.annotations.VisibleForTesting;
@@ -143,7 +144,7 @@ public void onMultiTenancyEnabledChanged(boolean isEnabled) {
143144
}
144145

145146
@Override
146-
public void execute(Input input, ActionListener<Output> listener) {
147+
public void execute(Input input, ActionListener<Output> listener, TransportChannel channel) {
147148
if (!(input instanceof AgentMLInput)) {
148149
throw new IllegalArgumentException("wrong input");
149150
}
@@ -271,7 +272,8 @@ public void execute(Input input, ActionListener<Output> listener) {
271272
isAsync,
272273
outputs,
273274
modelTensors,
274-
mlAgent
275+
mlAgent,
276+
channel
275277
);
276278
}, e -> {
277279
log.error("Failed to get existing interaction for regeneration", e);
@@ -287,7 +289,8 @@ public void execute(Input input, ActionListener<Output> listener) {
287289
isAsync,
288290
outputs,
289291
modelTensors,
290-
mlAgent
292+
mlAgent,
293+
channel
291294
);
292295
}
293296
}, ex -> {
@@ -318,7 +321,8 @@ public void execute(Input input, ActionListener<Output> listener) {
318321
outputs,
319322
modelTensors,
320323
listener,
321-
createdMemory
324+
createdMemory,
325+
channel
322326
),
323327
ex -> {
324328
log.error("Failed to find memory with memory_id: {}", memoryId, ex);
@@ -329,7 +333,6 @@ public void execute(Input input, ActionListener<Output> listener) {
329333
return;
330334
}
331335
}
332-
333336
executeAgent(
334337
inputDataSet,
335338
mlTask,
@@ -339,7 +342,8 @@ public void execute(Input input, ActionListener<Output> listener) {
339342
outputs,
340343
modelTensors,
341344
listener,
342-
null
345+
null,
346+
channel
343347
);
344348
}
345349
} catch (Exception e) {
@@ -382,7 +386,8 @@ private void saveRootInteractionAndExecute(
382386
boolean isAsync,
383387
List<ModelTensors> outputs,
384388
List<ModelTensor> modelTensors,
385-
MLAgent mlAgent
389+
MLAgent mlAgent,
390+
TransportChannel channel
386391
) {
387392
String appType = mlAgent.getAppType();
388393
String question = inputDataSet.getParameters().get(QUESTION);
@@ -416,7 +421,8 @@ private void saveRootInteractionAndExecute(
416421
outputs,
417422
modelTensors,
418423
listener,
419-
memory
424+
memory,
425+
channel
420426
),
421427
e -> {
422428
log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e);
@@ -425,7 +431,18 @@ private void saveRootInteractionAndExecute(
425431
)
426432
);
427433
} else {
428-
executeAgent(inputDataSet, mlTask, isAsync, memory.getConversationId(), mlAgent, outputs, modelTensors, listener, memory);
434+
executeAgent(
435+
inputDataSet,
436+
mlTask,
437+
isAsync,
438+
memory.getConversationId(),
439+
mlAgent,
440+
outputs,
441+
modelTensors,
442+
listener,
443+
memory,
444+
channel
445+
);
429446
}
430447
}, ex -> {
431448
log.error("Failed to create parent interaction", ex);
@@ -442,7 +459,8 @@ private void executeAgent(
442459
List<ModelTensors> outputs,
443460
List<ModelTensor> modelTensors,
444461
ActionListener<Output> listener,
445-
ConversationIndexMemory memory
462+
ConversationIndexMemory memory,
463+
TransportChannel channel
446464
) {
447465
String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null;
448466
if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) {
@@ -494,7 +512,7 @@ private void executeAgent(
494512
memory
495513
);
496514
inputDataSet.getParameters().put(TASK_ID_FIELD, taskId);
497-
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
515+
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener, channel);
498516
}, e -> {
499517
log.error("Failed to create task for agent async execution", e);
500518
listener.onFailure(e);
@@ -508,7 +526,7 @@ private void executeAgent(
508526
parentInteractionId,
509527
memory
510528
);
511-
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
529+
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener, channel);
512530
}
513531
}
514532

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentRunner.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,29 @@
99

1010
import org.opensearch.core.action.ActionListener;
1111
import org.opensearch.ml.common.agent.MLAgent;
12+
import org.opensearch.transport.TransportChannel;
1213

1314
/**
1415
* Agent executor interface definition. Agent executor will be used by {@link MLAgentExecutor} to invoke agents.
1516
*/
1617
public interface MLAgentRunner {
1718

1819
/**
19-
* Function interface to execute agent.
20+
* Function interface to execute agent (non-streaming)
2021
* @param mlAgent
2122
* @param params
2223
* @param listener
2324
*/
24-
void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener);
25+
default void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener) {
26+
run(mlAgent, params, listener, null);
27+
}
28+
29+
/**
30+
* Function interface to execute agent (streaming)
31+
* @param mlAgent
32+
* @param params
33+
* @param listener
34+
* @param channel
35+
*/
36+
void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener, TransportChannel channel);
2537
}

0 commit comments

Comments
 (0)