diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 633b85d9ea83..710ccbace868 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -112,6 +112,9 @@ API Changes * GITHUB#14899: Deprecate MergeSpecification#segString(Directory) (kitoha) +* GITHUB#14978: Add a bulk scoring interface to RandomVectorScorer + (Trevor McCulloch, Chris Hegarty) + New Features --------------------- * GITHUB#14404: Introducing DocValuesMultiRangeQuery.SortedNumericStabbingBuilder into sandbox. diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index a83b1635032c..faa885f7b2ea 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -71,6 +71,8 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Lucene99HnswVectorsFormat.class); + // Number of ordinals to score at a time when scoring exhaustively rather than using HNSW. + private static final int EXHAUSTIVE_BULK_SCORE_ORDS = 64; private final FlatVectorsReader flatVectorsReader; private final FieldInfos fieldInfos; @@ -344,15 +346,33 @@ private void search( HnswGraphSearcher.search( scorer, collector, getGraph(fieldEntry), acceptedOrds, filteredDocCount); } else { - // if k is larger than the number of vectors, we can just iterate over all vectors - // and collect them + // if k is larger than the number of vectors we expect to visit in an HNSW search, + // we can just iterate over all vectors and collect them. + int[] ords = new int[EXHAUSTIVE_BULK_SCORE_ORDS]; + float[] scores = new float[EXHAUSTIVE_BULK_SCORE_ORDS]; + int numOrds = 0; for (int i = 0; i < scorer.maxOrd(); i++) { if (acceptedOrds == null || acceptedOrds.get(i)) { if (knnCollector.earlyTerminated()) { break; } + ords[numOrds++] = i; + if (numOrds == ords.length) { + scorer.bulkScore(ords, scores, numOrds); + for (int j = 0; j < numOrds; j++) { + knnCollector.incVisitedCount(1); + knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]); + } + numOrds = 0; + } + } + } + + if (numOrds > 0) { + scorer.bulkScore(ords, scores, numOrds); + for (int j = 0; j < numOrds; j++) { knnCollector.incVisitedCount(1); - knnCollector.collect(scorer.ordToDoc(i), scorer.score(i)); + knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]); } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 74d5c1541af5..cd9b013d27ac 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -42,6 +42,9 @@ public class HnswGraphSearcher extends AbstractHnswGraphSearcher { protected BitSet visited; + protected int[] bulkNodes = null; + protected float[] bulkScores = null; + /** * HNSW search is roughly logarithmic. This doesn't take maxConn into account, but it is a pretty * good approximation. @@ -276,6 +279,11 @@ void searchLevel( prepareScratchState(size); + if (bulkNodes == null || bulkNodes.length < graph.maxConn() * 2) { + bulkNodes = new int[graph.maxConn() * 2]; + bulkScores = new float[graph.maxConn() * 2]; + } + for (int ep : eps) { if (visited.getAndSet(ep) == false) { if (results.earlyTerminated()) { @@ -313,6 +321,7 @@ void searchLevel( int topCandidateNode = candidates.pop(); graphSeek(graph, level, topCandidateNode); int friendOrd; + int numNodes = 0; while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) { assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; if (visited.getAndSet(friendOrd)) { @@ -322,18 +331,28 @@ void searchLevel( if (results.earlyTerminated()) { break; } - float friendSimilarity = scorer.score(friendOrd); - results.incVisitedCount(1); - if (friendSimilarity >= minAcceptedSimilarity) { - candidates.add(friendOrd, friendSimilarity); - if (acceptOrds == null || acceptOrds.get(friendOrd)) { - if (results.collect(friendOrd, friendSimilarity)) { - float oldMinAcceptedSimilarity = minAcceptedSimilarity; - minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity()); - if (minAcceptedSimilarity > oldMinAcceptedSimilarity) { - // we adjusted our minAcceptedSimilarity, so we should explore the next equivalent - // if necessary - shouldExploreMinSim = true; + + bulkNodes[numNodes++] = friendOrd; + } + + if (numNodes > 0) { + numNodes = (int) Math.min((long) numNodes, results.visitLimit() - results.visitedCount()); + scorer.bulkScore(bulkNodes, bulkScores, numNodes); + results.incVisitedCount(numNodes); + for (int i = 0; i < numNodes; i++) { + int node = bulkNodes[i]; + float score = bulkScores[i]; + if (score >= minAcceptedSimilarity) { + candidates.add(node, score); + if (acceptOrds == null || acceptOrds.get(node)) { + if (results.collect(node, score)) { + float oldMinAcceptedSimilarity = minAcceptedSimilarity; + minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity()); + if (minAcceptedSimilarity > oldMinAcceptedSimilarity) { + // we adjusted our minAcceptedSimilarity, so we should explore the next equivalent + // if necessary + shouldExploreMinSim = true; + } } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java index d2ea8e28a246..d481256d7a52 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java @@ -34,6 +34,21 @@ public interface RandomVectorScorer { */ float score(int node) throws IOException; + /** + * Score a list of numNodes and store the results in the scores array. + * + *

