Skip to content

Commit d05a418

Browse files
committed
Add filter function for NeuralQueryBuilder and HybridQueryBuilder and modify fromXContent function in HybridQueryBuilder to support filter field.
Signed-off-by: Chloe Gao <[email protected]>
1 parent da5eebb commit d05a418

File tree

7 files changed

+417
-2
lines changed

7 files changed

+417
-2
lines changed

src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java

+38-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu
5151
public static final String NAME = "hybrid";
5252

5353
private static final ParseField QUERIES_FIELD = new ParseField("queries");
54+
private static final ParseField FILTER_FIELD = new ParseField("filter");
5455
private static final ParseField PAGINATION_DEPTH_FIELD = new ParseField("pagination_depth");
5556

5657
private final List<QueryBuilder> queries = new ArrayList<>();
@@ -94,6 +95,31 @@ public HybridQueryBuilder add(QueryBuilder queryBuilder) {
9495
return this;
9596
}
9697

98+
/**
99+
* Function to support filter function on HybridQueryBuilder filter. Currently pushing down a filter
100+
* to HybridQueryBuilder is not supported by design. We would simply check if the filter is valid
101+
* and throw exception telling this is an unsupported operation. If the filter is null, then we do nothing and
102+
* return.
103+
* @param filter the filter parameter
104+
* @return HybridQueryBuilder itself
105+
*/
106+
public QueryBuilder filter(QueryBuilder filter) {
107+
if (!validateFilterParams(filter)) {
108+
return this;
109+
}
110+
HybridQueryBuilder compoundQueryBuilder = new HybridQueryBuilder();
111+
compoundQueryBuilder.queryName(queryName);
112+
compoundQueryBuilder.boost(boost);
113+
compoundQueryBuilder.paginationDepth(this.paginationDepth);
114+
for (QueryBuilder query : queries) {
115+
if (query instanceof HybridQueryBuilder) {
116+
throw new UnsupportedOperationException("Cannot push filter to nested hybridQueryBuilder");
117+
}
118+
compoundQueryBuilder.add(query.filter(filter));
119+
}
120+
return compoundQueryBuilder;
121+
}
122+
97123
/**
98124
* Create builder object with a content of this hybrid query
99125
* @param builder
@@ -155,6 +181,10 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
155181
* }
156182
* }
157183
* ]
184+
* "filter":
185+
* "term": {
186+
* "text": "keyword"
187+
* }
158188
* }
159189
* }
160190
* }
@@ -168,6 +198,7 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
168198

169199
Integer paginationDepth = null;
170200
final List<QueryBuilder> queries = new ArrayList<>();
201+
QueryBuilder filter = null;
171202
String queryName = null;
172203

173204
String currentFieldName = null;
@@ -178,6 +209,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
178209
} else if (token == XContentParser.Token.START_OBJECT) {
179210
if (QUERIES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
180211
queries.add(parseInnerQueryBuilder(parser));
212+
} else if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
213+
filter = parseInnerQueryBuilder(parser);
181214
} else {
182215
log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName));
183216
throw new ParsingException(
@@ -240,7 +273,11 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
240273
compoundQueryBuilder.paginationDepth(paginationDepth);
241274
}
242275
for (QueryBuilder query : queries) {
243-
compoundQueryBuilder.add(query);
276+
if (filter == null) {
277+
compoundQueryBuilder.add(query);
278+
} else {
279+
compoundQueryBuilder.add(query.filter(filter));
280+
}
244281
}
245282
return compoundQueryBuilder;
246283
}

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

+17
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,23 @@ protected void doWriteTo(StreamOutput out) throws IOException {
309309
RescoreParser.streamOutput(out, rescoreContext);
310310
}
311311

312+
/**
313+
* Add a filter to Neural Query Builder
314+
* @param filterToBeAdded filter to be added
315+
* @return return itself with underlying filter combined with passed in filter
316+
*/
317+
public QueryBuilder filter(QueryBuilder filterToBeAdded) {
318+
if (validateFilterParams(filterToBeAdded) == false) {
319+
return this;
320+
}
321+
if (filter == null) {
322+
filter = filterToBeAdded;
323+
return this;
324+
}
325+
filter = filter.filter(filterToBeAdded);
326+
return this;
327+
}
328+
312329
@Override
313330
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
314331
xContentBuilder.startObject(NAME);

src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java

+140
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.opensearch.index.IndexSettings;
5757
import org.opensearch.index.mapper.MappedFieldType;
5858
import org.opensearch.index.mapper.TextFieldMapper;
59+
import org.opensearch.index.query.BoolQueryBuilder;
5960
import org.opensearch.index.query.MatchAllQueryBuilder;
6061
import org.opensearch.index.query.QueryBuilder;
6162
import org.opensearch.index.query.QueryBuilders;
@@ -82,6 +83,7 @@ public class HybridQueryBuilderTests extends OpenSearchQueryTestCase {
8283
static final String TEXT_FIELD_NAME = "field";
8384
static final String QUERY_TEXT = "Hello world!";
8485
static final String TERM_QUERY_TEXT = "keyword";
86+
static final String FILTER_TERM_QUERY_TEXT = "filterKeyword";
8587
static final String MODEL_ID = "mfgfgdsfgfdgsde";
8688
static final int K = 10;
8789
static final float BOOST = 1.8f;
@@ -436,6 +438,121 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
436438
assertEquals(TERM_QUERY_TEXT, termQueryBuilder.value());
437439
}
438440

