Skip to content

Commit a8dd556

Browse files
committed
Support semantic sentence highlighter
Signed-off-by: Junqiu Lei <[email protected]>
1 parent f5377c0 commit a8dd556

24 files changed

+2753
-25
lines changed

CHANGELOG.md

+1-6
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
99
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
1010
- Support filter function for HybridQueryBuilder and NeuralQueryBuilder ([#1206](https://github.com/opensearch-project/neural-search/pull/1206))
1111
- Add Z Score normalization technique ([#1224](https://github.com/opensearch-project/neural-search/pull/1224))
12-
12+
- Support semantic sentence highlighter ([#1193](https://github.com/opensearch-project/neural-search/pull/1193))
1313
### Enhancements
14-
1514
### Bug Fixes
16-
1715
### Infrastructure
18-
1916
### Documentation
20-
2117
### Maintenance
22-
2318
### Refactoring
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.highlight;
6+
7+
import lombok.extern.log4j.Log4j2;
8+
import org.opensearch.index.mapper.MappedFieldType;
9+
import org.opensearch.search.fetch.subphase.highlight.FieldHighlightContext;
10+
import org.opensearch.search.fetch.subphase.highlight.HighlightField;
11+
import org.opensearch.search.fetch.subphase.highlight.Highlighter;
12+
import org.opensearch.core.common.text.Text;
13+
14+
/**
15+
* Semantic highlighter that uses ML models to identify relevant text spans for highlighting
16+
*/
17+
@Log4j2
18+
public class SemanticHighlighter implements Highlighter {
19+
public static final String NAME = "semantic";
20+
21+
private SemanticHighlighterEngine semanticHighlighterEngine;
22+
23+
public void initialize(SemanticHighlighterEngine semanticHighlighterEngine) {
24+
if (this.semanticHighlighterEngine != null) {
25+
throw new IllegalStateException(
26+
"SemanticHighlighterEngine has already been initialized. Multiple initializations are not permitted."
27+
);
28+
}
29+
this.semanticHighlighterEngine = semanticHighlighterEngine;
30+
}
31+
32+
@Override
33+
public boolean canHighlight(MappedFieldType fieldType) {
34+
return true;
35+
}
36+
37+
/**
38+
* Highlights a field using semantic highlighting
39+
*
40+
* @param fieldContext The field context containing the query and field information
41+
* @return The highlighted field or null if highlighting is not possible
42+
*/
43+
@Override
44+
public HighlightField highlight(FieldHighlightContext fieldContext) {
45+
if (semanticHighlighterEngine == null) {
46+
throw new IllegalStateException("SemanticHighlighter has not been initialized");
47+
}
48+
49+
// Extract field text
50+
String fieldText = semanticHighlighterEngine.getFieldText(fieldContext);
51+
52+
// Get model ID
53+
String modelId = semanticHighlighterEngine.getModelId(fieldContext.field.fieldOptions().options());
54+
55+
// Try to extract query text
56+
String originalQueryText = semanticHighlighterEngine.extractOriginalQuery(fieldContext.query, fieldContext.fieldName);
57+
58+
if (originalQueryText == null || originalQueryText.isEmpty()) {
59+
log.warn("No query text found for field {}", fieldContext.fieldName);
60+
return null;
61+
}
62+
63+
// Get highlighted text - allow any exceptions from this call to propagate
64+
String highlightedResponse = semanticHighlighterEngine.getHighlightedSentences(modelId, originalQueryText, fieldText);
65+
66+
if (highlightedResponse == null || highlightedResponse.isEmpty()) {
67+
log.warn("No highlighted text found for field {}", fieldContext.fieldName);
68+
return null;
69+
}
70+
71+
// Create highlight field
72+
Text[] fragments = new Text[] { new Text(highlightedResponse) };
73+
return new HighlightField(fieldContext.fieldName, fragments);
74+
}
75+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.highlight;
6+
7+
import lombok.extern.log4j.Log4j2;
8+
import org.apache.lucene.search.Query;
9+
import org.opensearch.OpenSearchException;
10+
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
11+
import org.opensearch.neuralsearch.processor.highlight.SentenceHighlightingRequest;
12+
import org.opensearch.search.fetch.subphase.highlight.FieldHighlightContext;
13+
import org.opensearch.neuralsearch.highlight.extractor.QueryTextExtractorRegistry;
14+
import org.opensearch.action.support.PlainActionFuture;
15+
import lombok.NonNull;
16+
17+
import java.util.ArrayList;
18+
import java.util.List;
19+
import java.util.Locale;
20+
import java.util.Map;
21+
import java.util.Objects;
22+
23+
/**
24+
* Engine class for semantic highlighting operations
25+
*/
26+
@Log4j2
27+
public class SemanticHighlighterEngine {
28+
private static final String MODEL_ID_FIELD = "model_id";
29+
private static final String DEFAULT_PRE_TAG = "<em>";
30+
private static final String DEFAULT_POST_TAG = "</em>";
31+
private static final String MODEL_INFERENCE_RESULT_KEY = "highlights";
32+
private static final String MODEL_INFERENCE_RESULT_START_KEY = "start";
33+
private static final String MODEL_INFERENCE_RESULT_END_KEY = "end";
34+
35+
private final MLCommonsClientAccessor mlCommonsClient;
36+
private final QueryTextExtractorRegistry queryTextExtractorRegistry;
37+
38+
public SemanticHighlighterEngine(
39+
@NonNull MLCommonsClientAccessor mlCommonsClient,
40+
@NonNull QueryTextExtractorRegistry queryTextExtractorRegistry
41+
) {
42+
this.mlCommonsClient = mlCommonsClient;
43+
this.queryTextExtractorRegistry = queryTextExtractorRegistry;
44+
}
45+
46+
/**
47+
* Gets the field text from the document
48+
*
49+
* @param fieldContext The field highlight context
50+
* @return The field text
51+
*/
52+
public String getFieldText(FieldHighlightContext fieldContext) {
53+
if (fieldContext.hitContext == null || fieldContext.hitContext.sourceLookup() == null) {
54+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field %s is not found in the hit", fieldContext.fieldName));
55+
}
56+
Object fieldTextObject = fieldContext.hitContext.sourceLookup().extractValue(fieldContext.fieldName, null);
57+
if (fieldTextObject == null) {
58+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field %s is not found in the document", fieldContext.fieldName));
59+
}
60+
if (fieldTextObject instanceof String == false) {
61+
throw new IllegalArgumentException(
62+
String.format(
63+
Locale.ROOT,
64+
"Field %s must be a string for highlighting, but was %s",
65+
fieldContext.fieldName,
66+
fieldTextObject.getClass().getSimpleName()
67+
)
68+
);
69+
}
70+
String fieldTextString = (String) fieldTextObject;
71+
if (fieldTextString.isEmpty()) {
72+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field %s is empty", fieldContext.fieldName));
73+
}
74+
return fieldTextString;
75+
}
76+
77+
/**
78+
* Extracts the original query text from the search query object.
79+
*
80+
* @param query The query object from which to extract the original text
81+
* @param fieldName The name of the field being highlighted
82+
* @return The extracted original query text for highlighting
83+
* @throws IllegalArgumentException if the extracted query text is empty
84+
*/
85+
public String extractOriginalQuery(Query query, String fieldName) {
86+
if (fieldName == null) {
87+
log.warn("Field name is null, extraction may be less accurate");
88+
}
89+
return queryTextExtractorRegistry.extractQueryText(query, fieldName);
90+
}
91+
92+
/**
93+
* Gets the model ID from the options
94+
*
95+
* @param options The options map
96+
* @return The model ID
97+
*/
98+
public String getModelId(Map<String, Object> options) {
99+
Object modelId = options.get(MODEL_ID_FIELD);
100+
if (Objects.isNull(modelId)) {
101+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Missing required option: %s", MODEL_ID_FIELD));
102+
}
103+
return modelId.toString();
104+
}
105+
106+
/**
107+
* Gets highlighted text from the ML model
108+
*
109+
* @param modelId The ID of the model to use
110+
* @param question The search query
111+
* @param context The document text
112+
* @return Formatted text with highlighting
113+
*/
114+
public String getHighlightedSentences(String modelId, String question, String context) {
115+
List<Map<String, Object>> results = fetchModelResults(modelId, question, context);
116+
if (results == null || results.isEmpty()) {
117+
return null;
118+
}
119+
120+
return applyHighlighting(context, results.getFirst());
121+
}
122+
123+
/**
124+
* Fetches highlighting results from the ML model
125+
*
126+
* @param modelId The ID of the model to use
127+
* @param question The search query
128+
* @param context The document text
129+
* @return The highlighting results
130+
*/
131+
public List<Map<String, Object>> fetchModelResults(String modelId, String question, String context) {
132+
PlainActionFuture<List<Map<String, Object>>> future = PlainActionFuture.newFuture();
133+
134+
SentenceHighlightingRequest request = SentenceHighlightingRequest.builder()
135+
.modelId(modelId)
136+
.question(question)
137+
.context(context)
138+
.build();
139+
140+
mlCommonsClient.inferenceSentenceHighlighting(request, future);
141+
142+
try {
143+
return future.actionGet();
144+
} catch (Exception e) {
145+
log.error(
146+
"Error during sentence highlighting inference - modelId: [{}], question: [{}], context: [{}]",
147+
modelId,
148+
question,
149+
context,
150+
e
151+
);
152+
throw new OpenSearchException(
153+
String.format(Locale.ROOT, "Error during sentence highlighting inference from model [%s]", modelId),
154+
e
155+
);
156+
}
157+
}
158+
159+
/**
160+
* Applies highlighting to the original context based on the ML model response
161+
*
162+
* @param context The original document text
163+
* @param highlightResult The highlighting result from the ML model
164+
* @return Formatted text with highlighting
165+
* @throws IllegalArgumentException if highlight positions are invalid
166+
*/
167+
public String applyHighlighting(String context, Map<String, Object> highlightResult) {
168+
// Get the "highlights" list from the result
169+
Object highlightsObj = highlightResult.get(MODEL_INFERENCE_RESULT_KEY);
170+
171+
if (!(highlightsObj instanceof List<?> highlightsList)) {
172+
log.error(String.format(Locale.ROOT, "No valid highlights found in model inference result, highlightsObj: %s", highlightsObj));
173+
return null;
174+
}
175+
176+
if (highlightsList.isEmpty()) {
177+
// No highlights found, return context as is
178+
return context;
179+
}
180+
181+
// Pre-allocate size * 2 since we store start and end positions as consecutive pairs
182+
// Format: [start1, end1, start2, end2, start3, end3, ...]
183+
ArrayList<Integer> validHighlights = new ArrayList<>(highlightsList.size() * 2);
184+
185+
for (Object item : highlightsList) {
186+
Map<String, Number> map = getHighlightsPositionMap(item);
187+
188+
Number start = map.get(MODEL_INFERENCE_RESULT_START_KEY);
189+
Number end = map.get(MODEL_INFERENCE_RESULT_END_KEY);
190+
191+
if (start == null || end == null) {
192+
throw new OpenSearchException("Missing start or end position in highlight data");
193+
}
194+
195+
// Validate positions and add them as a pair to maintain the start-end relationship
196+
validateHighlightPositions(start.intValue(), end.intValue(), context.length());
197+
validHighlights.add(start.intValue()); // Even indices (0,2,4,...) store start positions
198+
validHighlights.add(end.intValue()); // Odd indices (1,3,5,...) store end positions
199+
}
200+
201+
// Verify highlights are sorted by start position (ascending)
202+
// We start from i=2 (second start position) and compare with previous start position (i-2)
203+
// Using i+=2 to skip end positions and only compare start positions with each other
204+
for (int i = 2; i < validHighlights.size(); i += 2) {
205+
// Compare current start position with previous start position
206+
if (validHighlights.get(i) < validHighlights.get(i - 2)) {
207+
log.error(String.format(Locale.ROOT, "Highlights are not sorted: %s", validHighlights));
208+
throw new OpenSearchException("Internal error while applying semantic highlight: received unsorted highlights from model");
209+
}
210+
}
211+
212+
return constructHighlightedText(context, validHighlights);
213+
}
214+
215+
/**
216+
* Validates highlight position values
217+
*
218+
* @param start The start position
219+
* @param end The end position
220+
* @param textLength The length of the text being highlighted
221+
* @throws OpenSearchException if positions are invalid
222+
*/
223+
private void validateHighlightPositions(int start, int end, int textLength) {
224+
if (start < 0 || end > textLength || start >= end) {
225+
throw new OpenSearchException(
226+
String.format(
227+
Locale.ROOT,
228+
"Invalid highlight positions: start=%d, end=%d, textLength=%d. Positions must satisfy: 0 <= start < end <= textLength",
229+
start,
230+
end,
231+
textLength
232+
)
233+
);
234+
}
235+
}
236+
237+
/**
238+
* Constructs highlighted text by iterating through the text once
239+
*
240+
* @param text The original text
241+
* @param highlights The list of valid highlight positions in pairs [start1, end1, start2, end2, ...]
242+
* @return The highlighted text
243+
*/
244+
private String constructHighlightedText(String text, List<Integer> highlights) {
245+
StringBuilder result = new StringBuilder();
246+
int currentPos = 0;
247+
248+
// Iterate through highlight positions in pairs (start, end)
249+
// i increments by 2 to move from one pair to the next
250+
for (int i = 0; i < highlights.size(); i += 2) {
251+
int start = highlights.get(i); // Get start position from even index
252+
int end = highlights.get(i + 1); // Get end position from odd index
253+
254+
// Add text before the highlight if there is any
255+
if (start > currentPos) {
256+
result.append(text, currentPos, start);
257+
}
258+
259+
// Add the highlighted text with highlight tags
260+
result.append(DEFAULT_PRE_TAG);
261+
result.append(text, start, end);
262+
result.append(DEFAULT_POST_TAG);
263+
264+
// Update current position to end of this highlight
265+
currentPos = end;
266+
}
267+
268+
// Add any remaining text after the last highlight
269+
if (currentPos < text.length()) {
270+
result.append(text, currentPos, text.length());
271+
}
272+
273+
return result.toString();
274+
}
275+
276+
/**
277+
* Extracts the highlight position map from a highlight item
278+
*
279+
* @param item The highlight item
280+
* @return The highlight position map
281+
* @throws OpenSearchException if the item cannot be cast to Map<String, Number>
282+
*/
283+
private static Map<String, Number> getHighlightsPositionMap(Object item) {
284+
try {
285+
return (Map<String, Number>) item;
286+
} catch (ClassCastException e) {
287+
throw new OpenSearchException(String.format(Locale.ROOT, "Expect item to be map of string to number, but was: %s", item));
288+
}
289+
}
290+
}

0 commit comments

Comments
 (0)