Skip to content

Commit 3dc9d7c

Browse files
committed
implement batch document update scenario for text embedding processor (#1217)
Signed-off-by: Will Hwang <[email protected]>
1 parent d4b46c8 commit 3dc9d7c

10 files changed

+566
-112
lines changed

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

+212-38
Large diffs are not rendered by default.

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

+56-25
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66

77
import java.util.List;
88
import java.util.Map;
9-
import java.util.Objects;
109
import java.util.function.BiConsumer;
1110
import java.util.function.Consumer;
12-
import java.util.stream.Collectors;
1311

1412
import org.opensearch.action.get.GetAction;
1513
import org.opensearch.action.get.GetRequest;
14+
import org.opensearch.action.get.MultiGetAction;
1615
import org.opensearch.cluster.service.ClusterService;
1716
import org.opensearch.core.action.ActionListener;
17+
import org.opensearch.core.common.util.CollectionUtils;
1818
import org.opensearch.env.Environment;
1919
import org.opensearch.ingest.IngestDocument;
20+
import org.opensearch.ingest.IngestDocumentWrapper;
2021
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
2122

2223
import lombok.extern.log4j.Log4j2;
@@ -74,29 +75,16 @@ public void doExecute(
7475
// if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings are copied
7576
String index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD).toString();
7677
String id = ingestDocument.getSourceAndMetadata().get(ID_FIELD).toString();
77-
openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> {
78-
final Map<String, Object> existingDocument = response.getSourceAsMap();
79-
if (existingDocument == null || existingDocument.isEmpty()) {
80-
makeInferenceCall(ingestDocument, processMap, inferenceList, handler);
81-
return;
82-
}
83-
// filter given ProcessMap by comparing existing document with ingestDocument
84-
Map<String, Object> filteredProcessMap = textEmbeddingInferenceFilter.filter(
85-
existingDocument,
86-
ingestDocument.getSourceAndMetadata(),
87-
processMap
88-
);
89-
// create inference list based on filtered ProcessMap
90-
List<String> filteredInferenceList = createInferenceList(filteredProcessMap).stream()
91-
.filter(Objects::nonNull)
92-
.collect(Collectors.toList());
93-
if (filteredInferenceList.isEmpty()) {
94-
handler.accept(ingestDocument, null);
95-
} else {
96-
makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler);
97-
}
98-
99-
}, e -> { handler.accept(null, e); }));
78+
openSearchClient.execute(
79+
GetAction.INSTANCE,
80+
new GetRequest(index, id),
81+
ActionListener.wrap(
82+
response -> getResponseHandler(response, ingestDocument, processMap, inferenceList, handler, textEmbeddingInferenceFilter),
83+
e -> {
84+
handler.accept(null, e);
85+
}
86+
)
87+
);
10088
}
10189

10290
@Override
@@ -106,4 +94,47 @@ public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler
10694
ActionListener.wrap(handler::accept, onException)
10795
);
10896
}
97+
98+
@Override
99+
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
100+
try {
101+
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
102+
handler.accept(ingestDocumentWrappers);
103+
return;
104+
}
105+
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
106+
List<String> inferenceList = constructInferenceTexts(dataForInferences);
107+
if (inferenceList.isEmpty()) {
108+
handler.accept(ingestDocumentWrappers);
109+
return;
110+
}
111+
// skip existing flag is turned off. Call doSubBatchExecute without filtering
112+
if (skipExisting == false) {
113+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
114+
return;
115+
}
116+
// skipExisting flag is turned on, eligible inference texts in dataForInferences will be compared and filtered after embeddings
117+
// are copied
118+
openSearchClient.execute(
119+
MultiGetAction.INSTANCE,
120+
buildMultiGetRequest(dataForInferences),
121+
ActionListener.wrap(
122+
response -> multiGetResponseHandler(
123+
response,
124+
ingestDocumentWrappers,
125+
inferenceList,
126+
dataForInferences,
127+
handler,
128+
textEmbeddingInferenceFilter
129+
),
130+
e -> {
131+
// When exception is thrown in for MultiGetAction, set exception to all ingestDocumentWrappers
132+
updateWithExceptions(getIngestDocumentWrappers(dataForInferences), handler, e);
133+
}
134+
)
135+
);
136+
} catch (Exception e) {
137+
updateWithExceptions(ingestDocumentWrappers, handler, e);
138+
}
139+
}
109140
}

src/main/java/org/opensearch/neuralsearch/processor/optimization/InferenceFilter.java

+22-18
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.List;
1414
import java.util.ListIterator;
1515
import java.util.Map;
16+
import java.util.Objects;
1617
import java.util.Optional;
1718

