1616 */
1717package org .apache .lucene .codecs .lucene102 ;
1818
19- import static org .apache .lucene .codecs .lucene102 .Lucene102BinaryQuantizedVectorsFormat .INDEX_BITS ;
2019import static org .apache .lucene .codecs .lucene102 .Lucene102BinaryQuantizedVectorsFormat .QUERY_BITS ;
2120import static org .apache .lucene .index .VectorSimilarityFunction .COSINE ;
21+ import static org .apache .lucene .index .VectorSimilarityFunction .EUCLIDEAN ;
22+ import static org .apache .lucene .index .VectorSimilarityFunction .MAXIMUM_INNER_PRODUCT ;
2223import static org .apache .lucene .util .quantization .OptimizedScalarQuantizer .transposeHalfByte ;
2324
2425import java .io .IOException ;
3031import org .apache .lucene .util .hnsw .RandomVectorScorer ;
3132import org .apache .lucene .util .hnsw .RandomVectorScorerSupplier ;
3233import org .apache .lucene .util .hnsw .UpdateableRandomVectorScorer ;
33- import org .apache .lucene .util .quantization .OptimizedScalarQuantizedVectorSimilarity ;
3434import org .apache .lucene .util .quantization .OptimizedScalarQuantizer ;
3535import org .apache .lucene .util .quantization .OptimizedScalarQuantizer .QuantizationResult ;
3636
3737/** Vector scorer over binarized vector values */
3838public class Lucene102BinaryFlatVectorsScorer implements FlatVectorsScorer {
3939 private final FlatVectorsScorer nonQuantizedDelegate ;
40+ private static final float FOUR_BIT_SCALE = 1f / ((1 << 4 ) - 1 );
4041
4142 public Lucene102BinaryFlatVectorsScorer (FlatVectorsScorer nonQuantizedDelegate ) {
4243 this .nonQuantizedDelegate = nonQuantizedDelegate ;
@@ -72,20 +73,10 @@ public RandomVectorScorer getRandomVectorScorer(
7273 quantizer .scalarQuantize (target , initial , (byte ) 4 , centroid );
7374 transposeHalfByte (initial , quantized );
7475 return new RandomVectorScorer .AbstractRandomVectorScorer (binarizedVectors ) {
75- private final OptimizedScalarQuantizedVectorSimilarity similarity =
76- new OptimizedScalarQuantizedVectorSimilarity (
77- similarityFunction ,
78- binarizedVectors .dimension (),
79- binarizedVectors .getCentroidDP (),
80- QUERY_BITS ,
81- INDEX_BITS );
82-
8376 @ Override
8477 public float score (int node ) throws IOException {
85- var indexVector = binarizedVectors .vectorValue (node );
86- var indexCorrections = binarizedVectors .getCorrectiveTerms (node );
87- float dotProduct = VectorUtil .int4BitDotProduct (quantized , indexVector );
88- return similarity .score (dotProduct , queryCorrections , indexCorrections );
78+ return quantizedScore (
79+ quantized , queryCorrections , binarizedVectors , node , similarityFunction );
8980 }
9081 };
9182 }
@@ -102,8 +93,7 @@ public RandomVectorScorer getRandomVectorScorer(
10293 RandomVectorScorerSupplier getRandomVectorScorerSupplier (
10394 VectorSimilarityFunction similarityFunction ,
10495 Lucene102BinaryQuantizedVectorsWriter .OffHeapBinarizedQueryVectorValues scoringVectors ,
105- BinarizedByteVectorValues targetVectors )
106- throws IOException {
96+ BinarizedByteVectorValues targetVectors ) {
10797 return new BinarizedRandomVectorScorerSupplier (
10898 scoringVectors , targetVectors , similarityFunction );
10999 }
@@ -118,31 +108,15 @@ static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSu
118108 private final Lucene102BinaryQuantizedVectorsWriter .OffHeapBinarizedQueryVectorValues
119109 queryVectors ;
120110 private final BinarizedByteVectorValues targetVectors ;
121- private final OptimizedScalarQuantizedVectorSimilarity similarity ;
122-
123- BinarizedRandomVectorScorerSupplier (
124- Lucene102BinaryQuantizedVectorsWriter .OffHeapBinarizedQueryVectorValues queryVectors ,
125- BinarizedByteVectorValues targetVectors ,
126- VectorSimilarityFunction similarityFunction )
127- throws IOException {
128- this .queryVectors = queryVectors ;
129- this .targetVectors = targetVectors ;
130- this .similarity =
131- new OptimizedScalarQuantizedVectorSimilarity (
132- similarityFunction ,
133- targetVectors .dimension (),
134- targetVectors .getCentroidDP (),
135- QUERY_BITS ,
136- INDEX_BITS );
137- }
111+ private final VectorSimilarityFunction similarityFunction ;
138112
139113 BinarizedRandomVectorScorerSupplier (
140114 Lucene102BinaryQuantizedVectorsWriter .OffHeapBinarizedQueryVectorValues queryVectors ,
141115 BinarizedByteVectorValues targetVectors ,
142- OptimizedScalarQuantizedVectorSimilarity similarity ) {
116+ VectorSimilarityFunction similarityFunction ) {
143117 this .queryVectors = queryVectors ;
144118 this .targetVectors = targetVectors ;
145- this .similarity = similarity ;
119+ this .similarityFunction = similarityFunction ;
146120 }
147121
148122 @ Override
@@ -165,20 +139,57 @@ public float score(int node) throws IOException {
165139 if (vector == null || queryCorrections == null ) {
166140 throw new IllegalStateException ("setScoringOrdinal was not called" );
167141 }
168- var indexVector = targetVectors .vectorValue (node );
169- var indexCorrections = targetVectors .getCorrectiveTerms (node );
170- return similarity .score (
171- (float ) VectorUtil .int4BitDotProduct (vector , indexVector ),
172- queryCorrections ,
173- indexCorrections );
142+ return quantizedScore (vector , queryCorrections , targetVectors , node , similarityFunction );
174143 }
175144 };
176145 }
177146
178147 @ Override
179148 public RandomVectorScorerSupplier copy () throws IOException {
180149 return new BinarizedRandomVectorScorerSupplier (
181- queryVectors .copy (), targetVectors .copy (), similarity );
150+ queryVectors .copy (), targetVectors .copy (), similarityFunction );
151+ }
152+ }
153+
154+ static float quantizedScore (
155+ byte [] quantizedQuery ,
156+ OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
157+ BinarizedByteVectorValues targetVectors ,
158+ int targetOrd ,
159+ VectorSimilarityFunction similarityFunction )
160+ throws IOException {
161+ byte [] binaryCode = targetVectors .vectorValue (targetOrd );
162+ float qcDist = VectorUtil .int4BitDotProduct (quantizedQuery , binaryCode );
163+ OptimizedScalarQuantizer .QuantizationResult indexCorrections =
164+ targetVectors .getCorrectiveTerms (targetOrd );
165+ float x1 = indexCorrections .quantizedComponentSum ();
166+ float ax = indexCorrections .lowerInterval ();
167+ // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
168+ float lx = indexCorrections .upperInterval () - ax ;
169+ float ay = queryCorrections .lowerInterval ();
170+ float ly = (queryCorrections .upperInterval () - ay ) * FOUR_BIT_SCALE ;
171+ float y1 = queryCorrections .quantizedComponentSum ();
172+ float score =
173+ ax * ay * targetVectors .dimension () + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist ;
174+ // For euclidean, we need to invert the score and apply the additional correction, which is
175+ // assumed to be the squared l2norm of the centroid centered vectors.
176+ if (similarityFunction == EUCLIDEAN ) {
177+ score =
178+ queryCorrections .additionalCorrection ()
179+ + indexCorrections .additionalCorrection ()
180+ - 2 * score ;
181+ return Math .max (1 / (1f + score ), 0 );
182+ } else {
183+ // For cosine and max inner product, we need to apply the additional correction, which is
184+ // assumed to be the non-centered dot-product between the vector and the centroid
185+ score +=
186+ queryCorrections .additionalCorrection ()
187+ + indexCorrections .additionalCorrection ()
188+ - targetVectors .getCentroidDP ();
189+ if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
190+ return VectorUtil .scaleMaxInnerProductScore (score );
191+ }
192+ return Math .max ((1f + score ) / 2f , 0 );
182193 }
183194 }
184195}
0 commit comments