Skip to content

Commit 8abb418

Browse files
authored
Add filter function for NeuralQueryBuilder and HybridQueryBuilder and modify fromXContent function in HybridQueryBuilder to support filter field. (#1206)
Signed-off-by: Chloe Gao <[email protected]>
1 parent 6163f67 commit 8abb418

File tree

9 files changed

+525
-9
lines changed

9 files changed

+525
-9
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
## [Unreleased 3.x](https://github.com/opensearch-project/neural-search/compare/main...HEAD)
77
### Features
88
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
9+
- Support filter function for HybridQueryBuilder and NeuralQueryBuilder ([#1206](https://github.com/opensearch-project/neural-search/pull/1206))
910
### Enhancements
1011
### Bug Fixes
1112
### Infrastructure

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/restart/HybridSearchIT.java

+15-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.opensearch.knn.index.query.rescore.RescoreContext;
2424
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
2525
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
26+
import org.opensearch.index.query.QueryBuilder;
2627

2728
public class HybridSearchIT extends AbstractRestartUpgradeRestTestCase {
2829
private static final String PIPELINE_NAME = "nlp-hybrid-pipeline";
@@ -72,9 +73,15 @@ private void validateNormalizationProcessor(final String fileName, final String
7273
modelId = getModelId(getIngestionPipeline(pipelineName), TEXT_EMBEDDING_PROCESSOR);
7374
loadModel(modelId);
7475
addDocuments(getIndexNameForTest(), false);
75-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
76+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
7677
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
77-
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault());
78+
hybridQueryBuilder = getQueryBuilder(
79+
modelId,
80+
Boolean.FALSE,
81+
Map.of("ef_search", 100),
82+
RescoreContext.getDefault(),
83+
new MatchQueryBuilder("_id", "5")
84+
);
7885
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
7986
} finally {
8087
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
@@ -120,7 +127,8 @@ private HybridQueryBuilder getQueryBuilder(
120127
final String modelId,
121128
final Boolean expandNestedDocs,
122129
final Map<String, ?> methodParameters,
123-
final RescoreContext rescoreContext
130+
final RescoreContext rescoreContext,
131+
final QueryBuilder filter
124132
) {
125133
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
126134
.fieldName("passage_embedding")
@@ -144,6 +152,10 @@ private HybridQueryBuilder getQueryBuilder(
144152
hybridQueryBuilder.add(matchQueryBuilder);
145153
hybridQueryBuilder.add(neuralQueryBuilder);
146154

155+
if (filter != null) {
156+
hybridQueryBuilder.filter(filter);
157+
}
158+
147159
return hybridQueryBuilder;
148160
}
149161

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/HybridSearchIT.java

+16-5
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
6666
int totalDocsCountMixed;
6767
if (isFirstMixedRound()) {
6868
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
69-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
69+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
7070
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
7171
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
7272
} else {
7373
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
74-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
74+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
7575
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
7676
}
7777
break;
@@ -81,9 +81,15 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
8181
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND;
8282
loadModel(modelId);
8383
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
84-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
84+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
8585
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null);
86-
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault());
86+
hybridQueryBuilder = getQueryBuilder(
87+
modelId,
88+
Boolean.FALSE,
89+
Map.of("ef_search", 100),
90+
RescoreContext.getDefault(),
91+
new MatchQueryBuilder("_id", "2")
92+
);
8793
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null);
8894
} finally {
8995
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
@@ -123,7 +129,8 @@ private HybridQueryBuilder getQueryBuilder(
123129
final String modelId,
124130
final Boolean expandNestedDocs,
125131
final Map<String, ?> methodParameters,
126-
final RescoreContext rescoreContextForNeuralQuery
132+
final RescoreContext rescoreContextForNeuralQuery,
133+
final QueryBuilder filter
127134
) {
128135
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
129136
.fieldName(VECTOR_EMBEDDING_FIELD)
@@ -147,6 +154,10 @@ private HybridQueryBuilder getQueryBuilder(
147154
hybridQueryBuilder.add(matchQueryBuilder);
148155
hybridQueryBuilder.add(neuralQueryBuilder);
149156

157+
if (filter != null) {
158+
hybridQueryBuilder.filter(filter);
159+
}
160+
150161
return hybridQueryBuilder;
151162
}
152163
}

src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java

+34-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.ArrayList;
99
import java.util.Collection;
1010
import java.util.List;
11+
import java.util.ListIterator;
1112
import java.util.Locale;
1213
import java.util.Objects;
1314
import java.util.stream.Collectors;
@@ -51,6 +52,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu
5152
public static final String NAME = "hybrid";
5253

5354
private static final ParseField QUERIES_FIELD = new ParseField("queries");
55+
private static final ParseField FILTER_FIELD = new ParseField("filter");
5456
private static final ParseField PAGINATION_DEPTH_FIELD = new ParseField("pagination_depth");
5557

5658
private final List<QueryBuilder> queries = new ArrayList<>();
@@ -94,6 +96,26 @@ public HybridQueryBuilder add(QueryBuilder queryBuilder) {
9496
return this;
9597
}
9698

99+
/**
100+
* Function to support filter on HybridQueryBuilder filter.
101+
* If the filter is null, then we do nothing and return.
102+
* Otherwise, we push down the filter to queries list.
103+
* @param filter the filter parameter
104+
* @return HybridQueryBuilder itself
105+
*/
106+
public QueryBuilder filter(QueryBuilder filter) {
107+
if (validateFilterParams(filter) == false) {
108+
return this;
109+
}
110+
ListIterator<QueryBuilder> iterator = queries.listIterator();
111+
while (iterator.hasNext()) {
112+
QueryBuilder query = iterator.next();
113+
// set the query again because query.filter(filter) can return new query.
114+
iterator.set(query.filter(filter));
115+
}
116+
return this;
117+
}
118+
97119
/**
98120
* Create builder object with a content of this hybrid query
99121
* @param builder
@@ -155,6 +177,10 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
155177
* }
156178
* }
157179
* ]
180+
* "filter":
181+
* "term": {
182+
* "text": "keyword"
183+
* }
158184
* }
159185
* }
160186
* }
@@ -168,6 +194,7 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
168194

169195
Integer paginationDepth = null;
170196
final List<QueryBuilder> queries = new ArrayList<>();
197+
QueryBuilder filter = null;
171198
String queryName = null;
172199

173200
String currentFieldName = null;
@@ -178,6 +205,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
178205
} else if (token == XContentParser.Token.START_OBJECT) {
179206
if (QUERIES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
180207
queries.add(parseInnerQueryBuilder(parser));
208+
} else if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
209+
filter = parseInnerQueryBuilder(parser);
181210
} else {
182211
log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName));
183212
throw new ParsingException(
@@ -240,7 +269,11 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
240269
compoundQueryBuilder.paginationDepth(paginationDepth);
241270
}
242271
for (QueryBuilder query : queries) {
243-
compoundQueryBuilder.add(query);
272+
if (filter == null) {
273+
compoundQueryBuilder.add(query);
274+
} else {
275+
compoundQueryBuilder.add(query.filter(filter));
276+
}
244277
}
245278
return compoundQueryBuilder;
246279
}

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

+18
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,24 @@ protected void doWriteTo(StreamOutput out) throws IOException {
309309
RescoreParser.streamOutput(out, rescoreContext);
310310
}
311311

312+
/**
313+
* Add a filter to Neural Query Builder
314+
* @param filterToBeAdded filter to be added
315+
* @return return itself with underlying filter combined with passed in filter
316+
*/
317+
public QueryBuilder filter(QueryBuilder filterToBeAdded) {
318+
if (validateFilterParams(filterToBeAdded) == false) {
319+
return this;
320+
}
321+
if (filter == null) {
322+
filter = filterToBeAdded;
323+
} else {
324+
filter = filter.filter(filterToBeAdded);
325+
}
326+
return this;
327+
328+
}
329+
312330
@Override
313331
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
314332
xContentBuilder.startObject(NAME);

src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java

+136
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.opensearch.index.IndexSettings;
5757
import org.opensearch.index.mapper.MappedFieldType;
5858
import org.opensearch.index.mapper.TextFieldMapper;
59+
import org.opensearch.index.query.BoolQueryBuilder;
5960
import org.opensearch.index.query.MatchAllQueryBuilder;
6061
import org.opensearch.index.query.QueryBuilder;
6162
import org.opensearch.index.query.QueryBuilders;
@@ -82,6 +83,7 @@ public class HybridQueryBuilderTests extends OpenSearchQueryTestCase {
8283
static final String TEXT_FIELD_NAME = "field";
8384
static final String QUERY_TEXT = "Hello world!";
8485
static final String TERM_QUERY_TEXT = "keyword";
86+
static final String FILTER_TERM_QUERY_TEXT = "filterKeyword";
8587
static final String MODEL_ID = "mfgfgdsfgfdgsde";
8688
static final int K = 10;
8789
static final float BOOST = 1.8f;
@@ -436,6 +438,121 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
436438
assertEquals(TERM_QUERY_TEXT, termQueryBuilder.value());
437439
}
438440

441+
/**
442+
* Tests basic query:
443+
* {
444+
* "query": {
445+
* "hybrid": {
446+
* "queries": [
447+
* {
448+
* "neural": {
449+
* "text_knn": {
450+
* "query_text": "Hello world",
451+
* "model_id": "dcsdcasd",
452+
* "k": 1
453+
* }
454+
* }
455+
* },
456+
* {
457+
* "term": {
458+
* "text": "keyword"
459+
* }
460+
* }
461+
* ]
462+
* "filter": {
463+
* "term": {
464+
* "text": "filterKeyword"
465+
* }
466+
* }
467+
* }
468+
* }
469+
* }
470+
*/
471+
@SneakyThrows
472+
public void testFromXContent_whenMultipleSubQueriesAndFilter_thenBuildSuccessfully() {
473+
setUpClusterService();
474+
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
475+
.startObject()
476+
.startArray("queries")
477+
.startObject()
478+
.startObject(NeuralQueryBuilder.NAME)
479+
.startObject(VECTOR_FIELD_NAME)
480+
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
481+
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
482+
.field(K_FIELD.getPreferredName(), K)
483+
.field(BOOST_FIELD.getPreferredName(), BOOST)
484+
.endObject()
485+
.endObject()
486+
.endObject()
487+
.startObject()
488+
.startObject(TermQueryBuilder.NAME)
489+
.field(TEXT_FIELD_NAME, TERM_QUERY_TEXT)
490+
.endObject()
491+
.endObject()
492+
.endArray()
493+
494+
.field("pagination_depth", 10)
495+
.startObject("filter")
496+
.startObject(TermQueryBuilder.NAME)
497+
.field(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT)
498+
.endObject()
499+
.endObject()
500+
.endObject();
501+
502+
NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(
503+
List.of(
504+
new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent),
505+
new NamedXContentRegistry.Entry(
506+
QueryBuilder.class,
507+
new ParseField(NeuralQueryBuilder.NAME),
508+
NeuralQueryBuilder::fromXContent
509+
),
510+
new NamedXContentRegistry.Entry(
511+
QueryBuilder.class,
512+
new ParseField(HybridQueryBuilder.NAME),
513+
HybridQueryBuilder::fromXContent
514+
)
515+
)
516+
);
517+
XContentParser contentParser = createParser(
518+
namedXContentRegistry,
519+
xContentBuilder.contentType().xContent(),
520+
BytesReference.bytes(xContentBuilder)
521+
);
522+
contentParser.nextToken();
523+
524+
HybridQueryBuilder queryTwoSubQueries = HybridQueryBuilder.fromXContent(contentParser);
525+
assertEquals(2, queryTwoSubQueries.queries().size());
526+
assertTrue(queryTwoSubQueries.queries().get(0) instanceof NeuralQueryBuilder);
527+
528+
assertTrue(queryTwoSubQueries.queries().get(1) instanceof BoolQueryBuilder);
529+
assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().size());
530+
assertTrue(((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0) instanceof TermQueryBuilder);
531+
assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().size());
532+
533+
assertEquals(10, queryTwoSubQueries.paginationDepth().intValue());
534+
// verify knn vector query
535+
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0);
536+
assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName());
537+
assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText());
538+
assertEquals(K, (int) neuralQueryBuilder.k());
539+
assertEquals(MODEL_ID, neuralQueryBuilder.modelId());
540+
assertEquals(BOOST, neuralQueryBuilder.boost(), 0f);
541+
assertEquals(
542+
new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT),
543+
((NeuralQueryBuilder) queryTwoSubQueries.queries().get(0)).filter()
544+
);
545+
// verify term query
546+
assertEquals(
547+
new TermQueryBuilder(TEXT_FIELD_NAME, TERM_QUERY_TEXT),
548+
((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0)
549+
);
550+
assertEquals(
551+
new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT),
552+
((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().get(0)
553+
);
554+
}
555+
439556
@SneakyThrows
440557
public void testFromXContent_whenIncorrectFormat_thenFail() {
441558
XContentBuilder unsupportedFieldXContentBuilder = XContentFactory.jsonBuilder()
@@ -960,6 +1077,25 @@ public void testVisit() {
9601077
assertEquals(3, visitedQueries.size());
9611078
}
9621079

1080+
public void testFilter() {
1081+
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(
1082+
NeuralQueryBuilder.builder().fieldName("test").queryText("test").build()
1083+
).add(new NeuralSparseQueryBuilder());
1084+
// Test for Null filter Case
1085+
QueryBuilder queryBuilder = hybridQueryBuilder.filter(null);
1086+
assertEquals(queryBuilder, hybridQueryBuilder);
1087+
1088+
// Test for Non-Null filter case and assert every field as expected
1089+
HybridQueryBuilder updatedHybridQueryBuilder = (HybridQueryBuilder) hybridQueryBuilder.filter(new MatchAllQueryBuilder());
1090+
assertEquals(updatedHybridQueryBuilder.queryName(), hybridQueryBuilder.queryName());
1091+
assertEquals(updatedHybridQueryBuilder.paginationDepth(), hybridQueryBuilder.paginationDepth());
1092+
NeuralQueryBuilder updatedNeuralQueryBuilder = (NeuralQueryBuilder) updatedHybridQueryBuilder.queries().get(0);
1093+
assertEquals(new MatchAllQueryBuilder(), updatedNeuralQueryBuilder.filter());
1094+
BoolQueryBuilder updatedNeuralSparseQueryBuilder = (BoolQueryBuilder) updatedHybridQueryBuilder.queries().get(1);
1095+
assertEquals(new NeuralSparseQueryBuilder(), updatedNeuralSparseQueryBuilder.must().get(0));
1096+
assertEquals(new MatchAllQueryBuilder(), updatedNeuralSparseQueryBuilder.filter().get(0));
1097+
}
1098+
9631099
private Map<String, Object> getInnerMap(Object innerObject, String queryName, String fieldName) {
9641100
if (!(innerObject instanceof Map)) {
9651101
fail("field name does not map to nested object");

0 commit comments

Comments
 (0)