diff --git a/CHANGELOG.md b/CHANGELOG.md index cf98033cf..952fb6822 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.x](https://github.com/opensearch-project/neural-search/compare/main...HEAD) ### Features - Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195)) +- Support filter function for HybridQueryBuilder and NeuralQueryBuilder ([#1206](https://github.com/opensearch-project/neural-search/pull/1206)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/restart/HybridSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/restart/HybridSearchIT.java index d08d208da..a334f19f8 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/restart/HybridSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/restart/HybridSearchIT.java @@ -23,6 +23,7 @@ import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.index.query.QueryBuilder; public class HybridSearchIT extends AbstractRestartUpgradeRestTestCase { private static final String PIPELINE_NAME = "nlp-hybrid-pipeline"; @@ -72,9 +73,15 @@ private void validateNormalizationProcessor(final String fileName, final String modelId = getModelId(getIngestionPipeline(pipelineName), TEXT_EMBEDDING_PROCESSOR); loadModel(modelId); addDocuments(getIndexNameForTest(), false); - HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null); + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null); validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder); - hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault()); + hybridQueryBuilder = getQueryBuilder( + modelId, + Boolean.FALSE, + Map.of("ef_search", 100), + RescoreContext.getDefault(), + new MatchQueryBuilder("_id", "5") + ); validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder); } finally { wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName); @@ -120,7 +127,8 @@ private HybridQueryBuilder getQueryBuilder( final String modelId, final Boolean expandNestedDocs, final Map methodParameters, - final RescoreContext rescoreContext + final RescoreContext rescoreContext, + final QueryBuilder filter ) { NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() .fieldName("passage_embedding") @@ -144,6 +152,10 @@ private HybridQueryBuilder getQueryBuilder( hybridQueryBuilder.add(matchQueryBuilder); hybridQueryBuilder.add(neuralQueryBuilder); + if (filter != null) { + hybridQueryBuilder.filter(filter); + } + return hybridQueryBuilder; } diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/HybridSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/HybridSearchIT.java index de7ddef55..2642678ab 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/HybridSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/HybridSearchIT.java @@ -66,12 +66,12 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr int totalDocsCountMixed; if (isFirstMixedRound()) { totalDocsCountMixed = NUM_DOCS_PER_ROUND; - HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null); + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null); validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null); addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null); } else { totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND; - HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null); + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null); validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null); } break; @@ -81,9 +81,15 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND; loadModel(modelId); addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null); - HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null); + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null); validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null); - hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault()); + hybridQueryBuilder = getQueryBuilder( + modelId, + Boolean.FALSE, + Map.of("ef_search", 100), + RescoreContext.getDefault(), + new MatchQueryBuilder("_id", "2") + ); validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null); } finally { wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME); @@ -123,7 +129,8 @@ private HybridQueryBuilder getQueryBuilder( final String modelId, final Boolean expandNestedDocs, final Map methodParameters, - final RescoreContext rescoreContextForNeuralQuery + final RescoreContext rescoreContextForNeuralQuery, + final QueryBuilder filter ) { NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() .fieldName(VECTOR_EMBEDDING_FIELD) @@ -147,6 +154,10 @@ private HybridQueryBuilder getQueryBuilder( hybridQueryBuilder.add(matchQueryBuilder); hybridQueryBuilder.add(neuralQueryBuilder); + if (filter != null) { + hybridQueryBuilder.filter(filter); + } + return hybridQueryBuilder; } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index c04fbda09..ad144b716 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.ListIterator; import java.util.Locale; import java.util.Objects; import java.util.stream.Collectors; @@ -51,6 +52,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder queries = new ArrayList<>(); @@ -94,6 +96,26 @@ public HybridQueryBuilder add(QueryBuilder queryBuilder) { return this; } + /** + * Function to support filter on HybridQueryBuilder filter. + * If the filter is null, then we do nothing and return. + * Otherwise, we push down the filter to queries list. + * @param filter the filter parameter + * @return HybridQueryBuilder itself + */ + public QueryBuilder filter(QueryBuilder filter) { + if (validateFilterParams(filter) == false) { + return this; + } + ListIterator iterator = queries.listIterator(); + while (iterator.hasNext()) { + QueryBuilder query = iterator.next(); + // set the query again because query.filter(filter) can return new query. + iterator.set(query.filter(filter)); + } + return this; + } + /** * Create builder object with a content of this hybrid query * @param builder @@ -155,6 +177,10 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio * } * } * ] + * "filter": + * "term": { + * "text": "keyword" + * } * } * } * } @@ -168,6 +194,7 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx Integer paginationDepth = null; final List queries = new ArrayList<>(); + QueryBuilder filter = null; String queryName = null; String currentFieldName = null; @@ -178,6 +205,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx } else if (token == XContentParser.Token.START_OBJECT) { if (QUERIES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queries.add(parseInnerQueryBuilder(parser)); + } else if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + filter = parseInnerQueryBuilder(parser); } else { log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName)); throw new ParsingException( @@ -240,7 +269,11 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx compoundQueryBuilder.paginationDepth(paginationDepth); } for (QueryBuilder query : queries) { - compoundQueryBuilder.add(query); + if (filter == null) { + compoundQueryBuilder.add(query); + } else { + compoundQueryBuilder.add(query.filter(filter)); + } } return compoundQueryBuilder; } diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index d90906836..ff0af42fa 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -309,6 +309,24 @@ protected void doWriteTo(StreamOutput out) throws IOException { RescoreParser.streamOutput(out, rescoreContext); } + /** + * Add a filter to Neural Query Builder + * @param filterToBeAdded filter to be added + * @return return itself with underlying filter combined with passed in filter + */ + public QueryBuilder filter(QueryBuilder filterToBeAdded) { + if (validateFilterParams(filterToBeAdded) == false) { + return this; + } + if (filter == null) { + filter = filterToBeAdded; + } else { + filter = filter.filter(filterToBeAdded); + } + return this; + + } + @Override protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException { xContentBuilder.startObject(NAME); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index ad59ff471..69e68428d 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -56,6 +56,7 @@ import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -82,6 +83,7 @@ public class HybridQueryBuilderTests extends OpenSearchQueryTestCase { static final String TEXT_FIELD_NAME = "field"; static final String QUERY_TEXT = "Hello world!"; static final String TERM_QUERY_TEXT = "keyword"; + static final String FILTER_TERM_QUERY_TEXT = "filterKeyword"; static final String MODEL_ID = "mfgfgdsfgfdgsde"; static final int K = 10; static final float BOOST = 1.8f; @@ -436,6 +438,121 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() { assertEquals(TERM_QUERY_TEXT, termQueryBuilder.value()); } + /** + * Tests basic query: + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "model_id": "dcsdcasd", + * "k": 1 + * } + * } + * }, + * { + * "term": { + * "text": "keyword" + * } + * } + * ] + * "filter": { + * "term": { + * "text": "filterKeyword" + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testFromXContent_whenMultipleSubQueriesAndFilter_thenBuildSuccessfully() { + setUpClusterService(); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startArray("queries") + .startObject() + .startObject(NeuralQueryBuilder.NAME) + .startObject(VECTOR_FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .endObject() + .endObject() + .endObject() + .startObject() + .startObject(TermQueryBuilder.NAME) + .field(TEXT_FIELD_NAME, TERM_QUERY_TEXT) + .endObject() + .endObject() + .endArray() + + .field("pagination_depth", 10) + .startObject("filter") + .startObject(TermQueryBuilder.NAME) + .field(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT) + .endObject() + .endObject() + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(NeuralQueryBuilder.NAME), + NeuralQueryBuilder::fromXContent + ), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(HybridQueryBuilder.NAME), + HybridQueryBuilder::fromXContent + ) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + xContentBuilder.contentType().xContent(), + BytesReference.bytes(xContentBuilder) + ); + contentParser.nextToken(); + + HybridQueryBuilder queryTwoSubQueries = HybridQueryBuilder.fromXContent(contentParser); + assertEquals(2, queryTwoSubQueries.queries().size()); + assertTrue(queryTwoSubQueries.queries().get(0) instanceof NeuralQueryBuilder); + + assertTrue(queryTwoSubQueries.queries().get(1) instanceof BoolQueryBuilder); + assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().size()); + assertTrue(((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0) instanceof TermQueryBuilder); + assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().size()); + + assertEquals(10, queryTwoSubQueries.paginationDepth().intValue()); + // verify knn vector query + NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0); + assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(K, (int) neuralQueryBuilder.k()); + assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); + assertEquals(BOOST, neuralQueryBuilder.boost(), 0f); + assertEquals( + new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT), + ((NeuralQueryBuilder) queryTwoSubQueries.queries().get(0)).filter() + ); + // verify term query + assertEquals( + new TermQueryBuilder(TEXT_FIELD_NAME, TERM_QUERY_TEXT), + ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0) + ); + assertEquals( + new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT), + ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().get(0) + ); + } + @SneakyThrows public void testFromXContent_whenIncorrectFormat_thenFail() { XContentBuilder unsupportedFieldXContentBuilder = XContentFactory.jsonBuilder() @@ -960,6 +1077,25 @@ public void testVisit() { assertEquals(3, visitedQueries.size()); } + public void testFilter() { + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add( + NeuralQueryBuilder.builder().fieldName("test").queryText("test").build() + ).add(new NeuralSparseQueryBuilder()); + // Test for Null filter Case + QueryBuilder queryBuilder = hybridQueryBuilder.filter(null); + assertEquals(queryBuilder, hybridQueryBuilder); + + // Test for Non-Null filter case and assert every field as expected + HybridQueryBuilder updatedHybridQueryBuilder = (HybridQueryBuilder) hybridQueryBuilder.filter(new MatchAllQueryBuilder()); + assertEquals(updatedHybridQueryBuilder.queryName(), hybridQueryBuilder.queryName()); + assertEquals(updatedHybridQueryBuilder.paginationDepth(), hybridQueryBuilder.paginationDepth()); + NeuralQueryBuilder updatedNeuralQueryBuilder = (NeuralQueryBuilder) updatedHybridQueryBuilder.queries().get(0); + assertEquals(new MatchAllQueryBuilder(), updatedNeuralQueryBuilder.filter()); + BoolQueryBuilder updatedNeuralSparseQueryBuilder = (BoolQueryBuilder) updatedHybridQueryBuilder.queries().get(1); + assertEquals(new NeuralSparseQueryBuilder(), updatedNeuralSparseQueryBuilder.must().get(0)); + assertEquals(new MatchAllQueryBuilder(), updatedNeuralSparseQueryBuilder.filter().get(0)); + } + private Map getInnerMap(Object innerObject, String queryName, String fieldName) { if (!(innerObject instanceof Map)) { fail("field name does not map to nested object"); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryFilterIT.java new file mode 100644 index 000000000..680d48114 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryFilterIT.java @@ -0,0 +1,233 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import com.google.common.primitives.Floats; +import lombok.SneakyThrows; +import org.junit.BeforeClass; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; + +public class HybridQueryFilterIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = + "test-hybrid-post-filter-multi-doc-index-multiple-shards"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = + "test-hybrid-post-filter-multi-doc-index-single-shard"; + private static final String SEARCH_PIPELINE = "phase-results-hybrid-post-filter-pipeline"; + private static final String INTEGER_FIELD_1_STOCK = "stock"; + private static final String TEXT_FIELD_1_NAME = "name"; + private static final String KEYWORD_FIELD_2_CATEGORY = "category"; + private static final String TEXT_FIELD_VALUE_1_DUNES = "Dunes part 1"; + private static final String TEXT_FIELD_VALUE_2_DUNES = "Dunes part 2"; + private static final String TEXT_FIELD_VALUE_3_MI_1 = "Mission Impossible 1"; + private static final String TEXT_FIELD_VALUE_4_MI_2 = "Mission Impossible 2"; + private static final String TEXT_FIELD_VALUE_5_TERMINAL = "The Terminal"; + private static final String TEXT_FIELD_VALUE_6_AVENGERS = "Avengers"; + private static final String TEST_QUERY_TEXT = "Hello world"; + private static final int INTEGER_FIELD_STOCK_1_25 = 25; + private static final int INTEGER_FIELD_STOCK_2_22 = 22; + private static final int INTEGER_FIELD_STOCK_3_256 = 256; + private static final int INTEGER_FIELD_STOCK_4_25 = 25; + private static final int INTEGER_FIELD_STOCK_5_20 = 20; + private static final String KEYWORD_FIELD_CATEGORY_1_DRAMA = "Drama"; + private static final String KEYWORD_FIELD_CATEGORY_2_ACTION = "Action"; + private static final String KEYWORD_FIELD_CATEGORY_3_SCI_FI = "Sci-fi"; + private static final int SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER = 1; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + + @BeforeClass + @SneakyThrows + public static void setUpCluster() { + // we need new instance because we're calling non-static methods from static method. + // main purpose is to minimize network calls, initialization is only needed once + HybridQueryFilterIT instance = new HybridQueryFilterIT(); + instance.initClient(); + instance.updateClusterSettings(); + } + + @SneakyThrows + public void testFilterOnNeuralQueryFilterAndTermQueryFilter_thenSuccessful() { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); + testNeuralQueryBuilder(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + } + + @SneakyThrows + private void testNeuralQueryBuilder(String indexName) { + String modelId = null; + modelId = prepareModel(); + NeuralQueryBuilder neuralQueryBuilderTextQuery = NeuralQueryBuilder.builder() + .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .k(1) + .build(); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1_NAME, TEST_QUERY_TEXT); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilderTextQuery); + hybridQueryBuilder.add(termQueryBuilder); + + hybridQueryBuilder.filter(new MatchQueryBuilder("_id", "1")); + + Map searchResponseAsMapTextQuery = search(indexName, hybridQueryBuilder, 1); + assertEquals(1, getHitCount(searchResponseAsMapTextQuery)); + + Map firstInnerHitTextQuery = getFirstInnerHit(searchResponseAsMapTextQuery); + assertEquals("1", firstInnerHitTextQuery.get("_id")); + } + + @SneakyThrows + void prepareResourcesBeforeTestExecution(int numShards) { + if (numShards == 1) { + initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, numShards); + } else { + initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, numShards); + } + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + } + + @SneakyThrows + private void initializeIndexIfNotExists(String indexName, int numShards) { + if (!indexExists(indexName)) { + prepareKnnIndex( + indexName, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)) + ); + + addKnnDoc( + indexName, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_2_DUNES), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_1_25), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_1_DRAMA), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_1_DUNES), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_2_22), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_1_DRAMA), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_3_MI_1), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_3_256), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_2_ACTION), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "4", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_4_MI_2), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_4_25), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_2_ACTION), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "5", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_5_TERMINAL), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_5_20), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_1_DRAMA), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "6", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_6_AVENGERS), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_5_20), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_3_SCI_FI), + List.of(), + List.of() + ); + } + } + + private HybridQueryBuilder createHybridQueryBuilderWithMatchTermAndRangeQuery(String text, String value, int lte, int gte) { + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, text); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1_NAME, value); + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder).add(termQueryBuilder).add(rangeQueryBuilder); + return hybridQueryBuilder; + } + + private QueryBuilder createQueryBuilderWithRangeQuery(int lte, int gte) { + return QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 68ba11f9b..8ebc95714 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -36,6 +36,8 @@ import java.util.function.Supplier; import org.opensearch.Version; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.transport.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -77,9 +79,12 @@ public class NeuralQueryBuilderTests extends OpenSearchTestCase { private static final Float MIN_SCORE = 0.985f; private static final float BOOST = 1.8f; private static final String QUERY_NAME = "queryName"; + private static final String TERM_QUERY_FIELD_NAME = "termQueryFiledName"; + private static final String TERM_QUERY_FIELD_VALUE = "termQueryFiledValue"; private static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[10]; private static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder(); + private static final QueryBuilder ADDITIONAL_TEST_FILTER = new TermQueryBuilder(TERM_QUERY_FIELD_NAME, TERM_QUERY_FIELD_VALUE); @SneakyThrows public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() { @@ -763,6 +768,29 @@ private NeuralQueryBuilder getBaselineNeuralQueryBuilder() { .build(); } + public void testFilter_whenAddBoolQueryBuilderToNeuralQueryBuilder_thenFilterSuccessful() { + // Test for Null Case + NeuralQueryBuilder neuralQueryBuilder = getBaselineNeuralQueryBuilder(); + QueryBuilder updatedNeuralQueryBuilder = neuralQueryBuilder.filter(null); + assertEquals(neuralQueryBuilder, updatedNeuralQueryBuilder); + + // Test for valid case + neuralQueryBuilder = getBaselineNeuralQueryBuilder(); + updatedNeuralQueryBuilder = neuralQueryBuilder.filter(ADDITIONAL_TEST_FILTER); + BoolQueryBuilder expectedUpdatedQueryFilter = new BoolQueryBuilder(); + expectedUpdatedQueryFilter.must(TEST_FILTER); + expectedUpdatedQueryFilter.filter(ADDITIONAL_TEST_FILTER); + assertEquals(neuralQueryBuilder, updatedNeuralQueryBuilder); + assertEquals(expectedUpdatedQueryFilter, neuralQueryBuilder.filter()); + + // Test for queryBuilder without filter initialized where filter function would + // simply assign filter to its filter field. + neuralQueryBuilder = NeuralQueryBuilder.builder().fieldName(FIELD_NAME).queryText(QUERY_TEXT).modelId(MODEL_ID).k(K).build(); + updatedNeuralQueryBuilder = neuralQueryBuilder.filter(TEST_FILTER); + assertEquals(neuralQueryBuilder, updatedNeuralQueryBuilder); + assertEquals(TEST_FILTER, neuralQueryBuilder.filter()); + } + @SneakyThrows public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index de3313654..58c692190 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -176,6 +176,50 @@ public void testQueryWithBoostAndImageQueryAndRadialQuery() { ); } + /** + * Test basic query with Match Query Builder + * { + * "query": { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "model_id": "dcsdcasd", + * "k": 2, + * "boost": 2.0 + * } + * "filter": { + * "match": { + * "_id": { + * "query": "3" + * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testQueryWithBoostAndFilterApplied() { + String modelId = null; + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + modelId = prepareModel(); + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() + .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .k(3) + .build(); + + // Test with a Filter Applied + neuralQueryBuilder.filter(new MatchQueryBuilder("_id", "3")); + Map searchResponseAsMap = search(TEST_MULTI_DOC_INDEX_NAME, neuralQueryBuilder, 3); + assertEquals(1, getHitCount(searchResponseAsMap)); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("3", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA_FOR_SCORE_ASSERTION); + } + /** * Tests rescore query: * {