Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -100,14 +102,26 @@ public void inferenceSentences(
@NonNull final TextInferenceRequest inferenceRequest,
@NonNull final ActionListener<List<List<Number>>> listener
) {
retryableInferenceSentencesWithVectorResult(inferenceRequest, 0, listener);
retryableInference(
inferenceRequest,
0,
() -> createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts()),
this::buildVectorFromResponse,
listener
);
}

public void inferenceSentencesWithMapResult(
@NonNull final TextInferenceRequest inferenceRequest,
@NonNull final ActionListener<List<Map<String, ?>>> listener
) {
retryableInferenceSentencesWithMapResult(inferenceRequest, 0, listener);
retryableInference(
inferenceRequest,
0,
() -> createMLTextInput(null, inferenceRequest.getInputTexts()),
this::buildMapResultFromResponse,
listener
);
}

/**
Expand All @@ -119,7 +133,13 @@ public void inferenceSentencesWithMapResult(
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
*/
public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener<List<Number>> listener) {
retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, 0, listener);
retryableInference(
inferenceRequest,
0,
() -> createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects()),
this::buildSingleVectorFromResponse,
listener
);
}

/**
Expand All @@ -134,63 +154,42 @@ public void inferenceSimilarity(
@NonNull SimilarityInferenceRequest inferenceRequest,
@NonNull final ActionListener<List<Float>> listener
) {
retryableInferenceSimilarityWithVectorResult(inferenceRequest, 0, listener);
retryableInference(
inferenceRequest,
0,
() -> createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts()),
(mlOutput) -> buildVectorFromResponse(mlOutput).stream().map(v -> v.getFirst().floatValue()).collect(Collectors.toList()),
listener
);
}

private void retryableInferenceSentencesWithMapResult(
final TextInferenceRequest inferenceRequest,
/**
* A generic function to make retryable inference request.
* It allows caller to specify functions to vend their MLInput and process MLOutput.
*
* @param inferenceRequest inference request
* @param retryTime retry time
* @param mlInputSupplier a supplier to vend MLInput
* @param mlOutputBuilder a consumer to consume MLOutput and provide processed output format.
* @param listener a callback to handle result or failures.
* @param <T> type of processed MLOutput format.
*/
private <T> void retryableInference(
final InferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Map<String, ?>>> listener
final Supplier<MLInput> mlInputSupplier,
final Function<MLOutput, T> mlOutputBuilder,
final ActionListener<T> listener
) {
MLInput mlInput = createMLTextInput(null, inferenceRequest.getInputTexts());
MLInput mlInput = mlInputSupplier.get();
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput);
final T result = mlOutputBuilder.apply(mlOutput);
listener.onResponse(result);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithMapResult(inferenceRequest, retryTime + 1, listener),
listener
)
));
}

private void retryableInferenceSentencesWithVectorResult(
final TextInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<List<Number>>> listener
) {
MLInput mlInput = createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Number>> vector = buildVectorFromResponse(mlOutput);
listener.onResponse(vector);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithVectorResult(inferenceRequest, retryTime + 1, listener),
listener
)
));
}

private void retryableInferenceSimilarityWithVectorResult(
final SimilarityInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream()
.map(v -> v.getFirst().floatValue())
.collect(Collectors.toList());
listener.onResponse(scores);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSimilarityWithVectorResult(inferenceRequest, retryTime + 1, listener),
() -> retryableInference(inferenceRequest, retryTime + 1, mlInputSupplier, mlOutputBuilder, listener),
listener
)
));
Expand Down Expand Up @@ -270,26 +269,6 @@ private <T extends Number> List<T> buildSingleVectorFromResponse(final MLOutput
return vector.isEmpty() ? new ArrayList<>() : vector.get(0);
}

private void retryableInferenceSentencesWithSingleVectorResult(
final MapInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Number>> listener
) {
MLInput mlInput = createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Number> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence is : {} ", vector);
listener.onResponse(vector);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, retryTime + 1, listener),
listener
)
));
}

/**
* Process the highlighting output from ML model response.
* Converts the model output into a list of maps containing highlighting information.
Expand Down Expand Up @@ -471,7 +450,10 @@ public void inferenceSentenceHighlighting(
@NonNull final SentenceHighlightingRequest inferenceRequest,
@NonNull final ActionListener<List<Map<String, Object>>> listener
) {
retryableInferenceSentenceHighlighting(inferenceRequest, 0, listener);
retryableInference(inferenceRequest, 0, () -> {
MLInputDataset inputDataset = new QuestionAnsweringInputDataSet(inferenceRequest.getQuestion(), inferenceRequest.getContext());
return new MLInput(FunctionName.QUESTION_ANSWERING, null, inputDataset);
}, (mlOutput) -> processHighlightingOutput((ModelTensorOutput) mlOutput), listener);
}

/**
Expand Down
Loading