Skip to content

Commit b084838

Browse files
authored
Pagination in Hybrid query (#1048)
* Pagination in Hybrid query Signed-off-by: Varun Jain <[email protected]> * Remove unwanted code Signed-off-by: Varun Jain <[email protected]> * Adding hybrid query context dto Signed-off-by: Varun Jain <[email protected]> * Adding javadoc in hybridquerycontext and addressing few comments from review Signed-off-by: Varun Jain <[email protected]> * rename hybrid query extraction method Signed-off-by: Varun Jain <[email protected]> * Refactoring to optimize extractHybridQuery method calls Signed-off-by: Varun Jain <[email protected]> * Changes in tests to adapt with builder pattern in querybuilder Signed-off-by: Varun Jain <[email protected]> * Add mapper service mock in tests Signed-off-by: Varun Jain <[email protected]> * Fix error message of index.max_result_window setting Signed-off-by: Varun Jain <[email protected]> * Fix error message of index.max_result_window setting Signed-off-by: Varun Jain <[email protected]> * Fixing validation condition for lower bound Signed-off-by: Varun Jain <[email protected]> * fix tests Signed-off-by: Varun Jain <[email protected]> * Removing version check from doEquals and doHashCode method Signed-off-by: Varun Jain <[email protected]> --------- Signed-off-by: Varun Jain <[email protected]>
1 parent b4cb267 commit b084838

24 files changed

+880
-103
lines changed

Diff for: CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1616

1717
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
1818
### Features
19+
- Pagination in Hybrid query ([#1048](https://github.com/opensearch-project/neural-search/pull/1048))
1920
### Enhancements
2021
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
2122
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))

Diff for: src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java

+5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ public final class MinClusterVersionUtil {
2424
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
2525
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
2626
private static final Version MINIMAL_SUPPORTED_VERSION_QUERY_IMAGE_FIX = Version.V_2_19_0;
27+
private static final Version MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY = Version.V_2_19_0;
2728

2829
// Note this minimal version will act as a override
2930
private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder()
@@ -41,6 +42,10 @@ public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
4142
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
4243
}
4344

45+
public static boolean isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery() {
46+
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY);
47+
}
48+
4449
public static boolean isClusterOnOrAfterMinReqVersion(String key) {
4550
Version version;
4651
if (MINIMAL_VERSION_NEURAL.containsKey(key)) {

Diff for: src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWo
9393
.combinationTechnique(combinationTechnique)
9494
.explain(explain)
9595
.pipelineProcessingContext(requestContextOptional.orElse(null))
96+
.searchPhaseContext(searchPhaseContext)
9697
.build();
9798
normalizationWorkflow.execute(request);
9899
}

Diff for: src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

+59-12
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.apache.lucene.search.Sort;
2020
import org.apache.lucene.search.TopFieldDocs;
2121
import org.apache.lucene.search.FieldDoc;
22+
import org.opensearch.action.search.SearchPhaseContext;
2223
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
2324
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
2425
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
@@ -64,25 +65,30 @@ public void execute(
6465
final List<QuerySearchResult> querySearchResults,
6566
final Optional<FetchSearchResult> fetchSearchResultOptional,
6667
final ScoreNormalizationTechnique normalizationTechnique,
67-
final ScoreCombinationTechnique combinationTechnique
68+
final ScoreCombinationTechnique combinationTechnique,
69+
final SearchPhaseContext searchPhaseContext
6870
) {
6971
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder()
7072
.querySearchResults(querySearchResults)
7173
.fetchSearchResultOptional(fetchSearchResultOptional)
7274
.normalizationTechnique(normalizationTechnique)
7375
.combinationTechnique(combinationTechnique)
7476
.explain(false)
77+
.searchPhaseContext(searchPhaseContext)
7578
.build();
7679
execute(request);
7780
}
7881

7982
public void execute(final NormalizationProcessorWorkflowExecuteRequest request) {
83+
List<QuerySearchResult> querySearchResults = request.getQuerySearchResults();
84+
Optional<FetchSearchResult> fetchSearchResultOptional = request.getFetchSearchResultOptional();
85+
8086
// save original state
81-
List<Integer> unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults());
87+
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);
8288

8389
// pre-process data
8490
log.debug("Pre-process query results");
85-
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(request.getQuerySearchResults());
91+
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);
8692

8793
explain(request, queryTopDocs);
8894

