Skip to content

Commit 25aa792

Browse files
0ctopus13primegithub-actions[bot]
authored andcommitted
Enable optimistic search to memory optimized search. (#2933)
* Added MMapByteVectorValues for FP16 native scoring in LuceneOnFaiss. (#2904) Signed-off-by: Dooyong Kim <[email protected]> * Enable optimistic search for LuceneOnFaiss. Signed-off-by: Dooyong Kim <[email protected]> * Added debug logging, measure execution times of 2nd search. Signed-off-by: Dooyong Kim <[email protected]> --------- Signed-off-by: Dooyong Kim <[email protected]> (cherry picked from commit a95a7d0)
1 parent 79a87a0 commit 25aa792

24 files changed

+1801
-77
lines changed

build.gradle

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ def commonIntegTest(RestIntegTestTask task, project, integTestDependOnJniLib, op
482482
task.systemProperty 'cluster.debug', isDebuggingCluster
483483
// Set number of nodes system property to be used in tests
484484
task.systemProperty 'cluster.number_of_nodes', "${_numNodes}"
485+
485486
// There seems to be an issue when running multi node run or integ tasks with unicast_hosts
486487
// not being written, the waitForAllConditions ensures it's written
487488
task.getClusters().forEach { cluster ->
@@ -541,6 +542,7 @@ def commonIntegTestClusters(OpenSearchCluster cluster, _numNodes){
541542
debugPort += 1
542543
}
543544
}
545+
544546
cluster.systemProperty("java.library.path", "$rootDir/jni/build/release")
545547
final testSnapshotFolder = file("${buildDir}/testSnapshotFolder")
546548
testSnapshotFolder.mkdirs()
@@ -550,6 +552,8 @@ def commonIntegTestClusters(OpenSearchCluster cluster, _numNodes){
550552

551553
testClusters.integTest {
552554
commonIntegTestClusters(it, _numNodes)
555+
// Forcing optimistic search for testing
556+
systemProperty 'mem_opt_srch.force_reenter', 'true'
553557
}
554558

555559
testClusters.integTestRemoteIndexBuild {
@@ -562,6 +566,8 @@ testClusters.integTestRemoteIndexBuild {
562566
keystore 's3.client.default.access_key', "${System.getProperty("access_key")}"
563567
keystore 's3.client.default.secret_key', "${System.getProperty("secret_key")}"
564568
keystore 's3.client.default.session_token', "${System.getProperty("session_token")}"
569+
// Forcing optimistic search for testing
570+
systemProperty 'mem_opt_srch.force_reenter', 'true'
565571
}
566572

567573
task integTestRemote(type: RestIntegTestTask) {

src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ public static Query create(CreateQueryRequest createQueryRequest) {
117117
.build();
118118
}
119119

120-
if (createQueryRequest.getRescoreContext().isPresent()
120+
if (memoryOptimizedSearchEnabled
121+
|| createQueryRequest.getRescoreContext().isPresent()
121122
|| (ENGINES_SUPPORTING_NESTED_FIELDS.contains(createQueryRequest.getKnnEngine()) && expandNested)) {
122123
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.getInstance(), expandNested);
123124
}

src/main/java/org/opensearch/knn/index/query/KNNWeight.java

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.apache.lucene.util.BitSetIterator;
2525
import org.apache.lucene.util.Bits;
2626
import org.apache.lucene.util.FixedBitSet;
27-
import org.opensearch.common.Nullable;
2827
import org.opensearch.common.StopWatch;
2928
import org.opensearch.common.lucene.Lucene;
3029
import org.opensearch.knn.common.FieldInfoExtractor;
@@ -52,6 +51,9 @@
5251
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
5352
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
5453

54+
import static org.opensearch.knn.profile.StopWatchUtils.startStopWatch;
55+
import static org.opensearch.knn.profile.StopWatchUtils.stopStopWatchAndLog;
56+
5557
/**
5658
* {@link KNNWeight} serves as a template for implementing approximate nearest neighbor (ANN)
5759
* and radius search over a native index type, such as Faiss.
@@ -298,12 +300,13 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
298300
final SegmentReader reader = Lucene.segmentReader(context.reader());
299301
final String segmentName = reader.getSegmentName();
300302

301-
StopWatch stopWatch = startStopWatch();
303+
final StopWatch stopWatch = startStopWatch(log);
302304
final BitSet filterBitSet = getFilteredDocsBitSet(context);
303-
stopStopWatchAndLog(stopWatch, "FilterBitSet creation", segmentName);
305+
stopStopWatchAndLog(log, stopWatch, "FilterBitSet creation", knnQuery.getShardId(), segmentName, knnQuery.getField());
306+
307+
// Save its cardinality, as the cardinality calculation is expensive.
308+
final int filterCardinality = filterBitSet.cardinality();
304309

305-
final int maxDoc = context.reader().maxDoc();
306-
int filterCardinality = filterBitSet.cardinality();
307310
// We don't need to go to JNI layer if no documents are found which satisfy the filters
308311
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
309312
// place,
@@ -320,19 +323,19 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
320323
* This improves the recall.
321324
*/
322325
if (isFilteredExactSearchPreferred(filterCardinality)) {
323-
TopDocs result = doExactSearch(context, new BitSetIterator(filterBitSet, filterCardinality), filterCardinality, k);
324-
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
326+
final TopDocs result = doExactSearch(context, new BitSetIterator(filterBitSet, filterCardinality), filterCardinality, k);
327+
return new PerLeafResult(
328+
filterWeight == null ? null : filterBitSet,
329+
filterCardinality,
330+
result,
331+
PerLeafResult.SearchMode.EXACT_SEARCH
332+
);
325333
}
326334

327-
/*
328-
* If filters match all docs in this segment, then null should be passed as filterBitSet
329-
* so that it will not do a bitset look up in bottom search layer.
330-
*/
331-
final BitSet annFilter = (filterWeight != null && filterCardinality == maxDoc) ? null : filterBitSet;
335+
final StopWatch annStopWatch = startStopWatch(log);
336+
final TopDocs topDocs = approximateSearch(context, filterBitSet, filterCardinality, k);
337+
stopStopWatchAndLog(log, stopWatch, "ANN search", knnQuery.getShardId(), segmentName, knnQuery.getField());
332338

333-
StopWatch annStopWatch = startStopWatch();
334-
final TopDocs topDocs = approximateSearch(context, annFilter, filterCardinality, k);
335-
stopStopWatchAndLog(annStopWatch, "ANN search", segmentName);
336339
if (knnQuery.isExplain()) {
337340
knnExplanation.addLeafResult(context.id(), topDocs.scoreDocs.length);
338341
}
@@ -341,18 +344,21 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
341344
// results less than K, though we have more than k filtered docs
342345
if (isExactSearchRequire(context, filterCardinality, topDocs.scoreDocs.length)) {
343346
final BitSetIterator docs = filterWeight != null ? new BitSetIterator(filterBitSet, filterCardinality) : null;
344-
TopDocs result = doExactSearch(context, docs, filterCardinality, k);
345-
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
347+
final TopDocs result = doExactSearch(context, docs, filterCardinality, k);
348+
return new PerLeafResult(
349+
filterWeight == null ? null : filterBitSet,
350+
filterCardinality,
351+
result,
352+
PerLeafResult.SearchMode.EXACT_SEARCH
353+
);
346354
}
347-
return new PerLeafResult(filterWeight == null ? null : filterBitSet, topDocs);
348-
}
349355

350-
private void stopStopWatchAndLog(@Nullable final StopWatch stopWatch, final String prefixMessage, String segmentName) {
351-
if (log.isDebugEnabled() && stopWatch != null) {
352-
stopWatch.stop();
353-
final String logMessage = prefixMessage + " shard: [{}], segment: [{}], field: [{}], time in nanos:[{}] ";
354-
log.debug(logMessage, knnQuery.getShardId(), segmentName, knnQuery.getField(), stopWatch.totalTime().nanos());
355-
}
356+
return new PerLeafResult(
357+
filterWeight == null ? null : filterBitSet,
358+
filterCardinality,
359+
topDocs,
360+
PerLeafResult.SearchMode.APPROXIMATE_SEARCH
361+
);
356362
}
357363

358364
protected BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
@@ -413,9 +419,33 @@ private TopDocs doExactSearch(
413419
return exactSearch(context, exactSearcherContextBuilder.build());
414420
}
415421

416-
protected TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k)
417-
throws IOException {
422+
/**
423+
* Performs an approximate nearest neighbor (ANN) search on the provided index segment.
424+
* <p>
425+
* This method prepares all necessary query metadata before triggering the actual ANN search.
426+
* It extracts the {@code model_id} from field-level attributes if required, retrieves any
427+
* quantization or auxiliary metadata associated with the vector field, and applies quantization
428+
* to the query vector when applicable. After these preprocessing steps, it invokes
429+
* {@code doANNSearch(LeafReaderContext, BitSet, int, int)} to execute the approximate search
430+
* and obtain the top results.
431+
*
432+
* @param context the {@link LeafReaderContext} representing the current index segment
433+
* @param filterIdsBitSet an optional {@link BitSet} indicating document IDs to include in the search;
434+
* may be {@code null} if no filtering is required
435+
* @param filterCardinality the number of documents included in {@code filterIdsBitSet};
436+
* used to optimize search filtering
437+
* @param k the number of nearest neighbors to retrieve
438+
* @return a {@link TopDocs} object containing the top {@code k} approximate search results
439+
* @throws IOException if an error occurs while reading index data or accessing vector fields
440+
*/
441+
public TopDocs approximateSearch(
442+
final LeafReaderContext context,
443+
final BitSet filterIdsBitSet,
444+
final int filterCardinality,
445+
final int k
446+
) throws IOException {
418447
final SegmentReader reader = Lucene.segmentReader(context.reader());
448+
419449
FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
420450

421451
if (fieldInfo == null) {
@@ -465,6 +495,11 @@ protected TopDocs approximateSearch(final LeafReaderContext context, final BitSe
465495
// TODO: Change type of vector once more quantization methods are supported
466496
byte[] quantizedVector = maybeQuantizeVector(segmentLevelQuantizationInfo);
467497
float[] transformedVector = maybeTransformVector(segmentLevelQuantizationInfo, spaceType);
498+
/*
499+
* If filters match all docs in this segment, then null should be passed as filterBitSet
500+
* so that it will not do a bitset look up in bottom search layer.
501+
*/
502+
final BitSet annFilter = filterCardinality == context.reader().maxDoc() ? null : filterIdsBitSet;
468503

469504
KNNCounter.GRAPH_QUERY_REQUESTS.increment();
470505
final TopDocs results = doANNSearch(
@@ -477,8 +512,8 @@ protected TopDocs approximateSearch(final LeafReaderContext context, final BitSe
477512
quantizedVector,
478513
transformedVector,
479514
modelId,
480-
filterIdsBitSet,
481-
cardinality,
515+
annFilter,
516+
filterCardinality,
482517
k
483518
);
484519

@@ -553,10 +588,10 @@ protected void addExplainIfRequired(final TopDocs results, final KNNEngine knnEn
553588
*/
554589
public TopDocs exactSearch(final LeafReaderContext leafReaderContext, final ExactSearcher.ExactSearcherContext exactSearcherContext)
555590
throws IOException {
556-
StopWatch stopWatch = startStopWatch();
591+
final StopWatch stopWatch = startStopWatch(log);
557592
TopDocs exactSearchResults = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext);
558593
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
559-
stopStopWatchAndLog(stopWatch, "Exact search", reader.getSegmentName());
594+
stopStopWatchAndLog(log, stopWatch, "Exact search", knnQuery.getShardId(), reader.getSegmentName(), knnQuery.getField());
560595
return exactSearchResults;
561596
}
562597

@@ -673,13 +708,6 @@ private boolean isMissingNativeEngineFiles(LeafReaderContext context) {
673708
return engineFiles.isEmpty();
674709
}
675710

676-
private StopWatch startStopWatch() {
677-
if (log.isDebugEnabled()) {
678-
return new StopWatch().start();
679-
}
680-
return null;
681-
}
682-
683711
protected int[] getParentIdsArray(final LeafReaderContext context) throws IOException {
684712
if (knnQuery.getParentsFilter() == null) {
685713
return null;

0 commit comments

Comments
 (0)