54
54
import static org .opensearch .flowframework .common .WorkflowResources .MODEL_ID ;
55
55
import static org .mockito .ArgumentMatchers .any ;
56
56
import static org .mockito .ArgumentMatchers .anyString ;
57
+ import static org .mockito .ArgumentMatchers .nullable ;
57
58
import static org .mockito .Mockito .doAnswer ;
58
59
import static org .mockito .Mockito .mock ;
59
60
import static org .mockito .Mockito .times ;
@@ -147,11 +148,11 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
147
148
148
149
// Stub getTask for success case
149
150
doAnswer (invocation -> {
150
- ActionListener <MLTask > actionListener = invocation .getArgument (1 );
151
+ ActionListener <MLTask > actionListener = invocation .getArgument (2 );
151
152
MLTask output = MLTask .builder ().taskId (taskId ).modelId (modelId ).state (MLTaskState .COMPLETED ).async (false ).build ();
152
153
actionListener .onResponse (output );
153
154
return null ;
154
- }).when (machineLearningNodeClient ).getTask (any (), any ());
155
+ }).when (machineLearningNodeClient ).getTask (any (), nullable ( String . class ), any ());
155
156
156
157
doAnswer (invocation -> {
157
158
ActionListener <WorkflowData > updateResponseListener = invocation .getArgument (5 );
@@ -172,7 +173,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
172
173
future .actionGet ();
173
174
174
175
verify (machineLearningNodeClient , times (1 )).register (any (MLRegisterModelInput .class ), any ());
175
- verify (machineLearningNodeClient , times (1 )).getTask (any (), any ());
176
+ verify (machineLearningNodeClient , times (1 )).getTask (any (), nullable ( String . class ), any ());
176
177
177
178
assertEquals (modelId , future .get ().getContent ().get (MODEL_ID ));
178
179
assertEquals (status , future .get ().getContent ().get (REGISTER_MODEL_STATUS ));
@@ -208,7 +209,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
208
209
future .actionGet ();
209
210
210
211
verify (machineLearningNodeClient , times (2 )).register (any (MLRegisterModelInput .class ), any ());
211
- verify (machineLearningNodeClient , times (2 )).getTask (any (), any ());
212
+ verify (machineLearningNodeClient , times (2 )).getTask (any (), nullable ( String . class ), any ());
212
213
213
214
assertEquals (modelId , future .get ().getContent ().get (MODEL_ID ));
214
215
assertEquals (status , future .get ().getContent ().get (REGISTER_MODEL_STATUS ));
@@ -230,11 +231,11 @@ public void testRegisterLocalCustomModelDeployStateUpdateFail() throws Exception
230
231
231
232
// Stub getTask for success case
232
233
doAnswer (invocation -> {
233
- ActionListener <MLTask > actionListener = invocation .getArgument (1 );
234
+ ActionListener <MLTask > actionListener = invocation .getArgument (2 );
234
235
MLTask output = MLTask .builder ().taskId (taskId ).modelId (modelId ).state (MLTaskState .COMPLETED ).async (false ).build ();
235
236
actionListener .onResponse (output );
236
237
return null ;
237
- }).when (machineLearningNodeClient ).getTask (any (), any ());
238
+ }).when (machineLearningNodeClient ).getTask (any (), nullable ( String . class ), any ());
238
239
239
240
AtomicInteger invocationCount = new AtomicInteger (0 );
240
241
doAnswer (invocation -> {
@@ -321,7 +322,7 @@ public void testRegisterLocalCustomModelTaskFailure() {
321
322
322
323
// Stub get ml task for failure case
323
324
doAnswer (invocation -> {
324
- ActionListener <MLTask > actionListener = invocation .getArgument (1 );
325
+ ActionListener <MLTask > actionListener = invocation .getArgument (2 );
325
326
MLTask output = MLTask .builder ()
326
327
.taskId (taskId )
327
328
.modelId (modelId )
@@ -331,7 +332,7 @@ public void testRegisterLocalCustomModelTaskFailure() {
331
332
.build ();
332
333
actionListener .onResponse (output );
333
334
return null ;
334
- }).when (machineLearningNodeClient ).getTask (any (), any ());
335
+ }).when (machineLearningNodeClient ).getTask (any (), nullable ( String . class ), any ());
335
336
336
337
PlainActionFuture <WorkflowData > future = this .registerLocalModelStep .execute (
337
338
workflowData .getNodeId (),
0 commit comments