@@ -93,8 +99,9 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request)
9399
CombineScoresDto combineScoresDTO = CombineScoresDto.builder()
94100
.queryTopDocs(queryTopDocs)
95101
.scoreCombinationTechnique(request.getCombinationTechnique())
96-
.querySearchResults(request.getQuerySearchResults())
97-
.sort(evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs))
102+
.querySearchResults(querySearchResults)
103+
.sort(evaluateSortCriteria(querySearchResults, queryTopDocs))
104+
.fromValueForSingleShard(getFromValueIfSingleShard(request))
98105
.build();
99106

100107
// combine
@@ -103,8 +110,26 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request)
103110

104111
// post-process data
105112
log.debug("Post-process query results after score normalization and combination");
106-
updateOriginalQueryResults(combineScoresDTO);
107-
updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds);
113+
updateOriginalQueryResults(combineScoresDTO, fetchSearchResultOptional.isPresent());
114+
updateOriginalFetchResults(
115+
querySearchResults,
116+
fetchSearchResultOptional,
117+
unprocessedDocIds,
118+
combineScoresDTO.getFromValueForSingleShard()
119+
);
120+
}
121+
122+
/**
123+
* Get value of from parameter when there is a single shard
124+
* and fetch phase is already executed
125+
* Ref https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchService.java#L715
126+
*/
127+
private int getFromValueIfSingleShard(final NormalizationProcessorWorkflowExecuteRequest request) {
128+
final SearchPhaseContext searchPhaseContext = request.getSearchPhaseContext();
129+
if (searchPhaseContext.getNumShards() > 1 || request.fetchSearchResultOptional.isEmpty()) {
130+
return -1;
131+
}
132+
return searchPhaseContext.getRequest().source().from();
108133
}
109134

