diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index c33f3ea63c..cc694f6c2e 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -5,12 +5,14 @@ package org.opensearch.knn.index; +import lombok.NonNull; import lombok.extern.log4j.Log4j2; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchParseException; import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse; +import org.opensearch.knn.index.engine.MemoryOptimizedSearchSupportSpec; import org.opensearch.transport.client.Client; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; @@ -96,6 +98,11 @@ public class KNNSettings { public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled"; public static final String KNN_INDEX_REMOTE_VECTOR_BUILD = "index.knn.remote_index_build.enabled"; public static final String KNN_REMOTE_VECTOR_REPO = "knn.remote_index_build.vector_repo"; + /** + * For more details on supported engines, refer to {@link MemoryOptimizedSearchSupportSpec} + */ + public static final String MEMORY_OPTIMIZED_KNN_SEARCH_MODE = "index.knn.memory_optimized_search"; + public static final boolean DEFAULT_MEMORY_OPTIMIZED_KNN_SEARCH_MODE = false; /** * Default setting values @@ -121,9 +128,9 @@ public class KNNSettings { public static final Integer ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE = -1; public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 5; // By default, set aside 5% of the JVM for - // the limit + // the limit public static final Integer KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // Quantization state cache limit cannot exceed - // 10% of the JVM heap + // 10% of the JVM heap public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60; public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = false; @@ -285,6 +292,12 @@ public class KNNSettings { */ public static final Setting IS_KNN_INDEX_SETTING = Setting.boolSetting(KNN_INDEX, false, IndexScope, Final); + public static final Setting MEMORY_OPTIMIZED_KNN_SEARCH_MODE_SETTING = Setting.boolSetting( + MEMORY_OPTIMIZED_KNN_SEARCH_MODE, + false, + IndexScope + ); + /** * index_thread_quantity - the parameter specifies how many threads the nms library should use to create the graph. * By default, the nms library sets this value to NUM_CORES. However, because ES can spawn NUM_CORES threads for @@ -577,7 +590,8 @@ public List> getSettings() { KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, KNN_DERIVED_SOURCE_ENABLED_SETTING, KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING, - KNN_REMOTE_VECTOR_REPO_SETTING + KNN_REMOTE_VECTOR_REPO_SETTING, + MEMORY_OPTIMIZED_KNN_SEARCH_MODE_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); @@ -734,12 +748,25 @@ public static int getEfSearchParam(String index) { ); } + /** + * Return whether memory optimized search enabled for the given index. + * + * @param indexName The name of target index to test whether if it is on. + * @return True if memory optimized search is enabled, otherwise False. + */ + public static boolean isMemoryOptimizedKnnSearchModeEnabled(@NonNull final String indexName) { + return KNNSettings.state().clusterService.state() + .getMetadata() + .index(indexName) + .getSettings() + .getAsBoolean(MEMORY_OPTIMIZED_KNN_SEARCH_MODE, DEFAULT_MEMORY_OPTIMIZED_KNN_SEARCH_MODE); + } + public void setClusterService(ClusterService clusterService) { this.clusterService = clusterService; } static class SpaceTypeValidator implements Setting.Validator { - @Override public void validate(String value) { try { diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 07a5c34e4b..813e959183 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -155,11 +155,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() { // mapperService is already checked for null or valid instance type at caller, hence we don't need // addition isPresent check here. - int approximateThreshold = getApproximateThresholdValue(); + final int approximateThreshold = getApproximateThresholdValue(); return new NativeEngines990KnnVectorsFormat( new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()), approximateThreshold, - nativeIndexBuildStrategyFactory + nativeIndexBuildStrategyFactory, + mapperService ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java index f4e76f3caa..6cc7979ba7 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java @@ -11,6 +11,7 @@ package org.opensearch.knn.index.codec.KNN990Codec; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; @@ -19,22 +20,31 @@ import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory; import org.opensearch.knn.index.engine.KNNEngine; import java.io.IOException; +import java.util.Optional; + +import static org.opensearch.knn.index.KNNSettings.DEFAULT_MEMORY_OPTIMIZED_KNN_SEARCH_MODE; +import static org.opensearch.knn.index.KNNSettings.MEMORY_OPTIMIZED_KNN_SEARCH_MODE; /** * This is a Vector format that will be used for Native engines like Faiss and Nmslib for reading and writing vector * related data structures. */ +@Log4j2 public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat { /** The format for storing, reading, merging vectors on disk */ private static FlatVectorsFormat flatVectorsFormat; private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat"; private static int approximateThreshold; private final NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory; + private final boolean memoryOptimizedSearchEnabled; public NativeEngines990KnnVectorsFormat() { this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer())); @@ -49,18 +59,20 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsForma } public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat, int approximateThreshold) { - this(flatVectorsFormat, approximateThreshold, new NativeIndexBuildStrategyFactory()); + this(flatVectorsFormat, approximateThreshold, new NativeIndexBuildStrategyFactory(), Optional.empty()); } public NativeEngines990KnnVectorsFormat( final FlatVectorsFormat flatVectorsFormat, int approximateThreshold, - final NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory + final NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory, + final Optional mapperService ) { super(FORMAT_NAME); NativeEngines990KnnVectorsFormat.flatVectorsFormat = flatVectorsFormat; NativeEngines990KnnVectorsFormat.approximateThreshold = approximateThreshold; this.nativeIndexBuildStrategyFactory = nativeIndexBuildStrategyFactory; + this.memoryOptimizedSearchEnabled = isMemoryOptimizedSearchSupported(mapperService); } /** @@ -85,7 +97,7 @@ public KnnVectorsWriter fieldsWriter(final SegmentWriteState state) throws IOExc */ @Override public KnnVectorsReader fieldsReader(final SegmentReadState state) throws IOException { - return new NativeEngines990KnnVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + return new NativeEngines990KnnVectorsReader(state, flatVectorsFormat.fieldsReader(state), memoryOptimizedSearchEnabled); } /** @@ -107,4 +119,21 @@ public String toString() { + approximateThreshold + ")"; } + + private static boolean isMemoryOptimizedSearchSupported(final Optional mapperService) { + if (mapperService.isPresent()) { + final IndexSettings indexSettings = mapperService.get().getIndexSettings(); + if (indexSettings != null) { + final Settings settings = indexSettings.getSettings(); + if (settings != null) { + try { + return settings.getAsBoolean(MEMORY_OPTIMIZED_KNN_SEARCH_MODE, DEFAULT_MEMORY_OPTIMIZED_KNN_SEARCH_MODE); + } catch (Throwable th) { + log.error("Failed to get a bool flag of [{}] from settings.", MEMORY_OPTIMIZED_KNN_SEARCH_MODE); + } + } + } + } + return false; + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index bda235e6fb..1b0e2a8397 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -27,6 +27,7 @@ import org.apache.lucene.util.IOSupplier; import org.apache.lucene.util.IOUtils; import org.opensearch.common.UUIDs; +import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper; import org.opensearch.knn.index.engine.KNNEngine; @@ -45,7 +46,6 @@ import java.util.List; import java.util.Map; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapper.KNN_FIELD; /** @@ -59,85 +59,27 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader { private final FlatVectorsReader flatVectorsReader; private Map quantizationStateCacheKeyPerField; - private SegmentReadState segmentReadState; + private final SegmentReadState segmentReadState; private final List cacheKeys; private volatile Map vectorSearchers; public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) { + this(state, flatVectorsReader, false); + } + + public NativeEngines990KnnVectorsReader( + final SegmentReadState state, + final FlatVectorsReader flatVectorsReader, + final boolean memoryOptimizedSearchEnabled + ) { this.flatVectorsReader = flatVectorsReader; this.segmentReadState = state; this.cacheKeys = getVectorCacheKeysFromSegmentReaderState(state); - this.vectorSearchers = null; loadCacheKeyMap(); - // Memory optimized searcher will be ONLY used for searching, we don't need this during merging. - // TODO(KDY) : Enable this based on setting flag value. e.g. index.knn.memory_optimized_search = True - if (state.context.context() != IOContext.Context.MERGE) { - loadMemoryOptimizedSearcher(); - } - } - - private IOSupplier getIndexFileNameIfMemoryOptimizedSearchSupported(final FieldInfo fieldInfo) { - // Skip non-knn fields. - final Map attributes = fieldInfo.attributes(); - if (attributes == null || attributes.containsKey(KNN_FIELD) == false) { - return null; - } - - // Get engine - final String engineName = attributes.getOrDefault(KNN_ENGINE, KNNEngine.DEFAULT.getName()); - final KNNEngine knnEngine = KNNEngine.getEngine(engineName); - - // Get memory optimized searcher from engine - final VectorSearcherFactory searcherFactory = knnEngine.getVectorSearcherFactory(); - if (searcherFactory == null) { - // It's not supported - return null; - } - - // Start creating searcher - final String fileName = KNNCodecUtil.getNativeEngineFileFromFieldInfo(fieldInfo, segmentReadState.segmentInfo); - if (fileName != null) { - return () -> searcherFactory.createVectorSearcher(segmentReadState.directory, fileName); - } - - // Not supported - return null; - } - - private synchronized void loadMemoryOptimizedSearcher() { - if (vectorSearchers != null) { - return; - } - - final Map vectorSearcherPerField = new HashMap<>( - RESERVE_TWICE_SPACE * segmentReadState.fieldInfos.size(), - SUFFICIENT_LOAD_FACTOR - ); - - try { - for (FieldInfo fieldInfo : segmentReadState.fieldInfos) { - final IOSupplier searcherSupplier = getIndexFileNameIfMemoryOptimizedSearchSupported(fieldInfo); - if (searcherSupplier != null) { - final VectorSearcher searcher = searcherSupplier.get(); - if (searcher != null) { - // It's supported. There can be a case where a certain index type underlying is not yet supported while KNNEngine - // itself supports memory optimized searching. - vectorSearcherPerField.put(fieldInfo.getName(), searcher); - } - } - } - - vectorSearchers = vectorSearcherPerField; - } catch (Exception e) { - // Close opened searchers first, then suppress - try { - IOUtils.closeWhileHandlingException(vectorSearcherPerField.values()); - } catch (Exception closeException) { - log.error(closeException.getMessage(), closeException); - } - throw new RuntimeException(e); + if (memoryOptimizedSearchEnabled && state.context.context() != IOContext.Context.MERGE) { + loadMemoryOptimizedSearcherIfRequired(); } } @@ -308,12 +250,7 @@ private boolean trySearchWithMemoryOptimizedSearch( Bits acceptDocs, boolean isFloatVector ) throws IOException { - if (vectorSearchers == null) { - // Null `vectorSearchers` indicates that this reader was instantiated during merge. - // In this case, we load the searcher on demand for searching. - // This will not likely happen, by the time searching, this old segment will be merged away anyway. - loadMemoryOptimizedSearcher(); - } + loadMemoryOptimizedSearcherIfRequired(); // Try with memory optimized searcher final VectorSearcher memoryOptimizedSearcher = vectorSearchers.get(field); @@ -350,4 +287,80 @@ private static List getVectorCacheKeysFromSegmentReaderState(SegmentRead return cacheKeys; } + + private void loadMemoryOptimizedSearcherIfRequired() { + if (vectorSearchers != null) { + return; + } + + synchronized (this) { + if (vectorSearchers != null) { + return; + } + + // We need sufficient memory space for this table as it will be queried for every single search. + // Hence, having larger space to approximate a perfect hash here. + final Map vectorSearcherPerField = new HashMap<>( + RESERVE_TWICE_SPACE * segmentReadState.fieldInfos.size(), + SUFFICIENT_LOAD_FACTOR + ); + + try { + for (FieldInfo fieldInfo : segmentReadState.fieldInfos) { + final IOSupplier searcherSupplier = getVectorSearcherSupplier(fieldInfo); + if (searcherSupplier != null) { + final VectorSearcher searcher = searcherSupplier.get(); + if (searcher != null) { + // It's supported. There can be a case where a certain index type underlying is not yet supported while + // KNNEngine + // itself supports memory optimized searching. + vectorSearcherPerField.put(fieldInfo.getName(), searcher); + } + } + } + + vectorSearchers = vectorSearcherPerField; + } catch (Exception e) { + // Close opened searchers first, then suppress + try { + IOUtils.closeWhileHandlingException(vectorSearcherPerField.values()); + } catch (Exception closeException) { + log.error(closeException.getMessage(), closeException); + } + throw new RuntimeException(e); + } + } + } + + private IOSupplier getVectorSearcherSupplier(final FieldInfo fieldInfo) { + // Skip non-knn fields. + final Map attributes = fieldInfo.attributes(); + if (attributes == null || attributes.containsKey(KNN_FIELD) == false) { + return null; + } + + // Try to get KNN engine from fieldInfo. + final KNNEngine knnEngine = FieldInfoExtractor.extractKNNEngine(fieldInfo); + + // No KNNEngine is available + if (knnEngine == null) { + return null; + } + + // Get memory optimized searcher from engine + final VectorSearcherFactory searcherFactory = knnEngine.getVectorSearcherFactory(); + if (searcherFactory == null) { + // It's not supported + return null; + } + + // Start creating searcher + final String fileName = KNNCodecUtil.getNativeEngineFileFromFieldInfo(fieldInfo, segmentReadState.segmentInfo); + if (fileName != null) { + return () -> searcherFactory.createVectorSearcher(segmentReadState.directory, fileName); + } + + // Not supported + return null; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/MemoryOptimizedSearchSupportSpec.java b/src/main/java/org/opensearch/knn/index/engine/MemoryOptimizedSearchSupportSpec.java new file mode 100644 index 0000000000..fa542d73e5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/MemoryOptimizedSearchSupportSpec.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.memoryoptsearch.VectorSearcher; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +/** + * This class encapsulates a determination logic for memory optimized search. + * Memory-optimized-search may not be applied to a certain type of index even {@link KNNEngine} returns a non-null + * {@link org.opensearch.knn.memoryoptsearch.VectorSearcherFactory}. + * The overall logic will be made based on the given method context and quantization configuration. + */ +public class MemoryOptimizedSearchSupportSpec { + private static final Set SUPPORTED_HNSW_ENCODING = Set.of(ENCODER_FLAT, ENCODER_SQ); + + /** + * Determine whether if a KNN field supports memory-optimized-search. + * If it is supported, then the field can perform memory-optimized search via {@link VectorSearcher}. + * Which can be obtained from a factory acquired from {@link KNNEngine#getVectorSearcherFactory()}. + * + * @param methodContextOpt Optional method context. + * @param quantizationConfig Quantization configuration. + * @param vectorDataType Vector data type. + * @return True if memory-optimized-search is supported, otherwise false. + */ + public static boolean supported( + final Optional methodContextOpt, + final QuantizationConfig quantizationConfig, + final VectorDataType vectorDataType + ) { + if (methodContextOpt.isPresent()) { + final KNNMethodContext methodContext = methodContextOpt.get(); + final KNNEngine engine = methodContext.getKnnEngine(); + + // We support Lucene engine + if (engine == KNNEngine.LUCENE) { + return true; + } + + // We don't support non-FAISS engine + if (engine != KNNEngine.FAISS) { + return false; + } + + // We only support HNSW method. + final MethodComponentContext methodComponentContext = methodContext.getMethodComponentContext(); + if (methodComponentContext.getName().equals(METHOD_HNSW) == false) { + return false; + } + + // We don't support quantization yet. + if (quantizationConfig != null && quantizationConfig.getQuantizationType() != null) { + return false; + } + + // Only support FLOAT/BYTE index. + if (vectorDataType != VectorDataType.FLOAT && vectorDataType != VectorDataType.BYTE) { + return false; + } + + // L2 or Inner product are supported. + if (methodContext.getSpaceType() != SpaceType.L2 && methodContext.getSpaceType() != SpaceType.INNER_PRODUCT) { + return false; + } + + // We only support Flat and SQ encoder for HNSW. + final Map parameters = methodComponentContext.getParameters(); + final Object methodComponentContextObj = parameters.get(METHOD_ENCODER_PARAMETER); + if ((methodComponentContextObj instanceof MethodComponentContext) == false) { + return false; + } + + if (SUPPORTED_HNSW_ENCODING.contains(((MethodComponentContext) methodComponentContextObj).getName()) == false) { + return false; + } + + return true; + } + + return false; + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index b0bead693c..653933a670 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -24,6 +24,7 @@ import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.index.engine.MemoryOptimizedSearchSupportSpec; import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; @@ -44,6 +45,7 @@ public class KNNVectorFieldType extends MappedFieldType { private static final Logger logger = LogManager.getLogger(KNNVectorFieldType.class); KNNMappingConfig knnMappingConfig; VectorDataType vectorDataType; + boolean memoryOptimizedSearchSupported; /** * Constructor for KNNVectorFieldType. @@ -57,6 +59,11 @@ public KNNVectorFieldType(String name, Map metadata, VectorDataT super(name, false, false, true, TextSearchInfo.NONE, metadata); this.vectorDataType = vectorDataType; this.knnMappingConfig = annConfig; + this.memoryOptimizedSearchSupported = MemoryOptimizedSearchSupportSpec.supported( + knnMappingConfig.getKnnMethodContext(), + knnMappingConfig.getQuantizationConfig(), + vectorDataType + ); } @Override @@ -151,6 +158,5 @@ public void transformQueryVector(float[] vector) { return; } throw new IllegalStateException("Either KNN method context or Model Id should be configured"); - } } diff --git a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java index 6c5eea08fc..4db11996c6 100644 --- a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java @@ -51,6 +51,7 @@ public static class CreateQueryRequest { private QueryShardContext context; private RescoreContext rescoreContext; private Boolean expandNested; + private boolean memoryOptimizedSearchSupported; public Optional getFilter() { return Optional.ofNullable(filter); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index f032210aab..1c3be50b16 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -23,22 +23,23 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.WithFieldName; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.VectorQueryType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNLibrarySearchContext; import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.model.QueryContext; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; import org.opensearch.knn.index.query.parser.RescoreParser; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.IndexUtil; -import org.opensearch.knn.index.engine.MethodComponentContext; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.VectorQueryType; -import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; -import org.opensearch.knn.index.engine.KNNLibrarySearchContext; -import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; @@ -56,9 +57,9 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; -import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters; import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; import static org.opensearch.knn.index.engine.validation.ParameterValidator.validateParameters; +import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters; import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_OVERSAMPLE_PARAMETER; import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER; @@ -404,6 +405,9 @@ protected Query doToQuery(QueryShardContext context) { knnVectorFieldType.transformQueryVector(vector); VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore); + final String indexName = context.index().getName(); + final boolean memoryOptimizedSearchSupported = knnVectorFieldType.isMemoryOptimizedSearchSupported() + && KNNSettings.isMemoryOptimizedKnnSearchModeEnabled(indexName); updateQueryStats(vectorQueryType); // This could be null in the case of when a model did not have serialized methodComponent information @@ -488,7 +492,7 @@ protected Query doToQuery(QueryShardContext context) { spaceType.validateVector(byteVector); break; case BYTE: - if (KNNEngine.LUCENE == knnEngine) { + if (isUsingLuceneQuery(knnEngine, memoryOptimizedSearchSupported)) { byteVector = new byte[vector.length]; for (int i = 0; i < vector.length; i++) { validateByteVectorValue(vector[i], knnVectorFieldType.getVectorDataType()); @@ -512,15 +516,13 @@ protected Query doToQuery(QueryShardContext context) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support filters", knnEngine)); } - String indexName = context.index().getName(); - if (k != 0) { KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine)) - .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector)) + .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector, memoryOptimizedSearchSupported)) .vectorDataType(vectorDataType) .k(this.k) .methodParameters(this.methodParameters) @@ -528,6 +530,7 @@ protected Query doToQuery(QueryShardContext context) { .context(context) .rescoreContext(processedRescoreContext) .expandNested(expandNested) + .memoryOptimizedSearchSupported(memoryOptimizedSearchSupported) .build(); return KNNQueryFactory.create(createQueryRequest); } @@ -543,6 +546,7 @@ protected Query doToQuery(QueryShardContext context) { .methodParameters(this.methodParameters) .filter(this.filter) .context(context) + .memoryOptimizedSearchSupported(memoryOptimizedSearchSupported) .build(); return RNNQueryFactory.create(createQueryRequest); } @@ -574,6 +578,19 @@ private QueryConfigFromMapping getQueryConfig(final KNNMappingConfig knnMappingC throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not built for ANN search.", this.fieldName)); } + /** + * Determine whether the query will be using Lucene query to perform vector search. + * Currently, if memory optimized search is enabled, it fallbacks to Lucene and delegate its HNSW graph searcher to perform ANN search + * on FAISS index. Hence, if it is true, then we need to use Lucene query. + * + * @param engine Engine type + * @param memoryOptimizedSearchSupported A bool flag whether memory optimized search is enabled. + * @return True when it should use Lucene query False otherwise. + */ + private static boolean isUsingLuceneQuery(final KNNEngine engine, final boolean memoryOptimizedSearchSupported) { + return memoryOptimizedSearchSupported || engine == KNNEngine.LUCENE; + } + private ModelMetadata getModelMetadataForField(String modelId) { ModelMetadata modelMetadata = modelDao.getMetadata(modelId); if (!ModelUtil.isModelCreated(modelMetadata)) { @@ -621,8 +638,15 @@ private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, return null; } - private byte[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, byte[] byteVector) { - if (VectorDataType.BINARY == vectorDataType || (VectorDataType.BYTE == vectorDataType && KNNEngine.LUCENE == knnEngine)) { + private byte[] getVectorForCreatingQueryRequest( + VectorDataType vectorDataType, + KNNEngine knnEngine, + byte[] byteVector, + boolean memoryOptimizedSearchSupported + ) { + + if (VectorDataType.BINARY == vectorDataType + || (VectorDataType.BYTE == vectorDataType && isUsingLuceneQuery(knnEngine, memoryOptimizedSearchSupported))) { return byteVector; } return null; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index b6770553b4..de426aaf58 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -49,8 +49,9 @@ public static Query create(CreateQueryRequest createQueryRequest) { final Query filterQuery = getFilterQuery(createQueryRequest); final Map methodParameters = createQueryRequest.getMethodParameters(); final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null); - final KNNEngine knnEngine = createQueryRequest.getKnnEngine(); final boolean expandNested = createQueryRequest.getExpandNested().orElse(false); + final boolean memoryOptimizedSearchSupported = createQueryRequest.isMemoryOptimizedSearchSupported(); + BitSetProducer parentFilter = null; int shardId = -1; if (createQueryRequest.getContext().isPresent()) { @@ -70,7 +71,8 @@ public static Query create(CreateQueryRequest createQueryRequest) { ); } - if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { + if (memoryOptimizedSearchSupported == false + && KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { final Query validatedFilterQuery = validateFilterQuerySupport(filterQuery, createQueryRequest.getKnnEngine()); log.debug( diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReaderTests.java index f6268af303..11b808aaf0 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReaderTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReaderTests.java @@ -42,7 +42,7 @@ public void testWhenMemoryOptimizedSearchIsEnabled_emptyCase() { final FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); // Load vector searchers - final Map vectorSearchers = loadSearchers(fieldInfos, Collections.emptySet()); + final Map vectorSearchers = loadSearchers(fieldInfos, Collections.emptySet(), true); assertTrue(vectorSearchers.isEmpty()); } @@ -87,13 +87,72 @@ public void testWhenMemoryOptimizedSearchIsEnabled_mixedCase() { mockedStatic.when(KNNEngine::getEnginesThatCreateCustomSegmentFiles).thenReturn(ImmutableSet.of(mockFaiss)); // Load vector searchers - final Map vectorSearchers = loadSearchers(fieldInfos, filesInSegment); + final Map vectorSearchers = loadSearchers(fieldInfos, filesInSegment, true); // Validate #searchers assertEquals(2, vectorSearchers.size()); } } + @SneakyThrows + public void testWhenMemoryOptimizedSearchIsNotEnabled() { + // Prepare field infos + final FieldInfo[] fieldInfoArray = new FieldInfo[] {}; + final FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + + // Load vector searchers + final Map vectorSearchers = loadSearchers(fieldInfos, Collections.emptySet(), false); + assertNull(vectorSearchers); + } + + @SneakyThrows + public void testWhenMemoryOptimizedSearchIsNotEnabled_mixedCase() { + // Prepare field infos + // - field1: Non KNN field + // - field2: KNN field, but using Lucene engine + // - field3: KNN field, FAISS + // - field4: KNN field, FAISS + // - field5: KNN field, FAISS, but it does not have file for some reason. + + // Mocking FAISS engine to return a dummy vector searcher + KNNEngine mockFaiss = spy(KNNEngine.FAISS); + VectorSearcherFactory mockFactory = mock(VectorSearcherFactory.class); + when(mockFactory.createVectorSearcher(any(), any())).thenReturn(mock(VectorSearcher.class)); + when(mockFaiss.getVectorSearcherFactory()).thenReturn(mockFactory); + + try (MockedStatic mockedStatic = mockStatic(KNNEngine.class)) { + // Prepare field infos + final FieldInfo[] fieldInfoArray = new FieldInfo[] { + createFieldInfo("field1", null, 0), + createFieldInfo("field2", KNNEngine.LUCENE, 1), + createFieldInfo("field3", mockFaiss, 2), + createFieldInfo("field4", mockFaiss, 3), + createFieldInfo("field5", mockFaiss, 4) }; + final FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + final Set filesInSegment = Set.of("_0_165_field3.faiss", "_0_165_field4.faiss"); + + // Replace static 'getEngine' to return mockFaiss + mockedStatic.when(() -> KNNEngine.getEngine(any())).thenAnswer(invocation -> { + final String strArg = invocation.getArgument(0); + // Intercept FAISS engine to return mock + if (strArg.equals(KNNEngine.FAISS.getName())) { + return mockFaiss; + } + + // Otherwise return Lucene, as field2 is using Lucene. + return KNNEngine.LUCENE; + }); + + mockedStatic.when(KNNEngine::getEnginesThatCreateCustomSegmentFiles).thenReturn(ImmutableSet.of(mockFaiss)); + + // Load vector searchers + final Map vectorSearchers = loadSearchers(fieldInfos, filesInSegment, false); + + // The table should be null even we had faiss fields. + assertNull(vectorSearchers); + } + } + private static FieldInfo createFieldInfo(final String fieldName, final KNNEngine engine, final int fieldNo) { final KNNCodecTestUtil.FieldInfoBuilder builder = KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName); builder.fieldNumber(fieldNo); @@ -105,7 +164,11 @@ private static FieldInfo createFieldInfo(final String fieldName, final KNNEngine } @SneakyThrows - private static Map loadSearchers(final FieldInfos fieldInfos, final Set filesInSegment) { + private static Map loadSearchers( + final FieldInfos fieldInfos, + final Set filesInSegment, + final boolean memoryOptimizedSearchEnabled + ) { // Prepare infra final IndexInput mockIndexInput = mock(IndexInput.class); final Directory mockDirectory = mock(Directory.class); @@ -116,7 +179,7 @@ private static Map loadSearchers(final FieldInfos fieldI final SegmentReadState readState = new SegmentReadState(mockDirectory, segmentInfo, fieldInfos, IOContext.DEFAULT); // Create reader - final NativeEngines990KnnVectorsReader reader = new NativeEngines990KnnVectorsReader(readState, null); + final NativeEngines990KnnVectorsReader reader = new NativeEngines990KnnVectorsReader(readState, null, memoryOptimizedSearchEnabled); final Class clazz = NativeEngines990KnnVectorsReader.class; // Get searcher table diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemoryOptimizedSearchFlagInFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemoryOptimizedSearchFlagInFormatTests.java new file mode 100644 index 0000000000..3b870e2c63 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemoryOptimizedSearchFlagInFormatTests.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.SneakyThrows; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat; + +import java.lang.reflect.Field; +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNSettings.DEFAULT_MEMORY_OPTIMIZED_KNN_SEARCH_MODE; +import static org.opensearch.knn.index.KNNSettings.MEMORY_OPTIMIZED_KNN_SEARCH_MODE; + +public class MemoryOptimizedSearchFlagInFormatTests extends KNNTestCase { + private static final int DONT_CARE = -1; + + @SneakyThrows + public void testWhenNullSettings() { + // Create format with null index settings + final IndexSettings indexSettings = mock(IndexSettings.class); + when(indexSettings.getSettings()).thenReturn(null); + + // Mock MapperService + final MapperService mapperService = mock(MapperService.class); + when(mapperService.getIndexSettings()).thenReturn(indexSettings); + + final NativeEngines990KnnVectorsFormat format = new NativeEngines990KnnVectorsFormat( + null, // Don't care + -1, // Don't care + null, // Don't care + Optional.of(mapperService) + ); + + doTest(format, false); + } + + @SneakyThrows + public void testWhenSettingsDontHaveTheFlag() { + // Mock settings + final Settings settings = mock(Settings.class); + when(settings.getAsBoolean(any(), any())).thenReturn(false); + + // Mock index settings + final IndexSettings indexSettings = mock(IndexSettings.class); + when(indexSettings.getSettings()).thenReturn(settings); + + // Create format with null index settings + final NativeEngines990KnnVectorsFormat format = new NativeEngines990KnnVectorsFormat( + null, // Don't care + -1, // Don't care + null, // Don't care + Optional.empty() + ); + + doTest(format, false); + } + + @SneakyThrows + public void testWhenSettingsHaveTheFlag() { + // Mock settings + final Settings settings = mock(Settings.class); + when(settings.getAsBoolean(MEMORY_OPTIMIZED_KNN_SEARCH_MODE, DEFAULT_MEMORY_OPTIMIZED_KNN_SEARCH_MODE)).thenReturn(true); + + // Mock index settings + final IndexSettings indexSettings = mock(IndexSettings.class); + when(indexSettings.getSettings()).thenReturn(settings); + + // Mock MapperService + final MapperService mapperService = mock(MapperService.class); + when(mapperService.getIndexSettings()).thenReturn(indexSettings); + + // Create format with null index settings + final NativeEngines990KnnVectorsFormat format = new NativeEngines990KnnVectorsFormat( + null, // Don't care + -1, // Don't care + null, // Don't care + Optional.of(mapperService) + ); + + doTest(format, true); + } + + @SneakyThrows + private void doTest(final NativeEngines990KnnVectorsFormat format, final boolean expected) { + // Get field + final Field field = NativeEngines990KnnVectorsFormat.class.getDeclaredField("memoryOptimizedSearchEnabled"); + field.setAccessible(true); // Bypass private access + + // Test whether it is supported + final boolean result = (boolean) field.get(format); + assertEquals(expected, result); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 1333d616e6..43eeb9e6d3 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -8,9 +8,12 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import org.apache.lucene.search.FloatVectorSimilarityQuery; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.junit.Before; +import org.mockito.MockedStatic; import org.opensearch.Version; import org.opensearch.cluster.ClusterModule; import org.opensearch.cluster.service.ClusterService; @@ -57,8 +60,10 @@ import static java.util.Collections.emptyMap; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; @@ -609,6 +614,7 @@ public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParamete QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + when(mockQueryShardContext.index()).thenReturn(new Index("dummy", "dummy")); KNNMethodContext knnMethodContext = new KNNMethodContext( KNNEngine.LUCENE, SpaceType.COSINESIMIL, @@ -645,6 +651,78 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + public void testDoToQuery_whenMemoryOptimizedSearchIsEnabled() { + do_testDoToQuery_whenMemoryOptimizedSearchIsEnabled(true, true, VectorDataType.FLOAT); + do_testDoToQuery_whenMemoryOptimizedSearchIsEnabled(true, true, VectorDataType.BYTE); + do_testDoToQuery_whenMemoryOptimizedSearchIsEnabled(true, false, VectorDataType.FLOAT); + do_testDoToQuery_whenMemoryOptimizedSearchIsEnabled(true, false, VectorDataType.BYTE); + + do_testDoToQuery_whenMemoryOptimizedSearchIsEnabled(false, true, VectorDataType.FLOAT); + do_testDoToQuery_whenMemoryOptimizedSearchIsEnabled(false, true, VectorDataType.BYTE); + do_testDoToQuery_whenMemoryOptimizedSearchIsEnabled(false, false, VectorDataType.FLOAT); + do_testDoToQuery_whenMemoryOptimizedSearchIsEnabled(false, false, VectorDataType.BYTE); + } + + private void do_testDoToQuery_whenMemoryOptimizedSearchIsEnabled( + boolean memoryOptimizedSearchEnabled, + boolean memoryOptimizedSearchSupportedInField, + VectorDataType vectorDataType + ) { + + try (MockedStatic knnSettingsMockedStatic = mockStatic(KNNSettings.class)) { + // Index setting mocking + knnSettingsMockedStatic.when(() -> KNNSettings.isMemoryOptimizedKnnSearchModeEnabled(any())) + .thenReturn(memoryOptimizedSearchEnabled); + + // Query vector + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); + + // Query shard context + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + // Field type + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + when(mockKNNVectorField.isMemoryOptimizedSearchSupported()).thenReturn(memoryOptimizedSearchSupportedInField); + when(mockKNNVectorField.getVectorDataType()).thenReturn(vectorDataType); + + // Method context + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); + final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); + + // KNN mapping config + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); + + // Execute `doToQuery` + final Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + final boolean isEnabled = memoryOptimizedSearchEnabled && memoryOptimizedSearchSupportedInField; + if (isEnabled) { + // If memory optimized search is on then, use Lucene query + assertFalse(query instanceof NativeEngineKnnVectorQuery); + assertTrue(query instanceof LuceneEngineKnnVectorQuery); + final LuceneEngineKnnVectorQuery luceneQuery = (LuceneEngineKnnVectorQuery) query; + + if (vectorDataType == VectorDataType.FLOAT) { + assert (luceneQuery.getLuceneQuery() instanceof KnnFloatVectorQuery); + assertEquals(queryVector.length, ((KnnFloatVectorQuery) luceneQuery.getLuceneQuery()).getTargetCopy().length); + } else if (vectorDataType == VectorDataType.BYTE) { + assert (luceneQuery.getLuceneQuery() instanceof KnnByteVectorQuery); + assertEquals(queryVector.length, ((KnnByteVectorQuery) luceneQuery.getLuceneQuery()).getTargetCopy().length); + } + } else { + // If memory optimized search is turned off then, use Native query + assertTrue(query instanceof NativeEngineKnnVectorQuery); + assertFalse(query instanceof LuceneEngineKnnVectorQuery); + } + } + } + @SneakyThrows public void testDoToQuery_FromModel() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; diff --git a/src/test/java/org/opensearch/knn/memoryoptsearch/MemoryOptimizedSearchSupportSpecTests.java b/src/test/java/org/opensearch/knn/memoryoptsearch/MemoryOptimizedSearchSupportSpecTests.java new file mode 100644 index 0000000000..87b70ea67d --- /dev/null +++ b/src/test/java/org/opensearch/knn/memoryoptsearch/MemoryOptimizedSearchSupportSpecTests.java @@ -0,0 +1,163 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch; + +import lombok.Builder; +import lombok.RequiredArgsConstructor; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MemoryOptimizedSearchSupportSpec; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +public class MemoryOptimizedSearchSupportSpecTests extends KNNTestCase { + public void testLuceneEngineIsSupported() { + // Lucene + any configurations must be supported. + final TestingSpec testingSpec = new TestingSpec( + KNNEngine.LUCENE, + Arrays.asList(SpaceType.values()), + Arrays.asList(VectorDataType.values()), + // Don't care MethodComponentContext for Lucene + Collections.emptyList() + ); + + mustSupported(testingSpec); + } + + public void testFaissSupportedCases() { + // HNSW,float, L2|IP, Flat + // HNSW,float, L2|IP, SQ + // Note that we do support byte index. And it is VectorDataType.FLOAT for the byte index, not VectorDataType.BYTE. + final TestingSpec testingSpec = new TestingSpec( + KNNEngine.FAISS, + Arrays.asList(SpaceType.L2, SpaceType.INNER_PRODUCT), + Arrays.asList(VectorDataType.FLOAT), + Arrays.asList( + new MethodComponentContext( + METHOD_HNSW, + Map.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext(ENCODER_FLAT, Collections.emptyMap())) + ), + new MethodComponentContext( + METHOD_HNSW, + Map.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext(ENCODER_SQ, Collections.emptyMap())) + ) + ) + ); + + mustSupported(testingSpec); + } + + public void testFaissUnsupportedCases() { + // HNSW,float, L2|IP, Flat + // HNSW,float, L2|IP, SQ + // Note that we do support byte index. And it is VectorDataType.FLOAT for the byte index, not VectorDataType.BYTE. + final TestingSpec testingSpec = new TestingSpec( + KNNEngine.FAISS, + Arrays.asList(SpaceType.values()), + Arrays.asList(VectorDataType.values()), + Arrays.asList( + new MethodComponentContext( + METHOD_HNSW, + Map.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext("DUMMY_KEY", Collections.emptyMap())) + ), + new MethodComponentContext(METHOD_HNSW, Map.of(METHOD_ENCODER_PARAMETER, new Object())) + ) + ); + + mustNotSupported(testingSpec); + } + + public void testBinaryFiassNotSupportedCases() { + // HNSW,binary, L2|IP, Flat + // HNSW,binary, L2|IP, SQ + // Note that we do support byte index. And it is VectorDataType.FLOAT for the byte index, not VectorDataType.BYTE. + final TestingSpec testingSpec = new TestingSpec( + KNNEngine.FAISS, + Arrays.asList(SpaceType.values()), + Arrays.asList(VectorDataType.BINARY), + Arrays.asList( + new MethodComponentContext( + METHOD_HNSW, + Map.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext("DUMMY_KEY", Collections.emptyMap())) + ), + new MethodComponentContext(METHOD_HNSW, Map.of(METHOD_ENCODER_PARAMETER, new Object())) + ) + ); + + mustNotSupported(testingSpec); + } + + private void mustSupported(final TestingSpec testingSpec) { + doTest(testingSpec, true); + } + + private void mustNotSupported(final TestingSpec testingSpec) { + doTest(testingSpec, false); + } + + private void doTest(final TestingSpec testingSpec, final boolean expected) { + for (final SpaceType spaceType : testingSpec.spaceTypes) { + for (final VectorDataType vectorDataType : testingSpec.vectorDataTypes) { + for (final MethodComponentContext methodComponentContext : testingSpec.methodComponentContexts) { + final Params params = buildParameters(testingSpec.knnEngine, spaceType, vectorDataType, methodComponentContext); + + final boolean isSupported = MemoryOptimizedSearchSupportSpec.supported( + params.methodContextOpt, + params.quantizationConfig, + params.vectorDataType + ); + + assertEquals(expected, isSupported); + } + } + } + } + + private Params buildParameters( + final KNNEngine knnEngine, + final SpaceType spaceType, + final VectorDataType vectorDataType, + final MethodComponentContext methodComponentContext + ) { + + final Params.ParamsBuilder builder = Params.builder(); + builder.vectorDataType(vectorDataType); + + final KNNMethodContext methodContext = new KNNMethodContext(knnEngine, spaceType, methodComponentContext); + builder.methodContextOpt = Optional.of(methodContext); + + return builder.build(); + } + + @Builder + private static class Params { + Optional methodContextOpt; + QuantizationConfig quantizationConfig; + VectorDataType vectorDataType; + } + + @RequiredArgsConstructor + private static class TestingSpec { + final KNNEngine knnEngine; + final List spaceTypes; + final List vectorDataTypes; + final List methodComponentContexts; + } +}