14
14
import java .util .function .BiConsumer ;
15
15
16
16
import org .apache .commons .lang3 .StringUtils ;
17
+ import org .opensearch .action .get .GetAction ;
18
+ import org .opensearch .action .get .GetRequest ;
19
+ import org .opensearch .action .get .GetResponse ;
17
20
import org .opensearch .cluster .service .ClusterService ;
18
21
import org .opensearch .core .action .ActionListener ;
19
22
import org .opensearch .env .Environment ;
24
27
import com .google .common .annotations .VisibleForTesting ;
25
28
26
29
import lombok .extern .log4j .Log4j2 ;
30
+ import org .opensearch .neuralsearch .processor .optimization .TextImageEmbeddingInferenceFilter ;
31
+ import org .opensearch .transport .client .OpenSearchClient ;
27
32
28
33
/**
29
34
* This processor is used for user input data text and image embedding processing, model_id can be used to indicate which model user use,
@@ -35,19 +40,24 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor {
35
40
public static final String TYPE = "text_image_embedding" ;
36
41
public static final String MODEL_ID_FIELD = "model_id" ;
37
42
public static final String EMBEDDING_FIELD = "embedding" ;
43
+ public static final boolean DEFAULT_SKIP_EXISTING = false ;
44
+ public static final String SKIP_EXISTING = "skip_existing" ;
38
45
public static final String FIELD_MAP_FIELD = "field_map" ;
39
46
public static final String TEXT_FIELD_NAME = "text" ;
40
47
public static final String IMAGE_FIELD_NAME = "image" ;
41
48
public static final String INPUT_TEXT = "inputText" ;
42
49
public static final String INPUT_IMAGE = "inputImage" ;
50
+ private static final String INDEX_FIELD = "_index" ;
51
+ private static final String ID_FIELD = "_id" ;
43
52
private static final Set <String > VALID_FIELD_NAMES = Set .of (TEXT_FIELD_NAME , IMAGE_FIELD_NAME );
44
53
45
54
private final String modelId ;
46
55
private final String embedding ;
47
56
private final Map <String , String > fieldMap ;
48
-
57
+ private final boolean skipExisting ;
58
+ private final OpenSearchClient openSearchClient ;
49
59
private final MLCommonsClientAccessor mlCommonsClientAccessor ;
50
-
60
+ private final TextImageEmbeddingInferenceFilter inferenceFilter ;
51
61
private final Environment environment ;
52
62
private final ClusterService clusterService ;
53
63
@@ -57,6 +67,9 @@ public TextImageEmbeddingProcessor(
57
67
final String modelId ,
58
68
final String embedding ,
59
69
final Map <String , String > fieldMap ,
70
+ final boolean skipExisting ,
71
+ final TextImageEmbeddingInferenceFilter inferenceFilter ,
72
+ final OpenSearchClient openSearchClient ,
60
73
final MLCommonsClientAccessor clientAccessor ,
61
74
final Environment environment ,
62
75
final ClusterService clusterService
@@ -71,6 +84,9 @@ public TextImageEmbeddingProcessor(
71
84
this .mlCommonsClientAccessor = clientAccessor ;
72
85
this .environment = environment ;
73
86
this .clusterService = clusterService ;
87
+ this .skipExisting = skipExisting ;
88
+ this .inferenceFilter = inferenceFilter ;
89
+ this .openSearchClient = openSearchClient ;
74
90
}
75
91
76
92
private void validateEmbeddingConfiguration (final Map <String , String > fieldMap ) {
@@ -107,17 +123,30 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest
107
123
try {
108
124
Map <String , String > knnMap = buildMapWithKnnKeyAndOriginalValue (ingestDocument );
109
125
Map <String , String > inferenceMap = createInferences (knnMap );
110
- if (inferenceMap .isEmpty ()) {
111
- handler .accept (ingestDocument , null );
112
- } else {
113
- mlCommonsClientAccessor .inferenceSentencesMap (
114
- MapInferenceRequest .builder ().modelId (this .modelId ).inputObjects (inferenceMap ).build (),
115
- ActionListener .wrap (vectors -> {
116
- setVectorFieldsToDocument (ingestDocument , vectors );
117
- handler .accept (ingestDocument , null );
118
- }, e -> { handler .accept (null , e ); })
119
- );
126
+ if (skipExisting == false ) {
127
+ if (inferenceMap .isEmpty ()) {
128
+ handler .accept (ingestDocument , null );
129
+ } else {
130
+ generateAndSetInference (ingestDocument , inferenceMap , handler );
131
+ }
132
+ return ;
120
133
}
134
+ // if skipExisting flag is turned on, eligible inference text and images will be compared and filtered after embeddings are
135
+ // copied
136
+ Object index = ingestDocument .getSourceAndMetadata ().get (INDEX_FIELD );
137
+ Object id = ingestDocument .getSourceAndMetadata ().get (ID_FIELD );
138
+ if (Objects .isNull (index ) || Objects .isNull (id )) {
139
+ generateAndSetInference (ingestDocument , inferenceMap , handler );
140
+ return ;
141
+ }
142
+ openSearchClient .execute (
143
+ GetAction .INSTANCE ,
144
+ new GetRequest (index .toString (), id .toString ()),
145
+ ActionListener .wrap (
146
+ response -> reuseOrGenerateEmbedding (response , ingestDocument , knnMap , inferenceMap , handler ),
147
+ e -> handler .accept (null , e )
148
+ )
149
+ );
121
150
} catch (Exception e ) {
122
151
handler .accept (null , e );
123
152
}
@@ -174,4 +203,55 @@ Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Number> m
174
203
public String getType () {
175
204
return TYPE ;
176
205
}
206
+
207
+ /**
208
+ * This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument
209
+ *
210
+ * @param ingestDocument ingestDocument to populate embeddings to
211
+ * @param inferenceMap map indicating the path in ingestDocument to populate embeddings
212
+ * @param handler SourceAndMetadataMap of ingestDocument Document
213
+ *
214
+ */
215
+ private void generateAndSetInference (
216
+ IngestDocument ingestDocument ,
217
+ Map <String , String > inferenceMap ,
218
+ BiConsumer <IngestDocument , Exception > handler
219
+ ) {
220
+ mlCommonsClientAccessor .inferenceSentencesMap (
221
+ MapInferenceRequest .builder ().modelId (this .modelId ).inputObjects (inferenceMap ).build (),
222
+ ActionListener .wrap (vectors -> {
223
+ setVectorFieldsToDocument (ingestDocument , vectors );
224
+ handler .accept (ingestDocument , null );
225
+ }, e -> { handler .accept (null , e ); })
226
+ );
227
+ }
228
+
229
+ // This method validates and filters given knnMap and inferenceMap after response is successfully retrieved from get operation.
230
+ private void reuseOrGenerateEmbedding (
231
+ GetResponse response ,
232
+ IngestDocument ingestDocument ,
233
+ Map <String , String > knnMap ,
234
+ Map <String , String > inferenceMap ,
235
+ BiConsumer <IngestDocument , Exception > handler
236
+ ) {
237
+ final Map <String , Object > existingDocument = response .getSourceAsMap ();
238
+ if (existingDocument == null || existingDocument .isEmpty ()) {
239
+ generateAndSetInference (ingestDocument , inferenceMap , handler );
240
+ return ;
241
+ }
242
+ // filter given knnMap by comparing existing document with ingestDocument
243
+ Map <String , String > filteredKnnMap = inferenceFilter .filterAndCopyExistingEmbeddings (
244
+ ingestDocument ,
245
+ existingDocument ,
246
+ knnMap ,
247
+ embedding
248
+ );
249
+ // create inference map based on filtered knnMap
250
+ Map <String , String > filteredInferenceMap = createInferences (filteredKnnMap );
251
+ if (filteredInferenceMap .isEmpty ()) {
252
+ handler .accept (ingestDocument , null );
253
+ } else {
254
+ generateAndSetInference (ingestDocument , filteredInferenceMap , handler );
255
+ }
256
+ }
177
257
}
0 commit comments