53
53
import static org .opensearch .flowframework .common .WorkflowResources .MODEL_ID ;
54
54
import static org .mockito .ArgumentMatchers .any ;
55
55
import static org .mockito .ArgumentMatchers .anyString ;
56
+ import static org .mockito .ArgumentMatchers .nullable ;
56
57
import static org .mockito .Mockito .doAnswer ;
57
58
import static org .mockito .Mockito .mock ;
58
59
import static org .mockito .Mockito .times ;
@@ -147,7 +148,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception {
147
148
MLTask output = MLTask .builder ().taskId (taskId ).modelId (modelId ).state (MLTaskState .COMPLETED ).async (false ).build ();
148
149
actionListener .onResponse (output );
149
150
return null ;
150
- }).when (machineLearningNodeClient ).getTask (any (), any ());
151
+ }).when (machineLearningNodeClient ).getTask (any (), nullable ( String . class ), any ());
151
152
152
153
doAnswer (invocation -> {
153
154
ActionListener <WorkflowData > updateResponseListener = invocation .getArgument (5 );
@@ -168,7 +169,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception {
168
169
future .actionGet ();
169
170
170
171
verify (machineLearningNodeClient , times (1 )).register (any (MLRegisterModelInput .class ), any ());
171
- verify (machineLearningNodeClient , times (1 )).getTask (any (), any ());
172
+ verify (machineLearningNodeClient , times (1 )).getTask (any (), nullable ( String . class ), any ());
172
173
173
174
assertEquals (modelId , future .get ().getContent ().get (MODEL_ID ));
174
175
assertEquals (status , future .get ().getContent ().get (REGISTER_MODEL_STATUS ));
0 commit comments