Skip to content

Commit 0aa2105

Browse files
authored
restructure ML input, parameter and output (opensearch-project#245)
Signed-off-by: Yaliang Wu <[email protected]>
1 parent cad8451 commit 0aa2105

File tree

142 files changed

+446
-448
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

142 files changed

+446
-448
lines changed

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
import org.opensearch.action.search.SearchRequest;
1313
import org.opensearch.action.search.SearchResponse;
1414
import org.opensearch.action.support.PlainActionFuture;
15-
import org.opensearch.ml.common.parameter.*;
15+
import org.opensearch.ml.common.input.MLInput;
16+
import org.opensearch.ml.common.MLModel;
17+
import org.opensearch.ml.common.output.MLOutput;
18+
import org.opensearch.ml.common.MLTask;
1619

1720
/**
1821
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
import org.opensearch.action.search.SearchRequest;
1515
import org.opensearch.action.search.SearchResponse;
1616
import org.opensearch.client.node.NodeClient;
17-
import org.opensearch.ml.common.parameter.*;
17+
import org.opensearch.ml.common.input.MLInput;
18+
import org.opensearch.ml.common.MLModel;
19+
import org.opensearch.ml.common.output.MLOutput;
20+
import org.opensearch.ml.common.MLTask;
1821
import org.opensearch.ml.common.transport.MLTaskResponse;
1922
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
2023
import org.opensearch.ml.common.transport.model.MLModelGetResponse;

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

+7-5
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
import org.opensearch.action.delete.DeleteResponse;
1616
import org.opensearch.action.search.SearchRequest;
1717
import org.opensearch.action.search.SearchResponse;
18-
import org.opensearch.common.io.stream.StreamOutput;
19-
import org.opensearch.common.xcontent.XContentBuilder;
2018
import org.opensearch.ml.common.dataframe.DataFrame;
2119
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
22-
import org.opensearch.ml.common.parameter.*;
23-
24-
import java.io.IOException;
20+
import org.opensearch.ml.common.input.MLInput;
21+
import org.opensearch.ml.common.FunctionName;
22+
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
23+
import org.opensearch.ml.common.MLModel;
24+
import org.opensearch.ml.common.output.MLOutput;
25+
import org.opensearch.ml.common.MLTask;
26+
import org.opensearch.ml.common.output.MLTrainingOutput;
2527

2628
import static org.junit.Assert.assertEquals;
2729
import static org.mockito.Mockito.verify;

client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java

+8-8
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@
2929
import org.opensearch.index.shard.ShardId;
3030
import org.opensearch.ml.common.dataframe.DataFrame;
3131
import org.opensearch.ml.common.dataset.MLInputDataset;
32-
import org.opensearch.ml.common.parameter.FunctionName;
33-
import org.opensearch.ml.common.parameter.MLInput;
34-
import org.opensearch.ml.common.parameter.MLModel;
35-
import org.opensearch.ml.common.parameter.MLOutput;
36-
import org.opensearch.ml.common.parameter.MLPredictionOutput;
37-
import org.opensearch.ml.common.parameter.MLTask;
38-
import org.opensearch.ml.common.parameter.MLTaskState;
39-
import org.opensearch.ml.common.parameter.MLTrainingOutput;
32+
import org.opensearch.ml.common.FunctionName;
33+
import org.opensearch.ml.common.input.MLInput;
34+
import org.opensearch.ml.common.MLModel;
35+
import org.opensearch.ml.common.output.MLOutput;
36+
import org.opensearch.ml.common.output.MLPredictionOutput;
37+
import org.opensearch.ml.common.MLTask;
38+
import org.opensearch.ml.common.MLTaskState;
39+
import org.opensearch.ml.common.output.MLTrainingOutput;
4040
import org.opensearch.ml.common.transport.MLTaskResponse;
4141
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
4242
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;

common/src/main/java/org/opensearch/ml/common/parameter/FunctionName.java common/src/main/java/org/opensearch/ml/common/FunctionName.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common;
77

88
public enum FunctionName {
99
LINEAR_REGRESSION,

common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java

+30-10
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
import org.opensearch.ml.common.annotation.MLAlgoParameter;
1515
import org.opensearch.ml.common.dataset.MLInputDataType;
1616
import org.opensearch.ml.common.exception.MLException;
17-
import org.opensearch.ml.common.parameter.FunctionName;
18-
import org.opensearch.ml.common.parameter.MLOutputType;
17+
import org.opensearch.ml.common.output.MLOutputType;
1918
import org.reflections.Reflections;
2019

2120
import java.lang.reflect.Constructor;
@@ -47,15 +46,17 @@ public class MLCommonsClassLoader {
4746

4847
public static void loadClassMapping() {
4948
loadMLAlgoParameterClassMapping();
49+
loadMLOutputClassMapping();
5050
loadMLInputDataSetClassMapping();
51-
loadExecuteInputOutputClassMapping();
51+
loadExecuteInputClassMapping();
52+
loadExecuteOutputClassMapping();
5253
}
5354

5455
/**
5556
* Load ML algorithm parameter and ML output class.
5657
*/
5758
private static void loadMLAlgoParameterClassMapping() {
58-
Reflections reflections = new Reflections("org.opensearch.ml.common.parameter");
59+
Reflections reflections = new Reflections("org.opensearch.ml.common.input.parameter");
5960

6061
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(MLAlgoParameter.class);
6162
// Load ML algorithm parameter class
@@ -80,6 +81,22 @@ private static void loadMLAlgoParameterClassMapping() {
8081
}
8182
}
8283

84+
/**
85+
* Load ML algorithm parameter and ML output class.
86+
*/
87+
private static void loadMLOutputClassMapping() {
88+
Reflections reflections = new Reflections("org.opensearch.ml.common.output");
89+
90+
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(MLAlgoOutput.class);
91+
for (Class<?> clazz : classes) {
92+
MLAlgoOutput mlAlgoOutput = clazz.getAnnotation(MLAlgoOutput.class);
93+
MLOutputType mlOutputType = mlAlgoOutput.value();
94+
if (mlOutputType != null) {
95+
parameterClassMap.put(mlOutputType, clazz);
96+
}
97+
}
98+
}
99+
83100
/**
84101
* Load ML input data set class
85102
*/
@@ -98,11 +115,9 @@ private static void loadMLInputDataSetClassMapping() {
98115
/**
99116
* Load execute input output class.
100117
*/
101-
private static void loadExecuteInputOutputClassMapping() {
102-
Reflections reflections = new Reflections("org.opensearch.ml.common.parameter");
103-
118+
private static void loadExecuteInputClassMapping() {
119+
Reflections reflections = new Reflections("org.opensearch.ml.common.input.execute");
104120
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(ExecuteInput.class);
105-
// Load execute input class
106121
for (Class<?> clazz : classes) {
107122
ExecuteInput executeInput = clazz.getAnnotation(ExecuteInput.class);
108123
FunctionName[] algorithms = executeInput.algorithms();
@@ -112,9 +127,14 @@ private static void loadExecuteInputOutputClassMapping() {
112127
}
113128
}
114129
}
130+
}
115131

116-
// Load execute output class
117-
classes = reflections.getTypesAnnotatedWith(ExecuteOutput.class);
132+
/**
133+
* Load execute input output class.
134+
*/
135+
private static void loadExecuteOutputClassMapping() {
136+
Reflections reflections = new Reflections("org.opensearch.ml.common.output.execute");
137+
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(ExecuteOutput.class);
118138
for (Class<?> clazz : classes) {
119139
ExecuteOutput executeOutput = clazz.getAnnotation(ExecuteOutput.class);
120140
FunctionName[] algorithms = executeOutput.algorithms();

common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java common/src/main/java/org/opensearch/ml/common/MLModel.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common;
77

88
import lombok.Builder;
99
import lombok.Getter;

common/src/main/java/org/opensearch/ml/common/parameter/MLTask.java common/src/main/java/org/opensearch/ml/common/MLTask.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common;
77

88
import lombok.Builder;
99
import lombok.EqualsAndHashCode;

common/src/main/java/org/opensearch/ml/common/parameter/MLTaskState.java common/src/main/java/org/opensearch/ml/common/MLTaskState.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common;
77

88
/**
99
* ML task states.

common/src/main/java/org/opensearch/ml/common/parameter/MLTaskType.java common/src/main/java/org/opensearch/ml/common/MLTaskType.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common;
77

88
public enum MLTaskType {
99
TRAINING,

common/src/main/java/org/opensearch/ml/common/parameter/Model.java common/src/main/java/org/opensearch/ml/common/Model.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common;
77

88
import lombok.Data;
99

common/src/main/java/org/opensearch/ml/common/annotation/ExecuteInput.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
package org.opensearch.ml.common.annotation;
77

8-
import org.opensearch.ml.common.parameter.FunctionName;
8+
import org.opensearch.ml.common.FunctionName;
99

1010
import java.lang.annotation.ElementType;
1111
import java.lang.annotation.Retention;

common/src/main/java/org/opensearch/ml/common/annotation/ExecuteOutput.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
package org.opensearch.ml.common.annotation;
77

8-
import org.opensearch.ml.common.parameter.FunctionName;
8+
import org.opensearch.ml.common.FunctionName;
99

1010
import java.lang.annotation.ElementType;
1111
import java.lang.annotation.Retention;

common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoOutput.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
package org.opensearch.ml.common.annotation;
77

8-
import org.opensearch.ml.common.parameter.MLOutputType;
8+
import org.opensearch.ml.common.output.MLOutputType;
99

1010
import java.lang.annotation.ElementType;
1111
import java.lang.annotation.Retention;

common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoParameter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
package org.opensearch.ml.common.annotation;
77

8-
import org.opensearch.ml.common.parameter.FunctionName;
8+
import org.opensearch.ml.common.FunctionName;
99

1010
import java.lang.annotation.ElementType;
1111
import java.lang.annotation.Retention;

common/src/main/java/org/opensearch/ml/common/parameter/Input.java common/src/main/java/org/opensearch/ml/common/input/Input.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common.input;
77

88
import org.opensearch.common.io.stream.Writeable;
99
import org.opensearch.common.xcontent.ToXContentObject;
10+
import org.opensearch.ml.common.FunctionName;
1011

1112
public interface Input extends ToXContentObject, Writeable {
1213

common/src/main/java/org/opensearch/ml/common/parameter/MLInput.java common/src/main/java/org/opensearch/ml/common/input/MLInput.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common.input;
77

88
import lombok.Builder;
99
import lombok.Data;
@@ -18,6 +18,8 @@
1818
import org.opensearch.ml.common.dataset.MLInputDataType;
1919
import org.opensearch.ml.common.dataset.MLInputDataset;
2020
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
21+
import org.opensearch.ml.common.FunctionName;
22+
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
2123
import org.opensearch.search.builder.SearchSourceBuilder;
2224

2325
import java.io.IOException;

common/src/main/java/org/opensearch/ml/common/parameter/AnomalyLocalizationInput.java common/src/main/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInput.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common.input.execute.anomalylocalization;
77

88
import java.io.IOException;
99
import java.util.ArrayList;
@@ -19,6 +19,8 @@
1919
import org.opensearch.common.xcontent.XContentParser;
2020
import org.opensearch.index.query.QueryBuilder;
2121
import org.opensearch.ml.common.annotation.ExecuteInput;
22+
import org.opensearch.ml.common.FunctionName;
23+
import org.opensearch.ml.common.input.Input;
2224
import org.opensearch.search.aggregations.AggregationBuilder;
2325
import org.opensearch.search.aggregations.AggregatorFactories;
2426

@@ -47,7 +49,7 @@ public class AnomalyLocalizationInput implements Input {
4749
public static final String FIELD_ANOMALY_START_TIME = "anomaly_start_time";
4850
public static final String FIELD_FILTER_QUERY = "filter_query";
4951
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY_ENTRY = new NamedXContentRegistry.Entry(
50-
org.opensearch.ml.common.parameter.Input.class,
52+
Input.class,
5153
new ParseField(FunctionName.ANOMALY_LOCALIZATION.name()),
5254
parser -> parse(parser)
5355
);

common/src/main/java/org/opensearch/ml/common/parameter/LocalSampleCalculatorInput.java common/src/main/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInput.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common.input.execute.samplecalculator;
77

88
import lombok.Builder;
99
import lombok.Data;
@@ -14,6 +14,8 @@
1414
import org.opensearch.common.xcontent.XContentBuilder;
1515
import org.opensearch.common.xcontent.XContentParser;
1616
import org.opensearch.ml.common.annotation.ExecuteInput;
17+
import org.opensearch.ml.common.FunctionName;
18+
import org.opensearch.ml.common.input.Input;
1719

1820
import java.io.IOException;
1921
import java.util.ArrayList;

common/src/main/java/org/opensearch/ml/common/parameter/MLAlgoParams.java common/src/main/java/org/opensearch/ml/common/input/parameter/MLAlgoParams.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common.input.parameter;
77

88
import org.opensearch.common.io.stream.NamedWriteable;
99
import org.opensearch.common.xcontent.ToXContentObject;

common/src/main/java/org/opensearch/ml/common/parameter/AnomalyDetectionParams.java common/src/main/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParams.java

+7-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.parameter;
6+
package org.opensearch.ml.common.input.parameter.ad;
77

88
import lombok.Builder;
99
import lombok.Data;
@@ -14,6 +14,8 @@
1414
import org.opensearch.common.xcontent.XContentBuilder;
1515
import org.opensearch.common.xcontent.XContentParser;
1616
import org.opensearch.ml.common.annotation.MLAlgoParameter;
17+
import org.opensearch.ml.common.FunctionName;
18+
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
1719

1820
import java.io.IOException;
1921
import java.util.Locale;
@@ -22,7 +24,7 @@
2224

2325
@Data
2426
@MLAlgoParameter(algorithms={FunctionName.AD_LIBSVM})
25-
public class AnomalyDetectionParams implements MLAlgoParams {
27+
public class AnomalyDetectionLibSVMParams implements MLAlgoParams {
2628
public static final String PARSE_FIELD_NAME = FunctionName.AD_LIBSVM.name();
2729
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
2830
MLAlgoParams.class,
@@ -47,7 +49,7 @@ public class AnomalyDetectionParams implements MLAlgoParams {
4749

4850

4951
@Builder(toBuilder = true)
50-
public AnomalyDetectionParams(ADKernelType kernelType, Double gamma, Double nu, Double cost, Double coeff, Double epsilon, Integer degree) {
52+
public AnomalyDetectionLibSVMParams(ADKernelType kernelType, Double gamma, Double nu, Double cost, Double coeff, Double epsilon, Integer degree) {
5153
this.kernelType = kernelType;
5254
this.gamma = gamma;
5355
this.nu = nu;
@@ -57,7 +59,7 @@ public AnomalyDetectionParams(ADKernelType kernelType, Double gamma, Double nu,
5759
this.degree = degree;
5860
}
5961

60-
public AnomalyDetectionParams(StreamInput in) throws IOException {
62+
public AnomalyDetectionLibSVMParams(StreamInput in) throws IOException {
6163
if (in.readBoolean()) {
6264
this.kernelType = in.readEnum(ADKernelType.class);
6365
}
@@ -110,7 +112,7 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
110112
break;
111113
}
112114
}
113-
return new AnomalyDetectionParams(kernelType, gamma, nu, cost, coeff, epsilon, degree);
115+
return new AnomalyDetectionLibSVMParams(kernelType, gamma, nu, cost, coeff, epsilon, degree);
114116
}
115117

116118
@Override

0 commit comments

Comments
 (0)