Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Optimized embedding generation in text and image embedding processor #1249

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Support semantic sentence highlighter ([#1193](https://github.com/opensearch-project/neural-search/pull/1193))
- Optimize embedding generation in Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
- Optimize embedding generation in Sparse Encoding Processor ([#1246](https://github.com/opensearch-project/neural-search/pull/1246))
- Optimize embedding generation in Text/Image Embedding Processor ([#1249](https://github.com/opensearch-project/neural-search/pull/1249))

### Enhancements

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,12 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
parameters.ingestService.getClusterService()
),
TextImageEmbeddingProcessor.TYPE,
new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
new TextImageEmbeddingProcessorFactory(
parameters.client,
clientAccessor,
parameters.env,
parameters.ingestService.getClusterService()
),
TextChunkingProcessor.TYPE,
new TextChunkingProcessorFactory(parameters.env, parameters.ingestService.getClusterService(), parameters.analysisRegistry)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import java.util.function.BiConsumer;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.get.GetAction;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
Expand All @@ -24,6 +27,8 @@
import com.google.common.annotations.VisibleForTesting;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.optimization.TextImageEmbeddingInferenceFilter;
import org.opensearch.transport.client.OpenSearchClient;

/**
* This processor is used for user input data text and image embedding processing, model_id can be used to indicate which model user use,
Expand All @@ -35,19 +40,24 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor {
public static final String TYPE = "text_image_embedding";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding";
public static final boolean DEFAULT_SKIP_EXISTING = false;
public static final String SKIP_EXISTING = "skip_existing";
public static final String FIELD_MAP_FIELD = "field_map";
public static final String TEXT_FIELD_NAME = "text";
public static final String IMAGE_FIELD_NAME = "image";
public static final String INPUT_TEXT = "inputText";
public static final String INPUT_IMAGE = "inputImage";
private static final String INDEX_FIELD = "_index";
private static final String ID_FIELD = "_id";
private static final Set<String> VALID_FIELD_NAMES = Set.of(TEXT_FIELD_NAME, IMAGE_FIELD_NAME);

private final String modelId;
private final String embedding;
private final Map<String, String> fieldMap;

private final boolean skipExisting;
private final OpenSearchClient openSearchClient;
private final MLCommonsClientAccessor mlCommonsClientAccessor;

private final TextImageEmbeddingInferenceFilter inferenceFilter;
private final Environment environment;
private final ClusterService clusterService;

Expand All @@ -57,6 +67,9 @@ public TextImageEmbeddingProcessor(
final String modelId,
final String embedding,
final Map<String, String> fieldMap,
final boolean skipExisting,
final TextImageEmbeddingInferenceFilter inferenceFilter,
final OpenSearchClient openSearchClient,
final MLCommonsClientAccessor clientAccessor,
final Environment environment,
final ClusterService clusterService
Expand All @@ -71,6 +84,9 @@ public TextImageEmbeddingProcessor(
this.mlCommonsClientAccessor = clientAccessor;
this.environment = environment;
this.clusterService = clusterService;
this.skipExisting = skipExisting;
this.inferenceFilter = inferenceFilter;
this.openSearchClient = openSearchClient;
}

private void validateEmbeddingConfiguration(final Map<String, String> fieldMap) {
Expand Down Expand Up @@ -109,15 +125,28 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest
Map<String, String> inferenceMap = createInferences(knnMap);
if (inferenceMap.isEmpty()) {
handler.accept(ingestDocument, null);
} else {
mlCommonsClientAccessor.inferenceSentencesMap(
MapInferenceRequest.builder().modelId(this.modelId).inputObjects(inferenceMap).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
return;
}
if (skipExisting == false) {
generateAndSetInference(ingestDocument, inferenceMap, handler);
return;
}
// if skipExisting flag is turned on, eligible inference text and images will be compared and filtered after embeddings are
// copied
Object index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD);
Object id = ingestDocument.getSourceAndMetadata().get(ID_FIELD);
if (Objects.isNull(index) || Objects.isNull(id)) {
generateAndSetInference(ingestDocument, inferenceMap, handler);
return;
}
openSearchClient.execute(
GetAction.INSTANCE,
new GetRequest(index.toString(), id.toString()),
ActionListener.wrap(
response -> reuseOrGenerateEmbedding(response, ingestDocument, knnMap, inferenceMap, handler),
e -> handler.accept(null, e)
)
);
} catch (Exception e) {
handler.accept(null, e);
}
Expand Down Expand Up @@ -174,4 +203,55 @@ Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Number> m
public String getType() {
return TYPE;
}

/**
* This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument
*
* @param ingestDocument ingestDocument to populate embeddings to
* @param inferenceMap map indicating the path in ingestDocument to populate embeddings
* @param handler SourceAndMetadataMap of ingestDocument Document
*
*/
private void generateAndSetInference(
IngestDocument ingestDocument,
Map<String, String> inferenceMap,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentencesMap(
MapInferenceRequest.builder().modelId(this.modelId).inputObjects(inferenceMap).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}

// This method validates and filters given knnMap and inferenceMap after response is successfully retrieved from get operation.
private void reuseOrGenerateEmbedding(
GetResponse response,
IngestDocument ingestDocument,
Map<String, String> knnMap,
Map<String, String> inferenceMap,
BiConsumer<IngestDocument, Exception> handler
) {
final Map<String, Object> existingDocument = response.getSourceAsMap();
if (existingDocument == null || existingDocument.isEmpty()) {
generateAndSetInference(ingestDocument, inferenceMap, handler);
return;
}
// filter given knnMap by comparing existing document with ingestDocument
Map<String, String> filteredKnnMap = inferenceFilter.filterAndCopyExistingEmbeddings(
ingestDocument,
existingDocument,
knnMap,
embedding
);
// create inference map based on filtered knnMap
Map<String, String> filteredInferenceMap = createInferences(filteredKnnMap);
if (filteredInferenceMap.isEmpty()) {
handler.accept(ingestDocument, null);
} else {
generateAndSetInference(ingestDocument, filteredInferenceMap, handler);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
*/
package org.opensearch.neuralsearch.processor.factory;

import static org.opensearch.ingest.ConfigurationUtils.readBooleanProperty;
import static org.opensearch.ingest.ConfigurationUtils.readMap;
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.DEFAULT_SKIP_EXISTING;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.SKIP_EXISTING;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.FIELD_MAP_FIELD;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.MODEL_ID_FIELD;
Expand All @@ -20,13 +23,16 @@
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;

import lombok.AllArgsConstructor;
import org.opensearch.neuralsearch.processor.optimization.TextImageEmbeddingInferenceFilter;
import org.opensearch.transport.client.OpenSearchClient;

/**
* Factory for text_image embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
*/
@AllArgsConstructor
public class TextImageEmbeddingProcessorFactory implements Processor.Factory {

private final OpenSearchClient openSearchClient;
private final MLCommonsClientAccessor clientAccessor;
private final Environment environment;
private final ClusterService clusterService;
Expand All @@ -36,7 +42,20 @@ public Processor create(Map<String, Processor.Factory> processorFactories, Strin
throws Exception {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
String embedding = readStringProperty(TYPE, tag, config, EMBEDDING_FIELD);
Map<String, String> filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
return new TextImageEmbeddingProcessor(tag, description, modelId, embedding, filedMap, clientAccessor, environment, clusterService);
Map<String, String> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
boolean skipExisting = readBooleanProperty(TextImageEmbeddingProcessor.TYPE, tag, config, SKIP_EXISTING, DEFAULT_SKIP_EXISTING);
return new TextImageEmbeddingProcessor(
tag,
description,
modelId,
embedding,
fieldMap,
skipExisting,
skipExisting ? new TextImageEmbeddingInferenceFilter() : null,
openSearchClient,
clientAccessor,
environment,
clusterService
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.optimization;

import lombok.extern.log4j.Log4j2;
import org.opensearch.ingest.IngestDocument;

import java.util.Collections;
import java.util.Map;
import java.util.Objects;

/**
* TextImageEmbeddingInferenceFilter optimizes text/image embedding inference by selectively processing text/image data.
* This class provides efficient text/image embedding processing by comparing text/image between existing and new documents.
* If both text and image are identical, the corresponding embeddings are copied over, avoiding redundant inference calls and improving performance.
*/
@Log4j2
public class TextImageEmbeddingInferenceFilter {

public TextImageEmbeddingInferenceFilter() {}

/**
* Filters the given knnMap by checking if the values for both text and image are identical in the existing and new document.
* If both values for text and image match, the corresponding embedding is copied, and empty map is returned, indicating no further
* processing is required. If any of the two do not match or embedding does not exist, the given knnMap is returned to be processed
*
* @return empty Map if embeddings are reused; the original knnMap otherwise.
*/
public Map<String, String> filterAndCopyExistingEmbeddings(
IngestDocument ingestDocument,
Map<String, Object> existingDocument,
Map<String, String> knnMap,
String embeddingField
) {
// knnMap can only contain two keys: one for text field and another for image field.
// If either of the two does not match, knnMap cannot be filtered
for (Map.Entry<String, String> entry : knnMap.entrySet()) {
String key = entry.getKey();
String value = entry.getValue();
if (existingDocument.containsKey(key) == false || existingDocument.get(key).equals(value) == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to compare both the text and image here and this can be done by just checking one key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

knnMap contains two keys: one for text and one for value. For each entry, it will be compared with text and image values of the existing document

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced offline that this code will work because currently we only allow user to define one image and one text field. So the knnMap only contains the text and image fields and both of them should be the same to reuse the existing embedding. We should add a comment to call it out.

Besides also thinking we may want to allow users to define multiple text and image fields in the processor. Probably we can create a RFC to see if there is a user need.

Copy link
Collaborator

@heemin32 heemin32 Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides also thinking we may want to allow users to define multiple text and image fields in the processor. Probably we can create a RFC to see if there is a user need.

#476

return knnMap;
}
}
Object embeddings = existingDocument.get(embeddingField);
if (Objects.isNull(embeddings)) {
return knnMap;
}
ingestDocument.setFieldValue(embeddingField, existingDocument.get(embeddingField));
return Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,24 @@ public void testEmbeddingProcessor_whenReindexingDocument_thenSuccessful() throw
reindex(fromIndexName, toIndexName);
assertEquals(1, getDocCount(toIndexName));
}

public void testEmbeddingProcessor_whenSkipExisting_updateWithNoChange_thenSuccessful() throws Exception {
String modelId = uploadModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING_WITH_SKIP_EXISTING);
createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME);
ingestDocument(INDEX_NAME, INGEST_DOCUMENT, "1");
updateDocument(INDEX_NAME, INGEST_DOCUMENT, "1");
assertEquals(1, getDocCount(INDEX_NAME));
}

public void testEmbeddingProcessor_whenSkipExisting_updateWithChange_thenSuccessful() throws Exception {
String modelId = uploadModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING_WITH_SKIP_EXISTING);
createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME);
ingestDocument(INDEX_NAME, INGEST_DOCUMENT, "1");
updateDocument(INDEX_NAME, INGEST_DOCUMENT.replace("\"This is a good day\"", "\"This is a great day\""), "1");
assertEquals(1, getDocCount(INDEX_NAME));
}
}
Loading
Loading