Skip to content
Closed

bug fix #4148

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @b4sjoo @dhrubo-os @mingshl @jngz-es @model-collapse @rbhavna @ylwu-amzn @zane-neo @Zhangxunmt @austintlee @HenryL27 @samuel-oci @xinyual
* @b4sjoo @dhrubo-os @mingshl @jngz-es @model-collapse @rbhavna @ylwu-amzn @zane-neo @Zhangxunmt @austintlee @HenryL27 @sam-herman @xinyual @pyek-bot
7 changes: 4 additions & 3 deletions MAINTAINERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@ This document contains a list of maintainers in this repo. See [opensearch-proje
| Jing Zhang | [jngz-es](https://github.com/jngz-es) | Amazon |
| Junshen Wu | [wujunshen](https://github.com/wujunshen) | Amazon |
| Sicheng Song | [b4sjoo](https://github.com/b4sjoo) | Amazon |
| Mingshi Liu | [mingshl](https://github.com/mingshl) | Amazon |
| Mingshi Liu | [mingshl](https://github.com/mingshl) | Amazon |
| Pavan Yekbote | [pyek-bot](https://github.com/pyek-bot) | Amazon |
| Xinyuan Lu | [xinyual](https://github.com/xinyual) | Amazon |
| Xun Zhang | [Zhangxunmt](https://github.com/Zhangxunmt) | Amazon |
| Yaliang Wu | [ylwu-amzn](https://github.com/ylwu-amzn) | Amazon |
| Zan Niu | [zane-neo](https://github.com/zane-neo) | Amazon |
| Austin Lee | [austintlee](https://github.com/austintlee) | Aryn |
| Henry Lindeman | [HenryL27](https://github.com/HenryL27) | Aryn |
| Samuel Herman | [samuel-oci](https://github.com/samuel-oci/) | Oracle |
| Samuel Herman | [samuel-oci](https://github.com/sam-herman/) | Oracle |

## Emeritus

| Maintainer | GitHub ID | Affiliation |
| ----------- | ------------------------------------------------- | ----------- |
| Jackie Han | [jackiehanyang](https://github.com/jackiehanyang) | Amazon |
| Jackie Han | [jackiehanyang](https://github.com/jackiehanyang) | Amazon |
11 changes: 11 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ buildscript {
dependencies {
classpath "${opensearch_group}.gradle:build-tools:${opensearch_version}"
classpath "gradle.plugin.com.dorongold.plugins:task-tree:1.5"
classpath "com.diffplug.spotless:spotless-plugin-gradle:6.25.0"
configurations.all {
resolutionStrategy {
force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // for spotless transitive dependency CVE (for 3.26.100)
Expand Down Expand Up @@ -97,6 +98,16 @@ subprojects {
resolutionStrategy.force "com.google.guava:guava:32.1.3-jre"
resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0'
}

apply plugin: 'com.diffplug.spotless'

spotless {
java {
removeUnusedImports()
importOrder 'java', 'javax', 'org', 'com'
eclipse().configFile rootProject.file('.eclipseformat.xml')
}
}
}

ext {
Expand Down
10 changes: 0 additions & 10 deletions client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ plugins {
id 'jacoco'
id 'io.github.goooler.shadow' version "8.1.7"
id 'maven-publish'
id 'com.diffplug.spotless' version '6.25.0'
id 'signing'
}

Expand All @@ -23,15 +22,6 @@ dependencies {

}

spotless {
java {
removeUnusedImports()
importOrder 'java', 'javax', 'org', 'com'

eclipse().withP2Mirrors(Map.of("https://download.eclipse.org/", "https://mirror.umd.edu/eclipse/")).configFile rootProject.file('.eclipseformat.xml')
}
}

jacocoTestReport {
reports {
xml.getRequired().set(true)
Expand Down
10 changes: 0 additions & 10 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ plugins {
id 'io.github.goooler.shadow' version "8.1.7"
id 'jacoco'
id "io.freefair.lombok"
id 'com.diffplug.spotless' version '6.25.0'
id 'maven-publish'
id 'signing'
}
Expand Down Expand Up @@ -77,15 +76,6 @@ jacocoTestCoverageVerification {
}
check.dependsOn jacocoTestCoverageVerification

spotless {
java {
removeUnusedImports()
importOrder 'java', 'javax', 'org', 'com'

eclipse().withP2Mirrors(Map.of("https://download.eclipse.org/", "https://mirror.umd.edu/eclipse/")).configFile rootProject.file('.eclipseformat.xml')
}
}

shadowJar {
destinationDirectory = file("${project.buildDir}/distributions")
archiveClassifier.set(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.ml.common.connector.functions.postprocess.BedrockRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.RemoteMlCommonsPassthroughPostProcessFunction;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

Expand All @@ -35,6 +36,8 @@ public class MLPostProcessFunction {
public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";
// ML commons passthrough unwraps a remote ml-commons response and reconstructs model tensors directly based on remote inference
public static final String ML_COMMONS_PASSTHROUGH = "connector.post_process.mlcommons.passthrough";

private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

Expand All @@ -46,6 +49,8 @@ public class MLPostProcessFunction {
BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction();
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction();
RemoteMlCommonsPassthroughPostProcessFunction remoteMlCommonsPassthroughPostProcessFunction =
new RemoteMlCommonsPassthroughPostProcessFunction();
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_FLOAT32, "$.embeddings.float");
Expand All @@ -61,6 +66,7 @@ public class MLPostProcessFunction {
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
JSON_PATH_EXPRESSION.put(ML_COMMONS_PASSTHROUGH, "$"); // Get the entire response
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_FLOAT32, embeddingPostProcessFunction);
Expand All @@ -76,6 +82,7 @@ public class MLPostProcessFunction {
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(ML_COMMONS_PASSTHROUGH, remoteMlCommonsPassthroughPostProcessFunction);
}

public static String getResponseFilter(String postProcessFunction) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

/**
* A post-processing function for calling a remote ml commons instance that preserves the original neural sparse response structure
* to avoid double-wrapping when receiving responses from another ML-Commons instance.
*/
public class RemoteMlCommonsPassthroughPostProcessFunction extends ConnectorPostProcessFunction<Map<String, Object>> {
@Override
public void validate(Object input) {
if (!(input instanceof Map) && !(input instanceof List)) {
throw new IllegalArgumentException("Post process function input must be a Map or List");
}
}

/**
* Example unwrapped response:
* {
* "inference_results": [
* {
* "output": [
* {
* "name": "output",
* "dataAsMap": {
* "inference_results": [
* {
* "output": [
* {
* "name": "output",
* "dataAsMap": {
* "response": [
* {
* "increasingly": 0.028670792,
* "achievements": 0.4906937,
* ...
* }
* ]
* }
* }
* ],
* "status_code": 200.0
* }
* ]
* }
* }
* ],
* "status_code": 200
* }
* ]
* }
*
* Example unwrapped response:
*
* {
* "inference_results": [
* {
* "output": [
* {
* "name": "output",
* "dataAsMap": {
* "response": [
* {
* "increasingly": 0.028670792,
* "achievements": 0.4906937,
* ...
* }
* ]
* }
* },
* ],
* "status_code": 200
* }
* ]
* }
*
* @param mlCommonsResponse raw remote ml commons response
* @param dataType the datatype of the result, not used since datatype is set based on the response body
* @return a list of model tensors representing the inner model tensors
*/
@Override
public List<ModelTensor> process(Map<String, Object> mlCommonsResponse, MLResultDataType dataType) {
// Check if this is an ML-Commons response with inference_results
if (mlCommonsResponse.containsKey("inference_results") && mlCommonsResponse.get("inference_results") instanceof List) {
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) mlCommonsResponse.get("inference_results");

List<ModelTensor> modelTensors = new ArrayList<>();
for (Map<String, Object> result : inferenceResults) {
// Extract the output field which contains the ModelTensor data
if (result.containsKey("output") && result.get("output") instanceof List) {
List<Map<String, Object>> outputs = (List<Map<String, Object>>) result.get("output");
for (Map<String, Object> output : outputs) {
// This inner map should represent a model tensor, so we try to parse and instantiate a new one.
ModelTensor modelTensor = createModelTensorFromMap(output);
if (modelTensor != null) {
modelTensors.add(modelTensor);
}
}
}
}

return modelTensors;
}

// Fallback for non-ML-Commons responses
ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(mlCommonsResponse).build();

return List.of(tensor);
}

/**
* Creates a ModelTensor from a Map<String, Object> representation based on the API format
* of the /_predict API
*/
private ModelTensor createModelTensorFromMap(Map<String, Object> map) {
if (map == null || map.isEmpty()) {
return null;
}

// Get name. If name is null or not a String, default to OUTPUT_FIELD
Object uncastedName = map.get(ModelTensor.NAME_FIELD);
String name = uncastedName instanceof String castedName ? castedName : OUTPUT_FIELD;
String result = (String) map.get(ModelTensor.RESULT_FIELD);

// Handle data as map
Map<String, Object> dataAsMap = (Map<String, Object>) map.get(ModelTensor.DATA_AS_MAP_FIELD);

// Handle data type. For certain models like neural sparse and non-dense remote models, this field
// is not populated and left as null instead, which is still valid
MLResultDataType dataType = null;
if (map.containsKey(ModelTensor.DATA_TYPE_FIELD)) {
Object dataTypeObj = map.get(ModelTensor.DATA_TYPE_FIELD);
if (dataTypeObj instanceof String) {
try {
dataType = MLResultDataType.valueOf((String) dataTypeObj);
} catch (IllegalArgumentException e) {
// Invalid data type, leave as null in case inner data is still useful to be parsed in the future
}
}
}

// Handle shape. For certain models like neural sparse and non-dense, null is valid since inference result
// is stored in dataAsMap, not data/shape field
long[] shape = null;
if (map.containsKey(ModelTensor.SHAPE_FIELD)) {
Number[] numbers = processNumericalArray(map, ModelTensor.SHAPE_FIELD, Number.class);
if (numbers != null) {
shape = Arrays.stream(numbers).mapToLong(Number::longValue).toArray();
}
}

// Handle shape. For certain models like neural sparse and non-dense, null is valid since inference result
// is stored in dataAsMap, not data/shape field
Number[] data = null;
if (map.containsKey(ModelTensor.DATA_FIELD)) {
data = processNumericalArray(map, ModelTensor.DATA_FIELD, Number.class);
}

// For now, we skip handling byte buffer since it's not needed for neural sparse and dense model use cases.

return ModelTensor.builder().name(name).dataType(dataType).shape(shape).data(data).result(result).dataAsMap(dataAsMap).build();
}

private static <T> T[] processNumericalArray(Map<String, Object> map, String key, Class<T> type) {
Object obj = map.get(key);
if (obj instanceof List<?> list) {
T[] array = (T[]) Array.newInstance(type, list.size());
for (int i = 0; i < list.size(); i++) {
Object item = list.get(i);
if (type.isInstance(item)) {
array[i] = type.cast(item);
}
}
return array;
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,13 @@ public static MLMemory parse(XContentParser parser) throws IOException {
lastUpdatedTime = Instant.ofEpochMilli(parser.longValue());
break;
case MEMORY_EMBEDDING_FIELD:
// Parse embedding as generic object (could be array or sparse map)
memoryEmbedding = parser.map();
if (parser.currentToken() == XContentParser.Token.START_ARRAY) {
memoryEmbedding = parser.list(); // Simple list parsing like ModelTensor
} else if (parser.currentToken() == XContentParser.Token.START_OBJECT) {
memoryEmbedding = parser.map(); // For sparse embeddings
} else {
parser.skipChildren();
}
break;
default:
parser.skipChildren();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public class StringUtils {
}
public static final String TO_STRING_FUNCTION_NAME = ".toString()";

private static final ObjectMapper MAPPER = new ObjectMapper();
public static final ObjectMapper MAPPER = new ObjectMapper();

public static boolean isValidJsonString(String json) {
if (json == null || json.isBlank()) {
Expand Down
Loading
Loading