diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index fc87906493..af636c6dbd 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -17,6 +17,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import org.opensearch.core.action.ActionListener; @@ -100,14 +102,26 @@ public void inferenceSentences( @NonNull final TextInferenceRequest inferenceRequest, @NonNull final ActionListener>> listener ) { - retryableInferenceSentencesWithVectorResult(inferenceRequest, 0, listener); + retryableInference( + inferenceRequest, + 0, + () -> createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts()), + this::buildVectorFromResponse, + listener + ); } public void inferenceSentencesWithMapResult( @NonNull final TextInferenceRequest inferenceRequest, @NonNull final ActionListener>> listener ) { - retryableInferenceSentencesWithMapResult(inferenceRequest, 0, listener); + retryableInference( + inferenceRequest, + 0, + () -> createMLTextInput(null, inferenceRequest.getInputTexts()), + this::buildMapResultFromResponse, + listener + ); } /** @@ -119,7 +133,13 @@ public void inferenceSentencesWithMapResult( * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. */ public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener> listener) { - retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, 0, listener); + retryableInference( + inferenceRequest, + 0, + () -> createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects()), + this::buildSingleVectorFromResponse, + listener + ); } /** @@ -134,63 +154,42 @@ public void inferenceSimilarity( @NonNull SimilarityInferenceRequest inferenceRequest, @NonNull final ActionListener> listener ) { - retryableInferenceSimilarityWithVectorResult(inferenceRequest, 0, listener); + retryableInference( + inferenceRequest, + 0, + () -> createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts()), + (mlOutput) -> buildVectorFromResponse(mlOutput).stream().map(v -> v.getFirst().floatValue()).collect(Collectors.toList()), + listener + ); } - private void retryableInferenceSentencesWithMapResult( - final TextInferenceRequest inferenceRequest, + /** + * A generic function to make retryable inference request. + * It allows caller to specify functions to vend their MLInput and process MLOutput. + * + * @param inferenceRequest inference request + * @param retryTime retry time + * @param mlInputSupplier a supplier to vend MLInput + * @param mlOutputBuilder a consumer to consume MLOutput and provide processed output format. + * @param listener a callback to handle result or failures. + * @param type of processed MLOutput format. + */ + private void retryableInference( + final InferenceRequest inferenceRequest, final int retryTime, - final ActionListener>> listener + final Supplier mlInputSupplier, + final Function mlOutputBuilder, + final ActionListener listener ) { - MLInput mlInput = createMLTextInput(null, inferenceRequest.getInputTexts()); + MLInput mlInput = mlInputSupplier.get(); mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> { - final List> result = buildMapResultFromResponse(mlOutput); + final T result = mlOutputBuilder.apply(mlOutput); listener.onResponse(result); }, e -> RetryUtil.handleRetryOrFailure( e, retryTime, - () -> retryableInferenceSentencesWithMapResult(inferenceRequest, retryTime + 1, listener), - listener - ) - )); - } - - private void retryableInferenceSentencesWithVectorResult( - final TextInferenceRequest inferenceRequest, - final int retryTime, - final ActionListener>> listener - ) { - MLInput mlInput = createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts()); - mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> { - final List> vector = buildVectorFromResponse(mlOutput); - listener.onResponse(vector); - }, - e -> RetryUtil.handleRetryOrFailure( - e, - retryTime, - () -> retryableInferenceSentencesWithVectorResult(inferenceRequest, retryTime + 1, listener), - listener - ) - )); - } - - private void retryableInferenceSimilarityWithVectorResult( - final SimilarityInferenceRequest inferenceRequest, - final int retryTime, - final ActionListener> listener - ) { - MLInput mlInput = createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts()); - mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> { - final List scores = buildVectorFromResponse(mlOutput).stream() - .map(v -> v.getFirst().floatValue()) - .collect(Collectors.toList()); - listener.onResponse(scores); - }, - e -> RetryUtil.handleRetryOrFailure( - e, - retryTime, - () -> retryableInferenceSimilarityWithVectorResult(inferenceRequest, retryTime + 1, listener), + () -> retryableInference(inferenceRequest, retryTime + 1, mlInputSupplier, mlOutputBuilder, listener), listener ) )); @@ -270,26 +269,6 @@ private List buildSingleVectorFromResponse(final MLOutput return vector.isEmpty() ? new ArrayList<>() : vector.get(0); } - private void retryableInferenceSentencesWithSingleVectorResult( - final MapInferenceRequest inferenceRequest, - final int retryTime, - final ActionListener> listener - ) { - MLInput mlInput = createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects()); - mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> { - final List vector = buildSingleVectorFromResponse(mlOutput); - log.debug("Inference Response for input sentence is : {} ", vector); - listener.onResponse(vector); - }, - e -> RetryUtil.handleRetryOrFailure( - e, - retryTime, - () -> retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, retryTime + 1, listener), - listener - ) - )); - } - /** * Process the highlighting output from ML model response. * Converts the model output into a list of maps containing highlighting information. @@ -471,7 +450,10 @@ public void inferenceSentenceHighlighting( @NonNull final SentenceHighlightingRequest inferenceRequest, @NonNull final ActionListener>> listener ) { - retryableInferenceSentenceHighlighting(inferenceRequest, 0, listener); + retryableInference(inferenceRequest, 0, () -> { + MLInputDataset inputDataset = new QuestionAnsweringInputDataSet(inferenceRequest.getQuestion(), inferenceRequest.getContext()); + return new MLInput(FunctionName.QUESTION_ANSWERING, null, inputDataset); + }, (mlOutput) -> processHighlightingOutput((ModelTensorOutput) mlOutput), listener); } /**