Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add parser for ModelTensorOutput and ModelTensors #3662

Open
wants to merge 1 commit into
base: 2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@

package org.opensearch.ml.common.output.model;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.annotation.MLAlgoOutput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLOutputType;
Expand Down Expand Up @@ -79,4 +82,24 @@ protected MLOutputType getType() {
return OUTPUT_TYPE;
}

public static ModelTensorOutput parse(XContentParser parser) throws IOException {
List<ModelTensors> mlModelOutputs = new ArrayList<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

if (fieldName.equals(INFERENCE_RESULT_FIELD)) {
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
mlModelOutputs.add(ModelTensors.parse(parser));
}
} else {
parser.skipChildren();
}
}

return ModelTensorOutput.builder().mlModelOutputs(mlModelOutputs).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.common.output.model;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
Expand All @@ -17,6 +19,7 @@
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.exception.MLException;

import lombok.Builder;
Expand Down Expand Up @@ -139,4 +142,33 @@ public static ModelTensors fromBytes(byte[] bytes) {
throw new MLException(errorMsg, e);
}
}

public static ModelTensors parse(XContentParser parser) throws IOException {
Integer statusCode = null;
List<ModelTensor> mlModelTensors = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);

while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case STATUS_CODE_FIELD:
statusCode = parser.intValue(false);
break;
case OUTPUT_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
mlModelTensors.add(ModelTensor.parser(parser));
}
break;
default:
parser.skipChildren();
break;
}
}
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(mlModelTensors).build();
modelTensors.setStatusCode(statusCode);
return modelTensors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.TestHelper;
import org.opensearch.ml.common.output.MLOutputType;

public class ModelTensorOutputTest {
Expand Down Expand Up @@ -61,6 +67,109 @@ public void readInputStream_NullField() throws IOException {
});
}

@Test
public void parse_Success() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
builder.startObject();
builder.startArray(ModelTensorOutput.INFERENCE_RESULT_FIELD);

builder.startObject();
builder.startArray("output");

builder.startObject();
builder.field("name", "test");
builder.field("data_type", "FLOAT32");
builder.field("shape", new long[] { 1, 3 });
builder.field("data", value);
builder.endObject();

builder.endArray();
builder.endObject();

builder.endArray();
builder.endObject();

String jsonStr = TestHelper.xContentBuilderToString(builder);

XContentParser parser = XContentType.JSON
.xContent()
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr);
parser.nextToken();

ModelTensorOutput parsedOutput = ModelTensorOutput.parse(parser);

assertEquals(1, parsedOutput.getMlModelOutputs().size());
ModelTensors modelTensors = parsedOutput.getMlModelOutputs().get(0);
assertEquals(1, modelTensors.getMlModelTensors().size());
ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0);
assertEquals("test", modelTensor.getName());
assertEquals(value.length, modelTensor.getData().length);
assertEquals(value[0].doubleValue(), modelTensor.getData()[0].doubleValue(), 0.0001);
assertArrayEquals(new long[] { 1, 3 }, modelTensor.getShape());
assertEquals(MLResultDataType.FLOAT32, modelTensor.getDataType());
}

@Test
public void parse_EmptyObject() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
builder.startObject();
builder.endObject();

String jsonStr = TestHelper.xContentBuilderToString(builder);

XContentParser parser = XContentType.JSON
.xContent()
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr);
parser.nextToken();

ModelTensorOutput parsedOutput = ModelTensorOutput.parse(parser);

assertEquals(0, parsedOutput.getMlModelOutputs().size());
}

@Test
public void parse_SkipIrrelevantFields() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
builder.startObject();

builder.field("irrelevant_field", "irrelevant_value");

builder.startArray(ModelTensorOutput.INFERENCE_RESULT_FIELD);
builder.startObject();
builder.startArray("output");
builder.startObject();
builder.field("name", "test");
builder.field("data_type", "FLOAT32");
builder.field("shape", new long[] { 1, 3 });
builder.field("data", value);
builder.endObject();
builder.endArray();
builder.endObject();
builder.endArray();

builder.field("another_irrelevant_field", "another_value");

builder.endObject();

String jsonStr = TestHelper.xContentBuilderToString(builder);

XContentParser parser = XContentType.JSON
.xContent()
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr);
parser.nextToken();

ModelTensorOutput parsedOutput = ModelTensorOutput.parse(parser);

assertEquals(1, parsedOutput.getMlModelOutputs().size());
ModelTensors modelTensors = parsedOutput.getMlModelOutputs().get(0);
assertEquals(1, modelTensors.getMlModelTensors().size());
ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0);
assertEquals("test", modelTensor.getName());
assertEquals(value.length, modelTensor.getData().length);
assertEquals(value[0].doubleValue(), modelTensor.getData()[0].doubleValue(), 0.0001);
assertArrayEquals(new long[] { 1, 3 }, modelTensor.getShape());
}

private void readInputStream(ModelTensorOutput input, Consumer<ModelTensorOutput> verify) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
input.writeTo(bytesStreamOutput);
Expand Down
Loading
Loading