diff --git a/build.gradle b/build.gradle
index fd79065ea4..9bd256cd6a 100644
--- a/build.gradle
+++ b/build.gradle
@@ -289,10 +289,23 @@ publishing {
compileJava {
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])
+
+ // Since MemorySegment is not available until JDK22, exclude it when packaging and only include it for Java22+.
+ def javaExt = project.extensions.getByType(JavaPluginExtension)
+ if (javaExt.sourceCompatibility <= JavaVersion.VERSION_21 || javaExt.targetCompatibility <= JavaVersion.VERSION_21) {
+ exclude("org/opensearch/knn/memoryoptsearch/MemorySegmentAddressExtractorJDK22.java")
+ }
}
compileTestJava {
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])
}
+javadoc {
+ // Block generating Java doc as it will complain MemorySegment is under preview for Java21.
+ def javaExt = project.extensions.getByType(JavaPluginExtension)
+ if (javaExt.sourceCompatibility <= JavaVersion.VERSION_21 || javaExt.targetCompatibility <= JavaVersion.VERSION_21) {
+ exclude("org/opensearch/knn/memoryoptsearch/MemorySegmentAddressExtractorJDK22.java")
+ }
+}
compileTestFixturesJava {
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])
}
@@ -482,6 +495,7 @@ def commonIntegTest(RestIntegTestTask task, project, integTestDependOnJniLib, op
task.systemProperty 'cluster.debug', isDebuggingCluster
// Set number of nodes system property to be used in tests
task.systemProperty 'cluster.number_of_nodes', "${_numNodes}"
+
// There seems to be an issue when running multi node run or integ tasks with unicast_hosts
// not being written, the waitForAllConditions ensures it's written
task.getClusters().forEach { cluster ->
@@ -541,6 +555,7 @@ def commonIntegTestClusters(OpenSearchCluster cluster, _numNodes){
debugPort += 1
}
}
+
cluster.systemProperty("java.library.path", "$rootDir/jni/build/release")
final testSnapshotFolder = file("${buildDir}/testSnapshotFolder")
testSnapshotFolder.mkdirs()
@@ -550,6 +565,8 @@ def commonIntegTestClusters(OpenSearchCluster cluster, _numNodes){
testClusters.integTest {
commonIntegTestClusters(it, _numNodes)
+ // Forcing optimistic search for testing
+ systemProperty 'mem_opt_srch.force_reenter', 'true'
}
testClusters.integTestRemoteIndexBuild {
@@ -562,6 +579,8 @@ testClusters.integTestRemoteIndexBuild {
keystore 's3.client.default.access_key', "${System.getProperty("access_key")}"
keystore 's3.client.default.secret_key', "${System.getProperty("secret_key")}"
keystore 's3.client.default.session_token', "${System.getProperty("session_token")}"
+ // Forcing optimistic search for testing
+ systemProperty 'mem_opt_srch.force_reenter', 'true'
}
task integTestRemote(type: RestIntegTestTask) {
diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
index 1b7e0f9e5f..d284ea1d48 100644
--- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
+++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
@@ -117,7 +117,8 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.build();
}
- if (createQueryRequest.getRescoreContext().isPresent()
+ if (memoryOptimizedSearchEnabled
+ || createQueryRequest.getRescoreContext().isPresent()
|| (ENGINES_SUPPORTING_NESTED_FIELDS.contains(createQueryRequest.getKnnEngine()) && expandNested)) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.getInstance(), expandNested);
}
diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java
index e997441538..0471e09638 100644
--- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java
+++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java
@@ -24,7 +24,6 @@
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
-import org.opensearch.common.Nullable;
import org.opensearch.common.StopWatch;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
@@ -52,6 +51,9 @@
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
+import static org.opensearch.knn.profile.StopWatchUtils.startStopWatch;
+import static org.opensearch.knn.profile.StopWatchUtils.stopStopWatchAndLog;
+
/**
* {@link KNNWeight} serves as a template for implementing approximate nearest neighbor (ANN)
* and radius search over a native index type, such as Faiss.
@@ -298,12 +300,13 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
final SegmentReader reader = Lucene.segmentReader(context.reader());
final String segmentName = reader.getSegmentName();
- StopWatch stopWatch = startStopWatch();
+ final StopWatch stopWatch = startStopWatch(log);
final BitSet filterBitSet = getFilteredDocsBitSet(context);
- stopStopWatchAndLog(stopWatch, "FilterBitSet creation", segmentName);
+ stopStopWatchAndLog(log, stopWatch, "FilterBitSet creation", knnQuery.getShardId(), segmentName, knnQuery.getField());
+
+ // Save its cardinality, as the cardinality calculation is expensive.
+ final int filterCardinality = filterBitSet.cardinality();
- final int maxDoc = context.reader().maxDoc();
- int filterCardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
// place,
@@ -320,19 +323,19 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
* This improves the recall.
*/
if (isFilteredExactSearchPreferred(filterCardinality)) {
- TopDocs result = doExactSearch(context, new BitSetIterator(filterBitSet, filterCardinality), filterCardinality, k);
- return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
+ final TopDocs result = doExactSearch(context, new BitSetIterator(filterBitSet, filterCardinality), filterCardinality, k);
+ return new PerLeafResult(
+ filterWeight == null ? null : filterBitSet,
+ filterCardinality,
+ result,
+ PerLeafResult.SearchMode.EXACT_SEARCH
+ );
}
- /*
- * If filters match all docs in this segment, then null should be passed as filterBitSet
- * so that it will not do a bitset look up in bottom search layer.
- */
- final BitSet annFilter = (filterWeight != null && filterCardinality == maxDoc) ? null : filterBitSet;
+ final StopWatch annStopWatch = startStopWatch(log);
+ final TopDocs topDocs = approximateSearch(context, filterBitSet, filterCardinality, k);
+ stopStopWatchAndLog(log, stopWatch, "ANN search", knnQuery.getShardId(), segmentName, knnQuery.getField());
- StopWatch annStopWatch = startStopWatch();
- final TopDocs topDocs = approximateSearch(context, annFilter, filterCardinality, k);
- stopStopWatchAndLog(annStopWatch, "ANN search", segmentName);
if (knnQuery.isExplain()) {
knnExplanation.addLeafResult(context.id(), topDocs.scoreDocs.length);
}
@@ -341,18 +344,21 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
// results less than K, though we have more than k filtered docs
if (isExactSearchRequire(context, filterCardinality, topDocs.scoreDocs.length)) {
final BitSetIterator docs = filterWeight != null ? new BitSetIterator(filterBitSet, filterCardinality) : null;
- TopDocs result = doExactSearch(context, docs, filterCardinality, k);
- return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
+ final TopDocs result = doExactSearch(context, docs, filterCardinality, k);
+ return new PerLeafResult(
+ filterWeight == null ? null : filterBitSet,
+ filterCardinality,
+ result,
+ PerLeafResult.SearchMode.EXACT_SEARCH
+ );
}
- return new PerLeafResult(filterWeight == null ? null : filterBitSet, topDocs);
- }
- private void stopStopWatchAndLog(@Nullable final StopWatch stopWatch, final String prefixMessage, String segmentName) {
- if (log.isDebugEnabled() && stopWatch != null) {
- stopWatch.stop();
- final String logMessage = prefixMessage + " shard: [{}], segment: [{}], field: [{}], time in nanos:[{}] ";
- log.debug(logMessage, knnQuery.getShardId(), segmentName, knnQuery.getField(), stopWatch.totalTime().nanos());
- }
+ return new PerLeafResult(
+ filterWeight == null ? null : filterBitSet,
+ filterCardinality,
+ topDocs,
+ PerLeafResult.SearchMode.APPROXIMATE_SEARCH
+ );
}
protected BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
@@ -413,9 +419,33 @@ private TopDocs doExactSearch(
return exactSearch(context, exactSearcherContextBuilder.build());
}
- protected TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k)
- throws IOException {
+ /**
+ * Performs an approximate nearest neighbor (ANN) search on the provided index segment.
+ *
+ * This method prepares all necessary query metadata before triggering the actual ANN search.
+ * It extracts the {@code model_id} from field-level attributes if required, retrieves any
+ * quantization or auxiliary metadata associated with the vector field, and applies quantization
+ * to the query vector when applicable. After these preprocessing steps, it invokes
+ * {@code doANNSearch(LeafReaderContext, BitSet, int, int)} to execute the approximate search
+ * and obtain the top results.
+ *
+ * @param context the {@link LeafReaderContext} representing the current index segment
+ * @param filterIdsBitSet an optional {@link BitSet} indicating document IDs to include in the search;
+ * may be {@code null} if no filtering is required
+ * @param filterCardinality the number of documents included in {@code filterIdsBitSet};
+ * used to optimize search filtering
+ * @param k the number of nearest neighbors to retrieve
+ * @return a {@link TopDocs} object containing the top {@code k} approximate search results
+ * @throws IOException if an error occurs while reading index data or accessing vector fields
+ */
+ public TopDocs approximateSearch(
+ final LeafReaderContext context,
+ final BitSet filterIdsBitSet,
+ final int filterCardinality,
+ final int k
+ ) throws IOException {
final SegmentReader reader = Lucene.segmentReader(context.reader());
+
FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
@@ -465,6 +495,11 @@ protected TopDocs approximateSearch(final LeafReaderContext context, final BitSe
// TODO: Change type of vector once more quantization methods are supported
byte[] quantizedVector = maybeQuantizeVector(segmentLevelQuantizationInfo);
float[] transformedVector = maybeTransformVector(segmentLevelQuantizationInfo, spaceType);
+ /*
+ * If filters match all docs in this segment, then null should be passed as filterBitSet
+ * so that it will not do a bitset look up in bottom search layer.
+ */
+ final BitSet annFilter = filterCardinality == context.reader().maxDoc() ? null : filterIdsBitSet;
KNNCounter.GRAPH_QUERY_REQUESTS.increment();
final TopDocs results = doANNSearch(
@@ -477,8 +512,8 @@ protected TopDocs approximateSearch(final LeafReaderContext context, final BitSe
quantizedVector,
transformedVector,
modelId,
- filterIdsBitSet,
- cardinality,
+ annFilter,
+ filterCardinality,
k
);
@@ -553,10 +588,10 @@ protected void addExplainIfRequired(final TopDocs results, final KNNEngine knnEn
*/
public TopDocs exactSearch(final LeafReaderContext leafReaderContext, final ExactSearcher.ExactSearcherContext exactSearcherContext)
throws IOException {
- StopWatch stopWatch = startStopWatch();
+ final StopWatch stopWatch = startStopWatch(log);
TopDocs exactSearchResults = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext);
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
- stopStopWatchAndLog(stopWatch, "Exact search", reader.getSegmentName());
+ stopStopWatchAndLog(log, stopWatch, "Exact search", knnQuery.getShardId(), reader.getSegmentName(), knnQuery.getField());
return exactSearchResults;
}
@@ -673,13 +708,6 @@ private boolean isMissingNativeEngineFiles(LeafReaderContext context) {
return engineFiles.isEmpty();
}
- private StopWatch startStopWatch() {
- if (log.isDebugEnabled()) {
- return new StopWatch().start();
- }
- return null;
- }
-
protected int[] getParentIdsArray(final LeafReaderContext context) throws IOException {
if (knnQuery.getParentsFilter() == null) {
return null;
diff --git a/src/main/java/org/opensearch/knn/index/query/PerLeafResult.java b/src/main/java/org/opensearch/knn/index/query/PerLeafResult.java
index b14f9085c5..70d1937ac6 100644
--- a/src/main/java/org/opensearch/knn/index/query/PerLeafResult.java
+++ b/src/main/java/org/opensearch/knn/index/query/PerLeafResult.java
@@ -7,21 +7,189 @@
import lombok.Getter;
import lombok.Setter;
+import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
-import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BitSet;
import org.opensearch.common.Nullable;
+/**
+ * Represents the per-segment (leaf-level) result of a vector search operation.
+ *
+ * This class encapsulates the intermediate search state and results produced from
+ * a single {@link LeafReaderContext} during approximate or exact vector search.
+ * It stores the active filter bitset (if any), its cardinality, the top document
+ * results for that segment, and the search mode used.
+ *
+ * Instances of this class are typically aggregated at a higher level to produce
+ * the global {@code TopDocs} result set.
+ */
@Getter
public class PerLeafResult {
- public static final PerLeafResult EMPTY_RESULT = new PerLeafResult(new Bits.MatchNoBits(0), TopDocsCollector.EMPTY_TOPDOCS);
+ /**
+ * An immutable, empty {@link BitSet} implementation used to represent
+ * the absence of filter bits without incurring null checks or allocations.
+ */
+ public static final BitSet MATCH_ALL_BIT_SET = new BitSet() {
+ @Override
+ public void set(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean getAndSet(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void clear(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void clear(int startIndex, int endIndex) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int cardinality() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int approximateCardinality() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int prevSetBit(int index) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int nextSetBit(int start, int end) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean get(int i) {
+ return true;
+ }
+
+ @Override
+ public int length() {
+ throw new UnsupportedOperationException();
+ }
+ };
+
+ /**
+ * An immutable, empty {@link BitSet} implementation used to represent
+ * the absence of filter bits without incurring null checks or allocations.
+ */
+ private static final BitSet MATCH_NO_BIT_SET = new BitSet() {
+ @Override
+ public void set(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean getAndSet(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void clear(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void clear(int startIndex, int endIndex) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int cardinality() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int approximateCardinality() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int prevSetBit(int index) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int nextSetBit(int start, int end) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean get(int i) {
+ return false;
+ }
+
+ @Override
+ public int length() {
+ throw new UnsupportedOperationException();
+ }
+ };
+
+ // A statically defined empty {@code PerLeafResult} used as a lightweight placeholder when a segment produces no hits.
+ public static final PerLeafResult EMPTY_RESULT = new PerLeafResult(
+ MATCH_NO_BIT_SET,
+ 0,
+ TopDocsCollector.EMPTY_TOPDOCS,
+ SearchMode.EXACT_SEARCH
+ );
+
+ /**
+ * Indicates the search mode applied within a segment. Either exact or approximate nearest neighbor (ANN) search.
+ */
+ public enum SearchMode {
+ EXACT_SEARCH,
+ APPROXIMATE_SEARCH,
+ }
+
+ // Active filter bitset limiting document candidates in this leaf (may be empty).
@Nullable
- private final Bits filterBits;
+ private final BitSet filterBits;
+
+ // Cardinality of {@link #filterBits}, used for filtering optimizations.
+ private final int filterBitsCardinality;
+
+ // Top document results for this leaf segment.
@Setter
private TopDocs result;
- public PerLeafResult(final Bits filterBits, final TopDocs result) {
- this.filterBits = filterBits == null ? new Bits.MatchAllBits(0) : filterBits;
+ // Indicates whether this result was produced via exact or approximate search.
+ private final SearchMode searchMode;
+
+ /**
+ * Constructs a new {@code PerLeafResult}.
+ *
+ * @param filterBits the document filter bitset for this leaf, or {@code null} if none
+ * @param filterBitsCardinality the number of bits set in {@code filterBits}
+ * @param result the top document results for this leaf
+ * @param searchMode the search mode (exact or approximate) used
+ */
+ public PerLeafResult(final BitSet filterBits, final int filterBitsCardinality, final TopDocs result, final SearchMode searchMode) {
+ this.filterBits = filterBits == null ? MATCH_ALL_BIT_SET : filterBits;
+ this.filterBitsCardinality = filterBitsCardinality;
this.result = result;
+ this.searchMode = searchMode;
}
}
diff --git a/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/MemoryOptimizedKNNWeight.java b/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/MemoryOptimizedKNNWeight.java
index 223854e963..89f68fa4c1 100644
--- a/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/MemoryOptimizedKNNWeight.java
+++ b/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/MemoryOptimizedKNNWeight.java
@@ -5,6 +5,7 @@
package org.opensearch.knn.index.query.memoryoptsearch;
+import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
@@ -29,6 +30,8 @@
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.index.query.MemoryOptimizedSearchScoreConverter;
+import org.opensearch.lucene.OptimisticKnnCollectorManager;
+import org.opensearch.lucene.ReentrantKnnCollectorManager;
import java.io.IOException;
@@ -47,6 +50,8 @@ public class MemoryOptimizedKNNWeight extends KNNWeight {
private static final KnnSearchStrategy.Hnsw DEFAULT_HNSW_SEARCH_STRATEGY = new KnnSearchStrategy.Hnsw(60);
private final KnnCollectorManager knnCollectorManager;
+ @Setter
+ private ReentrantKnnCollectorManager optimistic2ndKnnCollectorManager;
public MemoryOptimizedKNNWeight(KNNQuery query, float boost, final Weight filterWeight, IndexSearcher searcher, Integer k) {
super(query, boost, filterWeight);
@@ -55,7 +60,7 @@ public MemoryOptimizedKNNWeight(KNNQuery query, float boost, final Weight filter
// ANN Search
if (query.getParentsFilter() == null) {
// Non-nested case
- this.knnCollectorManager = new TopKnnCollectorManager(k, searcher);
+ this.knnCollectorManager = new OptimisticKnnCollectorManager(k, new TopKnnCollectorManager(k, searcher));
} else {
// Nested case
this.knnCollectorManager = new DiversifyingNearestChildrenKnnCollectorManager(k, query.getParentsFilter(), searcher);
@@ -184,7 +189,10 @@ private TopDocs queryIndex(
}
// Create a collector + bitset
- final KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, DEFAULT_HNSW_SEARCH_STRATEGY, context);
+ final KnnCollectorManager collectorManager = optimistic2ndKnnCollectorManager != null
+ ? optimistic2ndKnnCollectorManager
+ : knnCollectorManager;
+ final KnnCollector knnCollector = collectorManager.newCollector(visitedLimit, DEFAULT_HNSW_SEARCH_STRATEGY, context);
final AcceptDocs acceptDocs = getAcceptedDocs(reader, cardinality, filterIdsBitSet);
// Start searching index
@@ -217,22 +225,21 @@ private AcceptDocs getAcceptedDocs(SegmentReader reader, int cardinality, BitSet
} else {
acceptDocs = new AcceptDocs() {
@Override
- public Bits bits() throws IOException {
+ public Bits bits() {
return filterIdsBitSet;
}
@Override
- public DocIdSetIterator iterator() throws IOException {
+ public DocIdSetIterator iterator() {
return new BitSetIterator(filterIdsBitSet, cardinality);
}
@Override
- public int cost() throws IOException {
+ public int cost() {
return cardinality;
}
};
}
return acceptDocs;
}
-
}
diff --git a/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/optimistic/OptimisticSearchStrategyUtils.java b/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/optimistic/OptimisticSearchStrategyUtils.java
new file mode 100644
index 0000000000..d0533897aa
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/index/query/memoryoptsearch/optimistic/OptimisticSearchStrategyUtils.java
@@ -0,0 +1,101 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.index.query.memoryoptsearch.optimistic;
+
+import lombok.experimental.UtilityClass;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.util.hnsw.FloatHeap;
+import org.opensearch.knn.index.query.PerLeafResult;
+
+import java.util.List;
+
+/**
+ * Utility class providing helper methods for the optimistic search strategy used in KNN search.
+ *
+ *
The optimistic search strategy executes a two-phase KNN search across multiple index segments:
+ *
+ * - Phase 1 – Shallow search: Runs an approximate KNN search independently on each
+ * segment with an adjusted
k
value based on segment size. The results are merged
+ * into a single candidate list across all segments.
+ * - Phase 2 – Deep search: Selects only the segments whose minimum score is greater than
+ * or equal to the
k
-th largest score from the merged results of Phase 1, and
+ * re-runs a refined KNN search using the collected candidates as seeds.
+ *
+ *
+ * This class provides utility functions that assist in merging results, computing score thresholds,
+ * and managing per-segment results between the two phases.
+ *
+ *
All methods are static and stateless.
+ */
+@UtilityClass
+public class OptimisticSearchStrategyUtils {
+ /**
+ * Returns the k-th largest score across a collection of per-leaf search results,
+ * as if all scores were merged and globally sorted in descending order.
+ *
+ * This utility is typically used to determine the global score threshold
+ * corresponding to the top-k results when combining partial {@code TopDocs}
+ * from multiple segments or shards.
+ *
+ * The method does not perform a full global sort of all scores; it only identifies
+ * the score value that would occupy the k-th position in the merged ranking.
+ *
+ * @param results a list of {@link PerLeafResult} objects, each containing scores
+ * collected from an individual segment or shard
+ * @param k the rank (1-based) of the desired score, e.g., {@code k = 10}
+ * returns the 10th highest score overall
+ * @param totalResults the total number of results across all {@code results};
+ * used for boundary checks or optimizations
+ * @return the score value that would appear at position {@code k} if all scores
+ * were globally sorted in descending order
+ * @throws IllegalArgumentException if {@code k} is less than 1 or greater than {@code totalResults}
+ */
+ public static float findKthLargestScore(final List results, final int k, final int totalResults) {
+ if (totalResults <= 0) {
+ throw new IllegalArgumentException("Total results must be greater than zero, got=" + totalResults);
+ }
+ if (k <= 0) {
+ throw new IllegalArgumentException("K must be greater than zero, got=" + k);
+ }
+ if (k > totalResults) {
+ throw new IllegalArgumentException("K must be less than total results, got=" + k + ", totalResults=" + totalResults);
+ }
+
+ // If fewer than k scores, return the minimum score
+ if (totalResults <= k) {
+ float min = Float.MAX_VALUE;
+ for (final PerLeafResult result : results) {
+ for (final ScoreDoc scoreDoc : result.getResult().scoreDocs) {
+ if (scoreDoc.score < min) {
+ min = scoreDoc.score;
+ }
+ }
+ }
+ return min;
+ }
+
+ // Use a min-heap to track the top-k largest values.
+ // Since each PerLeafResult is already sorted in descending order by score, we push larger values first,
+ // allowing the heap to fill quickly and skip most of the remaining elements.
+ // This makes the practical complexity close to O(N + log K), as heap operations occur infrequently once saturated.
+ final FloatHeap floatHeap = new FloatHeap(k);
+ final int[] indices = new int[results.size()];
+ // Maximum loop count is totalResults * #segments. Result size of segment < totalResults, therefore the upper bound (e.g. maxI)
+ // becomes totalResults * #segments. Having this limit to prevent infinite loop.
+ for (int i = 0, visited = 0, maxI = totalResults * results.size(); visited < totalResults && i < maxI; ++i) {
+ final int resultIndex = i % indices.length;
+ final int scoreIndex = indices[resultIndex];
+ final ScoreDoc[] scoreDocs = results.get(resultIndex).getResult().scoreDocs;
+ if (scoreIndex < scoreDocs.length) {
+ floatHeap.offer(scoreDocs[scoreIndex].score);
+ ++visited;
+ indices[resultIndex] = scoreIndex + 1;
+ }
+ }
+
+ return floatHeap.peek();
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java
index d8a0951fef..bd80f1d81a 100644
--- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java
+++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java
@@ -20,6 +20,7 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
+import org.apache.lucene.search.knn.TopKnnCollectorManager;
import org.apache.lucene.util.Bits;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.KNNSettings;
@@ -30,6 +31,9 @@
import org.opensearch.knn.index.query.PerLeafResult;
import org.opensearch.knn.index.query.ResultUtil;
import org.opensearch.knn.index.query.common.QueryUtils;
+import org.opensearch.knn.index.query.memoryoptsearch.MemoryOptimizedKNNWeight;
+import org.opensearch.knn.index.query.memoryoptsearch.optimistic.OptimisticSearchStrategyUtils;
+import org.opensearch.lucene.ReentrantKnnCollectorManager;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.profile.KNNProfileUtil;
import org.opensearch.knn.profile.LongMetric;
@@ -41,12 +45,17 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
+import static org.opensearch.knn.profile.StopWatchUtils.startStopWatch;
+import static org.opensearch.knn.profile.StopWatchUtils.stopStopWatchAndLog;
+
/**
* {@link KNNQuery} executes approximate nearest neighbor search (ANN) on a segment level.
* {@link NativeEngineKnnVectorQuery} executes approximate nearest neighbor search but gives
@@ -58,6 +67,17 @@
@Getter
@RequiredArgsConstructor
public class NativeEngineKnnVectorQuery extends Query {
+ /**
+ * A special flag used for testing purposes that forces execution of the second (exact) search
+ * in optimistic search mode, regardless of the results returned by the first approximate search.
+ *
+ * This flag should never be enabled in production; it is intended for testing and debugging only.
+ */
+ private static final boolean FORCE_REENTER_TESTING;
+
+ static {
+ FORCE_REENTER_TESTING = Boolean.parseBoolean(System.getProperty("mem_opt_srch.force_reenter", "false"));
+ }
private final KNNQuery knnQuery;
private final QueryUtils queryUtils;
@@ -214,6 +234,7 @@ private PerLeafResult retrieveLeafResult(
perLeafResult.getFilterBits()
);
+ // Build exact search context
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder()
.matchedDocsIterator(allSiblings)
.numberOfMatchedDocs(allSiblings.cost())
@@ -226,8 +247,17 @@ private PerLeafResult retrieveLeafResult(
.byteQueryVector(knnQuery.getByteQueryVector())
.isMemoryOptimizedSearchEnabled(knnQuery.isMemoryOptimizedSearch())
.build();
+
+ // Run exact search
TopDocs rescoreResult = knnWeight.exactSearch(leafReaderContext, exactSearcherContext);
- return new PerLeafResult(perLeafResult.getFilterBits(), rescoreResult);
+
+ // Pack it as a result and return
+ return new PerLeafResult(
+ perLeafResult.getFilterBits(),
+ perLeafResult.getFilterBitsCardinality(),
+ rescoreResult,
+ PerLeafResult.SearchMode.EXACT_SEARCH
+ );
}
private List doSearch(
@@ -236,11 +266,116 @@ private List doSearch(
KNNWeight knnWeight,
int k
) throws IOException {
+ // Collect search tasks
List> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext leafReaderContext : leafReaderContexts) {
tasks.add(() -> searchLeaf(leafReaderContext, knnWeight, k));
}
- return indexSearcher.getTaskExecutor().invokeAll(tasks);
+
+ // Execute search tasks
+ final List perLeafResults = indexSearcher.getTaskExecutor().invokeAll(tasks);
+
+ // For memory optimized search, it should kick off 2nd search if optimistic
+ if (knnQuery.isMemoryOptimizedSearch() && perLeafResults.size() > 1) {
+ log.debug(
+ "Running second deep dive search in optimistic while memory optimized search is enabled. perLeafResults.size()={}",
+ perLeafResults.size()
+ );
+ final StopWatch stopWatch = startStopWatch(log);
+ run2ndOptimisticSearch(perLeafResults, knnWeight, leafReaderContexts, k, indexSearcher);
+ stopStopWatchAndLog(log, stopWatch, "2ndOptimisticSearch", knnQuery.getShardId(), "All Shards", knnQuery.getField());
+ }
+
+ return perLeafResults;
+ }
+
+ private void run2ndOptimisticSearch(
+ final List perLeafResults,
+ final KNNWeight knnWeight,
+ final List leafReaderContexts,
+ final int k,
+ final IndexSearcher indexSearcher
+ ) throws IOException {
+ if ((knnWeight instanceof MemoryOptimizedKNNWeight) == false) {
+ log.error(
+ "Memory optimized search was enabled, but got ["
+ + (knnWeight == null ? "null" : knnWeight.getClass().getSimpleName())
+ + "], expected="
+ + MemoryOptimizedKNNWeight.class.getSimpleName()
+ );
+ return;
+ }
+
+ assert (perLeafResults.size() == leafReaderContexts.size());
+
+ // Get collector manager first
+ final MemoryOptimizedKNNWeight memoryOptKNNWeight = (MemoryOptimizedKNNWeight) knnWeight;
+
+ // How many results have we collected?
+ int totalResults = 0;
+ for (PerLeafResult perLeafResult : perLeafResults) {
+ totalResults += perLeafResult.getResult().scoreDocs.length;
+ }
+
+ // If we got empty results, then return immediately
+ if (totalResults == 0) {
+ return;
+ }
+
+ // Start 2nd deep dive, and get the minimum bar.
+ final float minTopKScore = OptimisticSearchStrategyUtils.findKthLargestScore(perLeafResults, knnQuery.getK(), totalResults);
+
+ // Select candidate segments for 2nd search. Pick whatever segment returned all vectors whose score values are greater than `kth`
+ // value in the merged results.
+ final List> secondDeepDiveTasks = new ArrayList<>();
+ final List contextIndices = new ArrayList<>();
+ final Map segmentOrdToResults = new HashMap<>();
+
+ for (int i = 0; i < leafReaderContexts.size(); ++i) {
+ final LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
+ final PerLeafResult perLeafResult = perLeafResults.get(i);
+ final TopDocs perLeaf = perLeafResults.get(i).getResult();
+ if (perLeaf.scoreDocs.length > 0 && perLeafResult.getSearchMode() == PerLeafResult.SearchMode.APPROXIMATE_SEARCH) {
+ if (FORCE_REENTER_TESTING || perLeaf.scoreDocs[perLeaf.scoreDocs.length - 1].score >= minTopKScore) {
+ // For the target segment, save top results. Which will be used as seeds.
+ segmentOrdToResults.put(leafReaderContext.ord, perLeaf);
+
+ // All this leaf's hits are at or above the global topK min score; explore it further
+ secondDeepDiveTasks.add(
+ () -> knnWeight.approximateSearch(
+ leafReaderContext,
+ perLeafResult.getFilterBits(),
+ perLeafResult.getFilterBitsCardinality(),
+ knnQuery.getK()
+ )
+ );
+ contextIndices.add(i);
+ }
+ }
+ }
+
+ // Kick off 2nd search tasks
+ if (secondDeepDiveTasks.isEmpty() == false) {
+ final ReentrantKnnCollectorManager knnCollectorManagerPhase2 = new ReentrantKnnCollectorManager(
+ new TopKnnCollectorManager(k, indexSearcher),
+ segmentOrdToResults,
+ knnQuery.getQueryVector(),
+ knnQuery.getField()
+ );
+
+ // Make weight use reentrant collector manager
+ memoryOptKNNWeight.setOptimistic2ndKnnCollectorManager(knnCollectorManagerPhase2);
+
+ final List deepDiveTopDocs = indexSearcher.getTaskExecutor().invokeAll(secondDeepDiveTasks);
+
+ // Override results for target context
+ for (int i = 0; i < deepDiveTopDocs.size(); ++i) {
+ // Override with the new results
+ final TopDocs resultsFrom2ncDeepDive = deepDiveTopDocs.get(i);
+ final PerLeafResult perLeafResult = perLeafResults.get(contextIndices.get(i));
+ perLeafResult.setResult(resultsFrom2ncDeepDive);
+ }
+ }
}
private List doRescore(
@@ -287,7 +422,12 @@ private List doRescore(
.parentsFilter(knnQuery.getParentsFilter())
.build();
TopDocs rescoreResult = knnWeight.exactSearch(leafReaderContext, exactSearcherContext);
- return new PerLeafResult(perLeafeResult.getFilterBits(), rescoreResult);
+ return new PerLeafResult(
+ perLeafeResult.getFilterBits(),
+ perLeafeResult.getFilterBitsCardinality(),
+ rescoreResult,
+ PerLeafResult.SearchMode.EXACT_SEARCH
+ );
});
}
return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks);
@@ -297,7 +437,6 @@ private PerLeafResult searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, i
final PerLeafResult perLeafResult = queryWeight.searchLeaf(ctx, k);
final Bits liveDocs = ctx.reader().getLiveDocs();
if (liveDocs != null) {
-
List list = new ArrayList<>();
for (ScoreDoc scoreDoc : perLeafResult.getResult().scoreDocs) {
if (liveDocs.get(scoreDoc.doc)) {
@@ -306,7 +445,7 @@ private PerLeafResult searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, i
}
ScoreDoc[] filteredScoreDoc = list.toArray(new ScoreDoc[0]);
TotalHits totalHits = new TotalHits(filteredScoreDoc.length, TotalHits.Relation.EQUAL_TO);
- return new PerLeafResult(perLeafResult.getFilterBits(), new TopDocs(totalHits, filteredScoreDoc));
+ perLeafResult.setResult(new TopDocs(totalHits, filteredScoreDoc));
}
return perLeafResult;
}
diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIndexScalarQuantizedFlat.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIndexScalarQuantizedFlat.java
index 782ca2b2b7..87b9e41852 100644
--- a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIndexScalarQuantizedFlat.java
+++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIndexScalarQuantizedFlat.java
@@ -7,6 +7,7 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;
+import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorEncoding;
@@ -25,6 +26,7 @@
* For example, the quantization type `QT_8BIT` indicates that each element in a vector is quantized into 8bits. Therefore, each element
* will occupy exactly one byte, a vector would occupy exactly the size of dimensions.
*/
+@Log4j2
@Getter
public class FaissIndexScalarQuantizedFlat extends FaissIndex {
private static EnumMap VECTOR_DATA_TYPES = new EnumMap<>(
diff --git a/src/main/java/org/opensearch/knn/profile/ProfileDefaultKNNWeight.java b/src/main/java/org/opensearch/knn/profile/ProfileDefaultKNNWeight.java
index 7cc7bc0ef7..ca2a89cfd6 100644
--- a/src/main/java/org/opensearch/knn/profile/ProfileDefaultKNNWeight.java
+++ b/src/main/java/org/opensearch/knn/profile/ProfileDefaultKNNWeight.java
@@ -59,7 +59,7 @@ protected BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOExc
}
@Override
- protected TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k)
+ public TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k)
throws IOException {
return (TopDocs) KNNProfileUtil.profileBreakdown(
profile,
diff --git a/src/main/java/org/opensearch/knn/profile/ProfileMemoryOptKNNWeight.java b/src/main/java/org/opensearch/knn/profile/ProfileMemoryOptKNNWeight.java
index e532322f5d..05d6126938 100644
--- a/src/main/java/org/opensearch/knn/profile/ProfileMemoryOptKNNWeight.java
+++ b/src/main/java/org/opensearch/knn/profile/ProfileMemoryOptKNNWeight.java
@@ -62,7 +62,7 @@ protected BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOExc
}
@Override
- protected TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k)
+ public TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k)
throws IOException {
return (TopDocs) KNNProfileUtil.profileBreakdown(
profile,
diff --git a/src/main/java/org/opensearch/knn/profile/StopWatchUtils.java b/src/main/java/org/opensearch/knn/profile/StopWatchUtils.java
new file mode 100644
index 0000000000..08531d4f35
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/profile/StopWatchUtils.java
@@ -0,0 +1,37 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.profile;
+
+import lombok.experimental.UtilityClass;
+import org.apache.logging.log4j.Logger;
+import org.opensearch.common.Nullable;
+import org.opensearch.common.StopWatch;
+
+@UtilityClass
+public final class StopWatchUtils {
+ public static StopWatch startStopWatch(final Logger log) {
+ if (log.isDebugEnabled()) {
+ return new StopWatch().start();
+ }
+ return null;
+ }
+
+ public static void stopStopWatchAndLog(
+ final Logger log,
+ @Nullable final StopWatch stopWatch,
+ final String prefixMessage,
+ final int shardId,
+ final String segmentName,
+ final String field
+ ) {
+
+ if (stopWatch != null && log.isDebugEnabled()) {
+ stopWatch.stop();
+ final String logMessage = prefixMessage + " shard: [{}], segment: [{}], field: [{}], time in nanos:[{}] ";
+ log.debug(logMessage, shardId, segmentName, field, stopWatch.totalTime().nanos());
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/lucene/OptimisticKnnCollectorManager.java b/src/main/java/org/opensearch/lucene/OptimisticKnnCollectorManager.java
new file mode 100644
index 0000000000..c2b545ad60
--- /dev/null
+++ b/src/main/java/org/opensearch/lucene/OptimisticKnnCollectorManager.java
@@ -0,0 +1,85 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.lucene;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.search.knn.KnnCollectorManager;
+import org.apache.lucene.search.knn.KnnSearchStrategy;
+
+import java.io.IOException;
+
+/**
+ * Collector manager responsible for constructing the appropriate KNN collector
+ * depending on whether the optimistic search strategy is enabled.
+ *
+ * This class wraps an underlying {@code KnnCollectorManager} delegate and
+ * decides at runtime which collector implementation to use:
+ *
+ *
+ * - If the collector manager is configured to use optimistic search, it creates
+ * an optimistic collector, which performs a two-phase KNN search
+ * to optimize performance by reducing redundant segment searches.
+ * - Otherwise, it falls back to the standard {@link KnnCollectorManager}
+ * provided by the delegate, preserving default Lucene KNN search behavior.
+ *
+ *
+ * The optimistic search strategy operates in two phases:
+ *
+ * - Phase 1 – Executes KNN searches independently per segment
+ * with adjusted {@code k} values based on segment size and merges the results.
+ * - Phase 2 – Deep search: Re-runs searches only on segments that have
+ * promising results (based on a global score threshold) to refine recall efficiently.
+ *
+ *
+ * Example usage:
+ *
{@code
+ * KnnCollectorManager baseManager = new DefaultKnnCollectorManager(...);
+ * OptimisticKnnCollectorManager manager =
+ * new OptimisticKnnCollectorManager(baseManager, useOptimisticSearch);
+ * KnnCollector collector = manager.newCollector();
+ * }
+ *
+ * Ported from ...
+ */
+public class OptimisticKnnCollectorManager implements KnnCollectorManager {
+ // Constant controlling the degree of additional result exploration done during
+ // pro-rata search of segments.
+ private static final int LAMBDA = 16;
+
+ private final int k;
+ private final KnnCollectorManager delegate;
+
+ public OptimisticKnnCollectorManager(int k, KnnCollectorManager delegate) {
+ this.k = k;
+ this.delegate = delegate;
+ }
+
+ @Override
+ public KnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context) throws IOException {
+ // The delegate supports optimistic collection
+ if (delegate.isOptimistic()) {
+ @SuppressWarnings("resource")
+ float leafProportion = context.reader().maxDoc() / (float) context.parent.reader().maxDoc();
+ int perLeafTopK = perLeafTopKCalculation(k, leafProportion);
+ // if we divided by zero above, leafProportion can be NaN and then this would be 0
+ assert perLeafTopK > 0;
+ return delegate.newOptimisticCollector(visitedLimit, searchStrategy, context, perLeafTopK);
+ }
+ // We don't support optimistic collection, so just do regular execution path
+ return delegate.newCollector(visitedLimit, searchStrategy, context);
+ }
+
+ /**
+ * Returns perLeafTopK, the expected number (K * leafProportion) of hits in a leaf with the given
+ * proportion of the entire index, plus three standard deviations of a binomial distribution. Math
+ * says there is a 95% probability that this segment's contribution to the global top K hits are
+ * <= perLeafTopK.
+ */
+ private static int perLeafTopKCalculation(int k, float leafProportion) {
+ return (int) Math.max(1, k * leafProportion + LAMBDA * Math.sqrt(k * leafProportion * (1 - leafProportion)));
+ }
+}
diff --git a/src/main/java/org/opensearch/lucene/ReentrantKnnCollectorManager.java b/src/main/java/org/opensearch/lucene/ReentrantKnnCollectorManager.java
new file mode 100644
index 0000000000..87d94796ab
--- /dev/null
+++ b/src/main/java/org/opensearch/lucene/ReentrantKnnCollectorManager.java
@@ -0,0 +1,134 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.lucene;
+
+import lombok.RequiredArgsConstructor;
+import lombok.extern.log4j.Log4j2;
+import org.apache.lucene.codecs.lucene90.IndexedDISI;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.search.knn.KnnCollectorManager;
+import org.apache.lucene.search.knn.KnnSearchStrategy;
+
+import java.io.IOException;
+import java.util.Map;
+
+/**
+ * A {@link KnnCollectorManager} that enables re-entrant (multi-phase) KNN vector search
+ * by seeding the HNSW graph search with document IDs collected from a prior search phase.
+ *
+ * This implementation reuses top-ranked results (from the 1st-phase search presumably)
+ * as entry points for a 2nd-phase vector search. It converts previously collected
+ * {@link TopDocs} into corresponding vector entry points using
+ * {@link SeededTopDocsDISI} and {@link SeededMappedDISI}, enabling the internal searcher to start from these known points
+ * instead of beginning from random or default graph entry nodes.
+ *
+ * See ...
+ */
+@Log4j2
+@RequiredArgsConstructor
+public class ReentrantKnnCollectorManager implements KnnCollectorManager {
+
+ // The underlying (delegate) KNN collector manager used to create collectors.
+ private final KnnCollectorManager knnCollectorManager;
+
+ // Mapping from segment ordinal to previously collected {@link TopDocs}.
+ private final Map segmentOrdToResults;
+
+ // Query vector used for scoring during vector similarity search.
+ private final float[] query;
+
+ // Name of the vector field being searched.
+ private final String field;
+
+ /**
+ * Creates a new {@link KnnCollector} for the given segment.
+ *
+ * If 1st-phase results are available for the segment, this collector
+ * will seed the vector search with those document IDs. The document IDs
+ * are mapped to vector indices using {@link SeededMappedDISI}, which enables
+ * the HNSW search to begin from those known entry points.
+ *
+ * If no prior results exist or no vector scorer is available, the method
+ * falls back to a delegate collector.
+ *
+ * @param visitLimit the maximum number of graph nodes that can be visited
+ * @param searchStrategy the search strategy to use (e.g., HNSW or brute-force)
+ * @param ctx the leaf reader context for the current segment
+ * @return a seeded {@link KnnCollector} that reuses prior phase entry points,
+ * or a delegate collector if no seeding is possible
+ * @throws IOException if an I/O error occurs during setup
+ */
+ @Override
+ public KnnCollector newCollector(int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx) throws IOException {
+ // Get delegate collector for fallback or empty cases
+ final KnnCollector delegateCollector = knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx);
+ final TopDocs seedTopDocs = segmentOrdToResults.get(ctx.ord);
+
+ if (seedTopDocs == null || seedTopDocs.totalHits.value() == 0) {
+ log.warn("Seed top docs was empty, expected non-empty top results to be given.");
+ // Normally shouldn't happen — indicates missing or empty seed results
+ assert false;
+ return delegateCollector;
+ }
+
+ // Obtain the per-segment vector values
+ final LeafReader reader = ctx.reader();
+ final FloatVectorValues vectorValues = reader.getFloatVectorValues(field);
+ if (vectorValues == null) {
+ log.error("Acquired null {} for field [{}]", FloatVectorValues.class.getSimpleName(), field);
+ // Validates the field exists, otherwise throws informative exception
+ FloatVectorValues.checkField(reader, field);
+ return null;
+ }
+
+ // Create a vector scorer for the query vector
+ final VectorScorer scorer = vectorValues.scorer(query);
+
+ if (scorer == null) {
+ log.error("Acquired null {} for field [{}]", VectorScorer.class.getSimpleName(), field);
+ // Normally shouldn't happen
+ assert false;
+ return delegateCollector;
+ }
+
+ // Get DocIdSetIterator from scorer
+ DocIdSetIterator vectorIterator = scorer.iterator();
+
+ // Convert to an indexed iterator if possible (for sparse vectors)
+ // Note that we're extracting DISI from Lucene's flat vector.
+ if (vectorIterator instanceof IndexedDISI indexedDISI) {
+ vectorIterator = IndexedDISI.asDocIndexIterator(indexedDISI);
+ }
+
+ // Map seed document IDs to vector indices to use as HNSW entry points
+ if (vectorIterator instanceof KnnVectorValues.DocIndexIterator indexIterator) {
+ DocIdSetIterator seedDocs = new SeededMappedDISI(indexIterator, new SeededTopDocsDISI(seedTopDocs));
+ return knnCollectorManager.newCollector(
+ visitLimit,
+ new KnnSearchStrategy.Seeded(seedDocs, seedTopDocs.scoreDocs.length, searchStrategy),
+ ctx
+ );
+ }
+
+ log.error(
+ "`vectorIterator` was not one of [{}, {}] and was {}",
+ IndexedDISI.class.getSimpleName(),
+ KnnVectorValues.DocIndexIterator.class.getSimpleName(),
+ vectorIterator == null ? "null" : vectorIterator.getClass().getSimpleName()
+ );
+
+ // This should not occur; fallback to delegate to prevent infinite loops
+ assert false;
+ return delegateCollector;
+ }
+}
diff --git a/src/main/java/org/opensearch/lucene/SeededMappedDISI.java b/src/main/java/org/opensearch/lucene/SeededMappedDISI.java
new file mode 100644
index 0000000000..7060ec96ba
--- /dev/null
+++ b/src/main/java/org/opensearch/lucene/SeededMappedDISI.java
@@ -0,0 +1,103 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.lucene;
+
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.search.DocIdSetIterator;
+
+import java.io.IOException;
+
+/**
+ * A {@link DocIdSetIterator} that maps document IDs from a source iterator
+ * to their corresponding vector indices in a {@link KnnVectorValues.DocIndexIterator}.
+ *
+ * This class advances the {@code indexedDISI} (which provides access to vector indices)
+ * in sync with the {@code sourceDISI} (which provides document IDs). For each document ID
+ * emitted by the source iterator, it advances the index iterator to the same document ID
+ * and returns the associated vector index.
+ *
+ * Typical usage is when document-level matches (from a collector or filter)
+ * need to be mapped back to the vector index space for further vector-based operations.
+ */
+public class SeededMappedDISI extends DocIdSetIterator {
+
+ // Iterator over vector values that exposes both doc IDs and their corresponding indices.
+ private final KnnVectorValues.DocIndexIterator indexedDISI;
+
+ // Source iterator that provides the sequence of document IDs to be mapped.
+ private final DocIdSetIterator sourceDISI;
+
+ /**
+ * Constructs a {@code SeededMappedDISI} that synchronizes a source document iterator
+ * with a vector index iterator.
+ *
+ * @param indexedDISI the {@link KnnVectorValues.DocIndexIterator} used to map
+ * document IDs to vector indices
+ * @param sourceDISI the {@link DocIdSetIterator} providing the source document IDs
+ */
+ public SeededMappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) {
+ this.indexedDISI = indexedDISI;
+ this.sourceDISI = sourceDISI;
+ }
+
+ /**
+ * Advances the source iterator to the first document ID that is greater than or equal
+ * to the specified target, then advances the index iterator to the same document ID.
+ *
+ * The returned value is the vector index corresponding to that document.
+ *
+ * @param target the target document ID
+ * @return the corresponding vector index, or {@link #NO_MORE_DOCS} if the end is reached
+ * @throws IOException if an I/O error occurs
+ */
+ @Override
+ public int advance(int target) throws IOException {
+ int newTarget = sourceDISI.advance(target);
+ if (newTarget != NO_MORE_DOCS) {
+ indexedDISI.advance(newTarget);
+ }
+ return docID();
+ }
+
+ /**
+ * Returns an estimate of the cost (number of documents) of iterating.
+ *
+ * @return the cost estimate from the source iterator
+ */
+ @Override
+ public long cost() {
+ return sourceDISI.cost();
+ }
+
+ /**
+ * Returns the current vector index corresponding to the current document position.
+ *
+ * @return the current vector index, or {@link #NO_MORE_DOCS} if iteration has completed
+ */
+ @Override
+ public int docID() {
+ if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) {
+ return NO_MORE_DOCS;
+ }
+ return indexedDISI.index();
+ }
+
+ /**
+ * Advances to the next document in the source iterator and updates the index iterator
+ * to the same document. Returns the corresponding vector index.
+ *
+ * @return the next vector index, or {@link #NO_MORE_DOCS} if there are no more documents
+ * @throws IOException if an I/O error occurs
+ */
+ @Override
+ public int nextDoc() throws IOException {
+ int newTarget = sourceDISI.nextDoc();
+ if (newTarget != NO_MORE_DOCS) {
+ indexedDISI.advance(newTarget);
+ }
+ return docID();
+ }
+}
diff --git a/src/main/java/org/opensearch/lucene/SeededTopDocsDISI.java b/src/main/java/org/opensearch/lucene/SeededTopDocsDISI.java
new file mode 100644
index 0000000000..8a5bd0b92a
--- /dev/null
+++ b/src/main/java/org/opensearch/lucene/SeededTopDocsDISI.java
@@ -0,0 +1,102 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.lucene;
+
+import java.io.IOException;
+import java.util.Arrays;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.TopDocs;
+
+/**
+ * A {@link DocIdSetIterator} implementation that iterates over the document IDs
+ * collected in a {@link TopDocs} object.
+ *
+ * This class extracts document IDs from the given {@link TopDocs}, sorts them
+ * in ascending order, and then provides sequential access to those IDs.
+ *
+ * It can be used to re-iterate over the documents returned by a collector,
+ * ensuring deterministic iteration order regardless of the original collection sequence.
+ */
+public class SeededTopDocsDISI extends DocIdSetIterator {
+
+ /** Sorted array of document IDs extracted from {@link TopDocs}. */
+ private final int[] sortedDocIds;
+
+ /** Current index in {@link #sortedDocIds}. Starts at -1 before iteration. */
+ private int idx = -1;
+
+ /**
+ * Constructs a {@code SeededTopDocsDISI} from the given {@link TopDocs}.
+ *
+ * The document IDs are extracted from {@link org.apache.lucene.search.ScoreDoc#doc}
+ * and sorted in ascending order. The collector's base offset, if any, is already removed.
+ *
+ * @param topDocs the {@link TopDocs} containing the collected document IDs
+ */
+ public SeededTopDocsDISI(final TopDocs topDocs) {
+ sortedDocIds = new int[topDocs.scoreDocs.length];
+ for (int i = 0; i < topDocs.scoreDocs.length; i++) {
+ // Remove the doc base as added by the collector
+ sortedDocIds[i] = topDocs.scoreDocs[i].doc;
+ }
+ Arrays.sort(sortedDocIds);
+ }
+
+ /**
+ * Advances to the first document which is greater than or equals to the current one whose ID is
+ * greater than or equal to the given target.
+ *
+ * This implementation delegates to {@link #slowAdvance(int)} for simplicity.
+ *
+ * @param target the target document ID
+ * @return the next matching document ID, or {@link #NO_MORE_DOCS} if none remain
+ * @throws IOException never thrown (declared for interface compatibility)
+ */
+ @Override
+ public int advance(int target) throws IOException {
+ return slowAdvance(target);
+ }
+
+ /**
+ * Returns an estimate of the number of documents this iterator will traverse.
+ *
+ * @return the number of document IDs available
+ */
+ @Override
+ public long cost() {
+ return sortedDocIds.length;
+ }
+
+ /**
+ * Returns the current document ID.
+ *
+ * @return the current doc ID, {@code -1} if not yet started,
+ * or {@link #NO_MORE_DOCS} if iteration is complete
+ */
+ @Override
+ public int docID() {
+ if (idx == -1) {
+ // Not advanced
+ return -1;
+ } else if (idx >= sortedDocIds.length) {
+ // Exhausted doc ids
+ return DocIdSetIterator.NO_MORE_DOCS;
+ } else {
+ return sortedDocIds[idx];
+ }
+ }
+
+ /**
+ * Advances to the next document ID in sorted order.
+ *
+ * @return the next document ID, or {@link #NO_MORE_DOCS} if there are no more
+ */
+ @Override
+ public int nextDoc() {
+ idx += 1;
+ return docID();
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java
index 9a7cf6b9da..9fb16f800a 100644
--- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java
+++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java
@@ -742,15 +742,19 @@ private void do_testDoToQuery_whenMemoryOptimizedSearchIsEnabled(
final Query query = knnQueryBuilder.doToQuery(mockQueryShardContext);
// If memory optimized search is on then, use Lucene query
final KNNQuery knnQuery;
- if (doRescore) {
- assertTrue(query instanceof NativeEngineKnnVectorQuery);
+ final boolean memoryOptimizedEnabled = memoryOptimizedSearchEnabled && memoryOptimizedSearchEnabledInField;
+ if (memoryOptimizedEnabled) {
+ // Regardless rescoring, once memory optimized search is enabled, It always uses NativeEngineKnnVectorQuery
knnQuery = ((NativeEngineKnnVectorQuery) query).getKnnQuery();
} else {
- assertFalse(query instanceof NativeEngineKnnVectorQuery);
- knnQuery = (KNNQuery) query;
+ // We use NativeEngineKnnVectorQuery only if when rescoring when memory optimized is turned off.
+ if (doRescore) {
+ knnQuery = ((NativeEngineKnnVectorQuery) query).getKnnQuery();
+ } else {
+ knnQuery = (KNNQuery) query;
+ }
}
- final boolean memoryOptimizedEnabled = memoryOptimizedSearchEnabled && memoryOptimizedSearchEnabledInField;
if (memoryOptimizedEnabled) {
if (vectorDataType == VectorDataType.FLOAT) {
assertEquals(queryVector.length, knnQuery.getQueryVector().length);
diff --git a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java
index 90b1e101bd..0615a1ff73 100644
--- a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java
+++ b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java
@@ -29,7 +29,7 @@ public void testReduceToTopK() {
List