From f46920af9d9c80e9515afcd0cbea81aa1f2f88a2 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 17 Jun 2024 11:01:37 -0400 Subject: [PATCH 1/5] Add off-heap scalar quantized scoring --- .../codecs/hnsw/FlatVectorScorerUtil.java | 9 + .../Lucene99ScalarQuantizedVectorsFormat.java | 7 +- .../DefaultVectorizationProvider.java | 7 + .../vectorization/VectorizationProvider.java | 3 + ...gmentScalarQuantizedFlatVectorsScorer.java | 102 +++++++++ ...orySegmentScalarQuantizedVectorScorer.java | 161 ++++++++++++++ ...ntScalarQuantizedVectorScorerSupplier.java | 210 ++++++++++++++++++ .../PanamaVectorUtilSupport.java | 37 +-- .../PanamaVectorizationProvider.java | 5 + 9 files changed, 522 insertions(+), 19 deletions(-) create mode 100644 lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java create mode 100644 lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java create mode 100644 lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java index 808d7b3cc882..46efcfd0599d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java @@ -37,4 +37,13 @@ private FlatVectorScorerUtil() {} public static FlatVectorsScorer getLucene99FlatVectorsScorer() { return IMPL.getLucene99FlatVectorsScorer(); } + + + /** + * Returns a FlatVectorsScorer that supports the Lucene99 format. Scorers retrieved through this + * method may be optimized on certain platforms. Otherwise, a DefaultFlatVectorScorer is returned. + */ + public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return IMPL.getLucene99ScalarQuantizedVectorsScorer(); + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index 552260894a8d..82b7f6db5f91 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -22,9 +22,11 @@ import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.util.VectorUtil; /** * Format supporting vector quantization, storage, and retrieval @@ -68,7 +70,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { final byte bits; final boolean compress; - final Lucene99ScalarQuantizedVectorScorer flatVectorScorer; + final FlatVectorsScorer flatVectorScorer; /** Constructs a format using default graph construction parameters */ public Lucene99ScalarQuantizedVectorsFormat() { @@ -109,8 +111,7 @@ public Lucene99ScalarQuantizedVectorsFormat( this.bits = (byte) bits; this.confidenceInterval = confidenceInterval; this.compress = compress; - this.flatVectorScorer = - new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); + this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer(); } public static float calculateDefaultConfidenceInterval(int vectorDimension) { diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java index c5193aa23de2..145d615620e0 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java @@ -19,6 +19,8 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; +import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter; /** Default provider returning scalar implementations. */ final class DefaultVectorizationProvider extends VectorizationProvider { @@ -38,4 +40,9 @@ public VectorUtilSupport getVectorUtilSupport() { public FlatVectorsScorer getLucene99FlatVectorsScorer() { return DefaultFlatVectorScorer.INSTANCE; } + + @Override + public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return new Lucene99ScalarQuantizedVectorScorer(getLucene99FlatVectorsScorer()); + } } diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java index a236c303eb4a..199ac315a494 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java @@ -95,6 +95,9 @@ public static VectorizationProvider getInstance() { /** Returns a FlatVectorsScorer that supports the Lucene99 format. */ public abstract FlatVectorsScorer getLucene99FlatVectorsScorer(); + /** Returns a FlatVectorsScorer that supports scalar quantized vectors in the Lucene99 format. */ + public abstract FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer(); + // *** Lookup mechanism: *** private static final Logger LOG = Logger.getLogger(VectorizationProvider.class.getName()); diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java new file mode 100644 index 000000000000..9542311bbd17 --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; + +public class Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer implements FlatVectorsScorer { + + private final FlatVectorsScorer delegate; + + public Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer(FlatVectorsScorer delegate) { + this.delegate = delegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityType, RandomAccessVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + // Unoptimized edge case, we don't optimize compressed 4-bit quantization with Euclidean similarity + // So, we delegate to the default scorer + if (quantizedByteVectorValues.getScalarQuantizer().getBits() == 4 + && similarityType == VectorSimilarityFunction.EUCLIDEAN + // Indicates that the vector is compressed as the byte length is not equal to the dimension count + && (vectorValues.getVectorByteLength() - Float.BYTES) != vectorValues.dimension() + ) { + return delegate.getRandomVectorScorer(similarityType, vectorValues, target); + } + } + return delegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityType, + RandomAccessVectorValues vectorValues, + float[] target) + throws IOException { + if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + // Unoptimized edge case, we don't optimize compressed 4-bit quantization with Euclidean similarity + // So, we delegate to the default scorer + if (quantizedByteVectorValues.getScalarQuantizer().getBits() == 4 + && similarityType == VectorSimilarityFunction.EUCLIDEAN + // Indicates that the vector is compressed as the byte length is not equal to the dimension count + && (vectorValues.getVectorByteLength() - Float.BYTES) != vectorValues.dimension() + ) { + return delegate.getRandomVectorScorer(similarityType, vectorValues, target); + } + checkDimensions(queryVector.length, vectorValues.dimension()); + ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); + byte[] targetBytes = new byte[target.length]; + float offsetCorrection = + quantizeQuery(target, targetBytes, similarityType, scalarQuantizer); + // TODO similarity + } + return delegate.getRandomVectorScorer(similarityType, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityType, + RandomAccessVectorValues vectorValues, + byte[] queryVector) + throws IOException { + return delegate.getRandomVectorScorer(similarityType, vectorValues, target); + } + + static void checkDimensions(int queryLen, int fieldLen) { + if (queryLen != fieldLen) { + throw new IllegalArgumentException( + "vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen); + } + } + + @Override + public String toString() { + return "Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer()"; + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..0fdd236057be --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer + extends RandomVectorScorer.AbstractRandomVectorScorer { + + final int vectorByteSize; + final MemorySegmentAccessInput input; + final MemorySegment query; + byte[] scratch; + + /** + * Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is + * returned. + */ + public static Optional create( + VectorSimilarityFunction similarityType, + byte[] targetBytes, + float offsetCorrection, + float constMultiplier, + byte bits, + RandomAccessQuantizedByteVectorValues values + ) { + IndexInput input = values.getSlice(); + if (input == null) { + return Optional.empty(); + } + input = FilterIndexInput.unwrapOnlyTest(input); + if (!(input instanceof MemorySegmentAccessInput msInput)) { + return Optional.empty(); + } + checkInvariants(values.size(), values.getVectorByteLength(), input); + return switch (type) { + case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector)); + case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector)); + case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector)); + case MAXIMUM_INNER_PRODUCT -> Optional.of( + new MaxInnerProductScorer(msInput, values, queryVector)); + }; + } + + Lucene99MemorySegmentScalarQuantizedVectorScorer( + MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector) { + super(values); + this.input = input; + this.vectorByteLength = values.getVectorByteLength(); + this.trueVectorByteSize = values.getVectorByteLength() - Float.Bytes; + this.query = MemorySegment.ofArray(queryVector); + } + + final MemorySegment getSegment(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = (long) ord * vectorByteLength; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteLength); + if (seg == null) { + if (scratch == null) { + scratch = new byte[trueVectorByteSize]; + } + input.readBytes(byteOffset, scratch, 0, trueVectorByteSize); + seg = MemorySegment.ofArray(scratch); + } + return seg; + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd()) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + static final class DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + DotProductScorer( + MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + super(input, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len + float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); + return 0.5f + raw / (float) (query.byteSize() * (1 << 15)); + } + } + + static final class Int4DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + Int4DotProductScorer( + MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + super(input, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len + float raw = PanamaVectorUtilSupport.int4DotProduct(query, getSegment(node)); + return 0.5f + raw / (float) (query.byteSize() * (1 << 15)); + } + } + + static final class EuclideanScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + super(input, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node)); + return 1 / (1f + raw); + } + } + + static final class MaxInnerProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + MaxInnerProductScorer( + MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + super(input, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); + if (raw < 0) { + return 1 / (1 + -1 * raw); + } + return raw + 1; + } + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java new file mode 100644 index 000000000000..544e5f102c33 --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +/** A score supplier of vectors whose element size is byte. */ +public abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier + implements RandomVectorScorerSupplier { + final int vectorByteSize; + final int maxOrd; + final MemorySegmentAccessInput input; + final RandomAccessVectorValues values; // to support ordToDoc/getAcceptOrds + byte[] scratch1, scratch2; + + /** + * Return an optional whose value, if present, is the scorer supplier. Otherwise, an empty + * optional is returned. + */ + static Optional create( + VectorSimilarityFunction type, IndexInput input, RandomAccessVectorValues values) { + input = FilterIndexInput.unwrapOnlyTest(input); + if (!(input instanceof MemorySegmentAccessInput msInput)) { + return Optional.empty(); + } + checkInvariants(values.size(), values.getVectorByteLength(), input); + return switch (type) { + case COSINE -> Optional.of(new CosineSupplier(msInput, values)); + case DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values)); + case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values)); + }; + } + + Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier( + MemorySegmentAccessInput input, RandomAccessVectorValues values) { + this.input = input; + this.values = values; + this.vectorByteSize = values.getVectorByteLength(); + this.maxOrd = values.size(); + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + final MemorySegment getFirstSegment(int ord) throws IOException { + long byteOffset = (long) ord * vectorByteSize; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch1 == null) { + scratch1 = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch1, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch1); + } + return seg; + } + + final MemorySegment getSecondSegment(int ord) throws IOException { + long byteOffset = (long) ord * vectorByteSize; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch2 == null) { + scratch2 = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch2, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch2); + } + return seg; + } + + static final class CosineSupplier extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + + CosineSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorer scorer(int ord) { + checkOrdinal(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = PanamaVectorUtilSupport.cosine(getFirstSegment(ord), getSecondSegment(node)); + return (1 + raw) / 2; + } + }; + } + + @Override + public CosineSupplier copy() throws IOException { + return new CosineSupplier(input.clone(), values); + } + } + + static final class DotProductSupplier extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + + DotProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorer scorer(int ord) { + checkOrdinal(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len + float raw = + PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node)); + return 0.5f + raw / (float) (values.dimension() * (1 << 15)); + } + }; + } + + @Override + public DotProductSupplier copy() throws IOException { + return new DotProductSupplier(input.clone(), values); + } + } + + static final class EuclideanSupplier extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + + EuclideanSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorer scorer(int ord) { + checkOrdinal(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = + PanamaVectorUtilSupport.squareDistance(getFirstSegment(ord), getSecondSegment(node)); + return 1 / (1f + raw); + } + }; + } + + @Override + public EuclideanSupplier copy() throws IOException { + return new EuclideanSupplier(input.clone(), values); + } + } + + static final class MaxInnerProductSupplier extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + + MaxInnerProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorer scorer(int ord) { + checkOrdinal(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = + PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node)); + if (raw < 0) { + return 1 / (1 + -1 * raw); + } + return raw + 1; + } + }; + } + + @Override + public MaxInnerProductSupplier copy() throws IOException { + return new MaxInnerProductSupplier(input.clone(), values); + } + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index 867d0c684cbe..ec8186f71602 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -401,29 +401,34 @@ private static int dotProductBody128(MemorySegment a, MemorySegment b, int limit @Override public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { + assert (apacked && bpacked) == false; + return int4DotProduct(MemorySegment.ofArray(a), apacked, MemorySegment.ofArray(b), bpacked); + } + + public static int public int int4DotProduct(MemorySegment a, boolean apacked, MemorySegment b, boolean bpacked) { assert (apacked && bpacked) == false; int i = 0; int res = 0; if (apacked || bpacked) { - byte[] packed = apacked ? a : b; - byte[] unpacked = apacked ? b : a; - if (packed.length >= 32) { + MemorySegment packed = apacked ? a : b; + MemorySegment unpacked = apacked ? b : a; + if (packed.byteSize() >= 32) { if (VECTOR_BITSIZE >= 512) { - i += ByteVector.SPECIES_256.loopBound(packed.length); + i += ByteVector.SPECIES_256.loopBound(packed.byteSize()); res += dotProductBody512Int4Packed(unpacked, packed, i); } else if (VECTOR_BITSIZE == 256) { - i += ByteVector.SPECIES_128.loopBound(packed.length); + i += ByteVector.SPECIES_128.loopBound(packed.byteSize()); res += dotProductBody256Int4Packed(unpacked, packed, i); } else if (HAS_FAST_INTEGER_VECTORS) { - i += ByteVector.SPECIES_64.loopBound(packed.length); + i += ByteVector.SPECIES_64.loopBound(packed.byteSize()); res += dotProductBody128Int4Packed(unpacked, packed, i); } } // scalar tail - for (; i < packed.length; i++) { - byte packedByte = packed[i]; - byte unpacked1 = unpacked[i]; - byte unpacked2 = unpacked[i + packed.length]; + for (; i < packed.byteSize(); i++) { + byte packedByte = packed.get(JAVA_BYTE, i); + byte unpacked1 = unpacked.get(JAVA_BYTE, i); + byte unpacked2 = unpacked.get(JAVA_BYTE, i + packed.getByte()); res += (packedByte & 0x0F) * unpacked2; res += ((packedByte & 0xFF) >> 4) * unpacked1; } @@ -435,15 +440,15 @@ public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) res += int4DotProductBody128(a, b, i); } // scalar tail - for (; i < a.length; i++) { - res += b[i] * a[i]; + for (; i < a.byteSize(); i++) { + res += b.get(JAVA_BYTE, i) * a.get(JAVA_BYTE, i); } } return res; } - private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limit) { + private static int dotProductBody512Int4Packed(MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 4096) { @@ -476,7 +481,7 @@ private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limi return sum; } - private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limit) { + private static int dotProductBody256Int4Packed(MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 2048) { @@ -510,7 +515,7 @@ private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limi } /** vectorized dot product body (128 bit vectors) */ - private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limit) { + private static int dotProductBody128Int4Packed(MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 1024) { @@ -545,7 +550,7 @@ private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limi return sum; } - private int int4DotProductBody128(byte[] a, byte[] b, int limit) { + private static int int4DotProductBody128(MemorySegment a, MemorySegment b, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 1024) { diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java index 87f7cf2baf76..e56c2abaf2c0 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java @@ -79,4 +79,9 @@ public VectorUtilSupport getVectorUtilSupport() { public FlatVectorsScorer getLucene99FlatVectorsScorer() { return Lucene99MemorySegmentFlatVectorsScorer.INSTANCE; } + + @Override + public FlatVectorsScorer getLucene99FlatVectorsScorer() { + return new Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer(new Lucene99ScalarQuantizedVectorScorer(getLucene99FlatVectorsScorer())); + } } From 32ce6022137782cb0eb9d67e628626e761d797c2 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 17 Jun 2024 11:19:55 -0400 Subject: [PATCH 2/5] iter --- ...orySegmentScalarQuantizedVectorScorer.java | 54 +++++++++++-------- .../PanamaVectorUtilSupport.java | 24 ++++----- 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java index 0fdd236057be..ef1c96e4d75c 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java @@ -32,6 +32,7 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer final int vectorByteSize; final MemorySegmentAccessInput input; final MemorySegment query; + final float constMultiplier; byte[] scratch; /** @@ -56,16 +57,16 @@ public static Optional create( } checkInvariants(values.size(), values.getVectorByteLength(), input); return switch (type) { - case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector)); - case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector)); - case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector)); + case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector, constMultiplier, offsetCorrection)); + case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector, constMultiplier, offsetCorrection)); + case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector, constMultiplier)); case MAXIMUM_INNER_PRODUCT -> Optional.of( - new MaxInnerProductScorer(msInput, values, queryVector)); + new MaxInnerProductScorer(msInput, values, queryVector, offsetCorrection)); }; } Lucene99MemorySegmentScalarQuantizedVectorScorer( - MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector) { + MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector, float constMultiplier) { super(values); this.input = input; this.vectorByteLength = values.getVectorByteLength(); @@ -101,50 +102,57 @@ final void checkOrdinal(int ord) { static final class DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { DotProductScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { - super(input, values, query); + MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) { + super(input, values, query, constMultiplier); } @Override public float score(int node) throws IOException { checkOrdinal(node); - // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); - return 0.5f + raw / (float) (query.byteSize() * (1 << 15)); + float vectorOffset = values.getScoreCorrectionConstant(node); + // For the current implementation of scalar quantization, all dotproducts should be >= 0; + assert dotProduct >= 0; + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return Math.max((1 + adjustedDistance) / 2, 0); } } static final class Int4DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { Int4DotProductScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { - super(input, values, query); + MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) { + super(input, values, query, constMultiplier); } @Override public float score(int node) throws IOException { checkOrdinal(node); - // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len - float raw = PanamaVectorUtilSupport.int4DotProduct(query, getSegment(node)); - return 0.5f + raw / (float) (query.byteSize() * (1 << 15)); + float raw = PanamaVectorUtilSupport.int4DotProduct(query, false, getSegment(node), false); + float vectorOffset = values.getScoreCorrectionConstant(node); + // For the current implementation of scalar quantization, all dotproducts should be >= 0; + assert dotProduct >= 0; + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return Math.max((1 + adjustedDistance) / 2, 0); } } static final class EuclideanScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { - EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { - super(input, values, query); + EuclideanScorer(MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier) { + super(input, values, query, constMultiplier); } @Override public float score(int node) throws IOException { checkOrdinal(node); float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node)); - return 1 / (1f + raw); + float adjustedDistance = raw * constMultiplier; + return 1 / (1f + adjustedDistance); } } static final class MaxInnerProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { MaxInnerProductScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) { super(input, values, query); } @@ -152,10 +160,14 @@ static final class MaxInnerProductScorer extends Lucene99MemorySegmentScalarQuan public float score(int node) throws IOException { checkOrdinal(node); float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); - if (raw < 0) { - return 1 / (1 + -1 * raw); + float vectorOffset = values.getScoreCorrectionConstant(node); + // For the current implementation of scalar quantization, all dotproducts should be >= 0; + assert dotProduct >= 0; + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + if (adjustedDistance < 0) { + return 1 / (1 + -1 * adjustedDistance); } - return raw + 1; + return adjustedDistance + 1; } } } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index ec8186f71602..b6ac4892ec5a 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -457,9 +457,9 @@ private static int dotProductBody512Int4Packed(MemorySegment unpacked, MemorySeg int innerLimit = Math.min(limit - i, 4096); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { // packed - var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j); + var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, packed, i + j); // unpacked - var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length); + var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j + packed.length); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -467,7 +467,7 @@ private static int dotProductBody512Int4Packed(MemorySegment unpacked, MemorySeg acc0 = acc0.add(prod16); // lower - ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j); + ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); acc1 = acc1.add(prod16a); @@ -490,9 +490,9 @@ private static int dotProductBody256Int4Packed(MemorySegment unpacked, MemorySeg int innerLimit = Math.min(limit - i, 2048); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { // packed - var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j); + var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, packed, i + j); // unpacked - var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length); + var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j + packed.length); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -500,7 +500,7 @@ private static int dotProductBody256Int4Packed(MemorySegment unpacked, MemorySeg acc0 = acc0.add(prod16); // lower - ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j); + ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); acc1 = acc1.add(prod16a); @@ -524,7 +524,7 @@ private static int dotProductBody128Int4Packed(MemorySegment unpacked, MemorySeg int innerLimit = Math.min(limit - i, 1024); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { // packed - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j); + ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, packed, i + j); // unpacked ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length); @@ -536,7 +536,7 @@ private static int dotProductBody128Int4Packed(MemorySegment unpacked, MemorySeg acc0 = acc0.add(prod16.and((short) 0xFF)); // lower - va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j); + va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, unpacked, i + j); prod8 = vb8.lanewise(LSHR, 4).mul(va8); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc1 = acc1.add(prod16.and((short) 0xFF)); @@ -558,15 +558,15 @@ private static int int4DotProductBody128(MemorySegment a, MemorySegment b, int l ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); int innerLimit = Math.min(limit - i, 1024); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { - ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j); - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j); + ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j); + ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j); ByteVector prod8 = va8.mul(vb8); ShortVector prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc0 = acc0.add(prod16.and((short) 0xFF)); - va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8); - vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8); + va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j + 8); + vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j + 8); prod8 = va8.mul(vb8); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc1 = acc1.add(prod16.and((short) 0xFF)); From 86dac9baaa2cf7d75e3bf0c93e7e27545968382c Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 17 Jun 2024 16:18:36 -0400 Subject: [PATCH 3/5] iter --- .../codecs/hnsw/FlatVectorScorerUtil.java | 1 - .../Lucene99ScalarQuantizedVectorsFormat.java | 2 - .../DefaultVectorizationProvider.java | 1 - ...gmentScalarQuantizedFlatVectorsScorer.java | 61 +++-- ...orySegmentScalarQuantizedVectorScorer.java | 152 ++++++++--- ...ntScalarQuantizedVectorScorerSupplier.java | 246 ++++++++++++++---- .../PanamaVectorUtilSupport.java | 56 ++-- .../PanamaVectorizationProvider.java | 6 +- 8 files changed, 400 insertions(+), 125 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java index 46efcfd0599d..7a92c077bbc7 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java @@ -38,7 +38,6 @@ public static FlatVectorsScorer getLucene99FlatVectorsScorer() { return IMPL.getLucene99FlatVectorsScorer(); } - /** * Returns a FlatVectorsScorer that supports the Lucene99 format. Scorers retrieved through this * method may be optimized on certain platforms. Otherwise, a DefaultFlatVectorScorer is returned. diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index 82b7f6db5f91..3533080c963a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -18,7 +18,6 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; -import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; @@ -26,7 +25,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.util.VectorUtil; /** * Format supporting vector quantization, storage, and retrieval diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java index 145d615620e0..84f464627aca 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java @@ -20,7 +20,6 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; -import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter; /** Default provider returning scalar implementations. */ final class DefaultVectorizationProvider extends VectorizationProvider { diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java index 9542311bbd17..cceee059d96f 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java @@ -19,7 +19,6 @@ import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery; import java.io.IOException; -import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; @@ -40,43 +39,65 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityType, RandomAccessVectorValues vectorValues) throws IOException { if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { - // Unoptimized edge case, we don't optimize compressed 4-bit quantization with Euclidean similarity + // Unoptimized edge case, we don't optimize compressed 4-bit quantization with Euclidean + // similarity // So, we delegate to the default scorer if (quantizedByteVectorValues.getScalarQuantizer().getBits() == 4 - && similarityType == VectorSimilarityFunction.EUCLIDEAN - // Indicates that the vector is compressed as the byte length is not equal to the dimension count - && (vectorValues.getVectorByteLength() - Float.BYTES) != vectorValues.dimension() - ) { - return delegate.getRandomVectorScorer(similarityType, vectorValues, target); + && similarityType == VectorSimilarityFunction.EUCLIDEAN + // Indicates that the vector is compressed as the byte length is not equal to the + // dimension count + && (vectorValues.getVectorByteLength() - Float.BYTES) != vectorValues.dimension()) { + return delegate.getRandomVectorScorerSupplier(similarityType, vectorValues); + } + var scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); + var scalarScorerSupplier = + Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.create( + similarityType, + scalarQuantizer.getBits(), + scalarQuantizer.getConstantMultiplier(), + quantizedByteVectorValues); + if (scalarScorerSupplier.isPresent()) { + return scalarScorerSupplier.get(); } } - return delegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + return delegate.getRandomVectorScorerSupplier(similarityType, vectorValues); } @Override public RandomVectorScorer getRandomVectorScorer( VectorSimilarityFunction similarityType, RandomAccessVectorValues vectorValues, - float[] target) + float[] queryVector) throws IOException { if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { - // Unoptimized edge case, we don't optimize compressed 4-bit quantization with Euclidean similarity + // Unoptimized edge case, we don't optimize compressed 4-bit quantization with Euclidean + // similarity // So, we delegate to the default scorer if (quantizedByteVectorValues.getScalarQuantizer().getBits() == 4 && similarityType == VectorSimilarityFunction.EUCLIDEAN - // Indicates that the vector is compressed as the byte length is not equal to the dimension count - && (vectorValues.getVectorByteLength() - Float.BYTES) != vectorValues.dimension() - ) { - return delegate.getRandomVectorScorer(similarityType, vectorValues, target); + // Indicates that the vector is compressed as the byte length is not equal to the + // dimension count + && (vectorValues.getVectorByteLength() - Float.BYTES) != vectorValues.dimension()) { + return delegate.getRandomVectorScorer(similarityType, vectorValues, queryVector); } checkDimensions(queryVector.length, vectorValues.dimension()); - ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); - byte[] targetBytes = new byte[target.length]; + var scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); + byte[] targetBytes = new byte[queryVector.length]; float offsetCorrection = - quantizeQuery(target, targetBytes, similarityType, scalarQuantizer); - // TODO similarity + quantizeQuery(queryVector, targetBytes, similarityType, scalarQuantizer); + var scalarScorer = + Lucene99MemorySegmentScalarQuantizedVectorScorer.create( + similarityType, + targetBytes, + offsetCorrection, + scalarQuantizer.getConstantMultiplier(), + scalarQuantizer.getBits(), + quantizedByteVectorValues); + if (scalarScorer.isPresent()) { + return scalarScorer.get(); + } } - return delegate.getRandomVectorScorer(similarityType, vectorValues, target); + return delegate.getRandomVectorScorer(similarityType, vectorValues, queryVector); } @Override @@ -85,7 +106,7 @@ public RandomVectorScorer getRandomVectorScorer( RandomAccessVectorValues vectorValues, byte[] queryVector) throws IOException { - return delegate.getRandomVectorScorer(similarityType, vectorValues, target); + return delegate.getRandomVectorScorer(similarityType, vectorValues, queryVector); } static void checkDimensions(int queryLen, int fieldLen) { diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java index ef1c96e4d75c..421510ff4a38 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java @@ -23,13 +23,13 @@ import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { - final int vectorByteSize; + final int vectorByteLength, trueVectorByteSize; final MemorySegmentAccessInput input; final MemorySegment query; final float constMultiplier; @@ -40,13 +40,12 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer * returned. */ public static Optional create( - VectorSimilarityFunction similarityType, - byte[] targetBytes, - float offsetCorrection, - float constMultiplier, - byte bits, - RandomAccessQuantizedByteVectorValues values - ) { + VectorSimilarityFunction similarityType, + byte[] targetBytes, + float offsetCorrection, + float constMultiplier, + byte bits, + RandomAccessQuantizedByteVectorValues values) { IndexInput input = values.getSlice(); if (input == null) { return Optional.empty(); @@ -56,22 +55,47 @@ public static Optional create( return Optional.empty(); } checkInvariants(values.size(), values.getVectorByteLength(), input); - return switch (type) { - case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector, constMultiplier, offsetCorrection)); - case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector, constMultiplier, offsetCorrection)); - case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector, constMultiplier)); - case MAXIMUM_INNER_PRODUCT -> Optional.of( - new MaxInnerProductScorer(msInput, values, queryVector, offsetCorrection)); + final boolean compressed = (values.getVectorByteLength() - Float.BYTES) != values.dimension(); + if (compressed) { + assert bits == 4; + assert (values.getVectorByteLength() - Float.BYTES) == values.dimension() / 2; + } + return switch (similarityType) { + case COSINE, DOT_PRODUCT -> { + if (bits == 4) { + yield Optional.of( + new Int4DotProductScorer( + msInput, values, targetBytes, constMultiplier, offsetCorrection, compressed)); + } + yield Optional.of( + new DotProductScorer(msInput, values, targetBytes, constMultiplier, offsetCorrection)); + } + case EUCLIDEAN -> Optional.of( + new EuclideanScorer(msInput, values, targetBytes, constMultiplier)); + case MAXIMUM_INNER_PRODUCT -> { + if (bits == 4) { + yield Optional.of( + new Int4MaxInnerProductScorer( + msInput, values, targetBytes, constMultiplier, offsetCorrection, compressed)); + } + yield Optional.of( + new MaxInnerProductScorer( + msInput, values, targetBytes, constMultiplier, offsetCorrection)); + } }; } Lucene99MemorySegmentScalarQuantizedVectorScorer( - MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] queryVector, float constMultiplier) { + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] queryVector, + float constMultiplier) { super(values); this.input = input; this.vectorByteLength = values.getVectorByteLength(); - this.trueVectorByteSize = values.getVectorByteLength() - Float.Bytes; + this.trueVectorByteSize = values.getVectorByteLength() - Float.BYTES; this.query = MemorySegment.ofArray(queryVector); + this.constMultiplier = constMultiplier; } final MemorySegment getSegment(int ord) throws IOException { @@ -88,6 +112,13 @@ final MemorySegment getSegment(int ord) throws IOException { return seg; } + final float getOffsetCorrection(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = ((long) ord * vectorByteLength) + trueVectorByteSize; + int floatInts = input.readInt(byteOffset); + return Float.intBitsToFloat(floatInts); + } + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { if (input.length() < (long) vectorByteLength * maxOrd) { throw new IllegalArgumentException("input length is less than expected vector data"); @@ -101,16 +132,23 @@ final void checkOrdinal(int ord) { } static final class DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + private final float offsetCorrection; + DotProductScorer( - MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) { + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] query, + float constMultiplier, + float offsetCorrection) { super(input, values, query, constMultiplier); + this.offsetCorrection = offsetCorrection; } @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); - float vectorOffset = values.getScoreCorrectionConstant(node); + float dotProduct = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); + float vectorOffset = getOffsetCorrection(node); // For the current implementation of scalar quantization, all dotproducts should be >= 0; assert dotProduct >= 0; float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; @@ -119,16 +157,27 @@ public float score(int node) throws IOException { } static final class Int4DotProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + private final boolean compressed; + private final float offsetCorrection; + Int4DotProductScorer( - MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) { + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] query, + float constMultiplier, + float offsetCorrection, + boolean compressed) { super(input, values, query, constMultiplier); + this.compressed = compressed; + this.offsetCorrection = offsetCorrection; } @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = PanamaVectorUtilSupport.int4DotProduct(query, false, getSegment(node), false); - float vectorOffset = values.getScoreCorrectionConstant(node); + float dotProduct = + PanamaVectorUtilSupport.int4DotProduct(query, false, getSegment(node), compressed); + float vectorOffset = getOffsetCorrection(node); // For the current implementation of scalar quantization, all dotproducts should be >= 0; assert dotProduct >= 0; float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; @@ -137,7 +186,11 @@ public float score(int node) throws IOException { } static final class EuclideanScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { - EuclideanScorer(MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier) { + EuclideanScorer( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] query, + float constMultiplier) { super(input, values, query, constMultiplier); } @@ -150,17 +203,58 @@ public float score(int node) throws IOException { } } - static final class MaxInnerProductScorer extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + static final class MaxInnerProductScorer + extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + private final float offsetCorrection; + MaxInnerProductScorer( - MemorySegmentAccessInput input, RandomAccessQuantizedByteVectorValues values, byte[] query, float constMultiplier, float offsetCorrection) { - super(input, values, query); + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] query, + float constMultiplier, + float offsetCorrection) { + super(input, values, query, constMultiplier); + this.offsetCorrection = offsetCorrection; + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float dotProduct = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); + float vectorOffset = getOffsetCorrection(node); + // For the current implementation of scalar quantization, all dotproducts should be >= 0; + assert dotProduct >= 0; + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + if (adjustedDistance < 0) { + return 1 / (1 + -1 * adjustedDistance); + } + return adjustedDistance + 1; + } + } + + static final class Int4MaxInnerProductScorer + extends Lucene99MemorySegmentScalarQuantizedVectorScorer { + private final boolean compressed; + private final float offsetCorrection; + + Int4MaxInnerProductScorer( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] query, + float constMultiplier, + float offsetCorrection, + boolean compressed) { + super(input, values, query, constMultiplier); + this.compressed = compressed; + this.offsetCorrection = offsetCorrection; } @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); - float vectorOffset = values.getScoreCorrectionConstant(node); + float dotProduct = + PanamaVectorUtilSupport.int4DotProduct(query, false, getSegment(node), compressed); + float vectorOffset = getOffsetCorrection(node); // For the current implementation of scalar quantization, all dotproducts should be >= 0; assert dotProduct >= 0; float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java index 544e5f102c33..acd40a46acd4 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java @@ -23,44 +23,80 @@ import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; /** A score supplier of vectors whose element size is byte. */ public abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier implements RandomVectorScorerSupplier { - final int vectorByteSize; + + final int vectorByteSize, trueVectorByteSize; final int maxOrd; final MemorySegmentAccessInput input; - final RandomAccessVectorValues values; // to support ordToDoc/getAcceptOrds + final RandomAccessQuantizedByteVectorValues values; // to support ordToDoc/getAcceptOrds byte[] scratch1, scratch2; + final int scratch1Size; + final float constMultiplier; /** * Return an optional whose value, if present, is the scorer supplier. Otherwise, an empty * optional is returned. */ static Optional create( - VectorSimilarityFunction type, IndexInput input, RandomAccessVectorValues values) { + VectorSimilarityFunction type, + byte bits, + float constMultiplier, + RandomAccessQuantizedByteVectorValues values) { + IndexInput input = values.getSlice(); + if (input == null) { + return Optional.empty(); + } input = FilterIndexInput.unwrapOnlyTest(input); if (!(input instanceof MemorySegmentAccessInput msInput)) { return Optional.empty(); } + final boolean compressed = (values.getVectorByteLength() - Float.BYTES) != values.dimension(); + if (compressed) { + assert bits == 4; + assert (values.getVectorByteLength() - Float.BYTES) == values.dimension() / 2; + } checkInvariants(values.size(), values.getVectorByteLength(), input); return switch (type) { - case COSINE -> Optional.of(new CosineSupplier(msInput, values)); - case DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values)); - case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values)); - case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values)); + case COSINE, DOT_PRODUCT -> { + if (bits == 4) { + yield Optional.of( + new Int4DotProductSupplier(msInput, values, constMultiplier, compressed)); + } + yield Optional.of(new DotProductSupplier(msInput, values, constMultiplier)); + } + case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values, constMultiplier)); + case MAXIMUM_INNER_PRODUCT -> { + if (bits == 4) { + yield Optional.of( + new Int4MaxInnerProductSupplier(msInput, values, constMultiplier, compressed)); + } + yield Optional.of(new MaxInnerProductSupplier(msInput, values, constMultiplier)); + } }; } Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier( - MemorySegmentAccessInput input, RandomAccessVectorValues values) { + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier) { this.input = input; this.values = values; this.vectorByteSize = values.getVectorByteLength(); + this.trueVectorByteSize = (values.getVectorByteLength() - Float.BYTES); + if (values.dimension() != trueVectorByteSize) { + assert values.dimension() == trueVectorByteSize / 2; + this.scratch1Size = trueVectorByteSize; + } else { + this.scratch1Size = trueVectorByteSize / 2; + } this.maxOrd = values.size(); + this.constMultiplier = constMultiplier; } static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { @@ -69,6 +105,20 @@ static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) } } + static void decompressBytes(byte[] compressed, int numBytes) { + if (numBytes == compressed.length) { + return; + } + if (numBytes << 1 != compressed.length) { + throw new IllegalArgumentException( + "numBytes: " + numBytes + " does not match compressed length: " + compressed.length); + } + for (int i = 0; i < numBytes; ++i) { + compressed[numBytes + i] = (byte) (compressed[i] & 0x0F); + compressed[i] = (byte) ((compressed[i] & 0xFF) >> 4); + } + } + final void checkOrdinal(int ord) { if (ord < 0 || ord >= maxOrd) { throw new IllegalArgumentException("illegal ordinal: " + ord); @@ -78,11 +128,16 @@ final void checkOrdinal(int ord) { final MemorySegment getFirstSegment(int ord) throws IOException { long byteOffset = (long) ord * vectorByteSize; MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); - if (seg == null) { + // we always read and decompress the full vector if the value is compressed + // Generally, this is OK, as the scorer is used many times after the initial decompression + if (seg == null || values.dimension() != trueVectorByteSize) { if (scratch1 == null) { - scratch1 = new byte[vectorByteSize]; + scratch1 = new byte[this.scratch1Size]; + } + input.readBytes(byteOffset, scratch1, 0, trueVectorByteSize); + if (values.dimension() != trueVectorByteSize) { + decompressBytes(scratch1, scratch1.length); } - input.readBytes(byteOffset, scratch1, 0, vectorByteSize); seg = MemorySegment.ofArray(scratch1); } return seg; @@ -93,118 +148,209 @@ final MemorySegment getSecondSegment(int ord) throws IOException { MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); if (seg == null) { if (scratch2 == null) { - scratch2 = new byte[vectorByteSize]; + scratch2 = new byte[values.dimension()]; } - input.readBytes(byteOffset, scratch2, 0, vectorByteSize); + input.readBytes(byteOffset, scratch2, 0, trueVectorByteSize); seg = MemorySegment.ofArray(scratch2); } return seg; } - static final class CosineSupplier extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + final float getOffsetCorrection(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = ((long) ord * vectorByteSize) + trueVectorByteSize; + int floatInts = input.readInt(byteOffset); + return Float.intBitsToFloat(floatInts); + } + + static final class DotProductSupplier + extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { - CosineSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { - super(input, values); + DotProductSupplier( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier) { + super(input, values, constMultiplier); } @Override - public RandomVectorScorer scorer(int ord) { + public RandomVectorScorer scorer(int ord) throws IOException { checkOrdinal(ord); + MemorySegment querySegment = getFirstSegment(ord); + float offsetCorrection = getOffsetCorrection(ord); return new RandomVectorScorer.AbstractRandomVectorScorer(values) { @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = PanamaVectorUtilSupport.cosine(getFirstSegment(ord), getSecondSegment(node)); - return (1 + raw) / 2; + MemorySegment nodeSegment = getSecondSegment(node); + float dotProduct = PanamaVectorUtilSupport.dotProduct(querySegment, nodeSegment); + float nodeOffsetCorrection = getOffsetCorrection(node); + assert dotProduct >= 0; + float adjustedDistance = + dotProduct * constMultiplier + offsetCorrection + nodeOffsetCorrection; + return Math.max((1 + adjustedDistance) / 2, 0); } }; } @Override - public CosineSupplier copy() throws IOException { - return new CosineSupplier(input.clone(), values); + public DotProductSupplier copy() throws IOException { + return new DotProductSupplier(input.clone(), values, constMultiplier); } } - static final class DotProductSupplier extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + static final class Int4DotProductSupplier + extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + + private final boolean compressed; - DotProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { - super(input, values); + Int4DotProductSupplier( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier, + boolean compressed) { + super(input, values, constMultiplier); + this.compressed = compressed; } @Override - public RandomVectorScorer scorer(int ord) { + public RandomVectorScorer scorer(int ord) throws IOException { checkOrdinal(ord); + MemorySegment querySegment = getFirstSegment(ord); + float offsetCorrection = getOffsetCorrection(ord); return new RandomVectorScorer.AbstractRandomVectorScorer(values) { @Override public float score(int node) throws IOException { checkOrdinal(node); - // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len - float raw = - PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node)); - return 0.5f + raw / (float) (values.dimension() * (1 << 15)); + MemorySegment nodeSegment = getSecondSegment(node); + float dotProduct = + PanamaVectorUtilSupport.int4DotProduct(querySegment, false, nodeSegment, compressed); + float nodeOffsetCorrection = getOffsetCorrection(node); + assert dotProduct >= 0; + float adjustedDistance = + dotProduct * constMultiplier + offsetCorrection + nodeOffsetCorrection; + return Math.max((1 + adjustedDistance) / 2, 0); } }; } @Override - public DotProductSupplier copy() throws IOException { - return new DotProductSupplier(input.clone(), values); + public Int4DotProductSupplier copy() throws IOException { + return new Int4DotProductSupplier(input.clone(), values, constMultiplier, compressed); } } - static final class EuclideanSupplier extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + static final class EuclideanSupplier + extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { - EuclideanSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { - super(input, values); + EuclideanSupplier( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier) { + super(input, values, constMultiplier); } @Override - public RandomVectorScorer scorer(int ord) { + public RandomVectorScorer scorer(int ord) throws IOException { checkOrdinal(ord); + MemorySegment querySegment = getFirstSegment(ord); return new RandomVectorScorer.AbstractRandomVectorScorer(values) { @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = - PanamaVectorUtilSupport.squareDistance(getFirstSegment(ord), getSecondSegment(node)); - return 1 / (1f + raw); + float raw = PanamaVectorUtilSupport.squareDistance(querySegment, getSecondSegment(node)); + float adjustedDistance = raw * constMultiplier; + return 1 / (1f + adjustedDistance); } }; } @Override public EuclideanSupplier copy() throws IOException { - return new EuclideanSupplier(input.clone(), values); + return new EuclideanSupplier(input.clone(), values, constMultiplier); } } - static final class MaxInnerProductSupplier extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + static final class MaxInnerProductSupplier + extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { - MaxInnerProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { - super(input, values); + MaxInnerProductSupplier( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier) { + super(input, values, constMultiplier); } @Override - public RandomVectorScorer scorer(int ord) { + public RandomVectorScorer scorer(int ord) throws IOException { checkOrdinal(ord); + MemorySegment querySegment = getFirstSegment(ord); + float offsetCorrection = getOffsetCorrection(ord); return new RandomVectorScorer.AbstractRandomVectorScorer(values) { @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = - PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node)); - if (raw < 0) { - return 1 / (1 + -1 * raw); + MemorySegment nodeSegment = getSecondSegment(node); + float dotProduct = PanamaVectorUtilSupport.dotProduct(querySegment, nodeSegment); + float nodeOffsetCorrection = getOffsetCorrection(node); + assert dotProduct >= 0; + float adjustedDistance = + dotProduct * constMultiplier + offsetCorrection + nodeOffsetCorrection; + if (adjustedDistance < 0) { + return 1 / (1 + -1 * adjustedDistance); } - return raw + 1; + return adjustedDistance + 1; } }; } @Override public MaxInnerProductSupplier copy() throws IOException { - return new MaxInnerProductSupplier(input.clone(), values); + return new MaxInnerProductSupplier(input.clone(), values, constMultiplier); + } + } + + static final class Int4MaxInnerProductSupplier + extends Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier { + + private final boolean compressed; + + Int4MaxInnerProductSupplier( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + float constMultiplier, + boolean compressed) { + super(input, values, constMultiplier); + this.compressed = compressed; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + checkOrdinal(ord); + MemorySegment querySegment = getFirstSegment(ord); + float offsetCorrection = getOffsetCorrection(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + MemorySegment nodeSegment = getSecondSegment(node); + float dotProduct = + PanamaVectorUtilSupport.int4DotProduct(querySegment, false, nodeSegment, compressed); + float nodeOffsetCorrection = getOffsetCorrection(node); + assert dotProduct >= 0; + float adjustedDistance = + dotProduct * constMultiplier + offsetCorrection + nodeOffsetCorrection; + if (adjustedDistance < 0) { + return 1 / (1 + -1 * adjustedDistance); + } + return adjustedDistance + 1; + } + }; + } + + @Override + public Int4MaxInnerProductSupplier copy() throws IOException { + return new Int4MaxInnerProductSupplier(input.clone(), values, constMultiplier, compressed); } } } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index b6ac4892ec5a..f4e410cb7b29 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -405,7 +405,8 @@ public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) return int4DotProduct(MemorySegment.ofArray(a), apacked, MemorySegment.ofArray(b), bpacked); } - public static int public int int4DotProduct(MemorySegment a, boolean apacked, MemorySegment b, boolean bpacked) { + public static int int4DotProduct( + MemorySegment a, boolean apacked, MemorySegment b, boolean bpacked) { assert (apacked && bpacked) == false; int i = 0; int res = 0; @@ -428,15 +429,15 @@ public static int public int int4DotProduct(MemorySegment a, boolean apacked, Me for (; i < packed.byteSize(); i++) { byte packedByte = packed.get(JAVA_BYTE, i); byte unpacked1 = unpacked.get(JAVA_BYTE, i); - byte unpacked2 = unpacked.get(JAVA_BYTE, i + packed.getByte()); + byte unpacked2 = unpacked.get(JAVA_BYTE, i + packed.byteSize()); res += (packedByte & 0x0F) * unpacked2; res += ((packedByte & 0xFF) >> 4) * unpacked1; } } else { if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { return dotProduct(a, b); - } else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) { - i += ByteVector.SPECIES_128.loopBound(a.length); + } else if (a.byteSize() >= 32 && HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_128.loopBound(a.byteSize()); res += int4DotProductBody128(a, b, i); } // scalar tail @@ -448,7 +449,8 @@ public static int public int int4DotProduct(MemorySegment a, boolean apacked, Me return res; } - private static int dotProductBody512Int4Packed(MemorySegment unpacked, MemorySegment packed, int limit) { + private static int dotProductBody512Int4Packed( + MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 4096) { @@ -457,9 +459,12 @@ private static int dotProductBody512Int4Packed(MemorySegment unpacked, MemorySeg int innerLimit = Math.min(limit - i, 4096); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { // packed - var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, packed, i + j); + var vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_256, packed, i + j, LITTLE_ENDIAN); // unpacked - var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j + packed.length); + var va8 = + ByteVector.fromMemorySegment( + ByteVector.SPECIES_256, unpacked, i + j + packed.byteSize(), LITTLE_ENDIAN); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -467,7 +472,8 @@ private static int dotProductBody512Int4Packed(MemorySegment unpacked, MemorySeg acc0 = acc0.add(prod16); // lower - ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j); + ByteVector vc8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_256, unpacked, i + j, LITTLE_ENDIAN); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); acc1 = acc1.add(prod16a); @@ -481,7 +487,8 @@ private static int dotProductBody512Int4Packed(MemorySegment unpacked, MemorySeg return sum; } - private static int dotProductBody256Int4Packed(MemorySegment unpacked, MemorySegment packed, int limit) { + private static int dotProductBody256Int4Packed( + MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 2048) { @@ -490,9 +497,12 @@ private static int dotProductBody256Int4Packed(MemorySegment unpacked, MemorySeg int innerLimit = Math.min(limit - i, 2048); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { // packed - var vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, packed, i + j); + var vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_128, packed, i + j, LITTLE_ENDIAN); // unpacked - var va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j + packed.length); + var va8 = + ByteVector.fromMemorySegment( + ByteVector.SPECIES_128, unpacked, i + j + packed.byteSize(), LITTLE_ENDIAN); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -500,7 +510,8 @@ private static int dotProductBody256Int4Packed(MemorySegment unpacked, MemorySeg acc0 = acc0.add(prod16); // lower - ByteVector vc8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j); + ByteVector vc8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_128, unpacked, i + j, LITTLE_ENDIAN); ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); acc1 = acc1.add(prod16a); @@ -515,7 +526,8 @@ private static int dotProductBody256Int4Packed(MemorySegment unpacked, MemorySeg } /** vectorized dot product body (128 bit vectors) */ - private static int dotProductBody128Int4Packed(MemorySegment unpacked, MemorySegment packed, int limit) { + private static int dotProductBody128Int4Packed( + MemorySegment unpacked, MemorySegment packed, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 1024) { @@ -524,10 +536,12 @@ private static int dotProductBody128Int4Packed(MemorySegment unpacked, MemorySeg int innerLimit = Math.min(limit - i, 1024); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { // packed - ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, packed, i + j); + ByteVector vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_64, packed, i + j, LITTLE_ENDIAN); // unpacked ByteVector va8 = - ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length); + ByteVector.fromMemorySegment( + ByteVector.SPECIES_64, unpacked, i + j + packed.byteSize(), LITTLE_ENDIAN); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); @@ -536,7 +550,7 @@ private static int dotProductBody128Int4Packed(MemorySegment unpacked, MemorySeg acc0 = acc0.add(prod16.and((short) 0xFF)); // lower - va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, unpacked, i + j); + va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, unpacked, i + j, LITTLE_ENDIAN); prod8 = vb8.lanewise(LSHR, 4).mul(va8); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc1 = acc1.add(prod16.and((short) 0xFF)); @@ -558,15 +572,17 @@ private static int int4DotProductBody128(MemorySegment a, MemorySegment b, int l ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); int innerLimit = Math.min(limit - i, 1024); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { - ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j); - ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j); + ByteVector va8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j, LITTLE_ENDIAN); + ByteVector vb8 = + ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j, LITTLE_ENDIAN); ByteVector prod8 = va8.mul(vb8); ShortVector prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc0 = acc0.add(prod16.and((short) 0xFF)); - va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j + 8); - vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j + 8); + va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i + j + 8, LITTLE_ENDIAN); + vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i + j + 8, LITTLE_ENDIAN); prod8 = va8.mul(vb8); prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); acc1 = acc1.add(prod16.and((short) 0xFF)); diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java index e56c2abaf2c0..2201e5f89336 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java @@ -22,6 +22,7 @@ import java.util.logging.Logger; import jdk.incubator.vector.FloatVector; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; import org.apache.lucene.util.Constants; import org.apache.lucene.util.SuppressForbidden; @@ -81,7 +82,8 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() { } @Override - public FlatVectorsScorer getLucene99FlatVectorsScorer() { - return new Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer(new Lucene99ScalarQuantizedVectorScorer(getLucene99FlatVectorsScorer())); + public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { + return new Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer( + new Lucene99ScalarQuantizedVectorScorer(getLucene99FlatVectorsScorer())); } } From f2c3f853ff0396963e06cd3ccf0dc19e26037a5f Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 17 Jun 2024 18:06:00 -0400 Subject: [PATCH 4/5] iter --- ...gmentScalarQuantizedFlatVectorsScorer.java | 4 +-- ...orySegmentScalarQuantizedVectorScorer.java | 20 +++++------ ...ntScalarQuantizedVectorScorerSupplier.java | 35 ++++++++----------- 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java index cceee059d96f..89a7e96674e0 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer.java @@ -46,7 +46,7 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( && similarityType == VectorSimilarityFunction.EUCLIDEAN // Indicates that the vector is compressed as the byte length is not equal to the // dimension count - && (vectorValues.getVectorByteLength() - Float.BYTES) != vectorValues.dimension()) { + && vectorValues.getVectorByteLength() != vectorValues.dimension()) { return delegate.getRandomVectorScorerSupplier(similarityType, vectorValues); } var scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); @@ -77,7 +77,7 @@ public RandomVectorScorer getRandomVectorScorer( && similarityType == VectorSimilarityFunction.EUCLIDEAN // Indicates that the vector is compressed as the byte length is not equal to the // dimension count - && (vectorValues.getVectorByteLength() - Float.BYTES) != vectorValues.dimension()) { + && vectorValues.getVectorByteLength() != vectorValues.dimension()) { return delegate.getRandomVectorScorer(similarityType, vectorValues, queryVector); } checkDimensions(queryVector.length, vectorValues.dimension()); diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java index 421510ff4a38..c6ef7f6cea83 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorer.java @@ -29,7 +29,7 @@ abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { - final int vectorByteLength, trueVectorByteSize; + final int vectorByteSize, vectorByteOffset; final MemorySegmentAccessInput input; final MemorySegment query; final float constMultiplier; @@ -55,10 +55,10 @@ public static Optional create( return Optional.empty(); } checkInvariants(values.size(), values.getVectorByteLength(), input); - final boolean compressed = (values.getVectorByteLength() - Float.BYTES) != values.dimension(); + final boolean compressed = values.getVectorByteLength() != values.dimension(); if (compressed) { assert bits == 4; - assert (values.getVectorByteLength() - Float.BYTES) == values.dimension() / 2; + assert values.getVectorByteLength() == values.dimension() / 2; } return switch (similarityType) { case COSINE, DOT_PRODUCT -> { @@ -92,21 +92,21 @@ public static Optional create( float constMultiplier) { super(values); this.input = input; - this.vectorByteLength = values.getVectorByteLength(); - this.trueVectorByteSize = values.getVectorByteLength() - Float.BYTES; + this.vectorByteSize = values.getVectorByteLength(); + this.vectorByteOffset = values.getVectorByteLength() + Float.BYTES; this.query = MemorySegment.ofArray(queryVector); this.constMultiplier = constMultiplier; } final MemorySegment getSegment(int ord) throws IOException { checkOrdinal(ord); - long byteOffset = (long) ord * vectorByteLength; - MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteLength); + long byteOffset = (long) ord * vectorByteOffset; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); if (seg == null) { if (scratch == null) { - scratch = new byte[trueVectorByteSize]; + scratch = new byte[vectorByteSize]; } - input.readBytes(byteOffset, scratch, 0, trueVectorByteSize); + input.readBytes(byteOffset, scratch, 0, vectorByteSize); seg = MemorySegment.ofArray(scratch); } return seg; @@ -114,7 +114,7 @@ final MemorySegment getSegment(int ord) throws IOException { final float getOffsetCorrection(int ord) throws IOException { checkOrdinal(ord); - long byteOffset = ((long) ord * vectorByteLength) + trueVectorByteSize; + long byteOffset = ((long) ord * vectorByteOffset) + vectorByteSize; int floatInts = input.readInt(byteOffset); return Float.intBitsToFloat(floatInts); } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java index acd40a46acd4..787bfcfbf316 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier.java @@ -31,12 +31,11 @@ public abstract sealed class Lucene99MemorySegmentScalarQuantizedVectorScorerSupplier implements RandomVectorScorerSupplier { - final int vectorByteSize, trueVectorByteSize; + final int vectorByteSize, vectorByteOffset; final int maxOrd; final MemorySegmentAccessInput input; final RandomAccessQuantizedByteVectorValues values; // to support ordToDoc/getAcceptOrds byte[] scratch1, scratch2; - final int scratch1Size; final float constMultiplier; /** @@ -56,10 +55,10 @@ static Optional create( if (!(input instanceof MemorySegmentAccessInput msInput)) { return Optional.empty(); } - final boolean compressed = (values.getVectorByteLength() - Float.BYTES) != values.dimension(); + final boolean compressed = values.getVectorByteLength() != values.dimension(); if (compressed) { assert bits == 4; - assert (values.getVectorByteLength() - Float.BYTES) == values.dimension() / 2; + assert values.getVectorByteLength() == values.dimension() / 2; } checkInvariants(values.size(), values.getVectorByteLength(), input); return switch (type) { @@ -88,13 +87,7 @@ static Optional create( this.input = input; this.values = values; this.vectorByteSize = values.getVectorByteLength(); - this.trueVectorByteSize = (values.getVectorByteLength() - Float.BYTES); - if (values.dimension() != trueVectorByteSize) { - assert values.dimension() == trueVectorByteSize / 2; - this.scratch1Size = trueVectorByteSize; - } else { - this.scratch1Size = trueVectorByteSize / 2; - } + this.vectorByteOffset = values.getVectorByteLength() + Float.BYTES; this.maxOrd = values.size(); this.constMultiplier = constMultiplier; } @@ -126,17 +119,17 @@ final void checkOrdinal(int ord) { } final MemorySegment getFirstSegment(int ord) throws IOException { - long byteOffset = (long) ord * vectorByteSize; + long byteOffset = (long) ord * vectorByteOffset; MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); // we always read and decompress the full vector if the value is compressed // Generally, this is OK, as the scorer is used many times after the initial decompression - if (seg == null || values.dimension() != trueVectorByteSize) { + if (seg == null || values.dimension() != vectorByteSize) { if (scratch1 == null) { - scratch1 = new byte[this.scratch1Size]; + scratch1 = new byte[values.dimension()]; } - input.readBytes(byteOffset, scratch1, 0, trueVectorByteSize); - if (values.dimension() != trueVectorByteSize) { - decompressBytes(scratch1, scratch1.length); + input.readBytes(byteOffset, scratch1, 0, vectorByteSize); + if (values.dimension() != vectorByteSize) { + decompressBytes(scratch1, vectorByteSize); } seg = MemorySegment.ofArray(scratch1); } @@ -144,13 +137,13 @@ final MemorySegment getFirstSegment(int ord) throws IOException { } final MemorySegment getSecondSegment(int ord) throws IOException { - long byteOffset = (long) ord * vectorByteSize; + long byteOffset = (long) ord * vectorByteOffset; MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); if (seg == null) { if (scratch2 == null) { - scratch2 = new byte[values.dimension()]; + scratch2 = new byte[vectorByteSize]; } - input.readBytes(byteOffset, scratch2, 0, trueVectorByteSize); + input.readBytes(byteOffset, scratch2, 0, vectorByteSize); seg = MemorySegment.ofArray(scratch2); } return seg; @@ -158,7 +151,7 @@ final MemorySegment getSecondSegment(int ord) throws IOException { final float getOffsetCorrection(int ord) throws IOException { checkOrdinal(ord); - long byteOffset = ((long) ord * vectorByteSize) + trueVectorByteSize; + long byteOffset = ((long) ord * vectorByteOffset) + vectorByteSize; int floatInts = input.readInt(byteOffset); return Float.intBitsToFloat(floatInts); } From fb73b2bd90efc8d7dfdb8d21daa47a615fc57b22 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:58:24 -0400 Subject: [PATCH 5/5] iter --- .../internal/vectorization/PanamaVectorizationProvider.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java index 022c985aab39..d29a75888e9e 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java @@ -90,7 +90,7 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() { @Override public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { return new Lucene99MemorySegmentScalarQuantizedFlatVectorsScorer( - new Lucene99ScalarQuantizedVectorScorer(getLucene99FlatVectorsScorer())); + new Lucene99ScalarQuantizedVectorScorer(getLucene99FlatVectorsScorer())); } @Override