Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,23 @@ publishing {

compileJava {
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])

// Since MemorySegment is not available until JDK22, exclude it when packaging and only include it for Java22+.
def javaExt = project.extensions.getByType(JavaPluginExtension)
if (javaExt.sourceCompatibility <= JavaVersion.VERSION_21 || javaExt.targetCompatibility <= JavaVersion.VERSION_21) {
exclude("org/opensearch/knn/memoryoptsearch/MemorySegmentAddressExtractorJDK22.java")
}
}
compileTestJava {
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])
}
javadoc {
// Block generating Java doc as it will complain MemorySegment is under preview for Java21.
def javaExt = project.extensions.getByType(JavaPluginExtension)
if (javaExt.sourceCompatibility <= JavaVersion.VERSION_21 || javaExt.targetCompatibility <= JavaVersion.VERSION_21) {
exclude("org/opensearch/knn/memoryoptsearch/MemorySegmentAddressExtractorJDK22.java")
}
}
compileTestFixturesJava {
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])
}
Expand Down Expand Up @@ -482,6 +495,7 @@ def commonIntegTest(RestIntegTestTask task, project, integTestDependOnJniLib, op
task.systemProperty 'cluster.debug', isDebuggingCluster
// Set number of nodes system property to be used in tests
task.systemProperty 'cluster.number_of_nodes', "${_numNodes}"

// There seems to be an issue when running multi node run or integ tasks with unicast_hosts
// not being written, the waitForAllConditions ensures it's written
task.getClusters().forEach { cluster ->
Expand Down Expand Up @@ -541,6 +555,7 @@ def commonIntegTestClusters(OpenSearchCluster cluster, _numNodes){
debugPort += 1
}
}

cluster.systemProperty("java.library.path", "$rootDir/jni/build/release")
final testSnapshotFolder = file("${buildDir}/testSnapshotFolder")
testSnapshotFolder.mkdirs()
Expand All @@ -550,6 +565,8 @@ def commonIntegTestClusters(OpenSearchCluster cluster, _numNodes){

testClusters.integTest {
commonIntegTestClusters(it, _numNodes)
// Forcing optimistic search for testing
systemProperty 'mem_opt_srch.force_reenter', 'true'
Comment on lines +568 to +569
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to run search for memory optimized cases then lets create another gradle task and also a new CI that runs that task.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this to force it to run 2nd search in optimistic.
Since the 2nd search will kick off only if there's segment whose min score > the min score in merged results, it was tricky for me to make the data.

}

testClusters.integTestRemoteIndexBuild {
Expand All @@ -562,6 +579,8 @@ testClusters.integTestRemoteIndexBuild {
keystore 's3.client.default.access_key', "${System.getProperty("access_key")}"
keystore 's3.client.default.secret_key', "${System.getProperty("secret_key")}"
keystore 's3.client.default.session_token', "${System.getProperty("session_token")}"
// Forcing optimistic search for testing
systemProperty 'mem_opt_srch.force_reenter', 'true'
}

task integTestRemote(type: RestIntegTestTask) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.build();
}

if (createQueryRequest.getRescoreContext().isPresent()
if (memoryOptimizedSearchEnabled
|| createQueryRequest.getRescoreContext().isPresent()
|| (ENGINES_SUPPORTING_NESTED_FIELDS.contains(createQueryRequest.getKnnEngine()) && expandNested)) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.getInstance(), expandNested);
}
Expand Down
104 changes: 66 additions & 38 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.Nullable;
import org.opensearch.common.StopWatch;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
Expand Down Expand Up @@ -52,6 +51,9 @@
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

import static org.opensearch.knn.profile.StopWatchUtils.startStopWatch;
import static org.opensearch.knn.profile.StopWatchUtils.stopStopWatchAndLog;

/**
* {@link KNNWeight} serves as a template for implementing approximate nearest neighbor (ANN)
* and radius search over a native index type, such as Faiss.
Expand Down Expand Up @@ -298,12 +300,13 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
final SegmentReader reader = Lucene.segmentReader(context.reader());
final String segmentName = reader.getSegmentName();

StopWatch stopWatch = startStopWatch();
final StopWatch stopWatch = startStopWatch(log);
final BitSet filterBitSet = getFilteredDocsBitSet(context);
stopStopWatchAndLog(stopWatch, "FilterBitSet creation", segmentName);
stopStopWatchAndLog(log, stopWatch, "FilterBitSet creation", knnQuery.getShardId(), segmentName, knnQuery.getField());

// Save its cardinality, as the cardinality calculation is expensive.
final int filterCardinality = filterBitSet.cardinality();

final int maxDoc = context.reader().maxDoc();
int filterCardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
// place,
Expand All @@ -320,19 +323,19 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
* This improves the recall.
*/
if (isFilteredExactSearchPreferred(filterCardinality)) {
TopDocs result = doExactSearch(context, new BitSetIterator(filterBitSet, filterCardinality), filterCardinality, k);
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
final TopDocs result = doExactSearch(context, new BitSetIterator(filterBitSet, filterCardinality), filterCardinality, k);
return new PerLeafResult(
filterWeight == null ? null : filterBitSet,
filterCardinality,
result,
PerLeafResult.SearchMode.EXACT_SEARCH
);
}

/*
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic has been moved to approximateSearch

* If filters match all docs in this segment, then null should be passed as filterBitSet
* so that it will not do a bitset look up in bottom search layer.
*/
final BitSet annFilter = (filterWeight != null && filterCardinality == maxDoc) ? null : filterBitSet;
final StopWatch annStopWatch = startStopWatch(log);
final TopDocs topDocs = approximateSearch(context, filterBitSet, filterCardinality, k);
stopStopWatchAndLog(log, stopWatch, "ANN search", knnQuery.getShardId(), segmentName, knnQuery.getField());

StopWatch annStopWatch = startStopWatch();
final TopDocs topDocs = approximateSearch(context, annFilter, filterCardinality, k);
stopStopWatchAndLog(annStopWatch, "ANN search", segmentName);
if (knnQuery.isExplain()) {
knnExplanation.addLeafResult(context.id(), topDocs.scoreDocs.length);
}
Expand All @@ -341,18 +344,21 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
// results less than K, though we have more than k filtered docs
if (isExactSearchRequire(context, filterCardinality, topDocs.scoreDocs.length)) {
final BitSetIterator docs = filterWeight != null ? new BitSetIterator(filterBitSet, filterCardinality) : null;
TopDocs result = doExactSearch(context, docs, filterCardinality, k);
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
final TopDocs result = doExactSearch(context, docs, filterCardinality, k);
return new PerLeafResult(
filterWeight == null ? null : filterBitSet,
filterCardinality,
result,
PerLeafResult.SearchMode.EXACT_SEARCH
);
}
return new PerLeafResult(filterWeight == null ? null : filterBitSet, topDocs);
}

