Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds profiler for knn query #2450

Open
wants to merge 4 commits into
base: 2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Make the build work for M series MacOS without manual code changes and local JAVA_HOME config (#2397)[https://github.com/opensearch-project/k-NN/pull/2397]
- Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408)[https://github.com/opensearch-project/k-NN/pull/2408]
- Remove skip building graph check for quantization use case (#2430)[https://github.com/opensearch-project/k-NN/2430]
- Adds support for knn related components for profiling knn query (#2450)[https://github.com/opensearch-project/k-NN/pull/2450]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Fixing the bug where search fails with "fields" parameter for an index with a knn_vector field (#2314)[https://github.com/opensearch-project/k-NN/pull/2314]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.KNNIterator;
import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
Expand Down
56 changes: 51 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.profile.ContextualProfileBreakdown;
import org.opensearch.search.profile.Timer;
import org.opensearch.search.profile.query.ProfileWeight;
import org.opensearch.search.profile.query.QueryProfiler;
import org.opensearch.search.profile.query.QueryTimingType;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -168,20 +174,60 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled to true");
}
final Weight filterWeight = getFilterWeight(searcher);
if (filterWeight != null) {
return new KNNWeight(this, boost, filterWeight);

// TODO: We can look into decoupling the profiler logic with main query logic
final QueryProfiler profiler = getProfiler(searcher);
final ContextualProfileBreakdown<QueryTimingType> knnProfileBreakDown = getProfileBreakdown(profiler, this);
ContextualProfileBreakdown<QueryTimingType> filterQueryBreakdown = getProfileBreakdown(profiler, filterQuery);

return KNNWeight.builder()
.knnQuery(this)
.boost(boost)
.filterWeight(getFilterWeight(searcher, filterQueryBreakdown))
.knnQueryProfiler(knnProfileBreakDown)
.build();
}

private ContextualProfileBreakdown<QueryTimingType> getProfileBreakdown(QueryProfiler profiler, Query query) {
if (profiler != null && query != null) {
return profiler.getQueryBreakdown(query);
}
return null;
}

private QueryProfiler getProfiler(IndexSearcher searcher) {
if (searcher instanceof ContextIndexSearcher) {
ContextIndexSearcher contextIndexSearcher = (ContextIndexSearcher) searcher;
if (contextIndexSearcher.getProfiler() != null) {
return contextIndexSearcher.getProfiler();
}
}
return new KNNWeight(this, boost);
return null;
}

private Weight getFilterWeight(IndexSearcher searcher) throws IOException {
private Weight getFilterWeight(IndexSearcher searcher, ContextualProfileBreakdown<QueryTimingType> profiler) throws IOException {
if (this.getFilterQuery() != null) {
// Run the filter query
final BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.getFilterQuery(), BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(this.getField()), BooleanClause.Occur.FILTER)
.build();
final Query rewritten = searcher.rewrite(booleanQuery);

if (profiler != null) {
Timer timer = profiler.getTimer(QueryTimingType.CREATE_WEIGHT);
timer.start();

try {
return new ProfileWeight(
this.getFilterQuery(),
searcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f),
profiler
);
} finally {
timer.stop();
}
}

return searcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
}
return null;
Expand Down
55 changes: 39 additions & 16 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.Builder;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
Expand All @@ -28,18 +29,21 @@
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.query.ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder;
import org.opensearch.knn.index.query.profile.KNNProfileTimingType;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.search.profile.ContextualProfileBreakdown;
import org.opensearch.search.profile.query.QueryTimingType;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -59,39 +63,43 @@
/**
* Calculate query weights and build query scorers.
*/
@Builder
@Log4j2
public class KNNWeight extends Weight {
private static ModelDao modelDao;

private final KNNQuery knnQuery;
private final float boost;

private final NativeMemoryCacheManager nativeMemoryCacheManager;
private final NativeMemoryCacheManager nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
@Getter
private final Weight filterWeight;
private final ExactSearcher exactSearcher;
private final ExactSearcher exactSearcher = DEFAULT_EXACT_SEARCHER;

private static ExactSearcher DEFAULT_EXACT_SEARCHER;
private final QuantizationService quantizationService;
private final QuantizationService quantizationService = QuantizationService.getInstance();;
private ContextualProfileBreakdown<QueryTimingType> knnQueryProfiler;

public KNNWeight(KNNQuery query, float boost) {
super(query);
this.knnQuery = query;
this.boost = boost;
this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
this.filterWeight = null;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
}

public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
super(query);
this.knnQuery = query;
this.boost = boost;
this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
this.filterWeight = filterWeight;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
}

public KNNWeight(KNNQuery query, float boost, Weight filterWeight, ContextualProfileBreakdown<QueryTimingType> knnQueryProfiler) {
super(query);
this.knnQuery = query;
this.boost = boost;
this.filterWeight = filterWeight;
this.knnQueryProfiler = knnQueryProfiler;
}

public static void initialize(ModelDao modelDao) {
Expand Down Expand Up @@ -333,6 +341,10 @@ private Map<Integer, Float> doANNSearch(
throw new RuntimeException("Index has already been closed");
}
int[] parentIds = getParentIdsArray(context);
if (knnQueryProfiler != null) {
knnQueryProfiler.context(context).getTimer(KNNProfileTimingType.ANN_SEARCH.name()).start();
}

if (k > 0) {
if (knnQuery.getVectorDataType() == VectorDataType.BINARY
|| quantizedVector != null && quantizationService.getVectorDataTypeForTransfer(fieldInfo) == VectorDataType.BINARY) {
Expand Down Expand Up @@ -378,18 +390,19 @@ private Map<Integer, Float> doANNSearch(
} finally {
indexAllocation.readUnlock();
indexAllocation.decRef();
if (knnQueryProfiler != null) {
knnQueryProfiler.context(context).getTimer(KNNProfileTimingType.ANN_SEARCH.name()).stop();
}
}

if (results.length == 0) {
log.debug("[KNN] Query yielded 0 results");
return Collections.emptyMap();
}

if (quantizedVector != null) {
return Arrays.stream(results)
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), SpaceType.HAMMING)));
}
final SpaceType scoringSpaceType = quantizedVector != null ? SpaceType.HAMMING : spaceType;
return Arrays.stream(results)
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), scoringSpaceType)));
}

