Skip to content

Commit 31f0422

Browse files
authored
Add parser for ModelTensorOutput and ModelTensors (#3658)
Signed-off-by: Sicheng Song <[email protected]>
1 parent c8d1988 commit 31f0422

File tree

4 files changed

+319
-1
lines changed

4 files changed

+319
-1
lines changed

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java

+23
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55

66
package org.opensearch.ml.common.output.model;
77

8+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
810
import java.io.IOException;
911
import java.util.ArrayList;
1012
import java.util.List;
1113

1214
import org.opensearch.core.common.io.stream.StreamInput;
1315
import org.opensearch.core.common.io.stream.StreamOutput;
1416
import org.opensearch.core.xcontent.XContentBuilder;
17+
import org.opensearch.core.xcontent.XContentParser;
1518
import org.opensearch.ml.common.annotation.MLAlgoOutput;
1619
import org.opensearch.ml.common.output.MLOutput;
1720
import org.opensearch.ml.common.output.MLOutputType;
@@ -79,4 +82,24 @@ protected MLOutputType getType() {
7982
return OUTPUT_TYPE;
8083
}
8184

85+
public static ModelTensorOutput parse(XContentParser parser) throws IOException {
86+
List<ModelTensors> mlModelOutputs = new ArrayList<>();
87+
88+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
89+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
90+
String fieldName = parser.currentName();
91+
parser.nextToken();
92+
93+
if (fieldName.equals(INFERENCE_RESULT_FIELD)) {
94+
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
95+
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
96+
mlModelOutputs.add(ModelTensors.parse(parser));
97+
}
98+
} else {
99+
parser.skipChildren();
100+
}
101+
}
102+
103+
return ModelTensorOutput.builder().mlModelOutputs(mlModelOutputs).build();
104+
}
82105
}

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java

+32
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
package org.opensearch.ml.common.output.model;
77

8+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
810
import java.io.IOException;
911
import java.nio.ByteBuffer;
1012
import java.util.ArrayList;
@@ -17,6 +19,7 @@
1719
import org.opensearch.core.common.io.stream.Writeable;
1820
import org.opensearch.core.xcontent.ToXContentObject;
1921
import org.opensearch.core.xcontent.XContentBuilder;
22+
import org.opensearch.core.xcontent.XContentParser;
2023
import org.opensearch.ml.common.exception.MLException;
2124

2225
import lombok.Builder;
@@ -139,4 +142,33 @@ public static ModelTensors fromBytes(byte[] bytes) {
139142
throw new MLException(errorMsg, e);
140143
}
141144
}
145+
146+
public static ModelTensors parse(XContentParser parser) throws IOException {
147+
Integer statusCode = null;
148+
List<ModelTensor> mlModelTensors = new ArrayList<>();
149+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
150+
151+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
152+
String fieldName = parser.currentName();
153+
parser.nextToken();
154+
155+
switch (fieldName) {
156+
case STATUS_CODE_FIELD:
157+
statusCode = parser.intValue(false);
158+
break;
159+
case OUTPUT_FIELD:
160+
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
161+
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
162+
mlModelTensors.add(ModelTensor.parser(parser));
163+
}
164+
break;
165+
default:
166+
parser.skipChildren();
167+
break;
168+
}
169+
}
170+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(mlModelTensors).build();
171+
modelTensors.setStatusCode(statusCode);
172+
return modelTensors;
173+
}
142174
}

common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java

+109
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
import org.junit.Before;
1515
import org.junit.Test;
1616
import org.opensearch.common.io.stream.BytesStreamOutput;
17+
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
18+
import org.opensearch.common.xcontent.XContentType;
1719
import org.opensearch.core.common.io.stream.StreamInput;
20+
import org.opensearch.core.xcontent.NamedXContentRegistry;
21+
import org.opensearch.core.xcontent.XContentBuilder;
22+
import org.opensearch.core.xcontent.XContentParser;
23+
import org.opensearch.ml.common.TestHelper;
1824
import org.opensearch.ml.common.output.MLOutputType;
1925

2026
public class ModelTensorOutputTest {
@@ -61,6 +67,109 @@ public void readInputStream_NullField() throws IOException {
6167
});
6268
}
6369

