Skip to content

Commit 41e9344

Browse files
Explainability in hybrid query (#970) (#1014)
* Added Explainability support for hybrid query Signed-off-by: Martin Gaievski <[email protected]> (cherry picked from commit 393d49a)
1 parent 80f9c0a commit 41e9344

File tree

41 files changed

+2461
-95
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2461
-95
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1616
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
1717
### Features
1818
### Enhancements
19+
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
1920
### Bug Fixes
2021
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
2122
### Infrastructure

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@
3232
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
3333
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
3434
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
35+
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
3536
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
3637
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
3738
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
3839
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
3940
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
4041
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
42+
import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory;
4143
import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory;
4244
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
4345
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
@@ -80,6 +82,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
8082
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
8183
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
8284
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();
85+
public static final String EXPLANATION_RESPONSE_KEY = "explanation_response";
8386

8487
@Override
8588
public Collection<Object> createComponents(
@@ -181,7 +184,9 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRespon
181184
) {
182185
return Map.of(
183186
RerankProcessor.TYPE,
184-
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService())
187+
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService()),
188+
ExplanationResponseProcessor.TYPE,
189+
new ExplanationResponseProcessorFactory()
185190
);
186191
}
187192

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import lombok.Setter;
2222
import lombok.ToString;
2323
import lombok.extern.log4j.Log4j2;
24+
import org.opensearch.search.SearchShardTarget;
25+
import org.opensearch.search.query.QuerySearchResult;
2426

2527
/**
2628
* Class stores collection of TopDocs for each sub query from hybrid query. Collection of results is at shard level. We do store
@@ -37,15 +39,23 @@ public class CompoundTopDocs {
3739
private List<TopDocs> topDocs;
3840
@Setter
3941
private List<ScoreDoc> scoreDocs;
42+
@Getter
43+
private SearchShard searchShard;
4044

41-
public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs, final boolean isSortEnabled) {
42-
initialize(totalHits, topDocs, isSortEnabled);
45+
public CompoundTopDocs(
46+
final TotalHits totalHits,
47+
final List<TopDocs> topDocs,
48+
final boolean isSortEnabled,
49+
final SearchShard searchShard
50+
) {
51+
initialize(totalHits, topDocs, isSortEnabled, searchShard);
4352
}
4453

45-
private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSortEnabled) {
54+
private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSortEnabled, SearchShard searchShard) {
4655
this.totalHits = totalHits;
4756
this.topDocs = topDocs;
4857
scoreDocs = cloneLargestScoreDocs(topDocs, isSortEnabled);
58+
this.searchShard = searchShard;
4959
}
5060

5161
/**
@@ -72,14 +82,17 @@ private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSo
7282
* 6, 0.15
7383
* 0, 9549511920.4881596047
7484
*/
75-
public CompoundTopDocs(final TopDocs topDocs) {
85+
public CompoundTopDocs(final QuerySearchResult querySearchResult) {
86+
final TopDocs topDocs = querySearchResult.topDocs().topDocs;
87+
final SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget();
88+
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
7689
boolean isSortEnabled = false;
7790
if (topDocs instanceof TopFieldDocs) {
7891
isSortEnabled = true;
7992
}
8093
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
8194
if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) {
82-
initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled);
95+
initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled, searchShard);
8396
return;
8497
}
8598
// skipping first two elements, it's a start-stop element and delimiter for first series
@@ -103,7 +116,7 @@ public CompoundTopDocs(final TopDocs topDocs) {
103116
scoreDocList.add(scoreDoc);
104117
}
105118
}
106-
initialize(topDocs.totalHits, topDocsList, isSortEnabled);
119+
initialize(topDocs.totalHits, topDocsList, isSortEnabled, searchShard);
107120
}
108121

