Skip to content

Commit 080f974

Browse files
authored
Add output transformation support with mean pooling for ML inference processors (opensearch-project#4236)
* add meanPooling to Output Transformations Signed-off-by: Mingshi Liu <[email protected]> * use float data type same as knn vector Signed-off-by: Mingshi Liu <[email protected]> * removed duplicated license header Signed-off-by: Mingshi Liu <[email protected]> * add toDo to javadoc Signed-off-by: Mingshi Liu <[email protected]> * add helpful method Signed-off-by: Mingshi Liu <[email protected]> * clean up Signed-off-by: Mingshi Liu <[email protected]> --------- Signed-off-by: Mingshi Liu <[email protected]>
1 parent c05261d commit 080f974

File tree

6 files changed

+528
-3
lines changed

6 files changed

+528
-3
lines changed

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,14 @@ private void appendFieldValue(
395395
throw new RuntimeException("model inference output is null");
396396
}
397397

398-
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);
398+
// Check if transformation is needed
399+
String baseFieldName = OutputTransformations.getBaseFieldName(modelOutputFieldName);
400+
Object modelOutputValue = getModelOutputValue(mlOutput, baseFieldName, ignoreMissing, fullResponsePath);
401+
402+
// Apply transformation if specified
403+
if (OutputTransformations.hasTransformation(modelOutputFieldName)) {
404+
modelOutputValue = OutputTransformations.applyTransformation(modelOutputFieldName, modelOutputValue);
405+
}
399406

