Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.16] Use model type to check local or remote model #3623

Draft
wants to merge 3 commits into
base: 2.16
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
flags: ml-commons
token: ${{ secrets.CODECOV_TOKEN }}

- uses: actions/upload-artifact@v3
- uses: actions/upload-artifact@v4
with:
name: ml-plugin-linux-${{ matrix.java }}
path: ${{ steps.step-build-test-linux.outputs.build-test-linux }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;

import org.opensearch.client.node.NodeClient;
Expand Down Expand Up @@ -81,27 +82,30 @@ public List<Route> routes() {

@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
String algorithm = request.param(PARAMETER_ALGORITHM);
String userAlgorithm = request.param(PARAMETER_ALGORITHM);
String modelId = getParameterId(request, PARAMETER_MODEL_ID);
Optional<FunctionName> functionName = modelManager.getOptionalModelFunctionName(modelId);

if (algorithm == null && functionName.isPresent()) {
algorithm = functionName.get().name();
}

if (algorithm != null) {
MLPredictionTaskRequest mlPredictionTaskRequest = getRequest(modelId, algorithm, request);
return channel -> client
.execute(MLPredictionTaskAction.INSTANCE, mlPredictionTaskRequest, new RestToXContentListener<>(channel));
// check if the model is in cache
if (functionName.isPresent()) {
MLPredictionTaskRequest predictionRequest = getRequest(
modelId,
functionName.get().name(),
Objects.requireNonNullElse(userAlgorithm, functionName.get().name()),
request
);
return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel));
}

// If the model isn't in cache
return channel -> {
ActionListener<MLModel> listener = ActionListener.wrap(mlModel -> {
String algoName = mlModel.getAlgorithm().name();
String modelType = mlModel.getAlgorithm().name();
String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name());
client
.execute(
MLPredictionTaskAction.INSTANCE,
getRequest(modelId, algoName, request),
getRequest(modelId, modelType, modelAlgorithm, request),
new RestToXContentListener<>(channel)
);
}, e -> {
Expand All @@ -119,25 +123,30 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
}

/**
* Creates a MLPredictionTaskRequest from a RestRequest
* Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
* enabled features and model types, and parses the input data for prediction.
*
* @param request RestRequest
* @return MLPredictionTaskRequest
* @param modelId The ID of the ML model to use for prediction
* @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model
* @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
* @param request The REST request containing prediction input data
* @return MLPredictionTaskRequest configured with the model and input parameters
*/
@VisibleForTesting
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException {
ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request));
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
} else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT)))
&& !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
} else if (!ActionType.isValidActionInModelPrediction(actionType)) {
throw new IllegalArgumentException("Wrong action type in the rest request path!");
}

XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLInput mlInput = MLInput.parse(parser, algorithm, actionType);
MLInput mlInput = MLInput.parse(parser, userAlgorithm, actionType);
return new MLPredictionTaskRequest(modelId, mlInput, null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public class RestMLPredictionActionTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.empty());
when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.of(FunctionName.REMOTE));
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);
restMLPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting);
Expand Down Expand Up @@ -121,7 +121,8 @@ public void testRoutes_Batch() {

public void testGetRequest() throws IOException {
RestRequest request = getRestRequest_PredictModel();
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request);
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
.getRequest("modelId", FunctionName.KMEANS.name(), FunctionName.KMEANS.name(), request);

MLInput mlInput = mlPredictionTaskRequest.getMlInput();
verifyParsedKMeansMLInput(mlInput);
Expand All @@ -133,7 +134,8 @@ public void testGetRequest_RemoteInferenceDisabled() throws IOException {

when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false);
RestRequest request = getRestRequest_PredictModel();
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.REMOTE.name(), request);
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
.getRequest("modelId", FunctionName.REMOTE.name(), "text_embedding", request);
}

public void testGetRequest_LocalModelInferenceDisabled() throws IOException {
Expand All @@ -143,7 +145,7 @@ public void testGetRequest_LocalModelInferenceDisabled() throws IOException {
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);
RestRequest request = getRestRequest_PredictModel();
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
.getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), request);
.getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), "text_embedding", request);
}

