Skip to content

Commit 9c9407e

Browse files
authored
use standard config in ingest processor intead of always return list (#3008)
Signed-off-by: Mingshi Liu <[email protected]>
1 parent fe74150 commit 9c9407e

File tree

4 files changed

+281
-27
lines changed

4 files changed

+281
-27
lines changed

Diff for: common/build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ dependencies {
3434
exclude group: 'com.google.j2objc', module: 'j2objc-annotations'
3535
exclude group: 'com.google.guava', module: 'listenablefuture'
3636
}
37+
compileOnly 'com.jayway.jsonpath:json-path:2.9.0'
3738
}
3839

3940
lombok {

Diff for: common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

+88
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636

3737
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
3838
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD;
39+
import com.jayway.jsonpath.JsonPath;
40+
3941

4042
@Log4j2
4143
public class StringUtils {
@@ -56,6 +58,7 @@ public class StringUtils {
5658
static {
5759
gson = new Gson();
5860
}
61+
public static final String TO_STRING_FUNCTION_NAME = ".toString()";
5962

6063
public static boolean isValidJsonString(String Json) {
6164
try {
@@ -239,4 +242,89 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea
239242
return errorMessage + " Model ID: " + modelId;
240243
}
241244
}
245+
246+
public static String obtainFieldNameFromJsonPath(String jsonPath) {
247+
String[] parts = jsonPath.split("\\.");
248+
249+
// Get the last part which is the field name
250+
return parts[parts.length - 1];
251+
}
252+
253+
public static String getJsonPath(String jsonPathWithSource) {
254+
// Find the index of the first occurrence of "$."
255+
int startIndex = jsonPathWithSource.indexOf("$.");
256+
257+
// Extract the substring from the startIndex to the end of the input string
258+
return (startIndex != -1) ? jsonPathWithSource.substring(startIndex) : jsonPathWithSource;
259+
}
260+
261+
/**
262+
* Checks if the given input string matches the JSONPath format.
263+
*
264+
* <p>The JSONPath format is a way to navigate and extract data from JSON documents.
265+
* It uses a syntax similar to XPath for XML documents. This method attempts to compile
266+
* the input string as a JSONPath expression using the {@link com.jayway.jsonpath.JsonPath}
267+
* library. If the compilation succeeds, it means the input string is a valid JSONPath
268+
* expression.
269+
*
270+
* @param input the input string to be checked for JSONPath format validity
271+
* @return true if the input string is a valid JSONPath expression, false otherwise
272+
*/
273+
public static boolean isValidJSONPath(String input) {
274+
if (input == null || input.isBlank()) {
275+
return false;
276+
}
277+
try {
278+
JsonPath.compile(input); // This will throw an exception if the path is invalid
279+
return true;
280+
} catch (Exception e) {
281+
return false;
282+
}
283+
}
284+
285+
286+
/**
287+
* Collects the prefixes of the toString() method calls present in the values of the given map.
288+
*
289+
* @param map A map containing key-value pairs where the values may contain toString() method calls.
290+
* @return A list of prefixes for the toString() method calls found in the map values.
291+
*/
292+
public static List<String> collectToStringPrefixes(Map<String, String> map) {
293+
List<String> prefixes = new ArrayList<>();
294+
for (String key : map.keySet()) {
295+
String value = map.get(key);
296+
if (value != null) {
297+
Pattern pattern = Pattern.compile("\\$\\{parameters\\.(.+?)\\.toString\\(\\)\\}");
298+
Matcher matcher = pattern.matcher(value);
299+
while (matcher.find()) {
300+
String prefix = matcher.group(1);
301+
prefixes.add(prefix);
302+
}
303+
}
304+
}
305+
return prefixes;
306+
}
307+
308+
/**
309+
* Parses the given parameters map and processes the values containing toString() method calls.
310+
*
311+
* @param parameters A map containing key-value pairs where the values may contain toString() method calls.
312+
* @return A new map with the processed values for the toString() method calls.
313+
*/
314+
public static Map<String, String> parseParameters(Map<String, String> parameters) {
315+
if (parameters != null) {
316+
List<String> toStringParametersPrefixes = collectToStringPrefixes(parameters);
317+
318+
if (!toStringParametersPrefixes.isEmpty()) {
319+
for (String prefix : toStringParametersPrefixes) {
320+
String value = parameters.get(prefix);
321+
if (value != null) {
322+
parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value));
323+
}
324+
}
325+
}
326+
}
327+
return parameters;
328+
}
329+
242330
}

Diff for: plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java

+19-22
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import static org.opensearch.ml.processor.InferenceProcessorAttributes.*;
88

99
import java.io.IOException;
10-
import java.util.ArrayList;
1110
import java.util.Collection;
1211
import java.util.HashMap;
1312
import java.util.HashSet;
@@ -37,9 +36,7 @@
3736
import org.opensearch.script.ScriptService;
3837
import org.opensearch.script.TemplateScript;
3938

40-
import com.jayway.jsonpath.Configuration;
4139
import com.jayway.jsonpath.JsonPath;
42-
import com.jayway.jsonpath.Option;
4340

4441
/**
4542
* MLInferenceIngestProcessor requires a modelId string to call model inferences
@@ -75,11 +72,6 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
7572
public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
7673
private final NamedXContentRegistry xContentRegistry;
7774

78-
private Configuration suppressExceptionConfiguration = Configuration
79-
.builder()
80-
.options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL, Option.ALWAYS_RETURN_LIST)
81-
.build();
82-
8375
protected MLInferenceIngestProcessor(
8476
String modelId,
8577
List<Map<String, String>> inputMaps,
@@ -320,24 +312,29 @@ private void getMappedModelInputFromDocuments(
320312
Object documentFieldValue = ingestDocument.getFieldValue(originalFieldPath, Object.class);
321313
String documentFieldValueAsString = toString(documentFieldValue);
322314
updateModelParameters(modelInputFieldName, documentFieldValueAsString, modelParameters);
315+
return;
323316
}
324-
// else when cannot find field path in document, try check for nested array using json path
325-
else {
326-
if (documentFieldName.contains(DOT_SYMBOL)) {
327-
328-
Map<String, Object> sourceObject = ingestDocument.getSourceAndMetadata();
329-
ArrayList<Object> fieldValueList = JsonPath
330-
.using(suppressExceptionConfiguration)
331-
.parse(sourceObject)
332-
.read(documentFieldName);
333-
if (!fieldValueList.isEmpty()) {
334-
updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters);
335-
} else if (!ignoreMissing) {
336-
throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
317+
// If the standard dot path fails, try to check for a nested array using JSON path
318+
if (StringUtils.isValidJSONPath(documentFieldName)) {
319+
Map<String, Object> sourceObject = ingestDocument.getSourceAndMetadata();
320+
Object fieldValue = JsonPath.using(suppressExceptionConfiguration).parse(sourceObject).read(documentFieldName);
321+
322+
if (fieldValue != null) {
323+
if (fieldValue instanceof List) {
324+
List<?> fieldValueList = (List<?>) fieldValue;
325+
if (!fieldValueList.isEmpty()) {
326+
updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters);
327+
} else if (!ignoreMissing) {
328+
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
329+
}
330+
} else {
331+
updateModelParameters(modelInputFieldName, toString(fieldValue), modelParameters);
337332
}
338333
} else if (!ignoreMissing) {
339-
throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
334+
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
340335
}
336+
} else {
337+
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
341338
}
342339
}
343340

0 commit comments

Comments
 (0)