400407
Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
401408
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,16 @@ private void updateIncomeQueryObject(
334334
try {
335335
newQueryField = outputMapEntry.getKey();
336336
String modelOutputFieldName = outputMapEntry.getValue();
337-
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);
337+
338+
// Check if transformation is needed
339+
String baseFieldName = OutputTransformations.getBaseFieldName(modelOutputFieldName);
340+
Object modelOutputValue = getModelOutputValue(mlOutput, baseFieldName, ignoreMissing, fullResponsePath);
341+
342+
// Apply transformation if specified
343+
if (OutputTransformations.hasTransformation(modelOutputFieldName)) {
344+
modelOutputValue = OutputTransformations.applyTransformation(modelOutputFieldName, modelOutputValue);
345+
}
346+
338347
requestContext.setAttribute(newQueryField, modelOutputValue);
339348

340349
// if output mapping is using jsonpath starts with $. or use dot path starts with ext.
@@ -359,7 +368,16 @@ private String updateQueryTemplate(String queryTemplate, Map<String, String> out
359368
for (Map.Entry<String, String> outputMapEntry : outputMapping.entrySet()) {
360369
String newQueryField = outputMapEntry.getKey();
361370
String modelOutputFieldName = outputMapEntry.getValue();
362-
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);
371+
372+
// Check if transformation is needed
373+
String baseFieldName = OutputTransformations.getBaseFieldName(modelOutputFieldName);
374+
Object modelOutputValue = getModelOutputValue(mlOutput, baseFieldName, ignoreMissing, fullResponsePath);
375+
376+
// Apply transformation if specified
377+
if (OutputTransformations.hasTransformation(modelOutputFieldName)) {
378+
modelOutputValue = OutputTransformations.applyTransformation(modelOutputFieldName, modelOutputValue);
379+
}
380+
363381
if (modelOutputValue instanceof Map) {
364382
modelOutputValue = toJson(modelOutputValue);
365383
}
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.processor;
7+
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
11+
/**
12+
* Utility class for output transformations in ML inference processor
13+
*
14+
* TODO: Support additional pooling methods like max pooling, min pooling, etc.
15+
*/
16+
public class OutputTransformations {
17+
18+
private static final String MEAN_POOLING_SUFFIX = ".meanPooling()";
19+
private static final String MAX_POOLING_SUFFIX = ".maxPooling()";
20+
21+
/**
22+
* Checks if the output field name contains a transformation function
23+
*/
24+
public static boolean hasTransformation(String outputFieldName) {
25+
return outputFieldName != null && (outputFieldName.endsWith(MEAN_POOLING_SUFFIX) || outputFieldName.endsWith(MAX_POOLING_SUFFIX));
26+
}
27+
28+
/**
29+
* Extracts the base field name without transformation function
30+
*/
31+
public static String getBaseFieldName(String outputFieldName) {
32+
if (outputFieldName != null) {
33+
if (outputFieldName.endsWith(MEAN_POOLING_SUFFIX)) {
34+
return outputFieldName.substring(0, outputFieldName.length() - MEAN_POOLING_SUFFIX.length());
35+
} else if (outputFieldName.endsWith(MAX_POOLING_SUFFIX)) {
36+
return outputFieldName.substring(0, outputFieldName.length() - MAX_POOLING_SUFFIX.length());
37+
}
38+
}
39+
return outputFieldName;
40+
}
41+
42+
/**
43+
* Applies mean pooling transformation to a nested array of floats
44+
*/
45+
public static Object applyMeanPooling(Object value) {
46+
if (!(value instanceof List)) {
47+
throw new IllegalArgumentException("Mean pooling requires a list input");
48+
}
49+
50+
List<?> outerList = (List<?>) value;
51+
if (outerList.isEmpty() || !(outerList.getFirst() instanceof List)) {
52+
throw new IllegalArgumentException("Mean pooling requires nested array structure");
53+
}
54+
55+
List<?> firstVector = (List<?>) outerList.getFirst();
56+
int dimensions = firstVector.size();
57+
float[] meanVector = new float[dimensions];
58+
59+
// Sum all vectors
60+
for (Object vectorObj : outerList) {
61+
if (!(vectorObj instanceof List)) {
62+
throw new IllegalArgumentException("All elements must be vectors (lists)");
63+
}
64+
List<?> vector = (List<?>) vectorObj;
65+
if (vector.size() != dimensions) {
66+
throw new IllegalArgumentException("All vectors must have the same dimension");
67+
}
68+
69+
for (int i = 0; i < dimensions; i++) {
70+
Object element = vector.get(i);
71+
float val = element instanceof Number ? ((Number) element).floatValue() : 0.0f;
72+
meanVector[i] += val;
73+
}
74+
}
75+
76+
// Calculate mean and convert to List
77+
List<Float> result = new ArrayList<>();
78+
for (int i = 0; i < dimensions; i++) {
79+
result.add(meanVector[i] / outerList.size());
80+
}
81+
82+
return result;
83+
}
84+
85+
/**
86+
* Applies max pooling transformation to a nested array of floats
87+
*/
88+
public static Object applyMaxPooling(Object value) {
89+
if (!(value instanceof List)) {
90+
throw new IllegalArgumentException("Max pooling requires a list input");
91+
}
92+
93+
List<?> outerList = (List<?>) value;
94+
if (outerList.isEmpty() || !(outerList.getFirst() instanceof List)) {
95+
throw new IllegalArgumentException("Max pooling requires nested array structure");
96+
}
97+
98+
List<?> firstVector = (List<?>) outerList.getFirst();
99+
int dimensions = firstVector.size();
100+
float[] maxVector = new float[dimensions];
101+
102+
// Initialize with first vector
103+
List<?> firstVectorList = (List<?>) outerList.getFirst();
104+
for (int i = 0; i < dimensions; i++) {
105+
Object element = firstVectorList.get(i);
106+
maxVector[i] = element instanceof Number ? ((Number) element).floatValue() : Float.NEGATIVE_INFINITY;
107+
}
108+
109+
// Find max across all vectors
110+
for (Object vectorObj : outerList) {
111+
if (!(vectorObj instanceof List)) {
112+
throw new IllegalArgumentException("All elements must be vectors (lists)");
113+
}
114+
List<?> vector = (List<?>) vectorObj;
115+
if (vector.size() != dimensions) {
116+
throw new IllegalArgumentException("All vectors must have the same dimension");
117+
}
118+
119+
for (int i = 0; i < dimensions; i++) {
120+
Object element = vector.get(i);
121+
float val = element instanceof Number ? ((Number) element).floatValue() : Float.NEGATIVE_INFINITY;
122+
maxVector[i] = Math.max(maxVector[i], val);
123+
}
124+
}
125+
126+
// Convert to List
127+
List<Float> result = new ArrayList<>();
128+
for (int i = 0; i < dimensions; i++) {
129+
result.add(maxVector[i]);
130+
}
131+
132+
return result;
133+
}
134+
135+
/**
136+
* Applies the appropriate transformation based on the field name suffix
137+
*/
138+
public static Object applyTransformation(String outputFieldName, Object value) {
139+
if (outputFieldName != null) {
140+
if (outputFieldName.endsWith(MEAN_POOLING_SUFFIX)) {
141+
return applyMeanPooling(value);
142+
} else if (outputFieldName.endsWith(MAX_POOLING_SUFFIX)) {
143+
return applyMaxPooling(value);
144+
}
145+
}
146+
return value;
147+
}
148+
}

plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,6 +2084,67 @@ private static List<Map<String, String>> getOutputMapsForNestedObjectChunks() {
20842084
return outputMap;
20852085
}
20862086

2087+
public void testExecute_MeanPoolingTransformation() {
2088+
List<Map<String, String>> outputMap = new ArrayList<>();
2089+
Map<String, String> output = new HashMap<>();
2090+
output.put("multi_vectors", "image_embeddings[0]");
2091+
output.put("knn_vector", "image_embeddings[0].meanPooling()");
2092+
outputMap.add(output);
2093+
2094+
MLInferenceIngestProcessor processor = createMLInferenceProcessor(
2095+
"model1",
2096+
null,
2097+
outputMap,
2098+
null,
2099+
false,
2100+
"REMOTE",
2101+
false,
2102+
false,
2103+
false,
2104+
null
2105+
);
2106+
2107+
Map<String, Object> sourceAndMetadata = new HashMap<>();
2108+
sourceAndMetadata.put("key1", "value1");
2109+
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
2110+
2111+
// Mock nested array structure for image embeddings
2112+
List<List<Double>> imageEmbeddings = Arrays
2113+
.asList(Arrays.asList(1.0, 2.0, 3.0), Arrays.asList(4.0, 5.0, 6.0), Arrays.asList(7.0, 8.0, 9.0));
2114+
2115+
Map<String, Object> dataAsMap = new HashMap<>();
2116+
dataAsMap.put("image_embeddings", Arrays.asList(imageEmbeddings));
2117+
2118+
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(dataAsMap).build();
2119+
2120+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
2121+
2122+
ModelTensorOutput mlOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
2123+
2124+
MLTaskResponse response = MLTaskResponse.builder().output(mlOutput).build();
2125+
2126+
doAnswer(invocation -> {
2127+
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
2128+
actionListener.onResponse(response);
2129+
return null;
2130+
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
2131+
2132+
BiConsumer<IngestDocument, Exception> handler = mock(BiConsumer.class);
2133+
processor.execute(ingestDocument, handler);
2134+
2135+
verify(handler).accept(eq(ingestDocument), isNull());
2136+
2137+
// Verify multi_vectors contains the original nested array
2138+
assertEquals(imageEmbeddings, ingestDocument.getFieldValue("multi_vectors", Object.class));
2139+
2140+
// Verify knn_vector contains the mean pooled result
2141+
List<Float> meanPooled = (List<Float>) ingestDocument.getFieldValue("knn_vector", Object.class);
2142+
assertEquals(3, meanPooled.size());
2143+
assertEquals(4.0, meanPooled.get(0), 0.001); // (1+4+7)/3
2144+
assertEquals(5.0, meanPooled.get(1), 0.001); // (2+5+8)/3
2145+
assertEquals(6.0, meanPooled.get(2), 0.001); // (3+6+9)/3
2146+
}
2147+
20872148
private static List<Map<String, String>> getInputMapsForNestedObjectChunks(String documentFieldPath) {
20882149
List<Map<String, String>> inputMap = new ArrayList<>();
20892150
Map<String, String> input = new HashMap<>();

plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.mockito.ArgumentMatchers.any;
99
import static org.mockito.ArgumentMatchers.eq;
1010
import static org.mockito.Mockito.doAnswer;
11+
import static org.mockito.Mockito.mock;
1112
import static org.mockito.Mockito.times;
1213
import static org.mockito.Mockito.verify;
1314
import static org.opensearch.ml.common.utils.StringUtils.toJson;
@@ -2567,4 +2568,92 @@ public void testCreateNoOutputMapFields() throws Exception {
25672568
);
25682569
}
25692570

2571+
/**
2572+
* Tests mean pooling transformation in search request processor
2573+
* @throws Exception if an error occurs during the test
2574+
*/
2575+
public void testExecute_MeanPoolingTransformation() throws Exception {
2576+
String modelInputField = "inputs";
2577+
String originalQueryField = "query.term.text.value";
2578+
String multiVectorField = "ext.ml_inference.multi_vectors";
2579+
String knnVectorField = "ext.ml_inference.knn_vector";
2580+
2581+
List<Map<String, String>> inputMap = new ArrayList<>();
2582+
Map<String, String> input = new HashMap<>();
2583+
input.put(modelInputField, originalQueryField);
2584+
inputMap.add(input);
2585+
2586+
List<Map<String, String>> outputMap = new ArrayList<>();
2587+
Map<String, String> output = new HashMap<>();
2588+
output.put(multiVectorField, "image_embeddings[0]");
2589+
output.put(knnVectorField, "image_embeddings[0].meanPooling()");
2590+
outputMap.add(output);
2591+
2592+
MLInferenceSearchRequestProcessor processor = new MLInferenceSearchRequestProcessor(
2593+
"model1",
2594+
null,
2595+
inputMap,
2596+
outputMap,
2597+
null,
2598+
null,
2599+
null,
2600+
1,
2601+
"tag",
2602+
"description",
2603+
false,
2604+
"REMOTE",
2605+
false,
2606+
false,
2607+
"{ \"parameters\": ${ml_inference.parameters} }",
2608+
client,
2609+
TEST_XCONTENT_REGISTRY_FOR_QUERY
2610+
);
2611+
2612+
String inputQuery = "{\"query\":{\"term\":{\"text\":{\"value\":\"foo\",\"boost\":1.0}}}}";
2613+
QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo");
2614+
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);
2615+
SearchRequest request = new SearchRequest().source(source);
2616+
2617+
// Mock nested array structure for image embeddings
2618+
List<List<Double>> imageEmbeddings = Arrays
2619+
.asList(Arrays.asList(1.0, 2.0, 3.0), Arrays.asList(4.0, 5.0, 6.0), Arrays.asList(7.0, 8.0, 9.0));
2620+
2621+
Map<String, Object> dataAsMap = new HashMap<>();
2622+
dataAsMap.put("image_embeddings", Arrays.asList(imageEmbeddings));
2623+
2624+
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(dataAsMap).build();
2625+
2626+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
2627+
2628+
ModelTensorOutput mlOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
2629+
2630+
MLTaskResponse response = MLTaskResponse.builder().output(mlOutput).build();
2631+
2632+
doAnswer(invocation -> {
2633+
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
2634+
actionListener.onResponse(response);
2635+
return null;
2636+
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
2637+
2638+
PipelineProcessingContext requestContext = new PipelineProcessingContext();
2639+
ActionListener<SearchRequest> listener = mock(ActionListener.class);
2640+
2641+
processor.processRequestAsync(request, requestContext, listener);
2642+
2643+
ArgumentCaptor<SearchRequest> argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class);
2644+
verify(listener).onResponse(argumentCaptor.capture());
2645+
2646+
SearchRequest capturedRequest = argumentCaptor.getValue();
2647+
2648+
// Verify multi_vectors contains the original nested array
2649+
assertEquals(imageEmbeddings, requestContext.getAttribute(multiVectorField));
2650+
2651+
// Verify knn_vector contains the mean pooled result
2652+
List<Float> meanPooled = (List<Float>) requestContext.getAttribute(knnVectorField);
2653+
assertEquals(3, meanPooled.size());
2654+
assertEquals(4.0, meanPooled.get(0), 0.001); // (1+4+7)/3
2655+
assertEquals(5.0, meanPooled.get(1), 0.001); // (2+5+8)/3
2656+
assertEquals(6.0, meanPooled.get(2), 0.001); // (3+6+9)/3
2657+
}
2658+
25702659
}

0 commit comments

Comments
 (0)