Skip to content

Commit ff0223e

Browse files
authored
Support semantic sentence highlighter (#1193)
Signed-off-by: Junqiu Lei <[email protected]>
1 parent 8506daa commit ff0223e

24 files changed

+2772
-19
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
99
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
1010
- Support filter function for HybridQueryBuilder and NeuralQueryBuilder ([#1206](https://github.com/opensearch-project/neural-search/pull/1206))
1111
- Add Z Score normalization technique ([#1224](https://github.com/opensearch-project/neural-search/pull/1224))
12+
- Support semantic sentence highlighter ([#1193](https://github.com/opensearch-project/neural-search/pull/1193))
1213

1314
### Enhancements
1415

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.highlight;
6+
7+
import lombok.extern.log4j.Log4j2;
8+
import org.opensearch.index.mapper.MappedFieldType;
9+
import org.opensearch.search.fetch.subphase.highlight.FieldHighlightContext;
10+
import org.opensearch.search.fetch.subphase.highlight.HighlightField;
11+
import org.opensearch.search.fetch.subphase.highlight.Highlighter;
12+
import org.opensearch.core.common.text.Text;
13+
14+
/**
15+
* Semantic highlighter that uses ML models to identify relevant text spans for highlighting
16+
*/
17+
@Log4j2
18+
public class SemanticHighlighter implements Highlighter {
19+
public static final String NAME = "semantic";
20+
21+
private SemanticHighlighterEngine semanticHighlighterEngine;
22+
23+
public void initialize(SemanticHighlighterEngine semanticHighlighterEngine) {
24+
if (this.semanticHighlighterEngine != null) {
25+
throw new IllegalStateException(
26+
"SemanticHighlighterEngine has already been initialized. Multiple initializations are not permitted."
27+
);
28+
}
29+
this.semanticHighlighterEngine = semanticHighlighterEngine;
30+
}
31+
32+
@Override
33+
public boolean canHighlight(MappedFieldType fieldType) {
34+
return true;
35+
}
36+
37+
/**
38+
* Highlights a field using semantic highlighting
39+
*
40+
* @param fieldContext The field context containing the query and field information
41+
* @return The highlighted field or null if highlighting is not possible
42+
*/
43+
@Override
44+
public HighlightField highlight(FieldHighlightContext fieldContext) {
45+
if (semanticHighlighterEngine == null) {
46+
throw new IllegalStateException("SemanticHighlighter has not been initialized");
47+
}
48+
49+
// Extract field text
50+
String fieldText = semanticHighlighterEngine.getFieldText(fieldContext);
51+
52+
// Get model ID
53+
String modelId = semanticHighlighterEngine.getModelId(fieldContext.field.fieldOptions().options());
54+
55+
// Try to extract query text
56+
String originalQueryText = semanticHighlighterEngine.extractOriginalQuery(fieldContext.query, fieldContext.fieldName);
57+
58+
if (originalQueryText == null || originalQueryText.isEmpty()) {
59+
log.warn("No query text found for field {}", fieldContext.fieldName);
60+
return null;
61+
}
62+
63+
// Get highlighted text - allow any exceptions from this call to propagate
64+
String highlightedResponse = semanticHighlighterEngine.getHighlightedSentences(modelId, originalQueryText, fieldText);
65+
66+
if (highlightedResponse == null || highlightedResponse.isEmpty()) {
67+
log.warn("No highlighted text found for field {}", fieldContext.fieldName);
68+
return null;
69+
}
70+
71+
// Create highlight field
72+
Text[] fragments = new Text[] { new Text(highlightedResponse) };
73+
return new HighlightField(fieldContext.fieldName, fragments);
74+
}
75+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.highlight;
6+
7+
import lombok.extern.log4j.Log4j2;
8+
import org.apache.lucene.search.Query;
9+
import org.opensearch.OpenSearchException;
10+
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
11+
import org.opensearch.neuralsearch.processor.highlight.SentenceHighlightingRequest;
12+
import org.opensearch.search.fetch.subphase.highlight.FieldHighlightContext;
13+
import org.opensearch.neuralsearch.highlight.extractor.QueryTextExtractorRegistry;
14+
import org.opensearch.action.support.PlainActionFuture;
15+
import lombok.NonNull;
16+
17+
import java.util.ArrayList;
18+
import java.util.List;
19+
import java.util.Locale;
20+
import java.util.Map;
21+
import java.util.Objects;
22+
23+
/**
24+
* Engine class for semantic highlighting operations
25+
*/
26+
@Log4j2
27+
public class SemanticHighlighterEngine {
28+
private static final String MODEL_ID_FIELD = "model_id";
29+
private static final String DEFAULT_PRE_TAG = "<em>";
30+
private static final String DEFAULT_POST_TAG = "</em>";
31+
private static final String MODEL_INFERENCE_RESULT_KEY = "highlights";
32+
private static final String MODEL_INFERENCE_RESULT_START_KEY = "start";
33+
private static final String MODEL_INFERENCE_RESULT_END_KEY = "end";
34+
35+
private final MLCommonsClientAccessor mlCommonsClient;
36+
private final QueryTextExtractorRegistry queryTextExtractorRegistry;
37+
38+
public SemanticHighlighterEngine(
39+
@NonNull MLCommonsClientAccessor mlCommonsClient,
40+
@NonNull QueryTextExtractorRegistry queryTextExtractorRegistry
41+
) {
42+
this.mlCommonsClient = mlCommonsClient;
43+
this.queryTextExtractorRegistry = queryTextExtractorRegistry;
44+
}
45+
46+
/**
47+
* Gets the field text from the document
48+
*
49+
* @param fieldContext The field highlight context
50+
* @return The field text
51+
*/
52+
public String getFieldText(FieldHighlightContext fieldContext) {
53+
if (fieldContext.hitContext == null || fieldContext.hitContext.sourceLookup() == null) {
54+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field %s is not found in the hit", fieldContext.fieldName));
55+
}
56+
Object fieldTextObject = fieldContext.hitContext.sourceLookup().extractValue(fieldContext.fieldName, null);
57+
if (fieldTextObject == null) {
58+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field %s is not found in the document", fieldContext.fieldName));
59+
}
60+
if (fieldTextObject instanceof String == false) {
61+
throw new IllegalArgumentException(
62+
String.format(
63+
Locale.ROOT,
64+
"Field %s must be a string for highlighting, but was %s",
65+
fieldContext.fieldName,
66+
fieldTextObject.getClass().getSimpleName()
67+
)
68+
);
69+
}
70+
String fieldTextString = (String) fieldTextObject;
71+
if (fieldTextString.isEmpty()) {
72+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field %s is empty", fieldContext.fieldName));
73+
}
74+
return fieldTextString;
75+
}
76+
77+
/**
78+
* Extracts the original query text from the search query object.
79+
*
80+
* @param query The query object from which to extract the original text
81+
* @param fieldName The name of the field being highlighted
82+
* @return The extracted original query text for highlighting
83+
* @throws IllegalArgumentException if the extracted query text is empty
84+
*/
85+
public String extractOriginalQuery(Query query, String fieldName) {
86+
if (fieldName == null) {
87+
log.warn("Field name is null, extraction may be less accurate");
88+
}
89+
return queryTextExtractorRegistry.extractQueryText(query, fieldName);
90+
}
91+
92+
/**
93+
* Gets the model ID from the options
94+
*
95+
* @param options The options map
96+
* @return The model ID
97+
*/
98+
public String getModelId(Map<String, Object> options) {
99+
Object modelId = options.get(MODEL_ID_FIELD);
100+
if (Objects.isNull(modelId) || (modelId instanceof String) == false) {
101+
throw new IllegalArgumentException(
102+
String.format(
103+
Locale.ROOT,
104+
"%s must be a non-null string, but was %s",
105+
MODEL_ID_FIELD,
106+
modelId == null ? "null" : modelId.getClass().getSimpleName()
107+
)
108+
);
109+
}
110+
return (String) modelId;
111+
}
112+
113+
/**
114+
* Gets highlighted text from the ML model
115+
*
116+
* @param modelId The ID of the model to use
117+
* @param question The search query
118+
* @param context The document text
119+
* @return Formatted text with highlighting
120+
*/
121+
public String getHighlightedSentences(String modelId, String question, String context) {
122+
List<Map<String, Object>> results = fetchModelResults(modelId, question, context);
123+
if (results == null || results.isEmpty()) {
124+
return null;
125+
}
126+
127+
return applyHighlighting(context, results.getFirst());
128+
}
129+
130+
/**
131+
* Fetches highlighting results from the ML model
132+
*
133+
* @param modelId The ID of the model to use
134+
* @param question The search query
135+
* @param context The document text
136+
* @return The highlighting results
137+
*/
138+
public List<Map<String, Object>> fetchModelResults(String modelId, String question, String context) {
139+
PlainActionFuture<List<Map<String, Object>>> future = PlainActionFuture.newFuture();
140+
141+
SentenceHighlightingRequest request = SentenceHighlightingRequest.builder()
142+
.modelId(modelId)
143+
.question(question)
144+
.context(context)
145+
.build();
146+
147+
mlCommonsClient.inferenceSentenceHighlighting(request, future);
148+
149+
try {
150+
return future.actionGet();
151+
} catch (Exception e) {
152+
log.error(
153+
"Error during sentence highlighting inference - modelId: [{}], question: [{}], context: [{}]",
154+
modelId,
155+
question,
156+
context,
157+
e
158+
);
159+
throw new OpenSearchException(
160+
String.format(Locale.ROOT, "Error during sentence highlighting inference from model [%s]", modelId),
161+
e
162+
);
163+
}
164+
}
165+
166+
/**
167+
* Applies highlighting to the original context based on the ML model response
168+
*
169+
* @param context The original document text
170+
* @param highlightResult The highlighting result from the ML model
171+
* @return Formatted text with highlighting
172+
* @throws IllegalArgumentException if highlight positions are invalid
173+
*/
174+
public String applyHighlighting(String context, Map<String, Object> highlightResult) {
175+
// Get the "highlights" list from the result
176+
Object highlightsObj = highlightResult.get(MODEL_INFERENCE_RESULT_KEY);
177+
178+
if (!(highlightsObj instanceof List<?> highlightsList)) {
179+
log.error(String.format(Locale.ROOT, "No valid highlights found in model inference result, highlightsObj: %s", highlightsObj));
180+
return null;
181+
}
182+
183+
if (highlightsList.isEmpty()) {
184+
// No highlights found, return context as is
185+
return context;
186+
}
187+
188+
// Pre-allocate size * 2 since we store start and end positions as consecutive pairs
189+
// Format: [start1, end1, start2, end2, start3, end3, ...]
190+
ArrayList<Integer> validHighlights = new ArrayList<>(highlightsList.size() * 2);
191+
192+
for (Object item : highlightsList) {
193+
Map<String, Number> map = getHighlightsPositionMap(item);
194+
195+
Number start = map.get(MODEL_INFERENCE_RESULT_START_KEY);
196+
Number end = map.get(MODEL_INFERENCE_RESULT_END_KEY);
197+
198+
if (start == null || end == null) {
199+
throw new OpenSearchException("Missing start or end position in highlight data");
200+
}
201+
202+
// Validate positions and add them as a pair to maintain the start-end relationship
203+
validateHighlightPositions(start.intValue(), end.intValue(), context.length());
204+
validHighlights.add(start.intValue()); // Even indices (0,2,4,...) store start positions
205+
validHighlights.add(end.intValue()); // Odd indices (1,3,5,...) store end positions
206+
}
207+
208+
// Verify highlights are sorted by start position (ascending)
209+
// We start from i=2 (second start position) and compare with previous start position (i-2)
210+
// Using i+=2 to skip end positions and only compare start positions with each other
211+
for (int i = 2; i < validHighlights.size(); i += 2) {
212+
// Compare current start position with previous start position
213+
if (validHighlights.get(i) < validHighlights.get(i - 2)) {
214+
log.error(String.format(Locale.ROOT, "Highlights are not sorted: %s", validHighlights));
215+
throw new OpenSearchException("Internal error while applying semantic highlight: received unsorted highlights from model");
216+
}
217+
}
218+
219+
return constructHighlightedText(context, validHighlights);
220+
}
221+
222+
/**
223+
* Validates highlight position values
224+
*
225+
* @param start The start position
226+
* @param end The end position
227+
* @param textLength The length of the text being highlighted
228+
* @throws OpenSearchException if positions are invalid
229+
*/
230+
private void validateHighlightPositions(int start, int end, int textLength) {
231+
if (start < 0 || end > textLength || start >= end) {
232+
throw new OpenSearchException(
233+
String.format(
234+
Locale.ROOT,
235+
"Invalid highlight positions: start=%d, end=%d, textLength=%d. Positions must satisfy: 0 <= start < end <= textLength",
236+
start,
237+
end,
238+
textLength
239+
)
240+
);
241+
}
242+
}
243+
244+
/**
245+
* Constructs highlighted text by iterating through the text once
246+
*
247+
* @param text The original text
248+
* @param highlights The list of valid highlight positions in pairs [start1, end1, start2, end2, ...]
249+
* @return The highlighted text
250+
*/
251+
private String constructHighlightedText(String text, List<Integer> highlights) {
252+
StringBuilder result = new StringBuilder();
253+
int currentPos = 0;
254+
255+
// Iterate through highlight positions in pairs (start, end)
256+
// i increments by 2 to move from one pair to the next
257+
for (int i = 0; i < highlights.size(); i += 2) {
258+
int start = highlights.get(i); // Get start position from even index
259+
int end = highlights.get(i + 1); // Get end position from odd index
260+
261+
// Add text before the highlight if there is any
262+
if (start > currentPos) {
263+
result.append(text, currentPos, start);
264+
}
265+
266+
// Add the highlighted text with highlight tags
267+
result.append(DEFAULT_PRE_TAG);
268+
result.append(text, start, end);
269+
result.append(DEFAULT_POST_TAG);
270+
271+
// Update current position to end of this highlight
272+
currentPos = end;
273+
}
274+
275+
// Add any remaining text after the last highlight
276+
if (currentPos < text.length()) {
277+
result.append(text, currentPos, text.length());
278+
}
279+
280+
return result.toString();
281+
}
282+
283+
/**
284+
* Extracts the highlight position map from a highlight item
285+
*
286+
* @param item The highlight item
287+
* @return The highlight position map
288+
* @throws OpenSearchException if the item cannot be cast to Map<String, Number>
289+
*/
290+
private static Map<String, Number> getHighlightsPositionMap(Object item) {
291+
try {
292+
return (Map<String, Number>) item;
293+
} catch (ClassCastException e) {
294+
throw new OpenSearchException(String.format(Locale.ROOT, "Expect item to be map of string to number, but was: %s", item));
295+
}
296+
}
297+
}

0 commit comments

Comments
 (0)