From d6bdf80f718c34acaf181b09cd31179f64434874 Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Mon, 3 Mar 2025 12:42:13 -0800 Subject: [PATCH 1/3] Use model type to check local or remote model (#3597) * use model type to check local or remote model Signed-off-by: Sicheng Song * spotless Signed-off-by: Sicheng Song * Ignore test resource Signed-off-by: Sicheng Song * Add java doc Signed-off-by: Sicheng Song * Handle when model not in cache Signed-off-by: Sicheng Song * Handle when model not in cache Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song (cherry picked from commit 696b1e1739cbbb5ef138255339acd02375719570) --- .../ml/rest/RestMLPredictionAction.java | 44 +++++++++++-------- .../ml/rest/RestMLPredictionActionTests.java | 10 +++-- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index 5af116eb6f..eb51bb4c61 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.util.List; import java.util.Locale; +import java.util.Objects; import java.util.Optional; import org.opensearch.client.node.NodeClient; @@ -75,27 +76,30 @@ public List routes() { @Override public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - String algorithm = request.param(PARAMETER_ALGORITHM); + String userAlgorithm = request.param(PARAMETER_ALGORITHM); String modelId = getParameterId(request, PARAMETER_MODEL_ID); Optional functionName = modelManager.getOptionalModelFunctionName(modelId); - if (algorithm == null && functionName.isPresent()) { - algorithm = functionName.get().name(); - } - - if (algorithm != null) { - MLPredictionTaskRequest mlPredictionTaskRequest = getRequest(modelId, algorithm, request); - return channel -> client - .execute(MLPredictionTaskAction.INSTANCE, mlPredictionTaskRequest, new RestToXContentListener<>(channel)); + // check if the model is in cache + if (functionName.isPresent()) { + MLPredictionTaskRequest predictionRequest = getRequest( + modelId, + functionName.get().name(), + Objects.requireNonNullElse(userAlgorithm, functionName.get().name()), + request + ); + return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel)); } + // If the model isn't in cache return channel -> { ActionListener listener = ActionListener.wrap(mlModel -> { - String algoName = mlModel.getAlgorithm().name(); + String modelType = mlModel.getAlgorithm().name(); + String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name()); client .execute( MLPredictionTaskAction.INSTANCE, - getRequest(modelId, algoName, request), + getRequest(modelId, modelType, modelAlgorithm, request), new RestToXContentListener<>(channel) ); }, e -> { @@ -113,21 +117,25 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client } /** - * Creates a MLPredictionTaskRequest from a RestRequest + * Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on + * enabled features and model types, and parses the input data for prediction. * - * @param request RestRequest - * @return MLPredictionTaskRequest + * @param modelId The ID of the ML model to use for prediction + * @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model + * @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model + * @param request The REST request containing prediction input data + * @return MLPredictionTaskRequest configured with the model and input parameters */ @VisibleForTesting - MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException { - if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { + MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException { + if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); - } else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { + } else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); } XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLInput mlInput = MLInput.parse(parser, algorithm); + MLInput mlInput = MLInput.parse(parser, userAlgorithm); return new MLPredictionTaskRequest(modelId, mlInput, null); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java index d34e0fd00e..d33dfa822d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java @@ -66,7 +66,7 @@ public class RestMLPredictionActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.empty()); + when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.of(FunctionName.REMOTE)); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true); restMLPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting); @@ -109,7 +109,8 @@ public void testRoutes() { public void testGetRequest() throws IOException { RestRequest request = getRestRequest_PredictModel(); - MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request); + MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction + .getRequest("modelId", FunctionName.KMEANS.name(), FunctionName.KMEANS.name(), request); MLInput mlInput = mlPredictionTaskRequest.getMlInput(); verifyParsedKMeansMLInput(mlInput); @@ -121,7 +122,8 @@ public void testGetRequest_RemoteInferenceDisabled() throws IOException { when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); RestRequest request = getRestRequest_PredictModel(); - MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.REMOTE.name(), request); + MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction + .getRequest("modelId", FunctionName.REMOTE.name(), "text_embedding", request); } public void testGetRequest_LocalModelInferenceDisabled() throws IOException { @@ -131,7 +133,7 @@ public void testGetRequest_LocalModelInferenceDisabled() throws IOException { when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false); RestRequest request = getRestRequest_PredictModel(); MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction - .getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), request); + .getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), "text_embedding", request); } public void testPrepareRequest() throws Exception { From 00dd076330c0171953ed775f1c6de89f9c9813bd Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Tue, 4 Mar 2025 21:49:54 -0800 Subject: [PATCH 2/3] Updated to latest versions of actions to resolve CI issues Signed-off-by: rithin-pullela-aws --- .github/workflows/CI-workflow.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index 1757134442..4d79a898d8 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -79,7 +79,7 @@ jobs: flags: ml-commons token: ${{ secrets.CODECOV_TOKEN }} - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: ml-plugin-linux-${{ matrix.java }} path: ${{ steps.step-build-test-linux.outputs.build-test-linux }} From 41022d30f36212ebf72012cad3fdbd684479dc4b Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Wed, 5 Mar 2025 16:40:36 -0800 Subject: [PATCH 3/3] Removed test case which is no langer valid Signed-off-by: rithin-pullela-aws --- .../ml/rest/RestMLRemoteInferenceIT.java | 114 ------------------ 1 file changed, 114 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 59c2cc5fdd..82525adff0 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -615,120 +615,6 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti assertFalse(((String) responseMap.get("text")).isEmpty()); } - public void testCohereClassifyModel() throws IOException, InterruptedException { - // Skip test if key is null - if (COHERE_KEY == null) { - return; - } - String entity = "{\n" - + " \"name\": \"Cohere classify model Connector\",\n" - + " \"description\": \"The connector to public Cohere classify model service\",\n" - + " \"version\": 1,\n" - + " \"client_config\": {\n" - + " \"max_connection\": 20,\n" - + " \"connection_timeout\": 50000,\n" - + " \"read_timeout\": 50000\n" - + " },\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.cohere.ai\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": \"20\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"cohere_key\": \"" - + COHERE_KEY - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; - Response response = createConnector(entity); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("cohere classify model", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" - + " \"parameters\": {\n" - + " \"inputs\": [\n" - + " \"Confirm your email address\",\n" - + " \"hey i need u to send some $\"\n" - + " ],\n" - + " \"examples\": [\n" - + " {\n" - + " \"text\": \"Dermatologists don't like her!\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Hello, open to this?\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"I need help please wire me $1000 right now\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Nice to know you ;)\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Please help me?\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Your parcel will be delivered today\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Review changes to our Terms and Conditions\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Weekly sync notes\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Re: Follow up from todays meeting\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Pre-read for tomorrow\",\n" - + " \"label\": \"Not spam\"\n" - + " }\n" - + " ]\n" - + " }\n" - + "}"; - - response = predictRemoteModel(modelId, predictInput); - responseMap = parseResponseToMap(response); - List responseList = (List) responseMap.get("inference_results"); - responseMap = (Map) responseList.get(0); - responseList = (List) responseMap.get("output"); - responseMap = (Map) responseList.get(0); - responseMap = (Map) responseMap.get("dataAsMap"); - responseList = (List) responseMap.get("classifications"); - assertFalse(responseList.isEmpty()); - } - protected Response createConnector(String input) throws IOException { return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null); }