1819
/**
@@ -59,8 +60,8 @@ public abstract Object filterInferenceValue(
5960
);
6061

6162
/**
62-
* Abstract helper method to filter individual values based on the existing and new metadata maps.
63-
* Implementations should provide logic to compare values and determine if embeddings can be reused.
63+
* Abstract helper method to filter individual objects based on the existing and new metadata maps.
64+
* Implementations should provide logic to compare objects and determine if embeddings can be reused.
6465
*
6566
* @param embeddingKey The dot-notation path for the embedding field
6667
* @param processValue The value to be checked for potential embedding reuse
@@ -71,7 +72,7 @@ public abstract Object filterInferenceValue(
7172
* @return The processed value or null if embeddings are reused
7273
*/
7374

74-
public abstract Object copyEmbeddingForSingleValue(
75+
public abstract Object copyEmbeddingForSingleObject(
7576
String embeddingKey,
7677
Object processValue,
7778
Object existingValue,
@@ -81,9 +82,9 @@ public abstract Object copyEmbeddingForSingleValue(
8182
);
8283

8384
/**
84-
* Abstract method to filter and compare lists of values.
85-
* If all elements in the list are identical between the new and existing metadata maps, embeddings are copied,
86-
* and an empty list is returned to indicate no further processing is required.
85+
* Abstract method to filter and compare lists of objects.
86+
* If all objects in the list are identical between the new and existing metadata maps, embeddings are copied,
87+
* and null is returned to indicate no further processing is required.
8788
*
8889
* @param embeddingKey The dot-notation path for the embedding field
8990
* @param processList The list of values to be checked for potential embedding reuse
@@ -93,7 +94,7 @@ public abstract Object copyEmbeddingForSingleValue(
9394
* @return A processed list or an empty list if embeddings are reused.
9495
*/
9596

96-
public abstract List<Object> copyEmbeddingForMultipleValues(
97+
public abstract List<Object> copyEmbeddingForListObject(
9798
String embeddingKey,
9899
List<Object> processList,
99100
List<Object> existingList,
@@ -194,7 +195,7 @@ protected List<Object> filterListValue(
194195
List<Object> existingListValue = ProcessorUtils.unsafeCastToObjectList(existingListOptional.get());
195196
if (existingListValue.getFirst() instanceof List) {
196197
// in case of nested list, compare and copy by list comparison
197-
return copyEmbeddingForMultipleValues(
198+
return copyEmbeddingForListObject(
198199
embeddingKey,
199200
processList,
200201
ProcessorUtils.unsafeCastToObjectList(existingListValue.getFirst()),
@@ -235,16 +236,19 @@ public List<Object> filterMapValuesInList(
235236
ListIterator<Object> existingListIterator = existingList.listIterator();
236237
ListIterator<Object> embeddingListIterator = embeddingList.listIterator();
237238
int index = 0;
238-
while (processListIterator.hasNext() && existingListIterator.hasNext() && embeddingListIterator.hasNext()) {
239-
Object processedItem = copyEmbeddingForSingleValue(
240-
embeddingKey,
241-
processListIterator.next(),
242-
existingListIterator.next(),
243-
embeddingListIterator.next(),
244-
sourceAndMetadataMap,
245-
index++
246-
);
247-
filteredList.add(processedItem);
239+
for (Object processValue : processList) {
240+
if (Objects.nonNull(processValue) && existingListIterator.hasNext() && embeddingListIterator.hasNext()) {
241+
Object processedItem = copyEmbeddingForSingleObject(
242+
embeddingKey,
243+
processValue,
244+
existingListIterator.next(),
245+
embeddingListIterator.next(),
246+
sourceAndMetadataMap,
247+
index
248+
);
249+
filteredList.add(processedItem);
250+
}
251+
index++;
248252
}
249253
return filteredList;
250254
}

src/main/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilter.java

+5-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import lombok.extern.log4j.Log4j2;
88
import org.opensearch.neuralsearch.processor.util.ProcessorUtils;
99

10-
import java.util.Collections;
1110
import java.util.List;
1211
import java.util.Map;
1312
import java.util.Objects;
@@ -47,7 +46,7 @@ public Object filterInferenceValue(
4746
Optional<Object> existingValueOptional = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, textPath);
4847
Optional<Object> embeddingValueOptional = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, embeddingKey);
4948
if (existingValueOptional.isPresent() && embeddingValueOptional.isPresent()) {
50-
return copyEmbeddingForSingleValue(
49+
return copyEmbeddingForSingleObject(
5150
embeddingKey,
5251
processValue,
5352
existingValueOptional.get(),
@@ -67,7 +66,7 @@ public Object filterInferenceValue(
6766
* @return null if embeddings are reused; the processValue otherwise.
6867
*/
6968
@Override
70-
public Object copyEmbeddingForSingleValue(
69+
public Object copyEmbeddingForSingleObject(
7170
String embeddingKey,
7271
Object processValue,
7372
Object existingValue,
@@ -90,7 +89,7 @@ public Object copyEmbeddingForSingleValue(
9089
* @return empty list if embeddings are reused; processList otherwise.
9190
*/
9291
@Override
93-
public List<Object> copyEmbeddingForMultipleValues(
92+
public List<Object> copyEmbeddingForListObject(
9493
String embeddingKey,
9594
List<Object> processList,
9695
List<Object> existingList,
@@ -99,8 +98,8 @@ public List<Object> copyEmbeddingForMultipleValues(
9998
) {
10099
if (Objects.equals(processList, existingList)) {
101100
ProcessorUtils.setValueToSource(sourceAndMetadataMap, embeddingKey, embeddingList);
102-
// if successfully copied, return empty list to be filtered out from process map
103-
return Collections.emptyList();
101+
// if successfully copied, return null to be filtered out from process map
102+
return null;
104103
}
105104
// source list and existing list are different, return processList to be included in process map
106105
return processList;

src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,12 @@ public static boolean isNumeric(Object value) {
306306
*/
307307
public static int getNumOfSubqueries(final List<CompoundTopDocs> queryTopDocs) {
308308
return queryTopDocs.stream()
309-
.filter(Objects::nonNull)
310-
.filter(topDocs -> !topDocs.getTopDocs().isEmpty())
311-
.findAny()
312-
.get()
313-
.getTopDocs()
314-
.size();
309+
.filter(Objects::nonNull)
310+
.filter(topDocs -> !topDocs.getTopDocs().isEmpty())
311+
.findAny()
312+
.get()
313+
.getTopDocs()
314+
.size();
315315
}
316316

317317
// This method should be used only when you are certain the object is a `Map<String, Object>`.

src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java

+22-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import com.google.common.collect.ImmutableList;
88
import org.apache.commons.lang.math.RandomUtils;
99
import org.opensearch.action.get.GetResponse;
10+
import org.opensearch.action.get.MultiGetItemResponse;
11+
import org.opensearch.action.get.MultiGetResponse;
1012
import org.opensearch.common.xcontent.XContentFactory;
1113
import org.opensearch.core.common.bytes.BytesReference;
1214
import org.opensearch.core.xcontent.XContentBuilder;
@@ -24,18 +26,22 @@
2426
import java.util.Map;
2527

2628
public class InferenceProcessorTestCase extends OpenSearchTestCase {
27-
28-
protected List<IngestDocumentWrapper> createIngestDocumentWrappers(int count) {
29+
protected List<IngestDocumentWrapper> createIngestDocumentWrappers(int count, String value) {
2930
List<IngestDocumentWrapper> wrapperList = new ArrayList<>();
30-
for (int i = 0; i < count; ++i) {
31+
for (int i = 1; i <= count; ++i) {
3132
Map<String, Object> sourceAndMetadata = new HashMap<>();
32-
sourceAndMetadata.put("key1", "value1");
33+
sourceAndMetadata.put("key1", value);
3334
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
35+
sourceAndMetadata.put("_id", String.valueOf(i));
3436
wrapperList.add(new IngestDocumentWrapper(i, new IngestDocument(sourceAndMetadata, new HashMap<>()), null));
3537
}
3638
return wrapperList;
3739
}
3840

41+
protected List<IngestDocumentWrapper> createIngestDocumentWrappers(int count) {
42+
return createIngestDocumentWrappers(count, "value1");
43+
}
44+
3945
protected List<List<Float>> createMockVectorWithLength(int size) {
4046
float suffix = .234f;
4147
List<List<Float>> result = new ArrayList<>();
@@ -120,6 +126,10 @@ protected GetResponse mockEmptyGetResponse() throws IOException {
120126
return GetResponse.fromXContent(contentParser);
121127
}
122128

129+
protected MultiGetResponse mockEmptyMultiGetItemResponse() throws IOException {
130+
return new MultiGetResponse(new MultiGetItemResponse[0]);
131+
}
132+
123133
protected GetResponse convertToGetResponse(IngestDocument ingestDocument) throws IOException {
124134
String index = ingestDocument.getSourceAndMetadata().get("_index").toString();
125135
String id = ingestDocument.getSourceAndMetadata().get("_id").toString();
@@ -130,4 +140,12 @@ protected GetResponse convertToGetResponse(IngestDocument ingestDocument) throws
130140
GetResult result = new GetResult(index, id, 0, 1, 1, true, bytes, null, null);
131141
return new GetResponse(result);
132142
}
143+
144+
protected MultiGetResponse convertToMultiGetItemResponse(List<IngestDocumentWrapper> ingestDocuments) throws IOException {
145+
MultiGetItemResponse[] multiGetItems = new MultiGetItemResponse[ingestDocuments.size()];
146+
for (int i = 0; i < ingestDocuments.size(); i++) {
147+
multiGetItems[i] = new MultiGetItemResponse(convertToGetResponse(ingestDocuments.get(i).getIngestDocument()), null);
148+
}
149+
return new MultiGetResponse(multiGetItems);
150+
}
133151
}

0 commit comments

Comments
 (0)