Skip to content

Commit 161ccb2

Browse files
ylwu-amznb4sjoo
authored andcommitted
refactor ML algorithm package for supporting custom model (opensearch-project#474)
Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: Sicheng Song <[email protected]>
1 parent a2b9cb1 commit 161ccb2

40 files changed

+689
-349
lines changed

build-tools/repositories.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ repositories {
77
mavenLocal()
88
maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" }
99
mavenCentral()
10+
maven {url 'https://oss.sonatype.org/content/repositories/snapshots/'}
1011
maven { url "https://d1nvenhzbhpy0q.cloudfront.net/snapshots/lucene/" }
1112
}

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+17-9
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import org.opensearch.action.delete.DeleteResponse;
1414
import org.opensearch.action.search.SearchRequest;
1515
import org.opensearch.action.search.SearchResponse;
16-
import org.opensearch.client.node.NodeClient;
16+
import org.opensearch.client.Client;
1717
import org.opensearch.ml.common.input.MLInput;
1818
import org.opensearch.ml.common.MLModel;
1919
import org.opensearch.ml.common.output.MLOutput;
@@ -38,17 +38,16 @@
3838
@RequiredArgsConstructor
3939
public class MachineLearningNodeClient implements MachineLearningClient {
4040

41-
NodeClient client;
41+
Client client;
4242

4343
@Override
4444
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
4545
validateMLInput(mlInput, true);
4646

4747
MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder()
48-
.mlInput(mlInput)
49-
.modelId(modelId)
50-
.build();
51-
48+
.mlInput(mlInput)
49+
.modelId(modelId)
50+
.build();
5251
client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener));
5352
}
5453

@@ -80,9 +79,18 @@ public void getModel(String modelId, ActionListener<MLModel> listener) {
8079
.modelId(modelId)
8180
.build();
8281

83-
client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, ActionListener.wrap(response -> {
84-
listener.onResponse(MLModelGetResponse.fromActionResponse(response).getMlModel());
85-
}, listener::onFailure));
82+
client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener));
83+
}
84+
85+
private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(ActionListener<MLModel> listener) {
86+
ActionListener<MLModelGetResponse> internalListener = ActionListener.wrap(predictionResponse -> {
87+
listener.onResponse(predictionResponse.getMlModel());
88+
}, listener::onFailure);
89+
ActionListener<MLModelGetResponse> actionListener = wrapActionListener(internalListener, res -> {
90+
MLModelGetResponse getResponse = MLModelGetResponse.fromActionResponse(res);
91+
return getResponse;
92+
});
93+
return actionListener;
8694
}
8795

8896
@Override

common/src/main/java/org/opensearch/ml/common/MLModel.java

-7
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,6 @@ public MLModel(String name, FunctionName algorithm, Integer version, String cont
9595
this.totalChunks = totalChunks;
9696
}
9797

