diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java index 32f3318718..ca485ec05b 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.output.model; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -12,6 +14,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.annotation.MLAlgoOutput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLOutputType; @@ -79,4 +82,24 @@ protected MLOutputType getType() { return OUTPUT_TYPE; } + public static ModelTensorOutput parse(XContentParser parser) throws IOException { + List mlModelOutputs = new ArrayList<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + if (fieldName.equals(INFERENCE_RESULT_FIELD)) { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + mlModelOutputs.add(ModelTensors.parse(parser)); + } + } else { + parser.skipChildren(); + } + } + + return ModelTensorOutput.builder().mlModelOutputs(mlModelOutputs).build(); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java index 5622057951..8177a6ed56 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.output.model; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -17,6 +19,7 @@ import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.exception.MLException; import lombok.Builder; @@ -139,4 +142,33 @@ public static ModelTensors fromBytes(byte[] bytes) { throw new MLException(errorMsg, e); } } + + public static ModelTensors parse(XContentParser parser) throws IOException { + Integer statusCode = null; + List mlModelTensors = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case STATUS_CODE_FIELD: + statusCode = parser.intValue(false); + break; + case OUTPUT_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + mlModelTensors.add(ModelTensor.parser(parser)); + } + break; + default: + parser.skipChildren(); + break; + } + } + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(mlModelTensors).build(); + modelTensors.setStatusCode(statusCode); + return modelTensors; + } } diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java index 67690ed2bf..b4a79afb98 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java @@ -14,7 +14,13 @@ import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; import org.opensearch.ml.common.output.MLOutputType; public class ModelTensorOutputTest { @@ -61,6 +67,109 @@ public void readInputStream_NullField() throws IOException { }); } + @Test + public void parse_Success() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + builder.startArray(ModelTensorOutput.INFERENCE_RESULT_FIELD); + + builder.startObject(); + builder.startArray("output"); + + builder.startObject(); + builder.field("name", "test"); + builder.field("data_type", "FLOAT32"); + builder.field("shape", new long[] { 1, 3 }); + builder.field("data", value); + builder.endObject(); + + builder.endArray(); + builder.endObject(); + + builder.endArray(); + builder.endObject(); + + String jsonStr = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + + ModelTensorOutput parsedOutput = ModelTensorOutput.parse(parser); + + assertEquals(1, parsedOutput.getMlModelOutputs().size()); + ModelTensors modelTensors = parsedOutput.getMlModelOutputs().get(0); + assertEquals(1, modelTensors.getMlModelTensors().size()); + ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0); + assertEquals("test", modelTensor.getName()); + assertEquals(value.length, modelTensor.getData().length); + assertEquals(value[0].doubleValue(), modelTensor.getData()[0].doubleValue(), 0.0001); + assertArrayEquals(new long[] { 1, 3 }, modelTensor.getShape()); + assertEquals(MLResultDataType.FLOAT32, modelTensor.getDataType()); + } + + @Test + public void parse_EmptyObject() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + builder.endObject(); + + String jsonStr = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + + ModelTensorOutput parsedOutput = ModelTensorOutput.parse(parser); + + assertEquals(0, parsedOutput.getMlModelOutputs().size()); + } + + @Test + public void parse_SkipIrrelevantFields() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + + builder.field("irrelevant_field", "irrelevant_value"); + + builder.startArray(ModelTensorOutput.INFERENCE_RESULT_FIELD); + builder.startObject(); + builder.startArray("output"); + builder.startObject(); + builder.field("name", "test"); + builder.field("data_type", "FLOAT32"); + builder.field("shape", new long[] { 1, 3 }); + builder.field("data", value); + builder.endObject(); + builder.endArray(); + builder.endObject(); + builder.endArray(); + + builder.field("another_irrelevant_field", "another_value"); + + builder.endObject(); + + String jsonStr = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + + ModelTensorOutput parsedOutput = ModelTensorOutput.parse(parser); + + assertEquals(1, parsedOutput.getMlModelOutputs().size()); + ModelTensors modelTensors = parsedOutput.getMlModelOutputs().get(0); + assertEquals(1, modelTensors.getMlModelTensors().size()); + ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0); + assertEquals("test", modelTensor.getName()); + assertEquals(value.length, modelTensor.getData().length); + assertEquals(value[0].doubleValue(), modelTensor.getData()[0].doubleValue(), 0.0001); + assertArrayEquals(new long[] { 1, 3 }, modelTensor.getShape()); + } + private void readInputStream(ModelTensorOutput input, Consumer verify) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); input.writeTo(bytesStreamOutput); diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java index a4f7dc51b1..f3f7f98b6c 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java @@ -6,6 +6,8 @@ package org.opensearch.ml.common.output.model; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import java.io.IOException; @@ -17,9 +19,12 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; public class ModelTensorsTest { @@ -28,6 +33,7 @@ public class ModelTensorsTest { public ExpectedException exceptionRule = ExpectedException.none(); private ModelTensors modelTensors; private ModelResultFilter modelResultFilter; + private Number[] testData; @Before public void setUp() { @@ -40,10 +46,11 @@ public void setUp() { .targetResponsePositions(Arrays.asList(position)) .build(); + testData = new Number[] { 1, 2, 3 }; ModelTensor modelTensor = ModelTensor .builder() .name("model_tensor") - .data(new Number[] { 1, 2, 3 }) + .data(testData) .shape(new long[] { 1, 2, 3, }) .dataType(MLResultDataType.INT32) .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) @@ -120,4 +127,151 @@ public void test_ToAndFromBytes() throws IOException { ModelTensors tensors = ModelTensors.fromBytes(bytes); // assertEquals(modelTensors.getMlModelTensors(), tensors.getMlModelTensors()); } + + @Test + public void parse_Success_WithOutput() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + builder.startArray(ModelTensors.OUTPUT_FIELD); + + builder.startObject(); + builder.field("name", "test_tensor"); + builder.field("data_type", "FLOAT32"); + builder.field("shape", new long[] { 1, 3 }); + builder.field("data", new Float[] { 1.0f, 2.0f, 3.0f }); + builder.endObject(); + + builder.endArray(); + builder.endObject(); + + String jsonStr = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + + ModelTensors parsedTensors = ModelTensors.parse(parser); + + assertNotNull(parsedTensors.getMlModelTensors()); + assertEquals(1, parsedTensors.getMlModelTensors().size()); + ModelTensor modelTensor = parsedTensors.getMlModelTensors().get(0); + assertEquals("test_tensor", modelTensor.getName()); + assertEquals(3, modelTensor.getData().length); + // Compare the first value using double conversion to handle type differences + assertEquals(1.0, modelTensor.getData()[0].doubleValue(), 0.0001); + assertNull(parsedTensors.getStatusCode()); + } + + @Test + public void parse_Success_WithStatusCode() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + builder.field(ModelTensors.STATUS_CODE_FIELD, 200); + builder.endObject(); + + String jsonStr = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + + ModelTensors parsedTensors = ModelTensors.parse(parser); + + assertEquals(Integer.valueOf(200), parsedTensors.getStatusCode()); + assertNotNull(parsedTensors.getMlModelTensors()); + assertEquals(0, parsedTensors.getMlModelTensors().size()); + } + + @Test + public void parse_Success_WithOutputAndStatusCode() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + + builder.field(ModelTensors.STATUS_CODE_FIELD, 200); + + builder.startArray(ModelTensors.OUTPUT_FIELD); + builder.startObject(); + builder.field("name", "test_tensor"); + builder.field("data_type", "INT32"); + builder.field("shape", new long[] { 1, 2 }); + builder.field("data", new Integer[] { 1, 2 }); + builder.endObject(); + builder.endArray(); + + builder.endObject(); + + String jsonStr = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + + ModelTensors parsedTensors = ModelTensors.parse(parser); + + assertEquals(Integer.valueOf(200), parsedTensors.getStatusCode()); + assertNotNull(parsedTensors.getMlModelTensors()); + assertEquals(1, parsedTensors.getMlModelTensors().size()); + ModelTensor modelTensor = parsedTensors.getMlModelTensors().get(0); + assertEquals("test_tensor", modelTensor.getName()); + } + + @Test + public void parse_EmptyObject() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + builder.endObject(); + + String jsonStr = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + + ModelTensors parsedTensors = ModelTensors.parse(parser); + + assertNotNull(parsedTensors.getMlModelTensors()); + assertEquals(0, parsedTensors.getMlModelTensors().size()); + assertNull(parsedTensors.getStatusCode()); + } + + @Test + public void parse_SkipIrrelevantFields() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + + builder.field("irrelevant_field", "irrelevant_value"); + + builder.startArray(ModelTensors.OUTPUT_FIELD); + builder.startObject(); + builder.field("name", "test_tensor"); + builder.field("data_type", "INT32"); + builder.field("shape", new long[] { 1, 2 }); + builder.field("data", new Integer[] { 1, 2 }); + builder.endObject(); + builder.endArray(); + + builder.field(ModelTensors.STATUS_CODE_FIELD, 404); + + builder.field("another_irrelevant_field", "another_value"); + + builder.endObject(); + + String jsonStr = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + ModelTensors parsedTensors = ModelTensors.parse(parser); + + assertEquals(Integer.valueOf(404), parsedTensors.getStatusCode()); + assertNotNull(parsedTensors.getMlModelTensors()); + assertEquals(1, parsedTensors.getMlModelTensors().size()); + ModelTensor modelTensor = parsedTensors.getMlModelTensors().get(0); + assertEquals("test_tensor", modelTensor.getName()); + } }