19
19
import org .apache .lucene .search .Sort ;
20
20
import org .apache .lucene .search .TopFieldDocs ;
21
21
import org .apache .lucene .search .FieldDoc ;
22
+ import org .opensearch .action .search .SearchPhaseContext ;
22
23
import org .opensearch .common .lucene .search .TopDocsAndMaxScore ;
23
24
import org .opensearch .neuralsearch .processor .combination .CombineScoresDto ;
24
25
import org .opensearch .neuralsearch .processor .combination .ScoreCombinationTechnique ;
@@ -64,25 +65,30 @@ public void execute(
64
65
final List <QuerySearchResult > querySearchResults ,
65
66
final Optional <FetchSearchResult > fetchSearchResultOptional ,
66
67
final ScoreNormalizationTechnique normalizationTechnique ,
67
- final ScoreCombinationTechnique combinationTechnique
68
+ final ScoreCombinationTechnique combinationTechnique ,
69
+ final SearchPhaseContext searchPhaseContext
68
70
) {
69
71
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest .builder ()
70
72
.querySearchResults (querySearchResults )
71
73
.fetchSearchResultOptional (fetchSearchResultOptional )
72
74
.normalizationTechnique (normalizationTechnique )
73
75
.combinationTechnique (combinationTechnique )
74
76
.explain (false )
77
+ .searchPhaseContext (searchPhaseContext )
75
78
.build ();
76
79
execute (request );
77
80
}
78
81
79
82
public void execute (final NormalizationProcessorWorkflowExecuteRequest request ) {
83
+ List <QuerySearchResult > querySearchResults = request .getQuerySearchResults ();
84
+ Optional <FetchSearchResult > fetchSearchResultOptional = request .getFetchSearchResultOptional ();
85
+
80
86
// save original state
81
- List <Integer > unprocessedDocIds = unprocessedDocIds (request . getQuerySearchResults () );
87
+ List <Integer > unprocessedDocIds = unprocessedDocIds (querySearchResults );
82
88
83
89
// pre-process data
84
90
log .debug ("Pre-process query results" );
85
- List <CompoundTopDocs > queryTopDocs = getQueryTopDocs (request . getQuerySearchResults () );
91
+ List <CompoundTopDocs > queryTopDocs = getQueryTopDocs (querySearchResults );
86
92
87
93
explain (request , queryTopDocs );
88
94
@@ -93,8 +99,9 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request)
93
99
CombineScoresDto combineScoresDTO = CombineScoresDto .builder ()
94
100
.queryTopDocs (queryTopDocs )
95
101
.scoreCombinationTechnique (request .getCombinationTechnique ())
96
- .querySearchResults (request .getQuerySearchResults ())
97
- .sort (evaluateSortCriteria (request .getQuerySearchResults (), queryTopDocs ))
102
+ .querySearchResults (querySearchResults )
103
+ .sort (evaluateSortCriteria (querySearchResults , queryTopDocs ))
104
+ .fromValueForSingleShard (getFromValueIfSingleShard (request ))
98
105
.build ();
99
106
100
107
// combine
@@ -103,8 +110,26 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request)
103
110
104
111
// post-process data
105
112
log .debug ("Post-process query results after score normalization and combination" );
106
- updateOriginalQueryResults (combineScoresDTO );
107
- updateOriginalFetchResults (request .getQuerySearchResults (), request .getFetchSearchResultOptional (), unprocessedDocIds );
113
+ updateOriginalQueryResults (combineScoresDTO , fetchSearchResultOptional .isPresent ());
114
+ updateOriginalFetchResults (
115
+ querySearchResults ,
116
+ fetchSearchResultOptional ,
117
+ unprocessedDocIds ,
118
+ combineScoresDTO .getFromValueForSingleShard ()
119
+ );
120
+ }
121
+
122
+ /**
123
+ * Get value of from parameter when there is a single shard
124
+ * and fetch phase is already executed
125
+ * Ref https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchService.java#L715
126
+ */
127
+ private int getFromValueIfSingleShard (final NormalizationProcessorWorkflowExecuteRequest request ) {
128
+ final SearchPhaseContext searchPhaseContext = request .getSearchPhaseContext ();
129
+ if (searchPhaseContext .getNumShards () > 1 || request .fetchSearchResultOptional .isEmpty ()) {
130
+ return -1 ;
131
+ }
132
+ return searchPhaseContext .getRequest ().source ().from ();
108
133
}
109
134
110
135
/**
@@ -173,19 +198,33 @@ private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> quer
173
198
return queryTopDocs ;
174
199
}
175
200
176
- private void updateOriginalQueryResults (final CombineScoresDto combineScoresDTO ) {
201
+ private void updateOriginalQueryResults (final CombineScoresDto combineScoresDTO , final boolean isFetchPhaseExecuted ) {
177
202
final List <QuerySearchResult > querySearchResults = combineScoresDTO .getQuerySearchResults ();
178
203
final List <CompoundTopDocs > queryTopDocs = getCompoundTopDocs (combineScoresDTO , querySearchResults );
179
204
final Sort sort = combineScoresDTO .getSort ();
205
+ int totalScoreDocsCount = 0 ;
180
206
for (int index = 0 ; index < querySearchResults .size (); index ++) {
181
207
QuerySearchResult querySearchResult = querySearchResults .get (index );
182
208
CompoundTopDocs updatedTopDocs = queryTopDocs .get (index );
209
+ totalScoreDocsCount += updatedTopDocs .getScoreDocs ().size ();
183
210
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore (
184
211
buildTopDocs (updatedTopDocs , sort ),
185
212
maxScoreForShard (updatedTopDocs , sort != null )
186
213
);
214
+ // Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard.
215
+ // This will ensure the trimming of the search results.
216
+ if (isFetchPhaseExecuted ) {
217
+ querySearchResult .from (combineScoresDTO .getFromValueForSingleShard ());
218
+ }
187
219
querySearchResult .topDocs (updatedTopDocsAndMaxScore , querySearchResult .sortValueFormats ());
188
220
}
221
+
222
+ final int from = querySearchResults .get (0 ).from ();
223
+ if (from > totalScoreDocsCount ) {
224
+ throw new IllegalArgumentException (
225
+ String .format (Locale .ROOT , "Reached end of search result, increase pagination_depth value to see more results" )
226
+ );
227
+ }
189
228
}
190
229
191
230
private List <CompoundTopDocs > getCompoundTopDocs (CombineScoresDto combineScoresDTO , List <QuerySearchResult > querySearchResults ) {
@@ -244,7 +283,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) {
244
283
private void updateOriginalFetchResults (
245
284
final List <QuerySearchResult > querySearchResults ,
246
285
final Optional <FetchSearchResult > fetchSearchResultOptional ,
247
- final List <Integer > docIds
286
+ final List <Integer > docIds ,
287
+ final int fromValueForSingleShard
248
288
) {
249
289
if (fetchSearchResultOptional .isEmpty ()) {
250
290
return ;
@@ -276,14 +316,21 @@ private void updateOriginalFetchResults(
276
316
277
317
QuerySearchResult querySearchResult = querySearchResults .get (0 );
278
318
TopDocs topDocs = querySearchResult .topDocs ().topDocs ;
319
+ // Scenario to handle when calculating the trimmed length of updated search hits
320
+ // When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the
321
+ // search request to calculate the effective length of updated search hits array.
322
+ int trimmedLengthOfSearchHits = topDocs .scoreDocs .length - fromValueForSingleShard ;
279
323
// iterate over the normalized/combined scores, that solves (1) and (3)
280
- SearchHit [] updatedSearchHitArray = Arrays .stream (topDocs .scoreDocs ).map (scoreDoc -> {
324
+ SearchHit [] updatedSearchHitArray = new SearchHit [trimmedLengthOfSearchHits ];
325
+ for (int i = 0 ; i < trimmedLengthOfSearchHits ; i ++) {
326
+ // Read topDocs after the desired from length
327
+ ScoreDoc scoreDoc = topDocs .scoreDocs [i + fromValueForSingleShard ];
281
328
// get fetched hit content by doc_id
282
329
SearchHit searchHit = docIdToSearchHit .get (scoreDoc .doc );
283
330
// update score to normalized/combined value (3)
284
331
searchHit .score (scoreDoc .score );
285
- return searchHit ;
286
- }). toArray ( SearchHit []:: new );
332
+ updatedSearchHitArray [ i ] = searchHit ;
333
+ }
287
334
SearchHits updatedSearchHits = new SearchHits (
288
335
updatedSearchHitArray ,
289
336
querySearchResult .getTotalHits (),
0 commit comments