Skip to content

Commit ae58bfd

Browse files
authored
[FEATURE] Improve test coverage for MLEngine Class (#3675)
Improved the test coverage by adding missed use cases. There is no zero argument constructor, constructor with Input as paremeter and constructor with MLAlgoParams.class as paremeter in the class class org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation. So it is impossible to cover remianing potions, it always throws java.lang.NoSuchMethodException: org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.<init>(). Resolves #1376 Signed-off-by: Abdul Muneer Kolarkunnu <[email protected]>
1 parent 0448890 commit ae58bfd

File tree

1 file changed

+80
-4
lines changed

1 file changed

+80
-4
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java

+80-4
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
import java.io.IOException;
1717
import java.nio.file.Path;
18+
import java.util.ArrayList;
1819
import java.util.Arrays;
1920
import java.util.Collections;
21+
import java.util.List;
2022
import java.util.Map;
2123
import java.util.UUID;
2224

@@ -43,6 +45,7 @@
4345
import org.opensearch.ml.common.dataset.MLInputDataset;
4446
import org.opensearch.ml.common.input.Input;
4547
import org.opensearch.ml.common.input.MLInput;
48+
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
4649
import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput;
4750
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
4851
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
@@ -254,6 +257,17 @@ public void trainAndPredictWithKmeans() {
254257
assertEquals(dataSize, output.getPredictionResult().size());
255258
}
256259

260+
@Test
261+
public void trainAndPredictWithMetricsCorrelationThrowsException() {
262+
exceptionRule.expect(IllegalArgumentException.class);
263+
exceptionRule.expectMessage("Unsupported algorithm: METRICS_CORRELATION");
264+
int dataSize = 100;
265+
DataFrame dataFrame = constructTestDataFrame(dataSize);
266+
MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
267+
Input input = new MLInput(FunctionName.METRICS_CORRELATION, null, inputData);
268+
mlEngine.trainAndPredict(input);
269+
}
270+
257271
@Test
258272
public void trainAndPredictWithInvalidInput() {
259273
exceptionRule.expect(IllegalArgumentException.class);
@@ -272,6 +286,18 @@ public void executeLocalSampleCalculator() throws Exception {
272286
mlEngine.execute(input, listener);
273287
}
274288

289+
@Test
290+
public void executeWithMetricsCorrelationThrowsException() throws Exception {
291+
exceptionRule.expect(IllegalArgumentException.class);
292+
exceptionRule.expectMessage("Unsupported executable function: METRICS_CORRELATION");
293+
List<float[]> inputData = new ArrayList<>();
294+
inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f });
295+
inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f });
296+
inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f });
297+
Input input = MetricsCorrelationInput.builder().inputData(inputData).build();
298+
mlEngine.execute(input, null);
299+
}
300+
275301
@Test
276302
public void executeWithInvalidInput() throws Exception {
277303
exceptionRule.expect(IllegalArgumentException.class);
@@ -355,15 +381,36 @@ public void getRegisterModelPath_ReturnsCorrectPath() {
355381
}
356382

357383
@Test
358-
public void getDeployModelPath_ReturnsCorrectPath() {
384+
public void getPathAPIs_ReturnsCorrectPath() {
359385
String modelId = "deployedModel";
360386

361387
// Use the actual base path from the mlEngine instance
362388
Path basePath = mlEngine.getMlCachePath().getParent();
363-
Path expectedPath = basePath.resolve("ml_cache").resolve("models_cache").resolve(MLEngine.DEPLOY_MODEL_FOLDER).resolve(modelId);
364-
Path actualPath = mlEngine.getDeployModelPath(modelId);
389+
Path modelsCachePath = basePath.resolve("ml_cache").resolve("models_cache");
390+
Path expectedDeployModelRootPath = modelsCachePath.resolve(MLEngine.DEPLOY_MODEL_FOLDER);
391+
assertEquals(expectedDeployModelRootPath.toString(), mlEngine.getDeployModelRootPath().toString());
392+
Path expectedDeployModelPath = expectedDeployModelRootPath.resolve(modelId);
393+
assertEquals(expectedDeployModelPath.toString(), mlEngine.getDeployModelPath(modelId).toString());
365394

366-
assertEquals(expectedPath.toString(), actualPath.toString());
395+
String expectedDeployModelZipPath = expectedDeployModelRootPath.resolve(modelId).resolve("myModel") + ".zip";
396+
assertEquals(expectedDeployModelZipPath, mlEngine.getDeployModelZipPath(modelId, "myModel"));
397+
Path expectedDeployModelChunkPath = expectedDeployModelRootPath.resolve(modelId).resolve("chunks").resolve("1");
398+
assertEquals(expectedDeployModelChunkPath.toString(), mlEngine.getDeployModelChunkPath(modelId, 1).toString());
399+
400+
assertEquals(
401+
"https://artifacts.opensearch.org/models/ml-models/model_listing/pre_trained_models.json",
402+
mlEngine.getPrebuiltModelMetaListPath()
403+
);
404+
405+
Path expectedRegisterRootPath = modelsCachePath.resolve(MLEngine.REGISTER_MODEL_FOLDER);
406+
assertEquals(expectedRegisterRootPath.toString(), mlEngine.getRegisterModelRootPath().toString());
407+
Path expectedRegisterModelPath = expectedRegisterRootPath.resolve(modelId);
408+
assertEquals(expectedRegisterModelPath.toString(), mlEngine.getRegisterModelPath(modelId).toString());
409+
410+
Path expectedMdelCacheRootPath = modelsCachePath.resolve("models");
411+
assertEquals(expectedMdelCacheRootPath.toString(), mlEngine.getModelCacheRootPath().toString());
412+
Path expectedMdelCachePath = expectedMdelCacheRootPath.resolve(modelId);
413+
assertEquals(expectedMdelCachePath.toString(), mlEngine.getModelCachePath(modelId).toString());
367414
}
368415

369416
@Test
@@ -444,4 +491,33 @@ public void testGetConnectorCredential() throws IOException {
444491
assertEquals(decryptedCredential.get("key"), "test_key_value");
445492
assertEquals(decryptedCredential.get("region"), "test region");
446493
}
494+
495+
@Test
496+
public void testGetConnectorCredentialWithoutRegion() throws IOException {
497+
String encryptedValue = mlEngine.encrypt("test_key_value", null);
498+
String test_connector_string = "{\"name\":\"test_connector_name\",\"version\":\"1\","
499+
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
500+
+ "\"parameters\":{},\"credential\":{\"key\":\""
501+
+ encryptedValue
502+
+ "\"},"
503+
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
504+
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
505+
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"}],"
506+
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}";
507+
508+
XContentParser parser = XContentType.JSON
509+
.xContent()
510+
.createParser(
511+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
512+
null,
513+
test_connector_string
514+
);
515+
parser.nextToken();
516+
517+
HttpConnector connector = new HttpConnector("http", parser);
518+
Map<String, String> decryptedCredential = mlEngine.getConnectorCredential(connector);
519+
assertNotNull(decryptedCredential);
520+
assertEquals("test_key_value", decryptedCredential.get("key"));
521+
assertEquals(null, decryptedCredential.get("region"));
522+
}
447523
}

0 commit comments

Comments
 (0)