Skip to content

Commit 33003ea

Browse files
opensearch-trigger-bot[bot]github-actions[bot]Siddhartha Bingidbwiddis
authored
[Backport 2.19] Passing tenantId to mlclient getTask for multiTenancy feature (#1044)
* Passing tenantId to mlclient getTask for multiTenancy feature (#1041) Signed-off-by: Siddhartha Bingi <[email protected]> Co-authored-by: Siddhartha Bingi <[email protected]> (cherry picked from commit ada9b06) Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Update getTask mocks to add new parameter Signed-off-by: Daniel Widdis <[email protected]> --------- Signed-off-by: Siddhartha Bingi <[email protected]> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Signed-off-by: Daniel Widdis <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Siddhartha Bingi <[email protected]> Co-authored-by: Daniel Widdis <[email protected]>
1 parent 557a751 commit 33003ea

File tree

5 files changed

+29
-26
lines changed

5 files changed

+29
-26
lines changed

Diff for: src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ protected void retryableGetMlTask(
8282
) {
8383
CompletableFuture.runAsync(() -> {
8484
do {
85-
mlClient.getTask(taskId, ActionListener.wrap(response -> {
85+
mlClient.getTask(taskId, tenantId, ActionListener.wrap(response -> {
8686
String resourceName = getResourceByWorkflowStep(getName());
8787
String id = getResourceId(response);
8888
switch (response.getState()) {

Diff for: src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,11 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I
127127

128128
// Stub getTask for success case
129129
doAnswer(invocation -> {
130-
ActionListener<MLTask> actionListener = invocation.getArgument(1);
130+
ActionListener<MLTask> actionListener = invocation.getArgument(2);
131131
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
132132
actionListener.onResponse(output);
133133
return null;
134-
}).when(machineLearningNodeClient).getTask(any(), any());
134+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
135135

136136
doAnswer(invocation -> {
137137
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
@@ -152,7 +152,7 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I
152152
future.actionGet();
153153

154154
verify(machineLearningNodeClient, times(1)).deploy(any(String.class), nullable(String.class), any());
155-
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
155+
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());
156156

157157
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
158158
}
@@ -203,11 +203,11 @@ public void testDeployModelTaskFailure() throws IOException, InterruptedExceptio
203203

204204
// Stub getTask for success case
205205
doAnswer(invocation -> {
206-
ActionListener<MLTask> actionListener = invocation.getArgument(1);
206+
ActionListener<MLTask> actionListener = invocation.getArgument(2);
207207
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.FAILED).async(false).error("error").build();
208208
actionListener.onResponse(output);
209209
return null;
210-
}).when(machineLearningNodeClient).getTask(any(), any());
210+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
211211

212212
PlainActionFuture<WorkflowData> future = this.deployModel.execute(
213213
inputData.getNodeId(),

Diff for: src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java

+9-8
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
5555
import static org.mockito.ArgumentMatchers.any;
5656
import static org.mockito.ArgumentMatchers.anyString;
57+
import static org.mockito.ArgumentMatchers.nullable;
5758
import static org.mockito.Mockito.doAnswer;
5859
import static org.mockito.Mockito.mock;
5960
import static org.mockito.Mockito.times;
@@ -147,11 +148,11 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
147148

148149
// Stub getTask for success case
149150
doAnswer(invocation -> {
150-
ActionListener<MLTask> actionListener = invocation.getArgument(1);
151+
ActionListener<MLTask> actionListener = invocation.getArgument(2);
151152
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
152153
actionListener.onResponse(output);
153154
return null;
154-
}).when(machineLearningNodeClient).getTask(any(), any());
155+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
155156

156157
doAnswer(invocation -> {
157158
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
@@ -172,7 +173,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
172173
future.actionGet();
173174

174175
verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
175-
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
176+
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());
176177

177178
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
178179
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
@@ -208,7 +209,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
208209
future.actionGet();
209210

210211
verify(machineLearningNodeClient, times(2)).register(any(MLRegisterModelInput.class), any());
211-
verify(machineLearningNodeClient, times(2)).getTask(any(), any());
212+
verify(machineLearningNodeClient, times(2)).getTask(any(), nullable(String.class), any());
212213

213214
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
214215
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
@@ -230,11 +231,11 @@ public void testRegisterLocalCustomModelDeployStateUpdateFail() throws Exception
230231

231232
// Stub getTask for success case
232233
doAnswer(invocation -> {
233-
ActionListener<MLTask> actionListener = invocation.getArgument(1);
234+
ActionListener<MLTask> actionListener = invocation.getArgument(2);
234235
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
235236
actionListener.onResponse(output);
236237
return null;
237-
}).when(machineLearningNodeClient).getTask(any(), any());
238+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
238239

239240
AtomicInteger invocationCount = new AtomicInteger(0);
240241
doAnswer(invocation -> {
@@ -321,7 +322,7 @@ public void testRegisterLocalCustomModelTaskFailure() {
321322

322323
// Stub get ml task for failure case
323324
doAnswer(invocation -> {
324-
ActionListener<MLTask> actionListener = invocation.getArgument(1);
325+
ActionListener<MLTask> actionListener = invocation.getArgument(2);
325326
MLTask output = MLTask.builder()
326327
.taskId(taskId)
327328
.modelId(modelId)
@@ -331,7 +332,7 @@ public void testRegisterLocalCustomModelTaskFailure() {
331332
.build();
332333
actionListener.onResponse(output);
333334
return null;
334-
}).when(machineLearningNodeClient).getTask(any(), any());
335+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
335336

336337
PlainActionFuture<WorkflowData> future = this.registerLocalModelStep.execute(
337338
workflowData.getNodeId(),

Diff for: src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
5454
import static org.mockito.ArgumentMatchers.any;
5555
import static org.mockito.ArgumentMatchers.anyString;
56+
import static org.mockito.ArgumentMatchers.nullable;
5657
import static org.mockito.Mockito.doAnswer;
5758
import static org.mockito.Mockito.mock;
5859
import static org.mockito.Mockito.times;
@@ -140,11 +141,11 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception {
140141

141142
// Stub getTask for success case
142143
doAnswer(invocation -> {
143-
ActionListener<MLTask> actionListener = invocation.getArgument(1);
144+
ActionListener<MLTask> actionListener = invocation.getArgument(2);
144145
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
145146
actionListener.onResponse(output);
146147
return null;
147-
}).when(machineLearningNodeClient).getTask(any(), any());
148+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
148149

149150
doAnswer(invocation -> {
150151
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
@@ -165,7 +166,7 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception {
165166
future.actionGet();
166167

167168
verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
168-
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
169+
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());
169170

170171
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
171172
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
@@ -195,7 +196,7 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception {
195196
future.actionGet();
196197

197198
verify(machineLearningNodeClient, times(2)).register(any(MLRegisterModelInput.class), any());
198-
verify(machineLearningNodeClient, times(2)).getTask(any(), any());
199+
verify(machineLearningNodeClient, times(2)).getTask(any(), nullable(String.class), any());
199200

200201
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
201202
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
@@ -239,7 +240,7 @@ public void testRegisterLocalPretrainedModelTaskFailure() {
239240

240241
// Stub get ml task for failure case
241242
doAnswer(invocation -> {
242-
ActionListener<MLTask> actionListener = invocation.getArgument(1);
243+
ActionListener<MLTask> actionListener = invocation.getArgument(2);
243244
MLTask output = MLTask.builder()
244245
.taskId(taskId)
245246
.modelId(modelId)
@@ -249,7 +250,7 @@ public void testRegisterLocalPretrainedModelTaskFailure() {
249250
.build();
250251
actionListener.onResponse(output);
251252
return null;
252-
}).when(machineLearningNodeClient).getTask(any(), any());
253+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
253254

254255
PlainActionFuture<WorkflowData> future = this.registerLocalPretrainedModelStep.execute(
255256
workflowData.getNodeId(),

Diff for: src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
5454
import static org.mockito.ArgumentMatchers.any;
5555
import static org.mockito.ArgumentMatchers.anyString;
56+
import static org.mockito.ArgumentMatchers.nullable;
5657
import static org.mockito.Mockito.doAnswer;
5758
import static org.mockito.Mockito.mock;
5859
import static org.mockito.Mockito.times;
@@ -143,11 +144,11 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception {
143144

144145
// Stub getTask for success case
145146
doAnswer(invocation -> {
146-
ActionListener<MLTask> actionListener = invocation.getArgument(1);
147+
ActionListener<MLTask> actionListener = invocation.getArgument(2);
147148
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
148149
actionListener.onResponse(output);
149150
return null;
150-
}).when(machineLearningNodeClient).getTask(any(), any());
151+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
151152

152153
doAnswer(invocation -> {
153154
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
@@ -168,7 +169,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception {
168169
future.actionGet();
169170

170171
verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
171-
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
172+
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());
172173

173174
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
174175
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
@@ -200,7 +201,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception {
200201
future.actionGet();
201202

202203
verify(machineLearningNodeClient, times(2)).register(any(MLRegisterModelInput.class), any());
203-
verify(machineLearningNodeClient, times(2)).getTask(any(), any());
204+
verify(machineLearningNodeClient, times(2)).getTask(any(), nullable(String.class), any());
204205

205206
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
206207
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
@@ -244,7 +245,7 @@ public void testRegisterLocalSparseEncodingModelTaskFailure() {
244245

245246
// Stub get ml task for failure case
246247
doAnswer(invocation -> {
247-
ActionListener<MLTask> actionListener = invocation.getArgument(1);
248+
ActionListener<MLTask> actionListener = invocation.getArgument(2);
248249
MLTask output = MLTask.builder()
249250
.taskId(taskId)
250251
.modelId(modelId)
@@ -254,7 +255,7 @@ public void testRegisterLocalSparseEncodingModelTaskFailure() {
254255
.build();
255256
actionListener.onResponse(output);
256257
return null;
257-
}).when(machineLearningNodeClient).getTask(any(), any());
258+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
258259

259260
PlainActionFuture<WorkflowData> future = this.registerLocalSparseEncodingModelStep.execute(
260261
workflowData.getNodeId(),

0 commit comments

Comments
 (0)