Skip to content

Commit b9b21d0

Browse files
committed
Implement Optimized embedding generation in text and image embedding processor
Signed-off-by: will-hwang <[email protected]>
1 parent 8deb63c commit b9b21d0

11 files changed

+530
-37
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1212
- Support semantic sentence highlighter ([#1193](https://github.com/opensearch-project/neural-search/pull/1193))
1313
- Optimize embedding generation in Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
1414
- Optimize embedding generation in Sparse Encoding Processor ([#1246](https://github.com/opensearch-project/neural-search/pull/1246))
15+
- Optimize embedding generation in Text/Image Embedding Processor ([#1249](https://github.com/opensearch-project/neural-search/pull/1249))
1516

1617
### Enhancements
1718

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,12 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
154154
parameters.ingestService.getClusterService()
155155
),
156156
TextImageEmbeddingProcessor.TYPE,
157-
new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
157+
new TextImageEmbeddingProcessorFactory(
158+
parameters.client,
159+
clientAccessor,
160+
parameters.env,
161+
parameters.ingestService.getClusterService()
162+
),
158163
TextChunkingProcessor.TYPE,
159164
new TextChunkingProcessorFactory(parameters.env, parameters.ingestService.getClusterService(), parameters.analysisRegistry)
160165
);

src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java

+92-12
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import java.util.function.BiConsumer;
1515

1616
import org.apache.commons.lang3.StringUtils;
17+
import org.opensearch.action.get.GetAction;
18+
import org.opensearch.action.get.GetRequest;
19+
import org.opensearch.action.get.GetResponse;
1720
import org.opensearch.cluster.service.ClusterService;
1821
import org.opensearch.core.action.ActionListener;
1922
import org.opensearch.env.Environment;
@@ -24,6 +27,8 @@
2427
import com.google.common.annotations.VisibleForTesting;
2528

2629
import lombok.extern.log4j.Log4j2;
30+
import org.opensearch.neuralsearch.processor.optimization.TextImageEmbeddingInferenceFilter;
31+
import org.opensearch.transport.client.OpenSearchClient;
2732

2833
/**
2934
* This processor is used for user input data text and image embedding processing, model_id can be used to indicate which model user use,
@@ -35,19 +40,24 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor {
3540
public static final String TYPE = "text_image_embedding";
3641
public static final String MODEL_ID_FIELD = "model_id";
3742
public static final String EMBEDDING_FIELD = "embedding";
43+
public static final boolean DEFAULT_SKIP_EXISTING = false;
44+
public static final String SKIP_EXISTING = "skip_existing";
3845
public static final String FIELD_MAP_FIELD = "field_map";
3946
public static final String TEXT_FIELD_NAME = "text";
4047
public static final String IMAGE_FIELD_NAME = "image";
4148
public static final String INPUT_TEXT = "inputText";
4249
public static final String INPUT_IMAGE = "inputImage";
50+
private static final String INDEX_FIELD = "_index";
51+
private static final String ID_FIELD = "_id";
4352
private static final Set<String> VALID_FIELD_NAMES = Set.of(TEXT_FIELD_NAME, IMAGE_FIELD_NAME);
4453

4554
private final String modelId;
4655
private final String embedding;
4756
private final Map<String, String> fieldMap;
48-
57+
private final boolean skipExisting;
58+
private final OpenSearchClient openSearchClient;
4959
private final MLCommonsClientAccessor mlCommonsClientAccessor;
50-
60+
private final TextImageEmbeddingInferenceFilter inferenceFilter;
5161
private final Environment environment;
5262
private final ClusterService clusterService;
5363

@@ -57,6 +67,9 @@ public TextImageEmbeddingProcessor(
5767
final String modelId,
5868
final String embedding,
5969
final Map<String, String> fieldMap,
70+
final boolean skipExisting,
71+
final TextImageEmbeddingInferenceFilter inferenceFilter,
72+
final OpenSearchClient openSearchClient,
6073
final MLCommonsClientAccessor clientAccessor,
6174
final Environment environment,
6275
final ClusterService clusterService
@@ -71,6 +84,9 @@ public TextImageEmbeddingProcessor(
7184
this.mlCommonsClientAccessor = clientAccessor;
7285
this.environment = environment;
7386
this.clusterService = clusterService;
87+
this.skipExisting = skipExisting;
88+
this.inferenceFilter = inferenceFilter;
89+
this.openSearchClient = openSearchClient;
7490
}
7591

7692
private void validateEmbeddingConfiguration(final Map<String, String> fieldMap) {
@@ -107,17 +123,30 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest
107123
try {
108124
Map<String, String> knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument);
109125
Map<String, String> inferenceMap = createInferences(knnMap);
110-
if (inferenceMap.isEmpty()) {
111-
handler.accept(ingestDocument, null);
112-
} else {
113-
mlCommonsClientAccessor.inferenceSentencesMap(
114-
MapInferenceRequest.builder().modelId(this.modelId).inputObjects(inferenceMap).build(),
115-
ActionListener.wrap(vectors -> {
116-
setVectorFieldsToDocument(ingestDocument, vectors);
117-
handler.accept(ingestDocument, null);
118-
}, e -> { handler.accept(null, e); })
119-
);
126+
if (skipExisting == false) {
127+
if (inferenceMap.isEmpty()) {
128+
handler.accept(ingestDocument, null);
129+
} else {
130+
generateAndSetInference(ingestDocument, inferenceMap, handler);
131+
}
132+
return;
120133
}
134+
// if skipExisting flag is turned on, eligible inference text and images will be compared and filtered after embeddings are
135+
// copied
136+
Object index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD);
137+
Object id = ingestDocument.getSourceAndMetadata().get(ID_FIELD);
138+
if (Objects.isNull(index) || Objects.isNull(id)) {
139+
generateAndSetInference(ingestDocument, inferenceMap, handler);
140+
return;
141+
}
142+
openSearchClient.execute(
143+
GetAction.INSTANCE,
144+
new GetRequest(index.toString(), id.toString()),
145+
ActionListener.wrap(
146+
response -> reuseOrGenerateEmbedding(response, ingestDocument, knnMap, inferenceMap, handler),
147+
e -> handler.accept(null, e)
148+
)
149+
);
121150
} catch (Exception e) {
122151
handler.accept(null, e);
123152
}
@@ -174,4 +203,55 @@ Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Number> m
174203
public String getType() {
175204
return TYPE;
176205
}
206+
207+
/**
208+
* This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument
209+
*
210+
* @param ingestDocument ingestDocument to populate embeddings to
211+
* @param inferenceMap map indicating the path in ingestDocument to populate embeddings
212+
* @param handler SourceAndMetadataMap of ingestDocument Document
213+
*
214+
*/
215+
private void generateAndSetInference(
216+
IngestDocument ingestDocument,
217+
Map<String, String> inferenceMap,
218+
BiConsumer<IngestDocument, Exception> handler
219+
) {
220+
mlCommonsClientAccessor.inferenceSentencesMap(
221+
MapInferenceRequest.builder().modelId(this.modelId).inputObjects(inferenceMap).build(),
222+
ActionListener.wrap(vectors -> {
223+
setVectorFieldsToDocument(ingestDocument, vectors);
224+
handler.accept(ingestDocument, null);
225+
}, e -> { handler.accept(null, e); })
226+
);
227+
}
228+
229+
// This method validates and filters given knnMap and inferenceMap after response is successfully retrieved from get operation.
230+
private void reuseOrGenerateEmbedding(
231+
GetResponse response,
232+
IngestDocument ingestDocument,
233+
Map<String, String> knnMap,
234+
Map<String, String> inferenceMap,
235+
BiConsumer<IngestDocument, Exception> handler
236+
) {
237+
final Map<String, Object> existingDocument = response.getSourceAsMap();
238+
if (existingDocument == null || existingDocument.isEmpty()) {
239+
generateAndSetInference(ingestDocument, inferenceMap, handler);
240+
return;
241+
}
242+
// filter given knnMap by comparing existing document with ingestDocument
243+
Map<String, String> filteredKnnMap = inferenceFilter.filterAndCopyExistingEmbeddings(
244+
ingestDocument,
245+
existingDocument,
246+
knnMap,
247+
embedding
248+
);
249+
// create inference map based on filtered knnMap
250+
Map<String, String> filteredInferenceMap = createInferences(filteredKnnMap);
251+
if (filteredInferenceMap.isEmpty()) {
252+
handler.accept(ingestDocument, null);
253+
} else {
254+
generateAndSetInference(ingestDocument, filteredInferenceMap, handler);
255+
}
256+
}
177257
}

src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java

+21-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
*/
55
package org.opensearch.neuralsearch.processor.factory;
66

7+
import static org.opensearch.ingest.ConfigurationUtils.readBooleanProperty;
78
import static org.opensearch.ingest.ConfigurationUtils.readMap;
89
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;
10+
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.DEFAULT_SKIP_EXISTING;
11+
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.SKIP_EXISTING;
912
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD;
1013
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.FIELD_MAP_FIELD;
1114
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.MODEL_ID_FIELD;
@@ -20,13 +23,16 @@
2023
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
2124

2225
import lombok.AllArgsConstructor;
26+
import org.opensearch.neuralsearch.processor.optimization.TextImageEmbeddingInferenceFilter;
27+
import org.opensearch.transport.client.OpenSearchClient;
2328

2429
/**
2530
* Factory for text_image embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
2631
*/
2732
@AllArgsConstructor
2833
public class TextImageEmbeddingProcessorFactory implements Processor.Factory {
2934

35+
private final OpenSearchClient openSearchClient;
3036
private final MLCommonsClientAccessor clientAccessor;
3137
private final Environment environment;
3238
private final ClusterService clusterService;
@@ -36,7 +42,20 @@ public Processor create(Map<String, Processor.Factory> processorFactories, Strin
3642
throws Exception {
3743
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
3844
String embedding = readStringProperty(TYPE, tag, config, EMBEDDING_FIELD);
39-
Map<String, String> filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
40-
return new TextImageEmbeddingProcessor(tag, description, modelId, embedding, filedMap, clientAccessor, environment, clusterService);
45+
Map<String, String> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
46+
boolean skipExisting = readBooleanProperty(TextImageEmbeddingProcessor.TYPE, tag, config, SKIP_EXISTING, DEFAULT_SKIP_EXISTING);
47+
return new TextImageEmbeddingProcessor(
48+
tag,
49+
description,
50+
modelId,
51+
embedding,
52+
fieldMap,
53+
skipExisting,
54+
skipExisting ? new TextImageEmbeddingInferenceFilter() : null,
55+
openSearchClient,
56+
clientAccessor,
57+
environment,
58+
clusterService
59+
);
4160
}
4261
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor.optimization;
6+
7+
import lombok.extern.log4j.Log4j2;
8+
import org.opensearch.ingest.IngestDocument;
9+
10+
import java.util.Collections;
11+
import java.util.Map;
12+
import java.util.Objects;
13+
14+
/**
15+
* TextImageEmbeddingInferenceFilter optimizes text/image embedding inference by selectively processing text/image data.
16+
* This class provides efficient text/image embedding processing by comparing text/image between existing and new documents.
17+
* If both text and image are identical, the corresponding embeddings are copied over, avoiding redundant inference calls and improving performance.
18+
*/
19+
@Log4j2
20+
public class TextImageEmbeddingInferenceFilter {
21+
22+
public TextImageEmbeddingInferenceFilter() {}
23+
24+
/**
25+
* Filters the given knnMap by checking if the values for both text and image are identical in the existing and new document.
26+
* If both values for text and image match, the corresponding embedding is copied, and empty map is returned, indicating no further
27+
* processing is required. If any of the two do not match or embedding does not exist, the given knnMap is returned to be processed
28+
*
29+
* @return empty Map if embeddings are reused; the original knnMap otherwise.
30+
*/
31+
public Map<String, String> filterAndCopyExistingEmbeddings(
32+
IngestDocument ingestDocument,
33+
Map<String, Object> existingDocument,
34+
Map<String, String> knnMap,
35+
String embeddingField
36+
) {
37+
for (Map.Entry<String, String> entry : knnMap.entrySet()) {
38+
String key = entry.getKey();
39+
String value = entry.getValue();
40+
if (existingDocument.containsKey(key) == false || existingDocument.get(key).equals(value) == false) {
41+
return knnMap;
42+
}
43+
}
44+
Object embeddings = existingDocument.get(embeddingField);
45+
if (Objects.isNull(embeddings)) {
46+
return knnMap;
47+
}
48+
ingestDocument.setFieldValue(embeddingField, existingDocument.get(embeddingField));
49+
return Collections.emptyMap();
50+
}
51+
}

src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java

+20
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,24 @@ public void testEmbeddingProcessor_whenReindexingDocument_thenSuccessful() throw
7979
reindex(fromIndexName, toIndexName);
8080
assertEquals(1, getDocCount(toIndexName));
8181
}
82+
83+
public void testEmbeddingProcessor_whenSkipExisting_updateWithNoChange_thenSuccessful() throws Exception {
84+
String modelId = uploadModel();
85+
loadModel(modelId);
86+
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING_WITH_SKIP_EXISTING);
87+
createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME);
88+
ingestDocument(INDEX_NAME, INGEST_DOCUMENT, "1");
89+
updateDocument(INDEX_NAME, INGEST_DOCUMENT, "1");
90+
assertEquals(1, getDocCount(INDEX_NAME));
91+
}
92+
93+
public void testEmbeddingProcessor_whenSkipExisting_updateWithChange_thenSuccessful() throws Exception {
94+
String modelId = uploadModel();
95+
loadModel(modelId);
96+
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING_WITH_SKIP_EXISTING);
97+
createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME);
98+
ingestDocument(INDEX_NAME, INGEST_DOCUMENT, "1");
99+
updateDocument(INDEX_NAME, INGEST_DOCUMENT.replace("\"This is a good day\"", "\"This is a great day\""), "1");
100+
assertEquals(1, getDocCount(INDEX_NAME));
101+
}
82102
}

0 commit comments

Comments
 (0)