diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 30d4902f1b..adc83d9d3d 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -15,8 +15,10 @@ import java.io.IOException; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.UUID; @@ -43,6 +45,7 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; @@ -254,6 +257,17 @@ public void trainAndPredictWithKmeans() { assertEquals(dataSize, output.getPredictionResult().size()); } + @Test + public void trainAndPredictWithMetricsCorrelationThrowsException() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Unsupported algorithm: METRICS_CORRELATION"); + int dataSize = 100; + DataFrame dataFrame = constructTestDataFrame(dataSize); + MLInputDataset inputData = new DataFrameInputDataset(dataFrame); + Input input = new MLInput(FunctionName.METRICS_CORRELATION, null, inputData); + mlEngine.trainAndPredict(input); + } + @Test public void trainAndPredictWithInvalidInput() { exceptionRule.expect(IllegalArgumentException.class); @@ -272,6 +286,18 @@ public void executeLocalSampleCalculator() throws Exception { mlEngine.execute(input, listener); } + @Test + public void executeWithMetricsCorrelationThrowsException() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Unsupported executable function: METRICS_CORRELATION"); + List inputData = new ArrayList<>(); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + Input input = MetricsCorrelationInput.builder().inputData(inputData).build(); + mlEngine.execute(input, null); + } + @Test public void executeWithInvalidInput() throws Exception { exceptionRule.expect(IllegalArgumentException.class); @@ -355,15 +381,36 @@ public void getRegisterModelPath_ReturnsCorrectPath() { } @Test - public void getDeployModelPath_ReturnsCorrectPath() { + public void getPathAPIs_ReturnsCorrectPath() { String modelId = "deployedModel"; // Use the actual base path from the mlEngine instance Path basePath = mlEngine.getMlCachePath().getParent(); - Path expectedPath = basePath.resolve("ml_cache").resolve("models_cache").resolve(MLEngine.DEPLOY_MODEL_FOLDER).resolve(modelId); - Path actualPath = mlEngine.getDeployModelPath(modelId); + Path modelsCachePath = basePath.resolve("ml_cache").resolve("models_cache"); + Path expectedDeployModelRootPath = modelsCachePath.resolve(MLEngine.DEPLOY_MODEL_FOLDER); + assertEquals(expectedDeployModelRootPath.toString(), mlEngine.getDeployModelRootPath().toString()); + Path expectedDeployModelPath = expectedDeployModelRootPath.resolve(modelId); + assertEquals(expectedDeployModelPath.toString(), mlEngine.getDeployModelPath(modelId).toString()); - assertEquals(expectedPath.toString(), actualPath.toString()); + String expectedDeployModelZipPath = expectedDeployModelRootPath.resolve(modelId).resolve("myModel") + ".zip"; + assertEquals(expectedDeployModelZipPath, mlEngine.getDeployModelZipPath(modelId, "myModel")); + Path expectedDeployModelChunkPath = expectedDeployModelRootPath.resolve(modelId).resolve("chunks").resolve("1"); + assertEquals(expectedDeployModelChunkPath.toString(), mlEngine.getDeployModelChunkPath(modelId, 1).toString()); + + assertEquals( + "https://artifacts.opensearch.org/models/ml-models/model_listing/pre_trained_models.json", + mlEngine.getPrebuiltModelMetaListPath() + ); + + Path expectedRegisterRootPath = modelsCachePath.resolve(MLEngine.REGISTER_MODEL_FOLDER); + assertEquals(expectedRegisterRootPath.toString(), mlEngine.getRegisterModelRootPath().toString()); + Path expectedRegisterModelPath = expectedRegisterRootPath.resolve(modelId); + assertEquals(expectedRegisterModelPath.toString(), mlEngine.getRegisterModelPath(modelId).toString()); + + Path expectedMdelCacheRootPath = modelsCachePath.resolve("models"); + assertEquals(expectedMdelCacheRootPath.toString(), mlEngine.getModelCacheRootPath().toString()); + Path expectedMdelCachePath = expectedMdelCacheRootPath.resolve(modelId); + assertEquals(expectedMdelCachePath.toString(), mlEngine.getModelCachePath(modelId).toString()); } @Test @@ -444,4 +491,33 @@ public void testGetConnectorCredential() throws IOException { assertEquals(decryptedCredential.get("key"), "test_key_value"); assertEquals(decryptedCredential.get("region"), "test region"); } + + @Test + public void testGetConnectorCredentialWithoutRegion() throws IOException { + String encryptedValue = mlEngine.encrypt("test_key_value", null); + String test_connector_string = "{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{},\"credential\":{\"key\":\"" + + encryptedValue + + "\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"}]," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + test_connector_string + ); + parser.nextToken(); + + HttpConnector connector = new HttpConnector("http", parser); + Map decryptedCredential = mlEngine.getConnectorCredential(connector); + assertNotNull(decryptedCredential); + assertEquals("test_key_value", decryptedCredential.get("key")); + assertEquals(null, decryptedCredential.get("region")); + } }