14
14
import org .opensearch .ml .common .annotation .MLAlgoParameter ;
15
15
import org .opensearch .ml .common .dataset .MLInputDataType ;
16
16
import org .opensearch .ml .common .exception .MLException ;
17
- import org .opensearch .ml .common .parameter .FunctionName ;
18
- import org .opensearch .ml .common .parameter .MLOutputType ;
17
+ import org .opensearch .ml .common .output .MLOutputType ;
19
18
import org .reflections .Reflections ;
20
19
21
20
import java .lang .reflect .Constructor ;
@@ -47,15 +46,17 @@ public class MLCommonsClassLoader {
47
46
48
47
public static void loadClassMapping () {
49
48
loadMLAlgoParameterClassMapping ();
49
+ loadMLOutputClassMapping ();
50
50
loadMLInputDataSetClassMapping ();
51
- loadExecuteInputOutputClassMapping ();
51
+ loadExecuteInputClassMapping ();
52
+ loadExecuteOutputClassMapping ();
52
53
}
53
54
54
55
/**
55
56
* Load ML algorithm parameter and ML output class.
56
57
*/
57
58
private static void loadMLAlgoParameterClassMapping () {
58
- Reflections reflections = new Reflections ("org.opensearch.ml.common.parameter" );
59
+ Reflections reflections = new Reflections ("org.opensearch.ml.common.input. parameter" );
59
60
60
61
Set <Class <?>> classes = reflections .getTypesAnnotatedWith (MLAlgoParameter .class );
61
62
// Load ML algorithm parameter class
@@ -80,6 +81,22 @@ private static void loadMLAlgoParameterClassMapping() {
80
81
}
81
82
}
82
83
84
+ /**
85
+ * Load ML algorithm parameter and ML output class.
86
+ */
87
+ private static void loadMLOutputClassMapping () {
88
+ Reflections reflections = new Reflections ("org.opensearch.ml.common.output" );
89
+
90
+ Set <Class <?>> classes = reflections .getTypesAnnotatedWith (MLAlgoOutput .class );
91
+ for (Class <?> clazz : classes ) {
92
+ MLAlgoOutput mlAlgoOutput = clazz .getAnnotation (MLAlgoOutput .class );
93
+ MLOutputType mlOutputType = mlAlgoOutput .value ();
94
+ if (mlOutputType != null ) {
95
+ parameterClassMap .put (mlOutputType , clazz );
96
+ }
97
+ }
98
+ }
99
+
83
100
/**
84
101
* Load ML input data set class
85
102
*/
@@ -98,11 +115,9 @@ private static void loadMLInputDataSetClassMapping() {
98
115
/**
99
116
* Load execute input output class.
100
117
*/
101
- private static void loadExecuteInputOutputClassMapping () {
102
- Reflections reflections = new Reflections ("org.opensearch.ml.common.parameter" );
103
-
118
+ private static void loadExecuteInputClassMapping () {
119
+ Reflections reflections = new Reflections ("org.opensearch.ml.common.input.execute" );
104
120
Set <Class <?>> classes = reflections .getTypesAnnotatedWith (ExecuteInput .class );
105
- // Load execute input class
106
121
for (Class <?> clazz : classes ) {
107
122
ExecuteInput executeInput = clazz .getAnnotation (ExecuteInput .class );
108
123
FunctionName [] algorithms = executeInput .algorithms ();
@@ -112,9 +127,14 @@ private static void loadExecuteInputOutputClassMapping() {
112
127
}
113
128
}
114
129
}
130
+ }
115
131
116
- // Load execute output class
117
- classes = reflections .getTypesAnnotatedWith (ExecuteOutput .class );
132
+ /**
133
+ * Load execute input output class.
134
+ */
135
+ private static void loadExecuteOutputClassMapping () {
136
+ Reflections reflections = new Reflections ("org.opensearch.ml.common.output.execute" );
137
+ Set <Class <?>> classes = reflections .getTypesAnnotatedWith (ExecuteOutput .class );
118
138
for (Class <?> clazz : classes ) {
119
139
ExecuteOutput executeOutput = clazz .getAnnotation (ExecuteOutput .class );
120
140
FunctionName [] algorithms = executeOutput .algorithms ();
0 commit comments