70+
@Test
71+
public void parse_Success() throws IOException {
72+
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
73+
builder.startObject();
74+
builder.startArray(ModelTensorOutput.INFERENCE_RESULT_FIELD);
75+
76+
builder.startObject();
77+
builder.startArray("output");
78+
79+
builder.startObject();
80+
builder.field("name", "test");
81+
builder.field("data_type", "FLOAT32");
82+
builder.field("shape", new long[] { 1, 3 });
83+
builder.field("data", value);
84+
builder.endObject();
85+
86+
builder.endArray();
87+
builder.endObject();
88+
89+
builder.endArray();
90+
builder.endObject();
91+
92+
String jsonStr = TestHelper.xContentBuilderToString(builder);
93+
94+
XContentParser parser = XContentType.JSON
95+
.xContent()
96+
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr);
97+
parser.nextToken();
98+
99+
ModelTensorOutput parsedOutput = ModelTensorOutput.parse(parser);
100+
101+
assertEquals(1, parsedOutput.getMlModelOutputs().size());
102+
ModelTensors modelTensors = parsedOutput.getMlModelOutputs().get(0);
103+
assertEquals(1, modelTensors.getMlModelTensors().size());
104+
ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0);
105+
assertEquals("test", modelTensor.getName());
106+
assertEquals(value.length, modelTensor.getData().length);
107+
assertEquals(value[0].doubleValue(), modelTensor.getData()[0].doubleValue(), 0.0001);
108+
assertArrayEquals(new long[] { 1, 3 }, modelTensor.getShape());
109+
assertEquals(MLResultDataType.FLOAT32, modelTensor.getDataType());
110+
}
111+
112+
@Test
113+
public void parse_EmptyObject() throws IOException {
114+
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
115+
builder.startObject();
116+
builder.endObject();
117+
118+
String jsonStr = TestHelper.xContentBuilderToString(builder);
119+
120+
XContentParser parser = XContentType.JSON
121+
.xContent()
122+
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr);
123+
parser.nextToken();
124+
125+
ModelTensorOutput parsedOutput = ModelTensorOutput.parse(parser);
126+
127+
assertEquals(0, parsedOutput.getMlModelOutputs().size());
128+
}
129+
130+
@Test
131+
public void parse_SkipIrrelevantFields() throws IOException {
132+
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
133+
builder.startObject();
134+
135+
builder.field("irrelevant_field", "irrelevant_value");
136+
137+
builder.startArray(ModelTensorOutput.INFERENCE_RESULT_FIELD);
138+
builder.startObject();
139+
builder.startArray("output");
140+
builder.startObject();
141+
builder.field("name", "test");
142+
builder.field("data_type", "FLOAT32");
143+
builder.field("shape", new long[] { 1, 3 });
144+
builder.field("data", value);
145+
builder.endObject();
146+
builder.endArray();
147+
builder.endObject();
148+
builder.endArray();
149+
150+
builder.field("another_irrelevant_field", "another_value");
151+
152+
builder.endObject();
153+
154+
String jsonStr = TestHelper.xContentBuilderToString(builder);
155+
156+
XContentParser parser = XContentType.JSON
157+
.xContent()
158+
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr);
159+
parser.nextToken();
160+
161+
ModelTensorOutput parsedOutput = ModelTensorOutput.parse(parser);
162+
163+
assertEquals(1, parsedOutput.getMlModelOutputs().size());
164+
ModelTensors modelTensors = parsedOutput.getMlModelOutputs().get(0);
165+
assertEquals(1, modelTensors.getMlModelTensors().size());
166+
ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0);
167+
assertEquals("test", modelTensor.getName());
168+
assertEquals(value.length, modelTensor.getData().length);
169+
assertEquals(value[0].doubleValue(), modelTensor.getData()[0].doubleValue(), 0.0001);
170+
assertArrayEquals(new long[] { 1, 3 }, modelTensor.getShape());
171+
}
172+
64173
private void readInputStream(ModelTensorOutput input, Consumer<ModelTensorOutput> verify) throws IOException {
65174
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
66175
input.writeTo(bytesStreamOutput);

0 commit comments

Comments
 (0)