diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index e0e028d9f0..c3fb5aa864 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -88,6 +88,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client String modelId = getParameterId(request, PARAMETER_MODEL_ID); Optional functionName = modelManager.getOptionalModelFunctionName(modelId); + if (userAlgorithm != null) { + MLPredictionTaskRequest mlPredictionTaskRequest = getRequest(modelId, null, userAlgorithm, request); + return channel -> client + .execute(MLPredictionTaskAction.INSTANCE, mlPredictionTaskRequest, new RestToXContentListener<>(channel)); + } + // check if the model is in cache if (functionName.isPresent()) { MLPredictionTaskRequest predictionRequest = getRequest( @@ -143,9 +149,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException { String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request)); - if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { + if (modelType != null && FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); - } else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT))) + } else if (modelType != null + && FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT))) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); } else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {