From 3648777e13d2b4feccd3c899c6f38788372672e0 Mon Sep 17 00:00:00 2001 From: Abdul Muneer Kolarkunnu Date: Wed, 19 Mar 2025 15:43:07 +0530 Subject: [PATCH] [FEATURE] Improve test coverage for MLEngine Class Improved the test coverage by adding missed use cases. There is no zero argument constructor, constructor with Input as paremeter and constructor with MLAlgoParams.class as paremeter in the class class org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation. So it is impossible to cover remianing potions, it always throws java.lang.NoSuchMethodException: org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.(). Resolves #1376 Signed-off-by: Abdul Muneer Kolarkunnu --- .../opensearch/ml/engine/MLEngineTest.java | 84 ++++++++++++++++++- 1 file changed, 80 insertions(+), 4 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 30d4902f1b..adc83d9d3d 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -15,8 +15,10 @@ import java.io.IOException; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.UUID; @@ -43,6 +45,7 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; @@ -254,6 +257,17 @@ public void trainAndPredictWithKmeans() { assertEquals(dataSize, output.getPredictionResult().size()); } + @Test + public void trainAndPredictWithMetricsCorrelationThrowsException() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Unsupported algorithm: METRICS_CORRELATION"); + int dataSize = 100; + DataFrame dataFrame = constructTestDataFrame(dataSize); + MLInputDataset inputData = new DataFrameInputDataset(dataFrame); + Input input = new MLInput(FunctionName.METRICS_CORRELATION, null, inputData); + mlEngine.trainAndPredict(input); + } + @Test public void trainAndPredictWithInvalidInput() { exceptionRule.expect(IllegalArgumentException.class); @@ -272,6 +286,18 @@ public void executeLocalSampleCalculator() throws Exception { mlEngine.execute(input, listener); } + @Test + public void executeWithMetricsCorrelationThrowsException() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Unsupported executable function: METRICS_CORRELATION"); + List inputData = new ArrayList<>(); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + Input input = MetricsCorrelationInput.builder().inputData(inputData).build(); + mlEngine.execute(input, null); + } + @Test public void executeWithInvalidInput() throws Exception { exceptionRule.expect(IllegalArgumentException.class); @@ -355,15 +381,36 @@ public void getRegisterModelPath_ReturnsCorrectPath() { } @Test - public void getDeployModelPath_ReturnsCorrectPath() { + public void getPathAPIs_ReturnsCorrectPath() { String modelId = "deployedModel"; // Use the actual base path from the mlEngine instance Path basePath = mlEngine.getMlCachePath().getParent(); - Path expectedPath = basePath.resolve("ml_cache").resolve("models_cache").resolve(MLEngine.DEPLOY_MODEL_FOLDER).resolve(modelId); - Path actualPath = mlEngine.getDeployModelPath(modelId); + Path modelsCachePath = basePath.resolve("ml_cache").resolve("models_cache"); + Path expectedDeployModelRootPath = modelsCachePath.resolve(MLEngine.DEPLOY_MODEL_FOLDER); + assertEquals(expectedDeployModelRootPath.toString(), mlEngine.getDeployModelRootPath().toString()); + Path expectedDeployModelPath = expectedDeployModelRootPath.resolve(modelId); + assertEquals(expectedDeployModelPath.toString(), mlEngine.getDeployModelPath(modelId).toString()); - assertEquals(expectedPath.toString(), actualPath.toString()); + String expectedDeployModelZipPath = expectedDeployModelRootPath.resolve(modelId).resolve("myModel") + ".zip"; + assertEquals(expectedDeployModelZipPath, mlEngine.getDeployModelZipPath(modelId, "myModel")); + Path expectedDeployModelChunkPath = expectedDeployModelRootPath.resolve(modelId).resolve("chunks").resolve("1"); + assertEquals(expectedDeployModelChunkPath.toString(), mlEngine.getDeployModelChunkPath(modelId, 1).toString()); + + assertEquals( + "https://artifacts.opensearch.org/models/ml-models/model_listing/pre_trained_models.json", + mlEngine.getPrebuiltModelMetaListPath() + ); + + Path expectedRegisterRootPath = modelsCachePath.resolve(MLEngine.REGISTER_MODEL_FOLDER); + assertEquals(expectedRegisterRootPath.toString(), mlEngine.getRegisterModelRootPath().toString()); + Path expectedRegisterModelPath = expectedRegisterRootPath.resolve(modelId); + assertEquals(expectedRegisterModelPath.toString(), mlEngine.getRegisterModelPath(modelId).toString()); + + Path expectedMdelCacheRootPath = modelsCachePath.resolve("models"); + assertEquals(expectedMdelCacheRootPath.toString(), mlEngine.getModelCacheRootPath().toString()); + Path expectedMdelCachePath = expectedMdelCacheRootPath.resolve(modelId); + assertEquals(expectedMdelCachePath.toString(), mlEngine.getModelCachePath(modelId).toString()); } @Test @@ -444,4 +491,33 @@ public void testGetConnectorCredential() throws IOException { assertEquals(decryptedCredential.get("key"), "test_key_value"); assertEquals(decryptedCredential.get("region"), "test region"); } + + @Test + public void testGetConnectorCredentialWithoutRegion() throws IOException { + String encryptedValue = mlEngine.encrypt("test_key_value", null); + String test_connector_string = "{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{},\"credential\":{\"key\":\"" + + encryptedValue + + "\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"}]," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + test_connector_string + ); + parser.nextToken(); + + HttpConnector connector = new HttpConnector("http", parser); + Map decryptedCredential = mlEngine.getConnectorCredential(connector); + assertNotNull(decryptedCredential); + assertEquals("test_key_value", decryptedCredential.get("key")); + assertEquals(null, decryptedCredential.get("region")); + } }