Skip to content

Commit 8063726

Browse files
committed
fixing bugs
1 parent 2bb4af4 commit 8063726

9 files changed

+174
-223
lines changed

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
8686
/**
8787
* Constructs a format using the given graph construction parameters and scalar quantization.
8888
*
89+
* @param encoding the quantization encoding used to encode the vectors
8990
* @param maxConn the maximum number of connections to a node in the HNSW graph
9091
* @param beamWidth the size of the queue maintained during graph construction.
9192
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
@@ -99,31 +100,8 @@ public Lucene104HnswScalarQuantizedVectorsFormat(
99100
int beamWidth,
100101
int numMergeWorkers,
101102
ExecutorService mergeExec) {
102-
this(encoding, encoding, maxConn, beamWidth, numMergeWorkers, mergeExec);
103-
}
104-
105-
/**
106-
* Constructs a format using the given graph construction parameters and scalar quantization.
107-
*
108-
* @param encoding the encoding used to encode the indexed vectors
109-
* @param queryEncoding the encoding used to encode the query vectors. This may be different from
110-
* the encoding used to encode the indexed vectors.
111-
* @param maxConn the maximum number of connections to a node in the HNSW graph
112-
* @param beamWidth the size of the queue maintained during graph construction.
113-
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
114-
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
115-
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
116-
* generated by this format to do the merge
117-
*/
118-
public Lucene104HnswScalarQuantizedVectorsFormat(
119-
ScalarEncoding encoding,
120-
ScalarEncoding queryEncoding,
121-
int maxConn,
122-
int beamWidth,
123-
int numMergeWorkers,
124-
ExecutorService mergeExec) {
125103
super(NAME);
126-
flatVectorsFormat = new Lucene104ScalarQuantizedVectorsFormat(encoding, queryEncoding);
104+
flatVectorsFormat = new Lucene104ScalarQuantizedVectorsFormat(encoding);
127105
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
128106
throw new IllegalArgumentException(
129107
"maxConn must be positive and less than or equal to "

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,14 @@ public RandomVectorScorer getRandomVectorScorer(
6464
if (vectorValues instanceof QuantizedByteVectorValues qv) {
6565
checkDimensions(target.length, qv.dimension());
6666
OptimizedScalarQuantizer quantizer = qv.getQuantizer();
67-
byte[] scratch = new byte[qv.discretizedDimension()];
67+
Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding scalarEncoding = qv.getScalarEncoding();
68+
byte[] scratch = new byte[scalarEncoding.getDiscreteDimensions(qv.dimension())];
6869
final byte[] targetQuantized;
69-
if (qv.getScalarEncoding() == qv.getQueryScalarEncoding()) {
70-
assert qv.getScalarEncoding()
71-
!= Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SINGLE_BIT;
70+
if (scalarEncoding.isAsymmetric() == false) {
7271
targetQuantized = scratch;
7372
} else {
7473
// This is asymmetric quantization, we will pack the vector
75-
assert qv.getScalarEncoding()
76-
== Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SINGLE_BIT;
77-
assert qv.getQueryScalarEncoding()
78-
== Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.PACKED_NIBBLE;
79-
targetQuantized = new byte[qv.getQueryScalarEncoding().getPackedLength(scratch.length)];
74+
targetQuantized = new byte[scalarEncoding.getQueryPackedLength(scratch.length)];
8075
}
8176
// We make a copy as the quantization process mutates the input
8277
float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
@@ -86,12 +81,10 @@ public RandomVectorScorer getRandomVectorScorer(
8681
target = copy;
8782
var targetCorrectiveTerms =
8883
quantizer.scalarQuantize(
89-
target, scratch, qv.getQueryScalarEncoding().getBits(), qv.getCentroid());
90-
if (qv.getScalarEncoding() != qv.getQueryScalarEncoding()) {
91-
assert qv.getScalarEncoding()
92-
== Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SINGLE_BIT;
93-
assert qv.getQueryScalarEncoding()
94-
== Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.PACKED_NIBBLE;
84+
target, scratch, scalarEncoding.getQueryBits(), qv.getCentroid());
85+
// for single bit query nibble, we need to transpose the nibbles for fast scoring comparisons
86+
if (scalarEncoding
87+
== Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE) {
9588
OptimizedScalarQuantizer.transposeHalfByte(scratch, targetQuantized);
9689
}
9790
return new RandomVectorScorer.AbstractRandomVectorScorer(qv) {
@@ -137,9 +130,7 @@ static class AsymmetricQuantizedRandomVectorScorerSupplier implements RandomVect
137130
QuantizedByteVectorValues queryVectors,
138131
QuantizedByteVectorValues targetVectors,
139132
VectorSimilarityFunction similarityFunction) {
140-
assert targetVectors.getQueryScalarEncoding() != targetVectors.getScalarEncoding();
141-
assert queryVectors.getScalarEncoding() == targetVectors.getQueryScalarEncoding();
142-
assert queryVectors.getScalarEncoding() == queryVectors.getQueryScalarEncoding();
133+
assert targetVectors.getScalarEncoding().isAsymmetric();
143134
this.queryVectors = queryVectors;
144135
this.targetVectors = targetVectors;
145136
this.similarityFunction = similarityFunction;
@@ -155,15 +146,16 @@ public UpdateableRandomVectorScorer scorer() throws IOException {
155146

156147
@Override
157148
public void setScoringOrdinal(int node) throws IOException {
158-
queryCorrections = queryVectors.getCorrectiveTerms(node);
159149
vector = queryVectors.vectorValue(node);
150+
queryCorrections = queryVectors.getCorrectiveTerms(node);
160151
}
161152

162153
@Override
163154
public float score(int node) throws IOException {
164155
if (vector == null || queryCorrections == null) {
165156
throw new IllegalStateException("setScoringOrdinal was not called");
166157
}
158+
167159
return quantizedScore(vector, queryCorrections, targetVectors, node, similarityFunction);
168160
}
169161
};
@@ -184,7 +176,7 @@ private static final class ScalarQuantizedVectorScorerSupplier
184176

185177
public ScalarQuantizedVectorScorerSupplier(
186178
QuantizedByteVectorValues values, VectorSimilarityFunction similarity) throws IOException {
187-
assert values.getQueryScalarEncoding() == values.getScalarEncoding();
179+
assert values.getScalarEncoding().isAsymmetric() == false;
188180
this.targetValues = values.copy();
189181
this.values = values;
190182
this.similarity = similarity;
@@ -212,9 +204,9 @@ public void setScoringOrdinal(int node) throws IOException {
212204
}
213205
OffHeapScalarQuantizedVectorValues.unpackNibbles(rawTargetVector, targetVector);
214206
}
215-
case SINGLE_BIT -> {
207+
case SINGLE_BIT_QUERY_NIBBLE -> {
216208
throw new IllegalStateException(
217-
"SINGLE_BIT encoding is not supported for symmetric quantization");
209+
"SINGLE_BIT_QUERY_NIBBLE encoding is not supported for symmetric quantization");
218210
}
219211
}
220212
targetCorrectiveTerms = targetValues.getCorrectiveTerms(node);
@@ -248,18 +240,18 @@ private static float quantizedScore(
248240
VectorSimilarityFunction similarityFunction)
249241
throws IOException {
250242
var scalarEncoding = targetVectors.getScalarEncoding();
251-
var queryScalarEncoding = targetVectors.getQueryScalarEncoding();
252243
byte[] quantizedDoc = targetVectors.vectorValue(targetOrd);
253244
float qcDist =
254245
switch (scalarEncoding) {
255246
case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc);
256247
case SEVEN_BIT -> VectorUtil.dotProduct(quantizedQuery, quantizedDoc);
257248
case PACKED_NIBBLE -> VectorUtil.int4DotProductSinglePacked(quantizedQuery, quantizedDoc);
258-
case SINGLE_BIT -> VectorUtil.int4BitDotProduct(quantizedQuery, quantizedDoc);
249+
case SINGLE_BIT_QUERY_NIBBLE ->
250+
VectorUtil.int4BitDotProduct(quantizedQuery, quantizedDoc);
259251
};
260252
OptimizedScalarQuantizer.QuantizationResult indexCorrections =
261253
targetVectors.getCorrectiveTerms(targetOrd);
262-
float queryScale = SCALE_LUT[queryScalarEncoding.getBits() - 1];
254+
float queryScale = SCALE_LUT[scalarEncoding.getQueryBits() - 1];
263255
float scale = SCALE_LUT[scalarEncoding.getBits() - 1];
264256
float x1 = indexCorrections.quantizedComponentSum();
265257
float ax = indexCorrections.lowerInterval();

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979
* <ul>
8080
* <li><b>int</b> the field number
8181
* <li><b>int</b> the vector encoding ordinal
82-
* <li><b>int</b> the query encoding ordinal
8382
* <li><b>int</b> the vector similarity ordinal
8483
* <li><b>vint</b> the vector dimensions
8584
* <li><b>vlong</b> the offset to the vector data in the .veq file
@@ -110,7 +109,6 @@ public class Lucene104ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
110109
new Lucene104ScalarQuantizedVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
111110

112111
private final ScalarEncoding encoding;
113-
private final ScalarEncoding queryEncoding;
114112

115113
/**
116114
* Allowed encodings for scalar quantization.
@@ -132,14 +130,13 @@ public enum ScalarEncoding {
132130
*/
133131
SEVEN_BIT(2, (byte) 7, 8),
134132
/**
135-
* Each dimension is quantized to a single bit and packed into bytes.
133+
* Each dimension is quantized to a single bit and packed into bytes. During query time, the
134+
* query vector is quantized to 4 bits per dimension.
136135
*
137136
* <p>This is the most space efficient encoding, and will produce an index 8x smaller than
138-
* {@link #UNSIGNED_BYTE}. However, this comes at the cost of accuracy. This encoding is
139-
* recommended for use when the number of dimensions is high (e.g. &gt; 128) and with an
140-
* asymmetric quantization scheme where query vectors are quantized to 4 bits.
137+
* {@link #UNSIGNED_BYTE}. However, this comes at the cost of accuracy.
141138
*/
142-
SINGLE_BIT(3, (byte) 1, 1);
139+
SINGLE_BIT_QUERY_NIBBLE(3, (byte) 1, 1, (byte) 4, 4);
143140

144141
public static ScalarEncoding fromNumBits(int bits) {
145142
for (ScalarEncoding encoding : values()) {
@@ -153,13 +150,27 @@ public static ScalarEncoding fromNumBits(int bits) {
153150
/** The number used to identify this encoding on the wire, rather than relying on ordinal. */
154151
private final int wireNumber;
155152

156-
private final byte bits;
157-
private final int bitsPerDim;
153+
private final byte bits, queryBits;
154+
private final int bitsPerDim, queryBitsPerDim;
158155

159156
ScalarEncoding(int wireNumber, byte bits, int bitsPerDim) {
160157
this.wireNumber = wireNumber;
161158
this.bits = bits;
159+
this.queryBits = bits;
162160
this.bitsPerDim = bitsPerDim;
161+
this.queryBitsPerDim = bitsPerDim;
162+
}
163+
164+
ScalarEncoding(int wireNumber, byte bits, int bitsPerDim, byte queryBits, int queryBitsPerDim) {
165+
this.wireNumber = wireNumber;
166+
this.bits = bits;
167+
this.queryBits = queryBits;
168+
this.bitsPerDim = bitsPerDim;
169+
this.queryBitsPerDim = queryBitsPerDim;
170+
}
171+
172+
boolean isAsymmetric() {
173+
return bits != queryBits;
163174
}
164175

165176
int getWireNumber() {
@@ -171,20 +182,48 @@ public byte getBits() {
171182
return bits;
172183
}
173184

185+
public byte getQueryBits() {
186+
return queryBits;
187+
}
188+
174189
/** Return the number of dimensions rounded up to fit into whole bytes. */
175190
public int getDiscreteDimensions(int dimensions) {
176-
int totalBits = dimensions * bitsPerDim;
177-
return (totalBits + 7) / 8 * 8 / bitsPerDim;
191+
if (queryBits == bits) {
192+
int totalBits = dimensions * bitsPerDim;
193+
return (totalBits + 7) / 8 * 8 / bitsPerDim;
194+
}
195+
int queryDiscretized = (dimensions * queryBitsPerDim + 7) / 8 * 8 / queryBitsPerDim;
196+
int docDiscretized = (dimensions * bitsPerDim + 7) / 8 * 8 / bitsPerDim;
197+
int maxDiscretized = Math.max(queryDiscretized, docDiscretized);
198+
assert maxDiscretized % (8.0 / queryBitsPerDim) == 0
199+
: "bad discretized=" + maxDiscretized + " for dim=" + dimensions;
200+
assert maxDiscretized % (8.0 / bitsPerDim) == 0
201+
: "bad discretized=" + maxDiscretized + " for dim=" + dimensions;
202+
return maxDiscretized;
178203
}
179204

180205
/** Return the number of dimensions that can be packed into a single byte. */
181-
public int getBitsPerDim() {
206+
public int getDocBitsPerDim() {
182207
return this.bitsPerDim;
183208
}
184209

210+
public int getQueryBitsPerDim() {
211+
return this.queryBitsPerDim;
212+
}
213+
185214
/** Return the number of bytes required to store a packed vector of the given dimensions. */
186-
public int getPackedLength(int dimensions) {
187-
return (dimensions * bitsPerDim + 7) / 8;
215+
public int getDocPackedLength(int dimensions) {
216+
int discretized = getDiscreteDimensions(dimensions);
217+
// how many bytes do we need to store the quantized vector?
218+
int totalBits = discretized * bitsPerDim;
219+
return (totalBits + 7) / 8;
220+
}
221+
222+
public int getQueryPackedLength(int dimensions) {
223+
int discretized = getDiscreteDimensions(dimensions);
224+
// how many bytes do we need to store the quantized vector?
225+
int totalBits = discretized * queryBitsPerDim;
226+
return (totalBits + 7) / 8;
188227
}
189228

190229
/** Returns the encoding for the given wire number, or empty if unknown. */
@@ -203,35 +242,16 @@ public Lucene104ScalarQuantizedVectorsFormat() {
203242
this(ScalarEncoding.UNSIGNED_BYTE);
204243
}
205244

206-
/** Creates a new instance with the chosen symmetric quantization encoding. */
245+
/** Creates a new instance with the chosen quantization encoding. */
207246
public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding) {
208-
this(encoding, encoding);
209-
}
210-
211-
/** Creates a new instance with the chosen asymmetric quantization encoding. */
212-
public Lucene104ScalarQuantizedVectorsFormat(
213-
ScalarEncoding encoding, ScalarEncoding queryEncoding) {
214247
super(NAME);
215248
this.encoding = encoding;
216-
this.queryEncoding = queryEncoding;
217-
// until we have optimized scorers for various other asymmetric encodings, maybe we only allow 1
218-
// bit -> 4 bit
219-
// Technically, we should be able to do 2 bit -> 4 bit, and 1, 2 -> 8, and 4 -> 8. But these
220-
// will take time to
221-
// have optimized scorers, and we don't want users to accidentally use poorly optimized
222-
// combinations.
223-
if (encoding != queryEncoding) {
224-
if (encoding != ScalarEncoding.SINGLE_BIT || queryEncoding != ScalarEncoding.PACKED_NIBBLE) {
225-
throw new IllegalArgumentException(
226-
"Only SINGLE_BIT -> PACKED_NIBBLE asymmetric encoding is supported");
227-
}
228-
}
229249
}
230250

231251
@Override
232252
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
233253
return new Lucene104ScalarQuantizedVectorsWriter(
234-
state, encoding, queryEncoding, rawVectorFormat.fieldsWriter(state), scorer);
254+
state, encoding, rawVectorFormat.fieldsWriter(state), scorer);
235255
}
236256

237257
@Override
@@ -251,8 +271,6 @@ public String toString() {
251271
+ NAME
252272
+ ", encoding="
253273
+ encoding
254-
+ ", queryEncoding="
255-
+ queryEncoding
256274
+ ", flatVectorScorer="
257275
+ scorer
258276
+ ", rawVectorFormat="

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
141141

142142
long numQuantizedVectorBytes =
143143
Math.multiplyExact(
144-
(fieldEntry.scalarEncoding.getPackedLength(dimension)
144+
(fieldEntry.scalarEncoding.getDocPackedLength(dimension)
145145
+ (Float.BYTES * 3)
146146
+ Integer.BYTES),
147147
(long) fieldEntry.size);
@@ -173,7 +173,6 @@ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) th
173173
fi.size,
174174
new OptimizedScalarQuantizer(fi.similarityFunction),
175175
fi.scalarEncoding,
176-
fi.queryEncoding,
177176
fi.similarityFunction,
178177
vectorScorer,
179178
fi.centroid,
@@ -217,7 +216,6 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
217216
fi.size,
218217
new OptimizedScalarQuantizer(fi.similarityFunction),
219218
fi.scalarEncoding,
220-
fi.queryEncoding,
221219
fi.similarityFunction,
222220
vectorScorer,
223221
fi.centroid,
@@ -367,7 +365,6 @@ public org.apache.lucene.util.quantization.QuantizedByteVectorValues getQuantize
367365
fi.size,
368366
new OptimizedScalarQuantizer(fi.similarityFunction),
369367
fi.scalarEncoding,
370-
fi.queryEncoding,
371368
fi.similarityFunction,
372369
vectorScorer,
373370
fi.centroid,
@@ -411,7 +408,6 @@ private record FieldEntry(
411408
long vectorDataLength,
412409
int size,
413410
ScalarEncoding scalarEncoding,
414-
ScalarEncoding queryEncoding,
415411
float[] centroid,
416412
float centroidDP,
417413
OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) {
@@ -428,7 +424,6 @@ static FieldEntry create(
428424
final float[] centroid;
429425
float centroidDP = 0;
430426
ScalarEncoding scalarEncoding = ScalarEncoding.UNSIGNED_BYTE;
431-
ScalarEncoding queryEncoding = ScalarEncoding.UNSIGNED_BYTE;
432427
if (size > 0) {
433428
int wireNumber = input.readVInt();
434429
scalarEncoding =
@@ -437,13 +432,6 @@ static FieldEntry create(
437432
() ->
438433
new IllegalStateException(
439434
"Could not get ScalarEncoding from wire number: " + wireNumber));
440-
int queryWireNumber = input.readVInt();
441-
queryEncoding =
442-
ScalarEncoding.fromWireNumber(queryWireNumber)
443-
.orElseThrow(
444-
() ->
445-
new IllegalStateException(
446-
"Could not get ScalarEncoding from wire number: " + queryWireNumber));
447435
centroid = new float[dimension];
448436
input.readFloats(centroid, 0, dimension);
449437
centroidDP = Float.intBitsToFloat(input.readInt());
@@ -460,7 +448,6 @@ static FieldEntry create(
460448
vectorDataLength,
461449
size,
462450
scalarEncoding,
463-
queryEncoding,
464451
centroid,
465452
centroidDP,
466453
conf);

0 commit comments

Comments
 (0)