From 60d540a0548be51c6b007cf04fe0a7107d30223b Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Mon, 3 Mar 2025 12:42:13 -0800 Subject: [PATCH 1/3] Use model type to check local or remote model (#3597) * use model type to check local or remote model Signed-off-by: Sicheng Song * spotless Signed-off-by: Sicheng Song * Ignore test resource Signed-off-by: Sicheng Song * Add java doc Signed-off-by: Sicheng Song * Handle when model not in cache Signed-off-by: Sicheng Song * Handle when model not in cache Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song (cherry picked from commit 696b1e1739cbbb5ef138255339acd02375719570) --- .../ml/rest/RestMLPredictionAction.java | 45 +++++++++++-------- .../ml/rest/RestMLPredictionActionTests.java | 12 ++--- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index 68c0146ab2..d4e4ee8525 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.List; import java.util.Locale; +import java.util.Objects; import java.util.Optional; import org.opensearch.client.node.NodeClient; @@ -82,27 +83,30 @@ public List routes() { @Override public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - String algorithm = request.param(PARAMETER_ALGORITHM); + String userAlgorithm = request.param(PARAMETER_ALGORITHM); String modelId = getParameterId(request, PARAMETER_MODEL_ID); Optional functionName = modelManager.getOptionalModelFunctionName(modelId); - if (algorithm == null && functionName.isPresent()) { - algorithm = functionName.get().name(); - } - - if (algorithm != null) { - MLPredictionTaskRequest mlPredictionTaskRequest = getRequest(modelId, algorithm, request); - return channel -> client - .execute(MLPredictionTaskAction.INSTANCE, mlPredictionTaskRequest, new RestToXContentListener<>(channel)); + // check if the model is in cache + if (functionName.isPresent()) { + MLPredictionTaskRequest predictionRequest = getRequest( + modelId, + functionName.get().name(), + Objects.requireNonNullElse(userAlgorithm, functionName.get().name()), + request + ); + return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel)); } + // If the model isn't in cache return channel -> { ActionListener listener = ActionListener.wrap(mlModel -> { - String algoName = mlModel.getAlgorithm().name(); + String modelType = mlModel.getAlgorithm().name(); + String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name()); client .execute( MLPredictionTaskAction.INSTANCE, - getRequest(modelId, algoName, request), + getRequest(modelId, modelType, modelAlgorithm, request), new RestToXContentListener<>(channel) ); }, e -> { @@ -120,17 +124,22 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client } /** - * Creates a MLPredictionTaskRequest from a RestRequest + * Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on + * enabled features and model types, and parses the input data for prediction. * - * @param request RestRequest - * @return MLPredictionTaskRequest + * @param modelId The ID of the ML model to use for prediction + * @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model + * @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model + * @param request The REST request containing prediction input data + * @return MLPredictionTaskRequest configured with the model and input parameters */ @VisibleForTesting - MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException { + MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException { ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request)); - if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { + if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); - } else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { + } else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT))) + && !mlFeatureEnabledSetting.isLocalModelEnabled()) { throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); } else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) { throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG); @@ -140,7 +149,7 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLInput mlInput = MLInput.parse(parser, algorithm, actionType); + MLInput mlInput = MLInput.parse(parser, userAlgorithm, actionType); return new MLPredictionTaskRequest(modelId, mlInput, null); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java index c90f765ed0..e2bbfad79e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java @@ -69,7 +69,7 @@ public class RestMLPredictionActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.empty()); + when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.of(FunctionName.REMOTE)); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true); restMLPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting); @@ -121,7 +121,8 @@ public void testRoutes_Batch() { public void testGetRequest() throws IOException { RestRequest request = getRestRequest_PredictModel(); - MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request); + MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction + .getRequest("modelId", FunctionName.KMEANS.name(), FunctionName.KMEANS.name(), request); MLInput mlInput = mlPredictionTaskRequest.getMlInput(); verifyParsedKMeansMLInput(mlInput); @@ -133,7 +134,8 @@ public void testGetRequest_RemoteInferenceDisabled() throws IOException { when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); RestRequest request = getRestRequest_PredictModel(); - MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.REMOTE.name(), request); + MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction + .getRequest("modelId", FunctionName.REMOTE.name(), "text_embedding", request); } public void testGetRequest_LocalModelInferenceDisabled() throws IOException { @@ -143,7 +145,7 @@ public void testGetRequest_LocalModelInferenceDisabled() throws IOException { when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false); RestRequest request = getRestRequest_PredictModel(); MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction - .getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), request); + .getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), "text_embedding", request); } public void testPrepareRequest() throws Exception { @@ -182,7 +184,7 @@ public void testPrepareBatchRequest_WrongActionType() throws Exception { thrown.expectMessage("Wrong Action Type"); RestRequest request = getBatchRestRequest_WrongActionType(); - restMLPredictionAction.getRequest("model id", "remote", request); + restMLPredictionAction.getRequest("model id", "remote", "text_embedding", request); } @Ignore From c507d39efcfddc7bab38da9003af6d9a3dc4178c Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Tue, 4 Mar 2025 21:53:56 -0800 Subject: [PATCH 2/3] Updated to latest versions of actions to resolve CI issues Signed-off-by: rithin-pullela-aws --- .github/workflows/CI-workflow.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index 1757134442..4d79a898d8 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -79,7 +79,7 @@ jobs: flags: ml-commons token: ${{ secrets.CODECOV_TOKEN }} - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: ml-plugin-linux-${{ matrix.java }} path: ${{ steps.step-build-test-linux.outputs.build-test-linux }} From 01298edc61e7f0ee0b603d594d3f0c8b158352d2 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Wed, 5 Mar 2025 16:42:21 -0800 Subject: [PATCH 3/3] Removed test case which is no langer valid Signed-off-by: rithin-pullela-aws --- .../ml/rest/RestMLRemoteInferenceIT.java | 115 ------------------ 1 file changed, 115 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 05cb4898bf..795efa62a9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -714,121 +714,6 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti assertFalse(((String) responseMap.get("text")).isEmpty()); } - @Ignore - public void testCohereClassifyModel() throws IOException, InterruptedException { - // Skip test if key is null - if (COHERE_KEY == null) { - return; - } - String entity = "{\n" - + " \"name\": \"Cohere classify model Connector\",\n" - + " \"description\": \"The connector to public Cohere classify model service\",\n" - + " \"version\": 1,\n" - + " \"client_config\": {\n" - + " \"max_connection\": 20,\n" - + " \"connection_timeout\": 50000,\n" - + " \"read_timeout\": 50000\n" - + " },\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.cohere.ai\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": \"20\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"cohere_key\": \"" - + COHERE_KEY - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; - Response response = createConnector(entity); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("cohere classify model", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" - + " \"parameters\": {\n" - + " \"inputs\": [\n" - + " \"Confirm your email address\",\n" - + " \"hey i need u to send some $\"\n" - + " ],\n" - + " \"examples\": [\n" - + " {\n" - + " \"text\": \"Dermatologists don't like her!\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Hello, open to this?\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"I need help please wire me $1000 right now\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Nice to know you ;)\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Please help me?\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Your parcel will be delivered today\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Review changes to our Terms and Conditions\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Weekly sync notes\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Re: Follow up from todays meeting\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Pre-read for tomorrow\",\n" - + " \"label\": \"Not spam\"\n" - + " }\n" - + " ]\n" - + " }\n" - + "}"; - - response = predictRemoteModel(modelId, predictInput); - responseMap = parseResponseToMap(response); - List responseList = (List) responseMap.get("inference_results"); - responseMap = (Map) responseList.get(0); - responseList = (List) responseMap.get("output"); - responseMap = (Map) responseList.get(0); - responseMap = (Map) responseMap.get("dataAsMap"); - responseList = (List) responseMap.get("classifications"); - assertFalse(responseList.isEmpty()); - } - public static Response createConnector(String input) throws IOException { return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null); }