5
5
6
6
package org .opensearch .ml .engine ;
7
7
8
+ import org .opensearch .ml .common .MLModel ;
8
9
import org .opensearch .ml .common .dataframe .DataFrame ;
10
+ import org .opensearch .ml .common .dataset .DataFrameInputDataset ;
11
+ import org .opensearch .ml .common .dataset .MLInputDataset ;
9
12
import org .opensearch .ml .common .input .Input ;
10
13
import org .opensearch .ml .common .input .parameter .MLAlgoParams ;
11
14
import org .opensearch .ml .common .input .MLInput ;
12
15
import org .opensearch .ml .common .output .MLOutput ;
13
- import org .opensearch .ml .common .Model ;
14
16
import org .opensearch .ml .common .output .Output ;
15
17
18
+ import java .util .Map ;
19
+
16
20
/**
17
21
* This is the interface to all ml algorithms.
18
22
*/
19
23
public class MLEngine {
20
24
21
- public static Model train (Input input ) {
25
+ public static MLModel train (Input input ) {
22
26
validateMLInput (input );
23
27
MLInput mlInput = (MLInput ) input ;
24
28
Trainable trainable = MLEngineClassLoader .initInstance (mlInput .getAlgorithm (), mlInput .getParameters (), MLAlgoParams .class );
25
29
if (trainable == null ) {
26
30
throw new IllegalArgumentException ("Unsupported algorithm: " + mlInput .getAlgorithm ());
27
31
}
28
- return trainable .train (mlInput .getDataFrame ());
32
+ return trainable .train (mlInput .getInputDataset ());
33
+ }
34
+
35
+ public static Predictable load (MLModel mlModel , Map <String , Object > params ) {
36
+ Predictable predictable = MLEngineClassLoader .initInstance (mlModel .getAlgorithm (), null , MLAlgoParams .class );
37
+ predictable .initModel (mlModel , params );
38
+ return predictable ;
29
39
}
30
40
31
- public static MLOutput predict (Input input , Model model ) {
41
+ public static MLOutput predict (Input input , MLModel model ) {
32
42
validateMLInput (input );
33
43
MLInput mlInput = (MLInput ) input ;
34
44
Predictable predictable = MLEngineClassLoader .initInstance (mlInput .getAlgorithm (), mlInput .getParameters (), MLAlgoParams .class );
35
45
if (predictable == null ) {
36
46
throw new IllegalArgumentException ("Unsupported algorithm: " + mlInput .getAlgorithm ());
37
47
}
38
- return predictable .predict (mlInput .getDataFrame (), model );
48
+ return predictable .predict (mlInput .getInputDataset (), model );
39
49
}
40
50
41
51
public static MLOutput trainAndPredict (Input input ) {
@@ -45,7 +55,7 @@ public static MLOutput trainAndPredict(Input input) {
45
55
if (trainAndPredictable == null ) {
46
56
throw new IllegalArgumentException ("Unsupported algorithm: " + mlInput .getAlgorithm ());
47
57
}
48
- return trainAndPredictable .trainAndPredict (mlInput .getDataFrame ());
58
+ return trainAndPredictable .trainAndPredict (mlInput .getInputDataset ());
49
59
}
50
60
51
61
public static Output execute (Input input ) {
@@ -63,9 +73,15 @@ private static void validateMLInput(Input input) {
63
73
throw new IllegalArgumentException ("Input should be MLInput" );
64
74
}
65
75
MLInput mlInput = (MLInput ) input ;
66
- DataFrame dataFrame = mlInput .getDataFrame ();
67
- if (dataFrame == null || dataFrame .size () == 0 ) {
68
- throw new IllegalArgumentException ("Input data frame should not be null or empty" );
76
+ MLInputDataset inputDataset = mlInput .getInputDataset ();
77
+ if (inputDataset == null ) {
78
+ throw new IllegalArgumentException ("Input data set should not be null" );
79
+ }
80
+ if (inputDataset instanceof DataFrameInputDataset ) {
81
+ DataFrame dataFrame = ((DataFrameInputDataset )inputDataset ).getDataFrame ();
82
+ if (dataFrame == null || dataFrame .size () == 0 ) {
83
+ throw new IllegalArgumentException ("Input data frame should not be null or empty" );
84
+ }
69
85
}
70
86
}
71
87
0 commit comments