98-
public MLModel(FunctionName algorithm, Model model) {
99-
this.name = model.getName();
100-
this.algorithm = algorithm;
101-
this.version = model.getVersion();
102-
this.content = Base64.getEncoder().encodeToString(model.getContent());
103-
}
104-
10598
public MLModel(StreamInput input) throws IOException{
10699
name = input.readOptionalString();
107100
algorithm = input.readEnum(FunctionName.class);

common/src/main/java/org/opensearch/ml/common/exception/MLException.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public MLException(Throwable cause) {
3333
}
3434

3535
/**
36-
* Constructor with specified error message adn cause.
36+
* Constructor with specified error message and cause.
3737
* @param message error message
3838
* @param cause exception cause
3939
*/

common/src/main/java/org/opensearch/ml/common/input/MLInput.java

+6
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,17 @@ public class MLInput implements Input {
4242
public static final String INPUT_INDEX_FIELD = "input_index";
4343
public static final String INPUT_QUERY_FIELD = "input_query";
4444
public static final String INPUT_DATA_FIELD = "input_data";
45+
4546
// For trained model
47+
// Return bytes in model output
4648
public static final String RETURN_BYTES_FIELD = "return_bytes";
49+
// Return bytes in model output. This can be used together with return_bytes.
4750
public static final String RETURN_NUMBER_FIELD = "return_number";
51+
// Filter target response with name in model output
4852
public static final String TARGET_RESPONSE_FIELD = "target_response";
53+
// Filter target response with position in model output
4954
public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";
55+
// Input text sentences for text embedding model
5056
public static final String TEXT_DOCS_FIELD = "text_docs";
5157

5258
// Algorithm name

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -66,23 +66,23 @@ public void writeTo(StreamOutput out) throws IOException {
6666

6767
public void filter(ModelResultFilter resultFilter) {
6868
boolean returnBytes = resultFilter.isReturnBytes();
69-
boolean returnNUmber = resultFilter.isReturnNumber();
69+
boolean returnNumber = resultFilter.isReturnNumber();
7070
List<String> targetResponse = resultFilter.getTargetResponse();
7171
List<Integer> targetResponsePositions = resultFilter.getTargetResponsePositions();
7272
if ((targetResponse == null || targetResponse.size() == 0)
7373
&& (targetResponsePositions == null || targetResponsePositions.size() == 0)) {
74-
mlModelTensors.forEach(output -> filter(output, returnBytes, returnNUmber));
74+
mlModelTensors.forEach(output -> filter(output, returnBytes, returnNumber));
7575
return;
7676
}
7777
List<ModelTensor> targetOutput = new ArrayList<>();
7878
if (mlModelTensors != null) {
7979
for (int i = 0 ; i<mlModelTensors.size(); i++) {
8080
ModelTensor output = mlModelTensors.get(i);
8181
if (targetResponse != null && targetResponse.contains(output.getName())) {
82-
filter(output, returnBytes, returnNUmber);
82+
filter(output, returnBytes, returnNumber);
8383
targetOutput.add(output);
8484
} else if (targetResponsePositions != null && targetResponsePositions.contains(i)) {
85-
filter(output, returnBytes, returnNUmber);
85+
filter(output, returnBytes, returnNumber);
8686
targetOutput.add(output);
8787
}
8888
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

+25-9
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,47 @@
55

66
package org.opensearch.ml.engine;
77

8+
import org.opensearch.ml.common.MLModel;
89
import org.opensearch.ml.common.dataframe.DataFrame;
10+
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
11+
import org.opensearch.ml.common.dataset.MLInputDataset;
912
import org.opensearch.ml.common.input.Input;
1013
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
1114
import org.opensearch.ml.common.input.MLInput;
1215
import org.opensearch.ml.common.output.MLOutput;
13-
import org.opensearch.ml.common.Model;
1416
import org.opensearch.ml.common.output.Output;
1517

18+
import java.util.Map;
19+
1620
/**
1721
* This is the interface to all ml algorithms.
1822
*/
1923
public class MLEngine {
2024

21-
public static Model train(Input input) {
25+
public static MLModel train(Input input) {
2226
validateMLInput(input);
2327
MLInput mlInput = (MLInput) input;
2428
Trainable trainable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
2529
if (trainable == null) {
2630
throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
2731
}
28-
return trainable.train(mlInput.getDataFrame());
32+
return trainable.train(mlInput.getInputDataset());
33+
}
34+
35+
public static Predictable load(MLModel mlModel, Map<String, Object> params) {
36+
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
37+
predictable.initModel(mlModel, params);
38+
return predictable;
2939
}
3040

31-
public static MLOutput predict(Input input, Model model) {
41+
public static MLOutput predict(Input input, MLModel model) {
3242
validateMLInput(input);
3343
MLInput mlInput = (MLInput) input;
3444
Predictable predictable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
3545
if (predictable == null) {
3646
throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
3747
}
38-
return predictable.predict(mlInput.getDataFrame(), model);
48+
return predictable.predict(mlInput.getInputDataset(), model);
3949
}
4050

4151
public static MLOutput trainAndPredict(Input input) {
@@ -45,7 +55,7 @@ public static MLOutput trainAndPredict(Input input) {
4555
if (trainAndPredictable == null) {
4656
throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
4757
}
48-
return trainAndPredictable.trainAndPredict(mlInput.getDataFrame());
58+
return trainAndPredictable.trainAndPredict(mlInput.getInputDataset());
4959
}
5060

5161
public static Output execute(Input input) {
@@ -63,9 +73,15 @@ private static void validateMLInput(Input input) {
6373
throw new IllegalArgumentException("Input should be MLInput");
6474
}
6575
MLInput mlInput = (MLInput) input;
66-
DataFrame dataFrame = mlInput.getDataFrame();
67-
if (dataFrame == null || dataFrame.size() == 0) {
68-
throw new IllegalArgumentException("Input data frame should not be null or empty");
76+
MLInputDataset inputDataset = mlInput.getInputDataset();
77+
if (inputDataset == null) {
78+
throw new IllegalArgumentException("Input data set should not be null");
79+
}
80+
if (inputDataset instanceof DataFrameInputDataset) {
81+
DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame();
82+
if (dataFrame == null || dataFrame.size() == 0) {
83+
throw new IllegalArgumentException("Input data frame should not be null or empty");
84+
}
6985
}
7086
}
7187

ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java

+26-5
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,42 @@
55

66
package org.opensearch.ml.engine;
77

8-
import org.opensearch.ml.common.dataframe.DataFrame;
8+
import org.opensearch.ml.common.MLModel;
9+
import org.opensearch.ml.common.dataset.MLInputDataset;
910
import org.opensearch.ml.common.output.MLOutput;
10-
import org.opensearch.ml.common.Model;
11+
12+
import java.util.Map;
1113

1214
/**
1315
* This is machine learning algorithms predict interface.
1416
*/
1517
public interface Predictable {
1618

1719
/**
18-
* Predict with given features and model (optional).
19-
* @param dataFrame features data
20+
* Predict with given input data and model (optional).
21+
* Will reload model into memory with model content.
22+
* @param inputDataset input data set
2023
* @param model the java serialized model
2124
* @return predicted results
2225
*/
23-
MLOutput predict(DataFrame dataFrame, Model model);
26+
MLOutput predict(MLInputDataset inputDataset, MLModel model);
27+
28+
/**
29+
* Predict with given input data with loaded model.
30+
* @param inputDataset input data set
31+
* @return predicted results
32+
*/
33+
MLOutput predict(MLInputDataset inputDataset);
2434

35+
/**
36+
* Init model (load model into memory) with ML model content and params.
37+
* @param model ML model
38+
* @param params other parameters
39+
*/
40+
void initModel(MLModel model, Map<String, Object> params);
41+
42+
/**
43+
* Close resources like loaded model.
44+
*/
45+
void close();
2546
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/TrainAndPredictable.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,20 @@
55

66
package org.opensearch.ml.engine;
77

8-
import org.opensearch.ml.common.dataframe.DataFrame;
8+
import org.opensearch.ml.common.dataset.MLInputDataset;
99
import org.opensearch.ml.common.output.MLOutput;
1010

1111

1212
/**
13-
* This is machine learning algorithms train interface.
13+
* This is machine learning algorithms train and predict interface.
1414
*/
1515
public interface TrainAndPredictable extends Trainable, Predictable {
1616

1717
/**
18-
* Train model with given features. Then predict with the same data.
19-
* @param dataFrame training data
20-
* @return the java serialized model
18+
* Train model with given input data. Then predict with the same data.
19+
* @param inputDataset training data
20+
* @return ML model with serialized model content
2121
*/
22-
MLOutput trainAndPredict(DataFrame dataFrame);
22+
MLOutput trainAndPredict(MLInputDataset inputDataset);
2323

2424
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/Trainable.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
package org.opensearch.ml.engine;
77

8-
import org.opensearch.ml.common.dataframe.DataFrame;
9-
import org.opensearch.ml.common.Model;
8+
import org.opensearch.ml.common.MLModel;
9+
import org.opensearch.ml.common.dataset.MLInputDataset;
1010

1111
/**
1212
* This is machine learning algorithms train interface.
@@ -15,9 +15,9 @@ public interface Trainable {
1515

1616
/**
1717
* Train model with given features.
18-
* @param dataFrame training data
19-
* @return the java serialized model
18+
* @param inputDataset training data
19+
* @return ML model with serialized model content
2020
*/
21-
Model train(DataFrame dataFrame);
21+
MLModel train(MLInputDataset inputDataset);
2222

2323
}

0 commit comments

Comments
 (0)