Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
Expand Down Expand Up @@ -242,7 +243,7 @@ public ByteVectorValues getByteVectorValues(String field) {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
Expand All @@ -260,7 +261,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
vectorValues,
fieldEntry.similarityFunction,
getGraphValues(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry),
getAcceptOrds(acceptDocs != null ? acceptDocs.getBits() : null, fieldEntry),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set the contract of this method so that acceptDocs is not allowed to be null? This would help save all these (annoying) null checks, and it wouldn't be a great burden on the caller side?

knnCollector.visitLimit(),
random);
knnCollector.incVisitedCount(results.visitedCount());
Expand All @@ -273,7 +274,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
Expand Down Expand Up @@ -238,7 +239,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
Expand All @@ -253,11 +254,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry));
getAcceptOrds(acceptDocs != null ? acceptDocs.getBits() : null, fieldEntry));
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
Expand Down Expand Up @@ -236,7 +236,7 @@ public ByteVectorValues getByteVectorValues(String field) {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
Expand All @@ -251,11 +251,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
vectorValues.getAcceptOrds(acceptDocs != null ? acceptDocs.getBits() : null));
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
Expand Down Expand Up @@ -270,7 +270,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
Expand All @@ -285,11 +285,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
vectorValues.getAcceptOrds(acceptDocs != null ? acceptDocs.getBits() : null));
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
Expand All @@ -304,7 +304,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
vectorValues.getAcceptOrds(acceptDocs != null ? acceptDocs.getBits() : null));
}

private HnswGraph getGraph(FieldEntry entry) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
Expand Down Expand Up @@ -290,7 +290,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
Expand All @@ -314,11 +314,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
vectorValues.getAcceptOrds(acceptDocs != null ? acceptDocs.getBits() : null));
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
Expand All @@ -342,7 +342,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
vectorValues.getAcceptOrds(acceptDocs != null ? acceptDocs.getBits() : null));
}

/** Get knn graph values; used for testing */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.BufferedChecksumIndexInput;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IOUtils;
Expand Down Expand Up @@ -181,7 +181,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
FloatVectorValues values = getFloatVectorValues(field);
if (target.length != values.dimension()) {
Expand All @@ -195,7 +195,9 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
for (int ord = 0; ord < values.size(); ord++) {
int doc = values.ordToDoc(ord);
if (acceptDocs != null && acceptDocs.get(doc) == false) {
if (acceptDocs != null
&& acceptDocs.getBits() != null
&& acceptDocs.getBits().get(doc) == false) {
continue;
}

Expand All @@ -211,7 +213,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
ByteVectorValues values = getByteVectorValues(field);
if (target.length != values.dimension()) {
Expand All @@ -226,7 +228,9 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits

for (int ord = 0; ord < values.size(); ord++) {
int doc = values.ordToDoc(ord);
if (acceptDocs != null && acceptDocs.get(doc) == false) {
if (acceptDocs != null
&& acceptDocs.getBits() != null
&& acceptDocs.getBits().get(doc) == false) {
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.NamedSPILoader;

/**
Expand Down Expand Up @@ -140,13 +140,13 @@ public ByteVectorValues getByteVectorValues(String field) {

@Override
public void search(
String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {
String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {
throw new UnsupportedOperationException();
}

@Override
public void search(
String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {
String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
Expand Down Expand Up @@ -88,7 +89,8 @@ protected KnnVectorsReader() {}
* if they are all allowed to match.
*/
public abstract void search(
String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException;
String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException;

/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
Expand Down Expand Up @@ -116,7 +118,8 @@ public abstract void search(
* if they are all allowed to match.
*/
public abstract void search(
String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException;
String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException;

/**
* Returns an instance optimized for merging. This instance may only be consumed in the thread
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomVectorScorer;

/**
Expand Down Expand Up @@ -56,13 +56,13 @@ public FlatVectorsScorer getFlatVectorScorer() {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
// don't scan stored field data. If we didn't index it, produce no search results
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
// don't scan stored field data. If we didn't index it, produce no search results
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
Expand Down Expand Up @@ -223,20 +224,20 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
throws IOException {
if (knnCollector.k() == 0) return;
final RandomVectorScorer scorer = getRandomVectorScorer(field, target);
if (scorer == null) return;
OrdinalTranslatedKnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs != null ? acceptDocs.getBits() : null);
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
collector.collect(i, scorer.score(i));
Expand Down
Loading
Loading