Skip to content

Commit 48d242a

Browse files
committed
implement batch document update scenario in text embedding processor (#1217)
Signed-off-by: will-hwang <[email protected]>
1 parent d4b46c8 commit 48d242a

10 files changed

+496
-75
lines changed

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

+112-26
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import java.util.ArrayList;
88
import java.util.Arrays;
99
import java.util.Collection;
10-
import java.util.Collections;
1110
import java.util.Comparator;
1211
import java.util.HashMap;
1312
import java.util.Iterator;
@@ -26,6 +25,8 @@
2625
import org.apache.commons.lang3.StringUtils;
2726
import org.apache.commons.lang3.tuple.ImmutablePair;
2827
import org.apache.commons.lang3.tuple.Pair;
28+
import org.opensearch.action.get.MultiGetItemResponse;
29+
import org.opensearch.action.get.MultiGetRequest;
2930
import org.opensearch.common.collect.Tuple;
3031
import org.opensearch.core.action.ActionListener;
3132
import org.opensearch.core.common.util.CollectionUtils;
@@ -54,6 +55,8 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {
5455

5556
public static final String MODEL_ID_FIELD = "model_id";
5657
public static final String FIELD_MAP_FIELD = "field_map";
58+
public static final String INDEX_FIELD = "_index";
59+
public static final String ID_FIELD = "_id";
5760
private static final BiFunction<Object, Object, Object> REMAPPING_FUNCTION = (v1, v2) -> {
5861
if (v1 instanceof Collection && v2 instanceof Collection) {
5962
((Collection) v1).addAll((Collection) v2);
@@ -169,23 +172,71 @@ void preprocessIngestDocument(IngestDocument ingestDocument) {
169172
*/
170173
abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);
171174

175+
/**
176+
* This is the function which does actual inference work for subBatchExecute interface.
177+
* @param ingestDocumentWrappers a list of IngestDocuments in a batch.
178+
* @param handler a callback handler to handle inference results which is a list of objects.
179+
*/
172180
@Override
173181
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
174-
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
175-
handler.accept(Collections.emptyList());
176-
return;
182+
try {
183+
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
184+
handler.accept(ingestDocumentWrappers);
185+
return;
186+
}
187+
188+
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
189+
List<String> inferenceList = constructInferenceTexts(dataForInferences);
190+
if (inferenceList.isEmpty()) {
191+
handler.accept(ingestDocumentWrappers);
192+
return;
193+
}
194+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
195+
} catch (Exception e) {
196+
updateWithExceptions(ingestDocumentWrappers, e);
197+
handler.accept(ingestDocumentWrappers);
177198
}
199+
}
178200

179-
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
180-
List<String> inferenceList = constructInferenceTexts(dataForInferences);
181-
if (inferenceList.isEmpty()) {
201+
/**
202+
* This is a helper function for subBatchExecute, which invokes doBatchExecute for given inference list.
203+
* @param ingestDocumentWrappers a list of IngestDocuments in a batch.
204+
* @param inferenceList a list of String for inference.
205+
* @param dataForInferences a list of data for inference, which includes ingestDocumentWrapper, processMap, inferenceList.
206+
* @param handler a callback handler to handle inference results which is a list of objects.
207+
*/
208+
protected void doSubBatchExecute(
209+
List<IngestDocumentWrapper> ingestDocumentWrappers,
210+
List<String> inferenceList,
211+
List<DataForInference> dataForInferences,
212+
Consumer<List<IngestDocumentWrapper>> handler
213+
) {
214+
try {
215+
Tuple<List<String>, Map<Integer, Integer>> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList);
216+
inferenceList = sortedResult.v1();
217+
Map<Integer, Integer> originalOrder = sortedResult.v2();
218+
doBatchExecute(
219+
inferenceList,
220+
results -> batchExecuteHandler(results, ingestDocumentWrappers, dataForInferences, originalOrder, handler),
221+
exception -> {
222+
updateWithExceptions(ingestDocumentWrappers, exception);
223+
handler.accept(ingestDocumentWrappers);
224+
}
225+
);
226+
} catch (Exception e) {
227+
updateWithExceptions(ingestDocumentWrappers, e);
182228
handler.accept(ingestDocumentWrappers);
183-
return;
184229
}
185-
Tuple<List<String>, Map<Integer, Integer>> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList);
186-
inferenceList = sortedResult.v1();
187-
Map<Integer, Integer> originalOrder = sortedResult.v2();
188-
doBatchExecute(inferenceList, results -> {
230+
}
231+
232+
private void batchExecuteHandler(
233+
List<?> results,
234+
List<IngestDocumentWrapper> ingestDocumentWrappers,
235+
List<DataForInference> dataForInferences,
236+
Map<Integer, Integer> originalOrder,
237+
Consumer<List<IngestDocumentWrapper>> handler
238+
) {
239+
try {
189240
int startIndex = 0;
190241
results = restoreToOriginalOrder(results, originalOrder);
191242
for (DataForInference dataForInference : dataForInferences) {
@@ -201,17 +252,11 @@ public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers,
201252
inferenceResults
202253
);
203254
}
255+
} catch (Exception e) {
256+
updateWithExceptions(ingestDocumentWrappers, e);
257+
} finally {
204258
handler.accept(ingestDocumentWrappers);
205-
}, exception -> {
206-
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
207-
// The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
208-
// set exception to IngestDocumentWrapper which doesn't have exception before.
209-
if (ingestDocumentWrapper.getException() == null) {
210-
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), exception);
211-
}
212-
}
213-
handler.accept(ingestDocumentWrappers);
214-
});
259+
}
215260
}
216261

