diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 61eeccf9db..3cd32f2de6 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @b4sjoo @dhrubo-os @mingshl @jngz-es @model-collapse @rbhavna @ylwu-amzn @zane-neo @Zhangxunmt @austintlee @HenryL27 @samuel-oci @xinyual \ No newline at end of file +* @b4sjoo @dhrubo-os @mingshl @jngz-es @model-collapse @rbhavna @ylwu-amzn @zane-neo @Zhangxunmt @austintlee @HenryL27 @sam-herman @xinyual @pyek-bot diff --git a/MAINTAINERS.md b/MAINTAINERS.md index c9c1d6ade8..3f7e8893d3 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -12,17 +12,18 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Jing Zhang | [jngz-es](https://github.com/jngz-es) | Amazon | | Junshen Wu | [wujunshen](https://github.com/wujunshen) | Amazon | | Sicheng Song | [b4sjoo](https://github.com/b4sjoo) | Amazon | -| Mingshi Liu | [mingshl](https://github.com/mingshl) | Amazon | +| Mingshi Liu | [mingshl](https://github.com/mingshl) | Amazon | +| Pavan Yekbote | [pyek-bot](https://github.com/pyek-bot) | Amazon | | Xinyuan Lu | [xinyual](https://github.com/xinyual) | Amazon | | Xun Zhang | [Zhangxunmt](https://github.com/Zhangxunmt) | Amazon | | Yaliang Wu | [ylwu-amzn](https://github.com/ylwu-amzn) | Amazon | | Zan Niu | [zane-neo](https://github.com/zane-neo) | Amazon | | Austin Lee | [austintlee](https://github.com/austintlee) | Aryn | | Henry Lindeman | [HenryL27](https://github.com/HenryL27) | Aryn | -| Samuel Herman | [samuel-oci](https://github.com/samuel-oci/) | Oracle | +| Samuel Herman | [samuel-oci](https://github.com/sam-herman/) | Oracle | ## Emeritus | Maintainer | GitHub ID | Affiliation | | ----------- | ------------------------------------------------- | ----------- | -| Jackie Han | [jackiehanyang](https://github.com/jackiehanyang) | Amazon | \ No newline at end of file +| Jackie Han | [jackiehanyang](https://github.com/jackiehanyang) | Amazon | diff --git a/build.gradle b/build.gradle index 911a9e9da9..de007ef60e 100644 --- a/build.gradle +++ b/build.gradle @@ -41,6 +41,7 @@ buildscript { dependencies { classpath "${opensearch_group}.gradle:build-tools:${opensearch_version}" classpath "gradle.plugin.com.dorongold.plugins:task-tree:1.5" + classpath "com.diffplug.spotless:spotless-plugin-gradle:6.25.0" configurations.all { resolutionStrategy { force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // for spotless transitive dependency CVE (for 3.26.100) @@ -97,6 +98,16 @@ subprojects { resolutionStrategy.force "com.google.guava:guava:32.1.3-jre" resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' } + + apply plugin: 'com.diffplug.spotless' + + spotless { + java { + removeUnusedImports() + importOrder 'java', 'javax', 'org', 'com' + eclipse().configFile rootProject.file('.eclipseformat.xml') + } + } } ext { diff --git a/client/build.gradle b/client/build.gradle index 2df24f77fe..e592cd24ff 100644 --- a/client/build.gradle +++ b/client/build.gradle @@ -9,7 +9,6 @@ plugins { id 'jacoco' id 'io.github.goooler.shadow' version "8.1.7" id 'maven-publish' - id 'com.diffplug.spotless' version '6.25.0' id 'signing' } @@ -23,15 +22,6 @@ dependencies { } -spotless { - java { - removeUnusedImports() - importOrder 'java', 'javax', 'org', 'com' - - eclipse().withP2Mirrors(Map.of("https://download.eclipse.org/", "https://mirror.umd.edu/eclipse/")).configFile rootProject.file('.eclipseformat.xml') - } -} - jacocoTestReport { reports { xml.getRequired().set(true) diff --git a/common/build.gradle b/common/build.gradle index 24cc63046f..9db59f5070 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -9,7 +9,6 @@ plugins { id 'io.github.goooler.shadow' version "8.1.7" id 'jacoco' id "io.freefair.lombok" - id 'com.diffplug.spotless' version '6.25.0' id 'maven-publish' id 'signing' } @@ -77,15 +76,6 @@ jacocoTestCoverageVerification { } check.dependsOn jacocoTestCoverageVerification -spotless { - java { - removeUnusedImports() - importOrder 'java', 'javax', 'org', 'com' - - eclipse().withP2Mirrors(Map.of("https://download.eclipse.org/", "https://mirror.umd.edu/eclipse/")).configFile rootProject.file('.eclipseformat.xml') - } -} - shadowJar { destinationDirectory = file("${project.buildDir}/distributions") archiveClassifier.set(null) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index efda9c4743..fe10a831f5 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -15,6 +15,7 @@ import org.opensearch.ml.common.connector.functions.postprocess.BedrockRerankPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction; +import org.opensearch.ml.common.connector.functions.postprocess.RemoteMlCommonsPassthroughPostProcessFunction; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; @@ -35,6 +36,8 @@ public class MLPostProcessFunction { public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank"; public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; public static final String DEFAULT_RERANK = "connector.post_process.default.rerank"; + // ML commons passthrough unwraps a remote ml-commons response and reconstructs model tensors directly based on remote inference + public static final String ML_COMMONS_PASSTHROUGH = "connector.post_process.mlcommons.passthrough"; private static final Map JSON_PATH_EXPRESSION = new HashMap<>(); @@ -46,6 +49,8 @@ public class MLPostProcessFunction { BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction(); CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction(); BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction(); + RemoteMlCommonsPassthroughPostProcessFunction remoteMlCommonsPassthroughPostProcessFunction = + new RemoteMlCommonsPassthroughPostProcessFunction(); JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_FLOAT32, "$.embeddings.float"); @@ -61,6 +66,7 @@ public class MLPostProcessFunction { JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results"); JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results"); JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]"); + JSON_PATH_EXPRESSION.put(ML_COMMONS_PASSTHROUGH, "$"); // Get the entire response POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_FLOAT32, embeddingPostProcessFunction); @@ -76,6 +82,7 @@ public class MLPostProcessFunction { POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction); POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(ML_COMMONS_PASSTHROUGH, remoteMlCommonsPassthroughPostProcessFunction); } public static String getResponseFilter(String postProcessFunction) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunction.java new file mode 100644 index 0000000000..b991ee82d8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunction.java @@ -0,0 +1,192 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD; + +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +/** + * A post-processing function for calling a remote ml commons instance that preserves the original neural sparse response structure + * to avoid double-wrapping when receiving responses from another ML-Commons instance. + */ +public class RemoteMlCommonsPassthroughPostProcessFunction extends ConnectorPostProcessFunction> { + @Override + public void validate(Object input) { + if (!(input instanceof Map) && !(input instanceof List)) { + throw new IllegalArgumentException("Post process function input must be a Map or List"); + } + } + + /** + * Example unwrapped response: + * { + * "inference_results": [ + * { + * "output": [ + * { + * "name": "output", + * "dataAsMap": { + * "inference_results": [ + * { + * "output": [ + * { + * "name": "output", + * "dataAsMap": { + * "response": [ + * { + * "increasingly": 0.028670792, + * "achievements": 0.4906937, + * ... + * } + * ] + * } + * } + * ], + * "status_code": 200.0 + * } + * ] + * } + * } + * ], + * "status_code": 200 + * } + * ] + * } + * + * Example unwrapped response: + * + * { + * "inference_results": [ + * { + * "output": [ + * { + * "name": "output", + * "dataAsMap": { + * "response": [ + * { + * "increasingly": 0.028670792, + * "achievements": 0.4906937, + * ... + * } + * ] + * } + * }, + * ], + * "status_code": 200 + * } + * ] + * } + * + * @param mlCommonsResponse raw remote ml commons response + * @param dataType the datatype of the result, not used since datatype is set based on the response body + * @return a list of model tensors representing the inner model tensors + */ + @Override + public List process(Map mlCommonsResponse, MLResultDataType dataType) { + // Check if this is an ML-Commons response with inference_results + if (mlCommonsResponse.containsKey("inference_results") && mlCommonsResponse.get("inference_results") instanceof List) { + List> inferenceResults = (List>) mlCommonsResponse.get("inference_results"); + + List modelTensors = new ArrayList<>(); + for (Map result : inferenceResults) { + // Extract the output field which contains the ModelTensor data + if (result.containsKey("output") && result.get("output") instanceof List) { + List> outputs = (List>) result.get("output"); + for (Map output : outputs) { + // This inner map should represent a model tensor, so we try to parse and instantiate a new one. + ModelTensor modelTensor = createModelTensorFromMap(output); + if (modelTensor != null) { + modelTensors.add(modelTensor); + } + } + } + } + + return modelTensors; + } + + // Fallback for non-ML-Commons responses + ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(mlCommonsResponse).build(); + + return List.of(tensor); + } + + /** + * Creates a ModelTensor from a Map representation based on the API format + * of the /_predict API + */ + private ModelTensor createModelTensorFromMap(Map map) { + if (map == null || map.isEmpty()) { + return null; + } + + // Get name. If name is null or not a String, default to OUTPUT_FIELD + Object uncastedName = map.get(ModelTensor.NAME_FIELD); + String name = uncastedName instanceof String castedName ? castedName : OUTPUT_FIELD; + String result = (String) map.get(ModelTensor.RESULT_FIELD); + + // Handle data as map + Map dataAsMap = (Map) map.get(ModelTensor.DATA_AS_MAP_FIELD); + + // Handle data type. For certain models like neural sparse and non-dense remote models, this field + // is not populated and left as null instead, which is still valid + MLResultDataType dataType = null; + if (map.containsKey(ModelTensor.DATA_TYPE_FIELD)) { + Object dataTypeObj = map.get(ModelTensor.DATA_TYPE_FIELD); + if (dataTypeObj instanceof String) { + try { + dataType = MLResultDataType.valueOf((String) dataTypeObj); + } catch (IllegalArgumentException e) { + // Invalid data type, leave as null in case inner data is still useful to be parsed in the future + } + } + } + + // Handle shape. For certain models like neural sparse and non-dense, null is valid since inference result + // is stored in dataAsMap, not data/shape field + long[] shape = null; + if (map.containsKey(ModelTensor.SHAPE_FIELD)) { + Number[] numbers = processNumericalArray(map, ModelTensor.SHAPE_FIELD, Number.class); + if (numbers != null) { + shape = Arrays.stream(numbers).mapToLong(Number::longValue).toArray(); + } + } + + // Handle shape. For certain models like neural sparse and non-dense, null is valid since inference result + // is stored in dataAsMap, not data/shape field + Number[] data = null; + if (map.containsKey(ModelTensor.DATA_FIELD)) { + data = processNumericalArray(map, ModelTensor.DATA_FIELD, Number.class); + } + + // For now, we skip handling byte buffer since it's not needed for neural sparse and dense model use cases. + + return ModelTensor.builder().name(name).dataType(dataType).shape(shape).data(data).result(result).dataAsMap(dataAsMap).build(); + } + + private static T[] processNumericalArray(Map map, String key, Class type) { + Object obj = map.get(key); + if (obj instanceof List list) { + T[] array = (T[]) Array.newInstance(type, list.size()); + for (int i = 0; i < list.size(); i++) { + Object item = list.get(i); + if (type.isInstance(item)) { + array[i] = type.cast(item); + } + } + return array; + } + return null; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemory.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemory.java index b3b0cff6c2..09a649f052 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemory.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemory.java @@ -203,8 +203,13 @@ public static MLMemory parse(XContentParser parser) throws IOException { lastUpdatedTime = Instant.ofEpochMilli(parser.longValue()); break; case MEMORY_EMBEDDING_FIELD: - // Parse embedding as generic object (could be array or sparse map) - memoryEmbedding = parser.map(); + if (parser.currentToken() == XContentParser.Token.START_ARRAY) { + memoryEmbedding = parser.list(); // Simple list parsing like ModelTensor + } else if (parser.currentToken() == XContentParser.Token.START_OBJECT) { + memoryEmbedding = parser.map(); // For sparse embeddings + } else { + parser.skipChildren(); + } break; default: parser.skipChildren(); diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 9d4f5ad0c9..61a275e5f6 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -78,7 +78,7 @@ public class StringUtils { } public static final String TO_STRING_FUNCTION_NAME = ".toString()"; - private static final ObjectMapper MAPPER = new ObjectMapper(); + public static final ObjectMapper MAPPER = new ObjectMapper(); public static boolean isValidJsonString(String json) { if (json == null || json.isBlank()) { diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 5841a40c7c..29ea295a17 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -193,6 +193,41 @@ public void createPayload() { Assert.assertEquals("{\"input\": \"test input value\"}", predictPayload); } + @Test + public void createPayload_ExtraParams() { + + String requestBody = + "{\"input\": \"${parameters.input}\", \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}"; + String expected = + "{\"input\": \"test value\", \"parameters\": {\"sparseEmbeddingFormat\": \"WORD\", \"content_type\": \"query\" }}"; + + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + Map parameters = new HashMap<>(); + parameters.put("input", "test value"); + parameters.put("sparseEmbeddingFormat", "WORD"); + parameters.put("content_type", "query"); + String predictPayload = connector.createPayload(PREDICT.name(), parameters); + connector.validatePayload(predictPayload); + Assert.assertEquals(expected, predictPayload); + } + + @Test + public void createPayload_MissingParamsInvalidJson() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule + .expectMessage( + "Invalid payload: {\"input\": \"test value\", \"parameters\": {\"sparseEmbeddingFormat\": \"WORD\", \"content_type\": ${parameters.content_type} }}" + ); + String requestBody = + "{\"input\": \"${parameters.input}\", \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": ${parameters.content_type} }}"; + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + Map parameters = new HashMap<>(); + parameters.put("input", "test value"); + parameters.put("sparseEmbeddingFormat", "WORD"); + String predictPayload = connector.createPayload(PREDICT.name(), parameters); + connector.validatePayload(predictPayload); + } + @Test public void parseResponse_modelTensorJson() throws IOException { HttpConnector connector = createHttpConnector(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunctionTest.java new file mode 100644 index 0000000000..b2cc031f70 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunctionTest.java @@ -0,0 +1,193 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +public class RemoteMlCommonsPassthroughPostProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + RemoteMlCommonsPassthroughPostProcessFunction function; + + @Before + public void setUp() { + function = new RemoteMlCommonsPassthroughPostProcessFunction(); + } + + @Test + public void process_WrongInput_NotMapOrList() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Post process function input must be a Map or List"); + function.apply("abc", null); + } + + /** + * Tests processing of ML-Commons response containing sparse vector data with rank features. + * Validates that sparse vectors with dataAsMap containing token-score pairs are correctly parsed. + */ + @Test + public void process_MLCommonsResponse_RankFeatures() { + Map rankFeatures = Map + .of("increasingly", 0.028670792, "achievements", 0.4906937, "nation", 0.15371077, "hello", 0.35982144, "today", 3.0966291); + Map innerDataAsMap = Map.of("response", Arrays.asList(rankFeatures)); + Map output = Map.of("name", "output", "dataAsMap", innerDataAsMap); + Map inferenceResult = Map.of("output", Arrays.asList(output)); + Map input = Map.of("inference_results", Arrays.asList(inferenceResult)); + + List result = function.apply(input, null); + + assertEquals(1, result.size()); + ModelTensor tensor = result.get(0); + assertEquals("output", tensor.getName()); + assertEquals(innerDataAsMap, tensor.getDataAsMap()); + + // Verify the nested sparse data structure + Map dataAsMap = (Map) tensor.getDataAsMap(); + List> response = (List>) dataAsMap.get("response"); + assertEquals(1, response.size()); + assertEquals(0.35982144, (Double) response.get(0).get("hello"), 0.0001); + assertEquals(3.0966291, (Double) response.get(0).get("today"), 0.0001); + } + + /** + * Tests processing of ML-Commons response containing dense vector data with numerical arrays. + * Validates that dense vectors with data_type, shape, and data fields are correctly parsed. + */ + @Test + public void process_MLCommonsResponse_DenseVector() { + Map output = Map + .of( + "name", + "sentence_embedding", + "data_type", + "FLOAT32", + "shape", + Arrays.asList(3L), + "data", + Arrays.asList(0.5400895, -0.19082281, 0.4996347) + ); + Map inferenceResult = Map.of("output", Arrays.asList(output)); + Map input = Map.of("inference_results", Arrays.asList(inferenceResult)); + + List result = function.apply(input, null); + + assertEquals(1, result.size()); + ModelTensor tensor = result.get(0); + assertEquals("sentence_embedding", tensor.getName()); + assertEquals(MLResultDataType.FLOAT32, tensor.getDataType()); + assertEquals(1, tensor.getShape().length); + assertEquals(3L, tensor.getShape()[0]); + assertEquals(3, tensor.getData().length); + assertEquals(0.5400895, tensor.getData()[0].doubleValue(), 0.0001); + } + + /** + * Tests processing of ML-Commons response with multiple output tensors in a single inference result. + * Ensures all outputs are processed and returned as separate ModelTensor objects. + */ + @Test + public void process_MLCommonsResponse_MultipleOutputs() { + Map output1 = Map.of("name", "output1", "result", "result1"); + Map output2 = Map.of("name", "output2", "result", "result2"); + Map inferenceResult = Map.of("output", Arrays.asList(output1, output2)); + Map input = Map.of("inference_results", Arrays.asList(inferenceResult)); + + List result = function.apply(input, null); + + assertEquals(2, result.size()); + assertEquals("output1", result.get(0).getName()); + assertEquals("result1", result.get(0).getResult()); + assertEquals("output2", result.get(1).getName()); + assertEquals("result2", result.get(1).getResult()); + } + + /** + * Tests edge case where ML-Commons response has empty inference_results array. + * Should return empty list without errors. + */ + @Test + public void process_MLCommonsResponse_EmptyInferenceResults() { + Map input = Map.of("inference_results", Arrays.asList()); + + List result = function.apply(input, null); + + assertEquals(0, result.size()); + } + + /** + * Tests edge cases where inference result lacks the expected format. + * Should skip processing and return empty list. + */ + @Test + public void process_MLCommonsResponse_InvalidOutputs() { + Map inferenceResult = Map.of("other_field", "value"); + Map input = Map.of("inference_results", Arrays.asList(inferenceResult)); + + List result = function.apply(input, null); + + assertEquals(0, result.size()); + + // correct format, but with empty output + inferenceResult = Map.of("output", List.of(Map.of())); + input = Map.of("inference_results", List.of(inferenceResult)); + + result = function.apply(input, null); + + assertEquals(0, result.size()); + + // Fallback for non-ml-commons responses + input = Map.of("invalid_format", "invalid value"); + result = function.apply(input, null); + + assertEquals(1, result.size()); + assertEquals(input, result.getFirst().getDataAsMap()); + assertEquals("response", result.getFirst().getName()); + } + + /** + * Tests processing of ML-Commons response containing dense vector data with numerical arrays. + * Validates that when the types are incorrect, values are parsed as nulls. + */ + @Test + public void process_MLCommonsResponse_InvalidDenseVectorFormat() { + Map output = Map + .of( + "name", + List.of("Not a string"), + "data_type", + "NON-EXISTENT TYPE", + "shape", + "not a list of long", + "data", + "not a list of numbers" + ); + Map inferenceResult = Map.of("output", Arrays.asList(output)); + Map input = Map.of("inference_results", Arrays.asList(inferenceResult)); + + List result = function.apply(input, null); + + assertEquals(1, result.size()); + ModelTensor tensor = result.getFirst(); + assertEquals(OUTPUT_FIELD, tensor.getName()); + assertNull(tensor.getShape()); + assertNull(tensor.getData()); + assertNull(tensor.getDataType()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryTest.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryTest.java index 4aa1d5a137..09978d0a7a 100644 --- a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryTest.java +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryTest.java @@ -14,6 +14,7 @@ import java.io.IOException; import java.time.Instant; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.junit.Before; @@ -442,4 +443,192 @@ public void testSpecialCharactersInFields() throws IOException { assertEquals(specialMemory.getRole(), parsed.getRole()); assertEquals(specialMemory.getTags(), parsed.getTags()); } + + @Test + public void testParseWithDenseEmbeddingArray() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-123\"," + + "\"memory\":\"Test memory with dense embedding\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":[0.1,0.2,0.3,0.4]" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-123", parsed.getSessionId()); + assertEquals("Test memory with dense embedding", parsed.getMemory()); + assertNotNull(parsed.getMemoryEmbedding()); + assertTrue(parsed.getMemoryEmbedding() instanceof List); + @SuppressWarnings("unchecked") + List embedding = (List) parsed.getMemoryEmbedding(); + assertEquals(4, embedding.size()); + assertEquals(0.1, embedding.get(0), 0.001); + assertEquals(0.2, embedding.get(1), 0.001); + assertEquals(0.3, embedding.get(2), 0.001); + assertEquals(0.4, embedding.get(3), 0.001); + } + + @Test + public void testParseWithSparseEmbeddingObject() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-456\"," + + "\"memory\":\"Test memory with sparse embedding\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":{\"token1\":0.5,\"token2\":0.8,\"token3\":0.2}" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-456", parsed.getSessionId()); + assertEquals("Test memory with sparse embedding", parsed.getMemory()); + assertNotNull(parsed.getMemoryEmbedding()); + assertTrue(parsed.getMemoryEmbedding() instanceof Map); + @SuppressWarnings("unchecked") + Map embedding = (Map) parsed.getMemoryEmbedding(); + assertEquals(3, embedding.size()); + assertEquals(0.5, embedding.get("token1")); + assertEquals(0.8, embedding.get("token2")); + assertEquals(0.2, embedding.get("token3")); + } + + @Test + public void testParseWithWrappedDenseEmbedding() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-789\"," + + "\"memory\":\"Test memory with wrapped dense embedding\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":{\"values\":[0.1,0.2,0.3]}" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-789", parsed.getSessionId()); + assertEquals("Test memory with wrapped dense embedding", parsed.getMemory()); + assertNotNull(parsed.getMemoryEmbedding()); + assertTrue(parsed.getMemoryEmbedding() instanceof Map); + @SuppressWarnings("unchecked") + Map embedding = (Map) parsed.getMemoryEmbedding(); + assertTrue(embedding.containsKey("values")); + assertTrue(embedding.get("values") instanceof List); + } + + @Test + public void testParseWithEmptyEmbeddingArray() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-empty\"," + + "\"memory\":\"Test memory with empty embedding\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":[]" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-empty", parsed.getSessionId()); + assertNotNull(parsed.getMemoryEmbedding()); + assertTrue(parsed.getMemoryEmbedding() instanceof List); + @SuppressWarnings("unchecked") + List embedding = (List) parsed.getMemoryEmbedding(); + assertTrue(embedding.isEmpty()); + } + + @Test + public void testParseWithEmptyEmbeddingObject() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-empty-obj\"," + + "\"memory\":\"Test memory with empty embedding object\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":{}" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-empty-obj", parsed.getSessionId()); + assertNotNull(parsed.getMemoryEmbedding()); + assertTrue(parsed.getMemoryEmbedding() instanceof Map); + @SuppressWarnings("unchecked") + Map embedding = (Map) parsed.getMemoryEmbedding(); + assertTrue(embedding.isEmpty()); + } + + @Test + public void testParseWithInvalidEmbeddingType() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-invalid\"," + + "\"memory\":\"Test memory with invalid embedding\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":\"invalid_string_embedding\"" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-invalid", parsed.getSessionId()); + // Should gracefully handle invalid embedding type by skipping it + assertNull(parsed.getMemoryEmbedding()); + } } diff --git a/docs/tutorials/ml_inference/language_identification/ml_inference_with_language_identification_ingest.md b/docs/tutorials/ml_inference/language_identification/ml_inference_with_language_identification_ingest.md index 589f3e7dd5..0692861bdf 100644 --- a/docs/tutorials/ml_inference/language_identification/ml_inference_with_language_identification_ingest.md +++ b/docs/tutorials/ml_inference/language_identification/ml_inference_with_language_identification_ingest.md @@ -42,7 +42,7 @@ huggingface_model = HuggingFaceModel( # Deploy model to SageMaker Inference predictor = huggingface_model.deploy( initial_instance_count=1, - instance_type='ml.m5.xlarge' + instance_type='ml.m7g.xlarge' ) # After deployment, you can find your endpoint name in the diff --git a/memory/build.gradle b/memory/build.gradle index 3541e592b7..b1d6fd6398 100644 --- a/memory/build.gradle +++ b/memory/build.gradle @@ -20,7 +20,6 @@ plugins { id 'java' id 'jacoco' id "io.freefair.lombok" - id 'com.diffplug.spotless' version '6.25.0' } dependencies { @@ -81,12 +80,3 @@ jacocoTestCoverageVerification { dependsOn jacocoTestReport } check.dependsOn jacocoTestCoverageVerification - -spotless { - java { - removeUnusedImports() - importOrder 'java', 'javax', 'org', 'com' - - eclipse().withP2Mirrors(Map.of("https://download.eclipse.org/", "https://mirror.umd.edu/eclipse/")).configFile rootProject.file('.eclipseformat.xml') - } -} diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 009334a37c..3a43642cd3 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -10,7 +10,6 @@ plugins { id 'java-library' id 'jacoco' id "io.freefair.lombok" - id 'com.diffplug.spotless' version '6.25.0' } repositories { @@ -136,12 +135,3 @@ jacocoTestCoverageVerification { } check.dependsOn jacocoTestCoverageVerification compileJava.dependsOn(':opensearch-ml-common:shadowJar') - -spotless { - java { - removeUnusedImports() - importOrder 'java', 'javax', 'org', 'com' - - eclipse().withP2Mirrors(Map.of("https://download.eclipse.org/", "https://mirror.umd.edu/eclipse/")).configFile rootProject.file('.eclipseformat.xml') - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 84b827ccc6..19f53eadef 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -5,10 +5,12 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput; +import java.io.IOException; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; @@ -28,11 +30,14 @@ import org.opensearch.common.collect.Tuple; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.TokenBucket; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; @@ -42,10 +47,12 @@ import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.script.ScriptService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.Client; @@ -83,6 +90,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener parametersMap = getParams(mlInput); + parameters.putAll(parametersMap); + } catch (IOException e) { + actionListener.onFailure(e); + return; + } + } + RemoteInferenceInputDataSet inputData = processInput(action, mlInput, connector, parameters, getScriptService()); if (inputData.getParameters() != null) { parameters.putAll(inputData.getParameters()); @@ -227,6 +247,15 @@ && getUserRateLimiterMap().get(user.getName()) != null } } + static Map getParams(MLInput mlInput) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + mlInput.getParameters().toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.flush(); + String json = builder.toString(); + Map tempMap = StringUtils.MAPPER.readValue(json, Map.class); + return getParameterMap(tempMap); + } + default BackoffPolicy getRetryBackoffPolicy(ConnectorClientConfig connectorClientConfig) { switch (connectorClientConfig.getRetryBackoffPolicy()) { case EXPONENTIAL_EQUAL_JITTER: diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java index 4ff0c6c815..5d50d6573c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java @@ -7,8 +7,10 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.argThat; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; @@ -17,7 +19,9 @@ import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS; +import java.io.IOException; import java.util.Arrays; +import java.util.HashMap; import java.util.Map; import org.junit.Assert; @@ -30,6 +34,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.AwsConnector; @@ -39,6 +44,10 @@ import org.opensearch.ml.common.connector.RetryBackoffPolicy; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; @@ -64,6 +73,9 @@ public class RemoteConnectorExecutorTest { @Mock ActionListener> actionListener; + @Mock + private MLAlgoParams mlInputParams; + @Before public void setUp() { MockitoAnnotations.openMocks(this); @@ -169,4 +181,165 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault() ); assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role"); } + + @Test + public void executePreparePayloadAndInvoke_PassingParameter() { + Map parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2"); + Connector connector = getConnector(parameters); + AwsConnectorExecutor executor = getExecutor(connector); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("input", "You are a ${parameters.role}")) + .actionType(PREDICT) + .build(); + String actionType = inputDataSet.getActionType().toString(); + AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters + .builder() + .sparseEmbeddingFormat(SparseEmbeddingFormat.WORD) + .embeddingContentType(null) + .build(); + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .parameters(inputParams) + .inputDataset(inputDataSet) + .build(); + + Exception exception = Assert + .assertThrows( + IllegalArgumentException.class, + () -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener) + ); + assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role"); + } + + @Test + public void executePreparePayloadAndInvoke_GetParamsIOException() throws Exception { + Map parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2"); + Connector connector = getConnector(parameters); + AwsConnectorExecutor executor = getExecutor(connector); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("input", "test input")) + .actionType(PREDICT) + .build(); + String actionType = inputDataSet.getActionType().toString(); + doThrow(new IOException("UT test IOException")).when(mlInputParams).toXContent(any(XContentBuilder.class), any()); + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .parameters(mlInputParams) + .inputDataset(inputDataSet) + .build(); + + executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener); + verify(actionListener).onFailure(argThat(e -> e instanceof IOException && e.getMessage().contains("UT test IOException"))); + } + + @Test + public void executeGetParams_MissingParameter() { + Map parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2"); + Connector connector = getConnector(parameters); + AwsConnectorExecutor executor = getExecutor(connector); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("input", "${parameters.input}")) + .actionType(PREDICT) + .build(); + String actionType = inputDataSet.getActionType().toString(); + AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters + .builder() + .sparseEmbeddingFormat(SparseEmbeddingFormat.WORD) + .embeddingContentType(null) + .build(); + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .parameters(inputParams) + .inputDataset(inputDataSet) + .build(); + + try { + Map paramsMap = RemoteConnectorExecutor.getParams(mlInput); + Map expectedMap = new HashMap<>(); + expectedMap.put("sparse_embedding_format", "WORD"); + Assert.assertEquals(expectedMap, paramsMap); + } catch (IOException e) { + e.printStackTrace(); + } + } + + @Test + public void executeGetParams_PassingParameter() { + Map parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2"); + Connector connector = getConnector(parameters); + AwsConnectorExecutor executor = getExecutor(connector); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("input", "${parameters.input}")) + .actionType(PREDICT) + .build(); + String actionType = inputDataSet.getActionType().toString(); + AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters + .builder() + .sparseEmbeddingFormat(SparseEmbeddingFormat.WORD) + .embeddingContentType(AsymmetricTextEmbeddingParameters.EmbeddingContentType.PASSAGE) + .build(); + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .parameters(inputParams) + .inputDataset(inputDataSet) + .build(); + + try { + Map paramsMap = RemoteConnectorExecutor.getParams(mlInput); + Map expectedMap = new HashMap<>(); + expectedMap.put("sparse_embedding_format", "WORD"); + expectedMap.put("content_type", "PASSAGE"); + Assert.assertEquals(expectedMap, paramsMap); + } catch (IOException e) { + e.printStackTrace(); + } + } + + @Test + public void executeGetParams_ConvertToString() { + Map parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2"); + Connector connector = getConnector(parameters); + AwsConnectorExecutor executor = getExecutor(connector); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("input", "${parameters.input}")) + .actionType(PREDICT) + .build(); + KMeansParams inputParams = KMeansParams + .builder() + .centroids(5) + .iterations(100) + .distanceType(KMeansParams.DistanceType.EUCLIDEAN) + .build(); + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .parameters(inputParams) + .inputDataset(inputDataSet) + .build(); + + try { + Map paramsMap = RemoteConnectorExecutor.getParams(mlInput); + Map expectedMap = new HashMap<>(); + expectedMap.put("centroids", "5"); + expectedMap.put("iterations", "100"); + expectedMap.put("distance_type", "EUCLIDEAN"); + Assert.assertEquals(expectedMap, paramsMap); + } catch (IOException e) { + e.printStackTrace(); + } + } } diff --git a/plugin/build.gradle b/plugin/build.gradle index 13e64430af..935491db0d 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -13,7 +13,6 @@ plugins { id "io.freefair.lombok" id 'jacoco' id 'java-library' - id 'com.diffplug.spotless' version '6.25.0' } ext { @@ -482,15 +481,6 @@ afterEvaluate { } } -spotless { - java { - removeUnusedImports() - importOrder 'java', 'javax', 'org', 'com' - - eclipse().withP2Mirrors(Map.of("https://download.eclipse.org/", "https://mirror.umd.edu/eclipse/")).configFile rootProject.file('.eclipseformat.xml') - } -} - tasks.withType(licenseHeaders.class) { additionalLicense 'AL ', 'Apache', 'Licensed under the Apache License, Version 2.0 (the "License")' } diff --git a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryActionTests.java index 3054ba54c5..d10ee95796 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryActionTests.java @@ -437,6 +437,74 @@ private String createMemoryJson() { + "}"; } + private String createMemoryJsonWithDenseEmbedding() { + long currentTimeEpoch = System.currentTimeMillis(); + return "{" + + "\"session_id\":\"test-session\"," + + "\"memory\":\"Test memory with dense embedding\"," + + "\"memory_type\":\"RAW_MESSAGE\"," + + "\"user_id\":\"test-user\"," + + "\"created_time\":" + + currentTimeEpoch + + "," + + "\"last_updated_time\":" + + currentTimeEpoch + + "," + + "\"memory_embedding\":[0.1,0.2,0.3,0.4]" + + "}"; + } + + private String createMemoryJsonWithSparseEmbedding() { + long currentTimeEpoch = System.currentTimeMillis(); + return "{" + + "\"session_id\":\"test-session\"," + + "\"memory\":\"Test memory with sparse embedding\"," + + "\"memory_type\":\"RAW_MESSAGE\"," + + "\"user_id\":\"test-user\"," + + "\"created_time\":" + + currentTimeEpoch + + "," + + "\"last_updated_time\":" + + currentTimeEpoch + + "," + + "\"memory_embedding\":{\"token1\":0.5,\"token2\":0.8,\"token3\":0.2}" + + "}"; + } + + private String createMemoryJsonWithWrappedEmbedding() { + long currentTimeEpoch = System.currentTimeMillis(); + return "{" + + "\"session_id\":\"test-session\"," + + "\"memory\":\"Test memory with wrapped embedding\"," + + "\"memory_type\":\"RAW_MESSAGE\"," + + "\"user_id\":\"test-user\"," + + "\"created_time\":" + + currentTimeEpoch + + "," + + "\"last_updated_time\":" + + currentTimeEpoch + + "," + + "\"memory_embedding\":{\"values\":[0.1,0.2,0.3]}" + + "}"; + } + + private String createMemoryJsonWithInvalidEmbedding() { + long currentTimeEpoch = System.currentTimeMillis(); + return "{" + + "\"session_id\":\"test-session\"," + + "\"memory\":\"Test memory with invalid embedding\"," + + "\"memory_type\":\"RAW_MESSAGE\"," + + "\"user_id\":\"test-user\"," + + "\"created_time\":" + + currentTimeEpoch + + "," + + "\"last_updated_time\":" + + currentTimeEpoch + + "," + + "\"memory_embedding\":\"invalid_string_embedding\"" + + "}"; + } + @Test public void testDoExecuteWithFeatureDisabled() { // Setup feature flag to be disabled @@ -461,4 +529,149 @@ public void testDoExecuteWithFeatureDisabled() { // Verify that no other operations were attempted verify(memoryContainerHelper, never()).getMemoryContainer(any(String.class), any(ActionListener.class)); } + + @Test + public void testDoExecuteWithDenseEmbeddingArray() { + // Setup request + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_ID); + + // Setup memory container helper to return container + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(testMemoryContainer); + return null; + }).when(memoryContainerHelper).getMemoryContainer(any(String.class), any(ActionListener.class)); + + // Setup client to return response with dense embedding as array + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + org.opensearch.action.get.GetResponse getResponse = mock(org.opensearch.action.get.GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsString()).thenReturn(createMemoryJsonWithDenseEmbedding()); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(org.opensearch.action.get.GetRequest.class), any(ActionListener.class)); + + // Execute + action.doExecute(task, getRequest, actionListener); + + // Verify successful response + ArgumentCaptor responseCaptor = forClass(MLGetMemoryResponse.class); + verify(actionListener).onResponse(responseCaptor.capture()); + + MLGetMemoryResponse capturedResponse = responseCaptor.getValue(); + assertNotNull(capturedResponse); + MLMemory returnedMemory = capturedResponse.getMlMemory(); + assertNotNull(returnedMemory); + assertNotNull(returnedMemory.getMemoryEmbedding()); + } + + @Test + public void testDoExecuteWithSparseEmbeddingObject() { + // Setup request + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_ID); + + // Setup memory container helper to return container + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(testMemoryContainer); + return null; + }).when(memoryContainerHelper).getMemoryContainer(any(String.class), any(ActionListener.class)); + + // Setup client to return response with sparse embedding as object + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + org.opensearch.action.get.GetResponse getResponse = mock(org.opensearch.action.get.GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsString()).thenReturn(createMemoryJsonWithSparseEmbedding()); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(org.opensearch.action.get.GetRequest.class), any(ActionListener.class)); + + // Execute + action.doExecute(task, getRequest, actionListener); + + // Verify successful response + ArgumentCaptor responseCaptor = forClass(MLGetMemoryResponse.class); + verify(actionListener).onResponse(responseCaptor.capture()); + + MLGetMemoryResponse capturedResponse = responseCaptor.getValue(); + assertNotNull(capturedResponse); + MLMemory returnedMemory = capturedResponse.getMlMemory(); + assertNotNull(returnedMemory); + assertNotNull(returnedMemory.getMemoryEmbedding()); + } + + @Test + public void testDoExecuteWithWrappedDenseEmbedding() { + // Setup request + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_ID); + + // Setup memory container helper to return container + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(testMemoryContainer); + return null; + }).when(memoryContainerHelper).getMemoryContainer(any(String.class), any(ActionListener.class)); + + // Setup client to return response with wrapped dense embedding + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + org.opensearch.action.get.GetResponse getResponse = mock(org.opensearch.action.get.GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsString()).thenReturn(createMemoryJsonWithWrappedEmbedding()); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(org.opensearch.action.get.GetRequest.class), any(ActionListener.class)); + + // Execute + action.doExecute(task, getRequest, actionListener); + + // Verify successful response + ArgumentCaptor responseCaptor = forClass(MLGetMemoryResponse.class); + verify(actionListener).onResponse(responseCaptor.capture()); + + MLGetMemoryResponse capturedResponse = responseCaptor.getValue(); + assertNotNull(capturedResponse); + MLMemory returnedMemory = capturedResponse.getMlMemory(); + assertNotNull(returnedMemory); + assertNotNull(returnedMemory.getMemoryEmbedding()); + } + + @Test + public void testDoExecuteWithInvalidEmbeddingType() { + // Setup request + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_ID); + + // Setup memory container helper to return container + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(testMemoryContainer); + return null; + }).when(memoryContainerHelper).getMemoryContainer(any(String.class), any(ActionListener.class)); + + // Setup client to return response with invalid embedding type (string) + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + org.opensearch.action.get.GetResponse getResponse = mock(org.opensearch.action.get.GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsString()).thenReturn(createMemoryJsonWithInvalidEmbedding()); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(org.opensearch.action.get.GetRequest.class), any(ActionListener.class)); + + // Execute + action.doExecute(task, getRequest, actionListener); + + // Verify successful response (should gracefully handle invalid embedding) + ArgumentCaptor responseCaptor = forClass(MLGetMemoryResponse.class); + verify(actionListener).onResponse(responseCaptor.capture()); + + MLGetMemoryResponse capturedResponse = responseCaptor.getValue(); + assertNotNull(capturedResponse); + MLMemory returnedMemory = capturedResponse.getMlMemory(); + assertNotNull(returnedMemory); + // Invalid embedding should be null (gracefully handled) + assertNull(returnedMemory.getMemoryEmbedding()); + } } diff --git a/search-processors/build.gradle b/search-processors/build.gradle index 2e827e3db7..d47bc2a8fb 100644 --- a/search-processors/build.gradle +++ b/search-processors/build.gradle @@ -19,7 +19,6 @@ plugins { id 'java' id 'jacoco' id "io.freefair.lombok" - id 'com.diffplug.spotless' version '6.25.0' } repositories { @@ -74,12 +73,3 @@ jacocoTestCoverageVerification { } check.dependsOn jacocoTestCoverageVerification - -spotless { - java { - removeUnusedImports() - importOrder 'java', 'javax', 'org', 'com' - - eclipse().withP2Mirrors(Map.of("https://download.eclipse.org/", "https://mirror.umd.edu/eclipse/")).configFile rootProject.file('.eclipseformat.xml') - } -} diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/MLCommonsExtension.java b/spi/src/main/java/org/opensearch/ml/common/spi/MLCommonsExtension.java index 3761b2ec64..8ca1294c18 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/MLCommonsExtension.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/MLCommonsExtension.java @@ -5,10 +5,10 @@ package org.opensearch.ml.common.spi; -import org.opensearch.ml.common.spi.tools.Tool; - import java.util.List; +import org.opensearch.ml.common.spi.tools.Tool; + /** * ml-commons extension interface. */ diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java b/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java index 4898cea587..3615384fce 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java @@ -5,10 +5,10 @@ package org.opensearch.ml.common.spi.memory; -import org.opensearch.core.action.ActionListener; - import java.util.Map; +import org.opensearch.core.action.ActionListener; + /** * A general memory interface. * @param @@ -28,15 +28,18 @@ public interface Memory { */ default void save(String id, T message) {} - default void save(String id, T message, ActionListener listener){} + default void save(String id, T message, ActionListener listener) {} /** * Get messages of memory id. * @param id memory id * @return */ - default T[] getMessages(String id){return null;} - default void getMessages(String id, ActionListener listener){} + default T[] getMessages(String id) { + return null; + } + + default void getMessages(String id, ActionListener listener) {} /** * Clear all memory. diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java index 64d9e04008..28739c53b1 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java @@ -5,12 +5,11 @@ package org.opensearch.ml.common.spi.tools; -import org.opensearch.core.action.ActionListener; - import java.util.Collections; -import java.util.HashMap; import java.util.Map; +import org.opensearch.core.action.ActionListener; + /** * General tool interface. */ @@ -68,7 +67,9 @@ public interface Tool { String getDescription(); Map getAttributes(); + void setAttributes(Map attributes); + /** * Set tool description. * @param description the description to set diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/tools/WithModelTool.java b/spi/src/main/java/org/opensearch/ml/common/spi/tools/WithModelTool.java index 3b289681fc..e160e0ec40 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/tools/WithModelTool.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/tools/WithModelTool.java @@ -5,7 +5,6 @@ package org.opensearch.ml.common.spi.tools; - import java.util.List; /**