Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* [BUGFIX] [Remote Vector Index Build] Don't fall back to CPU on terminal failures [#2773](https://github.com/opensearch-project/k-NN/pull/2773)
* Add KNN timing info to core profiler [#2785](https://github.com/opensearch-project/k-NN/pull/2785)
* Add "exact_knn" query clause type [#2826](https://github.com/opensearch-project/k-NN/pull/2826)
* Add expandNested support for "exact_knn" query [#2848](https://github.com/opensearch-project/k-NN/pull/2848)

### Bug Fixes
* Fix @ collision in NativeMemoryCacheKeyHelper for vector index filenames containing @ characters [#2810](https://github.com/opensearch-project/k-NN/pull/2810)
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ public ExactKNNByteQuery(
String indexName,
VectorDataType vectorDataType,
BitSetProducer parentFilter,
boolean expandNested,
byte[] byteQueryVector
) {
super(field, spaceType, indexName, vectorDataType, parentFilter);
super(field, spaceType, indexName, vectorDataType, parentFilter, expandNested);
this.byteQueryVector = byteQueryVector;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ public ExactKNNFloatQuery(
String indexName,
VectorDataType vectorDataType,
BitSetProducer parentFilter,
boolean expandNested,
float[] queryVector
) {
super(field, spaceType, indexName, vectorDataType, parentFilter);
super(field, spaceType, indexName, vectorDataType, parentFilter, expandNested);
this.queryVector = queryVector;
}

Expand Down
16 changes: 13 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/ExactKNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,25 @@ public abstract class ExactKNNQuery extends Query {
private final String indexName;
private final VectorDataType vectorDataType;
private BitSetProducer parentFilter;
private boolean expandNested;
@Setter
@Getter
private boolean explain;

protected ExactKNNQuery(String field, String spaceType, String indexName, VectorDataType vectorDataType, BitSetProducer parentFilter) {
protected ExactKNNQuery(
String field,
String spaceType,
String indexName,
VectorDataType vectorDataType,
BitSetProducer parentFilter,
boolean expandNested
) {
this.field = field;
this.spaceType = spaceType;
this.indexName = indexName;
this.vectorDataType = vectorDataType;
this.parentFilter = parentFilter;
this.expandNested = expandNested;
}

@Override
Expand All @@ -60,7 +69,7 @@ public String toString(String field) {

@Override
public int hashCode() {
return Objects.hash(field, spaceType, indexName, vectorDataType, parentFilter);
return Objects.hash(field, spaceType, indexName, vectorDataType, parentFilter, expandNested);
}

@Override
Expand All @@ -73,6 +82,7 @@ public boolean equalsTo(ExactKNNQuery other) {
return Objects.equals(field, other.field)
&& Objects.equals(spaceType, other.spaceType)
&& Objects.equals(indexName, other.indexName)
&& Objects.equals(parentFilter, other.parentFilter);
&& Objects.equals(parentFilter, other.parentFilter)
&& Objects.equals(expandNested, other.expandNested);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Locale;
import java.util.Objects;

import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;

/**
Expand All @@ -45,6 +46,7 @@ public class ExactKNNQueryBuilder extends AbstractQueryBuilder<ExactKNNQueryBuil
public static final ParseField VECTOR_FIELD = new ParseField("vector");
public static final ParseField SPACE_TYPE_FIELD = new ParseField("space_type");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static final ParseField EXPAND_NESTED_FIELD = new ParseField(EXPAND_NESTED);

/**
* The name for the knn exact query
Expand All @@ -58,6 +60,8 @@ public class ExactKNNQueryBuilder extends AbstractQueryBuilder<ExactKNNQueryBuil
private String spaceType;
@Getter
private boolean ignoreUnmapped;
@Getter
private Boolean expandNested;

public static class Builder {
private String fieldName;
Expand All @@ -66,6 +70,7 @@ public static class Builder {
private boolean ignoreUnmapped;
private String queryName;
private float boost = DEFAULT_BOOST;
private Boolean expandNested;

public Builder() {}

Expand Down Expand Up @@ -99,9 +104,14 @@ public Builder boost(float boost) {
return this;
}

public Builder expandNested(Boolean expandNested) {
this.expandNested = expandNested;
return this;
}

public ExactKNNQueryBuilder build() {
validate();
return new ExactKNNQueryBuilder(fieldName, vector, spaceType, ignoreUnmapped).boost(boost).queryName(queryName);
return new ExactKNNQueryBuilder(fieldName, vector, spaceType, ignoreUnmapped, expandNested).boost(boost).queryName(queryName);
}

private void validate() {
Expand Down Expand Up @@ -135,6 +145,7 @@ public ExactKNNQueryBuilder(StreamInput in) throws IOException {
vector = builder.vector;
spaceType = builder.spaceType;
ignoreUnmapped = builder.ignoreUnmapped;
expandNested = builder.expandNested;
}

@Override
Expand Down Expand Up @@ -194,11 +205,27 @@ protected Query doToQuery(QueryShardContext context) {
}
// validate byteVector here because binary/hamming does not support float vectors
resolvedSpaceType.validateVector(byteVector);
return new ExactKNNByteQuery(fieldName, resolvedSpaceType.getValue(), indexName, vectorDataType, parentFilter, byteVector);
return new ExactKNNByteQuery(
fieldName,
resolvedSpaceType.getValue(),
indexName,
vectorDataType,
parentFilter,
expandNested == null ? false : expandNested,
byteVector
);
// FloatQuery used for bytes + floats because bytes are packed in floats
case BYTE, FLOAT:
resolvedSpaceType.validateVector(vector);
return new ExactKNNFloatQuery(fieldName, resolvedSpaceType.getValue(), indexName, vectorDataType, parentFilter, vector);
return new ExactKNNFloatQuery(
fieldName,
resolvedSpaceType.getValue(),
indexName,
vectorDataType,
parentFilter,
expandNested == null ? false : expandNested,
vector
);
default:
throw new IllegalStateException("Unsupported vector data type found.");
}
Expand All @@ -209,12 +236,13 @@ protected boolean doEquals(ExactKNNQueryBuilder other) {
return Objects.equals(fieldName, other.fieldName)
&& Arrays.equals(vector, other.vector)
&& Objects.equals(ignoreUnmapped, other.ignoreUnmapped)
&& Objects.equals(spaceType, other.spaceType);
&& Objects.equals(spaceType, other.spaceType)
&& Objects.equals(expandNested, other.expandNested);
}

@Override
protected int doHashCode() {
return Objects.hash(fieldName, Arrays.hashCode(vector), ignoreUnmapped, spaceType);
return Objects.hash(fieldName, Arrays.hashCode(vector), ignoreUnmapped, spaceType, expandNested);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ private ExactSearcherContext createExactSearcherContext(LeafReaderContext leafRe
.matchedDocsIterator(matchedDocsIterator)
.parentsFilter(exactKNNQuery.getParentFilter())
.exactKNNSpaceType(exactKNNQuery.getSpaceType())
.expandNested(exactKNNQuery.isExpandNested())
.build();
default:
return ExactSearcher.ExactSearcherContext.builder()
Expand All @@ -102,6 +103,7 @@ private ExactSearcherContext createExactSearcherContext(LeafReaderContext leafRe
.matchedDocsIterator(matchedDocsIterator)
.parentsFilter(exactKNNQuery.getParentFilter())
.exactKNNSpaceType(exactKNNQuery.getSpaceType())
.expandNested(exactKNNQuery.isExpandNested())
.build();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea
}
final SpaceType spaceType = resolvedSpaceType;
boolean isNestedRequired = exactSearcherContext.getParentsFilter() != null;
boolean isExpandNested = exactSearcherContext.isExpandNested();

if (VectorDataType.BINARY == vectorDataType) {
KNNVectorValues<byte[]> vectorValues;
Expand All @@ -208,7 +209,8 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea
exactSearcherContext.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType,
exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext)
exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext),
isExpandNested
);
}
return new BinaryVectorIdsKNNIterator(
Expand All @@ -227,7 +229,8 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea
exactSearcherContext.getFloatQueryVector(),
(KNNByteVectorValues) vectorValues,
spaceType,
exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext)
exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext),
isExpandNested
);
}
return new ByteVectorIdsKNNIterator(
Expand Down Expand Up @@ -271,7 +274,8 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea
spaceType,
exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext),
quantizedQueryVector,
segmentLevelQuantizationInfo
segmentLevelQuantizationInfo,
isExpandNested
);
}
return new VectorIdsKNNIterator(
Expand Down Expand Up @@ -317,5 +321,6 @@ public static class ExactSearcherContext {
VectorSimilarityFunction similarityFunction;
Boolean isMemoryOptimizedSearchEnabled;
String exactKNNSpaceType;
boolean expandNested;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@
* This iterator iterates filterIdsArray to scoreif filter is provided else it iterates over all docs.
* However, it dedupe docs per each parent doc
* of which ID is set in parentBitSet and only return best child doc with the highest score.
* When expandNested is true, it returns ALL child docs instead of just the best one.
*/
public class NestedBinaryVectorIdsKNNIterator extends BinaryVectorIdsKNNIterator {
private final BitSet parentBitSet;
private final boolean expandNested;

public NestedBinaryVectorIdsKNNIterator(
@Nullable final DocIdSetIterator filterIdsIterator,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
final BitSet parentBitSet,
final boolean expandNested
) throws IOException {
super(filterIdsIterator, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
this.expandNested = expandNested;
}

public NestedBinaryVectorIdsKNNIterator(
Expand All @@ -40,20 +44,34 @@ public NestedBinaryVectorIdsKNNIterator(
) throws IOException {
super(null, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
this.expandNested = false;
}

/**
* Advance to the next best child doc per parent and update score with the best score among child docs from the parent.
* When expandNested is true, returns ALL child docs instead of just the best one.
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
*
* @return next best child doc id
* @return next child doc id (best child when expandNested=false, all children when expandNested=true)
*/
@Override
public int nextDoc() throws IOException {
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}

if (expandNested) {
int currentParent = parentBitSet.nextSetBit(docId);
if (currentParent != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will there be a case where docId won't be less than parent? it will either be NO_MORE_DOCS or less than parent right?

currentScore = computeScore();
int currentDocId = docId;
docId = getNextDocId();
return currentDocId;
}
docId = getNextDocId();
return nextDoc();
Comment on lines +71 to +72
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for clarification, what case is this handling? I think if it hits this case then docId will already be -1 right? is it safe to return no more docs here?

}

currentScore = Float.NEGATIVE_INFINITY;
int currentParent = parentBitSet.nextSetBit(docId);
int bestChild = -1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@
* This iterator iterates filterIdsArray to score if filter is provided else it iterates over all docs.
* However, it dedupe docs per each parent doc
* of which ID is set in parentBitSet and only return best child doc with the highest score.
* When expandNested is true, it returns ALL child docs instead of just the best one.
*/
public class NestedByteVectorIdsKNNIterator extends ByteVectorIdsKNNIterator {
private final BitSet parentBitSet;
private final boolean expandNested;

public NestedByteVectorIdsKNNIterator(
@Nullable final DocIdSetIterator filterIdsIterator,
final float[] queryVector,
final KNNByteVectorValues byteVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
final BitSet parentBitSet,
final boolean expandNested
) throws IOException {
super(filterIdsIterator, queryVector, byteVectorValues, spaceType);
this.parentBitSet = parentBitSet;
this.expandNested = expandNested;
}

public NestedByteVectorIdsKNNIterator(
Expand All @@ -40,20 +44,34 @@ public NestedByteVectorIdsKNNIterator(
) throws IOException {
super(null, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
this.expandNested = false;
}

/**
* Advance to the next best child doc per parent and update score with the best score among child docs from the parent.
* When expandNested is true, returns ALL child docs instead of just the best one.
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
*
* @return next best child doc id
* @return next child doc id (best child when expandNested=false, all children when expandNested=true)
*/
@Override
public int nextDoc() throws IOException {
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}

if (expandNested) {
int currentParent = parentBitSet.nextSetBit(docId);
if (currentParent != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
currentScore = computeScore();
int currentDocId = docId;
docId = getNextDocId();
return currentDocId;
}
docId = getNextDocId();
return nextDoc();
}

currentScore = Float.NEGATIVE_INFINITY;
int currentParent = parentBitSet.nextSetBit(docId);
int bestChild = -1;
Expand Down
Loading
Loading