From a9f21359034d91bfa3f1810bdc2e153eb82c512a Mon Sep 17 00:00:00 2001 From: Kaushika Uppu Date: Wed, 13 Aug 2025 11:34:15 -0700 Subject: [PATCH 1/2] Added expandNested support Signed-off-by: Kaushika Uppu --- CHANGELOG.md | 1 + .../knn/index/query/ExactKNNByteQuery.java | 3 +- .../knn/index/query/ExactKNNFloatQuery.java | 3 +- .../knn/index/query/ExactKNNQuery.java | 16 +- .../knn/index/query/ExactKNNQueryBuilder.java | 38 +++- .../knn/index/query/ExactKNNWeight.java | 2 + .../knn/index/query/ExactSearcher.java | 11 +- .../NestedBinaryVectorIdsKNNIterator.java | 22 +- .../NestedByteVectorIdsKNNIterator.java | 22 +- .../iterators/NestedVectorIdsKNNIterator.java | 47 +++- .../parser/ExactKNNQueryBuilderParser.java | 12 + .../query/ExactKNNQueryBuilderTests.java | 42 ++++ .../knn/index/query/ExactKNNWeightTests.java | 1 + ...NestedBinaryVectorIdsKNNIteratorTests.java | 44 +++- .../NestedByteVectorIdsKNNIteratorTests.java | 45 +++- .../NestedVectorIdsKNNIteratorTests.java | 52 ++++- .../opensearch/knn/integ/ExactKNNQueryIT.java | 214 ++++++++++++++++++ 17 files changed, 551 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae981ca7dc..c20b80c99a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * [BUGFIX] [Remote Vector Index Build] Don't fall back to CPU on terminal failures [#2773](https://github.com/opensearch-project/k-NN/pull/2773) * Add KNN timing info to core profiler [#2785](https://github.com/opensearch-project/k-NN/pull/2785) * Add "exact_knn" query clause type [#2826](https://github.com/opensearch-project/k-NN/pull/2826) +* Add expandNested support for "exact_knn" query [#2846](https://github.com/opensearch-project/k-NN/pull/2846) ### Bug Fixes * Fix @ collision in NativeMemoryCacheKeyHelper for vector index filenames containing @ characters [#2810](https://github.com/opensearch-project/k-NN/pull/2810) diff --git a/src/main/java/org/opensearch/knn/index/query/ExactKNNByteQuery.java b/src/main/java/org/opensearch/knn/index/query/ExactKNNByteQuery.java index 215936ee09..1f26d47a8c 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactKNNByteQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactKNNByteQuery.java @@ -22,9 +22,10 @@ public ExactKNNByteQuery( String indexName, VectorDataType vectorDataType, BitSetProducer parentFilter, + boolean expandNested, byte[] byteQueryVector ) { - super(field, spaceType, indexName, vectorDataType, parentFilter); + super(field, spaceType, indexName, vectorDataType, parentFilter, expandNested); this.byteQueryVector = byteQueryVector; } diff --git a/src/main/java/org/opensearch/knn/index/query/ExactKNNFloatQuery.java b/src/main/java/org/opensearch/knn/index/query/ExactKNNFloatQuery.java index 98e27575e5..c707dc8332 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactKNNFloatQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactKNNFloatQuery.java @@ -22,9 +22,10 @@ public ExactKNNFloatQuery( String indexName, VectorDataType vectorDataType, BitSetProducer parentFilter, + boolean expandNested, float[] queryVector ) { - super(field, spaceType, indexName, vectorDataType, parentFilter); + super(field, spaceType, indexName, vectorDataType, parentFilter, expandNested); this.queryVector = queryVector; } diff --git a/src/main/java/org/opensearch/knn/index/query/ExactKNNQuery.java b/src/main/java/org/opensearch/knn/index/query/ExactKNNQuery.java index 1d3dd4c2d6..855cda6dc4 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactKNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactKNNQuery.java @@ -31,16 +31,25 @@ public abstract class ExactKNNQuery extends Query { private final String indexName; private final VectorDataType vectorDataType; private BitSetProducer parentFilter; + private boolean expandNested; @Setter @Getter private boolean explain; - protected ExactKNNQuery(String field, String spaceType, String indexName, VectorDataType vectorDataType, BitSetProducer parentFilter) { + protected ExactKNNQuery( + String field, + String spaceType, + String indexName, + VectorDataType vectorDataType, + BitSetProducer parentFilter, + boolean expandNested + ) { this.field = field; this.spaceType = spaceType; this.indexName = indexName; this.vectorDataType = vectorDataType; this.parentFilter = parentFilter; + this.expandNested = expandNested; } @Override @@ -60,7 +69,7 @@ public String toString(String field) { @Override public int hashCode() { - return Objects.hash(field, spaceType, indexName, vectorDataType, parentFilter); + return Objects.hash(field, spaceType, indexName, vectorDataType, parentFilter, expandNested); } @Override @@ -73,6 +82,7 @@ public boolean equalsTo(ExactKNNQuery other) { return Objects.equals(field, other.field) && Objects.equals(spaceType, other.spaceType) && Objects.equals(indexName, other.indexName) - && Objects.equals(parentFilter, other.parentFilter); + && Objects.equals(parentFilter, other.parentFilter) + && Objects.equals(expandNested, other.expandNested); } } diff --git a/src/main/java/org/opensearch/knn/index/query/ExactKNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/ExactKNNQueryBuilder.java index 5128527541..43d161ea6a 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactKNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactKNNQueryBuilder.java @@ -33,6 +33,7 @@ import java.util.Locale; import java.util.Objects; +import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; /** @@ -45,6 +46,7 @@ public class ExactKNNQueryBuilder extends AbstractQueryBuilder vectorValues; @@ -208,7 +209,8 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea exactSearcherContext.getByteQueryVector(), (KNNBinaryVectorValues) vectorValues, spaceType, - exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext) + exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext), + isExpandNested ); } return new BinaryVectorIdsKNNIterator( @@ -227,7 +229,8 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea exactSearcherContext.getFloatQueryVector(), (KNNByteVectorValues) vectorValues, spaceType, - exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext) + exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext), + isExpandNested ); } return new ByteVectorIdsKNNIterator( @@ -271,7 +274,8 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea spaceType, exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext), quantizedQueryVector, - segmentLevelQuantizationInfo + segmentLevelQuantizationInfo, + isExpandNested ); } return new VectorIdsKNNIterator( @@ -317,5 +321,6 @@ public static class ExactSearcherContext { VectorSimilarityFunction similarityFunction; Boolean isMemoryOptimizedSearchEnabled; String exactKNNSpaceType; + boolean expandNested; } } diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java index eb285814a0..4ec7068676 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java @@ -17,19 +17,23 @@ * This iterator iterates filterIdsArray to scoreif filter is provided else it iterates over all docs. * However, it dedupe docs per each parent doc * of which ID is set in parentBitSet and only return best child doc with the highest score. + * When expandNested is true, it returns ALL child docs instead of just the best one. */ public class NestedBinaryVectorIdsKNNIterator extends BinaryVectorIdsKNNIterator { private final BitSet parentBitSet; + private final boolean expandNested; public NestedBinaryVectorIdsKNNIterator( @Nullable final DocIdSetIterator filterIdsIterator, final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType, - final BitSet parentBitSet + final BitSet parentBitSet, + final boolean expandNested ) throws IOException { super(filterIdsIterator, queryVector, binaryVectorValues, spaceType); this.parentBitSet = parentBitSet; + this.expandNested = expandNested; } public NestedBinaryVectorIdsKNNIterator( @@ -40,13 +44,15 @@ public NestedBinaryVectorIdsKNNIterator( ) throws IOException { super(null, queryVector, binaryVectorValues, spaceType); this.parentBitSet = parentBitSet; + this.expandNested = false; } /** * Advance to the next best child doc per parent and update score with the best score among child docs from the parent. + * When expandNested is true, returns ALL child docs instead of just the best one. * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs * - * @return next best child doc id + * @return next child doc id (best child when expandNested=false, all children when expandNested=true) */ @Override public int nextDoc() throws IOException { @@ -54,6 +60,18 @@ public int nextDoc() throws IOException { return DocIdSetIterator.NO_MORE_DOCS; } + if (expandNested) { + int currentParent = parentBitSet.nextSetBit(docId); + if (currentParent != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { + currentScore = computeScore(); + int currentDocId = docId; + docId = getNextDocId(); + return currentDocId; + } + docId = getNextDocId(); + return nextDoc(); + } + currentScore = Float.NEGATIVE_INFINITY; int currentParent = parentBitSet.nextSetBit(docId); int bestChild = -1; diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java index 645133ba29..5362529ad8 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java @@ -17,19 +17,23 @@ * This iterator iterates filterIdsArray to score if filter is provided else it iterates over all docs. * However, it dedupe docs per each parent doc * of which ID is set in parentBitSet and only return best child doc with the highest score. + * When expandNested is true, it returns ALL child docs instead of just the best one. */ public class NestedByteVectorIdsKNNIterator extends ByteVectorIdsKNNIterator { private final BitSet parentBitSet; + private final boolean expandNested; public NestedByteVectorIdsKNNIterator( @Nullable final DocIdSetIterator filterIdsIterator, final float[] queryVector, final KNNByteVectorValues byteVectorValues, final SpaceType spaceType, - final BitSet parentBitSet + final BitSet parentBitSet, + final boolean expandNested ) throws IOException { super(filterIdsIterator, queryVector, byteVectorValues, spaceType); this.parentBitSet = parentBitSet; + this.expandNested = expandNested; } public NestedByteVectorIdsKNNIterator( @@ -40,13 +44,15 @@ public NestedByteVectorIdsKNNIterator( ) throws IOException { super(null, queryVector, binaryVectorValues, spaceType); this.parentBitSet = parentBitSet; + this.expandNested = false; } /** * Advance to the next best child doc per parent and update score with the best score among child docs from the parent. + * When expandNested is true, returns ALL child docs instead of just the best one. * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs * - * @return next best child doc id + * @return next child doc id (best child when expandNested=false, all children when expandNested=true) */ @Override public int nextDoc() throws IOException { @@ -54,6 +60,18 @@ public int nextDoc() throws IOException { return DocIdSetIterator.NO_MORE_DOCS; } + if (expandNested) { + int currentParent = parentBitSet.nextSetBit(docId); + if (currentParent != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { + currentScore = computeScore(); + int currentDocId = docId; + docId = getNextDocId(); + return currentDocId; + } + docId = getNextDocId(); + return nextDoc(); + } + currentScore = Float.NEGATIVE_INFINITY; int currentParent = parentBitSet.nextSetBit(docId); int bestChild = -1; diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java index f356fa02ec..17e62b722d 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java @@ -18,18 +18,21 @@ * This iterator iterates filterIdsArray to score if filter is provided else it iterates over all docs. * However, it dedupe docs per each parent doc * of which ID is set in parentBitSet and only return best child doc with the highest score. + * When expandNested is true, it returns ALL child docs instead of just the best one. */ public class NestedVectorIdsKNNIterator extends VectorIdsKNNIterator { private final BitSet parentBitSet; + private final boolean expandNested; public NestedVectorIdsKNNIterator( @Nullable final DocIdSetIterator filterIdsIterator, final float[] queryVector, final KNNFloatVectorValues knnFloatVectorValues, final SpaceType spaceType, - final BitSet parentBitSet + final BitSet parentBitSet, + final boolean expandNested ) throws IOException { - this(filterIdsIterator, queryVector, knnFloatVectorValues, spaceType, parentBitSet, null, null); + this(filterIdsIterator, queryVector, knnFloatVectorValues, spaceType, parentBitSet, null, null, expandNested); } public NestedVectorIdsKNNIterator( @@ -38,7 +41,7 @@ public NestedVectorIdsKNNIterator( final SpaceType spaceType, final BitSet parentBitSet ) throws IOException { - this(null, queryVector, knnFloatVectorValues, spaceType, parentBitSet, null, null); + this(null, queryVector, knnFloatVectorValues, spaceType, parentBitSet, null, null, false); } public NestedVectorIdsKNNIterator( @@ -49,16 +52,40 @@ public NestedVectorIdsKNNIterator( final BitSet parentBitSet, final byte[] quantizedVector, final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo + ) throws IOException { + this( + filterIdsIterator, + queryVector, + knnFloatVectorValues, + spaceType, + parentBitSet, + quantizedVector, + segmentLevelQuantizationInfo, + false + ); + } + + public NestedVectorIdsKNNIterator( + @Nullable final DocIdSetIterator filterIdsIterator, + final float[] queryVector, + final KNNFloatVectorValues knnFloatVectorValues, + final SpaceType spaceType, + final BitSet parentBitSet, + final byte[] quantizedVector, + final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo, + final boolean expandNested ) throws IOException { super(filterIdsIterator, queryVector, knnFloatVectorValues, spaceType, quantizedVector, segmentLevelQuantizationInfo); this.parentBitSet = parentBitSet; + this.expandNested = expandNested; } /** * Advance to the next best child doc per parent and update score with the best score among child docs from the parent. + * When expandNested is true, returns ALL child docs instead of just the best one. * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs * - * @return next best child doc id + * @return next child doc id (best child when expandNested=false, all children when expandNested=true) */ @Override public int nextDoc() throws IOException { @@ -66,6 +93,18 @@ public int nextDoc() throws IOException { return DocIdSetIterator.NO_MORE_DOCS; } + if (expandNested) { + int currentParent = parentBitSet.nextSetBit(docId); + if (currentParent != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { + currentScore = computeScore(); + int currentDocId = docId; + docId = getNextDocId(); + return currentDocId; + } + docId = getNextDocId(); + return nextDoc(); + } + currentScore = Float.NEGATIVE_INFINITY; int currentParent = parentBitSet.nextSetBit(docId); int bestChild = -1; diff --git a/src/main/java/org/opensearch/knn/index/query/parser/ExactKNNQueryBuilderParser.java b/src/main/java/org/opensearch/knn/index/query/parser/ExactKNNQueryBuilderParser.java index 49e70092a7..1beb2450ad 100644 --- a/src/main/java/org/opensearch/knn/index/query/parser/ExactKNNQueryBuilderParser.java +++ b/src/main/java/org/opensearch/knn/index/query/parser/ExactKNNQueryBuilderParser.java @@ -20,10 +20,12 @@ import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; +import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED; import static org.opensearch.knn.index.query.ExactKNNQueryBuilder.NAME; import static org.opensearch.knn.index.query.ExactKNNQueryBuilder.VECTOR_FIELD; import static org.opensearch.knn.index.query.ExactKNNQueryBuilder.SPACE_TYPE_FIELD; import static org.opensearch.knn.index.query.ExactKNNQueryBuilder.IGNORE_UNMAPPED_FIELD; +import static org.opensearch.knn.index.query.ExactKNNQueryBuilder.EXPAND_NESTED_FIELD; import static org.opensearch.knn.index.util.IndexUtil.isClusterOnOrAfterMinRequiredVersion; /** @@ -45,6 +47,7 @@ private static ObjectParser createInternalOb b.ignoreUnmapped(v); } }, IGNORE_UNMAPPED_FIELD); + internalParser.declareBoolean(ExactKNNQueryBuilder.Builder::expandNested, EXPAND_NESTED_FIELD); return internalParser; } @@ -66,6 +69,9 @@ public static ExactKNNQueryBuilder.Builder streamInput(StreamInput in, Function< if (minClusterVersionCheck.apply("ignore_unmapped")) { builder.ignoreUnmapped(in.readOptionalBoolean()); } + if (minClusterVersionCheck.apply(EXPAND_NESTED)) { + builder.expandNested(in.readOptionalBoolean()); + } return builder; } @@ -86,6 +92,9 @@ public static void streamOutput(StreamOutput out, ExactKNNQueryBuilder builder, if (minClusterVersionCheck.apply("ignore_unmapped")) { out.writeOptionalBoolean(builder.isIgnoreUnmapped()); } + if (minClusterVersionCheck.apply(EXPAND_NESTED)) { + out.writeOptionalBoolean(builder.getExpandNested()); + } } /** @@ -148,6 +157,9 @@ public static void toXContent(XContentBuilder builder, ToXContent.Params params, if (exactKNNQueryBuilder.queryName() != null) { builder.field(NAME_FIELD.getPreferredName(), exactKNNQueryBuilder.queryName()); } + if (exactKNNQueryBuilder.getExpandNested() != null) { + builder.field(EXPAND_NESTED, exactKNNQueryBuilder.getExpandNested()); + } builder.endObject(); builder.endObject(); diff --git a/src/test/java/org/opensearch/knn/index/query/ExactKNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/ExactKNNQueryBuilderTests.java index 212e40e0f9..beb6f51153 100644 --- a/src/test/java/org/opensearch/knn/index/query/ExactKNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ExactKNNQueryBuilderTests.java @@ -36,6 +36,9 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; +import static org.opensearch.knn.index.util.IndexUtil.isClusterOnOrAfterMinRequiredVersion; + +import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED; public class ExactKNNQueryBuilderTests extends KNNTestCase { @@ -92,6 +95,15 @@ public void testIgnoreUnmapped() throws IOException { expectThrows(IllegalArgumentException.class, () -> builder.build().doToQuery(mock(QueryShardContext.class))); } + public void testExpandNested() throws IOException { + ExactKNNQueryBuilder.Builder builder = ExactKNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .spaceType(SPACE_TYPE) + .expandNested(true); + assertTrue(builder.build().getExpandNested()); + } + public void testEmptyFieldName() { /** * empty field name @@ -227,6 +239,7 @@ public void testBuilderDefaults() { assertNull(builder.getSpaceType()); // Should be null by default assertFalse(builder.isIgnoreUnmapped()); // Should be false by default + assertNull(builder.getExpandNested()); // Should be null by default } public void testNestedFields() { @@ -252,6 +265,31 @@ public void testNestedFields() { assertEquals(mockParentFilter, query.getParentFilter()); } + public void testNestedFields_withExpandNested() { + ExactKNNQueryBuilder builder = ExactKNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .spaceType(SPACE_TYPE) + .expandNested(true) + .build(); + + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); + BitSetProducer mockParentFilter = mock(BitSetProducer.class); + + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 4)); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + when(mockQueryShardContext.getParentFilter()).thenReturn(mockParentFilter); + + ExactKNNQuery query = (ExactKNNQuery) builder.doToQuery(mockQueryShardContext); + assertNotNull(query); + assertEquals(mockParentFilter, query.getParentFilter()); + assertTrue(query.isExpandNested()); + } + public void testSerialization() throws Exception { assertSerialization(Version.CURRENT); assertSerialization(Version.V_2_3_0); @@ -269,6 +307,7 @@ private void assertSerialization(final Version version) throws Exception { .fieldName(FIELD_NAME) .vector(QUERY_VECTOR) .spaceType(SPACE_TYPE) + .expandNested(true) .build(); final ClusterService clusterService = mockClusterService(version); @@ -290,6 +329,9 @@ private void assertSerialization(final Version version) throws Exception { assertArrayEquals(QUERY_VECTOR, deserializedExactKNNQueryBuilder.getVector(), 0.0f); assertEquals(SPACE_TYPE, deserializedExactKNNQueryBuilder.getSpaceType()); assertFalse(deserializedExactKNNQueryBuilder.isIgnoreUnmapped()); + if (isClusterOnOrAfterMinRequiredVersion(EXPAND_NESTED)) { + assertTrue(deserializedExactKNNQueryBuilder.getExpandNested()); + } } } } diff --git a/src/test/java/org/opensearch/knn/index/query/ExactKNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/ExactKNNWeightTests.java index 022eda07b1..85290106ad 100644 --- a/src/test/java/org/opensearch/knn/index/query/ExactKNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ExactKNNWeightTests.java @@ -256,6 +256,7 @@ public void testExactSearch_thenCorrectDocOrderWithCorrectScores() { INDEX_NAME, VectorDataType.FLOAT, null, + false, QUERY_VECTOR ); diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java index 32ac081569..76a345d395 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java @@ -54,7 +54,8 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { queryVector, values, spaceType, - parentBitSet + parentBitSet, + false ); assertEquals(filterIds[0], iterator.nextDoc()); assertEquals(expectedScores.get(0), iterator.score()); @@ -89,4 +90,45 @@ public void testNextDoc_whenIterateWithoutFilters_thenReturnBestChildDocsPerPare assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); verify(values, never()).advance(anyInt()); } + + @SneakyThrows + public void testNextDoc_whenIterateExpandNested_thenReturnAllChildDocsPerParent() { + final SpaceType spaceType = SpaceType.HAMMING; + final byte[] queryVector = { 1, 2, 3 }; + final int[] filterIds = { 0, 2, 3 }; + // Parent id for 0 -> 1 + // Parent id for 2, 3 -> 4 + // In bit representation, it is 10010. In long, it is 18. + final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + + FixedBitSet filterBitSet = new FixedBitSet(4); + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); + } + + // Execute and verify + NestedBinaryVectorIdsKNNIterator iterator = new NestedBinaryVectorIdsKNNIterator( + new BitSetIterator(filterBitSet, filterBitSet.length()), + queryVector, + values, + spaceType, + parentBitSet, + true + ); + assertEquals(filterIds[0], iterator.nextDoc()); + assertEquals(expectedScores.get(0), iterator.score()); + assertEquals(filterIds[1], iterator.nextDoc()); + assertEquals(expectedScores.get(1), iterator.score()); + assertEquals(filterIds[2], iterator.nextDoc()); + assertEquals(expectedScores.get(2), iterator.score()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java index fcc635aaa7..02f6b1e78a 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java @@ -55,7 +55,8 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { queryVector, values, spaceType, - parentBitSet + parentBitSet, + false ); assertEquals(filterIds[0], iterator.nextDoc()); assertEquals(expectedScores.get(0), iterator.score()); @@ -91,4 +92,46 @@ public void testNextDoc_whenIterateWithoutFilters_thenReturnBestChildDocsPerPare assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); verify(values, never()).advance(anyInt()); } + + @SneakyThrows + public void testNextDoc_whenIterateExpandNested_thenReturnAllChildDocsPerParent() { + final SpaceType spaceType = SpaceType.L2; + final byte[] byteQueryVector = { 1, 2, 3 }; + final float[] queryVector = { 1.0f, 2.0f, 3.0f }; + final int[] filterIds = { 0, 2, 3 }; + // Parent id for 0 -> 1 + // Parent id for 2, 3 -> 4 + // In bit representation, it is 10010. In long, it is 18. + final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 17, 18, 19 }, new byte[] { 14, 15, 16 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector)) + .collect(Collectors.toList()); + + KNNByteVectorValues values = mock(KNNByteVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + + FixedBitSet filterBitSet = new FixedBitSet(4); + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); + } + + // Execute and verify + NestedByteVectorIdsKNNIterator iterator = new NestedByteVectorIdsKNNIterator( + new BitSetIterator(filterBitSet, filterBitSet.length()), + queryVector, + values, + spaceType, + parentBitSet, + true + ); + assertEquals(filterIds[0], iterator.nextDoc()); + assertEquals(expectedScores.get(0), iterator.score()); + assertEquals(filterIds[1], iterator.nextDoc()); + assertEquals(expectedScores.get(1), iterator.score()); + assertEquals(filterIds[2], iterator.nextDoc()); + assertEquals(expectedScores.get(2), iterator.score()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIteratorTests.java index b44a90c5ff..78c955fb58 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIteratorTests.java @@ -62,7 +62,8 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { queryVector, values, spaceType, - parentBitSet + parentBitSet, + false ); assertEquals(filterIds[0], iterator.nextDoc()); assertEquals(expectedScores.get(0), iterator.score()); @@ -101,4 +102,53 @@ public void testNextDoc_whenIterateWithoutFilters_thenReturnBestChildDocsPerPare assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); verify(values, never()).advance(anyInt()); } + + @SneakyThrows + public void testNextDoc_whenIterateExpandNested_thenReturnAllChildDocsPerParent() { + final SpaceType spaceType = SpaceType.L2; + final float[] queryVector = { 1.0f, 2.0f, 3.0f }; + final int[] filterIds = { 0, 2, 3 }; + // Parent id for 0 -> 1 + // Parent id for 2, 3 -> 4 + // In bit representation, it is 10010. In long, it is 18. + final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); + final List dataVectors = Arrays.asList( + new float[] { 11.0f, 12.0f, 13.0f }, + new float[] { 17.0f, 18.0f, 19.0f }, + new float[] { 14.0f, 15.0f, 16.0f } + ); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNFloatVectorValues values = mock(KNNFloatVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + // final List byteRefs = dataVectors.stream() + // .map(vector -> new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector))) + // .collect(Collectors.toList()); + // when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + + FixedBitSet filterBitSet = new FixedBitSet(4); + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); + } + + // Execute and verify + NestedVectorIdsKNNIterator iterator = new NestedVectorIdsKNNIterator( + new BitSetIterator(filterBitSet, filterBitSet.length()), + queryVector, + values, + spaceType, + parentBitSet, + true + ); + assertEquals(filterIds[0], iterator.nextDoc()); + assertEquals(expectedScores.get(0), iterator.score()); + assertEquals(filterIds[1], iterator.nextDoc()); + assertEquals(expectedScores.get(1), iterator.score()); + assertEquals(filterIds[2], iterator.nextDoc()); + assertEquals(expectedScores.get(2), iterator.score()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } } diff --git a/src/test/java/org/opensearch/knn/integ/ExactKNNQueryIT.java b/src/test/java/org/opensearch/knn/integ/ExactKNNQueryIT.java index 7e19f115c3..ef043b3f1f 100644 --- a/src/test/java/org/opensearch/knn/integ/ExactKNNQueryIT.java +++ b/src/test/java/org/opensearch/knn/integ/ExactKNNQueryIT.java @@ -6,6 +6,7 @@ package org.opensearch.knn.integ; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Multimap; import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; import org.apache.hc.core5.http.io.entity.EntityUtils; @@ -24,7 +25,9 @@ import org.opensearch.knn.plugin.script.KNNScoringUtil; import java.io.IOException; +import java.util.ArrayList; import java.util.List; +import java.util.Map; import static org.opensearch.knn.common.KNNConstants.QUERY; import static org.opensearch.knn.common.KNNConstants.EXACT_KNN; @@ -32,6 +35,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.PATH; import static org.opensearch.knn.common.KNNConstants.TYPE_NESTED; +import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED; @Log4j2 public class ExactKNNQueryIT extends KNNRestTestCase { @@ -498,6 +502,216 @@ public void testKNNExactQuery_HammingFloat_ThenException() { deleteKNNIndex(INDEX_NAME); } + @SneakyThrows + public void testKNNExactQuery_NestedExpandDocs() { + createNestedTestIndex(); + for (int i = 0; i < SIZE; i++) { + NestedKnnDocBuilder builder = NestedKnnDocBuilder.create(FIELD_NAME_NESTED); + for (int j = 0; j < SIZE; j++) { + builder.addVectors(FIELD_NAME, new Float[] { (float) i + j, (float) i + j, (float) i + j }); + } + String doc = builder.build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + + XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().field("_source", false); + queryBuilder.startObject(QUERY); + queryBuilder.startObject(TYPE_NESTED); + queryBuilder.field(PATH, FIELD_NAME_NESTED); + queryBuilder.startObject(QUERY).startObject(EXACT_KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME); + queryBuilder.field(VECTOR, QUERY_VECTOR); + queryBuilder.field(EXPAND_NESTED, true); + queryBuilder.field("space_type", "l2"); + queryBuilder.endObject().endObject().endObject(); + queryBuilder.startObject("inner_hits").field("size", 5).endObject(); + queryBuilder.endObject().endObject().endObject(); + + float[] expectedResults = new float[SIZE]; + for (int i = 0; i < SIZE; i++) { + float sum = 0; + for (int j = 0; j < SIZE; j++) { + float[] childVector = { i + j, i + j, i + j }; + float score = SpaceType.L2.getKnnVectorSimilarityFunction().compare(QUERY_VECTOR, childVector); + sum += score; + } + expectedResults[i] = sum / SIZE; + } + + Map> expectedInnerScoreOrder = Map.of( + "0", + List.of(2, 1, 3, 0, 4), // [2,2,2], [1,1,1], [3,3,3], [0,0,0], [4,4,4] + "1", + List.of(1, 0, 2, 3, 4), // [2,2,2], [1,1,1], [3,3,3], [4,4,4], [5,5,5] + "2", + List.of(0, 1, 2, 3, 4), // [2,2,2], [3,3,3], [4,4,4], [5,5,5], [6,6,6] + "3", + List.of(0, 1, 2, 3, 4), // [3,3,3], [4,4,4], [5,5,5], [6,6,6], [7,7,7] + "4", + List.of(0, 1, 2, 3, 4) // [4,4,4], [5,5,5], [6,6,6], [7,7,7], [8,8,8] + ); + + Response searchResponse = searchKNNIndex(INDEX_NAME, queryBuilder, SIZE); + String entity = EntityUtils.toString(searchResponse.getEntity()); + List docIds = parseIds(entity); + assertEquals(SIZE, docIds.size()); + assertEquals(SIZE, parseTotalSearchHits(entity)); + List results = parseScores(entity); + for (int i = 0; i < SIZE; i++) { + assertEquals(expectedResults[i], results.get(i), 0.00001); + } + Multimap docIdToOffsets = parseInnerHits(entity, FIELD_NAME_NESTED); + assertEquals(5, docIdToOffsets.keySet().size()); + for (String key : docIdToOffsets.keySet()) { + assertEquals(5, docIdToOffsets.get(key).size()); + List offsets = new ArrayList<>(docIdToOffsets.get(key)); + for (int i = 0; i < SIZE; i++) { + assertEquals(offsets.get(i), expectedInnerScoreOrder.get(key).get(i)); + } + } + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testKNNExactQuery_NestedExpandDocs_Binary() { + createBinaryNestedTestIndex(false); + for (byte i = 0; i < SIZE; i++) { + NestedKnnDocBuilder builder = NestedKnnDocBuilder.create(FIELD_NAME_NESTED); + for (byte j = 0; j < SIZE; j++) { + builder.addVectors(FIELD_NAME, new Byte[] { (byte) (i + j), (byte) (i + j), (byte) (i + j) }); + } + String doc = builder.build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + + XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().field("_source", false); + queryBuilder.startObject(QUERY); + queryBuilder.startObject(TYPE_NESTED); + queryBuilder.field(PATH, FIELD_NAME_NESTED); + queryBuilder.startObject(QUERY).startObject(EXACT_KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME); + queryBuilder.field(VECTOR, QUERY_VECTOR); + queryBuilder.field(EXPAND_NESTED, true); + queryBuilder.endObject().endObject().endObject(); + queryBuilder.startObject("inner_hits").field("size", 5).endObject(); + queryBuilder.endObject().endObject().endObject(); + + float[] expectedResults = new float[SIZE]; + for (byte i = 0; i < SIZE; i++) { + float sum = 0; + for (byte j = 0; j < SIZE; j++) { + byte[] childVector = { (byte) (i + j), (byte) (i + j), (byte) (i + j) }; + float score = SpaceType.HAMMING.getKnnVectorSimilarityFunction().compare(BYTE_QUERY_VECTOR, childVector); + sum += score; + } + expectedResults[i] = sum / SIZE; + } + + Map> expectedInnerScoreOrder = Map.of( + "0", + List.of(3, 1, 2, 0, 4), // [3,3,3], [1,1,1], [2,2,2], [0,0,0], [4,4,4] + "1", + List.of(2, 0, 1, 4, 3), // [3,3,3], [1,1,1], [2,2,2], [5,5,5], [4,4,4] + "2", + List.of(1, 0, 3, 4, 2), // [3,3,3], [2,2,2], [5,5,5], [6,6,6], [4,4,4] + "3", + List.of(0, 4, 2, 3, 1), // [3,3,3], [7,7,7], [5,5,5], [6,6,6], [4,4,4] + "4", + List.of(3, 1, 2, 0, 4) // [7,7,7], [5,5,5], [6,6,6], [4,4,4], [8,8,8] + ); + + Response searchResponse = searchKNNIndex(INDEX_NAME, queryBuilder, SIZE); + String entity = EntityUtils.toString(searchResponse.getEntity()); + List docIds = parseIds(entity); + assertEquals(SIZE, docIds.size()); + assertEquals(SIZE, parseTotalSearchHits(entity)); + List results = parseScores(entity); + for (int i = 0; i < SIZE; i++) { + assertEquals(expectedResults[i], results.get(i), 0.00001); + } + Multimap docIdToOffsets = parseInnerHits(entity, FIELD_NAME_NESTED); + assertEquals(5, docIdToOffsets.keySet().size()); + for (String key : docIdToOffsets.keySet()) { + assertEquals(5, docIdToOffsets.get(key).size()); + List offsets = new ArrayList<>(docIdToOffsets.get(key)); + for (int i = 0; i < SIZE; i++) { + assertEquals(offsets.get(i), expectedInnerScoreOrder.get(key).get(i)); + } + } + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testKNNExactQuery_NestedExpandDocs_ScoreModeMax() { + createNestedTestIndex(); + for (int i = 0; i < SIZE; i++) { + NestedKnnDocBuilder builder = NestedKnnDocBuilder.create(FIELD_NAME_NESTED); + for (int j = 0; j < SIZE; j++) { + builder.addVectors(FIELD_NAME, new Float[] { (float) i + j, (float) i + j, (float) i + j }); + } + String doc = builder.build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + + XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().field("_source", false); + queryBuilder.startObject(QUERY); + queryBuilder.startObject(TYPE_NESTED); + queryBuilder.field(PATH, FIELD_NAME_NESTED); + queryBuilder.startObject(QUERY).startObject(EXACT_KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME); + queryBuilder.field(VECTOR, QUERY_VECTOR); + queryBuilder.field(EXPAND_NESTED, true); + queryBuilder.field("space_type", "linf"); + queryBuilder.endObject().endObject().endObject(); + queryBuilder.startObject("inner_hits").field("size", 5).endObject(); + queryBuilder.field("score_mode", "max"); + queryBuilder.endObject().endObject().endObject(); + + float[] expectedResults = new float[SIZE]; + for (int i = 0; i < SIZE; i++) { + float maxScore = Float.NEGATIVE_INFINITY; + for (int j = 0; j < SIZE; j++) { + float[] childVector = { i + j, i + j, i + j }; + float score = 1 / (1 + KNNScoringUtil.lInfNorm(QUERY_VECTOR, childVector)); + maxScore = Math.max(maxScore, score); + } + expectedResults[i] = maxScore; + } + + Map> expectedInnerScoreOrder = Map.of( + "0", + List.of(2, 1, 3, 0, 4), // [2,2,2], [1,1,1], [3,3,3], [0,0,0], [4,4,4] + "1", + List.of(1, 0, 2, 3, 4), // [2,2,2], [1,1,1], [3,3,3], [4,4,4], [5,5,5] + "2", + List.of(0, 1, 2, 3, 4), // [2,2,2], [3,3,3], [4,4,4], [5,5,5], [6,6,6] + "3", + List.of(0, 1, 2, 3, 4), // [3,3,3], [4,4,4], [5,5,5], [6,6,6], [7,7,7] + "4", + List.of(0, 1, 2, 3, 4) // [4,4,4], [5,5,5], [6,6,6], [7,7,7], [8,8,8] + ); + + Response searchResponse = searchKNNIndex(INDEX_NAME, queryBuilder, SIZE); + String entity = EntityUtils.toString(searchResponse.getEntity()); + List docIds = parseIds(entity); + assertEquals(SIZE, docIds.size()); + assertEquals(SIZE, parseTotalSearchHits(entity)); + List results = parseScores(entity); + for (int i = 0; i < SIZE; i++) { + assertEquals(expectedResults[i], results.get(i), 0.00001); + } + Multimap docIdToOffsets = parseInnerHits(entity, FIELD_NAME_NESTED); + assertEquals(5, docIdToOffsets.keySet().size()); + for (String key : docIdToOffsets.keySet()) { + assertEquals(5, docIdToOffsets.get(key).size()); + List offsets = new ArrayList<>(docIdToOffsets.get(key)); + for (int i = 0; i < SIZE; i++) { + assertEquals(offsets.get(i), expectedInnerScoreOrder.get(key).get(i)); + } + } + deleteKNNIndex(INDEX_NAME); + } + private Response validateKNNExactSearch(String testIndex, ExactKNNQueryBuilder exactKNNQueryBuilder) throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); exactKNNQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS); From 80852f3df68f4f8ec7bd76a8b0d02c2ac7024b74 Mon Sep 17 00:00:00 2001 From: Kaushika Uppu Date: Wed, 13 Aug 2025 13:29:59 -0700 Subject: [PATCH 2/2] Updated changelog Signed-off-by: Kaushika Uppu --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c20b80c99a..5f6b7bec5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * [BUGFIX] [Remote Vector Index Build] Don't fall back to CPU on terminal failures [#2773](https://github.com/opensearch-project/k-NN/pull/2773) * Add KNN timing info to core profiler [#2785](https://github.com/opensearch-project/k-NN/pull/2785) * Add "exact_knn" query clause type [#2826](https://github.com/opensearch-project/k-NN/pull/2826) -* Add expandNested support for "exact_knn" query [#2846](https://github.com/opensearch-project/k-NN/pull/2846) +* Add expandNested support for "exact_knn" query [#2848](https://github.com/opensearch-project/k-NN/pull/2848) ### Bug Fixes * Fix @ collision in NativeMemoryCacheKeyHelper for vector index filenames containing @ characters [#2810](https://github.com/opensearch-project/k-NN/pull/2810)