Skip to content

Commit 6f29afc

Browse files
fixing the circuit breaker issue for remote model (#3652) (#3654)
Signed-off-by: Dhrubo Saha <[email protected]> (cherry picked from commit 70391fc) Co-authored-by: Dhrubo Saha <[email protected]>
1 parent 116546c commit 6f29afc

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java

-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) {
8787
public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
8888
if (!request.isDispatchTask()) {
8989
log.debug("Run ML request {} locally", request.getRequestID());
90-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
9190
checkCBAndExecute(functionName, request, listener);
9291
return;
9392
}

plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java

+13-1
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,20 @@ public void testRun_CircuitBreakerOpen() {
139139
TransportService transportService = mock(TransportService.class);
140140
ActionListener listener = mock(ActionListener.class);
141141
MLTaskRequest request = new MLTaskRequest(false);
142-
expectThrows(CircuitBreakingException.class, () -> mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener));
142+
expectThrows(CircuitBreakingException.class, () -> mlTaskRunner.run(FunctionName.BATCH_RCF, request, transportService, listener));
143143
Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
144144
assertEquals(1L, value.longValue());
145145
}
146+
147+
public void testRun_NoCircuitbreakerforRemote() {
148+
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
149+
when(thresholdCircuitBreaker.getName()).thenReturn("Memory Circuit Breaker");
150+
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
151+
TransportService transportService = mock(TransportService.class);
152+
ActionListener listener = mock(ActionListener.class);
153+
MLTaskRequest request = new MLTaskRequest(false);
154+
mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener);
155+
Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
156+
assertEquals(0L, value.longValue());
157+
}
146158
}

0 commit comments

Comments
 (0)