public void testPrepareRequest() throws Exception {
Expand All @@ -169,7 +171,7 @@ public void testPrepareBatchRequest_WrongActionType() throws Exception {
thrown.expectMessage("Wrong Action Type");

RestRequest request = getBatchRestRequest_WrongActionType();
restMLPredictionAction.getRequest("model id", "remote", request);
restMLPredictionAction.getRequest("model id", "remote", "text_embedding", request);
}

@Ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,121 +651,6 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti
assertFalse(((String) responseMap.get("text")).isEmpty());
}

@Ignore
public void testCohereClassifyModel() throws IOException, InterruptedException {
// Skip test if key is null
if (COHERE_KEY == null) {
return;
}
String entity = "{\n"
+ " \"name\": \"Cohere classify model Connector\",\n"
+ " \"description\": \"The connector to public Cohere classify model service\",\n"
+ " \"version\": 1,\n"
+ " \"client_config\": {\n"
+ " \"max_connection\": 20,\n"
+ " \"connection_timeout\": 50000,\n"
+ " \"read_timeout\": 50000\n"
+ " },\n"
+ " \"protocol\": \"http\",\n"
+ " \"parameters\": {\n"
+ " \"endpoint\": \"api.cohere.ai\",\n"
+ " \"auth\": \"API_Key\",\n"
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens\": \"20\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"cohere_key\": \""
+ COHERE_KEY
+ "\"\n"
+ " },\n"
+ " \"actions\": [\n"
+ " {\n"
+ " \"action_type\": \"predict\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n"
+ " \"headers\": { \n"
+ " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n"
+ " }\n"
+ " ]\n"
+ "}";
Response response = createConnector(entity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("cohere classify model", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n"
+ " \"parameters\": {\n"
+ " \"inputs\": [\n"
+ " \"Confirm your email address\",\n"
+ " \"hey i need u to send some $\"\n"
+ " ],\n"
+ " \"examples\": [\n"
+ " {\n"
+ " \"text\": \"Dermatologists don't like her!\",\n"
+ " \"label\": \"Spam\"\n"
+ " },\n"
+ " {\n"
+ " \"text\": \"Hello, open to this?\",\n"
+ " \"label\": \"Spam\"\n"
+ " },\n"
+ " {\n"
+ " \"text\": \"I need help please wire me $1000 right now\",\n"
+ " \"label\": \"Spam\"\n"
+ " },\n"
+ " {\n"
+ " \"text\": \"Nice to know you ;)\",\n"
+ " \"label\": \"Spam\"\n"
+ " },\n"
+ " {\n"
+ " \"text\": \"Please help me?\",\n"
+ " \"label\": \"Spam\"\n"
+ " },\n"
+ " {\n"
+ " \"text\": \"Your parcel will be delivered today\",\n"
+ " \"label\": \"Not spam\"\n"
+ " },\n"
+ " {\n"
+ " \"text\": \"Review changes to our Terms and Conditions\",\n"
+ " \"label\": \"Not spam\"\n"
+ " },\n"
+ " {\n"
+ " \"text\": \"Weekly sync notes\",\n"
+ " \"label\": \"Not spam\"\n"
+ " },\n"
+ " {\n"
+ " \"text\": \"Re: Follow up from todays meeting\",\n"
+ " \"label\": \"Not spam\"\n"
+ " },\n"
+ " {\n"
+ " \"text\": \"Pre-read for tomorrow\",\n"
+ " \"label\": \"Not spam\"\n"
+ " }\n"
+ " ]\n"
+ " }\n"
+ "}";

response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
List responseList = (List) responseMap.get("inference_results");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("output");
responseMap = (Map) responseList.get(0);
responseMap = (Map) responseMap.get("dataAsMap");
responseList = (List) responseMap.get("classifications");
assertFalse(responseList.isEmpty());
}

public static Response createConnector(String input) throws IOException {
return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null);
}
Expand Down
Loading