diff --git a/CHANGELOG.md b/CHANGELOG.md index c043d319c..fca636279 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Support semantic sentence highlighter ([#1193](https://github.com/opensearch-project/neural-search/pull/1193)) - Optimize embedding generation in Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191)) - Optimize embedding generation in Sparse Encoding Processor ([#1246](https://github.com/opensearch-project/neural-search/pull/1246)) +- Optimize embedding generation in Text/Image Embedding Processor ([#1249](https://github.com/opensearch-project/neural-search/pull/1249)) ### Enhancements diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 832ef4d16..0a8be7f83 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -154,7 +154,12 @@ public Map getProcessors(Processor.Parameters paramet parameters.ingestService.getClusterService() ), TextImageEmbeddingProcessor.TYPE, - new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()), + new TextImageEmbeddingProcessorFactory( + parameters.client, + clientAccessor, + parameters.env, + parameters.ingestService.getClusterService() + ), TextChunkingProcessor.TYPE, new TextChunkingProcessorFactory(parameters.env, parameters.ingestService.getClusterService(), parameters.analysisRegistry) ); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index 9923ff1d4..164c8b36e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -14,6 +14,9 @@ import java.util.function.BiConsumer; import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.get.GetAction; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; @@ -24,6 +27,8 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.optimization.TextImageEmbeddingInferenceFilter; +import org.opensearch.transport.client.OpenSearchClient; /** * This processor is used for user input data text and image embedding processing, model_id can be used to indicate which model user use, @@ -35,19 +40,24 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor { public static final String TYPE = "text_image_embedding"; public static final String MODEL_ID_FIELD = "model_id"; public static final String EMBEDDING_FIELD = "embedding"; + public static final boolean DEFAULT_SKIP_EXISTING = false; + public static final String SKIP_EXISTING = "skip_existing"; public static final String FIELD_MAP_FIELD = "field_map"; public static final String TEXT_FIELD_NAME = "text"; public static final String IMAGE_FIELD_NAME = "image"; public static final String INPUT_TEXT = "inputText"; public static final String INPUT_IMAGE = "inputImage"; + private static final String INDEX_FIELD = "_index"; + private static final String ID_FIELD = "_id"; private static final Set VALID_FIELD_NAMES = Set.of(TEXT_FIELD_NAME, IMAGE_FIELD_NAME); private final String modelId; private final String embedding; private final Map fieldMap; - + private final boolean skipExisting; + private final OpenSearchClient openSearchClient; private final MLCommonsClientAccessor mlCommonsClientAccessor; - + private final TextImageEmbeddingInferenceFilter inferenceFilter; private final Environment environment; private final ClusterService clusterService; @@ -57,6 +67,9 @@ public TextImageEmbeddingProcessor( final String modelId, final String embedding, final Map fieldMap, + final boolean skipExisting, + final TextImageEmbeddingInferenceFilter inferenceFilter, + final OpenSearchClient openSearchClient, final MLCommonsClientAccessor clientAccessor, final Environment environment, final ClusterService clusterService @@ -71,6 +84,9 @@ public TextImageEmbeddingProcessor( this.mlCommonsClientAccessor = clientAccessor; this.environment = environment; this.clusterService = clusterService; + this.skipExisting = skipExisting; + this.inferenceFilter = inferenceFilter; + this.openSearchClient = openSearchClient; } private void validateEmbeddingConfiguration(final Map fieldMap) { @@ -109,15 +125,28 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer inferenceMap = createInferences(knnMap); if (inferenceMap.isEmpty()) { handler.accept(ingestDocument, null); - } else { - mlCommonsClientAccessor.inferenceSentencesMap( - MapInferenceRequest.builder().modelId(this.modelId).inputObjects(inferenceMap).build(), - ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); }) - ); + return; + } + if (skipExisting == false) { + generateAndSetInference(ingestDocument, inferenceMap, handler); + return; } + // if skipExisting flag is turned on, eligible inference text and images will be compared and filtered after embeddings are + // copied + Object index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD); + Object id = ingestDocument.getSourceAndMetadata().get(ID_FIELD); + if (Objects.isNull(index) || Objects.isNull(id)) { + generateAndSetInference(ingestDocument, inferenceMap, handler); + return; + } + openSearchClient.execute( + GetAction.INSTANCE, + new GetRequest(index.toString(), id.toString()), + ActionListener.wrap( + response -> reuseOrGenerateEmbedding(response, ingestDocument, knnMap, inferenceMap, handler), + e -> handler.accept(null, e) + ) + ); } catch (Exception e) { handler.accept(null, e); } @@ -174,4 +203,55 @@ Map buildTextEmbeddingResult(final String knnKey, List m public String getType() { return TYPE; } + + /** + * This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument + * + * @param ingestDocument ingestDocument to populate embeddings to + * @param inferenceMap map indicating the path in ingestDocument to populate embeddings + * @param handler SourceAndMetadataMap of ingestDocument Document + * + */ + private void generateAndSetInference( + IngestDocument ingestDocument, + Map inferenceMap, + BiConsumer handler + ) { + mlCommonsClientAccessor.inferenceSentencesMap( + MapInferenceRequest.builder().modelId(this.modelId).inputObjects(inferenceMap).build(), + ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); }) + ); + } + + // This method validates and filters given knnMap and inferenceMap after response is successfully retrieved from get operation. + private void reuseOrGenerateEmbedding( + GetResponse response, + IngestDocument ingestDocument, + Map knnMap, + Map inferenceMap, + BiConsumer handler + ) { + final Map existingDocument = response.getSourceAsMap(); + if (existingDocument == null || existingDocument.isEmpty()) { + generateAndSetInference(ingestDocument, inferenceMap, handler); + return; + } + // filter given knnMap by comparing existing document with ingestDocument + Map filteredKnnMap = inferenceFilter.filterAndCopyExistingEmbeddings( + ingestDocument, + existingDocument, + knnMap, + embedding + ); + // create inference map based on filtered knnMap + Map filteredInferenceMap = createInferences(filteredKnnMap); + if (filteredInferenceMap.isEmpty()) { + handler.accept(ingestDocument, null); + } else { + generateAndSetInference(ingestDocument, filteredInferenceMap, handler); + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java index 7250e9365..608bc3251 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java @@ -4,8 +4,11 @@ */ package org.opensearch.neuralsearch.processor.factory; +import static org.opensearch.ingest.ConfigurationUtils.readBooleanProperty; import static org.opensearch.ingest.ConfigurationUtils.readMap; import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.DEFAULT_SKIP_EXISTING; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.SKIP_EXISTING; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.FIELD_MAP_FIELD; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.MODEL_ID_FIELD; @@ -20,6 +23,8 @@ import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; import lombok.AllArgsConstructor; +import org.opensearch.neuralsearch.processor.optimization.TextImageEmbeddingInferenceFilter; +import org.opensearch.transport.client.OpenSearchClient; /** * Factory for text_image embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. @@ -27,6 +32,7 @@ @AllArgsConstructor public class TextImageEmbeddingProcessorFactory implements Processor.Factory { + private final OpenSearchClient openSearchClient; private final MLCommonsClientAccessor clientAccessor; private final Environment environment; private final ClusterService clusterService; @@ -36,7 +42,20 @@ public Processor create(Map processorFactories, Strin throws Exception { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); String embedding = readStringProperty(TYPE, tag, config, EMBEDDING_FIELD); - Map filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); - return new TextImageEmbeddingProcessor(tag, description, modelId, embedding, filedMap, clientAccessor, environment, clusterService); + Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + boolean skipExisting = readBooleanProperty(TextImageEmbeddingProcessor.TYPE, tag, config, SKIP_EXISTING, DEFAULT_SKIP_EXISTING); + return new TextImageEmbeddingProcessor( + tag, + description, + modelId, + embedding, + fieldMap, + skipExisting, + skipExisting ? new TextImageEmbeddingInferenceFilter() : null, + openSearchClient, + clientAccessor, + environment, + clusterService + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/optimization/TextImageEmbeddingInferenceFilter.java b/src/main/java/org/opensearch/neuralsearch/processor/optimization/TextImageEmbeddingInferenceFilter.java new file mode 100644 index 000000000..aa00a8f68 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/optimization/TextImageEmbeddingInferenceFilter.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.optimization; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.ingest.IngestDocument; + +import java.util.Collections; +import java.util.Map; +import java.util.Objects; + +/** + * TextImageEmbeddingInferenceFilter optimizes text/image embedding inference by selectively processing text/image data. + * This class provides efficient text/image embedding processing by comparing text/image between existing and new documents. + * If both text and image are identical, the corresponding embeddings are copied over, avoiding redundant inference calls and improving performance. + */ +@Log4j2 +public class TextImageEmbeddingInferenceFilter { + + public TextImageEmbeddingInferenceFilter() {} + + /** + * Filters the given knnMap by checking if the values for both text and image are identical in the existing and new document. + * If both values for text and image match, the corresponding embedding is copied, and empty map is returned, indicating no further + * processing is required. If any of the two do not match or embedding does not exist, the given knnMap is returned to be processed + * + * @return empty Map if embeddings are reused; the original knnMap otherwise. + */ + public Map filterAndCopyExistingEmbeddings( + IngestDocument ingestDocument, + Map existingDocument, + Map knnMap, + String embeddingField + ) { + // knnMap can only contain two keys: one for text field and another for image field. + // If either of the two does not match, knnMap cannot be filtered + for (Map.Entry entry : knnMap.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + if (existingDocument.containsKey(key) == false || existingDocument.get(key).equals(value) == false) { + return knnMap; + } + } + Object embeddings = existingDocument.get(embeddingField); + if (Objects.isNull(embeddings)) { + return knnMap; + } + ingestDocument.setFieldValue(embeddingField, existingDocument.get(embeddingField)); + return Collections.emptyMap(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java index 313c5cb07..3d142a362 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java @@ -79,4 +79,24 @@ public void testEmbeddingProcessor_whenReindexingDocument_thenSuccessful() throw reindex(fromIndexName, toIndexName); assertEquals(1, getDocCount(toIndexName)); } + + public void testEmbeddingProcessor_whenSkipExisting_updateWithNoChange_thenSuccessful() throws Exception { + String modelId = uploadModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING_WITH_SKIP_EXISTING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOCUMENT, "1"); + updateDocument(INDEX_NAME, INGEST_DOCUMENT, "1"); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + public void testEmbeddingProcessor_whenSkipExisting_updateWithChange_thenSuccessful() throws Exception { + String modelId = uploadModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING_WITH_SKIP_EXISTING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOCUMENT, "1"); + updateDocument(INDEX_NAME, INGEST_DOCUMENT.replace("\"This is a good day\"", "\"This is a great day\""), "1"); + assertEquals(1, getDocCount(INDEX_NAME)); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java index e6306523e..017e40f86 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -12,11 +12,13 @@ import static org.mockito.Mockito.isA; import static org.mockito.Mockito.isNull; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.IMAGE_FIELD_NAME; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.TEXT_FIELD_NAME; +import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -25,17 +27,27 @@ import java.util.function.Supplier; import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; +import org.opensearch.action.get.GetAction; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.env.Environment; +import org.opensearch.index.get.GetResult; import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.Processor; @@ -47,9 +59,12 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.opensearch.transport.client.OpenSearchClient; public class TextImageEmbeddingProcessorTests extends OpenSearchTestCase { + @Mock + private OpenSearchClient openSearchClient; @Mock private MLCommonsClientAccessor mlCommonsClientAccessor; @Mock @@ -63,6 +78,9 @@ public class TextImageEmbeddingProcessorTests extends OpenSearchTestCase { @Mock private IndexMetadata indexMetadata; + @Captor + private ArgumentCaptor inferenceRequestCaptor; + @InjectMocks private TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory; private static final String PROCESSOR_TAG = "mockTag"; @@ -80,11 +98,12 @@ public void setup() { } @SneakyThrows - private TextImageEmbeddingProcessor createInstance() { + private TextImageEmbeddingProcessor createInstance(boolean skipExisting) { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextImageEmbeddingProcessor.EMBEDDING_FIELD, "my_embedding_field"); + config.put(TextImageEmbeddingProcessor.SKIP_EXISTING, skipExisting); config.put( TextImageEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "image_field") @@ -93,7 +112,7 @@ private TextImageEmbeddingProcessor createInstance() { } @SneakyThrows - public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() { + public void testTextImageEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); @@ -105,7 +124,7 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalA } @SneakyThrows - public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_throwIllegalArgumentException() { + public void testTextImageEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_throwIllegalArgumentException() { boolean ignoreFailure = false; String modelId = "mockModelId"; String embeddingField = "my_embedding_field"; @@ -119,6 +138,9 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t modelId, embeddingField, null, + false, + null, + openSearchClient, mlCommonsClientAccessor, env, clusterService @@ -135,6 +157,9 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t modelId, embeddingField, Map.of("", "my_field"), + false, + null, + openSearchClient, mlCommonsClientAccessor, env, clusterService @@ -155,6 +180,9 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t modelId, embeddingField, typeMapping, + false, + null, + openSearchClient, mlCommonsClientAccessor, env, clusterService @@ -164,7 +192,7 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t } @SneakyThrows - public void testTextEmbeddingProcessConstructor_whenEmptyModelId_throwIllegalArgumentException() { + public void testTextImageEmbeddingProcessConstructor_whenEmptyModelId_throwIllegalArgumentException() { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, ""); @@ -190,7 +218,7 @@ public void testExecute_successful() { sourceAndMetadata.put("key5", Map.of("inner_field", "innerValue1")); sourceAndMetadata.put("image_field", "base64_of_image_1234567890"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(); + TextImageEmbeddingProcessor processor = createInstance(false); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -213,7 +241,9 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); Map registry = new HashMap<>(); MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); + OpenSearchClient openSearchClient = mock(OpenSearchClient.class); TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + openSearchClient, accessor, env, clusterService @@ -244,7 +274,7 @@ public void testExecute_withListTypeInput_successful() { sourceAndMetadata.put("my_text_field", "value1"); sourceAndMetadata.put("another_text_field", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(); + TextImageEmbeddingProcessor processor = createInstance(false); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -266,7 +296,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { sourceAndMetadata.put("my_text_field", ret); sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(); + TextImageEmbeddingProcessor processor = createInstance(false); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -278,7 +308,7 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { sourceAndMetadata.put("my_text_field", "value1"); sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(); + TextImageEmbeddingProcessor processor = createInstance(false); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(1); @@ -300,7 +330,7 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { sourceAndMetadata.put("my_text_field", map2); sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(); + TextImageEmbeddingProcessor processor = createInstance(false); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -314,7 +344,7 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { sourceAndMetadata.put("my_text_field", map2); sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(); + TextImageEmbeddingProcessor processor = createInstance(false); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -326,7 +356,7 @@ public void testExecute_hybridTypeInput_successful() throws Exception { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key2", map1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(); + TextImageEmbeddingProcessor processor = createInstance(false); IngestDocument document = processor.execute(ingestDocument); assert document.getSourceAndMetadata().containsKey("key2"); } @@ -337,7 +367,7 @@ public void testExecute_whenInferencesAreEmpty_thenSuccessful() { sourceAndMetadata.put("my_field", "value1"); sourceAndMetadata.put("another_text_field", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(); + TextImageEmbeddingProcessor processor = createInstance(false); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -352,15 +382,117 @@ public void testExecute_whenInferencesAreEmpty_thenSuccessful() { verify(handler).accept(any(IngestDocument.class), isNull()); } + public void testExecute_no_update_skip_existing_flag_successful() { + Map ingestSourceAndMetadata = getIngestDocument(); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + Map updateSourceAndMetadata = getIngestDocument(); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(true); + Map inferenceMap = Map.of("inputText", "value2", "inputImage", "base64_of_image_1234567890"); + MapInferenceRequest ingestRequest = MapInferenceRequest.builder().modelId("mockModelId").inputObjects(inferenceMap).build(); + + mockUpdateVectorCreation(); + mockUpdateDocument(ingestDocument); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + processor.execute(updateDocument, handler); + verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(1)).inferenceSentencesMap(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + assertEquals(ingestRequest.getInputObjects(), inferenceRequestCaptor.getValue().getInputObjects()); + verifyEqualEmbedding( + (List>) ingestSourceAndMetadata.get("my_embedding_field"), + (List>) updateSourceAndMetadata.get("my_embedding_field") + ); + } + + public void testExecute_with_text_update_skip_existing_flag_successful() { + Map ingestSourceAndMetadata = getIngestDocument(); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + Map updateSourceAndMetadata = getIngestDocument(); + updateSourceAndMetadata.put("my_text_field", "newValue"); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(true); + Map ingestInferenceMap = Map.of("inputText", "value2", "inputImage", "base64_of_image_1234567890"); + MapInferenceRequest ingestRequest = MapInferenceRequest.builder().modelId("mockModelId").inputObjects(ingestInferenceMap).build(); + Map updateInferenceMap = Map.of("inputText", "newValue", "inputImage", "base64_of_image_1234567890"); + MapInferenceRequest updateRequest = MapInferenceRequest.builder().modelId("mockModelId").inputObjects(updateInferenceMap).build(); + + mockUpdateVectorCreation(); + mockUpdateDocument(ingestDocument); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + processor.execute(updateDocument, handler); + verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(2)).inferenceSentencesMap(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + List requests = inferenceRequestCaptor.getAllValues(); + assertEquals(ingestRequest.getInputObjects(), requests.get(0).getInputObjects()); + assertEquals(updateRequest.getInputObjects(), requests.get(1).getInputObjects()); + assertEquals( + ((List) ingestSourceAndMetadata.get("my_embedding_field")).size(), + ((List) updateSourceAndMetadata.get("my_embedding_field")).size() + ); + } + + public void testExecute_with_image_update_skip_existing_flag_successful() { + Map ingestSourceAndMetadata = getIngestDocument(); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + Map updateSourceAndMetadata = getIngestDocument(); + updateSourceAndMetadata.put("image_field", "newImage"); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(true); + Map ingestInferenceMap = Map.of("inputText", "value2", "inputImage", "base64_of_image_1234567890"); + MapInferenceRequest ingestRequest = MapInferenceRequest.builder().modelId("mockModelId").inputObjects(ingestInferenceMap).build(); + Map updateInferenceMap = Map.of("inputText", "value2", "inputImage", "newImage"); + MapInferenceRequest updateRequest = MapInferenceRequest.builder().modelId("mockModelId").inputObjects(updateInferenceMap).build(); + + mockUpdateVectorCreation(); + mockUpdateDocument(ingestDocument); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + processor.execute(updateDocument, handler); + verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(2)).inferenceSentencesMap(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + List requests = inferenceRequestCaptor.getAllValues(); + assertEquals(ingestRequest.getInputObjects(), requests.get(0).getInputObjects()); + assertEquals(updateRequest.getInputObjects(), requests.get(1).getInputObjects()); + assertEquals( + ((List) ingestSourceAndMetadata.get("my_embedding_field")).size(), + ((List) updateSourceAndMetadata.get("my_embedding_field")).size() + ); + } + + public void testExecute_OpensearchClientAccessorThrowFail_handlerFailure() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("_id", "1"); + sourceAndMetadata.put("my_text_field", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(openSearchClient).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(RuntimeException.class)); + } + private List> createMockVectorResult() { List> modelTensorList = new ArrayList<>(); - List number1 = ImmutableList.of(1.234f, 2.354f); - List number2 = ImmutableList.of(3.234f, 4.354f); - List number3 = ImmutableList.of(5.234f, 6.354f); - List number4 = ImmutableList.of(7.234f, 8.354f); - List number5 = ImmutableList.of(9.234f, 10.354f); - List number6 = ImmutableList.of(11.234f, 12.354f); - List number7 = ImmutableList.of(13.234f, 14.354f); + List number1 = ImmutableList.of(randomFloat(), randomFloat()); + List number2 = ImmutableList.of(randomFloat(), randomFloat()); + List number3 = ImmutableList.of(randomFloat(), randomFloat()); + List number4 = ImmutableList.of(randomFloat(), randomFloat()); + List number5 = ImmutableList.of(randomFloat(), randomFloat()); + List number6 = ImmutableList.of(randomFloat(), randomFloat()); + List number7 = ImmutableList.of(randomFloat(), randomFloat()); modelTensorList.add(number1); modelTensorList.add(number2); modelTensorList.add(number3); @@ -394,4 +526,75 @@ private Map createMaxDepthLimitExceedMap(Supplier maxDe innerMap.put("hello", ret); return innerMap; } + + private Map getIngestDocument() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("_id", "1"); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("my_text_field", "value2"); + sourceAndMetadata.put("text", ""); + sourceAndMetadata.put("image", null); + sourceAndMetadata.put("key5", Map.of("inner_field", "innerValue1")); + sourceAndMetadata.put("image_field", "base64_of_image_1234567890"); + return sourceAndMetadata; + } + + private void mockUpdateVectorCreation() { + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(1); + listener.onResponse(createMockVectorResult()); + return null; + }).doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(1); + listener.onResponse(createMockVectorResult()); + return null; + }) + .when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); + } + + private void mockUpdateDocument(IngestDocument ingestDocument) { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mockEmptyGetResponse()); // returns empty result for ingest action + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(convertToGetResponse(ingestDocument)); // returns previously ingested document for update action + return null; + }).when(openSearchClient).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + } + + protected GetResponse convertToGetResponse(IngestDocument ingestDocument) throws IOException { + String index = ingestDocument.getSourceAndMetadata().get("_index").toString(); + String id = ingestDocument.getSourceAndMetadata().get("_id").toString(); + Map source = ingestDocument.getSourceAndMetadata(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.map(source); + BytesReference bytes = BytesReference.bytes(builder); + GetResult result = new GetResult(index, id, 0, 1, 1, true, bytes, null, null); + return new GetResponse(result); + } + + protected GetResponse mockEmptyGetResponse() throws IOException { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field("_index", "my_index") + .field("_id", "1") + .field("found", false) + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + return GetResponse.fromXContent(contentParser); + } + + private void verifyEqualEmbedding(List> insertVectors, List> updateVectors) { + assertEquals(insertVectors.size(), updateVectors.size()); + for (int i = 0; i < insertVectors.size(); i++) { + for (int j = 0; j < insertVectors.get(i).size(); j++) { + assertEquals(insertVectors.get(i).get(j).floatValue(), updateVectors.get(i).get(j).floatValue(), 0.0000001f); + } + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java index cfb0803a6..361b88d73 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java @@ -11,6 +11,7 @@ import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.IMAGE_FIELD_NAME; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.SKIP_EXISTING; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.TEXT_FIELD_NAME; import java.util.HashMap; @@ -23,12 +24,14 @@ import org.opensearch.test.OpenSearchTestCase; import lombok.SneakyThrows; +import org.opensearch.transport.client.OpenSearchClient; public class TextImageEmbeddingProcessorFactoryTests extends OpenSearchTestCase { @SneakyThrows - public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() { + public void testTextImageEmbeddingProcessor_whenAllParamsPassed_thenSuccessful() { TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(OpenSearchClient.class), mock(MLCommonsClientAccessor.class), mock(Environment.class), mock(ClusterService.class) @@ -42,6 +45,7 @@ public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() { config.put(MODEL_ID_FIELD, "1234567678"); config.put(EMBEDDING_FIELD, "embedding_field"); config.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "my_image_field")); + config.put(SKIP_EXISTING, true); TextImageEmbeddingProcessor inferenceProcessor = (TextImageEmbeddingProcessor) textImageEmbeddingProcessorFactory.create( processorFactories, tag, @@ -53,8 +57,9 @@ public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() { } @SneakyThrows - public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() { + public void testTextImageEmbeddingProcessor_whenOnlyOneParamSet_thenSuccessful() { TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(OpenSearchClient.class), mock(MLCommonsClientAccessor.class), mock(Environment.class), mock(ClusterService.class) @@ -92,8 +97,9 @@ public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() { } @SneakyThrows - public void testNormalizationProcessor_whenMixOfParamsOrEmptyParams_thenFail() { + public void testTextImageEmbeddingProcessor_whenMixOfParamsOrEmptyParams_thenFail() { TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(OpenSearchClient.class), mock(MLCommonsClientAccessor.class), mock(Environment.class), mock(ClusterService.class) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/optimization/TextImageEmbeddingInferenceFilterTests.java b/src/test/java/org/opensearch/neuralsearch/processor/optimization/TextImageEmbeddingInferenceFilterTests.java new file mode 100644 index 000000000..fb3535c3f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/optimization/TextImageEmbeddingInferenceFilterTests.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.optimization; + +import org.junit.Before; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +public class TextImageEmbeddingInferenceFilterTests extends OpenSearchTestCase { + + private Map knnMap; + private Map existingSourceAndMetadataMap; + private IngestDocument ingestDocument; + private TextImageEmbeddingInferenceFilter textImageEmbeddingInferenceFilter; + private final String embeddingField = "vector_embedding"; + + @Before + public void setup() { + textImageEmbeddingInferenceFilter = new TextImageEmbeddingInferenceFilter(); + knnMap = new HashMap<>(); + existingSourceAndMetadataMap = new HashMap<>(); + existingSourceAndMetadataMap.put("image_description", "orange desk"); + existingSourceAndMetadataMap.put("image_binary", "base64_of_orange_desk_image"); + existingSourceAndMetadataMap.put(embeddingField, Arrays.asList(0.1, 0.2, 0.3)); + ingestDocument = new IngestDocument(new HashMap<>(), new HashMap<>()); + } + + public void test_filterAndCopyExistingEmbeddings_TextAndImageUnchanged_ShouldCopyEmbedding() { + knnMap.put("image_description", "orange desk"); + knnMap.put("image_binary", "base64_of_orange_desk_image"); + + Map result = textImageEmbeddingInferenceFilter.filterAndCopyExistingEmbeddings( + ingestDocument, + existingSourceAndMetadataMap, + knnMap, + embeddingField + ); + assertTrue(result.isEmpty()); + assertEquals(existingSourceAndMetadataMap.get(embeddingField), ingestDocument.getSourceAndMetadata().get(embeddingField)); + } + + public void test_filterAndCopyExistingEmbeddings_TextChanged_ShouldNotCopyEmbedding() { + knnMap.put("image_description", "blue desk"); + knnMap.put("image_binary", "base64_of_orange_desk_image"); + + Map result = textImageEmbeddingInferenceFilter.filterAndCopyExistingEmbeddings( + ingestDocument, + existingSourceAndMetadataMap, + knnMap, + embeddingField + ); + assertEquals(result, knnMap); + assertNull(ingestDocument.getSourceAndMetadata().get(embeddingField)); + } + + public void test_filterAndCopyExistingEmbeddings_ImageChanged_ShouldNotCopyEmbedding() { + knnMap.put("image_description", "orange desk"); + knnMap.put("image_binary", "base64_of_blue_desk_image"); + + Map result = textImageEmbeddingInferenceFilter.filterAndCopyExistingEmbeddings( + ingestDocument, + existingSourceAndMetadataMap, + knnMap, + embeddingField + ); + assertEquals(result, knnMap); + assertNull(ingestDocument.getSourceAndMetadata().get(embeddingField)); + } + + public void test_filterAndCopyExistingEmbeddings_EmbeddingDoesNotExist_ShouldNotCopyEmbedding() { + knnMap.put("image_description", "orange desk"); + knnMap.put("image_binary", "base64_of_blue_desk_image"); + existingSourceAndMetadataMap.remove(embeddingField); + Map result = textImageEmbeddingInferenceFilter.filterAndCopyExistingEmbeddings( + ingestDocument, + existingSourceAndMetadataMap, + knnMap, + embeddingField + ); + assertEquals(result, knnMap); + assertNull(ingestDocument.getSourceAndMetadata().get(embeddingField)); + } +} diff --git a/src/test/resources/processor/PipelineForTextImageEmbeddingWithSkipExistingProcessorConfiguration.json b/src/test/resources/processor/PipelineForTextImageEmbeddingWithSkipExistingProcessorConfiguration.json new file mode 100644 index 000000000..fa5bdcaca --- /dev/null +++ b/src/test/resources/processor/PipelineForTextImageEmbeddingWithSkipExistingProcessorConfiguration.json @@ -0,0 +1,16 @@ +{ + "description": "text image embedding pipeline", + "processors": [ + { + "text_image_embedding": { + "model_id": "%s", + "embedding": "passage_embedding", + "field_map": { + "text": "passage_text", + "image": "passage_image" + }, + "skip_existing": true + } + } + ] +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 841e8af8b..3aae14c37 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -108,6 +108,8 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { "processor/SparseEncodingPipelineConfigurationWithSkipExisting.json", ProcessorType.TEXT_IMAGE_EMBEDDING, "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json", + ProcessorType.TEXT_IMAGE_EMBEDDING_WITH_SKIP_EXISTING, + "processor/PipelineForTextImageEmbeddingWithSkipExistingProcessorConfiguration.json", ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, "processor/PipelineConfigurationWithNestedFieldsMapping.json", ProcessorType.TEXT_EMBEDDING_WITH_SKIP_EXISTING, @@ -1964,6 +1966,7 @@ protected enum ProcessorType { TEXT_EMBEDDING_WITH_SKIP_EXISTING, TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING_WITH_SKIP_EXISTING, TEXT_IMAGE_EMBEDDING, + TEXT_IMAGE_EMBEDDING_WITH_SKIP_EXISTING, SPARSE_ENCODING, SPARSE_ENCODING_WITH_SKIP_EXISTING, SPARSE_ENCODING_PRUNE