217262
private Tuple<List<String>, Map<Integer, Integer>> sortByLengthAndReturnOriginalOrder(List<String> inferenceList) {
@@ -238,7 +283,7 @@ private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> or
238283
return sortedResults;
239284
}
240285

241-
private List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
286+
protected List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
242287
List<String> inferenceTexts = new ArrayList<>();
243288
for (DataForInference dataForInference : dataForInferences) {
244289
if (dataForInference.getIngestDocumentWrapper().getException() != null
@@ -250,7 +295,7 @@ private List<String> constructInferenceTexts(List<DataForInference> dataForInfer
250295
return inferenceTexts;
251296
}
252297

253-
private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
298+
protected List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
254299
List<DataForInference> dataForInferences = new ArrayList<>();
255300
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
256301
Map<String, Object> processMap = null;
@@ -272,7 +317,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
272317

273318
@Getter
274319
@AllArgsConstructor
275-
private static class DataForInference {
320+
protected static class DataForInference {
276321
private final IngestDocumentWrapper ingestDocumentWrapper;
277322
private final Map<String, Object> processMap;
278323
private final List<String> inferenceList;
@@ -415,6 +460,36 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
415460
nlpResult.forEach(ingestDocument::setFieldValue);
416461
}
417462

463+
/**
464+
* This method creates a MultiGetRequest from a list of ingest documents to be fetched for comparison
465+
* @param ingestDocumentWrappers, list of ingest documents
466+
* */
467+
protected MultiGetRequest buildMultiGetRequest(List<IngestDocumentWrapper> ingestDocumentWrappers) {
468+
MultiGetRequest multiGetRequest = new MultiGetRequest();
469+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
470+
Object index = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata().get(INDEX_FIELD);
471+
Object id = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata().get(ID_FIELD);
472+
if (Objects.nonNull(index) && Objects.nonNull(id)) {
473+
multiGetRequest.add(index.toString(), id.toString());
474+
}
475+
}
476+
return multiGetRequest;
477+
}
478+
479+
/**
480+
* This method creates a map of documents from MultiGetItemResponse where the key is document ID and value is corresponding document
481+
* @param multiGetItemResponses, array of responses from Multi Get Request
482+
* */
483+
protected Map<String, Map<String, Object>> createDocumentMap(MultiGetItemResponse[] multiGetItemResponses) {
484+
Map<String, Map<String, Object>> existingDocuments = new HashMap<>();
485+
for (MultiGetItemResponse item : multiGetItemResponses) {
486+
String id = item.getId();
487+
Map<String, Object> existingDocument = item.getResponse().getSourceAsMap();
488+
existingDocuments.put(id, existingDocument);
489+
}
490+
return existingDocuments;
491+
}
492+
418493
@SuppressWarnings({ "unchecked" })
419494
@VisibleForTesting
420495
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
@@ -504,6 +579,17 @@ private void processMapEntryValue(
504579
}
505580
}
506581

582+
// This method updates each ingestDocument with exceptions
583+
protected void updateWithExceptions(List<IngestDocumentWrapper> ingestDocumentWrappers, Exception e) {
584+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
585+
// The IngestDocumentWrapper might have already run into exception. So here we only
586+
// set exception to IngestDocumentWrapper which doesn't have exception before.
587+
if (ingestDocumentWrapper.getException() == null) {
588+
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
589+
}
590+
}
591+
}
592+
507593
private void processMapEntryValue(
508594
List<?> results,
509595
IndexWrapper indexWrapper,
@@ -582,7 +668,7 @@ private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceV
582668
List<Map<String, Object>> keyToResult = new ArrayList<>();
583669
sourceValue.stream()
584670
.filter(Objects::nonNull) // explicit null check is required since sourceValue can contain null values in cases where
585-
// sourceValue has been filtered
671+
// sourceValue has been filtered
586672
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
587673
return keyToResult;
588674
}

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

+88
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
*/
55
package org.opensearch.neuralsearch.processor;
66