/**
Expand All @@ -401,7 +414,17 @@ public Map<Integer, Float> exactSearch(
final LeafReaderContext leafReaderContext,
final ExactSearcher.ExactSearcherContext exactSearcherContext
) throws IOException {
return exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext);
if (knnQueryProfiler != null) {
knnQueryProfiler.context(leafReaderContext).getTimer(KNNProfileTimingType.EXACT_KNN_SEARCH.name()).start();
}

try {
return exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext);
} finally {
if (knnQueryProfiler != null) {
knnQueryProfiler.context(leafReaderContext).getTimer(KNNProfileTimingType.EXACT_KNN_SEARCH.name()).stop();
}
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.profile;

import java.util.Arrays;
import java.util.Set;
import java.util.stream.Collectors;

public enum KNNProfileTimingType {

EXACT_KNN_SEARCH,
ANN_SEARCH;

public static Set<String> getAllValues() {
return Arrays.stream(KNNProfileTimingType.values()).map(Enum::name).collect(Collectors.toSet());
}
}
15 changes: 15 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.plugin;

import org.apache.lucene.search.Query;
import org.opensearch.cluster.NamedDiff;
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.core.ParseField;
Expand All @@ -14,6 +15,9 @@
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.knn.index.KNNCircuitBreaker;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.profile.KNNProfileTimingType;
import org.opensearch.knn.plugin.search.KNNConcurrentSearchRequestDecider;
import org.opensearch.knn.index.util.KNNClusterUtil;
import org.opensearch.knn.index.query.KNNQueryBuilder;
Expand Down Expand Up @@ -112,6 +116,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -180,6 +185,16 @@ public List<QuerySpec<?>> getQueries() {
return singletonList(new QuerySpec<>(KNNQueryBuilder.NAME, KNNQueryBuilder::new, KNNQueryBuilderParser::fromXContent));
}

@Override
public Map<Class<? extends Query>, Set<String>> registerProfilerTimingTypes() {
return Map.of(
KNNQuery.class,
KNNProfileTimingType.getAllValues(),
NativeEngineKnnVectorQuery.class,
KNNProfileTimingType.getAllValues()
);
}

@Override
public Collection<Object> createComponents(
Client client,
Expand Down