diff --git a/build.gradle b/build.gradle index fd79065ea4..9bd256cd6a 100644 --- a/build.gradle +++ b/build.gradle @@ -289,10 +289,23 @@ publishing { compileJava { options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) + + // Since MemorySegment is not available until JDK22, exclude it when packaging and only include it for Java22+. + def javaExt = project.extensions.getByType(JavaPluginExtension) + if (javaExt.sourceCompatibility <= JavaVersion.VERSION_21 || javaExt.targetCompatibility <= JavaVersion.VERSION_21) { + exclude("org/opensearch/knn/memoryoptsearch/MemorySegmentAddressExtractorJDK22.java") + } } compileTestJava { options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) } +javadoc { + // Block generating Java doc as it will complain MemorySegment is under preview for Java21. + def javaExt = project.extensions.getByType(JavaPluginExtension) + if (javaExt.sourceCompatibility <= JavaVersion.VERSION_21 || javaExt.targetCompatibility <= JavaVersion.VERSION_21) { + exclude("org/opensearch/knn/memoryoptsearch/MemorySegmentAddressExtractorJDK22.java") + } +} compileTestFixturesJava { options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) } @@ -482,6 +495,7 @@ def commonIntegTest(RestIntegTestTask task, project, integTestDependOnJniLib, op task.systemProperty 'cluster.debug', isDebuggingCluster // Set number of nodes system property to be used in tests task.systemProperty 'cluster.number_of_nodes', "${_numNodes}" + // There seems to be an issue when running multi node run or integ tasks with unicast_hosts // not being written, the waitForAllConditions ensures it's written task.getClusters().forEach { cluster -> @@ -541,6 +555,7 @@ def commonIntegTestClusters(OpenSearchCluster cluster, _numNodes){ debugPort += 1 } } + cluster.systemProperty("java.library.path", "$rootDir/jni/build/release") final testSnapshotFolder = file("${buildDir}/testSnapshotFolder") testSnapshotFolder.mkdirs() @@ -550,6 +565,8 @@ def commonIntegTestClusters(OpenSearchCluster cluster, _numNodes){ testClusters.integTest { commonIntegTestClusters(it, _numNodes) + // Forcing optimistic search for testing + systemProperty 'mem_opt_srch.force_reenter', 'true' } testClusters.integTestRemoteIndexBuild { @@ -562,6 +579,8 @@ testClusters.integTestRemoteIndexBuild { keystore 's3.client.default.access_key', "${System.getProperty("access_key")}" keystore 's3.client.default.secret_key', "${System.getProperty("secret_key")}" keystore 's3.client.default.session_token', "${System.getProperty("session_token")}" + // Forcing optimistic search for testing + systemProperty 'mem_opt_srch.force_reenter', 'true' } task integTestRemote(type: RestIntegTestTask) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 1b7e0f9e5f..d284ea1d48 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -117,7 +117,8 @@ public static Query create(CreateQueryRequest createQueryRequest) { .build(); } - if (createQueryRequest.getRescoreContext().isPresent() + if (memoryOptimizedSearchEnabled + || createQueryRequest.getRescoreContext().isPresent() || (ENGINES_SUPPORTING_NESTED_FIELDS.contains(createQueryRequest.getKnnEngine()) && expandNested)) { return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.getInstance(), expandNested); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index e997441538..0471e09638 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -24,7 +24,6 @@ import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; -import org.opensearch.common.Nullable; import org.opensearch.common.StopWatch; import org.opensearch.common.lucene.Lucene; import org.opensearch.knn.common.FieldInfoExtractor; @@ -52,6 +51,9 @@ import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.profile.StopWatchUtils.startStopWatch; +import static org.opensearch.knn.profile.StopWatchUtils.stopStopWatchAndLog; + /** * {@link KNNWeight} serves as a template for implementing approximate nearest neighbor (ANN) * and radius search over a native index type, such as Faiss. @@ -298,12 +300,13 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep final SegmentReader reader = Lucene.segmentReader(context.reader()); final String segmentName = reader.getSegmentName(); - StopWatch stopWatch = startStopWatch(); + final StopWatch stopWatch = startStopWatch(log); final BitSet filterBitSet = getFilteredDocsBitSet(context); - stopStopWatchAndLog(stopWatch, "FilterBitSet creation", segmentName); + stopStopWatchAndLog(log, stopWatch, "FilterBitSet creation", knnQuery.getShardId(), segmentName, knnQuery.getField()); + + // Save its cardinality, as the cardinality calculation is expensive. + final int filterCardinality = filterBitSet.cardinality(); - final int maxDoc = context.reader().maxDoc(); - int filterCardinality = filterBitSet.cardinality(); // We don't need to go to JNI layer if no documents are found which satisfy the filters // We should give this condition a deeper look that where it should be placed. For now I feel this is a good // place, @@ -320,19 +323,19 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep * This improves the recall. */ if (isFilteredExactSearchPreferred(filterCardinality)) { - TopDocs result = doExactSearch(context, new BitSetIterator(filterBitSet, filterCardinality), filterCardinality, k); - return new PerLeafResult(filterWeight == null ? null : filterBitSet, result); + final TopDocs result = doExactSearch(context, new BitSetIterator(filterBitSet, filterCardinality), filterCardinality, k); + return new PerLeafResult( + filterWeight == null ? null : filterBitSet, + filterCardinality, + result, + PerLeafResult.SearchMode.EXACT_SEARCH + ); } - /* - * If filters match all docs in this segment, then null should be passed as filterBitSet - * so that it will not do a bitset look up in bottom search layer. - */ - final BitSet annFilter = (filterWeight != null && filterCardinality == maxDoc) ? null : filterBitSet; + final StopWatch annStopWatch = startStopWatch(log); + final TopDocs topDocs = approximateSearch(context, filterBitSet, filterCardinality, k); + stopStopWatchAndLog(log, stopWatch, "ANN search", knnQuery.getShardId(), segmentName, knnQuery.getField()); - StopWatch annStopWatch = startStopWatch(); - final TopDocs topDocs = approximateSearch(context, annFilter, filterCardinality, k); - stopStopWatchAndLog(annStopWatch, "ANN search", segmentName); if (knnQuery.isExplain()) { knnExplanation.addLeafResult(context.id(), topDocs.scoreDocs.length); } @@ -341,18 +344,21 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep // results less than K, though we have more than k filtered docs if (isExactSearchRequire(context, filterCardinality, topDocs.scoreDocs.length)) { final BitSetIterator docs = filterWeight != null ? new BitSetIterator(filterBitSet, filterCardinality) : null; - TopDocs result = doExactSearch(context, docs, filterCardinality, k); - return new PerLeafResult(filterWeight == null ? null : filterBitSet, result); + final TopDocs result = doExactSearch(context, docs, filterCardinality, k); + return new PerLeafResult( + filterWeight == null ? null : filterBitSet, + filterCardinality, + result, + PerLeafResult.SearchMode.EXACT_SEARCH + ); } - return new PerLeafResult(filterWeight == null ? null : filterBitSet, topDocs); - } - private void stopStopWatchAndLog(@Nullable final StopWatch stopWatch, final String prefixMessage, String segmentName) { - if (log.isDebugEnabled() && stopWatch != null) { - stopWatch.stop(); - final String logMessage = prefixMessage + " shard: [{}], segment: [{}], field: [{}], time in nanos:[{}] "; - log.debug(logMessage, knnQuery.getShardId(), segmentName, knnQuery.getField(), stopWatch.totalTime().nanos()); - } + return new PerLeafResult( + filterWeight == null ? null : filterBitSet, + filterCardinality, + topDocs, + PerLeafResult.SearchMode.APPROXIMATE_SEARCH + ); } protected BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { @@ -413,9 +419,33 @@ private TopDocs doExactSearch( return exactSearch(context, exactSearcherContextBuilder.build()); } - protected TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k) - throws IOException { + /** + * Performs an approximate nearest neighbor (ANN) search on the provided index segment. + *

+ * This method prepares all necessary query metadata before triggering the actual ANN search. + * It extracts the {@code model_id} from field-level attributes if required, retrieves any + * quantization or auxiliary metadata associated with the vector field, and applies quantization + * to the query vector when applicable. After these preprocessing steps, it invokes + * {@code doANNSearch(LeafReaderContext, BitSet, int, int)} to execute the approximate search + * and obtain the top results. + * + * @param context the {@link LeafReaderContext} representing the current index segment + * @param filterIdsBitSet an optional {@link BitSet} indicating document IDs to include in the search; + * may be {@code null} if no filtering is required + * @param filterCardinality the number of documents included in {@code filterIdsBitSet}; + * used to optimize search filtering + * @param k the number of nearest neighbors to retrieve + * @return a {@link TopDocs} object containing the top {@code k} approximate search results + * @throws IOException if an error occurs while reading index data or accessing vector fields + */ + public TopDocs approximateSearch( + final LeafReaderContext context, + final BitSet filterIdsBitSet, + final int filterCardinality, + final int k + ) throws IOException { final SegmentReader reader = Lucene.segmentReader(context.reader()); + FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); if (fieldInfo == null) { @@ -465,6 +495,11 @@ protected TopDocs approximateSearch(final LeafReaderContext context, final BitSe // TODO: Change type of vector once more quantization methods are supported byte[] quantizedVector = maybeQuantizeVector(segmentLevelQuantizationInfo); float[] transformedVector = maybeTransformVector(segmentLevelQuantizationInfo, spaceType); + /* + * If filters match all docs in this segment, then null should be passed as filterBitSet + * so that it will not do a bitset look up in bottom search layer. + */ + final BitSet annFilter = filterCardinality == context.reader().maxDoc() ? null : filterIdsBitSet; KNNCounter.GRAPH_QUERY_REQUESTS.increment(); final TopDocs results = doANNSearch( @@ -477,8 +512,8 @@ protected TopDocs approximateSearch(final LeafReaderContext context, final BitSe quantizedVector, transformedVector, modelId, - filterIdsBitSet, - cardinality, + annFilter, + filterCardinality, k ); @@ -553,10 +588,10 @@ protected void addExplainIfRequired(final TopDocs results, final KNNEngine knnEn */ public TopDocs exactSearch(final LeafReaderContext leafReaderContext, final ExactSearcher.ExactSearcherContext exactSearcherContext) throws IOException { - StopWatch stopWatch = startStopWatch(); + final StopWatch stopWatch = startStopWatch(log); TopDocs exactSearchResults = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext); final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); - stopStopWatchAndLog(stopWatch, "Exact search", reader.getSegmentName()); + stopStopWatchAndLog(log, stopWatch, "Exact search", knnQuery.getShardId(), reader.getSegmentName(), knnQuery.getField()); return exactSearchResults; } @@ -673,13 +708,6 @@ private boolean isMissingNativeEngineFiles(LeafReaderContext context) { return engineFiles.isEmpty(); } - private StopWatch startStopWatch() { - if (log.isDebugEnabled()) { - return new StopWatch().start(); - } - return null; - } - protected int[] getParentIdsArray(final LeafReaderContext context) throws IOException { if (knnQuery.getParentsFilter() == null) { return null; diff --git a/src/main/java/org/opensearch/knn/index/query/PerLeafResult.java b/src/main/java/org/opensearch/knn/index/query/PerLeafResult.java index b14f9085c5..70d1937ac6 100644 --- a/src/main/java/org/opensearch/knn/index/query/PerLeafResult.java +++ b/src/main/java/org/opensearch/knn/index/query/PerLeafResult.java @@ -7,21 +7,189 @@ import lombok.Getter; import lombok.Setter; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; -import org.apache.lucene.util.Bits; +import org.apache.lucene.util.BitSet; import org.opensearch.common.Nullable; +/** + * Represents the per-segment (leaf-level) result of a vector search operation. + *

+ * This class encapsulates the intermediate search state and results produced from + * a single {@link LeafReaderContext} during approximate or exact vector search. + * It stores the active filter bitset (if any), its cardinality, the top document + * results for that segment, and the search mode used. + *

+ * Instances of this class are typically aggregated at a higher level to produce + * the global {@code TopDocs} result set. + */ @Getter public class PerLeafResult { - public static final PerLeafResult EMPTY_RESULT = new PerLeafResult(new Bits.MatchNoBits(0), TopDocsCollector.EMPTY_TOPDOCS); + /** + * An immutable, empty {@link BitSet} implementation used to represent + * the absence of filter bits without incurring null checks or allocations. + */ + public static final BitSet MATCH_ALL_BIT_SET = new BitSet() { + @Override + public void set(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getAndSet(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear(int startIndex, int endIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public int cardinality() { + throw new UnsupportedOperationException(); + } + + @Override + public int approximateCardinality() { + throw new UnsupportedOperationException(); + } + + @Override + public int prevSetBit(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public int nextSetBit(int start, int end) { + throw new UnsupportedOperationException(); + } + + @Override + public long ramBytesUsed() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean get(int i) { + return true; + } + + @Override + public int length() { + throw new UnsupportedOperationException(); + } + }; + + /** + * An immutable, empty {@link BitSet} implementation used to represent + * the absence of filter bits without incurring null checks or allocations. + */ + private static final BitSet MATCH_NO_BIT_SET = new BitSet() { + @Override + public void set(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getAndSet(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear(int startIndex, int endIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public int cardinality() { + throw new UnsupportedOperationException(); + } + + @Override + public int approximateCardinality() { + throw new UnsupportedOperationException(); + } + + @Override + public int prevSetBit(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public int nextSetBit(int start, int end) { + throw new UnsupportedOperationException(); + } + + @Override + public long ramBytesUsed() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean get(int i) { + return false; + } + + @Override + public int length() { + throw new UnsupportedOperationException(); + } + }; + + // A statically defined empty {@code PerLeafResult} used as a lightweight placeholder when a segment produces no hits. + public static final PerLeafResult EMPTY_RESULT = new PerLeafResult( + MATCH_NO_BIT_SET, + 0, + TopDocsCollector.EMPTY_TOPDOCS, + SearchMode.EXACT_SEARCH + ); + + /** + * Indicates the search mode applied within a segment. Either exact or approximate nearest neighbor (ANN) search. + */ + public enum SearchMode { + EXACT_SEARCH, + APPROXIMATE_SEARCH, + } + + // Active filter bitset limiting document candidates in this leaf (may be empty). @Nullable - private final Bits filterBits; + private final BitSet filterBits; + + // Cardinality of {@link #filterBits}, used for filtering optimizations. + private final int filterBitsCardinality; + + // Top document results for this leaf segment. @Setter private TopDocs result; - public PerLeafResult(final Bits filterBits, final TopDocs result) { - this.filterBits = filterBits == null ? new Bits.MatchAllBits(0) : filterBits; + // Indicates whether this result was produced via exact or approximate search. + private final SearchMode searchMode; + + /** + * Constructs a new {@code PerLeafResult}. + * + * @param filterBits the document filter bitset for this leaf, or {@code null} if none + * @param filterBitsCardinality the number of bits set in {@code filterBits} + * @param result the top document results for this leaf + * @param searchMode the search mode (exact or approximate) used + */ + public PerLeafResult(final BitSet filterBits, final int filterBitsCardinality, final TopDocs result, final SearchMode searchMode) { + this.filterBits = filterBits == null ? MATCH_ALL_BIT_SET : filterBits; + this.filterBitsCardinality = filterBitsCardinality; this.result = result; + this.searchMode = searchMode; } } diff --git a/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/MemoryOptimizedKNNWeight.java b/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/MemoryOptimizedKNNWeight.java index 223854e963..89f68fa4c1 100644 --- a/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/MemoryOptimizedKNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/MemoryOptimizedKNNWeight.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query.memoryoptsearch; +import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; @@ -29,6 +30,8 @@ import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.query.MemoryOptimizedSearchScoreConverter; +import org.opensearch.lucene.OptimisticKnnCollectorManager; +import org.opensearch.lucene.ReentrantKnnCollectorManager; import java.io.IOException; @@ -47,6 +50,8 @@ public class MemoryOptimizedKNNWeight extends KNNWeight { private static final KnnSearchStrategy.Hnsw DEFAULT_HNSW_SEARCH_STRATEGY = new KnnSearchStrategy.Hnsw(60); private final KnnCollectorManager knnCollectorManager; + @Setter + private ReentrantKnnCollectorManager optimistic2ndKnnCollectorManager; public MemoryOptimizedKNNWeight(KNNQuery query, float boost, final Weight filterWeight, IndexSearcher searcher, Integer k) { super(query, boost, filterWeight); @@ -55,7 +60,7 @@ public MemoryOptimizedKNNWeight(KNNQuery query, float boost, final Weight filter // ANN Search if (query.getParentsFilter() == null) { // Non-nested case - this.knnCollectorManager = new TopKnnCollectorManager(k, searcher); + this.knnCollectorManager = new OptimisticKnnCollectorManager(k, new TopKnnCollectorManager(k, searcher)); } else { // Nested case this.knnCollectorManager = new DiversifyingNearestChildrenKnnCollectorManager(k, query.getParentsFilter(), searcher); @@ -184,7 +189,10 @@ private TopDocs queryIndex( } // Create a collector + bitset - final KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, DEFAULT_HNSW_SEARCH_STRATEGY, context); + final KnnCollectorManager collectorManager = optimistic2ndKnnCollectorManager != null + ? optimistic2ndKnnCollectorManager + : knnCollectorManager; + final KnnCollector knnCollector = collectorManager.newCollector(visitedLimit, DEFAULT_HNSW_SEARCH_STRATEGY, context); final AcceptDocs acceptDocs = getAcceptedDocs(reader, cardinality, filterIdsBitSet); // Start searching index @@ -217,22 +225,21 @@ private AcceptDocs getAcceptedDocs(SegmentReader reader, int cardinality, BitSet } else { acceptDocs = new AcceptDocs() { @Override - public Bits bits() throws IOException { + public Bits bits() { return filterIdsBitSet; } @Override - public DocIdSetIterator iterator() throws IOException { + public DocIdSetIterator iterator() { return new BitSetIterator(filterIdsBitSet, cardinality); } @Override - public int cost() throws IOException { + public int cost() { return cardinality; } }; } return acceptDocs; } - } diff --git a/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/optimistic/OptimisticSearchStrategyUtils.java b/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/optimistic/OptimisticSearchStrategyUtils.java new file mode 100644 index 0000000000..d0533897aa --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/optimistic/OptimisticSearchStrategyUtils.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.memoryoptsearch.optimistic; + +import lombok.experimental.UtilityClass; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.util.hnsw.FloatHeap; +import org.opensearch.knn.index.query.PerLeafResult; + +import java.util.List; + +/** + * Utility class providing helper methods for the optimistic search strategy used in KNN search. + * + *

The optimistic search strategy executes a two-phase KNN search across multiple index segments: + *

    + *
  1. Phase 1 – Shallow search: Runs an approximate KNN search independently on each + * segment with an adjusted k value based on segment size. The results are merged + * into a single candidate list across all segments.
  2. + *
  3. Phase 2 – Deep search: Selects only the segments whose minimum score is greater than + * or equal to the k-th largest score from the merged results of Phase 1, and + * re-runs a refined KNN search using the collected candidates as seeds.
  4. + *
+ * + *

This class provides utility functions that assist in merging results, computing score thresholds, + * and managing per-segment results between the two phases. + * + *

All methods are static and stateless. + */ +@UtilityClass +public class OptimisticSearchStrategyUtils { + /** + * Returns the k-th largest score across a collection of per-leaf search results, + * as if all scores were merged and globally sorted in descending order. + *

+ * This utility is typically used to determine the global score threshold + * corresponding to the top-k results when combining partial {@code TopDocs} + * from multiple segments or shards. + *

+ * The method does not perform a full global sort of all scores; it only identifies + * the score value that would occupy the k-th position in the merged ranking. + * + * @param results a list of {@link PerLeafResult} objects, each containing scores + * collected from an individual segment or shard + * @param k the rank (1-based) of the desired score, e.g., {@code k = 10} + * returns the 10th highest score overall + * @param totalResults the total number of results across all {@code results}; + * used for boundary checks or optimizations + * @return the score value that would appear at position {@code k} if all scores + * were globally sorted in descending order + * @throws IllegalArgumentException if {@code k} is less than 1 or greater than {@code totalResults} + */ + public static float findKthLargestScore(final List results, final int k, final int totalResults) { + if (totalResults <= 0) { + throw new IllegalArgumentException("Total results must be greater than zero, got=" + totalResults); + } + if (k <= 0) { + throw new IllegalArgumentException("K must be greater than zero, got=" + k); + } + if (k > totalResults) { + throw new IllegalArgumentException("K must be less than total results, got=" + k + ", totalResults=" + totalResults); + } + + // If fewer than k scores, return the minimum score + if (totalResults <= k) { + float min = Float.MAX_VALUE; + for (final PerLeafResult result : results) { + for (final ScoreDoc scoreDoc : result.getResult().scoreDocs) { + if (scoreDoc.score < min) { + min = scoreDoc.score; + } + } + } + return min; + } + + // Use a min-heap to track the top-k largest values. + // Since each PerLeafResult is already sorted in descending order by score, we push larger values first, + // allowing the heap to fill quickly and skip most of the remaining elements. + // This makes the practical complexity close to O(N + log K), as heap operations occur infrequently once saturated. + final FloatHeap floatHeap = new FloatHeap(k); + final int[] indices = new int[results.size()]; + // Maximum loop count is totalResults * #segments. Result size of segment < totalResults, therefore the upper bound (e.g. maxI) + // becomes totalResults * #segments. Having this limit to prevent infinite loop. + for (int i = 0, visited = 0, maxI = totalResults * results.size(); visited < totalResults && i < maxI; ++i) { + final int resultIndex = i % indices.length; + final int scoreIndex = indices[resultIndex]; + final ScoreDoc[] scoreDocs = results.get(resultIndex).getResult().scoreDocs; + if (scoreIndex < scoreDocs.length) { + floatHeap.offer(scoreDocs[scoreIndex].score); + ++visited; + indices[resultIndex] = scoreIndex + 1; + } + } + + return floatHeap.peek(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index d8a0951fef..bd80f1d81a 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -20,6 +20,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.knn.TopKnnCollectorManager; import org.apache.lucene.util.Bits; import org.opensearch.common.StopWatch; import org.opensearch.knn.index.KNNSettings; @@ -30,6 +31,9 @@ import org.opensearch.knn.index.query.PerLeafResult; import org.opensearch.knn.index.query.ResultUtil; import org.opensearch.knn.index.query.common.QueryUtils; +import org.opensearch.knn.index.query.memoryoptsearch.MemoryOptimizedKNNWeight; +import org.opensearch.knn.index.query.memoryoptsearch.optimistic.OptimisticSearchStrategyUtils; +import org.opensearch.lucene.ReentrantKnnCollectorManager; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.profile.KNNProfileUtil; import org.opensearch.knn.profile.LongMetric; @@ -41,12 +45,17 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.concurrent.Callable; import java.util.stream.Collectors; +import static org.opensearch.knn.profile.StopWatchUtils.startStopWatch; +import static org.opensearch.knn.profile.StopWatchUtils.stopStopWatchAndLog; + /** * {@link KNNQuery} executes approximate nearest neighbor search (ANN) on a segment level. * {@link NativeEngineKnnVectorQuery} executes approximate nearest neighbor search but gives @@ -58,6 +67,17 @@ @Getter @RequiredArgsConstructor public class NativeEngineKnnVectorQuery extends Query { + /** + * A special flag used for testing purposes that forces execution of the second (exact) search + * in optimistic search mode, regardless of the results returned by the first approximate search. + *

+ * This flag should never be enabled in production; it is intended for testing and debugging only. + */ + private static final boolean FORCE_REENTER_TESTING; + + static { + FORCE_REENTER_TESTING = Boolean.parseBoolean(System.getProperty("mem_opt_srch.force_reenter", "false")); + } private final KNNQuery knnQuery; private final QueryUtils queryUtils; @@ -214,6 +234,7 @@ private PerLeafResult retrieveLeafResult( perLeafResult.getFilterBits() ); + // Build exact search context final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() .matchedDocsIterator(allSiblings) .numberOfMatchedDocs(allSiblings.cost()) @@ -226,8 +247,17 @@ private PerLeafResult retrieveLeafResult( .byteQueryVector(knnQuery.getByteQueryVector()) .isMemoryOptimizedSearchEnabled(knnQuery.isMemoryOptimizedSearch()) .build(); + + // Run exact search TopDocs rescoreResult = knnWeight.exactSearch(leafReaderContext, exactSearcherContext); - return new PerLeafResult(perLeafResult.getFilterBits(), rescoreResult); + + // Pack it as a result and return + return new PerLeafResult( + perLeafResult.getFilterBits(), + perLeafResult.getFilterBitsCardinality(), + rescoreResult, + PerLeafResult.SearchMode.EXACT_SEARCH + ); } private List doSearch( @@ -236,11 +266,116 @@ private List doSearch( KNNWeight knnWeight, int k ) throws IOException { + // Collect search tasks List> tasks = new ArrayList<>(leafReaderContexts.size()); for (LeafReaderContext leafReaderContext : leafReaderContexts) { tasks.add(() -> searchLeaf(leafReaderContext, knnWeight, k)); } - return indexSearcher.getTaskExecutor().invokeAll(tasks); + + // Execute search tasks + final List perLeafResults = indexSearcher.getTaskExecutor().invokeAll(tasks); + + // For memory optimized search, it should kick off 2nd search if optimistic + if (knnQuery.isMemoryOptimizedSearch() && perLeafResults.size() > 1) { + log.debug( + "Running second deep dive search in optimistic while memory optimized search is enabled. perLeafResults.size()={}", + perLeafResults.size() + ); + final StopWatch stopWatch = startStopWatch(log); + run2ndOptimisticSearch(perLeafResults, knnWeight, leafReaderContexts, k, indexSearcher); + stopStopWatchAndLog(log, stopWatch, "2ndOptimisticSearch", knnQuery.getShardId(), "All Shards", knnQuery.getField()); + } + + return perLeafResults; + } + + private void run2ndOptimisticSearch( + final List perLeafResults, + final KNNWeight knnWeight, + final List leafReaderContexts, + final int k, + final IndexSearcher indexSearcher + ) throws IOException { + if ((knnWeight instanceof MemoryOptimizedKNNWeight) == false) { + log.error( + "Memory optimized search was enabled, but got [" + + (knnWeight == null ? "null" : knnWeight.getClass().getSimpleName()) + + "], expected=" + + MemoryOptimizedKNNWeight.class.getSimpleName() + ); + return; + } + + assert (perLeafResults.size() == leafReaderContexts.size()); + + // Get collector manager first + final MemoryOptimizedKNNWeight memoryOptKNNWeight = (MemoryOptimizedKNNWeight) knnWeight; + + // How many results have we collected? + int totalResults = 0; + for (PerLeafResult perLeafResult : perLeafResults) { + totalResults += perLeafResult.getResult().scoreDocs.length; + } + + // If we got empty results, then return immediately + if (totalResults == 0) { + return; + } + + // Start 2nd deep dive, and get the minimum bar. + final float minTopKScore = OptimisticSearchStrategyUtils.findKthLargestScore(perLeafResults, knnQuery.getK(), totalResults); + + // Select candidate segments for 2nd search. Pick whatever segment returned all vectors whose score values are greater than `kth` + // value in the merged results. + final List> secondDeepDiveTasks = new ArrayList<>(); + final List contextIndices = new ArrayList<>(); + final Map segmentOrdToResults = new HashMap<>(); + + for (int i = 0; i < leafReaderContexts.size(); ++i) { + final LeafReaderContext leafReaderContext = leafReaderContexts.get(i); + final PerLeafResult perLeafResult = perLeafResults.get(i); + final TopDocs perLeaf = perLeafResults.get(i).getResult(); + if (perLeaf.scoreDocs.length > 0 && perLeafResult.getSearchMode() == PerLeafResult.SearchMode.APPROXIMATE_SEARCH) { + if (FORCE_REENTER_TESTING || perLeaf.scoreDocs[perLeaf.scoreDocs.length - 1].score >= minTopKScore) { + // For the target segment, save top results. Which will be used as seeds. + segmentOrdToResults.put(leafReaderContext.ord, perLeaf); + + // All this leaf's hits are at or above the global topK min score; explore it further + secondDeepDiveTasks.add( + () -> knnWeight.approximateSearch( + leafReaderContext, + perLeafResult.getFilterBits(), + perLeafResult.getFilterBitsCardinality(), + knnQuery.getK() + ) + ); + contextIndices.add(i); + } + } + } + + // Kick off 2nd search tasks + if (secondDeepDiveTasks.isEmpty() == false) { + final ReentrantKnnCollectorManager knnCollectorManagerPhase2 = new ReentrantKnnCollectorManager( + new TopKnnCollectorManager(k, indexSearcher), + segmentOrdToResults, + knnQuery.getQueryVector(), + knnQuery.getField() + ); + + // Make weight use reentrant collector manager + memoryOptKNNWeight.setOptimistic2ndKnnCollectorManager(knnCollectorManagerPhase2); + + final List deepDiveTopDocs = indexSearcher.getTaskExecutor().invokeAll(secondDeepDiveTasks); + + // Override results for target context + for (int i = 0; i < deepDiveTopDocs.size(); ++i) { + // Override with the new results + final TopDocs resultsFrom2ncDeepDive = deepDiveTopDocs.get(i); + final PerLeafResult perLeafResult = perLeafResults.get(contextIndices.get(i)); + perLeafResult.setResult(resultsFrom2ncDeepDive); + } + } } private List doRescore( @@ -287,7 +422,12 @@ private List doRescore( .parentsFilter(knnQuery.getParentsFilter()) .build(); TopDocs rescoreResult = knnWeight.exactSearch(leafReaderContext, exactSearcherContext); - return new PerLeafResult(perLeafeResult.getFilterBits(), rescoreResult); + return new PerLeafResult( + perLeafeResult.getFilterBits(), + perLeafeResult.getFilterBitsCardinality(), + rescoreResult, + PerLeafResult.SearchMode.EXACT_SEARCH + ); }); } return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks); @@ -297,7 +437,6 @@ private PerLeafResult searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, i final PerLeafResult perLeafResult = queryWeight.searchLeaf(ctx, k); final Bits liveDocs = ctx.reader().getLiveDocs(); if (liveDocs != null) { - List list = new ArrayList<>(); for (ScoreDoc scoreDoc : perLeafResult.getResult().scoreDocs) { if (liveDocs.get(scoreDoc.doc)) { @@ -306,7 +445,7 @@ private PerLeafResult searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, i } ScoreDoc[] filteredScoreDoc = list.toArray(new ScoreDoc[0]); TotalHits totalHits = new TotalHits(filteredScoreDoc.length, TotalHits.Relation.EQUAL_TO); - return new PerLeafResult(perLeafResult.getFilterBits(), new TopDocs(totalHits, filteredScoreDoc)); + perLeafResult.setResult(new TopDocs(totalHits, filteredScoreDoc)); } return perLeafResult; } diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIndexScalarQuantizedFlat.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIndexScalarQuantizedFlat.java index 782ca2b2b7..87b9e41852 100644 --- a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIndexScalarQuantizedFlat.java +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIndexScalarQuantizedFlat.java @@ -7,6 +7,7 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; @@ -25,6 +26,7 @@ * For example, the quantization type `QT_8BIT` indicates that each element in a vector is quantized into 8bits. Therefore, each element * will occupy exactly one byte, a vector would occupy exactly the size of dimensions. */ +@Log4j2 @Getter public class FaissIndexScalarQuantizedFlat extends FaissIndex { private static EnumMap VECTOR_DATA_TYPES = new EnumMap<>( diff --git a/src/main/java/org/opensearch/knn/profile/ProfileDefaultKNNWeight.java b/src/main/java/org/opensearch/knn/profile/ProfileDefaultKNNWeight.java index 7cc7bc0ef7..ca2a89cfd6 100644 --- a/src/main/java/org/opensearch/knn/profile/ProfileDefaultKNNWeight.java +++ b/src/main/java/org/opensearch/knn/profile/ProfileDefaultKNNWeight.java @@ -59,7 +59,7 @@ protected BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOExc } @Override - protected TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k) + public TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k) throws IOException { return (TopDocs) KNNProfileUtil.profileBreakdown( profile, diff --git a/src/main/java/org/opensearch/knn/profile/ProfileMemoryOptKNNWeight.java b/src/main/java/org/opensearch/knn/profile/ProfileMemoryOptKNNWeight.java index e532322f5d..05d6126938 100644 --- a/src/main/java/org/opensearch/knn/profile/ProfileMemoryOptKNNWeight.java +++ b/src/main/java/org/opensearch/knn/profile/ProfileMemoryOptKNNWeight.java @@ -62,7 +62,7 @@ protected BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOExc } @Override - protected TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k) + public TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k) throws IOException { return (TopDocs) KNNProfileUtil.profileBreakdown( profile, diff --git a/src/main/java/org/opensearch/knn/profile/StopWatchUtils.java b/src/main/java/org/opensearch/knn/profile/StopWatchUtils.java new file mode 100644 index 0000000000..08531d4f35 --- /dev/null +++ b/src/main/java/org/opensearch/knn/profile/StopWatchUtils.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profile; + +import lombok.experimental.UtilityClass; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.Nullable; +import org.opensearch.common.StopWatch; + +@UtilityClass +public final class StopWatchUtils { + public static StopWatch startStopWatch(final Logger log) { + if (log.isDebugEnabled()) { + return new StopWatch().start(); + } + return null; + } + + public static void stopStopWatchAndLog( + final Logger log, + @Nullable final StopWatch stopWatch, + final String prefixMessage, + final int shardId, + final String segmentName, + final String field + ) { + + if (stopWatch != null && log.isDebugEnabled()) { + stopWatch.stop(); + final String logMessage = prefixMessage + " shard: [{}], segment: [{}], field: [{}], time in nanos:[{}] "; + log.debug(logMessage, shardId, segmentName, field, stopWatch.totalTime().nanos()); + } + } +} diff --git a/src/main/java/org/opensearch/lucene/OptimisticKnnCollectorManager.java b/src/main/java/org/opensearch/lucene/OptimisticKnnCollectorManager.java new file mode 100644 index 0000000000..c2b545ad60 --- /dev/null +++ b/src/main/java/org/opensearch/lucene/OptimisticKnnCollectorManager.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.lucene; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.KnnSearchStrategy; + +import java.io.IOException; + +/** + * Collector manager responsible for constructing the appropriate KNN collector + * depending on whether the optimistic search strategy is enabled. + * + *

This class wraps an underlying {@code KnnCollectorManager} delegate and + * decides at runtime which collector implementation to use: + * + *

+ * + *

The optimistic search strategy operates in two phases: + *

    + *
  1. Phase 1 – Executes KNN searches independently per segment + * with adjusted {@code k} values based on segment size and merges the results.
  2. + *
  3. Phase 2 – Deep search: Re-runs searches only on segments that have + * promising results (based on a global score threshold) to refine recall efficiently.
  4. + *
+ * + *

Example usage: + *

{@code
+ * KnnCollectorManager baseManager = new DefaultKnnCollectorManager(...);
+ * OptimisticKnnCollectorManager manager =
+ *     new OptimisticKnnCollectorManager(baseManager, useOptimisticSearch);
+ * KnnCollector collector = manager.newCollector();
+ * }
+ * + * Ported from ... + */ +public class OptimisticKnnCollectorManager implements KnnCollectorManager { + // Constant controlling the degree of additional result exploration done during + // pro-rata search of segments. + private static final int LAMBDA = 16; + + private final int k; + private final KnnCollectorManager delegate; + + public OptimisticKnnCollectorManager(int k, KnnCollectorManager delegate) { + this.k = k; + this.delegate = delegate; + } + + @Override + public KnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context) throws IOException { + // The delegate supports optimistic collection + if (delegate.isOptimistic()) { + @SuppressWarnings("resource") + float leafProportion = context.reader().maxDoc() / (float) context.parent.reader().maxDoc(); + int perLeafTopK = perLeafTopKCalculation(k, leafProportion); + // if we divided by zero above, leafProportion can be NaN and then this would be 0 + assert perLeafTopK > 0; + return delegate.newOptimisticCollector(visitedLimit, searchStrategy, context, perLeafTopK); + } + // We don't support optimistic collection, so just do regular execution path + return delegate.newCollector(visitedLimit, searchStrategy, context); + } + + /** + * Returns perLeafTopK, the expected number (K * leafProportion) of hits in a leaf with the given + * proportion of the entire index, plus three standard deviations of a binomial distribution. Math + * says there is a 95% probability that this segment's contribution to the global top K hits are + * <= perLeafTopK. + */ + private static int perLeafTopKCalculation(int k, float leafProportion) { + return (int) Math.max(1, k * leafProportion + LAMBDA * Math.sqrt(k * leafProportion * (1 - leafProportion))); + } +} diff --git a/src/main/java/org/opensearch/lucene/ReentrantKnnCollectorManager.java b/src/main/java/org/opensearch/lucene/ReentrantKnnCollectorManager.java new file mode 100644 index 0000000000..87d94796ab --- /dev/null +++ b/src/main/java/org/opensearch/lucene/ReentrantKnnCollectorManager.java @@ -0,0 +1,134 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.lucene; + +import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.KnnSearchStrategy; + +import java.io.IOException; +import java.util.Map; + +/** + * A {@link KnnCollectorManager} that enables re-entrant (multi-phase) KNN vector search + * by seeding the HNSW graph search with document IDs collected from a prior search phase. + *

+ * This implementation reuses top-ranked results (from the 1st-phase search presumably) + * as entry points for a 2nd-phase vector search. It converts previously collected + * {@link TopDocs} into corresponding vector entry points using + * {@link SeededTopDocsDISI} and {@link SeededMappedDISI}, enabling the internal searcher to start from these known points + * instead of beginning from random or default graph entry nodes. + *

+ * See ... + */ +@Log4j2 +@RequiredArgsConstructor +public class ReentrantKnnCollectorManager implements KnnCollectorManager { + + // The underlying (delegate) KNN collector manager used to create collectors. + private final KnnCollectorManager knnCollectorManager; + + // Mapping from segment ordinal to previously collected {@link TopDocs}. + private final Map segmentOrdToResults; + + // Query vector used for scoring during vector similarity search. + private final float[] query; + + // Name of the vector field being searched. + private final String field; + + /** + * Creates a new {@link KnnCollector} for the given segment. + *

+ * If 1st-phase results are available for the segment, this collector + * will seed the vector search with those document IDs. The document IDs + * are mapped to vector indices using {@link SeededMappedDISI}, which enables + * the HNSW search to begin from those known entry points. + *

+ * If no prior results exist or no vector scorer is available, the method + * falls back to a delegate collector. + * + * @param visitLimit the maximum number of graph nodes that can be visited + * @param searchStrategy the search strategy to use (e.g., HNSW or brute-force) + * @param ctx the leaf reader context for the current segment + * @return a seeded {@link KnnCollector} that reuses prior phase entry points, + * or a delegate collector if no seeding is possible + * @throws IOException if an I/O error occurs during setup + */ + @Override + public KnnCollector newCollector(int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx) throws IOException { + // Get delegate collector for fallback or empty cases + final KnnCollector delegateCollector = knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx); + final TopDocs seedTopDocs = segmentOrdToResults.get(ctx.ord); + + if (seedTopDocs == null || seedTopDocs.totalHits.value() == 0) { + log.warn("Seed top docs was empty, expected non-empty top results to be given."); + // Normally shouldn't happen — indicates missing or empty seed results + assert false; + return delegateCollector; + } + + // Obtain the per-segment vector values + final LeafReader reader = ctx.reader(); + final FloatVectorValues vectorValues = reader.getFloatVectorValues(field); + if (vectorValues == null) { + log.error("Acquired null {} for field [{}]", FloatVectorValues.class.getSimpleName(), field); + // Validates the field exists, otherwise throws informative exception + FloatVectorValues.checkField(reader, field); + return null; + } + + // Create a vector scorer for the query vector + final VectorScorer scorer = vectorValues.scorer(query); + + if (scorer == null) { + log.error("Acquired null {} for field [{}]", VectorScorer.class.getSimpleName(), field); + // Normally shouldn't happen + assert false; + return delegateCollector; + } + + // Get DocIdSetIterator from scorer + DocIdSetIterator vectorIterator = scorer.iterator(); + + // Convert to an indexed iterator if possible (for sparse vectors) + // Note that we're extracting DISI from Lucene's flat vector. + if (vectorIterator instanceof IndexedDISI indexedDISI) { + vectorIterator = IndexedDISI.asDocIndexIterator(indexedDISI); + } + + // Map seed document IDs to vector indices to use as HNSW entry points + if (vectorIterator instanceof KnnVectorValues.DocIndexIterator indexIterator) { + DocIdSetIterator seedDocs = new SeededMappedDISI(indexIterator, new SeededTopDocsDISI(seedTopDocs)); + return knnCollectorManager.newCollector( + visitLimit, + new KnnSearchStrategy.Seeded(seedDocs, seedTopDocs.scoreDocs.length, searchStrategy), + ctx + ); + } + + log.error( + "`vectorIterator` was not one of [{}, {}] and was {}", + IndexedDISI.class.getSimpleName(), + KnnVectorValues.DocIndexIterator.class.getSimpleName(), + vectorIterator == null ? "null" : vectorIterator.getClass().getSimpleName() + ); + + // This should not occur; fallback to delegate to prevent infinite loops + assert false; + return delegateCollector; + } +} diff --git a/src/main/java/org/opensearch/lucene/SeededMappedDISI.java b/src/main/java/org/opensearch/lucene/SeededMappedDISI.java new file mode 100644 index 0000000000..7060ec96ba --- /dev/null +++ b/src/main/java/org/opensearch/lucene/SeededMappedDISI.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.lucene; + +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.search.DocIdSetIterator; + +import java.io.IOException; + +/** + * A {@link DocIdSetIterator} that maps document IDs from a source iterator + * to their corresponding vector indices in a {@link KnnVectorValues.DocIndexIterator}. + *

+ * This class advances the {@code indexedDISI} (which provides access to vector indices) + * in sync with the {@code sourceDISI} (which provides document IDs). For each document ID + * emitted by the source iterator, it advances the index iterator to the same document ID + * and returns the associated vector index. + *

+ * Typical usage is when document-level matches (from a collector or filter) + * need to be mapped back to the vector index space for further vector-based operations. + */ +public class SeededMappedDISI extends DocIdSetIterator { + + // Iterator over vector values that exposes both doc IDs and their corresponding indices. + private final KnnVectorValues.DocIndexIterator indexedDISI; + + // Source iterator that provides the sequence of document IDs to be mapped. + private final DocIdSetIterator sourceDISI; + + /** + * Constructs a {@code SeededMappedDISI} that synchronizes a source document iterator + * with a vector index iterator. + * + * @param indexedDISI the {@link KnnVectorValues.DocIndexIterator} used to map + * document IDs to vector indices + * @param sourceDISI the {@link DocIdSetIterator} providing the source document IDs + */ + public SeededMappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) { + this.indexedDISI = indexedDISI; + this.sourceDISI = sourceDISI; + } + + /** + * Advances the source iterator to the first document ID that is greater than or equal + * to the specified target, then advances the index iterator to the same document ID. + *

+ * The returned value is the vector index corresponding to that document. + * + * @param target the target document ID + * @return the corresponding vector index, or {@link #NO_MORE_DOCS} if the end is reached + * @throws IOException if an I/O error occurs + */ + @Override + public int advance(int target) throws IOException { + int newTarget = sourceDISI.advance(target); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + + /** + * Returns an estimate of the cost (number of documents) of iterating. + * + * @return the cost estimate from the source iterator + */ + @Override + public long cost() { + return sourceDISI.cost(); + } + + /** + * Returns the current vector index corresponding to the current document position. + * + * @return the current vector index, or {@link #NO_MORE_DOCS} if iteration has completed + */ + @Override + public int docID() { + if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + return indexedDISI.index(); + } + + /** + * Advances to the next document in the source iterator and updates the index iterator + * to the same document. Returns the corresponding vector index. + * + * @return the next vector index, or {@link #NO_MORE_DOCS} if there are no more documents + * @throws IOException if an I/O error occurs + */ + @Override + public int nextDoc() throws IOException { + int newTarget = sourceDISI.nextDoc(); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } +} diff --git a/src/main/java/org/opensearch/lucene/SeededTopDocsDISI.java b/src/main/java/org/opensearch/lucene/SeededTopDocsDISI.java new file mode 100644 index 0000000000..8a5bd0b92a --- /dev/null +++ b/src/main/java/org/opensearch/lucene/SeededTopDocsDISI.java @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.lucene; + +import java.io.IOException; +import java.util.Arrays; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.TopDocs; + +/** + * A {@link DocIdSetIterator} implementation that iterates over the document IDs + * collected in a {@link TopDocs} object. + *

+ * This class extracts document IDs from the given {@link TopDocs}, sorts them + * in ascending order, and then provides sequential access to those IDs. + *

+ * It can be used to re-iterate over the documents returned by a collector, + * ensuring deterministic iteration order regardless of the original collection sequence. + */ +public class SeededTopDocsDISI extends DocIdSetIterator { + + /** Sorted array of document IDs extracted from {@link TopDocs}. */ + private final int[] sortedDocIds; + + /** Current index in {@link #sortedDocIds}. Starts at -1 before iteration. */ + private int idx = -1; + + /** + * Constructs a {@code SeededTopDocsDISI} from the given {@link TopDocs}. + *

+ * The document IDs are extracted from {@link org.apache.lucene.search.ScoreDoc#doc} + * and sorted in ascending order. The collector's base offset, if any, is already removed. + * + * @param topDocs the {@link TopDocs} containing the collected document IDs + */ + public SeededTopDocsDISI(final TopDocs topDocs) { + sortedDocIds = new int[topDocs.scoreDocs.length]; + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + // Remove the doc base as added by the collector + sortedDocIds[i] = topDocs.scoreDocs[i].doc; + } + Arrays.sort(sortedDocIds); + } + + /** + * Advances to the first document which is greater than or equals to the current one whose ID is + * greater than or equal to the given target. + *

+ * This implementation delegates to {@link #slowAdvance(int)} for simplicity. + * + * @param target the target document ID + * @return the next matching document ID, or {@link #NO_MORE_DOCS} if none remain + * @throws IOException never thrown (declared for interface compatibility) + */ + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + /** + * Returns an estimate of the number of documents this iterator will traverse. + * + * @return the number of document IDs available + */ + @Override + public long cost() { + return sortedDocIds.length; + } + + /** + * Returns the current document ID. + * + * @return the current doc ID, {@code -1} if not yet started, + * or {@link #NO_MORE_DOCS} if iteration is complete + */ + @Override + public int docID() { + if (idx == -1) { + // Not advanced + return -1; + } else if (idx >= sortedDocIds.length) { + // Exhausted doc ids + return DocIdSetIterator.NO_MORE_DOCS; + } else { + return sortedDocIds[idx]; + } + } + + /** + * Advances to the next document ID in sorted order. + * + * @return the next document ID, or {@link #NO_MORE_DOCS} if there are no more + */ + @Override + public int nextDoc() { + idx += 1; + return docID(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 9a7cf6b9da..9fb16f800a 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -742,15 +742,19 @@ private void do_testDoToQuery_whenMemoryOptimizedSearchIsEnabled( final Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); // If memory optimized search is on then, use Lucene query final KNNQuery knnQuery; - if (doRescore) { - assertTrue(query instanceof NativeEngineKnnVectorQuery); + final boolean memoryOptimizedEnabled = memoryOptimizedSearchEnabled && memoryOptimizedSearchEnabledInField; + if (memoryOptimizedEnabled) { + // Regardless rescoring, once memory optimized search is enabled, It always uses NativeEngineKnnVectorQuery knnQuery = ((NativeEngineKnnVectorQuery) query).getKnnQuery(); } else { - assertFalse(query instanceof NativeEngineKnnVectorQuery); - knnQuery = (KNNQuery) query; + // We use NativeEngineKnnVectorQuery only if when rescoring when memory optimized is turned off. + if (doRescore) { + knnQuery = ((NativeEngineKnnVectorQuery) query).getKnnQuery(); + } else { + knnQuery = (KNNQuery) query; + } } - final boolean memoryOptimizedEnabled = memoryOptimizedSearchEnabled && memoryOptimizedSearchEnabledInField; if (memoryOptimizedEnabled) { if (vectorDataType == VectorDataType.FLOAT) { assertEquals(queryVector.length, knnQuery.getQueryVector().length); diff --git a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java index 90b1e101bd..0615a1ff73 100644 --- a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java @@ -29,7 +29,7 @@ public void testReduceToTopK() { List> initialLeafResults = getRandomListOfResults(firstPassK, segmentCount); List perLeafLeafResults = initialLeafResults.stream() - .map(result -> new PerLeafResult(null, buildTopDocs(result))) + .map(result -> new PerLeafResult(null, 0, buildTopDocs(result), PerLeafResult.SearchMode.EXACT_SEARCH)) .collect(Collectors.toList()); ResultUtil.reduceToTopK(perLeafLeafResults, finalK); List> reducedLeafResults = perLeafLeafResults.stream() @@ -43,7 +43,7 @@ public void testReduceToTopK() { initialLeafResults = getRandomListOfResults(firstPassK, segmentCount); perLeafLeafResults = initialLeafResults.stream() - .map(result -> new PerLeafResult(null, buildTopDocs(result))) + .map(result -> new PerLeafResult(null, 0, buildTopDocs(result), PerLeafResult.SearchMode.EXACT_SEARCH)) .collect(Collectors.toList()); ResultUtil.reduceToTopK(perLeafLeafResults, finalK); reducedLeafResults = perLeafLeafResults.stream().map(leaf -> convertTopDocsToMap(leaf.getResult())).collect(Collectors.toList()); diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 2f145fe41c..49e351a6fa 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -164,8 +164,18 @@ public int length() { leafReader1 = leaf1.reader(); leafReader2 = leaf2.reader(); // Given - PerLeafResult leaf1Result = new PerLeafResult(null, buildTopDocs(new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f)))); - PerLeafResult leaf2Result = new PerLeafResult(null, buildTopDocs(new HashMap<>(Map.of(4, 3.4f, 3, 5.1f)))); + PerLeafResult leaf1Result = new PerLeafResult( + null, + 0, + buildTopDocs(new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f))), + PerLeafResult.SearchMode.EXACT_SEARCH + ); + PerLeafResult leaf2Result = new PerLeafResult( + null, + 0, + buildTopDocs(new HashMap<>(Map.of(4, 3.4f, 3, 5.1f))), + PerLeafResult.SearchMode.EXACT_SEARCH + ); when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(leaf1Result); when(knnWeight.searchLeaf(leaf2, 4)).thenReturn(leaf2Result); @@ -195,7 +205,12 @@ public void testExplain() { leaf1 = leaves.get(0); leafReader1 = leaf1.reader(); - PerLeafResult leafResult = new PerLeafResult(null, buildTopDocs(new HashMap<>(Map.of(4, 3.4f, 3, 5.1f)))); + PerLeafResult leafResult = new PerLeafResult( + null, + 0, + buildTopDocs(new HashMap<>(Map.of(4, 3.4f, 3, 5.1f))), + PerLeafResult.SearchMode.EXACT_SEARCH + ); when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(leafResult); @@ -270,8 +285,18 @@ public void testRescoreWhenShardLevelRescoringEnabled() { leafReader2 = leaf2.reader(); int k = 2; - PerLeafResult initialLeaf1Results = new PerLeafResult(null, buildTopDocs(new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f)))); - PerLeafResult initialLeaf2Results = new PerLeafResult(null, buildTopDocs(new HashMap<>(Map.of(0, 20f, 1, 18f, 2, 16f)))); + PerLeafResult initialLeaf1Results = new PerLeafResult( + null, + 0, + buildTopDocs(new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f))), + PerLeafResult.SearchMode.EXACT_SEARCH + ); + PerLeafResult initialLeaf2Results = new PerLeafResult( + null, + 0, + buildTopDocs(new HashMap<>(Map.of(0, 20f, 1, 18f, 2, 16f))), + PerLeafResult.SearchMode.EXACT_SEARCH + ); Map rescoredLeaf1Results = new HashMap<>(Map.of(0, 18f, 1, 20f)); Map rescoredLeaf2Results = new HashMap<>(Map.of(0, 21f)); @@ -313,7 +338,12 @@ public void testSingleLeaf() { // Given int k = 4; float boost = 1; - PerLeafResult leaf1Result = new PerLeafResult(null, buildTopDocs(new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f)))); + PerLeafResult leaf1Result = new PerLeafResult( + null, + 0, + buildTopDocs(new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f))), + PerLeafResult.SearchMode.EXACT_SEARCH + ); List leaves = reader.leaves(); leaf1 = leaves.get(0); when(knnWeight.searchLeaf(leaf1, k)).thenReturn(leaf1Result); @@ -389,8 +419,18 @@ public void testRescore() { int k = 2; int firstPassK = 100; - PerLeafResult initialLeaf1Results = new PerLeafResult(null, buildTopDocs(new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f, 3, 15f)))); - PerLeafResult initialLeaf2Results = new PerLeafResult(null, buildTopDocs(new HashMap<>(Map.of(0, 20f, 1, 18f, 2, 16f, 3, 14f)))); + PerLeafResult initialLeaf1Results = new PerLeafResult( + null, + 0, + buildTopDocs(new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f, 3, 15f))), + PerLeafResult.SearchMode.EXACT_SEARCH + ); + PerLeafResult initialLeaf2Results = new PerLeafResult( + null, + 0, + buildTopDocs(new HashMap<>(Map.of(0, 20f, 1, 18f, 2, 16f, 3, 14f))), + PerLeafResult.SearchMode.EXACT_SEARCH + ); TopDocs topDocs1 = ResultUtil.resultMapToTopDocs(Map.of(0, 18f, 1, 20f), 0); TopDocs topDocs2 = ResultUtil.resultMapToTopDocs(Map.of(0, 21f), 4); when(knnQuery.getRescoreContext()).thenReturn(RescoreContext.builder().oversampleFactor(1.5f).build()); @@ -467,11 +507,20 @@ public void testExpandNestedDocs() { // Simulate liveDocs for leaf1 (e.g., marking some documents as deleted) leafReader1 = leaf1.reader(); leafReader2 = leaf2.reader(); - Bits queryFilterBits = mock(Bits.class); HashMap leaf1Result = new HashMap<>(Map.of(0, 19f, 1, 20f, 2, 17f, 3, 15f)); - PerLeafResult initialLeaf1Results = new PerLeafResult(queryFilterBits, buildTopDocs(leaf1Result)); + PerLeafResult initialLeaf1Results = new PerLeafResult( + PerLeafResult.MATCH_ALL_BIT_SET, + 0, + buildTopDocs(leaf1Result), + PerLeafResult.SearchMode.EXACT_SEARCH + ); HashMap leaf2Result = new HashMap<>(Map.of(0, 21f, 1, 18f, 2, 16f, 3, 14f)); - PerLeafResult initialLeaf2Results = new PerLeafResult(queryFilterBits, buildTopDocs(leaf2Result)); + PerLeafResult initialLeaf2Results = new PerLeafResult( + PerLeafResult.MATCH_ALL_BIT_SET, + 0, + buildTopDocs(leaf2Result), + PerLeafResult.SearchMode.EXACT_SEARCH + ); Map exactSearchLeaf1Result = new HashMap<>(Map.of(1, 20f)); Map exactSearchLeaf2Result = new HashMap<>(Map.of(0, 21f)); @@ -512,8 +561,8 @@ public void testExpandNestedDocs() { // Verify assertEquals(expectedWeight, finalWeigh); - verify(queryUtils).getAllSiblings(leaf1, perLeafResults.get(0).keySet(), parentFilter, queryFilterBits); - verify(queryUtils).getAllSiblings(leaf2, perLeafResults.get(1).keySet(), parentFilter, queryFilterBits); + verify(queryUtils).getAllSiblings(leaf1, perLeafResults.get(0).keySet(), parentFilter, PerLeafResult.MATCH_ALL_BIT_SET); + verify(queryUtils).getAllSiblings(leaf2, perLeafResults.get(1).keySet(), parentFilter, PerLeafResult.MATCH_ALL_BIT_SET); ArgumentCaptor topDocsCaptor = ArgumentCaptor.forClass(TopDocs.class); verify(queryUtils).createDocAndScoreQuery(eq(reader), topDocsCaptor.capture(), eq(knnWeight)); TopDocs capturedTopDocs = topDocsCaptor.getValue(); diff --git a/src/test/java/org/opensearch/knn/memoryoptsearch/OptimisticSearchStrategyUtilsTests.java b/src/test/java/org/opensearch/knn/memoryoptsearch/OptimisticSearchStrategyUtilsTests.java new file mode 100644 index 0000000000..2128103aac --- /dev/null +++ b/src/test/java/org/opensearch/knn/memoryoptsearch/OptimisticSearchStrategyUtilsTests.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.Test; +import org.opensearch.knn.index.query.PerLeafResult; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.opensearch.knn.index.query.memoryoptsearch.optimistic.OptimisticSearchStrategyUtils.findKthLargestScore; + +public class OptimisticSearchStrategyUtilsTests { + /** + * Helper method to create a {@link PerLeafResult} with given scores. + */ + private static PerLeafResult perLeaf(float... scores) { + ScoreDoc[] scoreDocs = new ScoreDoc[scores.length]; + for (int i = 0; i < scores.length; i++) { + scoreDocs[i] = new ScoreDoc(i, scores[i]); + } + TopDocs topDocs = new TopDocs(new TotalHits(scores.length, TotalHits.Relation.EQUAL_TO), scoreDocs); + return new PerLeafResult(null, 0, topDocs, PerLeafResult.SearchMode.APPROXIMATE_SEARCH); + } + + @Test + public void testSingleSegmentSimple() { + List results = List.of(perLeaf(9.5f, 8.2f, 7.1f, 5.0f)); + float score = findKthLargestScore(results, 2, 4); + assertEquals(8.2f, score, 1e-6); + } + + @Test + public void testMultiSegmentMerge() { + List results = List.of(perLeaf(9.0f, 3.0f), perLeaf(8.5f, 7.2f, 4.4f), perLeaf(6.8f)); + // All scores combined: [9.0, 8.5, 7.2, 6.8, 4.4, 3.0] + // .................................^-------- This is what we're looking for + float score = findKthLargestScore(results, 3, 6); + assertEquals(7.2f, score, 1e-6); + } + + @Test + public void testTiedScores() { + List results = List.of(perLeaf(9.0f, 9.0f, 8.0f), perLeaf(8.0f, 7.5f)); + // Combined sorted: [9.0, 9.0, 8.0, 8.0, 7.5] + assertEquals(9.0f, findKthLargestScore(results, 1, 5), 1e-6); + assertEquals(9.0f, findKthLargestScore(results, 2, 5), 1e-6); + assertEquals(8.0f, findKthLargestScore(results, 3, 5), 1e-6); + } + + @Test + public void testKEqualsTotalResults() { + List results = List.of(perLeaf(5.0f, 6.0f), perLeaf(7.0f)); + // Combined: [7.0, 6.0, 5.0] + float score = findKthLargestScore(results, 3, 3); + assertEquals(5.0f, score, 1e-6); + } + + @Test + public void testInvalidK() { + List results = List.of(perLeaf(1.0f, 2.0f)); + assertThrows(IllegalArgumentException.class, () -> findKthLargestScore(results, 0, 2)); + assertThrows(IllegalArgumentException.class, () -> findKthLargestScore(results, 3, 2)); + } + + @Test + public void testEmptyResults() { + List results = new ArrayList<>(); + assertThrows(IllegalArgumentException.class, () -> findKthLargestScore(results, 1, 0)); + } + + @Test + public void testAllSameScores() { + List results = List.of(perLeaf(5.0f, 5.0f), perLeaf(5.0f)); + float score = findKthLargestScore(results, 2, 3); + assertEquals(5.0f, score, 1e-6); + } + + @Test + public void testLargeKMultiSegment() { + List results = List.of(perLeaf(10f, 9f, 8f), perLeaf(7f, 6f), perLeaf(5f, 4f)); + // Combined sorted: [10, 9, 8, 7, 6, 5, 4] + assertEquals(4f, findKthLargestScore(results, 7, 7), 1e-6); + } +} diff --git a/src/test/java/org/opensearch/knn/memoryoptsearch/OptimisticSearchTests.java b/src/test/java/org/opensearch/knn/memoryoptsearch/OptimisticSearchTests.java new file mode 100644 index 0000000000..fc2761d008 --- /dev/null +++ b/src/test/java/org/opensearch/knn/memoryoptsearch/OptimisticSearchTests.java @@ -0,0 +1,241 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch; + +import lombok.SneakyThrows; +import org.apache.lucene.index.CompositeReaderContext; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.knn.index.query.KNNQuery; +import org.opensearch.knn.index.query.PerLeafResult; +import org.opensearch.knn.index.query.common.QueryUtils; +import org.opensearch.knn.index.query.memoryoptsearch.MemoryOptimizedKNNWeight; +import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; + +import java.lang.reflect.Constructor; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executors; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyFloat; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class OptimisticSearchTests { + private static final int DEFAULT_K = 50; + + private IndexSearcher searcher; + private LeafReader parentIndexReader; + private MemoryOptimizedKNNWeight knnWeight; + private KNNQuery knnQuery; + + @Before + @SneakyThrows + public void setup() { + parentIndexReader = mock(LeafReader.class); + final LeafReaderContext indexReaderContext = mock(LeafReaderContext.class); + when(indexReaderContext.id()).thenReturn(this); + when(parentIndexReader.getContext()).thenReturn(indexReaderContext); + + searcher = mock(IndexSearcher.class); + when(searcher.getIndexReader()).thenReturn(parentIndexReader); + final TaskExecutor executor = new TaskExecutor(Executors.newSingleThreadExecutor()); + when(searcher.getTaskExecutor()).thenReturn(executor); + + knnWeight = mock(MemoryOptimizedKNNWeight.class); + + knnQuery = mock(KNNQuery.class); + when(knnQuery.createWeight(any(), any(), anyFloat())).thenReturn(knnWeight); + when(knnQuery.getK()).thenReturn(DEFAULT_K); + when(knnQuery.isMemoryOptimizedSearch()).thenReturn(true); + } + + @Test + @SneakyThrows + public void testOptimisticSearchWith5Segments2Reentering() { + // 5 segments, only 2 segments returned results whose min score >= kth largest score + testOptimisticSearch(5, 2, true); + testOptimisticSearch(5, 2, false); + } + + @Test + @SneakyThrows + public void testOptimisticSearchWith5Segments0Reentering() { + // 5 segments, none returns results whose min score >= kth largest score + testOptimisticSearch(5, 0, true); + } + + @Test + @SneakyThrows + public void testOptimisticSearchWith0Segments() { + // Empty case + testOptimisticSearch(0, 0, true); + } + + @Test + @SneakyThrows + public void testOptimisticSearchWith1Segments() { + // There's only single segment, optimistic search should be disabled. + testOptimisticSearch(1, 0, true); + } + + @SneakyThrows + private void testOptimisticSearch(final int numSegments, final int numSegmentsForReentering, final boolean isApproximateSearch) { + // Create a query + final NativeEngineKnnVectorQuery query = new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.getInstance(), false); + + // Create answer sets for 1st phase search + final List> searchResults = new ArrayList<>(); + for (int i = 0; i < numSegments; i++) { + searchResults.add(new ArrayList<>()); + } + + // Score distribution = 0.123, 1.123, ..., (#segments * k - 1) + 0.123 + final float kthLargestScore = (numSegments * DEFAULT_K - DEFAULT_K) + 0.123F; + for (int i = 0, j = 0; i < numSegments * DEFAULT_K; j = (j + 1) % numSegments) { + final float score = i + 0.123F; + final List scoreDocs = searchResults.get(j); + int prevDocId = -1; + if (scoreDocs.isEmpty() == false) { + prevDocId = scoreDocs.get(scoreDocs.size() - 1).doc; + } + if (j >= numSegmentsForReentering || score >= kthLargestScore) { + scoreDocs.add(new ScoreDoc(prevDocId + 1, score)); + ++i; + } + } + + // Sort by score by desc + for (List scoreDocs : searchResults) { + scoreDocs.sort((a, b) -> Float.compare(b.score, a.score)); + } + + // Wrap results with PerLeafResult + final List perLeafResults = new ArrayList<>(); + for (int i = 0; i < numSegments; i++) { + perLeafResults.add( + new PerLeafResult( + null, + 0, + new TopDocs( + new TotalHits(searchResults.get(i).size(), TotalHits.Relation.EQUAL_TO), + searchResults.get(i).toArray(new ScoreDoc[0]) + ), + isApproximateSearch ? PerLeafResult.SearchMode.APPROXIMATE_SEARCH : PerLeafResult.SearchMode.EXACT_SEARCH + ) + ); + } + + // Create segments + final List leafReaderContexts = new ArrayList<>(); + final int numDocsInSegment = 1000; + for (int i = 0, j = 0, docBase = 0; i < numSegments; ++i, ++j, docBase += numDocsInSegment) { + // Make mock for leaf reader context + final SegmentReader mockSegmentReader = mock(SegmentReader.class); + when(mockSegmentReader.getSegmentName()).thenReturn("_" + i + "_165_target_field.faiss"); + when(mockSegmentReader.maxDoc()).thenReturn(numDocsInSegment); + + final LeafReaderContext leafReaderContext = createLeafReaderContext(i, docBase, mockSegmentReader); + when(mockSegmentReader.getContext()).thenReturn(leafReaderContext); + + leafReaderContexts.add(leafReaderContext); + + // Return answer set per this segment + when(knnWeight.searchLeaf(eq(leafReaderContext), anyInt())).thenReturn(perLeafResults.get(i)); + when(knnWeight.approximateSearch(eq(leafReaderContext), any(), anyInt(), anyInt())).thenReturn( + perLeafResults.get(i).getResult() + ); + } + + when(parentIndexReader.leaves()).thenReturn(leafReaderContexts); + + // Create a weight and do search + final Weight weight = query.createWeight(searcher, ScoreMode.TOP_DOCS_WITH_SCORES, 1.0f); + + // Validate reentering + for (int i = 0; i < leafReaderContexts.size(); ++i) { + // Make mock for leaf reader context + final LeafReaderContext mockLeafReaderContext = leafReaderContexts.get(i); + + verify(knnWeight, times(1)).searchLeaf(eq(mockLeafReaderContext), anyInt()); + + if (i < numSegmentsForReentering) { + // Even a segment has potential, if the results gotten from exact search, then we must not reenter + final int expectedInvocations = isApproximateSearch ? 1 : 0; + + // For competitive segments, it should be revisited. + verify(knnWeight, times(expectedInvocations)).approximateSearch(eq(mockLeafReaderContext), any(), anyInt(), anyInt()); + } + } + + // Validate results + // Take top-k for answer set + List answerScores = new ArrayList<>(); + for (List scoreDocs : searchResults) { + for (ScoreDoc scoreDoc : scoreDocs) { + answerScores.add(scoreDoc.score); + } + } + answerScores.sort((a, b) -> Float.compare(b, a)); + if (numSegments > 0) { + // Only take top-k for non-empty segment. + answerScores = answerScores.subList(0, DEFAULT_K); + } + + // Collect scores and sort them in desc. + final List acquiredScores = new ArrayList<>(); + for (final LeafReaderContext leafReaderContext : leafReaderContexts) { + final Scorer scorer = weight.scorer(leafReaderContext); + final DocIdSetIterator iterator = scorer.iterator(); + while (iterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + final float score = scorer.score(); + acquiredScores.add(score); + } + } + acquiredScores.sort((a, b) -> Float.compare(b, a)); + + // Scores should be the same + assertEquals("Invalid scores acquired. Answer=" + answerScores + ", got=" + acquiredScores, answerScores, acquiredScores); + } + + private static LeafReaderContext createLeafReaderContext(final int ord, final int docBase, SegmentReader mockSegmentReader) { + try { + // Get the package-private constructor + Constructor ctor = LeafReaderContext.class.getDeclaredConstructor( + CompositeReaderContext.class, + LeafReader.class, + int.class, + int.class, + int.class, + int.class + ); + ctor.setAccessible(true); + + // Call constructor with desired values + return ctor.newInstance(null, mockSegmentReader, ord, docBase, ord, docBase); + } catch (Exception e) { + throw new RuntimeException("Failed to create LeafReaderContext via reflection", e); + } + } +} diff --git a/src/test/java/org/opensearch/knn/memoryoptsearch/ReentrantKnnCollectorManagerTests.java b/src/test/java/org/opensearch/knn/memoryoptsearch/ReentrantKnnCollectorManagerTests.java new file mode 100644 index 0000000000..f925e24352 --- /dev/null +++ b/src/test/java/org/opensearch/knn/memoryoptsearch/ReentrantKnnCollectorManagerTests.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch; + +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.junit.Before; +import org.junit.Test; +import org.mockito.MockedStatic; +import org.opensearch.lucene.ReentrantKnnCollectorManager; + +import java.util.Map; + +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ReentrantKnnCollectorManagerTests { + private KnnCollectorManager delegateManager; + private KnnCollector delegateCollector; + private LeafReader reader; + private LeafReaderContext ctx; + private FloatVectorValues vectorValues; + private VectorScorer scorer; + private KnnSearchStrategy searchStrategy; + private ReentrantKnnCollectorManager manager; + + @Before + public void setUp() throws Exception { + delegateManager = mock(KnnCollectorManager.class); + delegateCollector = mock(KnnCollector.class); + reader = mock(LeafReader.class); + vectorValues = mock(FloatVectorValues.class); + scorer = mock(VectorScorer.class); + searchStrategy = mock(KnnSearchStrategy.class); + + // Mock final class LeafReaderContext + ctx = mock(LeafReaderContext.class); + when(ctx.reader()).thenReturn(reader); + + // Mock getFieldInfos so that FloatVectorValues.checkField() or scorer() doesn't throw + FieldInfos fieldInfos = mock(FieldInfos.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + + // Stub the delegate collector + when(delegateManager.newCollector(anyInt(), any(), any())).thenReturn(delegateCollector); + + // Seed TopDocs + TopDocs seedTopDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 0.9f), new ScoreDoc(2, 0.8f) } + ); + + manager = new ReentrantKnnCollectorManager(delegateManager, Map.of(0, seedTopDocs), new float[] { 1.0f, 2.0f }, "vector_field"); + } + + @Test + public void testNormalCase_SeedsApplied() throws Exception { + // Given: scorer returns a DocIndexIterator + KnnVectorValues.DocIndexIterator docIndexIterator = mock(KnnVectorValues.DocIndexIterator.class); + when(reader.getFloatVectorValues("vector_field")).thenReturn(vectorValues); + when(vectorValues.scorer(any(float[].class))).thenReturn(scorer); + when(scorer.iterator()).thenReturn(docIndexIterator); + + // When + KnnCollector collector = manager.newCollector(10, searchStrategy, ctx); + + // Then + assertNotNull("Collector should not be null", collector); + verify(delegateManager).newCollector(eq(10), argThat(arg -> arg instanceof KnnSearchStrategy.Seeded), eq(ctx)); + } + + @Test + public void testNullVectorValues_TriggersCheckField() throws Exception { + when(reader.getFloatVectorValues("vector_field")).thenReturn(null); + + try (MockedStatic mocked = mockStatic(FloatVectorValues.class)) { + manager.newCollector(10, searchStrategy, ctx); + mocked.verify(() -> FloatVectorValues.checkField(reader, "vector_field")); + } + } +} diff --git a/src/test/java/org/opensearch/knn/memoryoptsearch/SeededMappedDISITests.java b/src/test/java/org/opensearch/knn/memoryoptsearch/SeededMappedDISITests.java new file mode 100644 index 0000000000..b7855c2375 --- /dev/null +++ b/src/test/java/org/opensearch/knn/memoryoptsearch/SeededMappedDISITests.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch; + +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.junit.Before; +import org.junit.Test; +import org.mockito.InOrder; +import org.opensearch.lucene.SeededMappedDISI; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SeededMappedDISITests { + private KnnVectorValues.DocIndexIterator indexedDISI; + private DocIdSetIterator sourceDISI; + private SeededMappedDISI mappedDISI; + + @Before + public void setup() { + indexedDISI = mock(KnnVectorValues.DocIndexIterator.class); + sourceDISI = mock(DocIdSetIterator.class); + mappedDISI = new SeededMappedDISI(indexedDISI, sourceDISI); + } + + @Test + public void testNextDocAdvancesBothIterators() throws IOException { + // Arrange + when(sourceDISI.nextDoc()).thenReturn(10); + when(indexedDISI.advance(10)).thenReturn(10); + when(indexedDISI.docID()).thenReturn(10); + when(sourceDISI.docID()).thenReturn(10); + when(indexedDISI.index()).thenReturn(42); // vector index + + // Act + int result = mappedDISI.nextDoc(); + + // Assert + verify(sourceDISI).nextDoc(); + verify(indexedDISI).advance(10); + assertEquals("Should return vector index mapped to docID", 42, result); + } + + @Test + public void testNextDocNoMoreDocs() throws IOException { + when(sourceDISI.nextDoc()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + when(indexedDISI.docID()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + when(sourceDISI.docID()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + + int result = mappedDISI.nextDoc(); + + assertEquals("Should return NO_MORE_DOCS when exhausted", DocIdSetIterator.NO_MORE_DOCS, result); + verify(indexedDISI, never()).advance(anyInt()); + } + + @Test + public void testAdvanceSyncsIndexedDISI() throws IOException { + when(sourceDISI.advance(25)).thenReturn(25); + when(indexedDISI.advance(25)).thenReturn(25); + when(indexedDISI.docID()).thenReturn(25); + when(sourceDISI.docID()).thenReturn(25); + when(indexedDISI.index()).thenReturn(7); + + int result = mappedDISI.advance(25); + + InOrder order = inOrder(sourceDISI, indexedDISI); + order.verify(sourceDISI).advance(25); + order.verify(indexedDISI).advance(25); + + assertEquals("Should return mapped index for target doc", 7, result); + } + + @Test + public void testAdvanceNoMoreDocs() throws IOException { + when(sourceDISI.advance(100)).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + when(indexedDISI.docID()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + when(sourceDISI.docID()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + + int result = mappedDISI.advance(100); + + assertEquals("Should return NO_MORE_DOCS when exhausted", DocIdSetIterator.NO_MORE_DOCS, result); + verify(indexedDISI, never()).advance(anyInt()); + } + + @Test + public void testDocIDReturnsVectorIndex() { + when(indexedDISI.docID()).thenReturn(10); + when(sourceDISI.docID()).thenReturn(10); + when(indexedDISI.index()).thenReturn(99); + + int docID = mappedDISI.docID(); + + assertEquals("Should return mapped vector index", 99, docID); + } + + @Test + public void testDocIDReturnsNoMoreDocsIfEitherIteratorExhausted() { + when(indexedDISI.docID()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + when(sourceDISI.docID()).thenReturn(5); + assertEquals("If indexedDISI exhausted, return NO_MORE_DOCS", DocIdSetIterator.NO_MORE_DOCS, mappedDISI.docID()); + + when(indexedDISI.docID()).thenReturn(3); + when(sourceDISI.docID()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + assertEquals("If sourceDISI exhausted, return NO_MORE_DOCS", DocIdSetIterator.NO_MORE_DOCS, mappedDISI.docID()); + } + + @Test + public void testCostDelegatesToSource() { + when(sourceDISI.cost()).thenReturn(123L); + assertEquals("Cost should delegate to sourceDISI", 123L, mappedDISI.cost()); + verify(sourceDISI).cost(); + } + + @Test + public void testSequentialNextDocCallsAdvanceInOrder() throws IOException { + when(sourceDISI.nextDoc()).thenReturn(5).thenReturn(8).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + + when(indexedDISI.advance(5)).thenReturn(5); + when(indexedDISI.advance(8)).thenReturn(8); + when(indexedDISI.docID()).thenReturn(8); + when(sourceDISI.docID()).thenReturn(8); + when(indexedDISI.index()).thenReturn(33); + + mappedDISI.nextDoc(); // doc 5 + mappedDISI.nextDoc(); // doc 8 + mappedDISI.nextDoc(); // end + + InOrder inOrder = inOrder(sourceDISI, indexedDISI); + inOrder.verify(sourceDISI).nextDoc(); + inOrder.verify(indexedDISI).advance(5); + inOrder.verify(sourceDISI).nextDoc(); + inOrder.verify(indexedDISI).advance(8); + } +} diff --git a/src/test/java/org/opensearch/knn/memoryoptsearch/SeededTopDocsDISITests.java b/src/test/java/org/opensearch/knn/memoryoptsearch/SeededTopDocsDISITests.java new file mode 100644 index 0000000000..43a5a6843c --- /dev/null +++ b/src/test/java/org/opensearch/knn/memoryoptsearch/SeededTopDocsDISITests.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch; + +import lombok.SneakyThrows; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.Test; +import org.opensearch.lucene.SeededTopDocsDISI; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class SeededTopDocsDISITests { + private static TopDocs topDocs(int... docIds) { + // Make score docs having scores [1,2,3,...] + ScoreDoc[] scoreDocs = new ScoreDoc[docIds.length]; + for (int i = 0; i < docIds.length; i++) { + scoreDocs[i] = new ScoreDoc(docIds[i], i); + } + return new TopDocs(new TotalHits(docIds.length, TotalHits.Relation.EQUAL_TO), scoreDocs); + } + + @Test + public void testSortedOrderAfterConstruction() { + TopDocs unsorted = topDocs(5, 2, 9, 1); + SeededTopDocsDISI disi = new SeededTopDocsDISI(unsorted); + + // Force iteration to confirm sorting + int doc; + int last = -1; + while ((doc = disi.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + assertTrue("Doc IDs should be in ascending order", doc > last); + last = doc; + } + } + + @Test + public void testDocIDBeforeIteration() { + SeededTopDocsDISI disi = new SeededTopDocsDISI(topDocs(1, 2, 3)); + assertEquals("docID should be -1 before iteration starts", -1, disi.docID()); + } + + @Test + public void testNextDocIteratesAll() { + SeededTopDocsDISI disi = new SeededTopDocsDISI(topDocs(10, 5, 7)); + int[] expected = { 5, 7, 10 }; + int i = 0; + int doc; + while ((doc = disi.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + assertEquals(expected[i++], doc); + } + assertEquals("Should iterate over all docs", expected.length, i); + } + + @Test + public void testDocIDAfterExhaustion() { + SeededTopDocsDISI disi = new SeededTopDocsDISI(topDocs(1)); + disi.nextDoc(); // first doc + disi.nextDoc(); // should hit NO_MORE_DOCS + assertEquals(DocIdSetIterator.NO_MORE_DOCS, disi.docID()); + } + + @Test + public void testCostMatchesDocCount() { + TopDocs docs = topDocs(3, 1, 2); + SeededTopDocsDISI disi = new SeededTopDocsDISI(docs); + assertEquals(docs.scoreDocs.length, disi.cost()); + } + + @Test + @SneakyThrows + public void testAdvanceToTarget() { + SeededTopDocsDISI disi = new SeededTopDocsDISI(topDocs(2, 5, 7, 10)); + int advanced = disi.advance(6); + assertEquals(7, advanced); + assertEquals(7, disi.docID()); + } + + @Test + @SneakyThrows + public void testAdvancePastEndReturnsNoMoreDocs() { + SeededTopDocsDISI disi = new SeededTopDocsDISI(topDocs(1, 2, 3)); + int doc = disi.advance(10); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc); + } + + @Test + @SneakyThrows + public void testEmptyTopDocs() { + TopDocs empty = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + SeededTopDocsDISI disi = new SeededTopDocsDISI(empty); + + assertEquals(0, disi.cost()); + assertEquals(-1, disi.docID()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, disi.nextDoc()); + assertThrows(AssertionError.class, () -> disi.advance(0)); + } + + @Test + @SneakyThrows + public void testSequentialNextDocThenAdvance() { + SeededTopDocsDISI disi = new SeededTopDocsDISI(topDocs(1, 4, 9, 15)); + + // Move to second doc + assertEquals(1, disi.nextDoc()); + assertEquals(4, disi.nextDoc()); + + // Now advance beyond 4 to 10 + assertEquals(15, disi.advance(10)); + assertEquals(15, disi.docID()); + + // Advance beyond last + assertEquals(DocIdSetIterator.NO_MORE_DOCS, disi.advance(16)); + } +} diff --git a/src/test/java/org/opensearch/knn/profile/StopWatchUtilsTests.java b/src/test/java/org/opensearch/knn/profile/StopWatchUtilsTests.java new file mode 100644 index 0000000000..2a340966dd --- /dev/null +++ b/src/test/java/org/opensearch/knn/profile/StopWatchUtilsTests.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profile; + +import org.apache.logging.log4j.Logger; +import org.junit.Test; +import org.opensearch.common.StopWatch; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class StopWatchUtilsTests { + @Test + public void shouldNotLoggingWhenNoDebug() { + // Setting mock + final Logger log = mock(Logger.class); + when(log.isDebugEnabled()).thenReturn(false); + + // We should get null StopWatch if debug is disabled. + final StopWatch stopWatch = StopWatchUtils.startStopWatch(log); + assertNull(stopWatch); + + // It's safe to call stopping + StopWatchUtils.stopStopWatchAndLog(log, stopWatch, "PrefixMessage", 0, "SegmentName", "FieldName"); + + // Logger never called. + verify(log, never()).debug(anyString()); + } + + @Test + public void shouldLoggingWhenDebug() { + // Setting mock + final Logger log = mock(Logger.class); + when(log.isDebugEnabled()).thenReturn(true); + + // We should get log when debugging is enabled + final StopWatch stopWatch = StopWatchUtils.startStopWatch(log); + assertNotNull(stopWatch); + + // Debug logging should be called + StopWatchUtils.stopStopWatchAndLog(log, stopWatch, "PrefixMessage", 0, "SegmentName", "FieldName"); + verify(log, times(1)).debug(anyString(), eq(0), eq("SegmentName"), eq("FieldName"), anyLong()); + } +}