Skip to content

Commit ada9b06

Browse files
siddharthabingiSiddhartha Bingi
and
Siddhartha Bingi
authored
Passing tenantId to mlclient getTask for multiTenancy feature (#1041)
Signed-off-by: Siddhartha Bingi <[email protected]> Co-authored-by: Siddhartha Bingi <[email protected]>
1 parent 20a1d40 commit ada9b06

File tree

5 files changed

+12
-9
lines changed

5 files changed

+12
-9
lines changed

src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ protected void retryableGetMlTask(
8282
) {
8383
CompletableFuture.runAsync(() -> {
8484
do {
85-
mlClient.getTask(taskId, ActionListener.wrap(response -> {
85+
mlClient.getTask(taskId, tenantId, ActionListener.wrap(response -> {
8686
String resourceName = getResourceByWorkflowStep(getName());
8787
String id = getResourceId(response);
8888
switch (response.getState()) {

src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I
131131
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
132132
actionListener.onResponse(output);
133133
return null;
134-
}).when(machineLearningNodeClient).getTask(any(), any());
134+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
135135

136136
doAnswer(invocation -> {
137137
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
@@ -152,7 +152,7 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I
152152
future.actionGet();
153153

154154
verify(machineLearningNodeClient, times(1)).deploy(any(String.class), nullable(String.class), any());
155-
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
155+
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());
156156

157157
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
158158
}

src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
5555
import static org.mockito.ArgumentMatchers.any;
5656
import static org.mockito.ArgumentMatchers.anyString;
57+
import static org.mockito.ArgumentMatchers.nullable;
5758
import static org.mockito.Mockito.doAnswer;
5859
import static org.mockito.Mockito.mock;
5960
import static org.mockito.Mockito.times;
@@ -151,7 +152,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
151152
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
152153
actionListener.onResponse(output);
153154
return null;
154-
}).when(machineLearningNodeClient).getTask(any(), any());
155+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
155156

156157
doAnswer(invocation -> {
157158
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
@@ -172,7 +173,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
172173
future.actionGet();
173174

174175
verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
175-
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
176+
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());
176177

177178
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
178179
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));

src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
5454
import static org.mockito.ArgumentMatchers.any;
5555
import static org.mockito.ArgumentMatchers.anyString;
56+
import static org.mockito.ArgumentMatchers.nullable;
5657
import static org.mockito.Mockito.doAnswer;
5758
import static org.mockito.Mockito.mock;
5859
import static org.mockito.Mockito.times;
@@ -144,7 +145,7 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception {
144145
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
145146
actionListener.onResponse(output);
146147
return null;
147-
}).when(machineLearningNodeClient).getTask(any(), any());
148+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
148149

149150
doAnswer(invocation -> {
150151
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
@@ -165,7 +166,7 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception {
165166
future.actionGet();
166167

167168
verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
168-
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
169+
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());
169170

170171
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
171172
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));

src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
5454
import static org.mockito.ArgumentMatchers.any;
5555
import static org.mockito.ArgumentMatchers.anyString;
56+
import static org.mockito.ArgumentMatchers.nullable;
5657
import static org.mockito.Mockito.doAnswer;
5758
import static org.mockito.Mockito.mock;
5859
import static org.mockito.Mockito.times;
@@ -147,7 +148,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception {
147148
MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build();
148149
actionListener.onResponse(output);
149150
return null;
150-
}).when(machineLearningNodeClient).getTask(any(), any());
151+
}).when(machineLearningNodeClient).getTask(any(), nullable(String.class), any());
151152

152153
doAnswer(invocation -> {
153154
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(5);
@@ -168,7 +169,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception {
168169
future.actionGet();
169170

170171
verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
171-
verify(machineLearningNodeClient, times(1)).getTask(any(), any());
172+
verify(machineLearningNodeClient, times(1)).getTask(any(), nullable(String.class), any());
172173

173174
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
174175
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));

0 commit comments

Comments
 (0)