109122
private List<ScoreDoc> cloneLargestScoreDocs(final List<TopDocs> docs, boolean isSortEnabled) {
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor;
6+
7+
import lombok.AllArgsConstructor;
8+
import lombok.Getter;
9+
import org.apache.lucene.search.Explanation;
10+
import org.opensearch.action.search.SearchRequest;
11+
import org.opensearch.action.search.SearchResponse;
12+
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
13+
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
14+
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
15+
import org.opensearch.search.SearchHit;
16+
import org.opensearch.search.SearchHits;
17+
import org.opensearch.search.SearchShardTarget;
18+
import org.opensearch.search.pipeline.PipelineProcessingContext;
19+
import org.opensearch.search.pipeline.SearchResponseProcessor;
20+
21+
import java.util.ArrayList;
22+
import java.util.HashMap;
23+
import java.util.List;
24+
import java.util.Map;
25+
import java.util.Objects;
26+
27+
import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY;
28+
import static org.opensearch.neuralsearch.processor.explain.ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR;
29+
30+
/**
31+
* Processor to add explanation details to search response
32+
*/
33+
@Getter
34+
@AllArgsConstructor
35+
public class ExplanationResponseProcessor implements SearchResponseProcessor {
36+
37+
public static final String TYPE = "explanation_response_processor";
38+
39+
private final String description;
40+
private final String tag;
41+
private final boolean ignoreFailure;
42+
43+
/**
44+
* Add explanation details to search response if it is present in request context
45+
*/
46+
@Override
47+
public SearchResponse processResponse(SearchRequest request, SearchResponse response) {
48+
return processResponse(request, response, null);
49+
}
50+
51+
/**
52+
* Combines explanation from processor with search hits level explanations and adds it to search response
53+
*/
54+
@Override
55+
public SearchResponse processResponse(
56+
final SearchRequest request,
57+
final SearchResponse response,
58+
final PipelineProcessingContext requestContext
59+
) {
60+
if (Objects.isNull(requestContext)
61+
|| (Objects.isNull(requestContext.getAttribute(EXPLANATION_RESPONSE_KEY)))
62+
|| requestContext.getAttribute(EXPLANATION_RESPONSE_KEY) instanceof ExplanationPayload == false) {
63+
return response;
64+
}
65+
// Extract explanation payload from context
66+
ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLANATION_RESPONSE_KEY);
67+
Map<ExplanationPayload.PayloadType, Object> explainPayload = explanationPayload.getExplainPayload();
68+
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
69+
// for score normalization, processor level explanations will be sorted in scope of each shard,
70+
// and we are merging both into a single sorted list
71+
SearchHits searchHits = response.getHits();
72+
SearchHit[] searchHitsArray = searchHits.getHits();
73+
// create a map of searchShard and list of indexes of search hit objects in search hits array
74+
// the list will keep original order of sorting as per final search results
75+
Map<SearchShard, List<Integer>> searchHitsByShard = new HashMap<>();
76+
// we keep index for each shard, where index is a position in searchHitsByShard list
77+
Map<SearchShard, Integer> explainsByShardCount = new HashMap<>();
78+
// Build initial shard mappings
79+
for (int i = 0; i < searchHitsArray.length; i++) {
80+
SearchHit searchHit = searchHitsArray[i];
81+
SearchShardTarget searchShardTarget = searchHit.getShard();
82+
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
83+
searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i);
84+
explainsByShardCount.putIfAbsent(searchShard, -1);
85+
}
86+
// Process normalization details if available in correct format
87+
if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map<?, ?>) {
88+
@SuppressWarnings("unchecked")
89+
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplainDetails = (Map<
90+
SearchShard,
91+
List<CombinedExplanationDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR);
92+
// Process each search hit to add processor level explanations
93+
for (SearchHit searchHit : searchHitsArray) {
94+
SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard());
95+
int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1;
96+
CombinedExplanationDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard);
97+
// Extract various explanation components
98+
Explanation queryLevelExplanation = searchHit.getExplanation();
99+
ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations();
100+
ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations();
101+
// Create normalized explanations for each detail
102+
Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length];
103+
for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) {
104+
normalizedExplanation[i] = Explanation.match(
105+
// normalized score
106+
normalizationExplanation.getScoreDetails().get(i).getKey(),
107+
// description of normalized score
108+
normalizationExplanation.getScoreDetails().get(i).getValue(),
109+
// shard level details
110+
queryLevelExplanation.getDetails()[i]
111+
);
112+
}
113+
// Create and set final explanation combining all components
114+
Explanation finalExplanation = Explanation.match(
115+
searchHit.getScore(),
116+
// combination level explanation is always a single detail
117+
combinationExplanation.getScoreDetails().get(0).getValue(),
118+
normalizedExplanation
119+
);
120+
searchHit.explanation(finalExplanation);
121+
explainsByShardCount.put(searchShard, explanationIndexByShard);
122+
}
123+
}
124+
}
125+
return response;
126+
}
127+
128+
@Override
129+
public String getType() {
130+
return TYPE;
131+
}
132+
}

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.opensearch.search.SearchPhaseResult;
2121
import org.opensearch.search.fetch.FetchSearchResult;
2222
import org.opensearch.search.internal.SearchContext;
23+
import org.opensearch.search.pipeline.PipelineProcessingContext;
2324
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
2425
import org.opensearch.search.query.QuerySearchResult;
2526

@@ -43,22 +44,57 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {
4344

4445
/**
4546
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
46-
* are set as part of class constructor
47+
* are set as part of class constructor. This method is called when there is no pipeline context
4748
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
4849
* @param searchPhaseContext {@link SearchContext}
4950
*/
5051
@Override
5152
public <Result extends SearchPhaseResult> void process(
5253
final SearchPhaseResults<Result> searchPhaseResult,
5354
final SearchPhaseContext searchPhaseContext
55+
) {
56+
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty());
57+
}
58+
59+
/**
60+
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
61+
* are set as part of class constructor
62+
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
63+
* @param searchPhaseContext {@link SearchContext}
64+
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
65+
* @param <Result>
66+
*/
67+
@Override
68+
public <Result extends SearchPhaseResult> void process(
69+
final SearchPhaseResults<Result> searchPhaseResult,
70+
final SearchPhaseContext searchPhaseContext,
71+
final PipelineProcessingContext requestContext
72+
) {
73+
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
74+
}
75+
76+
private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWorkflow(
77+
SearchPhaseResults<Result> searchPhaseResult,
78+
SearchPhaseContext searchPhaseContext,
79+
Optional<PipelineProcessingContext> requestContextOptional
5480
) {
5581
if (shouldSkipProcessor(searchPhaseResult)) {
5682
log.debug("Query results are not compatible with normalization processor");
5783
return;
5884
}
5985
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
6086
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
61-
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
87+
boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain())
88+
&& searchPhaseContext.getRequest().source().explain();
89+
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder()
90+
.querySearchResults(querySearchResults)
91+
.fetchSearchResultOptional(fetchSearchResult)
92+
.normalizationTechnique(normalizationTechnique)
93+
.combinationTechnique(combinationTechnique)
94+
.explain(explain)
95+
.pipelineProcessingContext(requestContextOptional.orElse(null))
96+
.build();
97+
normalizationWorkflow.execute(request);
6298
}
6399

64100
@Override

0 commit comments

Comments
 (0)