Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filter function for NeuralQueryBuilder and HybridQueryBuilder and… #1206

Merged
merged 1 commit into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 3.x](https://github.com/opensearch-project/neural-search/compare/main...HEAD)
### Features
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
- Support filter function for HybridQueryBuilder and NeuralQueryBuilder ([#1206](https://github.com/opensearch-project/neural-search/pull/1206))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.index.query.QueryBuilder;

public class HybridSearchIT extends AbstractRestartUpgradeRestTestCase {
private static final String PIPELINE_NAME = "nlp-hybrid-pipeline";
Expand Down Expand Up @@ -72,9 +73,15 @@ private void validateNormalizationProcessor(final String fileName, final String
modelId = getModelId(getIngestionPipeline(pipelineName), TEXT_EMBEDDING_PROCESSOR);
loadModel(modelId);
addDocuments(getIndexNameForTest(), false);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault());
hybridQueryBuilder = getQueryBuilder(
modelId,
Boolean.FALSE,
Map.of("ef_search", 100),
RescoreContext.getDefault(),
new MatchQueryBuilder("_id", "5")
);
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
} finally {
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
Expand Down Expand Up @@ -120,7 +127,8 @@ private HybridQueryBuilder getQueryBuilder(
final String modelId,
final Boolean expandNestedDocs,
final Map<String, ?> methodParameters,
final RescoreContext rescoreContext
final RescoreContext rescoreContext,
final QueryBuilder filter
) {
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
.fieldName("passage_embedding")
Expand All @@ -144,6 +152,10 @@ private HybridQueryBuilder getQueryBuilder(
hybridQueryBuilder.add(matchQueryBuilder);
hybridQueryBuilder.add(neuralQueryBuilder);

if (filter != null) {
hybridQueryBuilder.filter(filter);
}

return hybridQueryBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
int totalDocsCountMixed;
if (isFirstMixedRound()) {
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
} else {
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
}
break;
Expand All @@ -81,9 +81,15 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND;
loadModel(modelId);
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null, null);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null);
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault());
hybridQueryBuilder = getQueryBuilder(
modelId,
Boolean.FALSE,
Map.of("ef_search", 100),
RescoreContext.getDefault(),
new MatchQueryBuilder("_id", "2")
);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null);
} finally {
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
Expand Down Expand Up @@ -123,7 +129,8 @@ private HybridQueryBuilder getQueryBuilder(
final String modelId,
final Boolean expandNestedDocs,
final Map<String, ?> methodParameters,
final RescoreContext rescoreContextForNeuralQuery
final RescoreContext rescoreContextForNeuralQuery,
final QueryBuilder filter
) {
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
.fieldName(VECTOR_EMBEDDING_FIELD)
Expand All @@ -147,6 +154,10 @@ private HybridQueryBuilder getQueryBuilder(
hybridQueryBuilder.add(matchQueryBuilder);
hybridQueryBuilder.add(neuralQueryBuilder);

if (filter != null) {
hybridQueryBuilder.filter(filter);
}

return hybridQueryBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.ListIterator;
import java.util.Locale;
import java.util.Objects;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -51,6 +52,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu
public static final String NAME = "hybrid";

private static final ParseField QUERIES_FIELD = new ParseField("queries");
private static final ParseField FILTER_FIELD = new ParseField("filter");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new field should also be added to HybridQueryBuilder(StreamInput in) and doXContent(XContentBuilder builder, Params params) so that we will lose the filter info when we pass it across nodes. Even though the filter info is already included in the sub queries I think we still should follow the best practice to include it in those two functions.

Besides we also want to add it to the doEquals and doHashCode function in case they are used to compare two hybrid queries.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we are missing some test case here unless this filter setting is happening in coordination node.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Bo, quick question while I am trying to follow your comments.

The new field should also be added to HybridQueryBuilder(StreamInput in) and doXContent(XContentBuilder builder, Params params) so that we will lose the filter info.

What do you mean by we will lose the filter info? Could you please give an example?

add it to the doEquals and doHashCode function in case they are used to compare two hybrid queries.

The field I added here is not the filter itself but the filter parser field name. I believe the filter itself is pushed down to queries. So we are automatically comparing the queries and its filter in doEquals and doHashCode. What do you mean by adding it to the doEquals and doHashCode function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For your first question when we pass the query across nodes we will convert it to a stream and later in another node build it based on the stream. If we don't clearly write the filter to the stream we will lose it.

For the second question we override the doEquals and doHashCode function so if we don't clearly check it those two functions will ignore it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced offline. Since currently we always want to push down the filter to sub-queries and we always do this work in the fromXContent function there is no need to persist the filter in the HybridQuery. In this case there is not need to write it to the stream since it's already included in the sub-queries. Same for the doEqual and doHash functions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be fine, one concern I do have is some edge case in context of cluster upgrade. Can you add a bwc test, it may be skipped for now for 2.x -> 3.x migration, but it will run in future for next 3.x, e.g. 3.0 -> 3.1. You can check how it's done in of the recent PR for stats API

private static final ParseField PAGINATION_DEPTH_FIELD = new ParseField("pagination_depth");

private final List<QueryBuilder> queries = new ArrayList<>();
Expand Down Expand Up @@ -94,6 +96,26 @@ public HybridQueryBuilder add(QueryBuilder queryBuilder) {
return this;
}

/**
* Function to support filter on HybridQueryBuilder filter.
* If the filter is null, then we do nothing and return.
* Otherwise, we push down the filter to queries list.
* @param filter the filter parameter
* @return HybridQueryBuilder itself
*/
public QueryBuilder filter(QueryBuilder filter) {
if (validateFilterParams(filter) == false) {
return this;
}
ListIterator<QueryBuilder> iterator = queries.listIterator();
while (iterator.hasNext()) {
QueryBuilder query = iterator.next();
// set the query again because query.filter(filter) can return new query.
iterator.set(query.filter(filter));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can save few CPU cycles here if we check for a queryWithFilter ref, if that's same as query then there is no need in calling set() method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martin-gaievski, If you're okay with it, let's keep this as is. The difference between equals and set is negligible, so a simpler approach would be preferable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, let's keep the set() method, I agree that original version of the code is simpler

}
return this;
}

/**
* Create builder object with a content of this hybrid query
* @param builder
Expand Down Expand Up @@ -155,6 +177,10 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
* }
* }
* ]
* "filter":
* "term": {
* "text": "keyword"
* }
* }
* }
* }
Expand All @@ -168,6 +194,7 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx

Integer paginationDepth = null;
final List<QueryBuilder> queries = new ArrayList<>();
QueryBuilder filter = null;
String queryName = null;

String currentFieldName = null;
Expand All @@ -178,6 +205,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
} else if (token == XContentParser.Token.START_OBJECT) {
if (QUERIES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queries.add(parseInnerQueryBuilder(parser));
} else if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
filter = parseInnerQueryBuilder(parser);
} else {
log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName));
throw new ParsingException(
Expand Down Expand Up @@ -240,7 +269,11 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
compoundQueryBuilder.paginationDepth(paginationDepth);
}
for (QueryBuilder query : queries) {
compoundQueryBuilder.add(query);
if (filter == null) {
compoundQueryBuilder.add(query);
} else {
compoundQueryBuilder.add(query.filter(filter));
}
}
return compoundQueryBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,24 @@ protected void doWriteTo(StreamOutput out) throws IOException {
RescoreParser.streamOutput(out, rescoreContext);
}

/**
* Add a filter to Neural Query Builder
* @param filterToBeAdded filter to be added
* @return return itself with underlying filter combined with passed in filter
*/
public QueryBuilder filter(QueryBuilder filterToBeAdded) {
if (validateFilterParams(filterToBeAdded) == false) {
return this;
}
if (filter == null) {
filter = filterToBeAdded;
} else {
filter = filter.filter(filterToBeAdded);
}
return this;

}

@Override
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject(NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.index.IndexSettings;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand All @@ -82,6 +83,7 @@ public class HybridQueryBuilderTests extends OpenSearchQueryTestCase {
static final String TEXT_FIELD_NAME = "field";
static final String QUERY_TEXT = "Hello world!";
static final String TERM_QUERY_TEXT = "keyword";
static final String FILTER_TERM_QUERY_TEXT = "filterKeyword";
static final String MODEL_ID = "mfgfgdsfgfdgsde";
static final int K = 10;
static final float BOOST = 1.8f;
Expand Down Expand Up @@ -436,6 +438,121 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
assertEquals(TERM_QUERY_TEXT, termQueryBuilder.value());
}

/**
* Tests basic query:
* {
* "query": {
* "hybrid": {
* "queries": [
* {
* "neural": {
* "text_knn": {
* "query_text": "Hello world",
* "model_id": "dcsdcasd",
* "k": 1
* }
* }
* },
* {
* "term": {
* "text": "keyword"
* }
* }
* ]
* "filter": {
* "term": {
* "text": "filterKeyword"
* }
* }
* }
* }
* }
*/
@SneakyThrows
public void testFromXContent_whenMultipleSubQueriesAndFilter_thenBuildSuccessfully() {
setUpClusterService();
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startArray("queries")
.startObject()
.startObject(NeuralQueryBuilder.NAME)
.startObject(VECTOR_FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
.field(K_FIELD.getPreferredName(), K)
.field(BOOST_FIELD.getPreferredName(), BOOST)
.endObject()
.endObject()
.endObject()
.startObject()
.startObject(TermQueryBuilder.NAME)
.field(TEXT_FIELD_NAME, TERM_QUERY_TEXT)
.endObject()
.endObject()
.endArray()

.field("pagination_depth", 10)
.startObject("filter")
.startObject(TermQueryBuilder.NAME)
.field(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT)
.endObject()
.endObject()
.endObject();

NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(
List.of(
new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent),
new NamedXContentRegistry.Entry(
QueryBuilder.class,
new ParseField(NeuralQueryBuilder.NAME),
NeuralQueryBuilder::fromXContent
),
new NamedXContentRegistry.Entry(
QueryBuilder.class,
new ParseField(HybridQueryBuilder.NAME),
HybridQueryBuilder::fromXContent
)
)
);
XContentParser contentParser = createParser(
namedXContentRegistry,
xContentBuilder.contentType().xContent(),
BytesReference.bytes(xContentBuilder)
);
contentParser.nextToken();

HybridQueryBuilder queryTwoSubQueries = HybridQueryBuilder.fromXContent(contentParser);
assertEquals(2, queryTwoSubQueries.queries().size());
assertTrue(queryTwoSubQueries.queries().get(0) instanceof NeuralQueryBuilder);

assertTrue(queryTwoSubQueries.queries().get(1) instanceof BoolQueryBuilder);
assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().size());
assertTrue(((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0) instanceof TermQueryBuilder);
assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().size());

assertEquals(10, queryTwoSubQueries.paginationDepth().intValue());
// verify knn vector query
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0);
assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText());
assertEquals(K, (int) neuralQueryBuilder.k());
assertEquals(MODEL_ID, neuralQueryBuilder.modelId());
assertEquals(BOOST, neuralQueryBuilder.boost(), 0f);
assertEquals(
new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT),
((NeuralQueryBuilder) queryTwoSubQueries.queries().get(0)).filter()
);
// verify term query
assertEquals(
new TermQueryBuilder(TEXT_FIELD_NAME, TERM_QUERY_TEXT),
((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0)
);
assertEquals(
new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT),
((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().get(0)
);
}

@SneakyThrows
public void testFromXContent_whenIncorrectFormat_thenFail() {
XContentBuilder unsupportedFieldXContentBuilder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -960,6 +1077,25 @@ public void testVisit() {
assertEquals(3, visitedQueries.size());
}

public void testFilter() {
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(
NeuralQueryBuilder.builder().fieldName("test").queryText("test").build()
).add(new NeuralSparseQueryBuilder());
// Test for Null filter Case
QueryBuilder queryBuilder = hybridQueryBuilder.filter(null);
assertEquals(queryBuilder, hybridQueryBuilder);

// Test for Non-Null filter case and assert every field as expected
HybridQueryBuilder updatedHybridQueryBuilder = (HybridQueryBuilder) hybridQueryBuilder.filter(new MatchAllQueryBuilder());
assertEquals(updatedHybridQueryBuilder.queryName(), hybridQueryBuilder.queryName());
assertEquals(updatedHybridQueryBuilder.paginationDepth(), hybridQueryBuilder.paginationDepth());
NeuralQueryBuilder updatedNeuralQueryBuilder = (NeuralQueryBuilder) updatedHybridQueryBuilder.queries().get(0);
assertEquals(new MatchAllQueryBuilder(), updatedNeuralQueryBuilder.filter());
BoolQueryBuilder updatedNeuralSparseQueryBuilder = (BoolQueryBuilder) updatedHybridQueryBuilder.queries().get(1);
assertEquals(new NeuralSparseQueryBuilder(), updatedNeuralSparseQueryBuilder.must().get(0));
assertEquals(new MatchAllQueryBuilder(), updatedNeuralSparseQueryBuilder.filter().get(0));
}

private Map<String, Object> getInnerMap(Object innerObject, String queryName, String fieldName) {
if (!(innerObject instanceof Map)) {
fail("field name does not map to nested object");
Expand Down
Loading
Loading