Skip to content

Commit 2bb4af4

Browse files
committed
iter
1 parent e3b583e commit 2bb4af4

File tree

4 files changed

+53
-20
lines changed

4 files changed

+53
-20
lines changed

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,8 @@ static DocsWithFieldSet writeBinarizedVectorAndQueryData(
410410
Math.max(
411411
encoding.getDiscreteDimensions(floatVectorValues.dimension()),
412412
queryEncoding.getDiscreteDimensions(floatVectorValues.dimension()));
413-
assert discretizedDims % encoding.getBits() == 0;
414-
assert discretizedDims % queryEncoding.getBits() == 0;
413+
assert discretizedDims % (8 / encoding.getBitsPerDim()) == 0;
414+
assert discretizedDims % (8 / queryEncoding.getBitsPerDim()) == 0;
415415
byte[][] quantizationScratch = new byte[2][];
416416
quantizationScratch[0] = new byte[discretizedDims];
417417
quantizationScratch[1] = new byte[discretizedDims];
@@ -563,7 +563,9 @@ private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
563563
if (finalQuantizedScoreDataInput != null) {
564564
scoreVectorValues =
565565
new OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues(
566-
fieldInfo.getVectorDimension(),
566+
// pass the commonly discretized dimension to ensure both vectorValues and
567+
// scoreVectorValues have the same discretized dimension
568+
vectorValues.discretizedDimension(),
567569
docsWithField.cardinality(),
568570
centroid,
569571
cDotC,
@@ -579,7 +581,7 @@ private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
579581
? vectorsScorer.getRandomVectorScorerSupplier(
580582
fieldInfo.getVectorSimilarityFunction(), vectorValues)
581583
: vectorsScorer.getRandomVectorScorerSupplier(
582-
fieldInfo.getVectorSimilarityFunction(), vectorValues, scoreVectorValues);
584+
fieldInfo.getVectorSimilarityFunction(), scoreVectorValues, vectorValues);
583585
return new QuantizedCloseableRandomVectorScorerSupplier(
584586
scorerSupplier,
585587
vectorValues,

lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public abstract class OffHeapScalarQuantizedVectorValues extends QuantizedByteVe
4646
final byte[] vectorValue;
4747
final ByteBuffer byteBuffer;
4848
final int byteSize;
49+
final int discretizedDimension;
4950
private int lastOrd = -1;
5051
final float[] correctiveValues;
5152
int quantizedComponentSum;
@@ -73,12 +74,14 @@ public abstract class OffHeapScalarQuantizedVectorValues extends QuantizedByteVe
7374
this.centroid = centroid;
7475
this.centroidDp = centroidDp;
7576
this.correctiveValues = new float[3];
76-
this.byteSize = encoding.getPackedLength(dimension) + (Float.BYTES * 3) + Integer.BYTES;
77-
this.byteBuffer = ByteBuffer.allocate(encoding.getPackedLength(dimension));
78-
this.vectorValue = byteBuffer.array();
79-
this.quantizer = quantizer;
8077
this.encoding = encoding;
8178
this.queryEncoding = queryEncoding;
79+
this.discretizedDimension = calculateDiscretizedDimension(dimension, encoding, queryEncoding);
80+
this.byteSize =
81+
encoding.getPackedLength(discretizedDimension) + (Float.BYTES * 3) + Integer.BYTES;
82+
this.byteBuffer = ByteBuffer.allocate(encoding.getPackedLength(discretizedDimension));
83+
this.vectorValue = byteBuffer.array();
84+
this.quantizer = quantizer;
8285
}
8386

8487
@Override

lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,19 @@
2525
/** Scalar quantized byte vector values */
2626
abstract class QuantizedByteVectorValues extends ByteVectorValues {
2727

28+
public static int calculateDiscretizedDimension(
29+
int dims, ScalarEncoding storageEncoding, ScalarEncoding queryEncoding) {
30+
if (storageEncoding == queryEncoding) {
31+
return storageEncoding.getDiscreteDimensions(dims);
32+
}
33+
int queryDiscretized = queryEncoding.getDiscreteDimensions(dims);
34+
int docDiscretized = storageEncoding.getDiscreteDimensions(dims);
35+
int maxDiscretized = Math.max(queryDiscretized, docDiscretized);
36+
assert maxDiscretized % (8 / queryEncoding.getBitsPerDim()) == 0;
37+
assert maxDiscretized % (8 / storageEncoding.getBitsPerDim()) == 0;
38+
return maxDiscretized;
39+
}
40+
2841
/**
2942
* Retrieve the corrective terms for the given vector ordinal. For the dot-product family of
3043
* distances, the corrective terms are, in order
@@ -72,16 +85,8 @@ public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(i
7285
* @return The correctly discretized dimension given the configured encodings
7386
*/
7487
public int discretizedDimension() {
75-
int dims = dimension();
76-
if (getScalarEncoding() == getQueryScalarEncoding()) {
77-
return getScalarEncoding().getDiscreteDimensions(dims);
78-
}
79-
int queryDiscretized = getQueryScalarEncoding().getDiscreteDimensions(dims);
80-
int docDiscretized = getScalarEncoding().getDiscreteDimensions(dims);
81-
int maxDiscretized = Math.max(queryDiscretized, docDiscretized);
82-
assert maxDiscretized % getQueryScalarEncoding().getBits() == 0;
83-
assert maxDiscretized % getScalarEncoding().getBits() == 0;
84-
return maxDiscretized;
88+
return calculateDiscretizedDimension(
89+
dimension(), getScalarEncoding(), getQueryScalarEncoding());
8590
}
8691

8792
/**

lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.lucene.codecs.KnnVectorsFormat;
3131
import org.apache.lucene.codecs.KnnVectorsReader;
3232
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
33+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
3334
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
3435
import org.apache.lucene.document.Document;
3536
import org.apache.lucene.document.KnnFloatVectorField;
@@ -47,14 +48,36 @@
4748
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
4849
import org.apache.lucene.tests.util.TestUtil;
4950
import org.apache.lucene.util.SameThreadExecutorService;
51+
import org.junit.Before;
5052

5153
public class TestLucene104HnswScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
5254

53-
private static final KnnVectorsFormat FORMAT = new Lucene104HnswScalarQuantizedVectorsFormat();
55+
private KnnVectorsFormat format;
56+
57+
@Before
58+
@Override
59+
public void setUp() throws Exception {
60+
var encodingValues = ScalarEncoding.values();
61+
var encoding = encodingValues[random().nextInt(encodingValues.length)];
62+
var queryEncoding = encoding;
63+
// always assume asymmetric for now. Eventually make this more general
64+
if (encoding == ScalarEncoding.SINGLE_BIT) {
65+
queryEncoding = ScalarEncoding.PACKED_NIBBLE;
66+
}
67+
format =
68+
new Lucene104HnswScalarQuantizedVectorsFormat(
69+
encoding,
70+
queryEncoding,
71+
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
72+
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
73+
1,
74+
null);
75+
super.setUp();
76+
}
5477

5578
@Override
5679
protected Codec getCodec() {
57-
return TestUtil.alwaysKnnVectorsFormat(FORMAT);
80+
return TestUtil.alwaysKnnVectorsFormat(format);
5881
}
5982

6083
public void testToString() {

0 commit comments

Comments
 (0)