15
15
16
16
import java .io .IOException ;
17
17
import java .nio .file .Path ;
18
+ import java .util .ArrayList ;
18
19
import java .util .Arrays ;
19
20
import java .util .Collections ;
21
+ import java .util .List ;
20
22
import java .util .Map ;
21
23
import java .util .UUID ;
22
24
43
45
import org .opensearch .ml .common .dataset .MLInputDataset ;
44
46
import org .opensearch .ml .common .input .Input ;
45
47
import org .opensearch .ml .common .input .MLInput ;
48
+ import org .opensearch .ml .common .input .execute .metricscorrelation .MetricsCorrelationInput ;
46
49
import org .opensearch .ml .common .input .execute .samplecalculator .LocalSampleCalculatorInput ;
47
50
import org .opensearch .ml .common .input .parameter .MLAlgoParams ;
48
51
import org .opensearch .ml .common .input .parameter .clustering .KMeansParams ;
@@ -254,6 +257,17 @@ public void trainAndPredictWithKmeans() {
254
257
assertEquals (dataSize , output .getPredictionResult ().size ());
255
258
}
256
259
260
+ @ Test
261
+ public void trainAndPredictWithMetricsCorrelationThrowsException () {
262
+ exceptionRule .expect (IllegalArgumentException .class );
263
+ exceptionRule .expectMessage ("Unsupported algorithm: METRICS_CORRELATION" );
264
+ int dataSize = 100 ;
265
+ DataFrame dataFrame = constructTestDataFrame (dataSize );
266
+ MLInputDataset inputData = new DataFrameInputDataset (dataFrame );
267
+ Input input = new MLInput (FunctionName .METRICS_CORRELATION , null , inputData );
268
+ mlEngine .trainAndPredict (input );
269
+ }
270
+
257
271
@ Test
258
272
public void trainAndPredictWithInvalidInput () {
259
273
exceptionRule .expect (IllegalArgumentException .class );
@@ -272,6 +286,18 @@ public void executeLocalSampleCalculator() throws Exception {
272
286
mlEngine .execute (input , listener );
273
287
}
274
288
289
+ @ Test
290
+ public void executeWithMetricsCorrelationThrowsException () throws Exception {
291
+ exceptionRule .expect (IllegalArgumentException .class );
292
+ exceptionRule .expectMessage ("Unsupported executable function: METRICS_CORRELATION" );
293
+ List <float []> inputData = new ArrayList <>();
294
+ inputData .add (new float [] { 1.0f , 2.0f , 3.0f , 4.0f });
295
+ inputData .add (new float [] { 1.0f , 2.0f , 3.0f , 4.0f });
296
+ inputData .add (new float [] { 1.0f , 2.0f , 3.0f , 4.0f });
297
+ Input input = MetricsCorrelationInput .builder ().inputData (inputData ).build ();
298
+ mlEngine .execute (input , null );
299
+ }
300
+
275
301
@ Test
276
302
public void executeWithInvalidInput () throws Exception {
277
303
exceptionRule .expect (IllegalArgumentException .class );
@@ -355,15 +381,36 @@ public void getRegisterModelPath_ReturnsCorrectPath() {
355
381
}
356
382
357
383
@ Test
358
- public void getDeployModelPath_ReturnsCorrectPath () {
384
+ public void getPathAPIs_ReturnsCorrectPath () {
359
385
String modelId = "deployedModel" ;
360
386
361
387
// Use the actual base path from the mlEngine instance
362
388
Path basePath = mlEngine .getMlCachePath ().getParent ();
363
- Path expectedPath = basePath .resolve ("ml_cache" ).resolve ("models_cache" ).resolve (MLEngine .DEPLOY_MODEL_FOLDER ).resolve (modelId );
364
- Path actualPath = mlEngine .getDeployModelPath (modelId );
389
+ Path modelsCachePath = basePath .resolve ("ml_cache" ).resolve ("models_cache" );
390
+ Path expectedDeployModelRootPath = modelsCachePath .resolve (MLEngine .DEPLOY_MODEL_FOLDER );
391
+ assertEquals (expectedDeployModelRootPath .toString (), mlEngine .getDeployModelRootPath ().toString ());
392
+ Path expectedDeployModelPath = expectedDeployModelRootPath .resolve (modelId );
393
+ assertEquals (expectedDeployModelPath .toString (), mlEngine .getDeployModelPath (modelId ).toString ());
365
394
366
- assertEquals (expectedPath .toString (), actualPath .toString ());
395
+ String expectedDeployModelZipPath = expectedDeployModelRootPath .resolve (modelId ).resolve ("myModel" ) + ".zip" ;
396
+ assertEquals (expectedDeployModelZipPath , mlEngine .getDeployModelZipPath (modelId , "myModel" ));
397
+ Path expectedDeployModelChunkPath = expectedDeployModelRootPath .resolve (modelId ).resolve ("chunks" ).resolve ("1" );
398
+ assertEquals (expectedDeployModelChunkPath .toString (), mlEngine .getDeployModelChunkPath (modelId , 1 ).toString ());
399
+
400
+ assertEquals (
401
+ "https://artifacts.opensearch.org/models/ml-models/model_listing/pre_trained_models.json" ,
402
+ mlEngine .getPrebuiltModelMetaListPath ()
403
+ );
404
+
405
+ Path expectedRegisterRootPath = modelsCachePath .resolve (MLEngine .REGISTER_MODEL_FOLDER );
406
+ assertEquals (expectedRegisterRootPath .toString (), mlEngine .getRegisterModelRootPath ().toString ());
407
+ Path expectedRegisterModelPath = expectedRegisterRootPath .resolve (modelId );
408
+ assertEquals (expectedRegisterModelPath .toString (), mlEngine .getRegisterModelPath (modelId ).toString ());
409
+
410
+ Path expectedMdelCacheRootPath = modelsCachePath .resolve ("models" );
411
+ assertEquals (expectedMdelCacheRootPath .toString (), mlEngine .getModelCacheRootPath ().toString ());
412
+ Path expectedMdelCachePath = expectedMdelCacheRootPath .resolve (modelId );
413
+ assertEquals (expectedMdelCachePath .toString (), mlEngine .getModelCachePath (modelId ).toString ());
367
414
}
368
415
369
416
@ Test
@@ -444,4 +491,33 @@ public void testGetConnectorCredential() throws IOException {
444
491
assertEquals (decryptedCredential .get ("key" ), "test_key_value" );
445
492
assertEquals (decryptedCredential .get ("region" ), "test region" );
446
493
}
494
+
495
+ @ Test
496
+ public void testGetConnectorCredentialWithoutRegion () throws IOException {
497
+ String encryptedValue = mlEngine .encrypt ("test_key_value" , null );
498
+ String test_connector_string = "{\" name\" :\" test_connector_name\" ,\" version\" :\" 1\" ,"
499
+ + "\" description\" :\" this is a test connector\" ,\" protocol\" :\" http\" ,"
500
+ + "\" parameters\" :{},\" credential\" :{\" key\" :\" "
501
+ + encryptedValue
502
+ + "\" },"
503
+ + "\" actions\" :[{\" action_type\" :\" PREDICT\" ,\" method\" :\" POST\" ,\" url\" :\" https://test.com\" ,"
504
+ + "\" headers\" :{\" api_key\" :\" ${credential.key}\" },"
505
+ + "\" request_body\" :\" {\\ \" input\\ \" : \\ \" ${parameters.input}\\ \" }\" }],"
506
+ + "\" retry_backoff_millis\" :10,\" retry_timeout_seconds\" :10,\" max_retry_times\" :-1,\" retry_backoff_policy\" :\" constant\" }}" ;
507
+
508
+ XContentParser parser = XContentType .JSON
509
+ .xContent ()
510
+ .createParser (
511
+ new NamedXContentRegistry (new SearchModule (Settings .EMPTY , Collections .emptyList ()).getNamedXContents ()),
512
+ null ,
513
+ test_connector_string
514
+ );
515
+ parser .nextToken ();
516
+
517
+ HttpConnector connector = new HttpConnector ("http" , parser );
518
+ Map <String , String > decryptedCredential = mlEngine .getConnectorCredential (connector );
519
+ assertNotNull (decryptedCredential );
520
+ assertEquals ("test_key_value" , decryptedCredential .get ("key" ));
521
+ assertEquals (null , decryptedCredential .get ("region" ));
522
+ }
447
523
}
0 commit comments