7+
import java.util.ArrayList;
8+
import java.util.Collections;
79
import java.util.List;
810
import java.util.Map;
911
import java.util.Objects;
@@ -13,10 +15,14 @@
1315

1416
import org.opensearch.action.get.GetAction;
1517
import org.opensearch.action.get.GetRequest;
18+
import org.opensearch.action.get.MultiGetAction;
19+
import org.opensearch.action.get.MultiGetItemResponse;
1620
import org.opensearch.cluster.service.ClusterService;
1721
import org.opensearch.core.action.ActionListener;
22+
import org.opensearch.core.common.util.CollectionUtils;
1823
import org.opensearch.env.Environment;
1924
import org.opensearch.ingest.IngestDocument;
25+
import org.opensearch.ingest.IngestDocumentWrapper;
2026
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
2127

2228
import lombok.extern.log4j.Log4j2;
@@ -106,4 +112,86 @@ public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler
106112
ActionListener.wrap(handler::accept, onException)
107113
);
108114
}
115+
116+
@Override
117+
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
118+
try {
119+
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
120+
handler.accept(Collections.emptyList());
121+
return;
122+
}
123+
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
124+
List<String> inferenceList = constructInferenceTexts(dataForInferences);
125+
if (inferenceList.isEmpty()) {
126+
handler.accept(ingestDocumentWrappers);
127+
return;
128+
}
129+
// skip existing flag is turned off. Call doSubBatchExecute without filtering
130+
if (skipExisting == false) {
131+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
132+
return;
133+
}
134+
// skipExisting flag is turned on, eligible inference texts in dataForInferences will be compared and filtered after embeddings
135+
// are copied
136+
openSearchClient.execute(
137+
MultiGetAction.INSTANCE,
138+
buildMultiGetRequest(ingestDocumentWrappers),
139+
ActionListener.wrap(response -> {
140+
try {
141+
MultiGetItemResponse[] multiGetItemResponses = response.getResponses();
142+
if (multiGetItemResponses == null || multiGetItemResponses.length == 0) {
143+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
144+
return;
145+
}
146+
// create a map of documents with key: doc_id and value: doc
147+
Map<String, Map<String, Object>> existingDocuments = createDocumentMap(multiGetItemResponses);
148+
List<DataForInference> filteredDataForInference = filterDataForInference(dataForInferences, existingDocuments);
149+
List<String> filteredInferenceList = constructInferenceTexts(filteredDataForInference);
150+
if (filteredInferenceList.isEmpty()) {
151+
handler.accept(ingestDocumentWrappers);
152+
} else {
153+
doSubBatchExecute(ingestDocumentWrappers, filteredInferenceList, filteredDataForInference, handler);
154+
}
155+
} catch (Exception e) {
156+
updateWithExceptions(ingestDocumentWrappers, e);
157+
handler.accept(ingestDocumentWrappers);
158+
}
159+
}, e -> {
160+
// When exception is thrown in for MultiGetAction, set exception to all ingestDocumentWrappers
161+
updateWithExceptions(ingestDocumentWrappers, e);
162+
handler.accept(ingestDocumentWrappers);
163+
})
164+
);
165+
} catch (Exception e) {
166+
updateWithExceptions(ingestDocumentWrappers, e);
167+
handler.accept(ingestDocumentWrappers);
168+
}
169+
}
170+
171+
// This is a helper method to filter the given list of dataForInferences by comparing its documents with existingDocuments with
172+
// textEmbeddingInferenceFilter
173+
private List<DataForInference> filterDataForInference(
174+
List<DataForInference> dataForInferences,
175+
Map<String, Map<String, Object>> existingDocuments
176+
) {
177+
List<DataForInference> filteredDataForInference = new ArrayList<>();
178+
for (DataForInference dataForInference : dataForInferences) {
179+
IngestDocumentWrapper ingestDocumentWrapper = dataForInference.getIngestDocumentWrapper();
180+
Map<String, Object> processMap = dataForInference.getProcessMap();
181+
Map<String, Object> document = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata();
182+
Object id = document.get(ID_FIELD);
183+
// insert non-filtered dataForInference if existing document does not exist
184+
if (Objects.isNull(id) || existingDocuments.containsKey(id.toString()) == false) {
185+
filteredDataForInference.add(dataForInference);
186+
continue;
187+
}
188+
// filter dataForInference when existing document exists
189+
String docId = id.toString();
190+
Map<String, Object> existingDocument = existingDocuments.get(docId);
191+
Map<String, Object> filteredProcessMap = textEmbeddingInferenceFilter.filter(existingDocument, document, processMap);
192+
List<String> filteredInferenceList = createInferenceList(filteredProcessMap);
193+
filteredDataForInference.add(new DataForInference(ingestDocumentWrapper, filteredProcessMap, filteredInferenceList));
194+
}
195+
return filteredDataForInference;
196+
}
109197
}

0 commit comments

Comments
 (0)