Skip to content

Commit 212e782

Browse files
Handle pagination_depth when from =0 (#1132) (#1136)
* Handle pagination_depth when from =0 Signed-off-by: Varun Jain <[email protected]> * Add changelog Signed-off-by: Varun Jain <[email protected]> * Remove unecessary logs Signed-off-by: Varun Jain <[email protected]> --------- Signed-off-by: Varun Jain <[email protected]> (cherry picked from commit 3dbdcba) Co-authored-by: Varun Jain <[email protected]>
1 parent 62fd9a2 commit 212e782

File tree

6 files changed

+55
-45
lines changed

6 files changed

+55
-45
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
3030
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
3131
- Update NeuralQueryBuilder doEquals() and doHashCode() to cater the missing parameters information ([#1045](https://github.com/opensearch-project/neural-search/pull/1045)).
3232
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
33+
- Handle pagination_depth when from =0 and removes default value of pagination_depth ([#1132](https://github.com/opensearch-project/neural-search/pull/1132))
3334
### Infrastructure
3435
- Update batch related tests to use batch_size in processor & refactor BWC version check ([#852](https://github.com/opensearch-project/neural-search/pull/852))
3536
- Fix CI for JDK upgrade towards 21 ([#835](https://github.com/opensearch-project/neural-search/pull/835))

src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu
5858
private Integer paginationDepth;
5959

6060
static final int MAX_NUMBER_OF_SUB_QUERIES = 5;
61-
private final static int DEFAULT_PAGINATION_DEPTH = 10;
6261
private static final int LOWER_BOUND_OF_PAGINATION_DEPTH = 0;
6362

6463
public HybridQueryBuilder(StreamInput in) throws IOException {
@@ -167,7 +166,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
167166
public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException {
168167
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
169168

170-
int paginationDepth = DEFAULT_PAGINATION_DEPTH;
169+
Integer paginationDepth = null;
171170
final List<QueryBuilder> queries = new ArrayList<>();
172171
String queryName = null;
173172

@@ -324,7 +323,7 @@ private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, Quer
324323
return queries;
325324
}
326325

327-
private static void validatePaginationDepth(final int paginationDepth, final QueryShardContext queryShardContext) {
326+
private static void validatePaginationDepth(final Integer paginationDepth, final QueryShardContext queryShardContext) {
328327
if (Objects.isNull(paginationDepth)) {
329328
return;
330329
}

src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java

+11-6
Original file line numberDiff line numberDiff line change
@@ -485,14 +485,19 @@ private ReduceableSearchResult reduceSearchResults(final List<ReduceableSearchRe
485485
*/
486486
private static int getSubqueryResultsRetrievalSize(final SearchContext searchContext) {
487487
HybridQuery hybridQuery = unwrapHybridQuery(searchContext);
488-
int paginationDepth = hybridQuery.getQueryContext().getPaginationDepth();
488+
Integer paginationDepth = hybridQuery.getQueryContext().getPaginationDepth();
489489

490-
// Switch to from+size retrieval size during standard hybrid query execution.
491-
if (searchContext.from() == 0) {
492-
return searchContext.size();
490+
// Pagination is expected to work only when pagination_depth is provided in the search request.
491+
if (Objects.isNull(paginationDepth) && searchContext.from() > 0) {
492+
throw new IllegalArgumentException(String.format(Locale.ROOT, "pagination_depth param is missing in the search request"));
493493
}
494-
log.info("pagination_depth is {}", paginationDepth);
495-
return paginationDepth;
494+
495+
if (Objects.nonNull(paginationDepth)) {
496+
return paginationDepth;
497+
}
498+
499+
// Switch to from+size retrieval size during standard hybrid query execution where from is 0.
500+
return searchContext.size();
496501
}
497502

498503
/**

src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java

+1
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
398398
.endObject()
399399
.endObject()
400400
.endArray()
401+
.field("pagination_depth", 10)
401402
.endObject();
402403

403404
NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(

src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java

+1-34
Original file line numberDiff line numberDiff line change
@@ -870,40 +870,6 @@ public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSucc
870870
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
871871
}
872872

873-
@SneakyThrows
874-
public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful() {
875-
try {
876-
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
877-
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);
878-
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
879-
HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder();
880-
hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder());
881-
882-
Map<String, Object> searchResponseAsMap = search(
883-
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
884-
hybridQueryBuilderOnlyMatchAll,
885-
null,
886-
10,
887-
Map.of("search_pipeline", SEARCH_PIPELINE),
888-
null,
889-
null,
890-
null,
891-
false,
892-
null,
893-
2
894-
);
895-
896-
assertEquals(2, getHitCount(searchResponseAsMap));
897-
Map<String, Object> total = getTotalHits(searchResponseAsMap);
898-
assertNotNull(total.get("value"));
899-
assertEquals(4, total.get("value"));
900-
assertNotNull(total.get("relation"));
901-
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
902-
} finally {
903-
wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE);
904-
}
905-
}
906-
907873
@SneakyThrows
908874
public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() {
909875
try {
@@ -912,6 +878,7 @@ public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() {
912878
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
913879
HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder();
914880
hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder());
881+
hybridQueryBuilderOnlyMatchAll.paginationDepth(10);
915882

916883
ResponseException responseException = assertThrows(
917884
ResponseException.class,

src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java

+39-2
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() {
439439
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
440440
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
441441
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
442-
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build();
442+
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().build();
443443

444444
HybridQuery hybridQueryWithMatchAll = new HybridQuery(
445445
List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)),
@@ -633,7 +633,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD
633633
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
634634
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
635635
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
636-
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build();
636+
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().build();
637637

638638
HybridQuery hybridQueryWithTerm = new HybridQuery(
639639
List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)),
@@ -1169,4 +1169,41 @@ public void testScrollWithHybridQuery_thenFail() {
11691169
illegalArgumentException.getMessage()
11701170
);
11711171
}
1172+
1173+
@SneakyThrows
1174+
public void testCreateCollectorManager_whenPaginationDepthIsEqualToNullAndFromIsGreaterThanZero_thenFail() {
1175+
SearchContext searchContext = mock(SearchContext.class);
1176+
// From >0
1177+
when(searchContext.from()).thenReturn(5);
1178+
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
1179+
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
1180+
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
1181+
TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1);
1182+
1183+
HybridQuery hybridQuery = new HybridQuery(
1184+
List.of(termSubQuery.toQuery(mockQueryShardContext)),
1185+
HybridQueryContext.builder().build() // pagination_depth is set to null
1186+
);
1187+
1188+
when(searchContext.query()).thenReturn(hybridQuery);
1189+
ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class);
1190+
IndexReader indexReader = mock(IndexReader.class);
1191+
when(indexSearcher.getIndexReader()).thenReturn(indexReader);
1192+
when(searchContext.searcher()).thenReturn(indexSearcher);
1193+
MapperService mapperService = createMapperService();
1194+
when(searchContext.mapperService()).thenReturn(mapperService);
1195+
1196+
Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> classCollectorManagerMap = new HashMap<>();
1197+
when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap);
1198+
when(searchContext.shouldUseConcurrentSearch()).thenReturn(false);
1199+
1200+
IllegalArgumentException illegalArgumentException = assertThrows(
1201+
IllegalArgumentException.class,
1202+
() -> HybridCollectorManager.createHybridCollectorManager(searchContext)
1203+
);
1204+
assertEquals(
1205+
String.format(Locale.ROOT, "pagination_depth param is missing in the search request"),
1206+
illegalArgumentException.getMessage()
1207+
);
1208+
}
11721209
}

0 commit comments

Comments
 (0)