Skip to content

Commit 7a9c717

Browse files
opensearch-trigger-bot[bot]github-actions[bot]dbwiddis
authored
[Backport 2.x] feat: parse connector id from tool parameters map (#848)
* feat: parse connector id from tool parameters map (#846) * feat: parse connector id from tool parameters map Signed-off-by: yuye-aws <[email protected]> * update changelog Signed-off-by: yuye-aws <[email protected]> * implement unit test for connector, model and agent id Signed-off-by: yuye-aws <[email protected]> * tool step id: make node id unique Signed-off-by: yuye-aws <[email protected]> * integration test: create agent with connector tool Signed-off-by: yuye-aws <[email protected]> * integration test: update with get agent and get workflow Signed-off-by: yuye-aws <[email protected]> * optimize: iterate through connector_id model_id and agent_id Signed-off-by: yuye-aws <[email protected]> * update changelog Signed-off-by: yuye-aws <[email protected]> --------- Signed-off-by: yuye-aws <[email protected]> (cherry picked from commit b3f9d65) Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * JDK21 minimum is only for 3.x Signed-off-by: Daniel Widdis <[email protected]> --------- Signed-off-by: yuye-aws <[email protected]> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Signed-off-by: Daniel Widdis <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Daniel Widdis <[email protected]>
1 parent dad8d31 commit 7a9c717

File tree

7 files changed

+329
-99
lines changed

7 files changed

+329
-99
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
1616
### Features
1717
- Adds reprovision API to support updating search pipelines, ingest pipelines index settings ([#804](https://github.com/opensearch-project/flow-framework/pull/804))
1818
- Adds user level access control based on backend roles ([#838](https://github.com/opensearch-project/flow-framework/pull/838))
19+
- Support parsing connector_id when creating tools ([#846](https://github.com/opensearch-project/flow-framework/pull/846))
1920

2021
### Enhancements
2122
### Bug Fixes

src/main/java/org/opensearch/flowframework/workflow/ToolStep.java

+26-27
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
2929
import static org.opensearch.flowframework.common.CommonValue.TYPE;
3030
import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID;
31+
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
3132
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
3233

3334
/**
@@ -64,7 +65,15 @@ public PlainActionFuture<WorkflowData> execute(
6465
String name = (String) inputs.get(NAME_FIELD);
6566
String description = (String) inputs.get(DESCRIPTION_FIELD);
6667
Boolean includeOutputInAgentResponse = ParseUtils.parseIfExists(inputs, INCLUDE_OUTPUT_IN_AGENT_RESPONSE, Boolean.class);
67-
Map<String, String> parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs);
68+
69+
// parse connector_id, model_id and agent_id from previous node inputs
70+
Set<String> toolParameterKeys = Set.of(CONNECTOR_ID, MODEL_ID, AGENT_ID);
71+
Map<String, String> parameters = getToolsParametersMap(
72+
inputs.get(PARAMETERS_FIELD),
73+
previousNodeInputs,
74+
outputs,
75+
toolParameterKeys
76+
);
6877

6978
MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder();
7079

@@ -110,39 +119,29 @@ public String getName() {
110119
private Map<String, String> getToolsParametersMap(
111120
Object parameters,
112121
Map<String, String> previousNodeInputs,
113-
Map<String, WorkflowData> outputs
122+
Map<String, WorkflowData> outputs,
123+
Set<String> toolParameterKeys
114124
) {
115125
@SuppressWarnings("unchecked")
116126
Map<String, String> parametersMap = (Map<String, String>) parameters;
117-
Optional<String> previousNodeModel = previousNodeInputs.entrySet()
118-
.stream()
119-
.filter(e -> MODEL_ID.equals(e.getValue()))
120-
.map(Map.Entry::getKey)
121-
.findFirst();
122-
123-
Optional<String> previousNodeAgent = previousNodeInputs.entrySet()
124-
.stream()
125-
.filter(e -> AGENT_ID.equals(e.getValue()))
126-
.map(Map.Entry::getKey)
127-
.findFirst();
128-
129-
// Case when modelId is passed through previousSteps and not present already in parameters
130-
if (previousNodeModel.isPresent() && !parametersMap.containsKey(MODEL_ID)) {
131-
WorkflowData previousNodeOutput = outputs.get(previousNodeModel.get());
132-
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) {
133-
parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString());
134-
}
135-
}
136127

137-
// Case when agentId is passed through previousSteps and not present already in parameters
138-
if (previousNodeAgent.isPresent() && !parametersMap.containsKey(AGENT_ID)) {
139-
WorkflowData previousNodeOutput = outputs.get(previousNodeAgent.get());
140-
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(AGENT_ID)) {
141-
parametersMap.put(AGENT_ID, previousNodeOutput.getContent().get(AGENT_ID).toString());
128+
for (String toolParameterKey : toolParameterKeys) {
129+
Optional<String> previousNodeParameter = previousNodeInputs.entrySet()
130+
.stream()
131+
.filter(e -> toolParameterKey.equals(e.getValue()))
132+
.map(Map.Entry::getKey)
133+
.findFirst();
134+
135+
// Case when toolParameterKey is passed through previousSteps and not present already in parameters
136+
if (previousNodeParameter.isPresent() && !parametersMap.containsKey(toolParameterKey)) {
137+
WorkflowData previousNodeOutput = outputs.get(previousNodeParameter.get());
138+
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(toolParameterKey)) {
139+
parametersMap.put(toolParameterKey, previousNodeOutput.getContent().get(toolParameterKey).toString());
140+
}
142141
}
143142
}
144143

145-
// For other cases where modelId is already present in the parameters or not return the parametersMap
144+
// For other cases where toolParameterKey is already present in the parameters or not return the parametersMap
146145
return parametersMap;
147146
}
148147

src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java

+17-1
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,23 @@ protected Response getWorkflowStep(RestClient client) throws Exception {
651651
);
652652
}
653653

654+
/**
655+
* Helper method to invoke the Get Agent Rest Action
656+
* @param client the rest client
657+
* @return rest response
658+
* @throws Exception
659+
*/
660+
protected Response getAgent(RestClient client, String agentId) throws Exception {
661+
return TestHelpers.makeRequest(
662+
client,
663+
"GET",
664+
String.format(Locale.ROOT, "/_plugins/_ml/agents/%s", agentId),
665+
Collections.emptyMap(),
666+
"",
667+
null
668+
);
669+
}
670+
654671
/**
655672
* Helper method to invoke the Search Workflow Rest Action with the given query
656673
* @param client the rest client
@@ -659,7 +676,6 @@ protected Response getWorkflowStep(RestClient client) throws Exception {
659676
* @throws Exception if the request fails
660677
*/
661678
protected SearchResponse searchWorkflows(RestClient client, String query) throws Exception {
662-
663679
// Execute search
664680
Response restSearchResponse = TestHelpers.makeRequest(
665681
client,

src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java

+74-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import java.nio.charset.StandardCharsets;
3232
import java.time.Instant;
33+
import java.util.ArrayList;
3334
import java.util.Collections;
3435
import java.util.EnumSet;
3536
import java.util.HashMap;
@@ -56,7 +57,6 @@ public void waitToStart() throws Exception {
5657
}
5758

5859
public void testSearchWorkflows() throws Exception {
59-
6060
// Create a Workflow that has a credential 12345
6161
Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json");
6262
Response response = createWorkflow(client(), template);
@@ -228,7 +228,6 @@ public void testCreateAndProvisionCyclicalTemplate() throws Exception {
228228
}
229229

230230
public void testCreateAndProvisionRemoteModelWorkflow() throws Exception {
231-
232231
// Using a 3 step template to create a connector, register remote model and deploy model
233232
Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json");
234233

@@ -331,6 +330,79 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception {
331330
assertBusy(() -> { getAndAssertWorkflowStatusNotFound(client(), workflowId); }, 30, TimeUnit.SECONDS);
332331
}
333332

333+
public void testCreateAndProvisionConnectorToolAgentFrameworkWorkflow() throws Exception {
334+
// Create a Workflow that has a credential 12345
335+
Template template = TestHelpers.createTemplateFromFile("createconnector-createconnectortool-createflowagent.json");
336+
337+
// Hit Create Workflow API to create agent-framework template, with template validation check and provision parameter
338+
Response response = createWorkflowWithProvision(client(), template);
339+
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));
340+
Map<String, Object> responseMap = entityAsMap(response);
341+
String workflowId = (String) responseMap.get(WORKFLOW_ID);
342+
// wait and ensure state is completed/done
343+
assertBusy(
344+
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.COMPLETED, ProvisioningProgress.DONE); },
345+
120,
346+
TimeUnit.SECONDS
347+
);
348+
349+
// Assert based on the agent-framework template
350+
List<ResourceCreated> resourcesCreated = getResourcesCreated(client(), workflowId, 120);
351+
Map<String, ResourceCreated> resourceMap = resourcesCreated.stream()
352+
.collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r));
353+
assertEquals(2, resourceMap.size());
354+
assertTrue(resourceMap.containsKey("create_connector"));
355+
assertTrue(resourceMap.containsKey("register_agent"));
356+
String connectorId = resourceMap.get("create_connector").resourceId();
357+
String agentId = resourceMap.get("register_agent").resourceId();
358+
assertNotNull(connectorId);
359+
assertNotNull(agentId);
360+
361+
// Assert that the agent contains the correct connector_id
362+
response = getAgent(client(), agentId);
363+
Map<String, Object> agentResponse = entityAsMap(response);
364+
assertTrue(agentResponse.containsKey("tools"));
365+
@SuppressWarnings("unchecked")
366+
ArrayList<Map<String, Object>> tools = (ArrayList<Map<String, Object>>) agentResponse.get("tools");
367+
assertEquals(1, tools.size());
368+
Map<String, Object> tool = tools.get(0);
369+
assertTrue(tool.containsKey("parameters"));
370+
@SuppressWarnings("unchecked")
371+
Map<String, String> toolParameters = (Map<String, String>) tool.get("parameters");
372+
assertEquals(toolParameters, Map.of("connector_id", connectorId));
373+
374+
// Hit Deprovision API
375+
// By design, this may not completely deprovision the first time if it takes >2s to process removals
376+
Response deprovisionResponse = deprovisionWorkflow(client(), workflowId);
377+
try {
378+
assertBusy(
379+
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); },
380+
30,
381+
TimeUnit.SECONDS
382+
);
383+
} catch (ComparisonFailure e) {
384+
// 202 return if still processing
385+
assertEquals(RestStatus.ACCEPTED, TestHelpers.restStatus(deprovisionResponse));
386+
}
387+
if (TestHelpers.restStatus(deprovisionResponse) == RestStatus.ACCEPTED) {
388+
// Short wait before we try again
389+
Thread.sleep(10000);
390+
deprovisionResponse = deprovisionWorkflow(client(), workflowId);
391+
assertBusy(
392+
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); },
393+
30,
394+
TimeUnit.SECONDS
395+
);
396+
}
397+
assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse));
398+
// Hit Delete API
399+
Response deleteResponse = deleteWorkflow(client(), workflowId);
400+
assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse));
401+
402+
// Verify state doc is deleted
403+
assertBusy(() -> { getAndAssertWorkflowStatusNotFound(client(), workflowId); }, 30, TimeUnit.SECONDS);
404+
}
405+
334406
public void testReprovisionWorkflow() throws Exception {
335407
// Begin with a template to register a local pretrained model
336408
Template template = TestHelpers.createTemplateFromFile("registerremotemodel.json");
@@ -650,7 +722,6 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception {
650722
}
651723

652724
public void testDefaultCohereUseCase() throws Exception {
653-
654725
// Hit Create Workflow API with original template
655726
Response response = createWorkflowWithUseCaseWithNoValidation(
656727
client(),

src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java

+71-4
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,26 @@
1414
import org.opensearch.ml.common.agent.MLToolSpec;
1515
import org.opensearch.test.OpenSearchTestCase;
1616

17-
import java.io.IOException;
1817
import java.util.Collections;
1918
import java.util.Map;
2019
import java.util.concurrent.ExecutionException;
2120

21+
import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID;
22+
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
23+
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
24+
2225
public class ToolStepTests extends OpenSearchTestCase {
2326
private WorkflowData inputData;
27+
private WorkflowData inputDataWithConnectorId;
28+
private WorkflowData inputDataWithModelId;
29+
private WorkflowData inputDataWithAgentId;
30+
private static final String mockedConnectorId = "mocked-connector-id";
31+
private static final String mockedModelId = "mocked-model-id";
32+
private static final String mockedAgentId = "mocked-agent-id";
33+
private static final String createConnectorNodeId = "create_connector_node_id";
34+
private static final String createModelNodeId = "create_model_node_id";
35+
private static final String createAgentNodeId = "create_agent_node_id";
36+
2437
private WorkflowData boolStringInputData;
2538
private WorkflowData badBoolInputData;
2639

@@ -39,6 +52,9 @@ public void setUp() throws Exception {
3952
"test-id",
4053
"test-node-id"
4154
);
55+
inputDataWithConnectorId = new WorkflowData(Map.of(CONNECTOR_ID, mockedConnectorId), "test-id", createConnectorNodeId);
56+
inputDataWithModelId = new WorkflowData(Map.of(MODEL_ID, mockedModelId), "test-id", createModelNodeId);
57+
inputDataWithAgentId = new WorkflowData(Map.of(AGENT_ID, mockedAgentId), "test-id", createAgentNodeId);
4258
boolStringInputData = new WorkflowData(
4359
Map.ofEntries(
4460
Map.entry("type", "type"),
@@ -63,7 +79,7 @@ public void setUp() throws Exception {
6379
);
6480
}
6581

66-
public void testTool() throws IOException, ExecutionException, InterruptedException {
82+
public void testTool() throws ExecutionException, InterruptedException {
6783
ToolStep toolStep = new ToolStep();
6884

6985
PlainActionFuture<WorkflowData> future = toolStep.execute(
@@ -88,7 +104,7 @@ public void testTool() throws IOException, ExecutionException, InterruptedExcept
88104
assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass());
89105
}
90106

91-
public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException {
107+
public void testBoolParseFail() {
92108
ToolStep toolStep = new ToolStep();
93109

94110
PlainActionFuture<WorkflowData> future = toolStep.execute(
@@ -100,10 +116,61 @@ public void testBoolParseFail() throws IOException, ExecutionException, Interrup
100116
);
101117

102118
assertTrue(future.isDone());
103-
ExecutionException e = assertThrows(ExecutionException.class, () -> future.get());
119+
ExecutionException e = assertThrows(ExecutionException.class, future::get);
104120
assertEquals(WorkflowStepException.class, e.getCause().getClass());
105121
WorkflowStepException w = (WorkflowStepException) e.getCause();
106122
assertEquals("Failed to parse value [yes] as only [true] or [false] are allowed.", w.getMessage());
107123
assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus());
108124
}
125+
126+
public void testToolWithConnectorId() throws ExecutionException, InterruptedException {
127+
ToolStep toolStep = new ToolStep();
128+
129+
PlainActionFuture<WorkflowData> future = toolStep.execute(
130+
inputData.getNodeId(),
131+
inputData,
132+
Map.of(createConnectorNodeId, inputDataWithConnectorId),
133+
Map.of(createConnectorNodeId, CONNECTOR_ID),
134+
Collections.emptyMap()
135+
);
136+
assertTrue(future.isDone());
137+
Object tools = future.get().getContent().get("tools");
138+
assertEquals(MLToolSpec.class, tools.getClass());
139+
MLToolSpec mlToolSpec = (MLToolSpec) tools;
140+
assertEquals(mlToolSpec.getParameters(), Map.of(CONNECTOR_ID, mockedConnectorId));
141+
}
142+
143+
public void testToolWithModelId() throws ExecutionException, InterruptedException {
144+
ToolStep toolStep = new ToolStep();
145+
146+
PlainActionFuture<WorkflowData> future = toolStep.execute(
147+
inputData.getNodeId(),
148+
inputData,
149+
Map.of(createModelNodeId, inputDataWithModelId),
150+
Map.of(createModelNodeId, MODEL_ID),
151+
Collections.emptyMap()
152+
);
153+
assertTrue(future.isDone());
154+
Object tools = future.get().getContent().get("tools");
155+
assertEquals(MLToolSpec.class, tools.getClass());
156+
MLToolSpec mlToolSpec = (MLToolSpec) tools;
157+
assertEquals(mlToolSpec.getParameters(), Map.of(MODEL_ID, mockedModelId));
158+
}
159+
160+
public void testToolWithAgentId() throws ExecutionException, InterruptedException {
161+
ToolStep toolStep = new ToolStep();
162+
163+
PlainActionFuture<WorkflowData> future = toolStep.execute(
164+
inputData.getNodeId(),
165+
inputData,
166+
Map.of(createAgentNodeId, inputDataWithAgentId),
167+
Map.of(createAgentNodeId, AGENT_ID),
168+
Collections.emptyMap()
169+
);
170+
assertTrue(future.isDone());
171+
Object tools = future.get().getContent().get("tools");
172+
assertEquals(MLToolSpec.class, tools.getClass());
173+
MLToolSpec mlToolSpec = (MLToolSpec) tools;
174+
assertEquals(mlToolSpec.getParameters(), Map.of(AGENT_ID, mockedAgentId));
175+
}
109176
}

0 commit comments

Comments
 (0)