private void stopStopWatchAndLog(@Nullable final StopWatch stopWatch, final String prefixMessage, String segmentName) {
if (log.isDebugEnabled() && stopWatch != null) {
stopWatch.stop();
final String logMessage = prefixMessage + " shard: [{}], segment: [{}], field: [{}], time in nanos:[{}] ";
log.debug(logMessage, knnQuery.getShardId(), segmentName, knnQuery.getField(), stopWatch.totalTime().nanos());
}
return new PerLeafResult(
filterWeight == null ? null : filterBitSet,
filterCardinality,
topDocs,
PerLeafResult.SearchMode.APPROXIMATE_SEARCH
);
}

protected BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
Expand Down Expand Up @@ -413,9 +419,33 @@ private TopDocs doExactSearch(
return exactSearch(context, exactSearcherContextBuilder.build());
}

protected TopDocs approximateSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k)
throws IOException {
/**
* Performs an approximate nearest neighbor (ANN) search on the provided index segment.
* <p>
* This method prepares all necessary query metadata before triggering the actual ANN search.
* It extracts the {@code model_id} from field-level attributes if required, retrieves any
* quantization or auxiliary metadata associated with the vector field, and applies quantization
* to the query vector when applicable. After these preprocessing steps, it invokes
* {@code doANNSearch(LeafReaderContext, BitSet, int, int)} to execute the approximate search
* and obtain the top results.
*
* @param context the {@link LeafReaderContext} representing the current index segment
* @param filterIdsBitSet an optional {@link BitSet} indicating document IDs to include in the search;
* may be {@code null} if no filtering is required
* @param filterCardinality the number of documents included in {@code filterIdsBitSet};
* used to optimize search filtering
* @param k the number of nearest neighbors to retrieve
* @return a {@link TopDocs} object containing the top {@code k} approximate search results
* @throws IOException if an error occurs while reading index data or accessing vector fields
*/
public TopDocs approximateSearch(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we are making this function public?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this particular function in optimistic second search. Otherwise, if using searchLeaf, then we will end up building filter bitset twice.

final LeafReaderContext context,
final BitSet filterIdsBitSet,
final int filterCardinality,
final int k
) throws IOException {
final SegmentReader reader = Lucene.segmentReader(context.reader());

FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());

if (fieldInfo == null) {
Expand Down Expand Up @@ -465,6 +495,11 @@ protected TopDocs approximateSearch(final LeafReaderContext context, final BitSe
// TODO: Change type of vector once more quantization methods are supported
byte[] quantizedVector = maybeQuantizeVector(segmentLevelQuantizationInfo);
float[] transformedVector = maybeTransformVector(segmentLevelQuantizationInfo, spaceType);
/*
* If filters match all docs in this segment, then null should be passed as filterBitSet
* so that it will not do a bitset look up in bottom search layer.
*/
final BitSet annFilter = filterCardinality == context.reader().maxDoc() ? null : filterIdsBitSet;

KNNCounter.GRAPH_QUERY_REQUESTS.increment();
final TopDocs results = doANNSearch(
Expand All @@ -477,8 +512,8 @@ protected TopDocs approximateSearch(final LeafReaderContext context, final BitSe
quantizedVector,
transformedVector,
modelId,
filterIdsBitSet,
cardinality,
annFilter,
filterCardinality,
k
);

Expand Down Expand Up @@ -553,10 +588,10 @@ protected void addExplainIfRequired(final TopDocs results, final KNNEngine knnEn
*/
public TopDocs exactSearch(final LeafReaderContext leafReaderContext, final ExactSearcher.ExactSearcherContext exactSearcherContext)
throws IOException {
StopWatch stopWatch = startStopWatch();
final StopWatch stopWatch = startStopWatch(log);
TopDocs exactSearchResults = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext);
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
stopStopWatchAndLog(stopWatch, "Exact search", reader.getSegmentName());
stopStopWatchAndLog(log, stopWatch, "Exact search", knnQuery.getShardId(), reader.getSegmentName(), knnQuery.getField());
return exactSearchResults;
}

Expand Down Expand Up @@ -673,13 +708,6 @@ private boolean isMissingNativeEngineFiles(LeafReaderContext context) {
return engineFiles.isEmpty();
}

private StopWatch startStopWatch() {
if (log.isDebugEnabled()) {
return new StopWatch().start();
}
return null;
}

protected int[] getParentIdsArray(final LeafReaderContext context) throws IOException {
if (knnQuery.getParentsFilter() == null) {
return null;
Expand Down
Loading
Loading