diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 633b85d9ea83..710ccbace868 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -112,6 +112,9 @@ API Changes * GITHUB#14899: Deprecate MergeSpecification#segString(Directory) (kitoha) +* GITHUB#14978: Add a bulk scoring interface to RandomVectorScorer + (Trevor McCulloch, Chris Hegarty) + New Features --------------------- * GITHUB#14404: Introducing DocValuesMultiRangeQuery.SortedNumericStabbingBuilder into sandbox. diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index a83b1635032c..faa885f7b2ea 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -71,6 +71,8 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Lucene99HnswVectorsFormat.class); + // Number of ordinals to score at a time when scoring exhaustively rather than using HNSW. + private static final int EXHAUSTIVE_BULK_SCORE_ORDS = 64; private final FlatVectorsReader flatVectorsReader; private final FieldInfos fieldInfos; @@ -344,15 +346,33 @@ private void search( HnswGraphSearcher.search( scorer, collector, getGraph(fieldEntry), acceptedOrds, filteredDocCount); } else { - // if k is larger than the number of vectors, we can just iterate over all vectors - // and collect them + // if k is larger than the number of vectors we expect to visit in an HNSW search, + // we can just iterate over all vectors and collect them. + int[] ords = new int[EXHAUSTIVE_BULK_SCORE_ORDS]; + float[] scores = new float[EXHAUSTIVE_BULK_SCORE_ORDS]; + int numOrds = 0; for (int i = 0; i < scorer.maxOrd(); i++) { if (acceptedOrds == null || acceptedOrds.get(i)) { if (knnCollector.earlyTerminated()) { break; } + ords[numOrds++] = i; + if (numOrds == ords.length) { + scorer.bulkScore(ords, scores, numOrds); + for (int j = 0; j < numOrds; j++) { + knnCollector.incVisitedCount(1); + knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]); + } + numOrds = 0; + } + } + } + + if (numOrds > 0) { + scorer.bulkScore(ords, scores, numOrds); + for (int j = 0; j < numOrds; j++) { knnCollector.incVisitedCount(1); - knnCollector.collect(scorer.ordToDoc(i), scorer.score(i)); + knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]); } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 74d5c1541af5..cd9b013d27ac 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -42,6 +42,9 @@ public class HnswGraphSearcher extends AbstractHnswGraphSearcher { protected BitSet visited; + protected int[] bulkNodes = null; + protected float[] bulkScores = null; + /** * HNSW search is roughly logarithmic. This doesn't take maxConn into account, but it is a pretty * good approximation. @@ -276,6 +279,11 @@ void searchLevel( prepareScratchState(size); + if (bulkNodes == null || bulkNodes.length < graph.maxConn() * 2) { + bulkNodes = new int[graph.maxConn() * 2]; + bulkScores = new float[graph.maxConn() * 2]; + } + for (int ep : eps) { if (visited.getAndSet(ep) == false) { if (results.earlyTerminated()) { @@ -313,6 +321,7 @@ void searchLevel( int topCandidateNode = candidates.pop(); graphSeek(graph, level, topCandidateNode); int friendOrd; + int numNodes = 0; while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) { assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; if (visited.getAndSet(friendOrd)) { @@ -322,18 +331,28 @@ void searchLevel( if (results.earlyTerminated()) { break; } - float friendSimilarity = scorer.score(friendOrd); - results.incVisitedCount(1); - if (friendSimilarity >= minAcceptedSimilarity) { - candidates.add(friendOrd, friendSimilarity); - if (acceptOrds == null || acceptOrds.get(friendOrd)) { - if (results.collect(friendOrd, friendSimilarity)) { - float oldMinAcceptedSimilarity = minAcceptedSimilarity; - minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity()); - if (minAcceptedSimilarity > oldMinAcceptedSimilarity) { - // we adjusted our minAcceptedSimilarity, so we should explore the next equivalent - // if necessary - shouldExploreMinSim = true; + + bulkNodes[numNodes++] = friendOrd; + } + + if (numNodes > 0) { + numNodes = (int) Math.min((long) numNodes, results.visitLimit() - results.visitedCount()); + scorer.bulkScore(bulkNodes, bulkScores, numNodes); + results.incVisitedCount(numNodes); + for (int i = 0; i < numNodes; i++) { + int node = bulkNodes[i]; + float score = bulkScores[i]; + if (score >= minAcceptedSimilarity) { + candidates.add(node, score); + if (acceptOrds == null || acceptOrds.get(node)) { + if (results.collect(node, score)) { + float oldMinAcceptedSimilarity = minAcceptedSimilarity; + minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity()); + if (minAcceptedSimilarity > oldMinAcceptedSimilarity) { + // we adjusted our minAcceptedSimilarity, so we should explore the next equivalent + // if necessary + shouldExploreMinSim = true; + } } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java index d2ea8e28a246..d481256d7a52 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java @@ -34,6 +34,21 @@ public interface RandomVectorScorer { */ float score(int node) throws IOException; + /** + * Score a list of numNodes and store the results in the scores array. + * + *
This may be more efficient than calling {@link #score(int)} for each node. + * + * @param nodes array of nodes to score. + * @param scores output array of scores corresponding to each node. + * @param numNodes number of nodes to score. Must not exceed length of nodes or scores arrays. + */ + default void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + for (int i = 0; i < numNodes; i++) { + scores[i] = score(nodes[i]); + } + } + /** * @return the maximum possible ordinal for this scorer */ diff --git a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java index ec3d7d34eb5f..1dab736945de 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java @@ -30,25 +30,31 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.vectorization.BaseVectorizationTestCase; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; -import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.IOSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.hamcrest.Matcher; import org.hamcrest.MatcherAssert; -public class TestFlatVectorScorer extends LuceneTestCase { +public class TestFlatVectorScorer extends BaseVectorizationTestCase { private static final AtomicInteger count = new AtomicInteger(); private final FlatVectorsScorer flatVectorsScorer; @@ -66,7 +72,8 @@ public static Iterable