110135
/**
@@ -173,19 +198,33 @@ private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> quer
173198
return queryTopDocs;
174199
}
175200

176-
private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) {
201+
private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO, final boolean isFetchPhaseExecuted) {
177202
final List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults();
178203
final List<CompoundTopDocs> queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults);
179204
final Sort sort = combineScoresDTO.getSort();
205+
int totalScoreDocsCount = 0;
180206
for (int index = 0; index < querySearchResults.size(); index++) {
181207
QuerySearchResult querySearchResult = querySearchResults.get(index);
182208
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
209+
totalScoreDocsCount += updatedTopDocs.getScoreDocs().size();
183210
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(
184211
buildTopDocs(updatedTopDocs, sort),
185212
maxScoreForShard(updatedTopDocs, sort != null)
186213
);
214+
// Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard.
215+
// This will ensure the trimming of the search results.
216+
if (isFetchPhaseExecuted) {
217+
querySearchResult.from(combineScoresDTO.getFromValueForSingleShard());
218+
}
187219
querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats());
188220
}
221+
222+
final int from = querySearchResults.get(0).from();
223+
if (from > totalScoreDocsCount) {
224+
throw new IllegalArgumentException(
225+
String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results")
226+
);
227+
}
189228
}
190229

191230
private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDTO, List<QuerySearchResult> querySearchResults) {
@@ -244,7 +283,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) {
244283
private void updateOriginalFetchResults(
245284
final List<QuerySearchResult> querySearchResults,
246285
final Optional<FetchSearchResult> fetchSearchResultOptional,
247-
final List<Integer> docIds
286+
final List<Integer> docIds,
287+
final int fromValueForSingleShard
248288
) {
249289
if (fetchSearchResultOptional.isEmpty()) {
250290
return;
@@ -276,14 +316,21 @@ private void updateOriginalFetchResults(
276316

277317
QuerySearchResult querySearchResult = querySearchResults.get(0);
278318
TopDocs topDocs = querySearchResult.topDocs().topDocs;
319+
// Scenario to handle when calculating the trimmed length of updated search hits
320+
// When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the
321+
// search request to calculate the effective length of updated search hits array.
322+
int trimmedLengthOfSearchHits = topDocs.scoreDocs.length - fromValueForSingleShard;
279323
// iterate over the normalized/combined scores, that solves (1) and (3)
280-
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
324+
SearchHit[] updatedSearchHitArray = new SearchHit[trimmedLengthOfSearchHits];
325+
for (int i = 0; i < trimmedLengthOfSearchHits; i++) {
326+
// Read topDocs after the desired from length
327+
ScoreDoc scoreDoc = topDocs.scoreDocs[i + fromValueForSingleShard];
281328
// get fetched hit content by doc_id
282329
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
283330
// update score to normalized/combined value (3)
284331
searchHit.score(scoreDoc.score);
285-
return searchHit;
286-
}).toArray(SearchHit[]::new);
332+
updatedSearchHitArray[i] = searchHit;
333+
}
287334
SearchHits updatedSearchHits = new SearchHits(
288335
updatedSearchHitArray,
289336
querySearchResult.getTotalHits(),

Diff for: src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import lombok.AllArgsConstructor;
88
import lombok.Builder;
99
import lombok.Getter;
10+
import org.opensearch.action.search.SearchPhaseContext;
1011
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
1112
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
1213
import org.opensearch.search.fetch.FetchSearchResult;
@@ -29,4 +30,5 @@ public class NormalizationProcessorWorkflowExecuteRequest {
2930
final ScoreCombinationTechnique combinationTechnique;
3031
boolean explain;
3132
final PipelineProcessingContext pipelineProcessingContext;
33+
final SearchPhaseContext searchPhaseContext;
3234
}

Diff for: src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java

+1
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ public class CombineScoresDto {
2929
private List<QuerySearchResult> querySearchResults;
3030
@Nullable
3131
private Sort sort;
32+
private int fromValueForSingleShard;
3233
}

Diff for: src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java

+3-7
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,10 @@ public class ScoreCombiner {
7070
public void combineScores(final CombineScoresDto combineScoresDTO) {
7171
// iterate over results from each shard. Every CompoundTopDocs object has results from
7272
// multiple sub queries, doc ids may repeat for each sub query results
73+
ScoreCombinationTechnique scoreCombinationTechnique = combineScoresDTO.getScoreCombinationTechnique();
74+
Sort sort = combineScoresDTO.getSort();
7375
combineScoresDTO.getQueryTopDocs()
74-
.forEach(
75-
compoundQueryTopDocs -> combineShardScores(
76-
combineScoresDTO.getScoreCombinationTechnique(),
77-
compoundQueryTopDocs,
78-
combineScoresDTO.getSort()
79-
)
80-
);
76+
.forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort));
8177
}
8278

8379
private void combineShardScores(

Diff for: src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java

+13-4
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,21 @@
3434
public final class HybridQuery extends Query implements Iterable<Query> {
3535

3636
private final List<Query> subQueries;
37+
private final HybridQueryContext queryContext;
3738

3839
/**
3940
* Create new instance of hybrid query object based on collection of sub queries and filter query
4041
* @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores
4142
* @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is
4243
*/
43-
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries) {
44+
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, final HybridQueryContext hybridQueryContext) {
4445
Objects.requireNonNull(subQueries, "collection of queries must not be null");
4546
if (subQueries.isEmpty()) {
4647
throw new IllegalArgumentException("collection of queries must not be empty");
4748
}
49+
if (hybridQueryContext.getPaginationDepth() == 0) {
50+
throw new IllegalArgumentException("pagination_depth must not be zero");
51+
}
4852
if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) {
4953
this.subQueries = new ArrayList<>(subQueries);
5054
} else {
@@ -57,10 +61,11 @@ public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQ
5761
}
5862
this.subQueries = modifiedSubQueries;
5963
}
64+
this.queryContext = hybridQueryContext;
6065
}
6166

62-
public HybridQuery(final Collection<Query> subQueries) {
63-
this(subQueries, List.of());
67+
public HybridQuery(final Collection<Query> subQueries, final HybridQueryContext hybridQueryContext) {
68+
this(subQueries, List.of(), hybridQueryContext);
6469
}
6570

6671
/**
@@ -128,7 +133,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
128133
return super.rewrite(indexSearcher);
129134
}
130135
final List<Query> rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors);
131-
return new HybridQuery(rewrittenSubQueries);
136+
return new HybridQuery(rewrittenSubQueries, queryContext);
132137
}
133138

134139
private Void rewriteQuery(Query query, HybridQueryExecutorCollector<IndexSearcher, Map.Entry<Query, Boolean>> collector) {
@@ -190,6 +195,10 @@ public Collection<Query> getSubQueries() {
190195
return Collections.unmodifiableCollection(subQueries);
191196
}
192197

198+
public HybridQueryContext getQueryContext() {
199+
return queryContext;
200+
}
201+
193202
/**
194203
* Create the Weight used to score this query
195204
*

0 commit comments

Comments
 (0)