|
75 | 75 | import org.mockito.ArgumentCaptor;
|
76 | 76 | import org.mockito.Mock;
|
77 | 77 | import org.mockito.MockitoAnnotations;
|
| 78 | +import org.opensearch.OpenSearchStatusException; |
78 | 79 | import org.opensearch.action.get.GetRequest;
|
79 | 80 | import org.opensearch.action.get.GetResponse;
|
80 | 81 | import org.opensearch.action.index.IndexResponse;
|
|
92 | 93 | import org.opensearch.core.common.breaker.CircuitBreakingException;
|
93 | 94 | import org.opensearch.core.common.bytes.BytesReference;
|
94 | 95 | import org.opensearch.core.index.shard.ShardId;
|
| 96 | +import org.opensearch.core.rest.RestStatus; |
95 | 97 | import org.opensearch.core.xcontent.NamedXContentRegistry;
|
96 | 98 | import org.opensearch.core.xcontent.ToXContent;
|
97 | 99 | import org.opensearch.core.xcontent.XContentBuilder;
|
| 100 | +import org.opensearch.index.IndexNotFoundException; |
98 | 101 | import org.opensearch.index.get.GetResult;
|
99 | 102 | import org.opensearch.ml.breaker.MLCircuitBreakerService;
|
100 | 103 | import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
|
@@ -492,6 +495,46 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException, IOExce
|
492 | 495 | verify(mlTaskManager).updateMLTask(anyString(), any(), anyMap(), anyLong(), anyBoolean());
|
493 | 496 | }
|
494 | 497 |
|
| 498 | + @Test |
| 499 | + public void testRegisterMLRemoteModelModelGroupNotFoundException() throws PrivilegedActionException, IOException { |
| 500 | + // Create listener and capture the failure |
| 501 | + ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class); |
| 502 | + ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class); |
| 503 | + |
| 504 | + // Setup mocks |
| 505 | + doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); |
| 506 | + when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null); |
| 507 | + when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); |
| 508 | + when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo")); |
| 509 | + when(modelHelper.isModelAllowed(any(), any())).thenReturn(true); |
| 510 | + |
| 511 | + // Create test inputs |
| 512 | + MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true); |
| 513 | + MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build(); |
| 514 | + |
| 515 | + // Mock index handler |
| 516 | + mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); |
| 517 | + |
| 518 | + // Mock client.get() to throw IndexNotFoundException |
| 519 | + doAnswer(invocation -> { |
| 520 | + ActionListener<GetResponse> getModelGroupListener = invocation.getArgument(1); |
| 521 | + getModelGroupListener.onFailure(new IndexNotFoundException("Test", "test")); |
| 522 | + return null; |
| 523 | + }).when(client).get(any(), any()); |
| 524 | + |
| 525 | + // Execute method under test |
| 526 | + modelManager.registerMLRemoteModel(sdkClient, pretrainedInput, pretrainedTask, listener); |
| 527 | + |
| 528 | + // Verify the listener's onFailure was called with correct exception |
| 529 | + verify(listener).onFailure(exceptionCaptor.capture()); |
| 530 | + Exception exception = exceptionCaptor.getValue(); |
| 531 | + |
| 532 | + // Verify exception type and message |
| 533 | + assertTrue(exception instanceof OpenSearchStatusException); |
| 534 | + assertEquals("Model group not found", exception.getMessage()); |
| 535 | + assertEquals(RestStatus.NOT_FOUND, ((OpenSearchStatusException) exception).status()); |
| 536 | + } |
| 537 | + |
495 | 538 | public void testRegisterMLRemoteModel_SkipMemoryCBOpen() throws IOException {
|
496 | 539 | ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
|
497 | 540 | doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
|
|
0 commit comments