Skip to content

Commit 9799c6c

Browse files
authored
Validate Disjuction query in HybridQueryPhaseSearcher (#1127)
* Validate Disjuction query in HybridQueryPhaseSearcher Signed-off-by: Owais <[email protected]>
1 parent c6b8ac4 commit 9799c6c

File tree

3 files changed

+124
-1
lines changed

3 files changed

+124
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2525
- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041))
2626
- Optimize ML inference connection retry logic ([#1054](https://github.com/opensearch-project/neural-search/pull/1054))
2727
- Support for builder constructor in Neural Query Builder ([#1047](https://github.com/opensearch-project/neural-search/pull/1047))
28+
- Validate Disjunction query to avoid having nested hybrid query ([#1127](https://github.com/opensearch-project/neural-search/pull/1127))
2829
### Bug Fixes
2930
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
3031
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))

src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java

+24-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import lombok.NoArgsConstructor;
1515
import org.apache.lucene.search.BooleanClause;
1616
import org.apache.lucene.search.BooleanQuery;
17+
import org.apache.lucene.search.DisjunctionMaxQuery;
1718
import org.apache.lucene.search.Query;
1819
import org.opensearch.common.settings.Settings;
1920
import org.opensearch.index.mapper.MapperService;
@@ -104,7 +105,7 @@ protected Query extractHybridQuery(final SearchContext searchContext, final Quer
104105
* }
105106
* ]
106107
* }
107-
* TODO add similar validation for other compound type queries like dis_max, constant_score etc.
108+
* TODO add similar validation for other compound type queries like constant_score, function_score etc.
108109
*
109110
* @param query query to validate
110111
*/
@@ -114,6 +115,10 @@ private void validateQuery(final SearchContext searchContext, final Query query)
114115
for (BooleanClause booleanClause : booleanClauses) {
115116
validateNestedBooleanQuery(booleanClause.getQuery(), getMaxDepthLimit(searchContext));
116117
}
118+
} else if (query instanceof DisjunctionMaxQuery) {
119+
for (Query disjunct : (DisjunctionMaxQuery) query) {
120+
validateNestedDisJunctionQuery(disjunct, getMaxDepthLimit(searchContext));
121+
}
117122
}
118123
}
119124

@@ -135,6 +140,24 @@ private void validateNestedBooleanQuery(final Query query, final int level) {
135140
}
136141
}
137142

143+
private void validateNestedDisJunctionQuery(final Query query, final int level) {
144+
if (query instanceof HybridQuery) {
145+
throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries");
146+
}
147+
if (level <= 0) {
148+
// ideally we should throw an error here but this code is on the main search workflow path and that might block
149+
// execution of some queries. Instead, we're silently exit and allow such query to execute and potentially produce incorrect
150+
// results in case hybrid query is wrapped into such dis_max query
151+
log.error("reached max nested query limit, cannot process dis_max query with that many nested clauses");
152+
return;
153+
}
154+
if (query instanceof DisjunctionMaxQuery) {
155+
for (Query disjunct : (DisjunctionMaxQuery) query) {
156+
validateNestedDisJunctionQuery(disjunct, level - 1);
157+
}
158+
}
159+
}
160+
138161
private int getMaxDepthLimit(final SearchContext searchContext) {
139162
Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings();
140163
return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue();

src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java

+99
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.opensearch.index.mapper.TextFieldMapper;
5959
import org.opensearch.index.query.BoolQueryBuilder;
6060
import org.opensearch.index.query.QueryBuilders;
61+
import org.opensearch.index.query.DisMaxQueryBuilder;
6162
import org.opensearch.index.query.QueryShardContext;
6263
import org.opensearch.index.query.TermQueryBuilder;
6364
import org.opensearch.index.remote.RemoteStoreEnums;
@@ -516,6 +517,104 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() {
516517
releaseResources(directory, w, reader);
517518
}
518519

520+
@SneakyThrows
521+
public void testWrappedHybridQuery_whenHybridNestedInDisjunctionQuery_thenFail() {
522+
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
523+
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
524+
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
525+
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
526+
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
527+
MapperService mapperService = mock(MapperService.class);
528+
when(mapperService.hasNested()).thenReturn(false);
529+
530+
Directory directory = newDirectory();
531+
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
532+
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
533+
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
534+
ft.setOmitNorms(random().nextBoolean());
535+
ft.freeze();
536+
int docId1 = RandomizedTest.randomInt();
537+
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
538+
w.commit();
539+
540+
IndexReader reader = DirectoryReader.open(w);
541+
SearchContext searchContext = mock(SearchContext.class);
542+
543+
ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
544+
reader,
545+
IndexSearcher.getDefaultSimilarity(),
546+
IndexSearcher.getDefaultQueryCache(),
547+
IndexSearcher.getDefaultQueryCachingPolicy(),
548+
true,
549+
null,
550+
searchContext
551+
);
552+
553+
ShardId shardId = new ShardId(dummyIndex, 1);
554+
SearchShardTarget shardTarget = new SearchShardTarget(
555+
randomAlphaOfLength(10),
556+
shardId,
557+
randomAlphaOfLength(10),
558+
OriginalIndices.NONE
559+
);
560+
when(searchContext.shardTarget()).thenReturn(shardTarget);
561+
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
562+
when(searchContext.size()).thenReturn(4);
563+
QuerySearchResult querySearchResult = new QuerySearchResult();
564+
when(searchContext.queryResult()).thenReturn(querySearchResult);
565+
when(searchContext.numberOfShards()).thenReturn(1);
566+
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
567+
IndexShard indexShard = mock(IndexShard.class);
568+
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
569+
when(indexShard.getSearchOperationListener()).thenReturn(mock(SearchOperationListener.class));
570+
when(searchContext.indexShard()).thenReturn(indexShard);
571+
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
572+
when(searchContext.mapperService()).thenReturn(mapperService);
573+
when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext);
574+
IndexMetadata indexMetadata = getIndexMetadata();
575+
Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build();
576+
IndexSettings indexSettings = new IndexSettings(indexMetadata, settings);
577+
when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings);
578+
579+
LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
580+
boolean hasFilterCollector = randomBoolean();
581+
boolean hasTimeout = randomBoolean();
582+
583+
// Create a HybridQueryBuilder
584+
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
585+
hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1));
586+
hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2));
587+
hybridQueryBuilder.paginationDepth(10);
588+
589+
// Create a regular term query
590+
TermQueryBuilder termQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2);
591+
592+
// Create a disjunction query (OR) with the hybrid query and the term query
593+
DisMaxQueryBuilder disjunctionMaxQueryBuilder = QueryBuilders.disMaxQuery().add(hybridQueryBuilder).add(termQuery);
594+
595+
Query query = disjunctionMaxQueryBuilder.toQuery(mockQueryShardContext);
596+
when(searchContext.query()).thenReturn(query);
597+
598+
IllegalArgumentException exception = expectThrows(
599+
IllegalArgumentException.class,
600+
() -> hybridQueryPhaseSearcher.searchWith(
601+
searchContext,
602+
contextIndexSearcher,
603+
query,
604+
collectors,
605+
hasFilterCollector,
606+
hasTimeout
607+
)
608+
);
609+
610+
org.hamcrest.MatcherAssert.assertThat(
611+
exception.getMessage(),
612+
containsString("hybrid query must be a top level query and cannot be wrapped into other queries")
613+
);
614+
615+
releaseResources(directory, w, reader);
616+
}
617+
519618
@SneakyThrows
520619
public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructure_thenFail() {
521620
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();

0 commit comments

Comments
 (0)