14
14
import org .opensearch .ml .common .agent .MLToolSpec ;
15
15
import org .opensearch .test .OpenSearchTestCase ;
16
16
17
- import java .io .IOException ;
18
17
import java .util .Collections ;
19
18
import java .util .Map ;
20
19
import java .util .concurrent .ExecutionException ;
21
20
21
+ import static org .opensearch .flowframework .common .WorkflowResources .AGENT_ID ;
22
+ import static org .opensearch .flowframework .common .WorkflowResources .CONNECTOR_ID ;
23
+ import static org .opensearch .flowframework .common .WorkflowResources .MODEL_ID ;
24
+
22
25
public class ToolStepTests extends OpenSearchTestCase {
23
26
private WorkflowData inputData ;
27
+ private WorkflowData inputDataWithConnectorId ;
28
+ private WorkflowData inputDataWithModelId ;
29
+ private WorkflowData inputDataWithAgentId ;
30
+ private static final String mockedConnectorId = "mocked-connector-id" ;
31
+ private static final String mockedModelId = "mocked-model-id" ;
32
+ private static final String mockedAgentId = "mocked-agent-id" ;
33
+ private static final String createConnectorNodeId = "create_connector_node_id" ;
34
+ private static final String createModelNodeId = "create_model_node_id" ;
35
+ private static final String createAgentNodeId = "create_agent_node_id" ;
36
+
24
37
private WorkflowData boolStringInputData ;
25
38
private WorkflowData badBoolInputData ;
26
39
@@ -39,6 +52,9 @@ public void setUp() throws Exception {
39
52
"test-id" ,
40
53
"test-node-id"
41
54
);
55
+ inputDataWithConnectorId = new WorkflowData (Map .of (CONNECTOR_ID , mockedConnectorId ), "test-id" , createConnectorNodeId );
56
+ inputDataWithModelId = new WorkflowData (Map .of (MODEL_ID , mockedModelId ), "test-id" , createModelNodeId );
57
+ inputDataWithAgentId = new WorkflowData (Map .of (AGENT_ID , mockedAgentId ), "test-id" , createAgentNodeId );
42
58
boolStringInputData = new WorkflowData (
43
59
Map .ofEntries (
44
60
Map .entry ("type" , "type" ),
@@ -63,7 +79,7 @@ public void setUp() throws Exception {
63
79
);
64
80
}
65
81
66
- public void testTool () throws IOException , ExecutionException , InterruptedException {
82
+ public void testTool () throws ExecutionException , InterruptedException {
67
83
ToolStep toolStep = new ToolStep ();
68
84
69
85
PlainActionFuture <WorkflowData > future = toolStep .execute (
@@ -88,7 +104,7 @@ public void testTool() throws IOException, ExecutionException, InterruptedExcept
88
104
assertEquals (MLToolSpec .class , future .get ().getContent ().get ("tools" ).getClass ());
89
105
}
90
106
91
- public void testBoolParseFail () throws IOException , ExecutionException , InterruptedException {
107
+ public void testBoolParseFail () {
92
108
ToolStep toolStep = new ToolStep ();
93
109
94
110
PlainActionFuture <WorkflowData > future = toolStep .execute (
@@ -100,10 +116,61 @@ public void testBoolParseFail() throws IOException, ExecutionException, Interrup
100
116
);
101
117
102
118
assertTrue (future .isDone ());
103
- ExecutionException e = assertThrows (ExecutionException .class , () -> future . get () );
119
+ ExecutionException e = assertThrows (ExecutionException .class , future :: get );
104
120
assertEquals (WorkflowStepException .class , e .getCause ().getClass ());
105
121
WorkflowStepException w = (WorkflowStepException ) e .getCause ();
106
122
assertEquals ("Failed to parse value [yes] as only [true] or [false] are allowed." , w .getMessage ());
107
123
assertEquals (RestStatus .BAD_REQUEST , w .getRestStatus ());
108
124
}
125
+
126
+ public void testToolWithConnectorId () throws ExecutionException , InterruptedException {
127
+ ToolStep toolStep = new ToolStep ();
128
+
129
+ PlainActionFuture <WorkflowData > future = toolStep .execute (
130
+ inputData .getNodeId (),
131
+ inputData ,
132
+ Map .of (createConnectorNodeId , inputDataWithConnectorId ),
133
+ Map .of (createConnectorNodeId , CONNECTOR_ID ),
134
+ Collections .emptyMap ()
135
+ );
136
+ assertTrue (future .isDone ());
137
+ Object tools = future .get ().getContent ().get ("tools" );
138
+ assertEquals (MLToolSpec .class , tools .getClass ());
139
+ MLToolSpec mlToolSpec = (MLToolSpec ) tools ;
140
+ assertEquals (mlToolSpec .getParameters (), Map .of (CONNECTOR_ID , mockedConnectorId ));
141
+ }
142
+
143
+ public void testToolWithModelId () throws ExecutionException , InterruptedException {
144
+ ToolStep toolStep = new ToolStep ();
145
+
146
+ PlainActionFuture <WorkflowData > future = toolStep .execute (
147
+ inputData .getNodeId (),
148
+ inputData ,
149
+ Map .of (createModelNodeId , inputDataWithModelId ),
150
+ Map .of (createModelNodeId , MODEL_ID ),
151
+ Collections .emptyMap ()
152
+ );
153
+ assertTrue (future .isDone ());
154
+ Object tools = future .get ().getContent ().get ("tools" );
155
+ assertEquals (MLToolSpec .class , tools .getClass ());
156
+ MLToolSpec mlToolSpec = (MLToolSpec ) tools ;
157
+ assertEquals (mlToolSpec .getParameters (), Map .of (MODEL_ID , mockedModelId ));
158
+ }
159
+
160
+ public void testToolWithAgentId () throws ExecutionException , InterruptedException {
161
+ ToolStep toolStep = new ToolStep ();
162
+
163
+ PlainActionFuture <WorkflowData > future = toolStep .execute (
164
+ inputData .getNodeId (),
165
+ inputData ,
166
+ Map .of (createAgentNodeId , inputDataWithAgentId ),
167
+ Map .of (createAgentNodeId , AGENT_ID ),
168
+ Collections .emptyMap ()
169
+ );
170
+ assertTrue (future .isDone ());
171
+ Object tools = future .get ().getContent ().get ("tools" );
172
+ assertEquals (MLToolSpec .class , tools .getClass ());
173
+ MLToolSpec mlToolSpec = (MLToolSpec ) tools ;
174
+ assertEquals (mlToolSpec .getParameters (), Map .of (AGENT_ID , mockedAgentId ));
175
+ }
109
176
}
0 commit comments