441+
/**
442+
* Tests basic query:
443+
* {
444+
* "query": {
445+
* "hybrid": {
446+
* "queries": [
447+
* {
448+
* "neural": {
449+
* "text_knn": {
450+
* "query_text": "Hello world",
451+
* "model_id": "dcsdcasd",
452+
* "k": 1
453+
* }
454+
* }
455+
* },
456+
* {
457+
* "term": {
458+
* "text": "keyword"
459+
* }
460+
* }
461+
* ]
462+
* "filter": {
463+
* "term": {
464+
* "text": "filterKeyword"
465+
* }
466+
* }
467+
* }
468+
* }
469+
* }
470+
*/
471+
@SneakyThrows
472+
public void testFromXContent_whenMultipleSubQueriesAndFilter_thenBuildSuccessfully() {
473+
setUpClusterService();
474+
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
475+
.startObject()
476+
.startArray("queries")
477+
.startObject()
478+
.startObject(NeuralQueryBuilder.NAME)
479+
.startObject(VECTOR_FIELD_NAME)
480+
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
481+
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
482+
.field(K_FIELD.getPreferredName(), K)
483+
.field(BOOST_FIELD.getPreferredName(), BOOST)
484+
.endObject()
485+
.endObject()
486+
.endObject()
487+
.startObject()
488+
.startObject(TermQueryBuilder.NAME)
489+
.field(TEXT_FIELD_NAME, TERM_QUERY_TEXT)
490+
.endObject()
491+
.endObject()
492+
.endArray()
493+
494+
.field("pagination_depth", 10)
495+
.startObject("filter")
496+
.startObject(TermQueryBuilder.NAME)
497+
.field(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT)
498+
.endObject()
499+
.endObject()
500+
.endObject();
501+
502+
NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(
503+
List.of(
504+
new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent),
505+
new NamedXContentRegistry.Entry(
506+
QueryBuilder.class,
507+
new ParseField(NeuralQueryBuilder.NAME),
508+
NeuralQueryBuilder::fromXContent
509+
),
510+
new NamedXContentRegistry.Entry(
511+
QueryBuilder.class,
512+
new ParseField(HybridQueryBuilder.NAME),
513+
HybridQueryBuilder::fromXContent
514+
)
515+
)
516+
);
517+
XContentParser contentParser = createParser(
518+
namedXContentRegistry,
519+
xContentBuilder.contentType().xContent(),
520+
BytesReference.bytes(xContentBuilder)
521+
);
522+
contentParser.nextToken();
523+
524+
HybridQueryBuilder queryTwoSubQueries = HybridQueryBuilder.fromXContent(contentParser);
525+
assertEquals(2, queryTwoSubQueries.queries().size());
526+
assertTrue(queryTwoSubQueries.queries().get(0) instanceof NeuralQueryBuilder);
527+
528+
assertTrue(queryTwoSubQueries.queries().get(1) instanceof BoolQueryBuilder);
529+
assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().size());
530+
assertTrue(((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0) instanceof TermQueryBuilder);
531+
assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().size());
532+
533+
assertEquals(10, queryTwoSubQueries.paginationDepth().intValue());
534+
// verify knn vector query
535+
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0);
536+
assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName());
537+
assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText());
538+
assertEquals(K, (int) neuralQueryBuilder.k());
539+
assertEquals(MODEL_ID, neuralQueryBuilder.modelId());
540+
assertEquals(BOOST, neuralQueryBuilder.boost(), 0f);
541+
assertEquals(
542+
new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT),
543+
((NeuralQueryBuilder) queryTwoSubQueries.queries().get(0)).filter()
544+
);
545+
// verify term query
546+
assertEquals(
547+
new TermQueryBuilder(TEXT_FIELD_NAME, TERM_QUERY_TEXT),
548+
((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0)
549+
);
550+
assertEquals(
551+
new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT),
552+
((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().get(0)
553+
);
554+
}
555+
439556
@SneakyThrows
440557
public void testFromXContent_whenIncorrectFormat_thenFail() {
441558
XContentBuilder unsupportedFieldXContentBuilder = XContentFactory.jsonBuilder()
@@ -960,6 +1077,29 @@ public void testVisit() {
9601077
assertEquals(3, visitedQueries.size());
9611078
}
9621079

1080+
public void testFilter() {
1081+
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(
1082+
NeuralQueryBuilder.builder().fieldName("test").queryText("test").build()
1083+
).add(new NeuralSparseQueryBuilder());
1084+
// Test for Null filter Case
1085+
QueryBuilder queryBuilder = hybridQueryBuilder.filter(null);
1086+
assertEquals(queryBuilder, hybridQueryBuilder);
1087+
1088+
// Test for Non-Null filter case and assert every field as expected
1089+
HybridQueryBuilder updatedHybridQueryBuilder = (HybridQueryBuilder) hybridQueryBuilder.filter(new MatchAllQueryBuilder());
1090+
assertEquals(updatedHybridQueryBuilder.queryName(), hybridQueryBuilder.queryName());
1091+
assertEquals(updatedHybridQueryBuilder.paginationDepth(), hybridQueryBuilder.paginationDepth());
1092+
NeuralQueryBuilder updatedNeuralQueryBuilder = (NeuralQueryBuilder) updatedHybridQueryBuilder.queries().get(0);
1093+
assertEquals(new MatchAllQueryBuilder(), updatedNeuralQueryBuilder.filter());
1094+
BoolQueryBuilder updatedNeuralSparseQueryBuilder = (BoolQueryBuilder) updatedHybridQueryBuilder.queries().get(1);
1095+
assertEquals(new NeuralSparseQueryBuilder(), updatedNeuralSparseQueryBuilder.must().get(0));
1096+
assertEquals(new MatchAllQueryBuilder(), updatedNeuralSparseQueryBuilder.filter().get(0));
1097+
1098+
// Test for Non-Null filter case but encountered Nested HybridQueryBuilder to throw Unsupported Exception
1099+
updatedHybridQueryBuilder.add(new HybridQueryBuilder());
1100+
assertThrows(UnsupportedOperationException.class, () -> updatedHybridQueryBuilder.filter(new MatchAllQueryBuilder()));
1101+
}
1102+
9631103
private Map<String, Object> getInnerMap(Object innerObject, String queryName, String fieldName) {
9641104
if (!(innerObject instanceof Map)) {
9651105
fail("field name does not map to nested object");

src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java

+94
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchEnabled_the
7474
testPostFilterMatchAllAndMatchNoneQueries(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD);
7575
}
7676

77+
@SneakyThrows
78+
public void testFilterOnIndexWithSingleShard_whenConcurrentSearchEnabled_thenSuccessful() {
79+
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true);
80+
prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER);
81+
testRangeQueryAsFilter(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD);
82+
}
83+
7784
@SneakyThrows
7885
public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchDisabled_thenSuccessful() {
7986
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
@@ -167,6 +174,93 @@ private void testPostFilterRangeQuery(String indexName) {
167174
assertHybridQueryResults(searchResponseAsMap, 1, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
168175
}
169176

177+
/**{
178+
"query": {
179+
"hybrid":{
180+
"queries":[
181+
"bool":{
182+
"must": [
183+
"match": {
184+
"name": "mission"
185+
},
186+
],
187+
"filter": [
188+
"range": {
189+
"stock": {
190+
"gte": 230,
191+
"lte": 400
192+
}
193+
}
194+
]
195+
},
196+
"bool":{
197+
"must": [
198+
"term": {
199+
"name": {"value": "part"}
200+
},
201+
],
202+
"filter": [
203+
"range": {
204+
"stock": {
205+
"gte": 230,
206+
"lte": 400
207+
}
208+
}
209+
]
210+
},
211+
212+
"bool":{
213+
"must": [
214+
"range": {
215+
"stock": {
216+
"gte": 200,
217+
"lte": 400
218+
}
219+
}
220+
],
221+
"filter": [
222+
"range": {
223+
"stock": {
224+
"gte": 230,
225+
"lte": 400
226+
}
227+
}
228+
]
229+
}
230+
]
231+
}
232+
}
233+
}*/
234+
@SneakyThrows
235+
private void testRangeQueryAsFilter(String indexName) {
236+
QueryBuilder postFilterQuery = createQueryBuilderWithRangeQuery(
237+
LTE_OF_RANGE_IN_POST_FILTER_QUERY,
238+
GTE_OF_RANGE_IN_POST_FILTER_QUERY
239+
);
240+
HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery(
241+
"mission",
242+
"part",
243+
LTE_OF_RANGE_IN_HYBRID_QUERY,
244+
GTE_OF_RANGE_IN_HYBRID_QUERY
245+
);
246+
hybridQueryBuilder = (HybridQueryBuilder) hybridQueryBuilder.filter(postFilterQuery);
247+
248+
Map<String, Object> searchResponseAsMap = search(
249+
indexName,
250+
hybridQueryBuilder,
251+
null,
252+
10,
253+
Map.of("search_pipeline", SEARCH_PIPELINE),
254+
null,
255+
null,
256+
null,
257+
false,
258+
null,
259+
0
260+
);
261+
assertHybridQueryResults(searchResponseAsMap, 1, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
262+
}
263+
170264
/*{
171265
"query": {
172266
"hybrid":{

0 commit comments

Comments
 (0)