This may be more efficient than calling {@link #score(int)} for each node. + * + * @param nodes array of nodes to score. + * @param scores output array of scores corresponding to each node. + * @param numNodes number of nodes to score. Must not exceed length of nodes or scores arrays. + */ + default void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException { + for (int i = 0; i < numNodes; i++) { + scores[i] = score(nodes[i]); + } + } + /** * @return the maximum possible ordinal for this scorer */ diff --git a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java index ec3d7d34eb5f..1dab736945de 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java @@ -30,25 +30,31 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.vectorization.BaseVectorizationTestCase; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; -import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.IOSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.hamcrest.Matcher; import org.hamcrest.MatcherAssert; -public class TestFlatVectorScorer extends LuceneTestCase { +public class TestFlatVectorScorer extends BaseVectorizationTestCase { private static final AtomicInteger count = new AtomicInteger(); private final FlatVectorsScorer flatVectorsScorer; @@ -66,7 +72,8 @@ public static Iterable parametersFactory() { List.of( DefaultFlatVectorScorer.INSTANCE, new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()), - FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + FlatVectorScorerUtil.getLucene99FlatVectorsScorer(), + maybePanamaProvider().getLucene99FlatVectorsScorer()); var dirs = List.>of( TestFlatVectorScorer::newDirectory, @@ -180,6 +187,113 @@ public void testCheckFloatDimensions() throws IOException { } } + public void testBulkScorerBytes() throws IOException { + int dims = random().nextInt(1, 1024); + int size = random().nextInt(2, 255); + String fileName = "testBulkScorerBytes"; + try (Directory dir = newDirectory.get()) { + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + byte[] ba = randomByteVector(dims); + out.writeBytes(ba, 0, ba.length); + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + assert in.length() == (long) dims * size * Byte.BYTES; + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = byteVectorValues(dims, size, in, sim); + assertBulkEqualsNonBulk(values, sim); + assertBulkEqualsNonBulkSupplier(values, sim); + } + } + } + } + + public void testBulkScorerFloats() throws IOException { + int dims = random().nextInt(1, 1024); + int size = random().nextInt(2, 255); + String fileName = "testBulkScorerFloats"; + try (Directory dir = newDirectory.get()) { + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + byte[] ba = concat(randomFloatVector(dims)); + out.writeBytes(ba, 0, ba.length); + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + assert in.length() == (long) dims * size * Float.BYTES; + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = floatVectorValues(dims, size, in, sim); + assertBulkEqualsNonBulk(values, sim); + assertBulkEqualsNonBulkSupplier(values, sim); + } + } + } + } + + void assertBulkEqualsNonBulk(KnnVectorValues values, VectorSimilarityFunction sim) + throws IOException { + final int dims = values.dimension(); + final int size = values.size(); + final float delta = 1e-3f * size; + var scorer = + values.getEncoding() == VectorEncoding.BYTE + ? flatVectorsScorer.getRandomVectorScorer(sim, values, randomByteVector(dims)) + : flatVectorsScorer.getRandomVectorScorer(sim, values, randomFloatVector(dims)); + int[] indices = randomIndices(size); + float[] expectedScores = new float[size]; + for (int i = 0; i < size; i++) { + expectedScores[i] = scorer.score(indices[i]); + } + float[] bulkScores = new float[size]; + scorer.bulkScore(indices, bulkScores, size); + assertArrayEquals(expectedScores, bulkScores, delta); + assertNoScoreBeyondNumNodes(scorer, size); + } + + // score through the supplier/updatableScorer interface + void assertBulkEqualsNonBulkSupplier(KnnVectorValues values, VectorSimilarityFunction sim) + throws IOException { + final int size = values.size(); + final float delta = 1e-3f * size; + var ss = flatVectorsScorer.getRandomVectorScorerSupplier(sim, values); + var updatableScorer = ss.scorer(); + var targetNode = random().nextInt(size); + updatableScorer.setScoringOrdinal(targetNode); + int[] indices = randomIndices(size); + float[] expectedScores = new float[size]; + for (int i = 0; i < size; i++) { + expectedScores[i] = updatableScorer.score(indices[i]); + } + float[] bulkScores = new float[size]; + updatableScorer.bulkScore(indices, bulkScores, size); + assertArrayEquals(expectedScores, bulkScores, delta); + assertNoScoreBeyondNumNodes(updatableScorer, size); + } + + void assertNoScoreBeyondNumNodes(RandomVectorScorer scorer, int maxSize) throws IOException { + int numNodes = random().nextInt(0, maxSize); + int[] indices = new int[numNodes + 1]; + float[] bulkScores = new float[numNodes + 1]; + bulkScores[bulkScores.length - 1] = Float.NaN; + scorer.bulkScore(indices, bulkScores, numNodes); + assertEquals(Float.NaN, bulkScores[bulkScores.length - 1], 0.0f); + } + + byte[] randomByteVector(int dims) { + byte[] ba = new byte[dims]; + random().nextBytes(ba); + return ba; + } + + float[] randomFloatVector(int dims) { + float[] fa = new float[dims]; + for (int i = 0; i < dims; ++i) { + fa[i] = random().nextFloat(); + } + return fa; + } + ByteVectorValues byteVectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { return new OffHeapByteVectorValues.DenseOffHeapVectorValues( @@ -221,6 +335,13 @@ public static byte[] concat(byte[]... arrays) throws IOException { } } + /** Returns an int[] of the given size with valued from 0 to size shuffled. */ + public static int[] randomIndices(int size) { + List list = IntStream.range(0, size).boxed().collect(Collectors.toList()); + Collections.shuffle(list, random()); + return list.stream().mapToInt(i -> i).toArray(); + } + public static void assertThat(T actual, Matcher matcher) { MatcherAssert.assertThat("", actual, matcher); }