|
| 1 | +/* |
| 2 | + * Copyright OpenSearch Contributors |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | +package org.opensearch.neuralsearch.processor; |
| 6 | + |
| 7 | +import lombok.AllArgsConstructor; |
| 8 | +import lombok.Getter; |
| 9 | +import org.apache.lucene.search.Explanation; |
| 10 | +import org.opensearch.action.search.SearchRequest; |
| 11 | +import org.opensearch.action.search.SearchResponse; |
| 12 | +import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails; |
| 13 | +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; |
| 14 | +import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; |
| 15 | +import org.opensearch.search.SearchHit; |
| 16 | +import org.opensearch.search.SearchHits; |
| 17 | +import org.opensearch.search.SearchShardTarget; |
| 18 | +import org.opensearch.search.pipeline.PipelineProcessingContext; |
| 19 | +import org.opensearch.search.pipeline.SearchResponseProcessor; |
| 20 | + |
| 21 | +import java.util.ArrayList; |
| 22 | +import java.util.HashMap; |
| 23 | +import java.util.List; |
| 24 | +import java.util.Map; |
| 25 | +import java.util.Objects; |
| 26 | + |
| 27 | +import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY; |
| 28 | +import static org.opensearch.neuralsearch.processor.explain.ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR; |
| 29 | + |
| 30 | +/** |
| 31 | + * Processor to add explanation details to search response |
| 32 | + */ |
| 33 | +@Getter |
| 34 | +@AllArgsConstructor |
| 35 | +public class ExplanationResponseProcessor implements SearchResponseProcessor { |
| 36 | + |
| 37 | + public static final String TYPE = "explanation_response_processor"; |
| 38 | + |
| 39 | + private final String description; |
| 40 | + private final String tag; |
| 41 | + private final boolean ignoreFailure; |
| 42 | + |
| 43 | + /** |
| 44 | + * Add explanation details to search response if it is present in request context |
| 45 | + */ |
| 46 | + @Override |
| 47 | + public SearchResponse processResponse(SearchRequest request, SearchResponse response) { |
| 48 | + return processResponse(request, response, null); |
| 49 | + } |
| 50 | + |
| 51 | + /** |
| 52 | + * Combines explanation from processor with search hits level explanations and adds it to search response |
| 53 | + */ |
| 54 | + @Override |
| 55 | + public SearchResponse processResponse( |
| 56 | + final SearchRequest request, |
| 57 | + final SearchResponse response, |
| 58 | + final PipelineProcessingContext requestContext |
| 59 | + ) { |
| 60 | + if (Objects.isNull(requestContext) |
| 61 | + || (Objects.isNull(requestContext.getAttribute(EXPLANATION_RESPONSE_KEY))) |
| 62 | + || requestContext.getAttribute(EXPLANATION_RESPONSE_KEY) instanceof ExplanationPayload == false) { |
| 63 | + return response; |
| 64 | + } |
| 65 | + // Extract explanation payload from context |
| 66 | + ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLANATION_RESPONSE_KEY); |
| 67 | + Map<ExplanationPayload.PayloadType, Object> explainPayload = explanationPayload.getExplainPayload(); |
| 68 | + if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { |
| 69 | + // for score normalization, processor level explanations will be sorted in scope of each shard, |
| 70 | + // and we are merging both into a single sorted list |
| 71 | + SearchHits searchHits = response.getHits(); |
| 72 | + SearchHit[] searchHitsArray = searchHits.getHits(); |
| 73 | + // create a map of searchShard and list of indexes of search hit objects in search hits array |
| 74 | + // the list will keep original order of sorting as per final search results |
| 75 | + Map<SearchShard, List<Integer>> searchHitsByShard = new HashMap<>(); |
| 76 | + // we keep index for each shard, where index is a position in searchHitsByShard list |
| 77 | + Map<SearchShard, Integer> explainsByShardCount = new HashMap<>(); |
| 78 | + // Build initial shard mappings |
| 79 | + for (int i = 0; i < searchHitsArray.length; i++) { |
| 80 | + SearchHit searchHit = searchHitsArray[i]; |
| 81 | + SearchShardTarget searchShardTarget = searchHit.getShard(); |
| 82 | + SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget); |
| 83 | + searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i); |
| 84 | + explainsByShardCount.putIfAbsent(searchShard, -1); |
| 85 | + } |
| 86 | + // Process normalization details if available in correct format |
| 87 | + if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map<?, ?>) { |
| 88 | + @SuppressWarnings("unchecked") |
| 89 | + Map<SearchShard, List<CombinedExplanationDetails>> combinedExplainDetails = (Map< |
| 90 | + SearchShard, |
| 91 | + List<CombinedExplanationDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR); |
| 92 | + // Process each search hit to add processor level explanations |
| 93 | + for (SearchHit searchHit : searchHitsArray) { |
| 94 | + SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard()); |
| 95 | + int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1; |
| 96 | + CombinedExplanationDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard); |
| 97 | + // Extract various explanation components |
| 98 | + Explanation queryLevelExplanation = searchHit.getExplanation(); |
| 99 | + ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations(); |
| 100 | + ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations(); |
| 101 | + // Create normalized explanations for each detail |
| 102 | + Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length]; |
| 103 | + for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) { |
| 104 | + normalizedExplanation[i] = Explanation.match( |
| 105 | + // normalized score |
| 106 | + normalizationExplanation.getScoreDetails().get(i).getKey(), |
| 107 | + // description of normalized score |
| 108 | + normalizationExplanation.getScoreDetails().get(i).getValue(), |
| 109 | + // shard level details |
| 110 | + queryLevelExplanation.getDetails()[i] |
| 111 | + ); |
| 112 | + } |
| 113 | + // Create and set final explanation combining all components |
| 114 | + Explanation finalExplanation = Explanation.match( |
| 115 | + searchHit.getScore(), |
| 116 | + // combination level explanation is always a single detail |
| 117 | + combinationExplanation.getScoreDetails().get(0).getValue(), |
| 118 | + normalizedExplanation |
| 119 | + ); |
| 120 | + searchHit.explanation(finalExplanation); |
| 121 | + explainsByShardCount.put(searchShard, explanationIndexByShard); |
| 122 | + } |
| 123 | + } |
| 124 | + } |
| 125 | + return response; |
| 126 | + } |
| 127 | + |
| 128 | + @Override |
| 129 | + public String getType() { |
| 130 | + return TYPE; |
| 131 | + } |
| 132 | +} |
0 commit comments