Skip to content

Commit 0ad8bb9

Browse files
committed
Add filter function for NeuralQueryBuilder and HybridQueryBuilder and modify fromXContent function in HybridQueryBuilder to support filter field.
Signed-off-by: Chloe Gao <[email protected]>
1 parent 57124dd commit 0ad8bb9

File tree

8 files changed

+736
-1
lines changed

8 files changed

+736
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
## [Unreleased 3.x](https://github.com/opensearch-project/neural-search/compare/main...HEAD)
77
### Features
88
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
9+
- Support filter function for HybridQueryBuilder and NeuralQueryBuilder ([#1206](https://github.com/opensearch-project/neural-search/pull/1206))
910
### Enhancements
1011
### Bug Fixes
1112
### Infrastructure

src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java

+38-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.ArrayList;
99
import java.util.Collection;
1010
import java.util.List;
11+
import java.util.ListIterator;
1112
import java.util.Locale;
1213
import java.util.Objects;
1314
import java.util.stream.Collectors;
@@ -51,6 +52,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu
5152
public static final String NAME = "hybrid";
5253

5354
private static final ParseField QUERIES_FIELD = new ParseField("queries");
55+
private static final ParseField FILTER_FIELD = new ParseField("filter");
5456
private static final ParseField PAGINATION_DEPTH_FIELD = new ParseField("pagination_depth");
5557

5658
private final List<QueryBuilder> queries = new ArrayList<>();
@@ -94,6 +96,30 @@ public HybridQueryBuilder add(QueryBuilder queryBuilder) {
9496
return this;
9597
}
9698

99+
/**
100+
* Function to support filter on HybridQueryBuilder filter. Currently pushing down a filter
101+
* to HybridQueryBuilder is not supported by design. We would simply check if the filter is valid
102+
* and throw exception telling this is an unsupported operation. If the filter is null, then we do nothing and
103+
* return.
104+
* @param filter the filter parameter
105+
* @return HybridQueryBuilder itself
106+
*/
107+
public QueryBuilder filter(QueryBuilder filter) {
108+
if (validateFilterParams(filter) == false) {
109+
return this;
110+
}
111+
ListIterator<QueryBuilder> iterator = queries.listIterator();
112+
while(iterator.hasNext()) {
113+
QueryBuilder query = iterator.next();
114+
if (query instanceof HybridQueryBuilder) {
115+
throw new UnsupportedOperationException("Cannot push filter to nested hybridQueryBuilder");
116+
}
117+
// set the query again because query.filter(filter) can return new query.
118+
iterator.set(query.filter(filter));
119+
}
120+
return this;
121+
}
122+
97123
/**
98124
* Create builder object with a content of this hybrid query
99125
* @param builder
@@ -155,6 +181,10 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
155181
* }
156182
* }
157183
* ]
184+
* "filter":
185+
* "term": {
186+
* "text": "keyword"
187+
* }
158188
* }
159189
* }
160190
* }
@@ -168,6 +198,7 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
168198

169199
Integer paginationDepth = null;
170200
final List<QueryBuilder> queries = new ArrayList<>();
201+
QueryBuilder filter = null;
171202
String queryName = null;
172203

173204
String currentFieldName = null;
@@ -178,6 +209,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
178209
} else if (token == XContentParser.Token.START_OBJECT) {
179210
if (QUERIES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
180211
queries.add(parseInnerQueryBuilder(parser));
212+
} else if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
213+
filter = parseInnerQueryBuilder(parser);
181214
} else {
182215
log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName));
183216
throw new ParsingException(
@@ -240,7 +273,11 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
240273
compoundQueryBuilder.paginationDepth(paginationDepth);
241274
}
242275
for (QueryBuilder query : queries) {
243-
compoundQueryBuilder.add(query);
276+
if (filter == null) {
277+
compoundQueryBuilder.add(query);
278+
} else {
279+
compoundQueryBuilder.add(query.filter(filter));
280+
}
244281
}
245282
return compoundQueryBuilder;
246283
}

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

+17
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,23 @@ protected void doWriteTo(StreamOutput out) throws IOException {
309309
RescoreParser.streamOutput(out, rescoreContext);
310310
}
311311

312+
/**
313+
* Add a filter to Neural Query Builder
314+
* @param filterToBeAdded filter to be added
315+
* @return return itself with underlying filter combined with passed in filter
316+
*/
317+
public QueryBuilder filter(QueryBuilder filterToBeAdded) {
318+
if (validateFilterParams(filterToBeAdded) == false) {
319+
return this;
320+
}
321+
if (filter == null) {
322+
filter = filterToBeAdded;
323+
return this;
324+
}
325+
filter = filter.filter(filterToBeAdded);
326+
return this;
327+
}
328+
312329
@Override
313330
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
314331
xContentBuilder.startObject(NAME);

src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java

+140
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.opensearch.index.IndexSettings;
5757
import org.opensearch.index.mapper.MappedFieldType;
5858
import org.opensearch.index.mapper.TextFieldMapper;
59+
import org.opensearch.index.query.BoolQueryBuilder;
5960
import org.opensearch.index.query.MatchAllQueryBuilder;
6061
import org.opensearch.index.query.QueryBuilder;
6162
import org.opensearch.index.query.QueryBuilders;
@@ -82,6 +83,7 @@ public class HybridQueryBuilderTests extends OpenSearchQueryTestCase {
8283
static final String TEXT_FIELD_NAME = "field";
8384
static final String QUERY_TEXT = "Hello world!";
8485
static final String TERM_QUERY_TEXT = "keyword";
86+
static final String FILTER_TERM_QUERY_TEXT = "filterKeyword";
8587
static final String MODEL_ID = "mfgfgdsfgfdgsde";
8688
static final int K = 10;
8789
static final float BOOST = 1.8f;
@@ -436,6 +438,121 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
436438
assertEquals(TERM_QUERY_TEXT, termQueryBuilder.value());
437439
}
438440

441+
/**
442+
* Tests basic query:
443+
* {
444+
* "query": {
445+
* "hybrid": {
446+
* "queries": [
447+
* {
448+
* "neural": {
449+
* "text_knn": {
450+
* "query_text": "Hello world",
451+
* "model_id": "dcsdcasd",
452+
* "k": 1
453+
* }
454+
* }
455+
* },
456+
* {
457+
* "term": {
458+
* "text": "keyword"
459+
* }
460+
* }
461+
* ]
462+
* "filter": {
463+
* "term": {
464+
* "text": "filterKeyword"
465+
* }
466+
* }
467+
* }
468+
* }
469+
* }
470+
*/
471+
@SneakyThrows
472+
public void testFromXContent_whenMultipleSubQueriesAndFilter_thenBuildSuccessfully() {
473+
setUpClusterService();
474+
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
475+
.startObject()
476+
.startArray("queries")
477+
.startObject()
478+
.startObject(NeuralQueryBuilder.NAME)
479+
.startObject(VECTOR_FIELD_NAME)
480+
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
481+
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
482+
.field(K_FIELD.getPreferredName(), K)
483+
.field(BOOST_FIELD.getPreferredName(), BOOST)
484+
.endObject()
485+
.endObject()
486+
.endObject()
487+
.startObject()
488+
.startObject(TermQueryBuilder.NAME)
489+
.field(TEXT_FIELD_NAME, TERM_QUERY_TEXT)
490+
.endObject()
491+
.endObject()
492+
.endArray()
493+
494+
.field("pagination_depth", 10)
495+
.startObject("filter")
496+
.startObject(TermQueryBuilder.NAME)
497+
.field(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT)
498+
.endObject()
499+
.endObject()
500+
.endObject();
501+
502+
NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(
503+
List.of(
504+
new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent),
505+
new NamedXContentRegistry.Entry(
506+
QueryBuilder.class,
507+
new ParseField(NeuralQueryBuilder.NAME),
508+
NeuralQueryBuilder::fromXContent
509+
),
510+
new NamedXContentRegistry.Entry(
511+
QueryBuilder.class,
512+
new ParseField(HybridQueryBuilder.NAME),
513+
HybridQueryBuilder::fromXContent
514+
)
515+
)
516+
);
517+
XContentParser contentParser = createParser(
518+
namedXContentRegistry,
519+
xContentBuilder.contentType().xContent(),
520+
BytesReference.bytes(xContentBuilder)
521+
);
522+
contentParser.nextToken();
523+
524+
HybridQueryBuilder queryTwoSubQueries = HybridQueryBuilder.fromXContent(contentParser);
525+
assertEquals(2, queryTwoSubQueries.queries().size());
526+
assertTrue(queryTwoSubQueries.queries().get(0) instanceof NeuralQueryBuilder);
527+
528+
assertTrue(queryTwoSubQueries.queries().get(1) instanceof BoolQueryBuilder);
529+
assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().size());
530+
assertTrue(((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0) instanceof TermQueryBuilder);
531+
assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().size());
532+
533+
assertEquals(10, queryTwoSubQueries.paginationDepth().intValue());
534+
// verify knn vector query
535+
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0);
536+
assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName());
537+
assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText());
538+
assertEquals(K, (int) neuralQueryBuilder.k());
539+
assertEquals(MODEL_ID, neuralQueryBuilder.modelId());
540+
assertEquals(BOOST, neuralQueryBuilder.boost(), 0f);
541+
assertEquals(
542+
new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT),
543+
((NeuralQueryBuilder) queryTwoSubQueries.queries().get(0)).filter()
544+
);
545+
// verify term query
546+
assertEquals(
547+
new TermQueryBuilder(TEXT_FIELD_NAME, TERM_QUERY_TEXT),
548+
((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0)
549+
);
550+
assertEquals(
551+
new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT),
552+
((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().get(0)
553+
);
554+
}
555+
439556
@SneakyThrows
440557
public void testFromXContent_whenIncorrectFormat_thenFail() {
441558
XContentBuilder unsupportedFieldXContentBuilder = XContentFactory.jsonBuilder()
@@ -960,6 +1077,29 @@ public void testVisit() {
9601077
assertEquals(3, visitedQueries.size());
9611078
}
9621079

1080+
public void testFilter() {
1081+
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(
1082+
NeuralQueryBuilder.builder().fieldName("test").queryText("test").build()
1083+
).add(new NeuralSparseQueryBuilder());
1084+
// Test for Null filter Case
1085+
QueryBuilder queryBuilder = hybridQueryBuilder.filter(null);
1086+
assertEquals(queryBuilder, hybridQueryBuilder);
1087+
1088+
// Test for Non-Null filter case and assert every field as expected
1089+
HybridQueryBuilder updatedHybridQueryBuilder = (HybridQueryBuilder) hybridQueryBuilder.filter(new MatchAllQueryBuilder());
1090+
assertEquals(updatedHybridQueryBuilder.queryName(), hybridQueryBuilder.queryName());
1091+
assertEquals(updatedHybridQueryBuilder.paginationDepth(), hybridQueryBuilder.paginationDepth());
1092+
NeuralQueryBuilder updatedNeuralQueryBuilder = (NeuralQueryBuilder) updatedHybridQueryBuilder.queries().get(0);
1093+
assertEquals(new MatchAllQueryBuilder(), updatedNeuralQueryBuilder.filter());
1094+
BoolQueryBuilder updatedNeuralSparseQueryBuilder = (BoolQueryBuilder) updatedHybridQueryBuilder.queries().get(1);
1095+
assertEquals(new NeuralSparseQueryBuilder(), updatedNeuralSparseQueryBuilder.must().get(0));
1096+
assertEquals(new MatchAllQueryBuilder(), updatedNeuralSparseQueryBuilder.filter().get(0));
1097+
1098+
// Test for Non-Null filter case but encountered Nested HybridQueryBuilder to throw Unsupported Exception
1099+
updatedHybridQueryBuilder.add(new HybridQueryBuilder());
1100+
assertThrows(UnsupportedOperationException.class, () -> updatedHybridQueryBuilder.filter(new MatchAllQueryBuilder()));
1101+
}
1102+
9631103
private Map<String, Object> getInnerMap(Object innerObject, String queryName, String fieldName) {
9641104
if (!(innerObject instanceof Map)) {
9651105
fail("field name does not map to nested object");

0 commit comments

Comments
 (0)