diff --git a/CHANGELOG.md b/CHANGELOG.md index 90ec7541c3..413110cb9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ All notable changes to this project are documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). See the [CONTRIBUTING guide](./CONTRIBUTING.md#Changelog) for instructions on how to add changelog entries. ## [Unreleased 3.3](https://github.com/opensearch-project/k-NN/compare/main...HEAD) +### Features +* Support native Maximal Marginal Relevance [#2868](https://github.com/opensearch-project/k-NN/pull/2868) ### Maintenance * Replace commons-lang with org.apache.commons:commons-lang3 [#2863](https://github.com/opensearch-project/k-NN/pull/2863) * Bump OpenSearch-Protobufs to 0.13.0 [#2833](https://github.com/opensearch-project/k-NN/pull/2833) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index e97ad6d224..bb71385ec3 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -191,4 +191,12 @@ public class KNNConstants { public static final Integer INDEX_THREAD_QUANTITY_DEFAULT_LARGE = 4; public static final Integer INDEX_THREAD_QUANTITY_DEFAULT_SMALL = 1; + // mmr + public static final String MMR = "mmr"; + public static final String DIVERSITY = "diversity"; + public static final String CANDIDATES = "candidates"; + public static final String VECTOR_FIELD_PATH = "vector_field_path"; + public static final String VECTOR_FIELD_DATA_TYPE = "vector_field_data_type"; + public static final String VECTOR_FIELD_SPACE_TYPE = "vector_field_space_type"; + public static final String MMR_RERANK_CONTEXT = "mmr.rerank_context"; } diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 67c1f27ecb..9aa426cf5f 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -188,7 +188,7 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { public static SpaceType DEFAULT = L2; public static SpaceType DEFAULT_BINARY = HAMMING; - private static final String[] VALID_VALUES = Arrays.stream(SpaceType.values()) + public static final String[] VALID_VALUES = Arrays.stream(SpaceType.values()) .filter(space -> space != SpaceType.UNDEFINED) .map(SpaceType::getValue) .collect(Collectors.toList()) diff --git a/src/main/java/org/opensearch/knn/index/engine/SpaceTypeResolver.java b/src/main/java/org/opensearch/knn/index/engine/SpaceTypeResolver.java index 31dfb18b51..d1269d62f7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/SpaceTypeResolver.java +++ b/src/main/java/org/opensearch/knn/index/engine/SpaceTypeResolver.java @@ -6,7 +6,6 @@ package org.opensearch.knn.index.engine; import org.apache.logging.log4j.util.Strings; -import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.MapperParsingException; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -38,15 +37,11 @@ private SpaceTypeResolver() {} public SpaceType resolveSpaceType( final KNNMethodContext knnMethodContext, final String topLevelSpaceTypeString, - final Settings indexSettings, final VectorDataType vectorDataType ) { SpaceType methodSpaceType = getSpaceTypeFromMethodContext(knnMethodContext); SpaceType topLevelSpaceType = getSpaceTypeFromString(topLevelSpaceTypeString); - // If we failed to find space type from both method context and top level - // 1. We try to get it from index setting, which is a relic of legacy. - // 2. Otherwise, we return a default one. if (isSpaceTypeConfigured(methodSpaceType) == false && isSpaceTypeConfigured(topLevelSpaceType) == false) { return getDefaultSpaceType(vectorDataType); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 7ebc84818e..4048df2b99 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -424,7 +424,6 @@ public Mapper.Builder parse(String name, Map node, ParserCont final SpaceType resolvedSpaceType = SpaceTypeResolver.INSTANCE.resolveSpaceType( builder.originalParameters.getKnnMethodContext(), builder.topLevelSpaceType.get(), - parserContext.getSettings(), builder.vectorDataType.get() ); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 2ee26813ed..8c901f96aa 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -8,6 +8,7 @@ import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; +import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; @@ -96,6 +97,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder imple private final String fieldName; private final float[] vector; @Getter + @Setter private Integer k; @Getter private Float maxDistance; diff --git a/src/main/java/org/opensearch/knn/index/util/KNNClusterUtil.java b/src/main/java/org/opensearch/knn/index/util/KNNClusterUtil.java index c9c96bea50..3c9ede4bed 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNClusterUtil.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNClusterUtil.java @@ -6,10 +6,23 @@ package org.opensearch.knn.index.util; import lombok.AccessLevel; +import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.log4j.Log4j2; import org.opensearch.Version; +import org.opensearch.action.IndicesRequest; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.index.Index; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.opensearch.search.pipeline.SearchPipelineService.ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING; /** * Class abstracts information related to underlying OpenSearch cluster @@ -20,6 +33,9 @@ public class KNNClusterUtil { private ClusterService clusterService; private static KNNClusterUtil instance; + private IndexNameExpressionResolver indexNameExpressionResolver; + @Getter + private List enabledSystemGeneratedFactories = Collections.emptyList(); /** * Return instance of the cluster context, must be initialized first for proper usage @@ -35,9 +51,17 @@ public static synchronized KNNClusterUtil instance() { /** * Initializes instance of cluster context by injecting dependencies * @param clusterService + * @param indexNameExpressionResolver */ - public void initialize(final ClusterService clusterService) { + public void initialize(final ClusterService clusterService, final IndexNameExpressionResolver indexNameExpressionResolver) { this.clusterService = clusterService; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.enabledSystemGeneratedFactories = clusterService.getClusterSettings().get(ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer( + ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING, + factories -> enabledSystemGeneratedFactories = factories + ); } /** @@ -55,4 +79,16 @@ public Version getClusterMinVersion() { return Version.CURRENT; } } + + /** + * + * @param searchRequest + * @return IndexMetadata of the indices of the search request + */ + public List getIndexMetadataList(@NonNull final IndicesRequest searchRequest) { + final Index[] concreteIndices = this.indexNameExpressionResolver.concreteIndices(clusterService.state(), searchRequest); + return Arrays.stream(concreteIndices) + .map(concreteIndex -> clusterService.state().metadata().index(concreteIndex)) + .collect(Collectors.toList()); + } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 7e84ae458c..8a3ffa165a 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -30,6 +30,7 @@ import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; import org.opensearch.index.mapper.Mapper; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.shard.IndexSettingProvider; import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.knn.index.KNNCircuitBreaker; @@ -88,6 +89,12 @@ import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; import org.opensearch.knn.profile.query.KNNMetrics; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; +import org.opensearch.knn.search.extension.MMRSearchExtBuilder; + +import org.opensearch.knn.search.processor.mmr.MMRKnnQueryTransformer; +import org.opensearch.knn.search.processor.mmr.MMROverSampleProcessor; +import org.opensearch.knn.search.processor.mmr.MMRQueryTransformer; +import org.opensearch.knn.search.processor.mmr.MMRRerankProcessor; import org.opensearch.knn.training.TrainingJobClusterStateListener; import org.opensearch.knn.training.TrainingJobRunner; import org.opensearch.knn.training.VectorReader; @@ -100,6 +107,7 @@ import org.opensearch.plugins.Plugin; import org.opensearch.plugins.ReloadablePlugin; import org.opensearch.plugins.ScriptPlugin; +import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.plugins.SystemIndexPlugin; import org.opensearch.remoteindexbuild.client.RemoteIndexHTTPClient; @@ -109,7 +117,11 @@ import org.opensearch.script.ScriptContext; import org.opensearch.script.ScriptEngine; import org.opensearch.script.ScriptService; +import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.deciders.ConcurrentSearchRequestDecider; +import org.opensearch.search.pipeline.SearchRequestProcessor; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.search.pipeline.SystemGeneratedProcessor; import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; import org.opensearch.threadpool.ThreadPool; @@ -120,7 +132,9 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.concurrent.ForkJoinPool; @@ -173,7 +187,8 @@ public class KNNPlugin extends Plugin ScriptPlugin, ExtensiblePlugin, SystemIndexPlugin, - ReloadablePlugin { + ReloadablePlugin, + SearchPipelinePlugin { public static final String LEGACY_KNN_BASE_URI = "/_opendistro/_knn"; public static final String KNN_BASE_URI = "/_plugins/_knn"; @@ -181,6 +196,7 @@ public class KNNPlugin extends Plugin private KNNStats knnStats; private ClusterService clusterService; private Supplier repositoriesServiceSupplier; + private final Map> mmrQueryTransformers = new HashMap<>(); static { ForkJoinPool.commonPool().execute(() -> { @@ -237,7 +253,7 @@ public Collection createComponents( NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader); KNNSettings.state().initialize(client, clusterService); - KNNClusterUtil.instance().initialize(clusterService); + KNNClusterUtil.instance().initialize(clusterService, indexNameExpressionResolver); ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings()); ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance()); @@ -356,6 +372,27 @@ public void onIndexModule(IndexModule indexModule) { } } + @Override + public void loadExtensions(ExtensionLoader loader) { + // knn plugin cannot extend itself so we have to manually load the transformer implemented in knn plugin + mmrQueryTransformers.put(KNNQueryBuilder.NAME, new MMRKnnQueryTransformer()); + for (MMRQueryTransformer transformer : loader.loadExtensions(MMRQueryTransformer.class)) { + String queryName = transformer.getQueryName(); + if (mmrQueryTransformers.containsKey(queryName)) { + throw new IllegalStateException( + String.format( + Locale.ROOT, + "Already load the MMR query transformer %s for %s query. Cannot load another transformer %s for it.", + mmrQueryTransformers.get(queryName).getClass().getName(), + queryName, + transformer.getClass().getName() + ) + ); + } + mmrQueryTransformers.put(queryName, transformer); + } + } + /** * Sample knn custom script * @@ -447,4 +484,26 @@ public void reload(Settings settings) { SecureString password = KNNSettings.KNN_REMOTE_BUILD_SERVER_PASSWORD_SETTING.get(settings); RemoteIndexHTTPClient.reloadAuthHeader(username, password); } + + @Override + public List> getSearchExts() { + return List.of(new SearchExtSpec(MMRSearchExtBuilder.NAME, MMRSearchExtBuilder::new, MMRSearchExtBuilder::parse)); + } + + @Override + public Map> getSystemGeneratedRequestProcessors( + Parameters parameters + ) { + return Map.of( + MMROverSampleProcessor.MMROverSampleProcessorFactory.TYPE, + new MMROverSampleProcessor.MMROverSampleProcessorFactory(parameters.client, mmrQueryTransformers) + ); + } + + @Override + public Map> getSystemGeneratedResponseProcessors( + Parameters parameters + ) { + return Map.of(MMRRerankProcessor.MMRRerankProcessorFactory.TYPE, new MMRRerankProcessor.MMRRerankProcessorFactory()); + } } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index 3bd38f1c26..554c8e7d82 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -177,7 +177,6 @@ && ensureSpaceTypeNotSet(topLevelSpaceType)) { SpaceType resolvedSpaceType = SpaceTypeResolver.INSTANCE.resolveSpaceType( knnMethodContext, topLevelSpaceType.getValue(), - null, vectorDataType ); setSpaceType(knnMethodContext, resolvedSpaceType); diff --git a/src/main/java/org/opensearch/knn/search/extension/MMRSearchExtBuilder.java b/src/main/java/org/opensearch/knn/search/extension/MMRSearchExtBuilder.java new file mode 100644 index 0000000000..9cb15e61a8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/search/extension/MMRSearchExtBuilder.java @@ -0,0 +1,285 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.extension; + +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.util.KNNClusterUtil; +import org.opensearch.knn.search.processor.mmr.MMROverSampleProcessor; +import org.opensearch.knn.search.processor.mmr.MMRRerankProcessor; +import org.opensearch.search.SearchExtBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Objects; + +import static org.opensearch.knn.common.KNNConstants.CANDIDATES; +import static org.opensearch.knn.common.KNNConstants.DIVERSITY; +import static org.opensearch.knn.common.KNNConstants.MMR; +import static org.opensearch.knn.common.KNNConstants.VECTOR_FIELD_DATA_TYPE; +import static org.opensearch.knn.common.KNNConstants.VECTOR_FIELD_PATH; +import static org.opensearch.knn.common.KNNConstants.VECTOR_FIELD_SPACE_TYPE; +import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; +import static org.opensearch.search.pipeline.SearchPipelineService.ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING; + +/** + * Search extension for Maximal Marginal Relevance + */ +@AllArgsConstructor(access = AccessLevel.PRIVATE) +@Getter +public class MMRSearchExtBuilder extends SearchExtBuilder { + + public static final String NAME = MMR; + + // Used to control the weight of the diversity, range is from [0,1]. (diversity = 1) prioritizes maximum diversity + // which means the documents are selected just based on how different they are from already chosen documents. + public static final ParseField DIVERSITY_FIELD = new ParseField(DIVERSITY); + // Used to control how many candidates we should oversample for MMR + public static final ParseField CANDIDATES_FIELD = new ParseField(CANDIDATES); + // Path to the vector field used for MMR re-rank. Optional. If not provided we should auto resolve it based on the + // search request. + public static final ParseField VECTOR_FIELD_PATH_FIELD = new ParseField(VECTOR_FIELD_PATH); + // Data type of the vector field. Used to decide how to parse the vector field to calculate the similarity. + // Optional. If not provided we should auto resolve it from the index mapping. + public static final ParseField VECTOR_FIELD_DATA_TYPE_FIELD = new ParseField(VECTOR_FIELD_DATA_TYPE); + // Space type of the vector field which is used to decide the similarity function. Optional. If not provided we + // should auto resolve it from the index mapping. + public static final ParseField VECTOR_FIELD_SPACE_TYPE_FIELD = new ParseField(VECTOR_FIELD_SPACE_TYPE); + + private Float diversity; + private Integer candidates; + private String vectorFieldPath; + private VectorDataType vectorFieldDataType; + private SpaceType spaceType; + + public static class Builder { + private Float diversity; + private Integer candidates; + private String vectorFieldPath; + private VectorDataType vectorFieldDataType; + private SpaceType spaceType; + + public Builder() {} + + public Builder diversity(Float diversity) { + this.diversity = diversity; + return this; + } + + public Builder candidates(Integer candidates) { + this.candidates = candidates; + return this; + } + + public Builder vectorFieldPath(String vectorFieldPath) { + this.vectorFieldPath = vectorFieldPath; + return this; + } + + public Builder vectorFieldDataType(String vectorFieldDataType) { + try { + this.vectorFieldDataType = VectorDataType.valueOf(vectorFieldDataType.toUpperCase(Locale.ROOT)); + return this; + } catch (Exception e) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "%s in mmr query extension is not valid, valid values are %s.", + VECTOR_FIELD_DATA_TYPE_FIELD.getPreferredName(), + SUPPORTED_VECTOR_DATA_TYPES + ) + ); + } + } + + public Builder spaceType(String spaceType) { + if (!Arrays.stream(SpaceType.VALID_VALUES).toList().contains(spaceType)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "%s in mmr query extension is not valid, valid values are %s.", + VECTOR_FIELD_SPACE_TYPE_FIELD.getPreferredName(), + String.join(",", SpaceType.VALID_VALUES) + ) + ); + } + this.spaceType = SpaceType.getSpace(spaceType); + return this; + } + + public MMRSearchExtBuilder build() { + setDefault(); + validate(); + return new MMRSearchExtBuilder(diversity, candidates, vectorFieldPath, vectorFieldDataType, spaceType); + } + + private void setDefault() { + if (diversity == null) { + diversity = 0.5f; + } + } + + private void validate() { + if (diversity < 0.0 || diversity > 1.0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "%s in mmr query extension must be between 0.0 and 1.0", DIVERSITY_FIELD.getPreferredName()) + ); + } + + if (candidates != null && candidates < 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "%s in mmr query extension must be larger than 0.", CANDIDATES_FIELD.getPreferredName()) + ); + } + + if (vectorFieldPath != null && vectorFieldPath.isEmpty()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "%s in mmr query extension should not be an empty string.", + VECTOR_FIELD_PATH_FIELD.getPreferredName() + ) + ); + } + } + } + + public MMRSearchExtBuilder(StreamInput in) throws IOException { + diversity = in.readOptionalFloat(); + candidates = in.readOptionalVInt(); + vectorFieldPath = in.readOptionalString(); + String vectorFieldDataTypeStr = in.readOptionalString(); + if (vectorFieldDataTypeStr != null) { + vectorFieldDataType = VectorDataType.get(vectorFieldDataTypeStr); + } + String spaceTypeStr = in.readOptionalString(); + if (spaceTypeStr != null) { + spaceType = SpaceType.getSpace(spaceTypeStr); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalFloat(diversity); + out.writeOptionalVInt(candidates); + out.writeOptionalString(vectorFieldPath); + out.writeOptionalString(vectorFieldDataType == null ? null : vectorFieldDataType.getValue()); + out.writeOptionalString(spaceType == null ? null : spaceType.getValue()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + if (diversity != null) { + builder.field(DIVERSITY_FIELD.getPreferredName(), diversity); + } + if (candidates != null) { + builder.field(CANDIDATES_FIELD.getPreferredName(), candidates); + } + if (vectorFieldPath != null) { + builder.field(VECTOR_FIELD_PATH_FIELD.getPreferredName(), vectorFieldPath); + } + if (vectorFieldDataType != null) { + builder.field(VECTOR_FIELD_DATA_TYPE_FIELD.getPreferredName(), vectorFieldDataType.getValue()); + } + if (spaceType != null) { + builder.field(VECTOR_FIELD_SPACE_TYPE_FIELD.getPreferredName(), spaceType.getValue()); + } + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(diversity, candidates, vectorFieldPath, vectorFieldDataType, spaceType); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + EqualsBuilder equalsBuilder = new EqualsBuilder(); + MMRSearchExtBuilder other = (MMRSearchExtBuilder) obj; + equalsBuilder.append(diversity, other.diversity); + equalsBuilder.append(candidates, other.candidates); + equalsBuilder.append(vectorFieldPath, other.vectorFieldPath); + equalsBuilder.append(vectorFieldDataType, other.vectorFieldDataType); + equalsBuilder.append(spaceType, other.spaceType); + return equalsBuilder.isEquals(); + } + + public static MMRSearchExtBuilder parse(XContentParser parser) throws IOException { + ensureMMRProcessorsEnabled(); + XContentParser.Token token; + String currentFieldName = ""; + Builder builder = new Builder(); + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token.isValue()) { + if (DIVERSITY_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + builder.diversity(parser.floatValue()); + } else if (CANDIDATES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + builder.candidates(parser.intValue()); + } else if (VECTOR_FIELD_PATH_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + builder.vectorFieldPath(parser.text()); + } else if (VECTOR_FIELD_DATA_TYPE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + builder.vectorFieldDataType(parser.text()); + } else if (VECTOR_FIELD_SPACE_TYPE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + builder.spaceType(parser.text()); + } else { + throw unsupportedField(parser, currentFieldName); + } + } else { + throw unsupportedField(parser, currentFieldName); + } + } + return builder.build(); + } + + private static void ensureMMRProcessorsEnabled() { + List enabledFactories = KNNClusterUtil.instance().getEnabledSystemGeneratedFactories(); + boolean isMMRProcessorsEnabled = enabledFactories.contains("*") + || (enabledFactories.contains(MMROverSampleProcessor.MMROverSampleProcessorFactory.TYPE) + && enabledFactories.contains(MMRRerankProcessor.MMRRerankProcessorFactory.TYPE)); + if (isMMRProcessorsEnabled == false) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "We need to enable [%s, %s] in the" + " cluster setting [%s] to support the mmr search extension.", + MMROverSampleProcessor.MMROverSampleProcessorFactory.TYPE, + MMRRerankProcessor.MMRRerankProcessorFactory.TYPE, + ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING.getKey() + ) + ); + } + } + + private static ParsingException unsupportedField(XContentParser parser, String fieldName) { + return new ParsingException( + parser.getTokenLocation(), + String.format(Locale.ROOT, "[%s] query extension does not support [%s]", NAME, fieldName) + ); + } + +} diff --git a/src/main/java/org/opensearch/knn/search/processor/mmr/MMRKnnQueryTransformer.java b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRKnnQueryTransformer.java new file mode 100644 index 0000000000..8ad4e1a2b1 --- /dev/null +++ b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRKnnQueryTransformer.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import lombok.NoArgsConstructor; +import org.opensearch.core.action.ActionListener; +import org.opensearch.knn.index.query.KNNQueryBuilder; + +import static org.opensearch.knn.search.processor.mmr.MMRUtil.resolveKnnVectorFieldInfo; + +/** + * A transformer to transform the knn query for MMR + */ +@NoArgsConstructor +public class MMRKnnQueryTransformer implements MMRQueryTransformer { + + @Override + public void transform(KNNQueryBuilder queryBuilder, ActionListener listener, MMRTransformContext mmrTransformContext) { + try { + if (queryBuilder.getMaxDistance() == null && queryBuilder.getMinScore() == null) { + queryBuilder.setK(mmrTransformContext.getCandidates()); + } + + if (mmrTransformContext.isVectorFieldInfoResolved()) { + listener.onResponse(null); + return; + } + + MMRRerankContext mmrRerankContext = mmrTransformContext.getMmrRerankContext(); + String knnVectorFieldPath = queryBuilder.fieldName(); + if (knnVectorFieldPath == null) { + throw new IllegalArgumentException( + "Failed to transform the knn query for MMR. Field name of the knn query should not be null." + ); + } + mmrRerankContext.setVectorFieldPath(knnVectorFieldPath); + + resolveKnnVectorFieldInfo( + knnVectorFieldPath, + mmrTransformContext.getUserProvidedSpaceType(), + mmrTransformContext.getUserProvidedVectorDataType(), + mmrTransformContext.getLocalIndexMetadataList(), + mmrTransformContext.getClient(), + ActionListener.wrap(vectorFieldInfo -> { + mmrRerankContext.setVectorDataType(vectorFieldInfo.getVectorDataType()); + mmrRerankContext.setSpaceType(vectorFieldInfo.getSpaceType()); + listener.onResponse(null); + }, listener::onFailure) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + @Override + public String getQueryName() { + return KNNQueryBuilder.NAME; + } +} diff --git a/src/main/java/org/opensearch/knn/search/processor/mmr/MMROverSampleProcessor.java b/src/main/java/org/opensearch/knn/search/processor/mmr/MMROverSampleProcessor.java new file mode 100644 index 0000000000..f98497279b --- /dev/null +++ b/src/main/java/org/opensearch/knn/search/processor/mmr/MMROverSampleProcessor.java @@ -0,0 +1,402 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import lombok.AllArgsConstructor; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.util.KNNClusterUtil; +import org.opensearch.knn.search.extension.MMRSearchExtBuilder; +import org.opensearch.search.fetch.StoredFieldsContext; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.ProcessorGenerationContext; +import org.opensearch.search.pipeline.SearchRequestProcessor; +import org.opensearch.search.pipeline.SystemGeneratedProcessor; +import org.opensearch.transport.RemoteClusterService; +import org.opensearch.transport.client.Client; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.MMR_RERANK_CONTEXT; +import static org.opensearch.knn.search.processor.mmr.MMRUtil.resolveKnnVectorFieldInfo; +import static org.opensearch.knn.search.processor.mmr.MMRUtil.shouldGenerateMMRProcessor; + +/** + * A system generated search request processor for MMR. It will transform the search request to oversample and also + * collect and store the info in the PipelineProcessingContext for MMR later in a search response processor. + */ +public class MMROverSampleProcessor implements SearchRequestProcessor, SystemGeneratedProcessor { + public static final String TYPE = "mmr_over_sample"; + public static final String DESCRIPTION = "This is a system generated processor that will modify the query size and" + + "k of the knn query to oversample for Maximal Marginal Relevance rerank."; + private static final int DEFAULT_QUERY_SIZE_INDICATOR = -1; + private static final int DEFAULT_QUERY_SIZE = 10; + private static final int DEFAULT_OVERSAMPLE_SCALE = 3; + private final String tag; + private final boolean ignoreFailure; + private final Client client; + private final Map> mmrQueryTransformers; + + public MMROverSampleProcessor( + String tag, + boolean ignoreFailure, + Client client, + Map> mmrQueryTransformers + ) { + this.tag = tag; + this.ignoreFailure = ignoreFailure; + this.client = client; + this.mmrQueryTransformers = mmrQueryTransformers; + } + + @Override + public SearchRequest processRequest(SearchRequest searchRequest) { + throw new UnsupportedOperationException( + String.format(Locale.ROOT, "Should not try to use %s to process a search request synchronously.", TYPE) + ); + } + + @Override + public SearchRequest processRequest(SearchRequest request, PipelineProcessingContext requestContext) { + throw new UnsupportedOperationException( + String.format(Locale.ROOT, "Should not try to use %s to process a search request synchronously.", TYPE) + ); + } + + @Override + public void processRequestAsync( + SearchRequest request, + PipelineProcessingContext requestContext, + ActionListener requestListener + ) { + try { + if (request == null || request.source() == null || request.source().ext() == null) { + throw new IllegalStateException( + String.format(Locale.ROOT, "Search request passed to %s search request processor must have mmr search extension.", TYPE) + ); + } + + // Find the MMRSearchExtBuilder. We must have one. + MMRSearchExtBuilder mmrSearchExtBuilder = extractMMRExtension(request); + + String[] allTargetIndices = request.indices(); + String remoteSeparator = String.valueOf(RemoteClusterService.REMOTE_CLUSTER_INDEX_SEPARATOR); + List remoteIndices = splitIndices(allTargetIndices, remoteSeparator, true); + List localIndices = splitIndices(allTargetIndices, remoteSeparator, false); + + MMRRerankContext mmrRerankContext = new MMRRerankContext(); + mmrRerankContext.setDiversity(mmrSearchExtBuilder.getDiversity()); + + validateForRemoteIndices(mmrSearchExtBuilder, remoteIndices); + + int candidates = computeCandidatesAndSetRequestSize(mmrRerankContext, request, mmrSearchExtBuilder); + // ensure we have the vector in the _source so that the MMRRerankProcessor can use it for mmr rerank + preserveAndEnableFullSource(request, mmrRerankContext); + + OriginalIndices localIndicesSearchRequest = new OriginalIndices(localIndices.toArray(String[]::new), request.indicesOptions()); + List localIndexMetadataList = getLocalIndexMetadata(localIndicesSearchRequest); + String userProvidedVectorFieldPath = mmrSearchExtBuilder.getVectorFieldPath(); + VectorDataType userProvidedVectorDataType = mmrSearchExtBuilder.getVectorFieldDataType(); + SpaceType userProvidedSpaceType = mmrSearchExtBuilder.getSpaceType(); + MMRTransformContext mmrTransformContext = new MMRTransformContext( + candidates, + mmrRerankContext, + localIndexMetadataList, + remoteIndices, + userProvidedSpaceType, + userProvidedVectorFieldPath, + userProvidedVectorDataType, + client, + false + ); + + if (userProvidedVectorFieldPath != null) { + processWithUserProvidedVectorFieldPath(request, requestContext, requestListener, mmrTransformContext); + return; + } + transformQueryForMMR(request, requestListener, mmrTransformContext, requestContext); + } catch (Exception e) { + requestListener.onFailure(e); + } + } + + private void processWithUserProvidedVectorFieldPath( + SearchRequest request, + PipelineProcessingContext requestContext, + ActionListener requestListener, + MMRTransformContext mmrTransformContext + ) { + try { + String userProvidedVectorFieldPath = mmrTransformContext.getUserProvidedVectorFieldPath(); + SpaceType userProvidedSpaceType = mmrTransformContext.getUserProvidedSpaceType(); + VectorDataType userProvidedVectorDataType = mmrTransformContext.getUserProvidedVectorDataType(); + List localIndexMetadataList = mmrTransformContext.getLocalIndexMetadataList(); + MMRRerankContext mmrRerankContext = mmrTransformContext.getMmrRerankContext(); + + mmrRerankContext.setVectorFieldPath(userProvidedVectorFieldPath); + + resolveKnnVectorFieldInfo( + userProvidedVectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + localIndexMetadataList, + client, + ActionListener.wrap(vectorFieldInfo -> { + mmrRerankContext.setVectorDataType(vectorFieldInfo.getVectorDataType()); + mmrRerankContext.setSpaceType(vectorFieldInfo.getSpaceType()); + mmrTransformContext.setVectorFieldInfoResolved(true); + transformQueryForMMR(request, requestListener, mmrTransformContext, requestContext); + }, requestListener::onFailure) + ); + } catch (Exception e) { + requestListener.onFailure(e); + } + } + + private MMRSearchExtBuilder extractMMRExtension(SearchRequest request) { + return request.source() + .ext() + .stream() + .filter(MMRSearchExtBuilder.class::isInstance) + .map(MMRSearchExtBuilder.class::cast) + .findFirst() + .orElseThrow( + () -> new IllegalStateException( + String.format(Locale.ROOT, "SearchRequest passed to %s processor must have an MMRSearchExtBuilder", TYPE) + ) + ); + } + + private List splitIndices(String[] indices, String separator, boolean remote) { + return Arrays.stream(indices).filter(index -> (index.contains(separator)) == remote).toList(); + } + + // For remote indices it is not cheap to pull the info from the remote cluster to resolve the space type and the + // vector data type so we require users to provide this info. + private void validateForRemoteIndices(MMRSearchExtBuilder mmrSearchExtBuilder, List remoteIndices) { + if (remoteIndices.isEmpty()) { + return; + } + + String indicesString = String.join(",", remoteIndices); + + SpaceType spaceType = mmrSearchExtBuilder.getSpaceType(); + if (spaceType == null) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "%s is required in the MMR query extension when querying remote indices [%s].", + MMRSearchExtBuilder.VECTOR_FIELD_SPACE_TYPE_FIELD.getPreferredName(), + indicesString + ) + ); + } + + VectorDataType vectorDataType = mmrSearchExtBuilder.getVectorFieldDataType(); + if (vectorDataType == null) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "%s is required in the MMR query extension when querying remote indices [%s].", + MMRSearchExtBuilder.VECTOR_FIELD_DATA_TYPE_FIELD.getPreferredName(), + indicesString + ) + ); + } + } + + private List getLocalIndexMetadata(OriginalIndices localIndicesSearchRequest) { + return KNNClusterUtil.instance().getIndexMetadataList(localIndicesSearchRequest); + } + + private int computeCandidatesAndSetRequestSize( + MMRRerankContext mmrRerankContext, + SearchRequest request, + MMRSearchExtBuilder mmrSearchExtBuilder + ) { + int originalQuerySize = request.source().size(); + // If the query size is not set by the user then it will be -1, and later we will use the default value 10 + if (originalQuerySize == DEFAULT_QUERY_SIZE_INDICATOR) { + originalQuerySize = DEFAULT_QUERY_SIZE; + } + mmrRerankContext.setOriginalQuerySize(originalQuerySize); + + Integer candidates = mmrSearchExtBuilder.getCandidates(); + if (candidates == null) { + candidates = DEFAULT_OVERSAMPLE_SCALE * originalQuerySize; // default candidates + } + + request.source().size(candidates); + return candidates; + } + + private void preserveAndEnableFullSource(SearchRequest request, MMRRerankContext mmrContext) { + FetchSourceContext currentSourceContext = request.source().fetchSource(); + StoredFieldsContext storedFieldsContext = request.source().storedFields(); + + if (storedFieldsContext != null) { + if (isStoredFieldsDisabled(storedFieldsContext)) { + handleDisabledStoredFields(request, mmrContext, currentSourceContext); + return; + } + + if (isSourceNotExplicitlySet(currentSourceContext)) { + // when stored_fields is defined and _source is not defined we will not fetch _source so need to + // temporarily enable it for mmr. + enableSourceTemporarily(request, mmrContext); + return; + } + } + + if (isAlreadyFetchingFullSource(currentSourceContext)) { + return; + } + + preserveAndEnableFullSourceFetch(request, mmrContext, currentSourceContext); + } + + private boolean isStoredFieldsDisabled(StoredFieldsContext context) { + return context.fetchFields() == false; + } + + private boolean isSourceNotExplicitlySet(FetchSourceContext sourceContext) { + return sourceContext == null; + } + + private boolean isAlreadyFetchingFullSource(FetchSourceContext sourceContext) { + if (sourceContext == null) { + return true; + } + boolean fetchingAll = sourceContext.fetchSource(); + boolean noIncludes = sourceContext.includes().length == 0; + boolean noExcludes = sourceContext.excludes().length == 0; + return fetchingAll && noIncludes && noExcludes; + } + + private void handleDisabledStoredFields(SearchRequest request, MMRRerankContext mmrContext, FetchSourceContext currentSourceContext) { + if (currentSourceContext != null) { + // stored_fields = _none_ + explicit _source → invalid + throw new IllegalArgumentException("[stored_fields] cannot be disabled if [_source] is requested"); + } + // stored_fields = _none_ + no _source defined → temporarily enable _source + mmrContext.setOriginalFetchSourceContext(new FetchSourceContext(false)); + request.source().storedFields(StoredFieldsContext.fromList(Collections.emptyList())); + request.source().fetchSource(new FetchSourceContext(true)); + } + + private void enableSourceTemporarily(SearchRequest request, MMRRerankContext mmrContext) { + mmrContext.setOriginalFetchSourceContext(new FetchSourceContext(false)); + request.source().fetchSource(new FetchSourceContext(true)); + } + + private void preserveAndEnableFullSourceFetch( + SearchRequest request, + MMRRerankContext mmrContext, + FetchSourceContext currentSourceContext + ) { + mmrContext.setOriginalFetchSourceContext(currentSourceContext); + request.source().fetchSource(new FetchSourceContext(true)); + } + + private void transformQueryForMMR( + SearchRequest request, + ActionListener requestListener, + MMRTransformContext mmrTransformationContext, + PipelineProcessingContext requestContext + ) { + QueryBuilder queryBuilder = request.source().query(); + if (queryBuilder == null) { + throw new IllegalArgumentException("Query builder must not be null to do Maximal Marginal Relevance rerank."); + } + + @SuppressWarnings("unchecked") + MMRQueryTransformer transformer = (MMRQueryTransformer) mmrQueryTransformers.get( + queryBuilder.getWriteableName() + ); + if (transformer == null) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Maximal Marginal Relevance rerank doesn't support the query type [%s]", + queryBuilder.getWriteableName() + ) + ); + } + + transformer.transform(queryBuilder, new ActionListener<>() { + @Override + public void onResponse(Void unused) { + requestContext.setAttribute(MMR_RERANK_CONTEXT, mmrTransformationContext.getMmrRerankContext()); + requestListener.onResponse(request); + } + + @Override + public void onFailure(Exception e) { + requestListener.onFailure(e); + } + }, mmrTransformationContext); + } + + // This processor will be executed post the user defined search request processor if there is any. + @Override + public ExecutionStage getExecutionStage() { + return ExecutionStage.POST_USER_DEFINED; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getTag() { + return tag; + } + + @Override + public String getDescription() { + return DESCRIPTION; + } + + @Override + public boolean isIgnoreFailure() { + return ignoreFailure; + } + + @AllArgsConstructor + public static class MMROverSampleProcessorFactory implements SystemGeneratedFactory { + public static final String TYPE = "mmr_over_sample_factory"; + private final Client client; + private final Map> mmrQueryTransformers; + + @Override + public boolean shouldGenerate(ProcessorGenerationContext processorGenerationContext) { + return shouldGenerateMMRProcessor(processorGenerationContext); + } + + @Override + public SearchRequestProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) throws Exception { + return new MMROverSampleProcessor(tag, ignoreFailure, client, mmrQueryTransformers); + } + } +} diff --git a/src/main/java/org/opensearch/knn/search/processor/mmr/MMRQueryTransformer.java b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRQueryTransformer.java new file mode 100644 index 0000000000..41ab2b000c --- /dev/null +++ b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRQueryTransformer.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.QueryBuilder; + +public interface MMRQueryTransformer { + /** + * Transform the queryBuilder to oversample for MMR. + * Also need to figure out the vector field path and the space type and set them in the MMRProcessingContext for + * response processor to consume. + * @param queryBuilder + * @param listener + * @param mmrTransformContext {@link MMRTransformContext} + */ + void transform(T queryBuilder, ActionListener listener, MMRTransformContext mmrTransformContext); + + /** + * @return The name of the query which will be used to find the transformer. + */ + String getQueryName(); +} diff --git a/src/main/java/org/opensearch/knn/search/processor/mmr/MMRRerankContext.java b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRRerankContext.java new file mode 100644 index 0000000000..53f5e5fce4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRRerankContext.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import lombok.Data; +import lombok.NoArgsConstructor; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.search.fetch.subphase.FetchSourceContext; + +import java.util.Map; + +/** + * A DTO to hold the context for MMR rerank + */ +@Data +@NoArgsConstructor +public class MMRRerankContext { + private Integer originalQuerySize; + private Float diversity; + private FetchSourceContext originalFetchSourceContext; + private SpaceType spaceType; + // The default path if we cannot find the path based on the index + private String vectorFieldPath; + private VectorDataType vectorDataType; + // To support the case that we have different vector field paths in different indices + private Map indexToVectorFieldPathMap; +} diff --git a/src/main/java/org/opensearch/knn/search/processor/mmr/MMRRerankProcessor.java b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRRerankProcessor.java new file mode 100644 index 0000000000..9922ffd327 --- /dev/null +++ b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRRerankProcessor.java @@ -0,0 +1,307 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import lombok.AllArgsConstructor; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.knn.index.KNNVectorSimilarityFunction; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.ProcessorGenerationContext; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.search.pipeline.SystemGeneratedProcessor; +import org.opensearch.search.profile.SearchProfileShardResults; + +import java.io.IOException; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import static org.opensearch.knn.common.KNNConstants.MMR_RERANK_CONTEXT; +import static org.opensearch.knn.search.processor.mmr.MMRUtil.extractVectorFromHit; +import static org.opensearch.knn.search.processor.mmr.MMRUtil.shouldGenerateMMRProcessor; + +/** + * A system generated search response processor that rerank the response based on the Maximal Marginal Relevance + */ +@AllArgsConstructor +public class MMRRerankProcessor implements SearchResponseProcessor, SystemGeneratedProcessor { + public static final String TYPE = "mmr_rerank"; + public static final String DESCRIPTION = "This is a system generated processor that will rerank the response based" + + "on Maximal Marginal Relevance."; + private final String tag; + private final boolean ignoreFailure; + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response) { + throw new UnsupportedOperationException( + String.format(Locale.ROOT, "Should not try to use %s to process a search response without PipelineProcessingContext.", TYPE) + ); + } + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse searchResponse, PipelineProcessingContext requestContext) + throws IOException { + long startNanos = System.nanoTime(); + + if (isEmptyResponse(searchResponse)) { + return searchResponse; + } + + final MMRRerankContext mmrContext = requireMMRContext(requestContext); + final KNNVectorSimilarityFunction similarityFunction = mmrContext.getSpaceType().getKnnVectorSimilarityFunction(); + final int originalQuerySize = mmrContext.getOriginalQuerySize(); + final float diversity = mmrContext.getDiversity(); + final boolean isFloatVector = VectorDataType.FLOAT.equals(mmrContext.getVectorDataType()); + + final List candidates = new ArrayList<>(List.of(searchResponse.getHits().getHits())); + final Map docVectors = extractVectors( + candidates, + mmrContext.getVectorFieldPath(), + mmrContext.getIndexToVectorFieldPathMap(), + isFloatVector + ); + + final List selected = selectHitsWithMMR( + candidates, + docVectors, + similarityFunction, + diversity, + originalQuerySize, + isFloatVector + ); + + applyFetchSourceFilterIfNeeded(selected, mmrContext); + + final float maxSelectedScore = selected.stream().map(SearchHit::getScore).max(Float::compare).orElse(Float.NEGATIVE_INFINITY); + + final SearchHits newHits = new SearchHits( + selected.toArray(new SearchHit[0]), + searchResponse.getHits().getTotalHits(), + maxSelectedScore, + searchResponse.getHits().getSortFields(), + searchResponse.getHits().getCollapseField(), + searchResponse.getHits().getCollapseValues() + ); + + final SearchResponseSections newSections = new SearchResponseSections( + newHits, + searchResponse.getAggregations(), + searchResponse.getSuggest(), + searchResponse.isTimedOut(), + searchResponse.isTerminatedEarly(), + new SearchProfileShardResults(searchResponse.getProfileResults()), + searchResponse.getNumReducePhases(), + searchResponse.getInternalResponse().getSearchExtBuilders() + ); + + long elapsedMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos); + long newTookMillis = searchResponse.getTook().millis() + elapsedMillis; + + return new SearchResponse( + newSections, + searchResponse.getScrollId(), + searchResponse.getTotalShards(), + searchResponse.getSuccessfulShards(), + searchResponse.getSkippedShards(), + newTookMillis, + searchResponse.getPhaseTook(), + searchResponse.getShardFailures(), + searchResponse.getClusters(), + searchResponse.pointInTimeId() + ); + } + + private boolean isEmptyResponse(SearchResponse response) { + return response == null + || response.getHits() == null + || response.getHits().getHits() == null + || response.getHits().getHits().length == 0; + } + + private MMRRerankContext requireMMRContext(PipelineProcessingContext requestContext) { + Object attr = requestContext.getAttribute(MMR_RERANK_CONTEXT); + if (attr == null) { + throw new IllegalStateException("MMR rerank context cannot be null"); + } + + final MMRRerankContext ctx = (MMRRerankContext) attr; + + if (ctx.getSpaceType() == null) { + throw new IllegalStateException("Space type in MMR rerank context cannot be null"); + } + if (ctx.getOriginalQuerySize() == null) { + throw new IllegalStateException("Original query size in MMR rerank context cannot be null"); + } + if (ctx.getDiversity() == null) { + throw new IllegalStateException("Diversity in MMR rerank context cannot be null"); + } + if (ctx.getVectorDataType() == null) { + throw new IllegalStateException("Vector data type in MMR rerank context cannot be null"); + } + + return ctx; + } + + private Map extractVectors( + List hits, + String defaultVectorFieldPath, + Map indexToVectorFieldPathMap, + boolean isFloatVector + ) { + Map vectors = new ConcurrentHashMap<>(); + + hits.parallelStream().forEach(hit -> { + String vectorPath = defaultVectorFieldPath; + + if (indexToVectorFieldPathMap != null) { + String overridePath = indexToVectorFieldPathMap.get(hit.getIndex()); + if (overridePath != null && !overridePath.isBlank()) { + vectorPath = overridePath; + } + } + + Object embedding = extractVectorFromHit(hit.getSourceAsMap(), vectorPath, hit.getId(), isFloatVector); + vectors.put(hit.docId(), embedding); + }); + + return vectors; + } + + private List selectHitsWithMMR( + List candidates, + Map docVectors, + KNNVectorSimilarityFunction similarityFunction, + float diversity, + int targetSize, + boolean isFloatVector + ) { + List selected = new ArrayList<>(); + Map simCache = new ConcurrentHashMap<>(); + + while (selected.size() < targetSize && !candidates.isEmpty()) { + + Optional bestCandidateOpt = candidates.parallelStream().max(Comparator.comparingDouble(candidate -> { + int candidateId = candidate.docId(); + float maxSimToSelected = 0.0f; + + for (SearchHit sel : selected) { + int selId = sel.docId(); + long key = cacheKey(candidateId, selId); + long symKey = cacheKey(selId, candidateId); + + float sim = simCache.computeIfAbsent(key, k -> { + if (isFloatVector) { + return similarityFunction.compare((float[]) docVectors.get(candidateId), (float[]) docVectors.get(selId)); + } else { + return similarityFunction.compare((byte[]) docVectors.get(candidateId), (byte[]) docVectors.get(selId)); + } + }); + + simCache.putIfAbsent(symKey, sim); + maxSimToSelected = Math.max(maxSimToSelected, sim); + } + + return (1 - diversity) * candidate.getScore() - diversity * maxSimToSelected; + })); + + if (bestCandidateOpt.isPresent()) { + SearchHit bestHit = bestCandidateOpt.get(); + selected.add(bestHit); + candidates.remove(bestHit); + } + } + + return selected; + } + + private void applyFetchSourceFilterIfNeeded(List hits, MMRRerankContext mmrContext) throws IOException { + final FetchSourceContext fetchSourceContext = mmrContext.getOriginalFetchSourceContext(); + if (fetchSourceContext == null) { + return; + } + // if fetch source is false we directly remove the whole _source + if (fetchSourceContext.fetchSource() == false) { + for (SearchHit hit : hits) { + hit.sourceRef(null); + } + return; + } + + final Function, Map> filter = fetchSourceContext.getFilter(); + for (SearchHit hit : hits) { + Map filtered = filter.apply(hit.getSourceAsMap()); + hit.sourceRef(BytesReference.bytes(XContentFactory.jsonBuilder().map(filtered))); + } + } + + private long cacheKey(int id1, int id2) { + return ((long) id1 << 32) | (id2 & 0xffffffffL); + } + + // This processor will be executed pre the user defined search request processor if there is any. Since + // we oversample before so it is better to execute this processor to rerank and reduce the response to the + // original query size before executing other user defined search response processors. + @Override + public ExecutionStage getExecutionStage() { + return ExecutionStage.PRE_USER_DEFINED; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getTag() { + return tag; + } + + @Override + public String getDescription() { + return DESCRIPTION; + } + + @Override + public boolean isIgnoreFailure() { + return ignoreFailure; + } + + public static class MMRRerankProcessorFactory implements SystemGeneratedFactory { + public static final String TYPE = "mmr_rerank_factory"; + + @Override + public boolean shouldGenerate(ProcessorGenerationContext context) { + return shouldGenerateMMRProcessor(context); + } + + @Override + public SearchResponseProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) throws Exception { + return new MMRRerankProcessor(tag, ignoreFailure); + } + } +} diff --git a/src/main/java/org/opensearch/knn/search/processor/mmr/MMRTransformContext.java b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRTransformContext.java new file mode 100644 index 0000000000..979be41fcf --- /dev/null +++ b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRTransformContext.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NonNull; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.transport.client.Client; + +import java.util.List; + +/** + * A DTO to hold the context for MMR query transformer to transform the request + */ +@Data +@AllArgsConstructor +public class MMRTransformContext { + @NonNull + private final Integer candidates; + // During transform, we will also collect some info for MMR rerank processor to use later + @NonNull + private final MMRRerankContext mmrRerankContext; + // During transform, we need to figure out the knn_vector space type based on the index metadata + @NonNull + private final List localIndexMetadataList; + @NonNull + private final List remoteIndices; + private final SpaceType userProvidedSpaceType; + private final String userProvidedVectorFieldPath; + private final VectorDataType userProvidedVectorDataType; + private final Client client; + private boolean isVectorFieldInfoResolved; +} diff --git a/src/main/java/org/opensearch/knn/search/processor/mmr/MMRUtil.java b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRUtil.java new file mode 100644 index 0000000000..e4d828b406 --- /dev/null +++ b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRUtil.java @@ -0,0 +1,549 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.mapper.ObjectMapper; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.plugin.transport.GetModelAction; +import org.opensearch.knn.plugin.transport.GetModelRequest; +import org.opensearch.knn.plugin.transport.GetModelResponse; +import org.opensearch.knn.search.extension.MMRSearchExtBuilder; +import org.opensearch.search.pipeline.ProcessorGenerationContext; +import org.opensearch.transport.client.Client; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.index.engine.SpaceTypeResolver.getDefaultSpaceType; + +/** + * A util for MMR related functions + */ +public class MMRUtil { + + /** + * Collect the knn vector field info based on the query field path or the user provided vector field path and + * the index metadata. + * @param path The path of the query field or user provided vector field path + * @param indexMetadataList List of index metadata of the query target indices + * @return List of knn vector field info + */ + private static List collectKnnVectorFieldInfos( + @NonNull final String path, + @NonNull final List indexMetadataList + ) { + List vectorFieldInfos = new ArrayList<>(); + + for (IndexMetadata indexMetadata : indexMetadataList) { + vectorFieldInfos.add(collectKnnVectorFieldInfo(indexMetadata, path)); + } + + return vectorFieldInfos; + } + + private static MMRVectorFieldInfo collectKnnVectorFieldInfo(IndexMetadata indexMetadata, String path) { + final MMRVectorFieldInfo vectorFieldInfo = new MMRVectorFieldInfo(); + vectorFieldInfo.setIndexNameByIndexMetadata(indexMetadata); + + MappingMetadata mappingMetadata = indexMetadata.mapping(); + if (mappingMetadata == null) { + vectorFieldInfo.setUnmapped(true); + return vectorFieldInfo; + } + + Map mapping = mappingMetadata.sourceAsMap(); + Map config = getMMRFieldMappingByPath(mapping, path); + if (config == null) { + vectorFieldInfo.setUnmapped(true); + return vectorFieldInfo; + } + + vectorFieldInfo.setUnmapped(false); + vectorFieldInfo.setFieldPath(path); + + String fieldType = (String) config.get(TYPE); + vectorFieldInfo.setFieldType(fieldType); + + if (KNNVectorFieldMapper.CONTENT_TYPE.equals(fieldType) == false) { + return vectorFieldInfo; // Not a knn_vector field skip further processing + } + + vectorFieldInfo.setKnnConfig(config); + + return vectorFieldInfo; + } + + private static MMRVectorFieldInfo resolveKnnVectorFieldInfo( + SpaceType userProvidedSpaceType, + VectorDataType userProvidedVectorDataType, + List MMRVectorFieldInfoList + ) throws IllegalArgumentException { + boolean allUnmapped = true; + List nonKnnFields = new ArrayList<>(); + SpaceType resolvedSpaceType = null; + VectorDataType resolvedVectorDataType = null; + + for (MMRVectorFieldInfo info : MMRVectorFieldInfoList) { + if (info.isUnmapped()) { + continue; + } + + allUnmapped = false; + + if (!info.isKNNVectorField()) { + nonKnnFields.add(info); + continue; + } + + // ensure we have the same space type and vector data type if we have multiple target indices + resolvedSpaceType = resolveConsistentValue( + resolvedSpaceType, + info.getSpaceType(), + SpaceType::getValue, + "space type", + info.getFieldPath() + ); + + resolvedVectorDataType = resolveConsistentValue( + resolvedVectorDataType, + info.getVectorDataType(), + VectorDataType::getValue, + "vector data type", + info.getFieldPath() + ); + } + + if (allUnmapped) { + resolvedSpaceType = userProvidedSpaceType != null ? userProvidedSpaceType : getDefaultSpaceType(VectorDataType.DEFAULT); + resolvedVectorDataType = userProvidedVectorDataType != null ? userProvidedVectorDataType : VectorDataType.DEFAULT; + return new MMRVectorFieldInfo(resolvedSpaceType, resolvedVectorDataType); + } + + if (!nonKnnFields.isEmpty()) { + throw new IllegalArgumentException( + String.format( + "MMR query extension cannot support non knn_vector field [%s].", + nonKnnFields.stream() + .map(info -> String.format(Locale.ROOT, "%s:%s", info.getIndexName(), info.getFieldPath())) + .collect(Collectors.joining(",")) + ) + ); + } + + return resolveFinalKnnVectorFieldInfo(userProvidedSpaceType, resolvedSpaceType, userProvidedVectorDataType, resolvedVectorDataType); + } + + private static T resolveConsistentValue( + T current, + T next, + Function valueFormatter, + String fieldDescription, + String fieldPath + ) { + if (next == null) { + return current; + } + if (current == null) { + return next; + } + if (!current.equals(next)) { + throw new IllegalArgumentException( + String.format( + "MMR query extension cannot support different %s [%s, %s] for the knn_vector field at path %s.", + fieldDescription, + valueFormatter.apply(current), + valueFormatter.apply(next), + fieldPath + ) + ); + } + return current; + } + + private static MMRVectorFieldInfo resolveFinalKnnVectorFieldInfo( + SpaceType userProvidedSpaceType, + SpaceType resolvedSpaceType, + VectorDataType userProvidedVectorDataType, + VectorDataType resolvedVectorDataType + ) throws IllegalArgumentException { + SpaceType finalSpaceType = resolveFinalValue( + userProvidedSpaceType, + resolvedSpaceType, + () -> getDefaultSpaceType(VectorDataType.DEFAULT), + SpaceType::getValue, + "space type" + ); + + VectorDataType finalVectorDataType = resolveFinalValue( + userProvidedVectorDataType, + resolvedVectorDataType, + () -> VectorDataType.DEFAULT, + VectorDataType::getValue, + "vector data type" + ); + + return new MMRVectorFieldInfo(finalSpaceType, finalVectorDataType); + } + + private static T resolveFinalValue( + T userProvided, + T resolved, + Supplier defaultSupplier, + Function valueFormatter, + String fieldDescription + ) { + if (userProvided != null && resolved != null && !userProvided.equals(resolved)) { + throw new IllegalArgumentException( + String.format( + "The %s [%s] provided in the MMR query extension does not match the %s [%s] in target indices.", + fieldDescription, + valueFormatter.apply(userProvided), + fieldDescription, + valueFormatter.apply(resolved) + ) + ); + } + + if (userProvided != null) { + return userProvided; + } else if (resolved != null) { + return resolved; + } else { + return defaultSupplier.get(); + } + } + + private static MMRVectorFieldInfo resolveVectorFieldInfoFromModel( + VectorDataType userProvidedVectorDataType, + SpaceType userProvidedSpaceType, + List MMRVectorFieldInfoList, + Map modelIdToVectorFieldInfo + ) throws IllegalArgumentException { + SpaceType resolvedSpaceType = null; + VectorDataType resolvedVectorDataType = null; + for (MMRVectorFieldInfo info : MMRVectorFieldInfoList) { + SpaceType spaceType; + VectorDataType vectorDataType; + + // Resolve from model if modelId is present, else from field config + if (info.getModelId() != null) { + MMRVectorFieldInfo infoFromModel = modelIdToVectorFieldInfo.get(info.getModelId()); + if (infoFromModel == null) { + throw new IllegalStateException( + String.format( + "Unexpected null when try to resolve the info of the vector field at path [%s] based on its model [%s].", + info.getModelId(), + info.getFieldPath() + ) + ); + } + vectorDataType = infoFromModel.getVectorDataType() != null ? infoFromModel.getVectorDataType() : VectorDataType.DEFAULT; + spaceType = infoFromModel.getSpaceType() != null ? infoFromModel.getSpaceType() : getDefaultSpaceType(vectorDataType); + } else { + spaceType = info.getSpaceType(); + vectorDataType = info.getVectorDataType(); + } + + resolvedSpaceType = resolveConsistentValue( + resolvedSpaceType, + spaceType, + SpaceType::getValue, + "space type", + info.getFieldPath() + ); + + resolvedVectorDataType = resolveConsistentValue( + resolvedVectorDataType, + vectorDataType, + VectorDataType::getValue, + "vector data type", + info.getFieldPath() + ); + } + + return resolveFinalKnnVectorFieldInfo(userProvidedSpaceType, resolvedSpaceType, userProvidedVectorDataType, resolvedVectorDataType); + } + + private static void retrieveFieldInfoFromModel( + @NonNull final Set modelIds, + @NonNull final Client client, + @NonNull final ActionListener> listener + ) { + Map modelIdToVectorFieldInfo = new ConcurrentHashMap<>(); + List errors = Collections.synchronizedList(new ArrayList<>()); + AtomicInteger counter = new AtomicInteger(modelIds.size()); + + for (String modelId : modelIds) { + client.execute(GetModelAction.INSTANCE, new GetModelRequest(modelId), ActionListener.wrap((GetModelResponse response) -> { + SpaceType spaceTypeFromModel = null; + VectorDataType vectorDataTypeFromModel = null; + if (response != null && response.getModel() != null && response.getModel().getModelMetadata() != null) { + spaceTypeFromModel = response.getModel().getModelMetadata().getSpaceType(); + vectorDataTypeFromModel = response.getModel().getModelMetadata().getVectorDataType(); + } + modelIdToVectorFieldInfo.put(modelId, new MMRVectorFieldInfo(spaceTypeFromModel, vectorDataTypeFromModel)); + if (counter.decrementAndGet() == 0) { + listener.onResponse(modelIdToVectorFieldInfo); + } + }, (Exception e) -> { + errors.add(e.getMessage()); + if (counter.decrementAndGet() == 0) { + listener.onFailure( + new RuntimeException( + String.format( + Locale.ROOT, + "Failed to retrieve model(s) to resolve the space type and vector data type for the MMR query extension. Errors: %s.", + String.join(", ", errors) + ) + ) + ); + } + })); + } + } + + /** + * Resolves the space type and data type for a vector field, optionally using model metadata if model IDs exist. + * It will collect the info from the localIndexMetadataList. + * + * @param path the path of the query field or the user provided vector field path + * @param userProvidedSpaceType Optional space type provided by the user + * @param userProvidedVectorDataType Optional vector data type provided by the user + * @param localIndexMetadataList List of local index metadata to inspect + * @param client OpenSearch client to fetch models + * @param continuation ActionListener callback to receive the resolved MMRVectorFieldInfo + */ + public static void resolveKnnVectorFieldInfo( + @NonNull String path, + @Nullable SpaceType userProvidedSpaceType, + @Nullable VectorDataType userProvidedVectorDataType, + @NonNull List localIndexMetadataList, + @NonNull Client client, + @NonNull ActionListener continuation + ) { + try { + List knnVectorFieldInfos = collectKnnVectorFieldInfos(path, localIndexMetadataList); + + resolveKnnVectorFieldInfo(knnVectorFieldInfos, userProvidedSpaceType, userProvidedVectorDataType, client, continuation); + } catch (Exception e) { + continuation.onFailure(e); + } + } + + /** + * Resolves the space type and data type for a knn vector field based on field info collected from its index + * mapping config from multiple indices, optionally using model metadata if model IDs exist. + * + * @param MMRVectorFieldInfoList A list of knn vector info collected from multiple target indices + * @param userProvidedSpaceType Optional space type provided by the user + * @param userProvidedVectorDataType Optional vector data type provided by the user + * @param client OpenSearch client to fetch models + * @param continuation callback to execute once the final space type is resolved + */ + public static void resolveKnnVectorFieldInfo( + @NonNull List MMRVectorFieldInfoList, + @Nullable SpaceType userProvidedSpaceType, + @Nullable VectorDataType userProvidedVectorDataType, + @NonNull Client client, + @NonNull ActionListener continuation + ) { + try { + // Resolve field info based on the field config in index mapping + MMRVectorFieldInfo resolvedVectorFieldInfo = resolveKnnVectorFieldInfo( + userProvidedSpaceType, + userProvidedVectorDataType, + MMRVectorFieldInfoList + ); + + // Collect model IDs + Set modelIds = MMRVectorFieldInfoList.stream() + .map(MMRVectorFieldInfo::getModelId) + .filter(Objects::nonNull) + .collect(Collectors.toSet()); + + if (modelIds.isEmpty()) { + continuation.onResponse(resolvedVectorFieldInfo); + } else { + // Retrieve the field info from the model metadata asynchronously + retrieveFieldInfoFromModel(modelIds, client, ActionListener.wrap(modelIdToVectorFieldInfo -> { + MMRVectorFieldInfo resolvedVectorFieldInfoFromModel = resolveVectorFieldInfoFromModel( + userProvidedVectorDataType, + userProvidedSpaceType, + MMRVectorFieldInfoList, + modelIdToVectorFieldInfo + ); + continuation.onResponse(resolvedVectorFieldInfoFromModel); + }, continuation::onFailure)); + } + } catch (Exception e) { + continuation.onFailure(e); + } + } + + /** + * Extracts a dense vector ({@code float[]} or {@code byte[]}) from a document source map given a dot-delimited + * field path. + * + * This utility is designed for KNN / MMR use cases where the vector is expected to be stored + * as a top-level or single field inside the document. Nested object structures containing + * vectors are not supported, and will cause an exception. + * + * Example: + * source = Map.of("embedding", List.of(0.1, 0.2, 0.3)); + * float[] vector = VectorUtils.extractVectorFromHit(source, "embedding", "doc-123"); + * vector = [0.1f, 0.2f, 0.3f] + * + * + * @param sourceAsMap The document source returned from {@code hit.getSourceAsMap()}. + * @param fieldPath The dot-delimited field path to the vector field (e.g. "embedding" or "nested.field.vector"). + * @param docId The document ID, used in error messages to help identify problematic documents. + * @param isFloatVector If the vector is float or byte + * @return A primitive float/byte array representing the extracted vector. + */ + @SuppressWarnings("unchecked") + public static Object extractVectorFromHit(Map sourceAsMap, String fieldPath, String docId, boolean isFloatVector) + throws IllegalArgumentException { + String baseError = String.format(Locale.ROOT, "Failed to extract the vector from the doc [%s] for MMR rerank", docId); + if (sourceAsMap == null || fieldPath == null) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s: source map and fieldPath must not be null.", baseError)); + } + + String[] pathParts = fieldPath.split("\\."); + Object current = sourceAsMap; + + if (pathParts.length == 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s: fieldPath must not be an empty string.", baseError)); + } + + for (int i = 0; i < pathParts.length; i++) { + String part = pathParts[i]; + if (!(current instanceof Map map)) { + throw new IllegalArgumentException( + String.format("%s: expected object at [%s], but found [%s]", baseError, part, current.getClass().getName()) + ); + } + + current = map.get(part); + if (current == null) { + throw new IllegalArgumentException( + String.format("%s: field path [%s] not found in document source.", baseError, fieldPath) + ); + } + + // Final part → should resolve to a vector + if (i == pathParts.length - 1) { + if (current instanceof List list) { + float[] floatVector = null; + byte[] byteVector = null; + if (isFloatVector) { + floatVector = new float[list.size()]; + } else { + byteVector = new byte[list.size()]; + } + try { + for (int j = 0; j < list.size(); j++) { + if (isFloatVector) { + floatVector[j] = (float) (double) list.get(j); + } else { + byteVector[j] = (byte) (double) list.get(j); + } + } + } catch (Exception e) { + throw new IllegalArgumentException( + String.format("%s: unexpected value at the vector field [%s]. error: %s", baseError, fieldPath, e.getMessage()) + ); + } + if (isFloatVector) { + return floatVector; + } else return byteVector; + } + throw new IllegalArgumentException( + String.format( + "%s: expected vector (list of numbers) at field path [%s], but found type [%s]", + baseError, + fieldPath, + current.getClass().getName() + ) + ); + } + } + + // Should never reach here + throw new IllegalStateException(String.format("%s: unexpected error resolving field path [%s].", baseError, fieldPath)); + } + + /** + * @param processorGenerationContext The context to evaluate if we should generate the MMR processor. + * @return If the MMR processor should be generated. + */ + public static boolean shouldGenerateMMRProcessor(ProcessorGenerationContext processorGenerationContext) { + SearchRequest request = processorGenerationContext.searchRequest(); + if (request == null || request.source() == null || request.source().ext() == null) { + return false; + } + return request.source().ext().stream().anyMatch(MMRSearchExtBuilder.class::isInstance); + } + + /** + * Get the field mapping config for a dot-separated path like "user.profile.age" for MMR. The fields on the path + * should not contain "nested" field type since it means the doc source can have multiple vectors which we cannot + * support. + * + * @param mappings Index mappings + * @param fieldPath Dot-separated path to the field + * @return The mapping config map for the field, or null if not found + */ + @SuppressWarnings("unchecked") + public static Map getMMRFieldMappingByPath(Map mappings, @NonNull String fieldPath) { + if (mappings == null) { + return null; + } + String[] parts = fieldPath.split("\\."); + Map current = mappings; + + for (int i = 0; i < parts.length; i++) { + String part = parts[i]; + Object propertiesObj = current.get("properties"); + if (!(propertiesObj instanceof Map)) { + return null; // no deeper properties + } + + Map properties = (Map) propertiesObj; + Object fieldConfig = properties.get(part); + if (!(fieldConfig instanceof Map)) { + return null; // field not found + } + current = (Map) fieldConfig; + + String fieldType = (String) current.get(TYPE); + if (ObjectMapper.NESTED_CONTENT_TYPE.equals(fieldType)) { + throw new IllegalArgumentException( + String.format("MMR search extension cannot support the field %s because it is in the nested field %s.", fieldPath, part) + ); + } + } + return current; + } +} diff --git a/src/main/java/org/opensearch/knn/search/processor/mmr/MMRVectorFieldInfo.java b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRVectorFieldInfo.java new file mode 100644 index 0000000000..8036719bf4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/search/processor/mmr/MMRVectorFieldInfo.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; + +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.TOP_LEVEL_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.engine.SpaceTypeResolver.getDefaultSpaceType; + +/** + * A DTO to hold the info of the vector field used for MMR + */ +@Data +@NoArgsConstructor +public class MMRVectorFieldInfo { + private String indexName; + private String fieldPath; + private VectorDataType vectorDataType; + private String modelId; + private SpaceType spaceType; + private boolean unmapped; + private String fieldType; + + public MMRVectorFieldInfo(SpaceType spaceType, VectorDataType vectorDataType) { + this.spaceType = spaceType; + this.vectorDataType = vectorDataType; + } + + public boolean isKNNVectorField() { + return KNNVectorFieldMapper.CONTENT_TYPE.equals(fieldType); + } + + public void setKnnConfig(@NonNull final Map knnConfig) { + setVectorDataTypeByConfig(knnConfig); + if (setModelIdIfPresent(knnConfig)) { + return; + } + if (setSpaceTypeIfPresent(knnConfig)) { + return; + } + this.spaceType = getDefaultSpaceType(vectorDataType); + } + + public void setIndexNameByIndexMetadata(@NonNull final IndexMetadata indexMetadata) { + this.indexName = indexMetadata.getIndex().getName(); + } + + private void setVectorDataTypeByConfig(Map knnConfig) { + String dataType = (String) knnConfig.get(VECTOR_DATA_TYPE_FIELD); + this.vectorDataType = (dataType == null) ? VectorDataType.DEFAULT : VectorDataType.get(dataType); + } + + private boolean setModelIdIfPresent(Map knnConfig) { + String modelId = (String) knnConfig.get(MODEL_ID); + if (modelId != null) { + this.modelId = modelId; + return true; + } + return false; + } + + private boolean setSpaceTypeIfPresent(Map knnConfig) { + String topLevelSpaceType = (String) knnConfig.get(TOP_LEVEL_PARAMETER_SPACE_TYPE); + if (topLevelSpaceType != null) { + this.spaceType = SpaceType.getSpace(topLevelSpaceType); + return true; + } + @SuppressWarnings("unchecked") + Map knnMethod = (Map) knnConfig.get(KNN_METHOD); + if (knnMethod != null) { + String spaceType = (String) knnMethod.get(METHOD_PARAMETER_SPACE_TYPE); + if (spaceType != null) { + this.spaceType = SpaceType.getSpace(spaceType); + return true; + } + } + return false; + } + +} diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index 376692f269..6be294ed40 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -28,11 +28,7 @@ import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.Collections; -import java.util.HashSet; -import java.util.Map; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import static org.mockito.Mockito.when; diff --git a/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java b/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java index 6ded05d176..034ac68e49 100644 --- a/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java +++ b/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java @@ -9,6 +9,8 @@ import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -30,6 +32,9 @@ public static ClusterService mockClusterService(final Version version) { DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); when(clusterState.getNodes()).thenReturn(discoveryNodes); when(discoveryNodes.getMinNodeVersion()).thenReturn(version); + when(clusterService.getClusterSettings()).thenReturn( + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); return clusterService; } } diff --git a/src/test/java/org/opensearch/knn/index/engine/SpaceTypeResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/SpaceTypeResolverTests.java index b29857e76c..e1cbbecdda 100644 --- a/src/test/java/org/opensearch/knn/index/engine/SpaceTypeResolverTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/SpaceTypeResolverTests.java @@ -29,10 +29,7 @@ private void assertResolveSpaceType( VectorDataType vectorDataType, SpaceType expectedSpaceType ) { - assertEquals( - expectedSpaceType, - SPACE_TYPE_RESOLVER.resolveSpaceType(knnMethodContext, topLevelSpaceTypeString, indexSettings, vectorDataType) - ); + assertEquals(expectedSpaceType, SPACE_TYPE_RESOLVER.resolveSpaceType(knnMethodContext, topLevelSpaceTypeString, vectorDataType)); } public void testResolveSpaceType_whenNoConfigProvided_thenFallbackToVectorDataType() { @@ -194,7 +191,6 @@ public void testResolveSpaceType_whenMethodSpaceTypeAndTopLevelSpecified_thenThr () -> SPACE_TYPE_RESOLVER.resolveSpaceType( new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.L2, MethodComponentContext.EMPTY), SpaceType.INNER_PRODUCT.getValue(), - DONT_CARE_SETTINGS, DONT_CARE_VECTOR_DATA ) ); @@ -203,7 +199,6 @@ public void testResolveSpaceType_whenMethodSpaceTypeAndTopLevelSpecified_thenThr SPACE_TYPE_RESOLVER.resolveSpaceType( new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, MethodComponentContext.EMPTY), SpaceType.DEFAULT.getValue(), - DONT_CARE_SETTINGS, DONT_CARE_VECTOR_DATA ) ); @@ -212,7 +207,6 @@ public void testResolveSpaceType_whenMethodSpaceTypeAndTopLevelSpecified_thenThr SPACE_TYPE_RESOLVER.resolveSpaceType( new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, MethodComponentContext.EMPTY), SpaceType.UNDEFINED.getValue(), - DONT_CARE_SETTINGS, DONT_CARE_VECTOR_DATA ) ); @@ -221,7 +215,6 @@ public void testResolveSpaceType_whenMethodSpaceTypeAndTopLevelSpecified_thenThr SPACE_TYPE_RESOLVER.resolveSpaceType( new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY), SpaceType.DEFAULT.getValue(), - DONT_CARE_SETTINGS, DONT_CARE_VECTOR_DATA ) ); @@ -234,7 +227,6 @@ public void testResolveSpaceType_whenMethodSpaceTypeAndTopLevelSpecified_thenThr SPACE_TYPE_RESOLVER.resolveSpaceType( new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY), SpaceType.UNDEFINED.getValue(), - settings, VectorDataType.BYTE ) ); @@ -244,7 +236,6 @@ public void testResolveSpaceType_whenMethodSpaceTypeAndTopLevelSpecified_thenThr SPACE_TYPE_RESOLVER.resolveSpaceType( new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY), SpaceType.UNDEFINED.getValue(), - settings, VectorDataType.FLOAT ) ); @@ -254,7 +245,6 @@ public void testResolveSpaceType_whenMethodSpaceTypeAndTopLevelSpecified_thenThr SPACE_TYPE_RESOLVER.resolveSpaceType( new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY), SpaceType.UNDEFINED.getValue(), - settings, VectorDataType.BINARY ) ); @@ -267,10 +257,9 @@ public void testResolveSpaceType_whenSpaceTypeSpecifiedOnce_thenReturnValue() { SPACE_TYPE_RESOLVER.resolveSpaceType( new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.L1, MethodComponentContext.EMPTY), "", - null, null ) ); - assertEquals(SpaceType.INNER_PRODUCT, SPACE_TYPE_RESOLVER.resolveSpaceType(null, SpaceType.INNER_PRODUCT.getValue(), null, null)); + assertEquals(SpaceType.INNER_PRODUCT, SPACE_TYPE_RESOLVER.resolveSpaceType(null, SpaceType.INNER_PRODUCT.getValue(), null)); } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 95b4e84d48..8d64678126 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -14,6 +14,7 @@ import org.mockito.MockedStatic; import org.opensearch.Version; import org.opensearch.cluster.ClusterModule; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.ClusterSettings; @@ -970,7 +971,7 @@ private void assertSerialization( final ClusterService clusterService = mockClusterService(version); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); + knnClusterUtil.initialize(clusterService, mock(IndexNameExpressionResolver.class)); try (BytesStreamOutput output = new BytesStreamOutput()) { output.setVersion(version); output.writeNamedWriteable(knnQueryBuilder); diff --git a/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java b/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java index 54d495b08b..392932a668 100644 --- a/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java +++ b/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java @@ -7,6 +7,7 @@ import org.opensearch.Version; import org.opensearch.cluster.ClusterModule; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.common.ParsingException; @@ -27,6 +28,7 @@ import java.util.List; import java.util.Map; +import static org.mockito.Mockito.mock; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; @@ -138,7 +140,7 @@ public void testFromXContent_withFilter() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); + knnClusterUtil.initialize(clusterService, mock(IndexNameExpressionResolver.class)); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() @@ -191,7 +193,7 @@ public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_whenFilter final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); + knnClusterUtil.initialize(clusterService, mock(IndexNameExpressionResolver.class)); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() @@ -219,7 +221,7 @@ public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_whenFilter_th final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); + knnClusterUtil.initialize(clusterService, mock(IndexNameExpressionResolver.class)); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() @@ -246,7 +248,7 @@ public void testFromXContent_InvalidQueryVectorType() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); + knnClusterUtil.initialize(clusterService, mock(IndexNameExpressionResolver.class)); List invalidTypeQueryVector = new ArrayList<>(); invalidTypeQueryVector.add(1.5); @@ -274,7 +276,7 @@ public void testFromXContent_whenDoRadiusSearch_whenInputInvalidQueryVectorType_ final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); + knnClusterUtil.initialize(clusterService, mock(IndexNameExpressionResolver.class)); List invalidTypeQueryVector = new ArrayList<>(); invalidTypeQueryVector.add(1.5); @@ -302,7 +304,7 @@ public void testFromXContent_missingQueryVector() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); + knnClusterUtil.initialize(clusterService, mock(IndexNameExpressionResolver.class)); // Test without vector field XContentBuilder builderWithoutVectorField = XContentFactory.jsonBuilder(); diff --git a/src/test/java/org/opensearch/knn/index/util/KNNClusterUtilTests.java b/src/test/java/org/opensearch/knn/index/util/KNNClusterUtilTests.java index f04db5c70b..cc3a3320cf 100644 --- a/src/test/java/org/opensearch/knn/index/util/KNNClusterUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/util/KNNClusterUtilTests.java @@ -6,7 +6,10 @@ package org.opensearch.knn.index.util; import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; import org.opensearch.knn.KNNTestCase; import static org.mockito.Mockito.mock; @@ -19,7 +22,7 @@ public void testSingleNodeCluster() { ClusterService clusterService = mockClusterService(Version.V_2_4_0); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); + knnClusterUtil.initialize(clusterService, mock(IndexNameExpressionResolver.class)); final Version minVersion = knnClusterUtil.getClusterMinVersion(); @@ -30,7 +33,7 @@ public void testMultipleNodesCluster() { ClusterService clusterService = mockClusterService(Version.V_2_3_0); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); + knnClusterUtil.initialize(clusterService, mock(IndexNameExpressionResolver.class)); final Version minVersion = knnClusterUtil.getClusterMinVersion(); @@ -39,10 +42,13 @@ public void testMultipleNodesCluster() { public void testWhenErrorOnClusterStateDiscover() { ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn( + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); when(clusterService.state()).thenThrow(new RuntimeException("Cluster state is not ready")); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); - knnClusterUtil.initialize(clusterService); + knnClusterUtil.initialize(clusterService, mock(IndexNameExpressionResolver.class)); final Version minVersion = knnClusterUtil.getClusterMinVersion(); diff --git a/src/test/java/org/opensearch/knn/plugin/KNNPluginTests.java b/src/test/java/org/opensearch/knn/plugin/KNNPluginTests.java new file mode 100644 index 0000000000..79c6b32fb9 --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/KNNPluginTests.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.search.processor.mmr.MMRKnnQueryTransformer; +import org.opensearch.knn.search.processor.mmr.MMRQueryTransformer; +import org.opensearch.plugins.ExtensiblePlugin; + +import java.lang.reflect.Field; +import java.util.List; +import java.util.Map; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KNNPluginTests extends KNNTestCase { + private KNNPlugin knnPlugin; + + @Override + public void setUp() throws Exception { + super.setUp(); + knnPlugin = new KNNPlugin(); + } + + public void testLoadExtensions_whenSuccess() throws Exception { + MMRQueryTransformer transformer = mock(MMRQueryTransformer.class); + when(transformer.getQueryName()).thenReturn("test_query"); + + ExtensiblePlugin.ExtensionLoader loader = new ExtensiblePlugin.ExtensionLoader() { + @SuppressWarnings("unchecked") + @Override + public List loadExtensions(Class extensionPointType) { + if (extensionPointType.equals(MMRQueryTransformer.class)) { + return (List) List.of(transformer); + } + return List.of(); + } + }; + + knnPlugin.loadExtensions(loader); + + Map> map = getMmrQueryTransformers(); + assertEquals(transformer, map.get("test_query")); + assertTrue(map.get(KNNQueryBuilder.NAME) instanceof MMRKnnQueryTransformer); + } + + public void testLoadExtensions_whenDuplicatedThenException() throws Exception { + MMRQueryTransformer transformerA = mock(MMRQueryTransformer.class); + when(transformerA.getQueryName()).thenReturn("test_query"); + MMRQueryTransformer transformerB = mock(MMRQueryTransformer.class); + when(transformerB.getQueryName()).thenReturn("test_query"); + + ExtensiblePlugin.ExtensionLoader loader = new ExtensiblePlugin.ExtensionLoader() { + @SuppressWarnings("unchecked") + @Override + public List loadExtensions(Class extensionPointType) { + if (extensionPointType.equals(MMRQueryTransformer.class)) { + return (List) List.of(transformerA, transformerB); + } + return List.of(); + } + }; + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> knnPlugin.loadExtensions(loader)); + + String expectedError = "Already load the MMR query transformer"; + assertTrue(exception.getMessage().contains(expectedError)); + } + + @SuppressWarnings("unchecked") + private Map> getMmrQueryTransformers() throws Exception { + Field field = KNNPlugin.class.getDeclaredField("mmrQueryTransformers"); + field.setAccessible(true); + return (Map>) field.get(knnPlugin); + } +} diff --git a/src/test/java/org/opensearch/knn/search/extension/MMRSearchExtBuilderIT.java b/src/test/java/org/opensearch/knn/search/extension/MMRSearchExtBuilderIT.java new file mode 100644 index 0000000000..243bada9d7 --- /dev/null +++ b/src/test/java/org/opensearch/knn/search/extension/MMRSearchExtBuilderIT.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.extension; + +import lombok.SneakyThrows; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.search.processor.mmr.MMROverSampleProcessor; +import org.opensearch.knn.search.processor.mmr.MMRRerankProcessor; + +import java.io.IOException; +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.*; +import static org.opensearch.search.pipeline.SearchPipelineService.ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING; + +public class MMRSearchExtBuilderIT extends KNNRestTestCase { + private static final int DIMENSION_NUM = 2; + private static final String FIELD_NAME = "vector_field"; + private static final String INDEX_NAME = "test_index"; + private static final int QUERY_SIZE = 3; + private static final float[] queryVector = new float[] { 1f, 1f }; + + @Before + public void setUpForMMR() { + enableMMRProcessors(); + createTestIndexAndDocs(); + } + + @After + public void cleanUpForMMR() throws IOException { + deleteKNNIndex(INDEX_NAME); + disableMMRProcessors(); + } + + @SneakyThrows + public void testMMR_whenRerankWithVectors_thenSelectTop3() { + XContentBuilder queryBuilder = buildMMRQuery(queryVector, QUERY_SIZE, false, false); + + Response response = searchKNNIndex(INDEX_NAME, queryBuilder.toString(), QUERY_SIZE); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + verifyResults(results, false); + } + + @SneakyThrows + public void testMMR_whenSourceExcludesVector_thenVectorExcluded() { + XContentBuilder queryBuilder = buildMMRQuery(queryVector, QUERY_SIZE, true, false); + + Response response = searchKNNIndex(INDEX_NAME, queryBuilder.toString(), QUERY_SIZE); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + verifyResults(results, true); + } + + @SneakyThrows + public void testMMR_whenDisabledStoredFields_thenVectorExcluded() { + XContentBuilder queryBuilderDisabledStoredFields = buildMMRQuery(queryVector, QUERY_SIZE, false, false, "_none_"); + + Response response = searchKNNIndex(INDEX_NAME, queryBuilderDisabledStoredFields.toString(), QUERY_SIZE); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + verifyResults(results, true); + } + + @SneakyThrows + public void testMMR_whenEmptyStoredFieldsAndExplicitlyEnableSource_thenVectorIncluded() { + XContentBuilder queryBuilderDisabledStoredFields = buildMMRQuery(queryVector, QUERY_SIZE, true, false, "empty"); + + Response response = searchKNNIndex(INDEX_NAME, queryBuilderDisabledStoredFields.toString(), QUERY_SIZE); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + verifyResults(results, true); + } + + @SneakyThrows + public void testMMR_whenUserProvidedVectorPath_thenVectorIncluded() { + XContentBuilder queryBuilder = buildMMRQuery(queryVector, QUERY_SIZE, false, true); + + Response response = searchKNNIndex(INDEX_NAME, queryBuilder.toString(), QUERY_SIZE); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + verifyResults(results, false); + } + + @SneakyThrows + private void enableMMRProcessors() { + updateClusterSettings( + ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING.getKey(), + new String[] { MMROverSampleProcessor.MMROverSampleProcessorFactory.TYPE, MMRRerankProcessor.MMRRerankProcessorFactory.TYPE } + ); + } + + @SneakyThrows + private void disableMMRProcessors() { + updateClusterSettings(ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING.getKey(), ""); + } + + @SneakyThrows + private void createTestIndexAndDocs() { + XContentBuilder mappingBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, DIMENSION_NUM) + .startObject(KNN_METHOD) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(NAME, METHOD_HNSW) + .endObject() + .endObject() + .endObject() + .endObject(); + createKnnIndex(INDEX_NAME, mappingBuilder.toString()); + + float[] similarVector = new float[] { 1f, 1f }; + for (int i = 0; i < 8; i++) + addKnnDoc(INDEX_NAME, String.valueOf(i), FIELD_NAME, similarVector); + + float[][] diverseVectors = new float[][] { { 1f, 2f }, { 2f, 1f } }; + for (int i = 8; i < 10; i++) + addKnnDoc(INDEX_NAME, String.valueOf(i), FIELD_NAME, diverseVectors[i - 8]); + } + + @SneakyThrows + private XContentBuilder buildMMRQuery(float[] queryVector, int k, boolean excludeVector, boolean userProvidedVector) { + return buildMMRQuery(queryVector, k, excludeVector, userProvidedVector, null); + } + + @SneakyThrows + private XContentBuilder buildMMRQuery( + float[] queryVector, + int k, + boolean excludeVector, + boolean userProvidedVector, + String storedFields + ) { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + if ("_none_".equals(storedFields)) { + builder.field("stored_fields", "_none_"); + } else if ("empty".equals(storedFields)) { + builder.array("stored_fields", ""); + } + if (excludeVector) { + builder.startObject("_source").array("excludes", FIELD_NAME).endObject(); + } else if ("empty".equals(storedFields)) { + builder.field("_source", true); + } + + builder.startObject("query") + .startObject(KNN) + .startObject(FIELD_NAME) + .array(VECTOR, queryVector) + .field(K, k) + .endObject() + .endObject() + .endObject(); + + builder.startObject("ext").startObject(MMR).field(CANDIDATES, 9).field(DIVERSITY, 0.9); + if (userProvidedVector) { + builder.field(VECTOR_FIELD_PATH, FIELD_NAME).field(VECTOR_FIELD_SPACE_TYPE, SpaceType.L2.getValue()); + } + builder.endObject().endObject().endObject(); + + return builder; + } + + private void verifyResults(List results, boolean excludeVector) { + if (excludeVector) { + results.forEach(r -> assertNull("Vector should be excluded", r.getVector())); + } else { + results.forEach(r -> assertNotNull("Vector should be included", r.getVector())); + } + assertEquals(QUERY_SIZE, results.size()); + assertEquals("0", results.get(0).getDocId()); + assertEquals("Should pick up the hit with diversity.", "8", results.get(1).getDocId()); + assertEquals("1", results.get(2).getDocId()); + } +} diff --git a/src/test/java/org/opensearch/knn/search/extension/MMRSearchExtBuilderTests.java b/src/test/java/org/opensearch/knn/search/extension/MMRSearchExtBuilderTests.java new file mode 100644 index 0000000000..07dfb1e519 --- /dev/null +++ b/src/test/java/org/opensearch/knn/search/extension/MMRSearchExtBuilderTests.java @@ -0,0 +1,195 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.extension; + +import org.junit.Before; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.util.KNNClusterUtil; +import org.opensearch.knn.search.processor.mmr.MMRTestCase; + +import java.io.IOException; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class MMRSearchExtBuilderTests extends MMRTestCase { + private float DELTA = 1e-6F; + private float DEFAULT_DIVERSITY = 0.5f; + private ClusterService clusterService = mock(ClusterService.class); + + @Before + public void setUp() throws Exception { + super.setUp(); + when(clusterService.getClusterSettings()).thenReturn(clusterSettingsWithSystemGeneratedFactoriesEnabled); + KNNClusterUtil.instance().initialize(clusterService, null); + } + + public void testBuilderDefaultsAndValues() { + MMRSearchExtBuilder builder = new MMRSearchExtBuilder.Builder().candidates(10) + .vectorFieldPath("vec") + .spaceType("l2") + .vectorFieldDataType("float") + .build(); + + assertEquals(DEFAULT_DIVERSITY, builder.getDiversity(), DELTA); + assertEquals(10, (int) builder.getCandidates()); + assertEquals("vec", builder.getVectorFieldPath()); + assertEquals(SpaceType.L2, builder.getSpaceType()); + assertEquals(VectorDataType.FLOAT, builder.getVectorFieldDataType()); + } + + public void testBuilder_whenNegativeDiversity_thenException() { + MMRSearchExtBuilder.Builder builder = new MMRSearchExtBuilder.Builder(); + builder.diversity(-0.1f); + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, builder::build); + String expectedError = "diversity in mmr query extension must be between 0.0 and 1.0"; + assertEquals(expectedError, ex.getMessage()); + } + + public void testBuilder_whenLargeDiversity_thenException() { + MMRSearchExtBuilder.Builder builder = new MMRSearchExtBuilder.Builder(); + builder.diversity(1.1f); + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, builder::build); + String expectedError = "diversity in mmr query extension must be between 0.0 and 1.0"; + assertEquals(expectedError, ex.getMessage()); + } + + public void testBuilder_whenNegativeCandidates_thenException() { + MMRSearchExtBuilder.Builder builder = new MMRSearchExtBuilder.Builder(); + builder.candidates(-1); + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, builder::build); + String expectedError = "candidates in mmr query extension must be larger than 0."; + assertEquals(expectedError, ex.getMessage()); + } + + public void testBuilder_whenInvalidSpaceType_thenException() { + MMRSearchExtBuilder.Builder builder = new MMRSearchExtBuilder.Builder(); + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> builder.spaceType("invalid")); + String expectedError = "vector_field_space_type in mmr query extension is not valid"; + assertTrue(ex.getMessage().contains(expectedError)); + } + + public void testBuilder_whenInvalidVectorDataType_thenException() { + MMRSearchExtBuilder.Builder builder = new MMRSearchExtBuilder.Builder(); + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> builder.vectorFieldDataType("invalid")); + String expectedError = "vector_field_data_type in mmr query extension is not valid"; + assertTrue(ex.getMessage().contains(expectedError)); + } + + public void testBuilder_whenEmptyVectorPath_thenException() { + MMRSearchExtBuilder.Builder builder = new MMRSearchExtBuilder.Builder(); + builder.vectorFieldPath(""); + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, builder::build); + String expectedError = "vector_field_path in mmr query extension should not be an empty string."; + assertEquals(expectedError, ex.getMessage()); + } + + public void testEqualsAndHashCode() { + MMRSearchExtBuilder builder1 = new MMRSearchExtBuilder.Builder().diversity(0.7f) + .candidates(3) + .vectorFieldPath("path") + .spaceType("l2") + .build(); + MMRSearchExtBuilder builder2 = new MMRSearchExtBuilder.Builder().diversity(0.7f) + .candidates(3) + .vectorFieldPath("path") + .spaceType("l2") + .build(); + MMRSearchExtBuilder builder3 = new MMRSearchExtBuilder.Builder().diversity(0.9f).build(); + + assertEquals(builder1, builder2); + assertEquals(builder1.hashCode(), builder2.hashCode()); + assertNotEquals(builder1, builder3); + } + + public void testSerializationRoundTrip() throws IOException { + MMRSearchExtBuilder original = new MMRSearchExtBuilder.Builder().diversity(0.6f) + .candidates(20) + .vectorFieldPath("vec") + .spaceType("l2") + .vectorFieldDataType("byte") + .build(); + + BytesStreamOutput out = new BytesStreamOutput(); + original.writeTo(out); + + MMRSearchExtBuilder deserialized = new MMRSearchExtBuilder(out.bytes().streamInput()); + + assertEquals(original, deserialized); + } + + public void testToXContentAndParse() throws IOException { + MMRSearchExtBuilder original = new MMRSearchExtBuilder.Builder().diversity(0.9f) + .candidates(15) + .vectorFieldPath("vector") + .spaceType(SpaceType.COSINESIMIL.getValue()) + .build(); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + original.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + xContentBuilder.endObject(); + + try (XContentParser parser = createParser(xContentBuilder)) { + parser.nextToken(); // start object + parser.nextToken(); // field name "mmr" + parser.nextToken(); // start object + MMRSearchExtBuilder parsed = MMRSearchExtBuilder.parse(parser); + + assertEquals(original, parsed); + } + } + + public void testParse_whenUnsupportedField_thenException() throws IOException { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + xContentBuilder.startObject("mmr"); + xContentBuilder.field("unsupported_field", "value"); + xContentBuilder.endObject(); + xContentBuilder.endObject(); + + try (XContentParser parser = createParser(xContentBuilder)) { + parser.nextToken(); // start object + parser.nextToken(); // field name "mmr" + parser.nextToken(); // start object + ParsingException ex = assertThrows(ParsingException.class, () -> MMRSearchExtBuilder.parse(parser)); + String expectedError = "[mmr] query extension does not support [unsupported_field]"; + assertEquals(expectedError, ex.getMessage()); + } + } + + public void testParse_whenMMRProcessorsNotEnabled_thenException() throws IOException { + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + KNNClusterUtil.instance().initialize(clusterService, null); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + xContentBuilder.startObject("mmr"); + xContentBuilder.endObject(); + xContentBuilder.endObject(); + + try (XContentParser parser = createParser(xContentBuilder)) { + parser.nextToken(); // start object + parser.nextToken(); // field name "mmr" + parser.nextToken(); // start object + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> MMRSearchExtBuilder.parse(parser)); + String expectedError = + "We need to enable [mmr_over_sample_factory, mmr_rerank_factory] in the cluster setting [cluster.search.enabled_system_generated_factories] to support the mmr search extension."; + assertEquals(expectedError, ex.getMessage()); + } + } +} diff --git a/src/test/java/org/opensearch/knn/search/processor/mmr/MMRKnnQueryTransformerTests.java b/src/test/java/org/opensearch/knn/search/processor/mmr/MMRKnnQueryTransformerTests.java new file mode 100644 index 0000000000..09d7a38123 --- /dev/null +++ b/src/test/java/org/opensearch/knn/search/processor/mmr/MMRKnnQueryTransformerTests.java @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.transport.client.Client; + +import java.util.List; +import java.util.Map; + +import static org.mockito.Mockito.*; +import static org.opensearch.knn.common.KNNConstants.TOP_LEVEL_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE; + +public class MMRKnnQueryTransformerTests extends MMRTestCase { + private Client client; + private MMRKnnQueryTransformer transformer; + private KNNQueryBuilder queryBuilder; + private ActionListener listener; + private MMRTransformContext transformContext; + private MMRRerankContext processingContext; + + @Override + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + transformer = new MMRKnnQueryTransformer(); + queryBuilder = mock(KNNQueryBuilder.class); + listener = mock(ActionListener.class); + processingContext = new MMRRerankContext(); + transformContext = new MMRTransformContext(10, processingContext, List.of(), List.of(), null, null, null, client, false); + } + + public void testTransform_whenNoMaxDistanceOrMinScore_thenSetsK() { + when(queryBuilder.getMaxDistance()).thenReturn(null); + when(queryBuilder.getMinScore()).thenReturn(null); + + transformer.transform(queryBuilder, listener, transformContext); + + verify(queryBuilder).setK(10); + } + + public void testTransform_whenMinScore_thenNotSetsK() { + when(queryBuilder.getMaxDistance()).thenReturn(null); + when(queryBuilder.getMinScore()).thenReturn(0.5f); // non-null minScore + + transformer.transform(queryBuilder, listener, transformContext); + + verify(queryBuilder, never()).setK(anyInt()); + } + + public void testTransform_whenVectorFieldInfoAlreadyResolved_thenEarlyExits() { + transformContext = new MMRTransformContext( + 10, + processingContext, + List.of(), + List.of(), + null, + "vector.field.path", + null, + client, + true + ); + + transformer.transform(queryBuilder, listener, transformContext); + + verify(listener).onResponse(null); + verifyNoMoreInteractions(client); + } + + public void testTransform_whenNoUserProvidedVectorFieldPath_thenResolveSpaceType() { + String indexName = "test-index"; + String vectorFieldName = "vectorField"; + IndexMetadata indexMetadata = mock(IndexMetadata.class); + MappingMetadata mappingMetadata = mock(MappingMetadata.class); + Map mapping = Map.of( + indexName, + Map.of( + "properties", + Map.of( + vectorFieldName, + Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, TOP_LEVEL_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + ) + ) + ); + when(indexMetadata.getIndex()).thenReturn(new Index(indexName, "uuid")); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + when(mappingMetadata.sourceAsMap()).thenReturn(mapping); + when(queryBuilder.fieldName()).thenReturn(vectorFieldName); + + transformContext = new MMRTransformContext( + 10, + processingContext, + List.of(indexMetadata), + List.of(), + null, + null, + null, + client, + false + ); + + transformer.transform(queryBuilder, listener, transformContext); + + verify(listener).onResponse(null); + assertEquals(vectorFieldName, processingContext.getVectorFieldPath()); + assertEquals(SpaceType.L2, processingContext.getSpaceType()); + } +} diff --git a/src/test/java/org/opensearch/knn/search/processor/mmr/MMROverSampleProcessorTests.java b/src/test/java/org/opensearch/knn/search/processor/mmr/MMROverSampleProcessorTests.java new file mode 100644 index 0000000000..51bf53f4e3 --- /dev/null +++ b/src/test/java/org/opensearch/knn/search/processor/mmr/MMROverSampleProcessorTests.java @@ -0,0 +1,278 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import org.mockito.ArgumentCaptor; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.search.extension.MMRSearchExtBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.SystemGeneratedProcessor; +import org.opensearch.transport.client.Client; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.opensearch.knn.common.KNNConstants.*; + +public class MMROverSampleProcessorTests extends MMRTestCase { + private Client mockClient; + private MMROverSampleProcessor processor; + + @Override + public void setUp() throws Exception { + super.setUp(); + mockClient = mock(Client.class); + processor = new MMROverSampleProcessor("testTag", true, mockClient, getMockMMRQueryTransformers()); + } + + public void testMetadata() { + assertEquals(MMROverSampleProcessor.TYPE, processor.getType()); + assertTrue(processor.getDescription().contains("system generated processor")); + assertEquals("testTag", processor.getTag()); + assertTrue(processor.isIgnoreFailure()); + assertEquals(SystemGeneratedProcessor.ExecutionStage.POST_USER_DEFINED, processor.getExecutionStage()); + } + + public void testSynchronousProcessRequestThrows() { + UnsupportedOperationException exception = assertThrows( + UnsupportedOperationException.class, + () -> processor.processRequest(new SearchRequest()) + ); + String expectedError = "Should not try to use mmr_over_sample to process a search request synchronously."; + assertEquals(expectedError, exception.getMessage()); + } + + public void testProcessRequestAsync_nullRequest_callsOnFailure() { + ActionListener listener = mock(ActionListener.class); + + processor.processRequestAsync(null, new PipelineProcessingContext(), listener); + + String expectedError = "Search request passed to mmr_over_sample search request processor must have mmr search extension."; + verifyException(listener, IllegalStateException.class, expectedError); + } + + public void testExtractMMRExtension_whenMissing_thenException() { + ActionListener listener = mock(ActionListener.class); + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder().ext(Collections.emptyList())); // no extensions + + processor.processRequestAsync(request, new PipelineProcessingContext(), listener); + + String expectedError = "SearchRequest passed to mmr_over_sample processor must have an MMRSearchExtBuilder"; + verifyException(listener, IllegalStateException.class, expectedError); + } + + public void testProcessRequestAsync_whenHappyCase() { + String indexName = "test-index"; + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + ActionListener listener = mock(ActionListener.class); + + SearchRequest request = buildSearchRequest(new String[] { indexName }, new MMRSearchExtBuilder.Builder().build()); + + mockClusterIndexMetadata(Map.of(indexName, Collections.emptyMap())); + + processor.processRequestAsync(request, pipelineProcessingContext, listener); + + ArgumentCaptor captor = ArgumentCaptor.forClass(SearchRequest.class); + verify(listener).onResponse(captor.capture()); + SearchRequest searchRequest = captor.getValue(); + assertEquals(30, searchRequest.source().size()); + MMRRerankContext mmrRerankContext = (MMRRerankContext) pipelineProcessingContext.getAttribute(MMR_RERANK_CONTEXT); + assertEquals(10, (int) mmrRerankContext.getOriginalQuerySize()); + assertEquals(0.5f, mmrRerankContext.getDiversity(), DELTA); + } + + public void testProcessRequestAsync_whenNullQueryBuilder_thenException() { + String indexName = "test-index"; + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + ActionListener listener = mock(ActionListener.class); + + SearchRequest request = buildSearchRequest(new String[] { indexName }, new MMRSearchExtBuilder.Builder().build()); + request.source().query(null); + + mockClusterIndexMetadata(Map.of(indexName, Collections.emptyMap())); + + processor.processRequestAsync(request, pipelineProcessingContext, listener); + + String expectedError = "Query builder must not be null to do Maximal Marginal Relevance rerank."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testProcessRequestAsync_whenUnsupportedQueryBuilder_thenException() { + String indexName = "test-index"; + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + ActionListener listener = mock(ActionListener.class); + + SearchRequest request = buildSearchRequest(new String[] { indexName }, new MMRSearchExtBuilder.Builder().build()); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + request.source().query(boolQueryBuilder); + + mockClusterIndexMetadata(Map.of(indexName, Collections.emptyMap())); + + processor.processRequestAsync(request, pipelineProcessingContext, listener); + + String expectedError = "Maximal Marginal Relevance rerank doesn't support the query type [bool]"; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testProcessRequestAsync_whenHappyCaseWithRemoteIndex() { + String indexName = "test-index"; + String remoteIndexName = "remote:test-index"; + String vectorFieldName = "vectorField"; + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + ActionListener listener = mock(ActionListener.class); + + MMRSearchExtBuilder mmrSearchExtBuilder = new MMRSearchExtBuilder.Builder().vectorFieldPath(vectorFieldName) + .spaceType(SpaceType.L2.getValue()) + .vectorFieldDataType(VectorDataType.FLOAT.getValue()) + .build(); + SearchRequest request = buildSearchRequest(new String[] { indexName, remoteIndexName }, mmrSearchExtBuilder); + + mockClusterIndexMetadata( + Map.of( + indexName, + Map.of( + "properties", + Map.of( + vectorFieldName, + Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, TOP_LEVEL_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + ) + ) + ) + ); + + processor.processRequestAsync(request, pipelineProcessingContext, listener); + + ArgumentCaptor captor = ArgumentCaptor.forClass(SearchRequest.class); + verify(listener).onResponse(captor.capture()); + SearchRequest searchRequest = captor.getValue(); + assertEquals(30, searchRequest.source().size()); + MMRRerankContext mmrRerankContext = (MMRRerankContext) pipelineProcessingContext.getAttribute(MMR_RERANK_CONTEXT); + assertEquals(10, (int) mmrRerankContext.getOriginalQuerySize()); + assertEquals(0.5f, mmrRerankContext.getDiversity(), DELTA); + assertEquals(vectorFieldName, mmrRerankContext.getVectorFieldPath()); + assertEquals(SpaceType.L2, mmrRerankContext.getSpaceType()); + assertEquals(VectorDataType.FLOAT, mmrRerankContext.getVectorDataType()); + } + + public void testProcessRequestAsync_whenRemoteIndexWithoutSpaceType_thenException() { + String remoteIndexName = "remote:test-index"; + String vectorFieldName = "vectorField"; + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + ActionListener listener = mock(ActionListener.class); + + MMRSearchExtBuilder mmrSearchExtBuilder = new MMRSearchExtBuilder.Builder().vectorFieldPath(vectorFieldName).build(); + SearchRequest request = buildSearchRequest(new String[] { remoteIndexName }, mmrSearchExtBuilder); + + mockClusterIndexMetadata(Collections.emptyMap()); + + processor.processRequestAsync(request, pipelineProcessingContext, listener); + + String expectedError = + "vector_field_space_type is required in the MMR query extension when querying remote indices [remote:test-index]."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testProcessRequestAsync_whenRemoteIndexWithoutVectorDataType_thenException() { + String remoteIndexName = "remote:test-index"; + String vectorFieldName = "vectorField"; + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + ActionListener listener = mock(ActionListener.class); + + MMRSearchExtBuilder mmrSearchExtBuilder = new MMRSearchExtBuilder.Builder().vectorFieldPath(vectorFieldName) + .spaceType(SpaceType.L2.getValue()) + .build(); + SearchRequest request = buildSearchRequest(new String[] { remoteIndexName }, mmrSearchExtBuilder); + + mockClusterIndexMetadata(Collections.emptyMap()); + + processor.processRequestAsync(request, pipelineProcessingContext, listener); + + String expectedError = + "vector_field_data_type is required in the MMR query extension when querying remote indices [remote:test-index]."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testProcessRequestAsync_whenKnnFieldWithModelId() { + String indexName = "test-index"; + String vectorFieldName = "vectorField"; + String modelId = "modelId"; + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + ActionListener listener = mock(ActionListener.class); + + MMRSearchExtBuilder mmrSearchExtBuilder = new MMRSearchExtBuilder.Builder().vectorFieldPath(vectorFieldName) + .spaceType(SpaceType.L2.getValue()) + .build(); + SearchRequest request = buildSearchRequest(new String[] { indexName }, mmrSearchExtBuilder); + FetchSourceContext fetchSourceContext = new FetchSourceContext(true, new String[] {}, new String[] { vectorFieldName }); + request.source().fetchSource(fetchSourceContext); + + mockClusterIndexMetadata( + Map.of( + indexName, + Map.of("properties", Map.of(vectorFieldName, Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, MODEL_ID, modelId))) + ) + ); + MMRVectorFieldInfo vectorFieldInfo = new MMRVectorFieldInfo(); + vectorFieldInfo.setVectorDataType(VectorDataType.FLOAT); + vectorFieldInfo.setSpaceType(SpaceType.L2); + mockModelMetadata(mockClient, Map.of(modelId, vectorFieldInfo)); + + processor.processRequestAsync(request, pipelineProcessingContext, listener); + + ArgumentCaptor captor = ArgumentCaptor.forClass(SearchRequest.class); + verify(listener).onResponse(captor.capture()); + SearchRequest searchRequest = captor.getValue(); + assertEquals(30, searchRequest.source().size()); + assertEquals("Fetch source should be set to fetch all fields.", 0, searchRequest.source().fetchSource().excludes().length); + MMRRerankContext mmrRerankContext = (MMRRerankContext) pipelineProcessingContext.getAttribute(MMR_RERANK_CONTEXT); + assertEquals(10, (int) mmrRerankContext.getOriginalQuerySize()); + assertEquals(0.5f, mmrRerankContext.getDiversity(), DELTA); + assertEquals(vectorFieldName, mmrRerankContext.getVectorFieldPath()); + assertEquals(SpaceType.L2, mmrRerankContext.getSpaceType()); + assertEquals(VectorDataType.FLOAT, mmrRerankContext.getVectorDataType()); + assertEquals(fetchSourceContext, mmrRerankContext.getOriginalFetchSourceContext()); + } + + private Map> getMockMMRQueryTransformers() { + MMRQueryTransformer transformer = mock(MMRKnnQueryTransformer.class); + // mock a no-op knn query transformer here + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + // Simulate success + listener.onResponse(null); + return null; // void method must return null + }).when(transformer).transform(any(KNNQueryBuilder.class), any(ActionListener.class), any(MMRTransformContext.class)); + return Map.of(KNNQueryBuilder.NAME, transformer); + } + + private SearchRequest buildSearchRequest(String[] indices, MMRSearchExtBuilder mmrSearchExtBuilder) { + KNNQueryBuilder queryBuilder = mock(KNNQueryBuilder.class); + when(queryBuilder.getWriteableName()).thenReturn(KNNQueryBuilder.NAME); + + SearchRequest request = new SearchRequest(); + request.indices(indices); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.ext(List.of(mmrSearchExtBuilder)); + request.source(searchSourceBuilder); + + return request; + } +} diff --git a/src/test/java/org/opensearch/knn/search/processor/mmr/MMRRerankProcessorTests.java b/src/test/java/org/opensearch/knn/search/processor/mmr/MMRRerankProcessorTests.java new file mode 100644 index 0000000000..5ea08b144d --- /dev/null +++ b/src/test/java/org/opensearch/knn/search/processor/mmr/MMRRerankProcessorTests.java @@ -0,0 +1,220 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.pipeline.PipelineProcessingContext; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; + +import static org.mockito.Mockito.mock; +import static org.opensearch.knn.common.KNNConstants.MMR_RERANK_CONTEXT; + +public class MMRRerankProcessorTests extends KNNTestCase { + private MMRRerankProcessor processor; + private SearchRequest searchRequest; + + @Override + public void setUp() throws Exception { + super.setUp(); + processor = new MMRRerankProcessor("test-tag", false); + searchRequest = new SearchRequest(); + } + + public void testProcessResponse_withoutContext_thenThrowsUnsupportedOperationException() { + UnsupportedOperationException ex = assertThrows( + UnsupportedOperationException.class, + () -> processor.processResponse(searchRequest, mock(SearchResponse.class)) + ); + + assertEquals("Should not try to use mmr_rerank to process a search response without PipelineProcessingContext.", ex.getMessage()); + } + + public void testProcessResponse_whenEmptyHits_thenReturnOriginalResponse() throws IOException { + SearchResponse emptyResponse = createSearchResponse(new SearchHit[] {}); + + PipelineProcessingContext ctx = mock(PipelineProcessingContext.class); + + SearchResponse result = processor.processResponse(searchRequest, emptyResponse, ctx); + + assertEquals("Processor should return the same response when there are no hits.", emptyResponse, result); + } + + public void testProcessResponse_whenHappyCaseFloatWithL2_thenRerank() throws IOException { + runProcessResponseRerankHappyCase(SpaceType.L2, VectorDataType.FLOAT); + } + + public void testProcessResponse_whenHappyCaseBinaryWithHammingSpaceType_thenRerank() throws IOException { + runProcessResponseRerankHappyCase(SpaceType.HAMMING, VectorDataType.BINARY); + } + + private void runProcessResponseRerankHappyCase(SpaceType spaceType, VectorDataType vectorDataType) throws IOException { + SearchResponse searchResponse = createSearchResponse(); + assertEquals(10, searchResponse.getInternalResponse().hits().getHits().length); + assertEquals(0, searchResponse.getInternalResponse().hits().getHits()[0].docId()); + assertNotNull( + "Should have the knn_vector in the source.", + searchResponse.getInternalResponse().hits().getHits()[0].getSourceAsMap().get("knn_vector") + ); + assertEquals(1, searchResponse.getInternalResponse().hits().getHits()[1].docId()); + assertEquals(2, searchResponse.getInternalResponse().hits().getHits()[2].docId()); + + MMRRerankContext mmrRerankContext = new MMRRerankContext(); + mmrRerankContext.setDiversity(0.5f); + mmrRerankContext.setOriginalQuerySize(3); + mmrRerankContext.setSpaceType(spaceType); + mmrRerankContext.setVectorDataType(vectorDataType); + mmrRerankContext.setVectorFieldPath("knn_vector"); + mmrRerankContext.setOriginalFetchSourceContext(new FetchSourceContext(true, new String[] {}, new String[] { "knn_vector" })); + PipelineProcessingContext ctx = new PipelineProcessingContext(); + ctx.setAttribute(MMR_RERANK_CONTEXT, mmrRerankContext); + + SearchResponse result = processor.processResponse(searchRequest, searchResponse, ctx); + + assertEquals("Should reduce the hits to the original query size.", 3, result.getInternalResponse().hits().getHits().length); + assertEquals(0, result.getInternalResponse().hits().getHits()[0].docId()); + assertNull( + "Should exclude the knn_vector from the source.", + result.getInternalResponse().hits().getHits()[0].getSourceAsMap().get("knn_vector") + ); + assertEquals("Should pick the hit with diversity.", 8, result.getInternalResponse().hits().getHits()[1].docId()); + assertEquals("Should pick the hit with diversity.", 9, result.getInternalResponse().hits().getHits()[2].docId()); + } + + public void testProcessResponse_whenMissingRerankContext_thenException() throws IOException { + SearchResponse searchResponse = createSearchResponse(); + + PipelineProcessingContext ctx = new PipelineProcessingContext(); + + IllegalStateException exception = assertThrows( + IllegalStateException.class, + () -> processor.processResponse(searchRequest, searchResponse, ctx) + ); + String expectedMessage = "MMR rerank context cannot be null"; + assertEquals(expectedMessage, exception.getMessage()); + } + + public void testProcessResponse_whenMissingSpaceType_thenException() throws IOException { + SearchResponse searchResponse = createSearchResponse(); + + MMRRerankContext mmrRerankContext = new MMRRerankContext(); + PipelineProcessingContext ctx = new PipelineProcessingContext(); + ctx.setAttribute(MMR_RERANK_CONTEXT, mmrRerankContext); + + IllegalStateException exception = assertThrows( + IllegalStateException.class, + () -> processor.processResponse(searchRequest, searchResponse, ctx) + ); + String expectedMessage = "Space type in MMR rerank context cannot be null"; + assertEquals(expectedMessage, exception.getMessage()); + } + + public void testProcessResponse_whenMissingOriginalQuerySize_thenException() throws IOException { + SearchResponse searchResponse = createSearchResponse(); + + MMRRerankContext mmrRerankContext = new MMRRerankContext(); + mmrRerankContext.setSpaceType(SpaceType.L2); + PipelineProcessingContext ctx = new PipelineProcessingContext(); + ctx.setAttribute(MMR_RERANK_CONTEXT, mmrRerankContext); + + IllegalStateException exception = assertThrows( + IllegalStateException.class, + () -> processor.processResponse(searchRequest, searchResponse, ctx) + ); + String expectedMessage = "Original query size in MMR rerank context cannot be null"; + assertEquals(expectedMessage, exception.getMessage()); + } + + public void testProcessResponse_whenMissingDiversity_thenException() throws IOException { + SearchResponse searchResponse = createSearchResponse(); + + MMRRerankContext mmrRerankContext = new MMRRerankContext(); + mmrRerankContext.setSpaceType(SpaceType.L2); + mmrRerankContext.setOriginalQuerySize(3); + PipelineProcessingContext ctx = new PipelineProcessingContext(); + ctx.setAttribute(MMR_RERANK_CONTEXT, mmrRerankContext); + + IllegalStateException exception = assertThrows( + IllegalStateException.class, + () -> processor.processResponse(searchRequest, searchResponse, ctx) + ); + String expectedMessage = "Diversity in MMR rerank context cannot be null"; + assertEquals(expectedMessage, exception.getMessage()); + } + + private SearchResponse createSearchResponse() throws IOException { + SearchHit[] hits = new SearchHit[10]; + + // 8 similar hits, high score + float[] similarVector = new float[] { 1f, 1f }; + for (int i = 0; i < 8; i++) { + XContentBuilder sourceBuilder = JsonXContent.contentBuilder().startObject().array("knn_vector", similarVector).endObject(); + + SearchHit hit = new SearchHit(i, String.valueOf(i), Map.of(), Map.of()); + hit.sourceRef(BytesReference.bytes(sourceBuilder)); + hit.score(1f); + hits[i] = hit; + } + + // 2 diverse hits, slightly lower score + float[][] diverseVectors = new float[][] { { 1f, 2f }, { 2f, 1f } }; + for (int i = 0; i < 2; i++) { + int idx = i + 8; + XContentBuilder sourceBuilder = JsonXContent.contentBuilder().startObject().array("knn_vector", diverseVectors[i]).endObject(); + + SearchHit hit = new SearchHit(idx, String.valueOf(idx), Map.of(), Map.of()); + hit.sourceRef(BytesReference.bytes(sourceBuilder)); + hit.score(0.8f); // slightly lower than top similar hits + hits[idx] = hit; + } + return createSearchResponse(hits); + } + + private SearchResponse createSearchResponse(SearchHit... hits) { + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + + float maxScore = Arrays.stream(hits).map(SearchHit::getScore).max(Float::compare).orElse(Float.NEGATIVE_INFINITY); + + SearchHits searchHits = new SearchHits(hits, totalHits, maxScore); + + SearchResponseSections sections = new SearchResponseSections( + searchHits, + null, // aggregations + null, // suggest + false, // timedOut + false, // terminatedEarly + null, // profileShardResults + 0 // numReducePhases + ); + + return new SearchResponse( + sections, + null, // scrollId + 1, // totalShards + 1, // successfulShards + 0, // skippedShards + 1, // tookInMillis + new ShardSearchFailure[0], + new SearchResponse.Clusters(1, 1, 0), + null // pitId + ); + } +} diff --git a/src/test/java/org/opensearch/knn/search/processor/mmr/MMRTestCase.java b/src/test/java/org/opensearch/knn/search/processor/mmr/MMRTestCase.java new file mode 100644 index 0000000000..824abde5ba --- /dev/null +++ b/src/test/java/org/opensearch/knn/search/processor/mmr/MMRTestCase.java @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import org.mockito.ArgumentCaptor; +import org.opensearch.action.IndicesRequest; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.util.KNNClusterUtil; +import org.opensearch.knn.indices.Model; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.plugin.transport.GetModelAction; +import org.opensearch.knn.plugin.transport.GetModelRequest; +import org.opensearch.knn.plugin.transport.GetModelResponse; +import org.opensearch.transport.client.Client; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.search.pipeline.SearchPipelineService.ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING; + +public class MMRTestCase extends KNNTestCase { + float DELTA = 1e-6F; + protected ClusterSettings clusterSettingsWithSystemGeneratedFactoriesEnabled = new ClusterSettings( + Settings.builder().putList(ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING.getKey(), "*").build(), + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS + ); + + void mockClusterIndexMetadata(final Map> indexToMappingMap) { + ClusterService clusterService = mock(ClusterService.class); + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(clusterService.getClusterSettings()).thenReturn(clusterSettingsWithSystemGeneratedFactoriesEnabled); + + final Set indices = indexToMappingMap.keySet(); + Map indexNameToIndexMap = new HashMap<>(); + for (String indexName : indices) { + indexNameToIndexMap.put(indexName, new Index(indexName, "uuid")); + } + for (Map.Entry> entry : indexToMappingMap.entrySet()) { + final Index index = indexNameToIndexMap.get(entry.getKey()); + final IndexMetadata indexMetadata = mock(IndexMetadata.class); + final MappingMetadata mappingMetadata = mock(MappingMetadata.class); + when(metadata.index(index)).thenReturn(indexMetadata); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + when(indexMetadata.getIndex()).thenReturn(index); + when(mappingMetadata.sourceAsMap()).thenReturn(entry.getValue()); + } + + IndexNameExpressionResolver resolver = mock(IndexNameExpressionResolver.class); + // simply return the indices of the request + when(resolver.concreteIndices(any(ClusterState.class), any(IndicesRequest.class))).thenAnswer(invocation -> { + IndicesRequest indicesRequest = (IndicesRequest) invocation.getArguments()[1]; + return Arrays.stream(indicesRequest.indices()).map(indexNameToIndexMap::get).toArray(Index[]::new); + }); + + KNNClusterUtil clusterUtil = KNNClusterUtil.instance(); + clusterUtil.initialize(clusterService, resolver); + } + + void mockModelMetadata(Client mockClient, Map modelIdToFieldInfoMap) { + doAnswer(invocation -> { + GetModelRequest request = (GetModelRequest) invocation.getArguments()[1]; + String modelId = request.getModelID(); + ActionListener getModelListener = invocation.getArgument(2); + if (modelIdToFieldInfoMap != null && modelIdToFieldInfoMap.containsKey(modelId)) { + getModelListener.onResponse(createMockGetModelResponse(modelIdToFieldInfoMap.get(modelId))); + } else { + getModelListener.onFailure(new Exception("Model ID " + modelId + " not found")); + } + return null; + }).when(mockClient).execute(eq(GetModelAction.INSTANCE), any(GetModelRequest.class), any(ActionListener.class)); + } + + private GetModelResponse createMockGetModelResponse(MMRVectorFieldInfo mmrVectorFieldInfo) { + GetModelResponse mockResponse = mock(GetModelResponse.class); + Model mockModel = mock(Model.class); + ModelMetadata mockModelMetadata = mock(ModelMetadata.class); + when(mockResponse.getModel()).thenReturn(mockModel); + when(mockModel.getModelMetadata()).thenReturn(mockModelMetadata); + when(mockModelMetadata.getSpaceType()).thenReturn(mmrVectorFieldInfo.getSpaceType()); + when(mockModelMetadata.getVectorDataType()).thenReturn(mmrVectorFieldInfo.getVectorDataType()); + return mockResponse; + } + + void verifyException(ActionListener listener, Class expectedType, String expectedMessage) { + + ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(captor.capture()); + + Exception exception = captor.getValue(); + assertTrue("Expected " + expectedType.getSimpleName() + " but got " + exception.getClass(), expectedType.isInstance(exception)); + assertEquals(expectedMessage, exception.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/knn/search/processor/mmr/MMRUtilTests.java b/src/test/java/org/opensearch/knn/search/processor/mmr/MMRUtilTests.java new file mode 100644 index 0000000000..3b3e788e83 --- /dev/null +++ b/src/test/java/org/opensearch/knn/search/processor/mmr/MMRUtilTests.java @@ -0,0 +1,495 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.search.processor.mmr; + +import org.mockito.ArgumentCaptor; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.search.extension.MMRSearchExtBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.ProcessorGenerationContext; +import org.opensearch.transport.client.Client; + +import java.util.*; + +import static org.mockito.Mockito.*; +import static org.opensearch.knn.common.KNNConstants.*; +import static org.opensearch.knn.search.processor.mmr.MMRUtil.getMMRFieldMappingByPath; + +public class MMRUtilTests extends MMRTestCase { + private Client mockClient; + private ActionListener listener; + + @Override + public void setUp() throws Exception { + super.setUp(); + mockClient = mock(Client.class); + listener = mock(ActionListener.class); + } + + public void testExtractVectorFromHit_whenValidList_thenReturnFloatArray() { + Map source = new HashMap<>(); + source.put("embedding", Arrays.asList(0.1, 0.2, 0.3)); + + float[] result = (float[]) MMRUtil.extractVectorFromHit(source, "embedding", "doc1", true); + + assertArrayEquals(new float[] { 0.1f, 0.2f, 0.3f }, result, 0.0001f); + } + + public void testExtractVectorFromHit_whenInvalidElementType_thenThrow() { + Map source = new HashMap<>(); + source.put("embedding", Arrays.asList(1.0, "bad")); + + IllegalArgumentException ex = assertThrows( + IllegalArgumentException.class, + () -> MMRUtil.extractVectorFromHit(source, "embedding", "doc1", true) + ); + assertTrue(ex.getMessage().contains("unexpected value at the vector field")); + } + + public void testExtractVectorFromHit_whenFieldNotFound_thenThrow() { + Map source = new HashMap<>(); + + IllegalArgumentException ex = assertThrows( + IllegalArgumentException.class, + () -> MMRUtil.extractVectorFromHit(source, "missing", "doc1", true) + ); + assertTrue(ex.getMessage().contains("not found")); + } + + public void testResolveKnnVectorFieldInfo_whenAllUnmappedField_thenDefaultFieldInfo() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = null; + VectorDataType userProvidedVectorDataType = null; + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(Collections.emptyMap())), + mockClient, + listener + ); + + verifyVectorFieldInfo(listener, new MMRVectorFieldInfo(SpaceType.L2, VectorDataType.DEFAULT)); + } + + public void testResolveKnnVectorFieldInfo_whenAllUnmappedField_thenUserProvidedFieldInfo() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = SpaceType.COSINESIMIL; + VectorDataType userProvidedVectorDataType = VectorDataType.FLOAT; + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(Collections.emptyMap())), + mockClient, + listener + ); + + verifyVectorFieldInfo(listener, new MMRVectorFieldInfo(SpaceType.COSINESIMIL, VectorDataType.FLOAT)); + } + + public void testResolveKnnVectorFieldInfo_whenNonKnnField_thenException() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = null; + VectorDataType userProvidedVectorDataType = null; + Map mapping = Map.of("properties", Map.of("field", Map.of(TYPE, "keyword"))); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping)), + mockClient, + listener + ); + + String expectedError = "MMR query extension cannot support non knn_vector field [index:field]."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testResolveKnnVectorFieldInfo_whenDifferentSpaceTypes_thenException() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = null; + VectorDataType userProvidedVectorDataType = null; + Map mapping = Map.of( + "properties", + Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, TOP_LEVEL_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())) + ); + Map mapping1 = Map.of( + "properties", + Map.of( + "field", + Map.of( + TYPE, + KNNVectorFieldMapper.CONTENT_TYPE, + KNN_METHOD, + Map.of(METHOD_PARAMETER_SPACE_TYPE, SpaceType.COSINESIMIL.getValue()) + ) + ) + ); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping), createMockIndexMetadata(mapping1)), + mockClient, + listener + ); + + String expectedError = + "MMR query extension cannot support different space type [l2, cosinesimil] for the knn_vector field at path field."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testResolveKnnVectorFieldInfo_whenDifferentVectorDataTypes_thenException() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = null; + VectorDataType userProvidedVectorDataType = null; + Map mapping = Map.of( + "properties", + Map.of( + "field", + Map.of( + TYPE, + KNNVectorFieldMapper.CONTENT_TYPE, + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BINARY.getValue(), + TOP_LEVEL_PARAMETER_SPACE_TYPE, + SpaceType.L2.getValue() + ) + ) + ); + Map mapping1 = Map.of( + "properties", + Map.of( + "field", + Map.of( + TYPE, + KNNVectorFieldMapper.CONTENT_TYPE, + VECTOR_DATA_TYPE_FIELD, + VectorDataType.FLOAT.getValue(), + KNN_METHOD, + Map.of(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + ) + ) + ); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping), createMockIndexMetadata(mapping1)), + mockClient, + listener + ); + + String expectedError = + "MMR query extension cannot support different vector data type [binary, float] for the knn_vector field at path field."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testResolveKnnVectorFieldInfo_whenDifferentUserProvidedSpaceTypes_thenException() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = SpaceType.COSINESIMIL; + VectorDataType userProvidedVectorDataType = null; + Map mapping = Map.of( + "properties", + Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, TOP_LEVEL_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())) + ); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping)), + mockClient, + listener + ); + + String expectedError = + "The space type [cosinesimil] provided in the MMR query extension does not match the space type [l2] in target indices."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testResolveKnnVectorFieldInfo_whenDifferentUserProvidedVectorDataTypes_thenException() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = null; + VectorDataType userProvidedVectorDataType = VectorDataType.FLOAT; + Map mapping = Map.of( + "properties", + Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue())) + ); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping)), + mockClient, + listener + ); + + String expectedError = + "The vector data type [float] provided in the MMR query extension does not match the vector data type [byte] in target indices."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testResolveKnnVectorFieldInfo_whenMappedFieldNoInfo_thenDefaultFieldInfo() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = null; + VectorDataType userProvidedVectorDataType = null; + Map mapping = Map.of("properties", Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE))); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping)), + mockClient, + listener + ); + + verifyVectorFieldInfo(listener, new MMRVectorFieldInfo(SpaceType.L2, VectorDataType.DEFAULT)); + } + + public void testResolveKnnVectorFieldInfo_whenMappedFieldWithModelId_thenFieldInfoFromModel() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = null; + VectorDataType userProvidedVectorDataType = null; + String modelId = "modelId"; + Map mapping = Map.of( + "properties", + Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, MODEL_ID, modelId)) + ); + mockModelMetadata(mockClient, Map.of(modelId, new MMRVectorFieldInfo(SpaceType.HAMMING, VectorDataType.BINARY))); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping)), + mockClient, + listener + ); + + verifyVectorFieldInfo(listener, new MMRVectorFieldInfo(SpaceType.HAMMING, VectorDataType.BINARY)); + } + + public void testResolveKnnVectorFieldInfo_whenDifferentModelSpaceTypes_thenException() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = null; + VectorDataType userProvidedVectorDataType = null; + String modelId1 = "model1"; + String modelId2 = "model2"; + Map mapping = Map.of( + "properties", + Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, MODEL_ID, modelId1)) + ); + Map mapping1 = Map.of( + "properties", + Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, MODEL_ID, modelId2)) + ); + mockModelMetadata( + mockClient, + Map.of( + modelId1, + new MMRVectorFieldInfo(SpaceType.L2, VectorDataType.FLOAT), + modelId2, + new MMRVectorFieldInfo(SpaceType.COSINESIMIL, VectorDataType.FLOAT) + ) + ); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping), createMockIndexMetadata(mapping1)), + mockClient, + listener + ); + + String expectedError = + "MMR query extension cannot support different space type [l2, cosinesimil] for the knn_vector field at path field."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testResolveKnnVectorFieldInfo_whenDifferentModelVectorDataTypes_thenException() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = null; + VectorDataType userProvidedVectorDataType = null; + String modelId1 = "model1"; + String modelId2 = "model2"; + Map mapping = Map.of( + "properties", + Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, MODEL_ID, modelId1)) + ); + Map mapping1 = Map.of( + "properties", + Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, MODEL_ID, modelId2)) + ); + mockModelMetadata( + mockClient, + Map.of( + modelId1, + new MMRVectorFieldInfo(SpaceType.L2, VectorDataType.FLOAT), + modelId2, + new MMRVectorFieldInfo(SpaceType.L2, VectorDataType.BINARY) + ) + ); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping), createMockIndexMetadata(mapping1)), + mockClient, + listener + ); + + String expectedError = + "MMR query extension cannot support different vector data type [float, binary] for the knn_vector field at path field."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testResolveKnnVectorFieldInfo_whenDifferentSpaceTypeFromModelAndUser_thenException() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = SpaceType.COSINESIMIL; + VectorDataType userProvidedVectorDataType = null; + String modelId1 = "model1"; + Map mapping = Map.of( + "properties", + Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, MODEL_ID, modelId1)) + ); + mockModelMetadata(mockClient, Map.of(modelId1, new MMRVectorFieldInfo(SpaceType.L2, VectorDataType.FLOAT))); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping)), + mockClient, + listener + ); + + String expectedError = + "The space type [cosinesimil] provided in the MMR query extension does not match the space type [l2] in target indices."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testResolveKnnVectorFieldInfo_whenDifferentVectorDataTypeFromModelAndUser_thenException() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = null; + VectorDataType userProvidedVectorDataType = VectorDataType.BINARY; + String modelId1 = "model1"; + Map mapping = Map.of( + "properties", + Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, MODEL_ID, modelId1)) + ); + mockModelMetadata(mockClient, Map.of(modelId1, new MMRVectorFieldInfo(SpaceType.L2, VectorDataType.FLOAT))); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping)), + mockClient, + listener + ); + + String expectedError = + "The vector data type [binary] provided in the MMR query extension does not match the vector data type [float] in target indices."; + verifyException(listener, IllegalArgumentException.class, expectedError); + } + + public void testResolveKnnVectorFieldInfo_whenModelNotFount_thenException() { + String vectorFieldPath = "field"; + SpaceType userProvidedSpaceType = null; + VectorDataType userProvidedVectorDataType = null; + String modelId1 = "model1"; + Map mapping = Map.of( + "properties", + Map.of("field", Map.of(TYPE, KNNVectorFieldMapper.CONTENT_TYPE, MODEL_ID, modelId1)) + ); + mockModelMetadata(mockClient, Collections.emptyMap()); + + MMRUtil.resolveKnnVectorFieldInfo( + vectorFieldPath, + userProvidedSpaceType, + userProvidedVectorDataType, + List.of(createMockIndexMetadata(mapping)), + mockClient, + listener + ); + + String expectedError = + "Failed to retrieve model(s) to resolve the space type and vector data type for the MMR query extension. Errors: Model ID model1 not found."; + verifyException(listener, RuntimeException.class, expectedError); + } + + private IndexMetadata createMockIndexMetadata(Map mappings) { + IndexMetadata indexMetadata = mock(IndexMetadata.class); + MappingMetadata mappingMetadata = mock(MappingMetadata.class); + when(indexMetadata.getIndex()).thenReturn(new Index("index", "uuid")); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + when(mappingMetadata.sourceAsMap()).thenReturn(mappings); + return indexMetadata; + } + + private void verifyVectorFieldInfo(ActionListener listener, MMRVectorFieldInfo vectorFieldInfo) { + ArgumentCaptor captor = ArgumentCaptor.forClass(MMRVectorFieldInfo.class); + verify(listener).onResponse(captor.capture()); + SpaceType capturedSpaceType = captor.getValue().getSpaceType(); + VectorDataType capturedVectorDataType = captor.getValue().getVectorDataType(); + assertEquals(vectorFieldInfo.getSpaceType(), capturedSpaceType); + assertEquals(vectorFieldInfo.getVectorDataType(), capturedVectorDataType); + } + + public void testShouldGenerateMMRProcessor_whenExtContainsBuilder_thenReturnTrue() { + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.ext(Collections.singletonList(new MMRSearchExtBuilder.Builder().build())); + searchRequest.source(searchSourceBuilder); + + ProcessorGenerationContext ctx = new ProcessorGenerationContext(searchRequest); + + assertTrue(MMRUtil.shouldGenerateMMRProcessor(ctx)); + } + + public void testShouldGenerateMMRProcessor_whenNoExt_thenReturnFalse() { + SearchRequest searchRequest = new SearchRequest(); + + ProcessorGenerationContext ctx = new ProcessorGenerationContext(searchRequest); + + assertFalse(MMRUtil.shouldGenerateMMRProcessor(ctx)); + } + + public void testGetMMRFieldMappingByPath_whenInNestedField_thenException() { + Map mappings = new HashMap<>(); + Map userMapping = new HashMap<>(); + userMapping.put("type", "nested"); + + Map properties = new HashMap<>(); + properties.put("user", userMapping); + mappings.put("properties", properties); + + String fieldPath = "user.profile.age"; + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> getMMRFieldMappingByPath(mappings, fieldPath)); + + String expectedError = "MMR search extension cannot support the field user.profile.age because it is in the nested field user."; + assertEquals(expectedError, ex.getMessage()); + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index abc7cde648..8ffa0f3488 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -413,6 +413,7 @@ protected List parseSearchResponseHits(String responseBody) throws IOExc /** * Parse the response of KNN search into a List of KNNResults */ + @SuppressWarnings("unchecked") protected List parseSearchResponse(String responseBody, String fieldName) throws IOException { @SuppressWarnings("unchecked") List hits = (List) ((Map) createParser( @@ -420,22 +421,26 @@ protected List parseSearchResponse(String responseBody, String fieldN responseBody ).map().get("hits")).get("hits"); - @SuppressWarnings("unchecked") - List knnSearchResponses = hits.stream().map(hit -> { - @SuppressWarnings("unchecked") - final float[] vector = Floats.toArray( - Arrays.stream( - ((ArrayList) ((Map) ((Map) hit).get("_source")).get(fieldName)).toArray() - ).map(Object::toString).map(Float::valueOf).collect(Collectors.toList()) - ); + return hits.stream().map(hit -> { + Object sourceObj = ((Map) hit).get("_source"); + float[] vector = null; + if (sourceObj != null) { + Object vectorObj = ((Map) ((Map) hit).get("_source")).get(fieldName); + vector = vectorObj == null + ? null + : Floats.toArray( + Arrays.stream(((ArrayList) vectorObj).toArray()) + .map(Object::toString) + .map(Float::valueOf) + .collect(Collectors.toList()) + ); + } return new KNNResult( (String) ((Map) hit).get("_id"), vector, ((Double) ((Map) hit).get("_score")).floatValue() ); }).collect(Collectors.toList()); - - return knnSearchResponses; } protected List parseSearchResponseScore(String responseBody, String fieldName) throws IOException { @@ -1076,12 +1081,9 @@ protected Map getKnnDoc(final String index, final String docId) * Utility to update settings */ protected void updateClusterSettings(String settingKey, Object value) throws Exception { - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("persistent") - .field(settingKey, value) - .endObject() - .endObject(); + XContentBuilder builder = value instanceof List + ? XContentFactory.jsonBuilder().startObject().startObject("persistent").array(settingKey, value).endObject().endObject() + : XContentFactory.jsonBuilder().startObject().startObject("persistent").field(settingKey, value).endObject().endObject(); Request request = new Request("PUT", "_cluster/settings"); request.setJsonEntity(builder.toString()); Response response = client().performRequest(request);