Skip to content

Commit d1dcd95

Browse files
committed
Replace null GetResponse with valid response and not exists
Signed-off-by: Daniel Widdis <[email protected]>
1 parent 088c1a5 commit d1dcd95

File tree

9 files changed

+133
-68
lines changed

9 files changed

+133
-68
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import static org.mockito.Mockito.doThrow;
66
import static org.mockito.Mockito.mock;
77
import static org.mockito.Mockito.when;
8+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
9+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
810
import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD;
911
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
1012
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
@@ -160,9 +162,10 @@ public void encrypt_NonExistingMasterKey() {
160162
}).when(mlIndicesHandler).initMLConfigIndex(any());
161163
IndexResponse indexResponse = prepareIndexResponse();
162164

165+
GetResponse getResponse = prepareNotExistsGetResponse();
163166
doAnswer(invocation -> {
164167
ActionListener<GetResponse> actionListener = (ActionListener) invocation.getArgument(1);
165-
actionListener.onResponse(null);
168+
actionListener.onResponse(getResponse);
166169
return null;
167170
}).when(client).get(any(), any());
168171

@@ -191,7 +194,8 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey() {
191194
}).when(mlIndicesHandler).initMLConfigIndex(any());
192195
doAnswer(invocation -> {
193196
ActionListener<GetResponse> actionListener = (ActionListener) invocation.getArgument(1);
194-
actionListener.onResponse(null);
197+
GetResponse getResponse = prepareNotExistsGetResponse();
198+
actionListener.onResponse(getResponse);
195199
return null;
196200
}).when(client).get(any(), any());
197201
doAnswer(invocation -> {
@@ -216,7 +220,8 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_NonRuntimeExceptio
216220
}).when(mlIndicesHandler).initMLConfigIndex(any());
217221
doAnswer(invocation -> {
218222
ActionListener<GetResponse> actionListener = (ActionListener) invocation.getArgument(1);
219-
actionListener.onResponse(null);
223+
GetResponse getResponse = prepareNotExistsGetResponse();
224+
actionListener.onResponse(getResponse);
220225
return null;
221226
}).when(client).get(any(), any());
222227
doAnswer(invocation -> {
@@ -245,7 +250,8 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict()
245250
}).when(mlIndicesHandler).initMLConfigIndex(any());
246251
doAnswer(invocation -> {
247252
ActionListener<GetResponse> actionListener = (ActionListener) invocation.getArgument(1);
248-
actionListener.onResponse(null);
253+
GetResponse getResponse = prepareNotExistsGetResponse();
254+
actionListener.onResponse(getResponse);
249255
return null;
250256
}).doAnswer(invocation -> {
251257
ActionListener<GetResponse> actionListener = (ActionListener) invocation.getArgument(1);
@@ -500,7 +506,8 @@ public void encrypt_SdkClientPutDataObjectFailure() {
500506

501507
doAnswer(invocation -> {
502508
ActionListener<GetResponse> listener = invocation.getArgument(1);
503-
listener.onResponse(null);
509+
GetResponse getResponse = prepareNotExistsGetResponse();
510+
listener.onResponse(getResponse);
504511
return null;
505512
}).when(client).get(any(), any());
506513

@@ -777,4 +784,20 @@ private IndexResponse prepareIndexResponse() {
777784
ShardId shardId = new ShardId(ML_CONFIG_INDEX, "index_uuid", 0);
778785
return new IndexResponse(shardId, MASTER_KEY, 1L, 1L, 1L, true);
779786
}
787+
788+
// Helper method to prepare a valid GetResponse
789+
private GetResponse prepareNotExistsGetResponse() {
790+
GetResult getResult = new GetResult(
791+
ML_CONFIG_INDEX,
792+
"fake_id",
793+
UNASSIGNED_SEQ_NO,
794+
UNASSIGNED_PRIMARY_TERM,
795+
-1L,
796+
false,
797+
null,
798+
null,
799+
null
800+
);
801+
return new GetResponse(getResult);
802+
}
780803
}

plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import static org.mockito.ArgumentMatchers.any;
88
import static org.mockito.Mockito.*;
99
import static org.mockito.Mockito.verify;
10+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
11+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
12+
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
1013

1114
import java.io.IOException;
1215
import java.time.Instant;
@@ -205,14 +208,25 @@ public void testDoExecute_RuntimeException() {
205208
}
206209

207210
@Test
208-
public void testGetTask_NullResponse() {
211+
public void testGetTask_NotFoundResponse() {
209212
String agentId = "test-agent-id-NullResponse";
210213
Task task = mock(Task.class);
211214
ActionListener<MLAgentGetResponse> actionListener = mock(ActionListener.class);
212215
MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId, true, null);
216+
GetResult getResult = new GetResult(
217+
ML_AGENT_INDEX,
218+
"fake_id",
219+
UNASSIGNED_SEQ_NO,
220+
UNASSIGNED_PRIMARY_TERM,
221+
-1L,
222+
false,
223+
null,
224+
null,
225+
null
226+
);
213227
doAnswer(invocation -> {
214228
ActionListener<GetResponse> listener = invocation.getArgument(1);
215-
listener.onResponse(null);
229+
listener.onResponse(new GetResponse(getResult));
216230
return null;
217231
}).when(client).get(any(), any());
218232
getAgentTransportAction.doExecute(task, getRequest, actionListener);

plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
import static org.mockito.Mockito.spy;
1111
import static org.mockito.Mockito.verify;
1212
import static org.mockito.Mockito.when;
13+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
14+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
15+
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
1316

1417
import java.io.IOException;
1518
import java.util.Collections;
@@ -167,10 +170,21 @@ public void testGetModel_ValidateAccessFailed() throws IOException {
167170
assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage());
168171
}
169172

170-
public void testGetModel_NullResponse() {
173+
public void testGetModel_NotExistsResponse() {
174+
GetResult getResult = new GetResult(
175+
ML_MODEL_GROUP_INDEX,
176+
"fake_id",
177+
UNASSIGNED_SEQ_NO,
178+
UNASSIGNED_PRIMARY_TERM,
179+
-1L,
180+
false,
181+
null,
182+
null,
183+
null
184+
);
171185
doAnswer(invocation -> {
172186
ActionListener<GetResponse> listener = invocation.getArgument(1);
173-
listener.onResponse(null);
187+
listener.onResponse(new GetResponse(getResult));
174188
return null;
175189
}).when(client).get(any(), any());
176190
getModelGroupTransportAction.doExecute(null, mlModelGroupGetRequest, actionListener);

plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import static org.mockito.Mockito.spy;
1212
import static org.mockito.Mockito.verify;
1313
import static org.mockito.Mockito.when;
14+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
15+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
16+
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
1417

1518
import java.io.IOException;
1619
import java.util.Collections;
@@ -212,10 +215,21 @@ public void testGetModel_ValidateAccessFailed() throws IOException, InterruptedE
212215
assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage());
213216
}
214217

215-
public void testGetModel_NullResponse() {
218+
public void testGetModel_NotExistsResponse() {
219+
GetResult getResult = new GetResult(
220+
ML_MODEL_INDEX,
221+
"fake_id",
222+
UNASSIGNED_SEQ_NO,
223+
UNASSIGNED_PRIMARY_TERM,
224+
-1L,
225+
false,
226+
null,
227+
null,
228+
null
229+
);
216230
doAnswer(invocation -> {
217231
ActionListener<GetResponse> listener = invocation.getArgument(1);
218-
listener.onResponse(null);
232+
listener.onResponse(new GetResponse(getResult));
219233
return null;
220234
}).when(client).get(any(), any());
221235

plugin/src/test/java/org/opensearch/ml/action/tasks/DeleteTaskTransportActionTests.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import static org.mockito.Mockito.verify;
1212
import static org.mockito.Mockito.when;
1313
import static org.opensearch.action.DocWriteResponse.Result.DELETED;
14+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
15+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
1416
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
1517

1618
import java.io.IOException;
@@ -144,10 +146,21 @@ public void testDeleteTask_ResourceNotFoundException() throws IOException {
144146
assertEquals("Failed to get data object from index .plugins-ml-task", argumentCaptor.getValue().getMessage());
145147
}
146148

147-
public void testDeleteTask_GetResponseNullException() {
149+
public void testDeleteTask_GetResponseNotExistsException() {
150+
GetResult getResult = new GetResult(
151+
ML_TASK_INDEX,
152+
"fake_id",
153+
UNASSIGNED_SEQ_NO,
154+
UNASSIGNED_PRIMARY_TERM,
155+
-1L,
156+
false,
157+
null,
158+
null,
159+
null
160+
);
148161
doAnswer(invocation -> {
149162
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
150-
actionListener.onResponse(null);
163+
actionListener.onResponse(new GetResponse(getResult));
151164
return null;
152165
}).when(client).get(any(), any());
153166

plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
import static org.mockito.Mockito.spy;
1616
import static org.mockito.Mockito.verify;
1717
import static org.mockito.Mockito.when;
18+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
19+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
20+
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
1821
import static org.opensearch.ml.common.connector.AbstractConnector.*;
1922
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX;
2023
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX;
@@ -255,9 +258,20 @@ public void setup() throws IOException {
255258
}
256259

257260
public void testGetTask_NullResponse() {
261+
GetResult getResult = new GetResult(
262+
ML_TASK_INDEX,
263+
"fake_id",
264+
UNASSIGNED_SEQ_NO,
265+
UNASSIGNED_PRIMARY_TERM,
266+
-1L,
267+
false,
268+
null,
269+
null,
270+
null
271+
);
258272
doAnswer(invocation -> {
259273
ActionListener<GetResponse> listener = invocation.getArgument(1);
260-
listener.onResponse(null);
274+
listener.onResponse(new GetResponse(getResult));
261275
return null;
262276
}).when(client).get(any(), any());
263277
getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener);

plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import static org.mockito.Mockito.doAnswer;
1111
import static org.mockito.Mockito.verify;
1212
import static org.mockito.Mockito.when;
13+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
14+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
1315
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
1416

1517
import java.io.IOException;
@@ -398,9 +400,20 @@ public void test_OtherExceptionGetModelGroup() throws IOException {
398400
}
399401

400402
public void test_NotFoundGetModelGroup() throws IOException {
403+
GetResult getResult = new GetResult(
404+
ML_MODEL_GROUP_INDEX,
405+
"fake_id",
406+
UNASSIGNED_SEQ_NO,
407+
UNASSIGNED_PRIMARY_TERM,
408+
-1L,
409+
false,
410+
null,
411+
null,
412+
null
413+
);
401414
doAnswer(invocation -> {
402415
ActionListener<GetResponse> listener = invocation.getArgument(1);
403-
listener.onResponse(null);
416+
listener.onResponse(new GetResponse(getResult));
404417
return null;
405418
}).when(client).get(any(GetRequest.class), isA(ActionListener.class));
406419

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import static org.opensearch.ml.utils.MockHelper.mock_client_ThreadContext;
4141
import static org.opensearch.ml.utils.MockHelper.mock_client_ThreadContext_Exception;
4242
import static org.opensearch.ml.utils.MockHelper.mock_client_get_NotExist;
43-
import static org.opensearch.ml.utils.MockHelper.mock_client_get_NullResponse;
4443
import static org.opensearch.ml.utils.MockHelper.mock_client_get_failure;
4544
import static org.opensearch.ml.utils.MockHelper.mock_client_index;
4645
import static org.opensearch.ml.utils.MockHelper.mock_client_index_failure;
@@ -725,47 +724,6 @@ public void testDeployModel_FailedToGetModel() {
725724
);
726725
}
727726

728-
public void testDeployModel_NullGetModelResponse() {
729-
MLModelConfig modelConfig = TextEmbeddingModelConfig
730-
.builder()
731-
.modelType("bert")
732-
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
733-
.embeddingDimension(384)
734-
.build();
735-
model = MLModel
736-
.builder()
737-
.modelId(modelId)
738-
.modelState(MLModelState.DEPLOYING)
739-
.algorithm(FunctionName.TEXT_EMBEDDING)
740-
.name(modelName)
741-
.version(version)
742-
.totalChunks(2)
743-
.modelFormat(MLModelFormat.TORCH_SCRIPT)
744-
.modelConfig(modelConfig)
745-
.modelContentHash(modelContentHashValue)
746-
.modelContentSizeInBytes(modelContentSize)
747-
.build();
748-
String[] nodes = new String[] { "node1", "node2" };
749-
mlTask.setWorkerNodes(List.of(nodes));
750-
ActionListener<String> listener = mock(ActionListener.class);
751-
when(modelCacheHelper.isModelDeployed(modelId)).thenReturn(false);
752-
when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {});
753-
when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {});
754-
mock_threadpool(threadPool, taskExecutorService);
755-
mock_client_get_NullResponse(client);
756-
modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener);
757-
assertFalse(modelManager.isModelRunningOnNode(modelId));
758-
ArgumentCaptor<Exception> exception = ArgumentCaptor.forClass(Exception.class);
759-
verify(listener).onFailure(exception.capture());
760-
assertEquals("Failed to find model", exception.getValue().getMessage());
761-
verify(mlStats)
762-
.createCounterStatIfAbsent(
763-
eq(FunctionName.TEXT_EMBEDDING),
764-
eq(ActionName.DEPLOY),
765-
eq(MLActionLevelStat.ML_ACTION_FAILURE_COUNT)
766-
);
767-
}
768-
769727
public void testDeployModel_GetModelResponse_NotExist() {
770728
MLModelConfig modelConfig = TextEmbeddingModelConfig
771729
.builder()

plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.mockito.Mockito.doThrow;
1212
import static org.mockito.Mockito.mock;
1313
import static org.mockito.Mockito.when;
14+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
1415
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
1516
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
1617

@@ -34,19 +35,20 @@
3435
public class MockHelper {
3536

3637
public static void mock_client_get_NotExist(Client client) {
38+
GetResult getResult = new GetResult(
39+
"fake_index",
40+
"fake_id",
41+
UNASSIGNED_SEQ_NO,
42+
UNASSIGNED_PRIMARY_TERM,
43+
-1L,
44+
false,
45+
null,
46+
null,
47+
null
48+
);
3749
doAnswer(invocation -> {
3850
ActionListener<GetResponse> listener = invocation.getArgument(1);
39-
GetResponse response = mock(GetResponse.class);
40-
when(response.isExists()).thenReturn(false);
41-
listener.onResponse(null);
42-
return null;
43-
}).when(client).get(any(), any());
44-
}
45-
46-
public static void mock_client_get_NullResponse(Client client) {
47-
doAnswer(invocation -> {
48-
ActionListener<GetResponse> listener = invocation.getArgument(1);
49-
listener.onResponse(null);
51+
listener.onResponse(new GetResponse(getResult));
5052
return null;
5153
}).when(client).get(any(), any());
5254
}

0 commit comments

Comments
 (0)