Skip to content

Commit 68ceab3

Browse files
gracefully handles model group index not found exception (#3488) (#3494)
Signed-off-by: Dhrubo Saha <[email protected]> (cherry picked from commit fd7776e) Co-authored-by: Dhrubo Saha <[email protected]>
1 parent 73129c6 commit 68ceab3

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ public void registerMLRemoteModel(
464464
mlRegisterModelInput.getTenantId(),
465465
new MLResourceNotFoundException("Failed to get model group due to index missing")
466466
);
467-
listener.onFailure(e);
467+
listener.onFailure(new OpenSearchStatusException("Model group not found", RestStatus.NOT_FOUND));
468468
} else {
469469
log.error("Failed to get model group", e);
470470
handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), mlRegisterModelInput.getTenantId(), e);

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

+43
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
import org.mockito.ArgumentCaptor;
7676
import org.mockito.Mock;
7777
import org.mockito.MockitoAnnotations;
78+
import org.opensearch.OpenSearchStatusException;
7879
import org.opensearch.action.get.GetRequest;
7980
import org.opensearch.action.get.GetResponse;
8081
import org.opensearch.action.index.IndexResponse;
@@ -92,9 +93,11 @@
9293
import org.opensearch.core.common.breaker.CircuitBreakingException;
9394
import org.opensearch.core.common.bytes.BytesReference;
9495
import org.opensearch.core.index.shard.ShardId;
96+
import org.opensearch.core.rest.RestStatus;
9597
import org.opensearch.core.xcontent.NamedXContentRegistry;
9698
import org.opensearch.core.xcontent.ToXContent;
9799
import org.opensearch.core.xcontent.XContentBuilder;
100+
import org.opensearch.index.IndexNotFoundException;
98101
import org.opensearch.index.get.GetResult;
99102
import org.opensearch.ml.breaker.MLCircuitBreakerService;
100103
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
@@ -492,6 +495,46 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException, IOExce
492495
verify(mlTaskManager).updateMLTask(anyString(), any(), anyMap(), anyLong(), anyBoolean());
493496
}
494497

498+
@Test
499+
public void testRegisterMLRemoteModelModelGroupNotFoundException() throws PrivilegedActionException, IOException {
500+
// Create listener and capture the failure
501+
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
502+
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
503+
504+
// Setup mocks
505+
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
506+
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null);
507+
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
508+
when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo"));
509+
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);
510+
511+
// Create test inputs
512+
MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
513+
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
514+
515+
// Mock index handler
516+
mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true);
517+
518+
// Mock client.get() to throw IndexNotFoundException
519+
doAnswer(invocation -> {
520+
ActionListener<GetResponse> getModelGroupListener = invocation.getArgument(1);
521+
getModelGroupListener.onFailure(new IndexNotFoundException("Test", "test"));
522+
return null;
523+
}).when(client).get(any(), any());
524+
525+
// Execute method under test
526+
modelManager.registerMLRemoteModel(sdkClient, pretrainedInput, pretrainedTask, listener);
527+
528+
// Verify the listener's onFailure was called with correct exception
529+
verify(listener).onFailure(exceptionCaptor.capture());
530+
Exception exception = exceptionCaptor.getValue();
531+
532+
// Verify exception type and message
533+
assertTrue(exception instanceof OpenSearchStatusException);
534+
assertEquals("Model group not found", exception.getMessage());
535+
assertEquals(RestStatus.NOT_FOUND, ((OpenSearchStatusException) exception).status());
536+
}
537+
495538
public void testRegisterMLRemoteModel_SkipMemoryCBOpen() throws IOException {
496539
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
497540
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());

0 commit comments

Comments
 (0)