|
15 | 15 | import static org.mockito.Mockito.times;
|
16 | 16 | import static org.mockito.Mockito.verify;
|
17 | 17 | import static org.mockito.Mockito.when;
|
| 18 | +import static org.opensearch.ml.utils.TestHelper.getAnomalyLocalizationRestRequest; |
18 | 19 | import static org.opensearch.ml.utils.TestHelper.getExecuteAgentRestRequest;
|
19 | 20 | import static org.opensearch.ml.utils.TestHelper.getLocalSampleCalculatorRestRequest;
|
20 | 21 | import static org.opensearch.ml.utils.TestHelper.getMetricsCorrelationRestRequest;
|
21 | 22 |
|
22 | 23 | import java.io.IOException;
|
| 24 | +import java.util.Arrays; |
| 25 | +import java.util.HashMap; |
23 | 26 | import java.util.List;
|
| 27 | +import java.util.Map; |
24 | 28 |
|
25 | 29 | import org.junit.Before;
|
26 | 30 | import org.mockito.ArgumentCaptor;
|
|
32 | 36 | import org.opensearch.core.action.ActionListener;
|
33 | 37 | import org.opensearch.core.common.Strings;
|
34 | 38 | import org.opensearch.core.rest.RestStatus;
|
| 39 | +import org.opensearch.core.xcontent.XContentBuilder; |
35 | 40 | import org.opensearch.ml.common.FunctionName;
|
36 | 41 | import org.opensearch.ml.common.input.Input;
|
| 42 | +import org.opensearch.ml.common.output.execute.anomalylocalization.AnomalyLocalizationOutput; |
| 43 | +import org.opensearch.ml.common.output.execute.anomalylocalization.AnomalyLocalizationOutput.Bucket; |
| 44 | +import org.opensearch.ml.common.output.execute.anomalylocalization.AnomalyLocalizationOutput.Result; |
| 45 | +import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput; |
37 | 46 | import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
|
38 | 47 | import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
|
39 | 48 | import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
|
@@ -337,4 +346,79 @@ public void testAgentExecutionResponsePlainText() throws Exception {
|
337 | 346 | "{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}";
|
338 | 347 | assertEquals(expectedError, response.content().utf8ToString());
|
339 | 348 | }
|
| 349 | + |
| 350 | + public void testLocalSampleCalculatorExecutionResponse() throws Exception { |
| 351 | + RestRequest request = getLocalSampleCalculatorRestRequest(); |
| 352 | + XContentBuilder builder = XContentFactory.jsonBuilder(); |
| 353 | + when(channel.newBuilder()).thenReturn(builder); |
| 354 | + doAnswer(invocation -> { |
| 355 | + ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2); |
| 356 | + LocalSampleCalculatorOutput output = LocalSampleCalculatorOutput.builder().totalSum(3.0).build(); |
| 357 | + MLExecuteTaskResponse response = MLExecuteTaskResponse |
| 358 | + .builder() |
| 359 | + .output(output) |
| 360 | + .functionName(FunctionName.LOCAL_SAMPLE_CALCULATOR) |
| 361 | + .build(); |
| 362 | + actionListener.onResponse(response); |
| 363 | + return null; |
| 364 | + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); |
| 365 | + doNothing().when(channel).sendResponse(any()); |
| 366 | + restMLExecuteAction.handleRequest(request, channel, client); |
| 367 | + |
| 368 | + ArgumentCaptor<RestResponse> responseCaptor = ArgumentCaptor.forClass(RestResponse.class); |
| 369 | + verify(channel).sendResponse(responseCaptor.capture()); |
| 370 | + BytesRestResponse response = (BytesRestResponse) responseCaptor.getValue(); |
| 371 | + assertEquals(RestStatus.OK, response.status()); |
| 372 | + assertEquals("{\"result\":3.0}", response.content().utf8ToString()); |
| 373 | + } |
| 374 | + |
| 375 | + public void testAnomalyLocalizationExecutionResponse() throws Exception { |
| 376 | + RestRequest request = getAnomalyLocalizationRestRequest(); |
| 377 | + XContentBuilder builder = XContentFactory.jsonBuilder(); |
| 378 | + when(channel.newBuilder()).thenReturn(builder); |
| 379 | + doAnswer(invocation -> { |
| 380 | + ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2); |
| 381 | + |
| 382 | + Bucket bucket1 = new Bucket(); |
| 383 | + bucket1.setStartTime(1620630000000L); |
| 384 | + bucket1.setEndTime(1620716400000L); |
| 385 | + bucket1.setOverallAggValue(65.0); |
| 386 | + |
| 387 | + Result result = new Result(); |
| 388 | + result.setBuckets(Arrays.asList(bucket1)); |
| 389 | + |
| 390 | + AnomalyLocalizationOutput output = new AnomalyLocalizationOutput(); |
| 391 | + Map<String, Result> results = new HashMap<>(); |
| 392 | + results.put("sum", result); |
| 393 | + output.setResults(results); |
| 394 | + |
| 395 | + MLExecuteTaskResponse response = MLExecuteTaskResponse |
| 396 | + .builder() |
| 397 | + .output(output) |
| 398 | + .functionName(FunctionName.ANOMALY_LOCALIZATION) |
| 399 | + .build(); |
| 400 | + actionListener.onResponse(response); |
| 401 | + return null; |
| 402 | + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); |
| 403 | + doNothing().when(channel).sendResponse(any()); |
| 404 | + restMLExecuteAction.handleRequest(request, channel, client); |
| 405 | + |
| 406 | + ArgumentCaptor<RestResponse> responseCaptor = ArgumentCaptor.forClass(RestResponse.class); |
| 407 | + verify(channel).sendResponse(responseCaptor.capture()); |
| 408 | + BytesRestResponse response = (BytesRestResponse) responseCaptor.getValue(); |
| 409 | + assertEquals(RestStatus.OK, response.status()); |
| 410 | + String expectedJson = "{\"results\":[{" |
| 411 | + + "\"name\":\"sum\"," |
| 412 | + + "\"result\":{" |
| 413 | + + "\"buckets\":[" |
| 414 | + + "{" |
| 415 | + + "\"start_time\":1620630000000," |
| 416 | + + "\"end_time\":1620716400000," |
| 417 | + + "\"overall_aggregate_value\":65.0" |
| 418 | + + "}" |
| 419 | + + "]" |
| 420 | + + "}" |
| 421 | + + "}]}"; |
| 422 | + assertEquals(expectedJson, response.content().utf8ToString()); |
| 423 | + } |
340 | 424 | }
|
0 commit comments