Skip to content

Commit 3f682e8

Browse files
committed
Add filter function for NeuralQueryBuilder and HybridQueryBuilder and modify fromXContent function in HybridQueryBuilder to support filter field.
Signed-off-by: Chloe Gao <[email protected]>
1 parent 57124dd commit 3f682e8

File tree

11 files changed

+677
-17
lines changed

11 files changed

+677
-17
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
## [Unreleased 3.x](https://github.com/opensearch-project/neural-search/compare/main...HEAD)
77
### Features
88
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
9+
- Support filter function for HybridQueryBuilder and NeuralQueryBuilder ([#1206](https://github.com/opensearch-project/neural-search/pull/1206))
910
### Enhancements
1011
### Bug Fixes
1112
### Infrastructure

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/restart/HybridSearchIT.java

+15-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.opensearch.knn.index.query.rescore.RescoreContext;
2424
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
2525
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
26+
import org.opensearch.index.query.QueryBuilder;
2627

2728
public class HybridSearchIT extends AbstractRestartUpgradeRestTestCase {
2829
private static final String PIPELINE_NAME = "nlp-hybrid-pipeline";
@@ -72,9 +73,15 @@ private void validateNormalizationProcessor(final String fileName, final String
7273
modelId = getModelId(getIngestionPipeline(pipelineName), TEXT_EMBEDDING_PROCESSOR);
7374
loadModel(modelId);
7475
addDocuments(getIndexNameForTest(), false);
75-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
76+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
7677
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
77-
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault());
78+
hybridQueryBuilder = getQueryBuilder(
79+
modelId,
80+
Boolean.FALSE,
81+
Map.of("ef_search", 100),
82+
RescoreContext.getDefault(),
83+
new MatchQueryBuilder("_id", "5")
84+
);
7885
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
7986
} finally {
8087
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
@@ -120,7 +127,8 @@ private HybridQueryBuilder getQueryBuilder(
120127
final String modelId,
121128
final Boolean expandNestedDocs,
122129
final Map<String, ?> methodParameters,
123-
final RescoreContext rescoreContext
130+
final RescoreContext rescoreContext,
131+
final QueryBuilder filter
124132
) {
125133
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
126134
.fieldName("passage_embedding")
@@ -144,6 +152,10 @@ private HybridQueryBuilder getQueryBuilder(
144152
hybridQueryBuilder.add(matchQueryBuilder);
145153
hybridQueryBuilder.add(neuralQueryBuilder);
146154

155+
if (filter != null) {
156+
hybridQueryBuilder.filter(filter);
157+
}
158+
147159
return hybridQueryBuilder;
148160
}
149161

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/restart/HybridSearchWithRescoreIT.java

+13-3
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,15 @@ public void testHybridQueryWithRescore_whenIndexWithMultipleShards_E2EFlow() thr
6464
modelId = getModelId(getIngestionPipeline(PIPELINE_NAME), TEXT_EMBEDDING_PROCESSOR);
6565
loadModel(modelId);
6666
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_UPGRADED, null, null);
67-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
67+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
6868
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
6969
validateTestIndex(getIndexNameForTest(), hybridQueryBuilder, rescorer);
70-
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
70+
hybridQueryBuilder = getQueryBuilder(
71+
modelId,
72+
Map.of("ef_search", 100),
73+
RescoreContext.getDefault(),
74+
new MatchQueryBuilder("_id", "1")
75+
);
7176
validateTestIndex(getIndexNameForTest(), hybridQueryBuilder, rescorer);
7277
} finally {
7378
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, null);
@@ -91,7 +96,8 @@ private void validateTestIndex(final String index, HybridQueryBuilder queryBuild
9196
private HybridQueryBuilder getQueryBuilder(
9297
final String modelId,
9398
final Map<String, ?> methodParameters,
94-
final RescoreContext rescoreContextForNeuralQuery
99+
final RescoreContext rescoreContextForNeuralQuery,
100+
final QueryBuilder filter
95101
) {
96102
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
97103
.fieldName(VECTOR_EMBEDDING_FIELD)
@@ -112,6 +118,10 @@ private HybridQueryBuilder getQueryBuilder(
112118
hybridQueryBuilder.add(matchQueryBuilder);
113119
hybridQueryBuilder.add(neuralQueryBuilder);
114120

121+
if (filter != null) {
122+
hybridQueryBuilder.filter(filter);
123+
}
124+
115125
return hybridQueryBuilder;
116126
}
117127
}

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/HybridSearchIT.java

+16-5
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
6666
int totalDocsCountMixed;
6767
if (isFirstMixedRound()) {
6868
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
69-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
69+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
7070
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
7171
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
7272
} else {
7373
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
74-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
74+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
7575
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
7676
}
7777
break;
@@ -81,9 +81,15 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
8181
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND;
8282
loadModel(modelId);
8383
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
84-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
84+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
8585
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null);
86-
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault());
86+
hybridQueryBuilder = getQueryBuilder(
87+
modelId,
88+
Boolean.FALSE,
89+
Map.of("ef_search", 100),
90+
RescoreContext.getDefault(),
91+
new MatchQueryBuilder("_id", "2")
92+
);
8793
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null);
8894
} finally {
8995
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
@@ -123,7 +129,8 @@ private HybridQueryBuilder getQueryBuilder(
123129
final String modelId,
124130
final Boolean expandNestedDocs,
125131
final Map<String, ?> methodParameters,
126-
final RescoreContext rescoreContextForNeuralQuery
132+
final RescoreContext rescoreContextForNeuralQuery,
133+
final QueryBuilder filter
127134
) {
128135
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
129136
.fieldName(VECTOR_EMBEDDING_FIELD)
@@ -147,6 +154,10 @@ private HybridQueryBuilder getQueryBuilder(
147154
hybridQueryBuilder.add(matchQueryBuilder);
148155
hybridQueryBuilder.add(neuralQueryBuilder);
149156

157+
if (filter != null) {
158+
hybridQueryBuilder.filter(filter);
159+
}
160+
150161
return hybridQueryBuilder;
151162
}
152163
}

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/HybridSearchWithRescoreIT.java

+15-5
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ public void testHybridQueryWithRescore_whenIndexWithMultipleShards_E2EFlow() thr
6767
int totalDocsCountMixed;
6868
if (isFirstMixedRound()) {
6969
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
70-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
70+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
7171
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
7272
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, rescorer);
7373
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
7474
} else {
7575
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
76-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
76+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
7777
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
7878
}
7979
break;
@@ -83,10 +83,15 @@ public void testHybridQueryWithRescore_whenIndexWithMultipleShards_E2EFlow() thr
8383
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND;
8484
loadModel(modelId);
8585
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
86-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
86+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
8787
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
8888
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
89-
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
89+
hybridQueryBuilder = getQueryBuilder(
90+
modelId,
91+
Map.of("ef_search", 100),
92+
RescoreContext.getDefault(),
93+
new MatchQueryBuilder("_id", "2")
94+
);
9095
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
9196
} finally {
9297
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
@@ -125,7 +130,8 @@ private void validateTestIndexOnUpgrade(
125130
private HybridQueryBuilder getQueryBuilder(
126131
final String modelId,
127132
final Map<String, ?> methodParameters,
128-
final RescoreContext rescoreContextForNeuralQuery
133+
final RescoreContext rescoreContextForNeuralQuery,
134+
final QueryBuilder filter
129135
) {
130136
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
131137
.fieldName(VECTOR_EMBEDDING_FIELD)
@@ -146,6 +152,10 @@ private HybridQueryBuilder getQueryBuilder(
146152
hybridQueryBuilder.add(matchQueryBuilder);
147153
hybridQueryBuilder.add(neuralQueryBuilder);
148154

155+
if (filter != null) {
156+
hybridQueryBuilder.filter(filter);
157+
}
158+
149159
return hybridQueryBuilder;
150160
}
151161
}

src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java

+37-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.ArrayList;
99
import java.util.Collection;
1010
import java.util.List;
11+
import java.util.ListIterator;
1112
import java.util.Locale;
1213
import java.util.Objects;
1314
import java.util.stream.Collectors;
@@ -51,6 +52,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu
5152
public static final String NAME = "hybrid";
5253

5354
private static final ParseField QUERIES_FIELD = new ParseField("queries");
55+
private static final ParseField FILTER_FIELD = new ParseField("filter");
5456
private static final ParseField PAGINATION_DEPTH_FIELD = new ParseField("pagination_depth");
5557

5658
private final List<QueryBuilder> queries = new ArrayList<>();
@@ -94,6 +96,29 @@ public HybridQueryBuilder add(QueryBuilder queryBuilder) {
9496
return this;
9597
}
9698

99+
/**
100+
* Function to support filter on HybridQueryBuilder filter.
101+
* If the filter is null, then we do nothing and return.
102+
* Otherwise, we push down the filter to queries list.
103+
* @param filter the filter parameter
104+
* @return HybridQueryBuilder itself
105+
*/
106+
public QueryBuilder filter(QueryBuilder filter) {
107+
if (validateFilterParams(filter) == false) {
108+
return this;
109+
}
110+
ListIterator<QueryBuilder> iterator = queries.listIterator();
111+
while (iterator.hasNext()) {
112+
final QueryBuilder query = iterator.next();
113+
final QueryBuilder queryWithFilter = query.filter(filter);
114+
if (!queryWithFilter.equals(query)) {
115+
// set the query again because query.filter(filter) can return new query.
116+
iterator.set(query.filter(filter));
117+
}
118+
}
119+
return this;
120+
}
121+
97122
/**
98123
* Create builder object with a content of this hybrid query
99124
* @param builder
@@ -155,6 +180,10 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
155180
* }
156181
* }
157182
* ]
183+
* "filter":
184+
* "term": {
185+
* "text": "keyword"
186+
* }
158187
* }
159188
* }
160189
* }
@@ -168,6 +197,7 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
168197

169198
Integer paginationDepth = null;
170199
final List<QueryBuilder> queries = new ArrayList<>();
200+
QueryBuilder filter = null;
171201
String queryName = null;
172202

173203
String currentFieldName = null;
@@ -178,6 +208,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
178208
} else if (token == XContentParser.Token.START_OBJECT) {
179209
if (QUERIES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
180210
queries.add(parseInnerQueryBuilder(parser));
211+
} else if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
212+
filter = parseInnerQueryBuilder(parser);
181213
} else {
182214
log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName));
183215
throw new ParsingException(
@@ -240,7 +272,11 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
240272
compoundQueryBuilder.paginationDepth(paginationDepth);
241273
}
242274
for (QueryBuilder query : queries) {
243-
compoundQueryBuilder.add(query);
275+
if (filter == null) {
276+
compoundQueryBuilder.add(query);
277+
} else {
278+
compoundQueryBuilder.add(query.filter(filter));
279+
}
244280
}
245281
return compoundQueryBuilder;
246282
}

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

+18
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,24 @@ protected void doWriteTo(StreamOutput out) throws IOException {
309309
RescoreParser.streamOutput(out, rescoreContext);
310310
}
311311

312+
/**
313+
* Add a filter to Neural Query Builder
314+
* @param filterToBeAdded filter to be added
315+
* @return return itself with underlying filter combined with passed in filter
316+
*/
317+
public QueryBuilder filter(QueryBuilder filterToBeAdded) {
318+
if (validateFilterParams(filterToBeAdded) == false) {
319+
return this;
320+
}
321+
if (filter == null) {
322+
filter = filterToBeAdded;
323+
} else {
324+
filter = filter.filter(filterToBeAdded);
325+
}
326+
return this;
327+
328+
}
329+
312330
@Override
313331
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
314332
xContentBuilder.startObject(NAME);

0 commit comments

Comments
 (0)