7
7
import java .util .ArrayList ;
8
8
import java .util .Arrays ;
9
9
import java .util .Collection ;
10
- import java .util .Collections ;
11
10
import java .util .Comparator ;
12
11
import java .util .HashMap ;
13
12
import java .util .Iterator ;
26
25
import org .apache .commons .lang3 .StringUtils ;
27
26
import org .apache .commons .lang3 .tuple .ImmutablePair ;
28
27
import org .apache .commons .lang3 .tuple .Pair ;
28
+ import org .opensearch .action .get .MultiGetItemResponse ;
29
+ import org .opensearch .action .get .MultiGetRequest ;
29
30
import org .opensearch .common .collect .Tuple ;
30
31
import org .opensearch .core .action .ActionListener ;
31
32
import org .opensearch .core .common .util .CollectionUtils ;
@@ -54,6 +55,8 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {
54
55
55
56
public static final String MODEL_ID_FIELD = "model_id" ;
56
57
public static final String FIELD_MAP_FIELD = "field_map" ;
58
+ public static final String INDEX_FIELD = "_index" ;
59
+ public static final String ID_FIELD = "_id" ;
57
60
private static final BiFunction <Object , Object , Object > REMAPPING_FUNCTION = (v1 , v2 ) -> {
58
61
if (v1 instanceof Collection && v2 instanceof Collection ) {
59
62
((Collection ) v1 ).addAll ((Collection ) v2 );
@@ -169,23 +172,71 @@ void preprocessIngestDocument(IngestDocument ingestDocument) {
169
172
*/
170
173
abstract void doBatchExecute (List <String > inferenceList , Consumer <List <?>> handler , Consumer <Exception > onException );
171
174
175
+ /**
176
+ * This is the function which does actual inference work for subBatchExecute interface.
177
+ * @param ingestDocumentWrappers a list of IngestDocuments in a batch.
178
+ * @param handler a callback handler to handle inference results which is a list of objects.
179
+ */
172
180
@ Override
173
181
public void subBatchExecute (List <IngestDocumentWrapper > ingestDocumentWrappers , Consumer <List <IngestDocumentWrapper >> handler ) {
174
- if (CollectionUtils .isEmpty (ingestDocumentWrappers )) {
175
- handler .accept (Collections .emptyList ());
176
- return ;
182
+ try {
183
+ if (CollectionUtils .isEmpty (ingestDocumentWrappers )) {
184
+ handler .accept (ingestDocumentWrappers );
185
+ return ;
186
+ }
187
+
188
+ List <DataForInference > dataForInferences = getDataForInference (ingestDocumentWrappers );
189
+ List <String > inferenceList = constructInferenceTexts (dataForInferences );
190
+ if (inferenceList .isEmpty ()) {
191
+ handler .accept (ingestDocumentWrappers );
192
+ return ;
193
+ }
194
+ doSubBatchExecute (ingestDocumentWrappers , inferenceList , dataForInferences , handler );
195
+ } catch (Exception e ) {
196
+ updateWithExceptions (ingestDocumentWrappers , e );
197
+ handler .accept (ingestDocumentWrappers );
177
198
}
199
+ }
178
200
179
- List <DataForInference > dataForInferences = getDataForInference (ingestDocumentWrappers );
180
- List <String > inferenceList = constructInferenceTexts (dataForInferences );
181
- if (inferenceList .isEmpty ()) {
201
+ /**
202
+ * This is a helper function for subBatchExecute, which invokes doBatchExecute for given inference list.
203
+ * @param ingestDocumentWrappers a list of IngestDocuments in a batch.
204
+ * @param inferenceList a list of String for inference.
205
+ * @param dataForInferences a list of data for inference, which includes ingestDocumentWrapper, processMap, inferenceList.
206
+ * @param handler a callback handler to handle inference results which is a list of objects.
207
+ */
208
+ protected void doSubBatchExecute (
209
+ List <IngestDocumentWrapper > ingestDocumentWrappers ,
210
+ List <String > inferenceList ,
211
+ List <DataForInference > dataForInferences ,
212
+ Consumer <List <IngestDocumentWrapper >> handler
213
+ ) {
214
+ try {
215
+ Tuple <List <String >, Map <Integer , Integer >> sortedResult = sortByLengthAndReturnOriginalOrder (inferenceList );
216
+ inferenceList = sortedResult .v1 ();
217
+ Map <Integer , Integer > originalOrder = sortedResult .v2 ();
218
+ doBatchExecute (
219
+ inferenceList ,
220
+ results -> batchExecuteHandler (results , ingestDocumentWrappers , dataForInferences , originalOrder , handler ),
221
+ exception -> {
222
+ updateWithExceptions (ingestDocumentWrappers , exception );
223
+ handler .accept (ingestDocumentWrappers );
224
+ }
225
+ );
226
+ } catch (Exception e ) {
227
+ updateWithExceptions (ingestDocumentWrappers , e );
182
228
handler .accept (ingestDocumentWrappers );
183
- return ;
184
229
}
185
- Tuple <List <String >, Map <Integer , Integer >> sortedResult = sortByLengthAndReturnOriginalOrder (inferenceList );
186
- inferenceList = sortedResult .v1 ();
187
- Map <Integer , Integer > originalOrder = sortedResult .v2 ();
188
- doBatchExecute (inferenceList , results -> {
230
+ }
231
+
232
+ private void batchExecuteHandler (
233
+ List <?> results ,
234
+ List <IngestDocumentWrapper > ingestDocumentWrappers ,
235
+ List <DataForInference > dataForInferences ,
236
+ Map <Integer , Integer > originalOrder ,
237
+ Consumer <List <IngestDocumentWrapper >> handler
238
+ ) {
239
+ try {
189
240
int startIndex = 0 ;
190
241
results = restoreToOriginalOrder (results , originalOrder );
191
242
for (DataForInference dataForInference : dataForInferences ) {
@@ -201,17 +252,11 @@ public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers,
201
252
inferenceResults
202
253
);
203
254
}
255
+ } catch (Exception e ) {
256
+ updateWithExceptions (ingestDocumentWrappers , e );
257
+ } finally {
204
258
handler .accept (ingestDocumentWrappers );
205
- }, exception -> {
206
- for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
207
- // The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
208
- // set exception to IngestDocumentWrapper which doesn't have exception before.
209
- if (ingestDocumentWrapper .getException () == null ) {
210
- ingestDocumentWrapper .update (ingestDocumentWrapper .getIngestDocument (), exception );
211
- }
212
- }
213
- handler .accept (ingestDocumentWrappers );
214
- });
259
+ }
215
260
}
216
261
217
262
private Tuple <List <String >, Map <Integer , Integer >> sortByLengthAndReturnOriginalOrder (List <String > inferenceList ) {
@@ -238,7 +283,7 @@ private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> or
238
283
return sortedResults ;
239
284
}
240
285
241
- private List <String > constructInferenceTexts (List <DataForInference > dataForInferences ) {
286
+ protected List <String > constructInferenceTexts (List <DataForInference > dataForInferences ) {
242
287
List <String > inferenceTexts = new ArrayList <>();
243
288
for (DataForInference dataForInference : dataForInferences ) {
244
289
if (dataForInference .getIngestDocumentWrapper ().getException () != null
@@ -250,7 +295,7 @@ private List<String> constructInferenceTexts(List<DataForInference> dataForInfer
250
295
return inferenceTexts ;
251
296
}
252
297
253
- private List <DataForInference > getDataForInference (List <IngestDocumentWrapper > ingestDocumentWrappers ) {
298
+ protected List <DataForInference > getDataForInference (List <IngestDocumentWrapper > ingestDocumentWrappers ) {
254
299
List <DataForInference > dataForInferences = new ArrayList <>();
255
300
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
256
301
Map <String , Object > processMap = null ;
@@ -272,7 +317,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
272
317
273
318
@ Getter
274
319
@ AllArgsConstructor
275
- private static class DataForInference {
320
+ protected static class DataForInference {
276
321
private final IngestDocumentWrapper ingestDocumentWrapper ;
277
322
private final Map <String , Object > processMap ;
278
323
private final List <String > inferenceList ;
@@ -415,6 +460,36 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
415
460
nlpResult .forEach (ingestDocument ::setFieldValue );
416
461
}
417
462
463
+ /**
464
+ * This method creates a MultiGetRequest from a list of ingest documents to be fetched for comparison
465
+ * @param ingestDocumentWrappers, list of ingest documents
466
+ * */
467
+ protected MultiGetRequest buildMultiGetRequest (List <IngestDocumentWrapper > ingestDocumentWrappers ) {
468
+ MultiGetRequest multiGetRequest = new MultiGetRequest ();
469
+ for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
470
+ Object index = ingestDocumentWrapper .getIngestDocument ().getSourceAndMetadata ().get (INDEX_FIELD );
471
+ Object id = ingestDocumentWrapper .getIngestDocument ().getSourceAndMetadata ().get (ID_FIELD );
472
+ if (Objects .nonNull (index ) && Objects .nonNull (id )) {
473
+ multiGetRequest .add (index .toString (), id .toString ());
474
+ }
475
+ }
476
+ return multiGetRequest ;
477
+ }
478
+
479
+ /**
480
+ * This method creates a map of documents from MultiGetItemResponse where the key is document ID and value is corresponding document
481
+ * @param multiGetItemResponses, array of responses from Multi Get Request
482
+ * */
483
+ protected Map <String , Map <String , Object >> createDocumentMap (MultiGetItemResponse [] multiGetItemResponses ) {
484
+ Map <String , Map <String , Object >> existingDocuments = new HashMap <>();
485
+ for (MultiGetItemResponse item : multiGetItemResponses ) {
486
+ String id = item .getId ();
487
+ Map <String , Object > existingDocument = item .getResponse ().getSourceAsMap ();
488
+ existingDocuments .put (id , existingDocument );
489
+ }
490
+ return existingDocuments ;
491
+ }
492
+
418
493
@ SuppressWarnings ({ "unchecked" })
419
494
@ VisibleForTesting
420
495
Map <String , Object > buildNLPResult (Map <String , Object > processorMap , List <?> results , Map <String , Object > sourceAndMetadataMap ) {
@@ -504,6 +579,17 @@ private void processMapEntryValue(
504
579
}
505
580
}
506
581
582
+ // This method updates each ingestDocument with exceptions
583
+ protected void updateWithExceptions (List <IngestDocumentWrapper > ingestDocumentWrappers , Exception e ) {
584
+ for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
585
+ // The IngestDocumentWrapper might have already run into exception. So here we only
586
+ // set exception to IngestDocumentWrapper which doesn't have exception before.
587
+ if (ingestDocumentWrapper .getException () == null ) {
588
+ ingestDocumentWrapper .update (ingestDocumentWrapper .getIngestDocument (), e );
589
+ }
590
+ }
591
+ }
592
+
507
593
private void processMapEntryValue (
508
594
List <?> results ,
509
595
IndexWrapper indexWrapper ,
@@ -582,7 +668,7 @@ private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceV
582
668
List <Map <String , Object >> keyToResult = new ArrayList <>();
583
669
sourceValue .stream ()
584
670
.filter (Objects ::nonNull ) // explicit null check is required since sourceValue can contain null values in cases where
585
- // sourceValue has been filtered
671
+ // sourceValue has been filtered
586
672
.forEachOrdered (x -> keyToResult .add (ImmutableMap .of (listTypeNestedMapKey , results .get (indexWrapper .index ++))));
587
673
return keyToResult ;
588
674
}
0 commit comments