From 81d3b6243d3991b997880456208954a0a552e560 Mon Sep 17 00:00:00 2001 From: Li Ning Date: Wed, 14 Oct 2020 11:42:09 -0700 Subject: [PATCH 1/8] add inference proto --- frontend/build.gradle | 8 +++ frontend/server/build.gradle | 14 ++++++ .../server/src/main/proto/inference.proto | 50 +++++++++++++++++++ frontend/tools/conf/findbugs-exclude.xml | 4 +- 4 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 frontend/server/src/main/proto/inference.proto diff --git a/frontend/build.gradle b/frontend/build.gradle index cc06f55fd..e6274389d 100644 --- a/frontend/build.gradle +++ b/frontend/build.gradle @@ -45,6 +45,14 @@ configure(javaProjects()) { } jacocoTestCoverageVerification { + afterEvaluate { + classDirectories = files(classDirectories.files.collect { + fileTree(dir: it, exclude: [ + 'com/amazonaws/ml/mms/protobuf/codegen/*', + ]) + }) + } + violationRules { rule { limit { diff --git a/frontend/server/build.gradle b/frontend/server/build.gradle index 3ad36d395..65333d532 100644 --- a/frontend/server/build.gradle +++ b/frontend/server/build.gradle @@ -1,8 +1,14 @@ +plugins { + id "com.google.protobuf" version "0.8.10" + id "java" +} + dependencies { compile "io.netty:netty-all:${netty_version}" compile project(":modelarchive") compile "commons-cli:commons-cli:${commons_cli_version}" compile "software.amazon.ai:mms-plugins-sdk:${mms_server_sdk_version}" + compile "com.google.protobuf:protobuf-java:3.13.0" testCompile "org.testng:testng:${testng_version}" } @@ -22,6 +28,14 @@ jar { exclude "META-INF//NOTICE*" } +protobuf { + // Configure the protoc executable + protoc { + // Download from repositories + artifact = 'com.google.protobuf:protoc:3.13.0' + } +} + test.doFirst { systemProperty "mmsConfigFile", 'src/test/resources/config.properties' systemProperty "METRICS_LOCATION","build/logs" diff --git a/frontend/server/src/main/proto/inference.proto b/frontend/server/src/main/proto/inference.proto new file mode 100644 index 000000000..4b76b17f2 --- /dev/null +++ b/frontend/server/src/main/proto/inference.proto @@ -0,0 +1,50 @@ +syntax = "proto3"; + +package com.amazonaws.ml.mms.protobuf.codegen; + +option optimize_for = SPEED; +option java_package = "com.amazonaws.ml.mms.protobuf.codegen"; +option java_multiple_files = true; + +message InferenceRequest { + string modelName = 1; + WorkerCommands command = 2; + repeated RequestInput batch = 3; +} + +message InferenceResponse { + // http code + int32 code = 1; + string message = 2; + repeated Prediction predictions = 3; +} + +enum WorkerCommands { + ping = 0; + models = 1; + predictions = 2; + apiDescription = 3; + invocations = 4; + predict = 5; +} + +message RequestInput { + string requestId = 1; + map headers = 2; + repeated InputParameter parameters = 3; +} + +message InputParameter { + string name = 1; + bytes value = 2; + string contentType = 3; +} + +message Prediction { + string requestId = 1; + int32 statusCode = 2; + string reasonPhrase = 3; + string contentType = 4; + map headers = 5; + bytes resp = 6; +} \ No newline at end of file diff --git a/frontend/tools/conf/findbugs-exclude.xml b/frontend/tools/conf/findbugs-exclude.xml index d2c4821a7..f8b8f8dc7 100644 --- a/frontend/tools/conf/findbugs-exclude.xml +++ b/frontend/tools/conf/findbugs-exclude.xml @@ -10,5 +10,7 @@ - + + + From f667fd2451f09f4776f1d2e1de763c5b543b74aa Mon Sep 17 00:00:00 2001 From: Li Ning Date: Thu, 15 Oct 2020 18:34:19 -0700 Subject: [PATCH 2/8] update gradle --- frontend/server/build.gradle | 7 +++++++ frontend/server/src/main/proto/inference.proto | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/frontend/server/build.gradle b/frontend/server/build.gradle index 65333d532..e96cc42f4 100644 --- a/frontend/server/build.gradle +++ b/frontend/server/build.gradle @@ -1,6 +1,7 @@ plugins { id "com.google.protobuf" version "0.8.10" id "java" + id "idea" } dependencies { @@ -36,6 +37,12 @@ protobuf { } } +idea { + module { + generatedSourceDirs += file('build/generated/source/proto') + } +} + test.doFirst { systemProperty "mmsConfigFile", 'src/test/resources/config.properties' systemProperty "METRICS_LOCATION","build/logs" diff --git a/frontend/server/src/main/proto/inference.proto b/frontend/server/src/main/proto/inference.proto index 4b76b17f2..ed5d05475 100644 --- a/frontend/server/src/main/proto/inference.proto +++ b/frontend/server/src/main/proto/inference.proto @@ -9,7 +9,7 @@ option java_multiple_files = true; message InferenceRequest { string modelName = 1; WorkerCommands command = 2; - repeated RequestInput batch = 3; + RequestInput request = 3; } message InferenceResponse { From fec16e4aceaca9a0504717b24221601fb53abc30 Mon Sep 17 00:00:00 2001 From: Li Ning Date: Fri, 16 Oct 2020 11:11:58 -0700 Subject: [PATCH 3/8] add protobuf encode and decode --- .../ml/mms/http/HttpRequestHandler.java | 8 +- .../ml/mms/http/InferenceRequestHandler.java | 120 +++++++++++++----- .../com/amazonaws/ml/mms/util/NettyUtils.java | 14 ++ .../ml/mms/util/messages/RequestInput.java | 10 ++ .../java/com/amazonaws/ml/mms/wlm/Job.java | 48 +++++-- .../server/src/main/proto/inference.proto | 57 ++++----- 6 files changed, 186 insertions(+), 71 deletions(-) diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java index ce3d28c1b..d71005e8c 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java @@ -32,6 +32,7 @@ public class HttpRequestHandler extends SimpleChannelInboundHandler ep) { endpointMap = ep; } - @Override - protected void handleRequest( - ChannelHandlerContext ctx, - FullHttpRequest req, - QueryStringDecoder decoder, - String[] segments) - throws ModelException { - if (isInferenceReq(segments)) { - if (endpointMap.getOrDefault(segments[1], null) != null) { - handleCustomEndpoint(ctx, req, segments, decoder); - } else { - switch (segments[1]) { - case "ping": - ModelManager.getInstance().workerStatus(ctx); - break; - case "models": - case "invocations": - validatePredictionsEndpoint(segments); - handleInvocations(ctx, req, decoder, segments); - break; - case "predictions": - handlePredictions(ctx, req, segments); - break; - default: - handleLegacyPredict(ctx, req, decoder, segments); - break; - } - } - } else { - chain.handleRequest(ctx, req, decoder, segments); - } - } - private boolean isInferenceReq(String[] segments) { return segments.length == 0 || segments[1].equals("ping") @@ -116,6 +85,61 @@ private void handlePredictions( predict(ctx, req, null, segments[2]); } + @Override + protected void handleRequest( + ChannelHandlerContext ctx, + FullHttpRequest req, + QueryStringDecoder decoder, + String[] segments) + throws ModelException { + if (decoder == null) { + try { + InferenceRequest inferenceRequest = + InferenceRequest.parseFrom(req.content().nioBuffer()); + + switch (inferenceRequest.getCommandValue()) { + case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.ping_VALUE: + ModelManager.getInstance().workerStatus(ctx); + break; + case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.models_VALUE: + case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.invocations_VALUE: + case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.predictions_VALUE: + case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.predict_VALUE: + handlePredictions(ctx, inferenceRequest); + break; + default: + chain.handleRequest(ctx, req, null, segments); + break; + } + } catch (InvalidProtocolBufferException e) { + chain.handleRequest(ctx, req, null, segments); + } + } else if (isInferenceReq(segments)) { + if (endpointMap.getOrDefault(segments[1], null) != null) { + handleCustomEndpoint(ctx, req, segments, decoder); + } else { + switch (segments[1]) { + case "ping": + ModelManager.getInstance().workerStatus(ctx); + break; + case "models": + case "invocations": + validatePredictionsEndpoint(segments); + handleInvocations(ctx, req, decoder, segments); + break; + case "predictions": + handlePredictions(ctx, req, segments); + break; + default: + handleLegacyPredict(ctx, req, decoder, segments); + break; + } + } + } else { + chain.handleRequest(ctx, req, decoder, segments); + } + } + private void handleInvocations( ChannelHandlerContext ctx, FullHttpRequest req, @@ -147,6 +171,27 @@ private void handleLegacyPredict( predict(ctx, req, decoder, segments[1]); } + private void handlePredictions(ChannelHandlerContext ctx, InferenceRequest inferenceRequest) + throws ModelNotFoundException { + String modelName = inferenceRequest.getModelName(); + if (modelName.isEmpty()) { + if (ModelManager.getInstance().getStartupModels().size() == 1) { + modelName = ModelManager.getInstance().getStartupModels().iterator().next(); + } + } + RequestInput input = new RequestInput(NettyUtils.getRequestId(ctx.channel())); + input.setProto(true); + com.amazonaws.ml.mms.protobuf.codegen.RequestInput protoInput = + inferenceRequest.getRequest(); + input.setHeaders(protoInput.getHeadersMap()); + for (com.amazonaws.ml.mms.protobuf.codegen.InputParameter parameter : + protoInput.getParametersList()) { + input.addParameter( + new InputParameter(parameter.getName(), parameter.getValue().toByteArray())); + } + predict(ctx, modelName, input); + } + private void predict( ChannelHandlerContext ctx, FullHttpRequest req, @@ -177,6 +222,15 @@ private void predict( } } + private void predict(ChannelHandlerContext ctx, String modelName, RequestInput input) + throws ModelNotFoundException, BadRequestException { + Job job = new Job(ctx, modelName, WorkerCommands.PREDICT, input); + if (!ModelManager.getInstance().addJob(job)) { + throw new ServiceUnavailableException( + "No worker is available to serve request: " + modelName); + } + } + private static RequestInput parseRequest( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder) { String requestId = NettyUtils.getRequestId(ctx.channel()); diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java index c2caf9a19..30cfc79c8 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java @@ -132,6 +132,20 @@ public static void sendError( sendJsonResponse(ctx, error, status); } + public static void sendErrorProto( + ChannelHandlerContext ctx, HttpResponseStatus status, Throwable t) { + com.amazonaws.ml.mms.protobuf.codegen.ErrorResponse errorResponse = + com.amazonaws.ml.mms.protobuf.codegen.ErrorResponse.newBuilder() + .setCode(status.code()) + .setType(t.getClass().getSimpleName()) + .setMessage(t.getMessage()) + .build(); + FullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false); + resp.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM); + resp.content().writeBytes(errorResponse.toByteArray()); + sendHttpResponse(ctx, resp, true); + } + public static void sendError( ChannelHandlerContext ctx, HttpResponseStatus status, Throwable t, String msg) { ErrorResponse error = new ErrorResponse(status.code(), t.getClass().getSimpleName(), msg); diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/RequestInput.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/RequestInput.java index f3a889914..8563ebc10 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/RequestInput.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/RequestInput.java @@ -23,11 +23,13 @@ public class RequestInput { private String requestId; private Map headers; private List parameters; + private boolean proto; public RequestInput(String requestId) { this.requestId = requestId; headers = new HashMap<>(); parameters = new ArrayList<>(); + proto = false; } public String getRequestId() { @@ -62,6 +64,14 @@ public void addParameter(InputParameter modelInput) { parameters.add(modelInput); } + public void setProto(boolean isProto) { + this.proto = isProto; + } + + public boolean isProto() { + return this.proto; + } + public String getStringParameter(String key) { for (InputParameter param : parameters) { if (key.equals(param.getName())) { diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java index b669106c6..501a8cb68 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java @@ -13,13 +13,16 @@ package com.amazonaws.ml.mms.wlm; import com.amazonaws.ml.mms.http.InternalServerException; +import com.amazonaws.ml.mms.protobuf.codegen.Predictions; import com.amazonaws.ml.mms.util.NettyUtils; import com.amazonaws.ml.mms.util.messages.RequestInput; import com.amazonaws.ml.mms.util.messages.WorkerCommands; +import com.google.protobuf.ByteString; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; import java.util.Map; @@ -37,6 +40,7 @@ public class Job { private RequestInput input; private long begin; private long scheduled; + private boolean proto; public Job( ChannelHandlerContext ctx, String modelName, WorkerCommands cmd, RequestInput input) { @@ -44,6 +48,7 @@ public Job( this.modelName = modelName; this.cmd = cmd; this.input = input; + this.proto = input.isProto(); begin = System.currentTimeMillis(); scheduled = begin; @@ -69,6 +74,14 @@ public RequestInput getPayload() { return input; } + public void setProto(boolean isProto) { + this.proto = isProto; + } + + public boolean isProto() { + return this.proto; + } + public void setScheduled() { scheduled = System.currentTimeMillis(); } @@ -85,15 +98,30 @@ public void response( : HttpResponseStatus.valueOf(statusCode, statusPhrase); FullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false); - if (contentType != null && contentType.length() > 0) { - resp.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); - } - if (responseHeaders != null) { - for (Map.Entry e : responseHeaders.entrySet()) { - resp.headers().set(e.getKey(), e.getValue()); + if (proto) { + Predictions predictions = + Predictions.newBuilder() + .setRequestId(getJobId()) + .setStatusCode(statusCode) + .setReasonPhrase(statusPhrase) + .setContentType(contentType.toString()) + .putAllHeaders(responseHeaders) + .setResp(ByteString.copyFrom(body)) + .build(); + resp.headers() + .set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM); + resp.content().writeBytes(predictions.toByteArray()); + } else { + if (contentType != null && contentType.length() > 0) { + resp.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); + } + if (responseHeaders != null) { + for (Map.Entry e : responseHeaders.entrySet()) { + resp.headers().set(e.getKey(), e.getValue()); + } } + resp.content().writeBytes(body); } - resp.content().writeBytes(body); /* * We can load the models based on the configuration file.Since this Job is @@ -119,7 +147,11 @@ public void sendError(HttpResponseStatus status, String error) { * by external clients. */ if (ctx != null) { - NettyUtils.sendError(ctx, status, new InternalServerException(error)); + if (proto) { + NettyUtils.sendErrorProto(ctx, status, new InternalServerException(error)); + } else { + NettyUtils.sendError(ctx, status, new InternalServerException(error)); + } } logger.debug( diff --git a/frontend/server/src/main/proto/inference.proto b/frontend/server/src/main/proto/inference.proto index ed5d05475..b0d50a1a4 100644 --- a/frontend/server/src/main/proto/inference.proto +++ b/frontend/server/src/main/proto/inference.proto @@ -7,44 +7,43 @@ option java_package = "com.amazonaws.ml.mms.protobuf.codegen"; option java_multiple_files = true; message InferenceRequest { - string modelName = 1; - WorkerCommands command = 2; - RequestInput request = 3; + string modelName = 1; + WorkerCommands command = 2; + RequestInput request = 3; } -message InferenceResponse { - // http code - int32 code = 1; - string message = 2; - repeated Prediction predictions = 3; +message Predictions { + string requestId = 1; + int32 statusCode = 2; + string reasonPhrase = 3; + string contentType = 4; + map headers = 5; + bytes resp = 6; +} + +message ErrorResponse { + int32 code = 1; + string type = 2; + string message = 3; } enum WorkerCommands { - ping = 0; - models = 1; - predictions = 2; - apiDescription = 3; - invocations = 4; - predict = 5; + ping = 0; + models = 1; + predictions = 2; + apiDescription = 3; + invocations = 4; + predict = 5; } message RequestInput { - string requestId = 1; - map headers = 2; - repeated InputParameter parameters = 3; + string requestId = 1; + map headers = 2; + repeated InputParameter parameters = 3; } message InputParameter { - string name = 1; - bytes value = 2; - string contentType = 3; -} - -message Prediction { - string requestId = 1; - int32 statusCode = 2; - string reasonPhrase = 3; - string contentType = 4; - map headers = 5; - bytes resp = 6; + string name = 1; + bytes value = 2; + string contentType = 3; } \ No newline at end of file From da7b3b5272596f54adfe42508d769efe193543f8 Mon Sep 17 00:00:00 2001 From: Li Ning Date: Fri, 16 Oct 2020 18:28:16 -0700 Subject: [PATCH 4/8] add custom endpoint support --- .../ml/mms/http/HttpRequestHandlerChain.java | 91 ++++++++++++------- .../ml/mms/http/InferenceRequestHandler.java | 34 +++---- .../ml/mms/http/ManagementRequestHandler.java | 2 +- .../servingsdk/impl/ModelServerRequest.java | 39 +++++++- .../server/src/main/proto/inference.proto | 4 +- 5 files changed, 118 insertions(+), 52 deletions(-) diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandlerChain.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandlerChain.java index cd942ca92..4416d20d5 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandlerChain.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandlerChain.java @@ -2,6 +2,8 @@ import com.amazonaws.ml.mms.archive.ModelException; import com.amazonaws.ml.mms.archive.ModelNotFoundException; +import com.amazonaws.ml.mms.protobuf.codegen.InferenceRequest; +import com.amazonaws.ml.mms.protobuf.codegen.RequestInput; import com.amazonaws.ml.mms.servingsdk.impl.ModelServerContext; import com.amazonaws.ml.mms.servingsdk.impl.ModelServerRequest; import com.amazonaws.ml.mms.servingsdk.impl.ModelServerResponse; @@ -49,32 +51,30 @@ private void run( FullHttpRequest req, FullHttpResponse rsp, QueryStringDecoder decoder, - String method) + RequestInput input) throws IOException { - switch (method) { + ModelServerRequest modelServerRequest; + if (decoder == null) { + modelServerRequest = new ModelServerRequest(req, input); + } else { + modelServerRequest = new ModelServerRequest(req, decoder); + } + switch (req.method().toString()) { case "GET": endpoint.doGet( - new ModelServerRequest(req, decoder), - new ModelServerResponse(rsp), - new ModelServerContext()); + modelServerRequest, new ModelServerResponse(rsp), new ModelServerContext()); break; case "PUT": endpoint.doPut( - new ModelServerRequest(req, decoder), - new ModelServerResponse(rsp), - new ModelServerContext()); + modelServerRequest, new ModelServerResponse(rsp), new ModelServerContext()); break; case "DELETE": endpoint.doDelete( - new ModelServerRequest(req, decoder), - new ModelServerResponse(rsp), - new ModelServerContext()); + modelServerRequest, new ModelServerResponse(rsp), new ModelServerContext()); break; case "POST": endpoint.doPost( - new ModelServerRequest(req, decoder), - new ModelServerResponse(rsp), - new ModelServerContext()); + modelServerRequest, new ModelServerResponse(rsp), new ModelServerContext()); break; default: throw new ServiceUnavailableException("Invalid HTTP method received"); @@ -85,7 +85,8 @@ protected void handleCustomEndpoint( ChannelHandlerContext ctx, FullHttpRequest req, String[] segments, - QueryStringDecoder decoder) { + QueryStringDecoder decoder, + InferenceRequest inferenceRequest) { ModelServerEndpoint endpoint = endpointMap.get(segments[1]); Runnable r = () -> { @@ -94,32 +95,60 @@ protected void handleCustomEndpoint( new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false); try { - run(endpoint, req, rsp, decoder, req.method().toString()); + if (decoder == null) { + run(endpoint, req, rsp, null, inferenceRequest.getRequest()); + } else { + run(endpoint, req, rsp, decoder, null); + } NettyUtils.sendHttpResponse(ctx, rsp, true); logger.info( "Running \"{}\" endpoint took {} ms", - segments[0], + decoder == null ? inferenceRequest.getCustomCommand() : segments[0], System.currentTimeMillis() - start); } catch (ModelServerEndpointException me) { - NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, me); + if (decoder == null) { + NettyUtils.sendErrorProto( + ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, me); + } else { + NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, me); + } logger.error("Error thrown by the model endpoint plugin.", me); } catch (OutOfMemoryError oom) { - NettyUtils.sendError( - ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, oom, "Out of memory"); + if (decoder == null) { + NettyUtils.sendErrorProto( + ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, oom); + } else { + NettyUtils.sendError( + ctx, + HttpResponseStatus.INSUFFICIENT_STORAGE, + oom, + "Out of memory"); + } + logger.error("Out of memory while running the custom endpoint.", oom); } catch (IOException ioe) { - NettyUtils.sendError( - ctx, - HttpResponseStatus.INTERNAL_SERVER_ERROR, - ioe, - "I/O error while running the custom endpoint"); + if (decoder == null) { + NettyUtils.sendErrorProto( + ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, ioe); + } else { + NettyUtils.sendError( + ctx, + HttpResponseStatus.INTERNAL_SERVER_ERROR, + ioe, + "I/O error while running the custom endpoint"); + } logger.error("I/O error while running the custom endpoint.", ioe); } catch (Throwable e) { - NettyUtils.sendError( - ctx, - HttpResponseStatus.INTERNAL_SERVER_ERROR, - e, - "Unknown exception"); - logger.error("Unknown exception", e); + if (decoder == null) { + NettyUtils.sendErrorProto( + ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, e); + } else { + NettyUtils.sendError( + ctx, + HttpResponseStatus.INTERNAL_SERVER_ERROR, + e, + "Unknown exception"); + logger.error("Unknown exception", e); + } } }; ModelManager.getInstance().submitTask(r); diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java index ccd7ed388..93c43bdfb 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java @@ -104,11 +104,15 @@ protected void handleRequest( case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.models_VALUE: case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.invocations_VALUE: case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.predictions_VALUE: - case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.predict_VALUE: - handlePredictions(ctx, inferenceRequest); + handlePredictions(ctx, inferenceRequest, req.method()); break; default: - chain.handleRequest(ctx, req, null, segments); + if (endpointMap.getOrDefault(inferenceRequest.getCustomCommand(), null) + != null) { + handleCustomEndpoint(ctx, req, segments, null, inferenceRequest); + } else { + chain.handleRequest(ctx, req, null, segments); + } break; } } catch (InvalidProtocolBufferException e) { @@ -116,7 +120,7 @@ protected void handleRequest( } } else if (isInferenceReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { - handleCustomEndpoint(ctx, req, segments, decoder); + handleCustomEndpoint(ctx, req, segments, decoder, null); } else { switch (segments[1]) { case "ping": @@ -171,7 +175,8 @@ private void handleLegacyPredict( predict(ctx, req, decoder, segments[1]); } - private void handlePredictions(ChannelHandlerContext ctx, InferenceRequest inferenceRequest) + private void handlePredictions( + ChannelHandlerContext ctx, InferenceRequest inferenceRequest, HttpMethod method) throws ModelNotFoundException { String modelName = inferenceRequest.getModelName(); if (modelName.isEmpty()) { @@ -189,7 +194,7 @@ private void handlePredictions(ChannelHandlerContext ctx, InferenceRequest infer input.addParameter( new InputParameter(parameter.getName(), parameter.getValue().toByteArray())); } - predict(ctx, modelName, input); + predict(ctx, modelName, input, method); } private void predict( @@ -199,11 +204,17 @@ private void predict( String modelName) throws ModelNotFoundException, BadRequestException { RequestInput input = parseRequest(ctx, req, decoder); + predict(ctx, modelName, input, req.method()); + } + + private void predict( + ChannelHandlerContext ctx, String modelName, RequestInput input, HttpMethod method) + throws ModelNotFoundException, BadRequestException { if (modelName == null) { throw new BadRequestException("Parameter model_name is required."); } - if (HttpMethod.OPTIONS.equals(req.method())) { + if (HttpMethod.OPTIONS.equals(method)) { ModelManager modelManager = ModelManager.getInstance(); Model model = modelManager.getModels().get(modelName); if (model == null) { @@ -222,15 +233,6 @@ private void predict( } } - private void predict(ChannelHandlerContext ctx, String modelName, RequestInput input) - throws ModelNotFoundException, BadRequestException { - Job job = new Job(ctx, modelName, WorkerCommands.PREDICT, input); - if (!ModelManager.getInstance().addJob(job)) { - throw new ServiceUnavailableException( - "No worker is available to serve request: " + modelName); - } - } - private static RequestInput parseRequest( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder) { String requestId = NettyUtils.getRequestId(ctx.channel()); diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ManagementRequestHandler.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ManagementRequestHandler.java index 114ff4f13..a1bc2cada 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ManagementRequestHandler.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ManagementRequestHandler.java @@ -63,7 +63,7 @@ protected void handleRequest( throws ModelException { if (isManagementReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { - handleCustomEndpoint(ctx, req, segments, decoder); + handleCustomEndpoint(ctx, req, segments, decoder, null); } else { if (!"models".equals(segments[1])) { throw new ResourceNotFoundException(); diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java index 9423a1eee..850255c57 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java @@ -13,11 +13,14 @@ package com.amazonaws.ml.mms.servingsdk.impl; +import com.amazonaws.ml.mms.protobuf.codegen.InputParameter; +import com.amazonaws.ml.mms.protobuf.codegen.RequestInput; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.QueryStringDecoder; import java.io.ByteArrayInputStream; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import software.amazon.ai.mms.servingsdk.http.Request; @@ -25,14 +28,27 @@ public class ModelServerRequest implements Request { private FullHttpRequest req; private QueryStringDecoder decoder; + private RequestInput input; + private Map> parameterMap; public ModelServerRequest(FullHttpRequest r, QueryStringDecoder d) { req = r; decoder = d; + parameterMap = null; + } + + public ModelServerRequest(FullHttpRequest r, RequestInput input) { + req = r; + decoder = null; + this.input = input; + parameterMap = null; } @Override public List getHeaderNames() { + if (decoder == null) { + return new ArrayList<>(input.getHeadersMap().keySet()); + } return new ArrayList<>(req.headers().names()); } @@ -43,12 +59,28 @@ public String getRequestURI() { @Override public Map> getParameterMap() { - return decoder.parameters(); + if (parameterMap == null) { + if (decoder == null) { + parameterMap = new HashMap<>(); + for (InputParameter parameter : input.getParametersList()) { + List values = + parameterMap.computeIfAbsent( + parameter.getName(), r -> new ArrayList<>()); + values.add(parameter.getValue().toString()); + } + } else { + parameterMap = decoder.parameters(); + } + } + return parameterMap; } @Override public List getParameter(String k) { - return decoder.parameters().get(k); + if (parameterMap == null) { + getParameterMap(); + } + return parameterMap.get(k); } @Override @@ -58,6 +90,9 @@ public String getContentType() { @Override public ByteArrayInputStream getInputStream() { + if (decoder == null) { + return new ByteArrayInputStream(input.toByteArray()); + } return new ByteArrayInputStream(req.content().array()); } } diff --git a/frontend/server/src/main/proto/inference.proto b/frontend/server/src/main/proto/inference.proto index b0d50a1a4..7fb697d1f 100644 --- a/frontend/server/src/main/proto/inference.proto +++ b/frontend/server/src/main/proto/inference.proto @@ -9,7 +9,8 @@ option java_multiple_files = true; message InferenceRequest { string modelName = 1; WorkerCommands command = 2; - RequestInput request = 3; + string customCommand = 3; + RequestInput request = 4; } message Predictions { @@ -33,7 +34,6 @@ enum WorkerCommands { predictions = 2; apiDescription = 3; invocations = 4; - predict = 5; } message RequestInput { From 40d516a3c708a656a3b03db5ccb4598d1586fa80 Mon Sep 17 00:00:00 2001 From: Li Ning Date: Sun, 18 Oct 2020 10:06:40 -0700 Subject: [PATCH 5/8] add test for protobuf --- .../http/ApiDescriptionRequestHandler.java | 2 +- .../ml/mms/http/HttpRequestHandler.java | 69 ++++++++--- .../ml/mms/http/HttpRequestHandlerChain.java | 98 ++++++++-------- .../ml/mms/http/InferenceRequestHandler.java | 13 +-- .../ml/mms/http/ManagementRequestHandler.java | 2 +- .../servingsdk/impl/ModelServerRequest.java | 1 + .../amazonaws/ml/mms/util/ConfigManager.java | 2 + .../com/amazonaws/ml/mms/util/NettyUtils.java | 26 ++++- .../java/com/amazonaws/ml/mms/wlm/Job.java | 4 +- .../amazonaws/ml/mms/wlm/ModelManager.java | 26 ++++- .../server/src/main/proto/inference.proto | 7 +- .../com/amazonaws/ml/mms/ModelServerTest.java | 109 +++++++++++++++++- 12 files changed, 272 insertions(+), 87 deletions(-) diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ApiDescriptionRequestHandler.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ApiDescriptionRequestHandler.java index f0d36ad70..850f65a60 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ApiDescriptionRequestHandler.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ApiDescriptionRequestHandler.java @@ -25,7 +25,7 @@ protected void handleRequest( String[] segments) throws ModelException { - if (isApiDescription(segments)) { + if (decoder != null && isApiDescription(segments)) { String path = decoder.path(); if (("/".equals(path) && HttpMethod.OPTIONS.equals(req.method())) || (segments.length == 2 && segments[1].equals("api-description"))) { diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java index d71005e8c..8f2969c28 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java @@ -14,11 +14,13 @@ import com.amazonaws.ml.mms.archive.ModelException; import com.amazonaws.ml.mms.archive.ModelNotFoundException; +import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.NettyUtils; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.QueryStringDecoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,52 +45,89 @@ public HttpRequestHandler(HttpRequestHandlerChain chain) { /** {@inheritDoc} */ @Override protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) { + boolean proto = false; try { NettyUtils.requestReceived(ctx.channel(), req); if (!req.decoderResult().isSuccess()) { throw new BadRequestException("Invalid HTTP message."); } - - QueryStringDecoder decoder = new QueryStringDecoder(req.uri()); - String path = decoder.path(); - - String[] segments = path.split("/"); - if (segments.length == 1) { + CharSequence contentType = HttpUtil.getMimeType(req); + if (contentType != null + && contentType + .toString() + .toLowerCase() + .contains(ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF)) { + proto = true; handlerChain.handleRequest(ctx, req, null, null); } else { + QueryStringDecoder decoder = new QueryStringDecoder(req.uri()); + String path = decoder.path(); + String[] segments = path.split("/"); handlerChain.handleRequest(ctx, req, decoder, segments); } } catch (ResourceNotFoundException | ModelNotFoundException e) { logger.trace("", e); - NettyUtils.sendError(ctx, HttpResponseStatus.NOT_FOUND, e); + if (proto) { + NettyUtils.sendErrorProto(ctx, HttpResponseStatus.NOT_FOUND, e); + } else { + NettyUtils.sendError(ctx, HttpResponseStatus.NOT_FOUND, e); + } } catch (BadRequestException | ModelException e) { logger.trace("", e); - NettyUtils.sendError(ctx, HttpResponseStatus.BAD_REQUEST, e); + if (proto) { + NettyUtils.sendErrorProto(ctx, HttpResponseStatus.BAD_REQUEST, e); + } else { + NettyUtils.sendError(ctx, HttpResponseStatus.BAD_REQUEST, e); + } } catch (ConflictStatusException e) { logger.trace("", e); - NettyUtils.sendError(ctx, HttpResponseStatus.CONFLICT, e); + if (proto) { + NettyUtils.sendErrorProto(ctx, HttpResponseStatus.CONFLICT, e); + } else { + NettyUtils.sendError(ctx, HttpResponseStatus.CONFLICT, e); + } } catch (RequestTimeoutException e) { logger.trace("", e); - NettyUtils.sendError(ctx, HttpResponseStatus.REQUEST_TIMEOUT, e); + if (proto) { + NettyUtils.sendErrorProto(ctx, HttpResponseStatus.REQUEST_TIMEOUT, e); + } else { + NettyUtils.sendError(ctx, HttpResponseStatus.REQUEST_TIMEOUT, e); + } } catch (MethodNotAllowedException e) { logger.trace("", e); - NettyUtils.sendError(ctx, HttpResponseStatus.METHOD_NOT_ALLOWED, e); + if (proto) { + NettyUtils.sendErrorProto(ctx, HttpResponseStatus.METHOD_NOT_ALLOWED, e); + } else { + NettyUtils.sendError(ctx, HttpResponseStatus.METHOD_NOT_ALLOWED, e); + } } catch (ServiceUnavailableException e) { logger.trace("", e); - NettyUtils.sendError(ctx, HttpResponseStatus.SERVICE_UNAVAILABLE, e); + if (proto) { + NettyUtils.sendErrorProto(ctx, HttpResponseStatus.SERVICE_UNAVAILABLE, e); + } else { + NettyUtils.sendError(ctx, HttpResponseStatus.SERVICE_UNAVAILABLE, e); + } } catch (OutOfMemoryError e) { logger.trace("", e); - NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, e); + if (proto) { + NettyUtils.sendErrorProto(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, e); + } else { + NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, e); + } } catch (Throwable t) { logger.error("", t); - NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, t); + if (proto) { + NettyUtils.sendErrorProto(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, t); + } else { + NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, t); + } } } /** {@inheritDoc} */ @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - logger.error("", cause); + logger.error("exceptionCaught:", cause); if (cause instanceof OutOfMemoryError) { NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, cause); } diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandlerChain.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandlerChain.java index 4416d20d5..9b6e49745 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandlerChain.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandlerChain.java @@ -85,8 +85,7 @@ protected void handleCustomEndpoint( ChannelHandlerContext ctx, FullHttpRequest req, String[] segments, - QueryStringDecoder decoder, - InferenceRequest inferenceRequest) { + QueryStringDecoder decoder) { ModelServerEndpoint endpoint = endpointMap.get(segments[1]); Runnable r = () -> { @@ -95,60 +94,69 @@ protected void handleCustomEndpoint( new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false); try { - if (decoder == null) { - run(endpoint, req, rsp, null, inferenceRequest.getRequest()); - } else { - run(endpoint, req, rsp, decoder, null); - } + run(endpoint, req, rsp, decoder, null); + NettyUtils.sendHttpResponse(ctx, rsp, true); + logger.info( + "Running \"{}\" endpoint took {} ms", + segments[0], + System.currentTimeMillis() - start); + } catch (ModelServerEndpointException me) { + NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, me); + logger.error("Error thrown by the model endpoint plugin.", me); + } catch (OutOfMemoryError oom) { + NettyUtils.sendError( + ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, oom, "Out of memory"); + logger.error("Out of memory while running the custom endpoint.", oom); + } catch (IOException ioe) { + NettyUtils.sendError( + ctx, + HttpResponseStatus.INTERNAL_SERVER_ERROR, + ioe, + "I/O error while running the custom endpoint"); + logger.error("I/O error while running the custom endpoint.", ioe); + } catch (Throwable e) { + NettyUtils.sendError( + ctx, + HttpResponseStatus.INTERNAL_SERVER_ERROR, + e, + "Unknown exception"); + logger.error("Unknown exception", e); + } + }; + ModelManager.getInstance().submitTask(r); + } + + protected void handleCustomEndpoint( + ChannelHandlerContext ctx, FullHttpRequest req, InferenceRequest inferenceRequest) { + ModelServerEndpoint endpoint = endpointMap.get(inferenceRequest.getCustomCommand()); + Runnable r = + () -> { + Long start = System.currentTimeMillis(); + FullHttpResponse rsp = + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false); + try { + run(endpoint, req, rsp, null, inferenceRequest.getRequest()); NettyUtils.sendHttpResponse(ctx, rsp, true); logger.info( "Running \"{}\" endpoint took {} ms", - decoder == null ? inferenceRequest.getCustomCommand() : segments[0], + inferenceRequest.getCustomCommand(), System.currentTimeMillis() - start); } catch (ModelServerEndpointException me) { - if (decoder == null) { - NettyUtils.sendErrorProto( - ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, me); - } else { - NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, me); - } + NettyUtils.sendErrorProto( + ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, me); logger.error("Error thrown by the model endpoint plugin.", me); } catch (OutOfMemoryError oom) { - if (decoder == null) { - NettyUtils.sendErrorProto( - ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, oom); - } else { - NettyUtils.sendError( - ctx, - HttpResponseStatus.INSUFFICIENT_STORAGE, - oom, - "Out of memory"); - } + NettyUtils.sendErrorProto( + ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, oom, "Out of memory"); logger.error("Out of memory while running the custom endpoint.", oom); } catch (IOException ioe) { - if (decoder == null) { - NettyUtils.sendErrorProto( - ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, ioe); - } else { - NettyUtils.sendError( - ctx, - HttpResponseStatus.INTERNAL_SERVER_ERROR, - ioe, - "I/O error while running the custom endpoint"); - } + NettyUtils.sendErrorProto( + ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, ioe); logger.error("I/O error while running the custom endpoint.", ioe); } catch (Throwable e) { - if (decoder == null) { - NettyUtils.sendErrorProto( - ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, e); - } else { - NettyUtils.sendError( - ctx, - HttpResponseStatus.INTERNAL_SERVER_ERROR, - e, - "Unknown exception"); - logger.error("Unknown exception", e); - } + NettyUtils.sendErrorProto(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, e); + logger.error("Unknown exception", e); } }; ModelManager.getInstance().submitTask(r); diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java index 93c43bdfb..f0c0b93bb 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java @@ -91,7 +91,7 @@ protected void handleRequest( FullHttpRequest req, QueryStringDecoder decoder, String[] segments) - throws ModelException { + throws ModelNotFoundException, ModelException { if (decoder == null) { try { InferenceRequest inferenceRequest = @@ -99,17 +99,15 @@ protected void handleRequest( switch (inferenceRequest.getCommandValue()) { case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.ping_VALUE: - ModelManager.getInstance().workerStatus(ctx); + ModelManager.getInstance().workerStatus(ctx, true); break; - case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.models_VALUE: - case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.invocations_VALUE: case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.predictions_VALUE: handlePredictions(ctx, inferenceRequest, req.method()); break; default: if (endpointMap.getOrDefault(inferenceRequest.getCustomCommand(), null) != null) { - handleCustomEndpoint(ctx, req, segments, null, inferenceRequest); + handleCustomEndpoint(ctx, req, inferenceRequest); } else { chain.handleRequest(ctx, req, null, segments); } @@ -120,11 +118,11 @@ protected void handleRequest( } } else if (isInferenceReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { - handleCustomEndpoint(ctx, req, segments, decoder, null); + handleCustomEndpoint(ctx, req, segments, decoder); } else { switch (segments[1]) { case "ping": - ModelManager.getInstance().workerStatus(ctx); + ModelManager.getInstance().workerStatus(ctx, false); break; case "models": case "invocations": @@ -184,6 +182,7 @@ private void handlePredictions( modelName = ModelManager.getInstance().getStartupModels().iterator().next(); } } + RequestInput input = new RequestInput(NettyUtils.getRequestId(ctx.channel())); input.setProto(true); com.amazonaws.ml.mms.protobuf.codegen.RequestInput protoInput = diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ManagementRequestHandler.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ManagementRequestHandler.java index a1bc2cada..114ff4f13 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ManagementRequestHandler.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ManagementRequestHandler.java @@ -63,7 +63,7 @@ protected void handleRequest( throws ModelException { if (isManagementReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { - handleCustomEndpoint(ctx, req, segments, decoder, null); + handleCustomEndpoint(ctx, req, segments, decoder); } else { if (!"models".equals(segments[1])) { throw new ResourceNotFoundException(); diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java index 850255c57..d8fa571f8 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java @@ -34,6 +34,7 @@ public class ModelServerRequest implements Request { public ModelServerRequest(FullHttpRequest r, QueryStringDecoder d) { req = r; decoder = d; + this.input = null; parameterMap = null; } diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java index 3f4031c9a..4d0dd3b78 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java @@ -99,6 +99,8 @@ public final class ConfigManager { public static final String MODEL_LOGGER = "MODEL_LOG"; public static final String MODEL_SERVER_METRICS_LOGGER = "MMS_METRICS"; + public static final String HTTP_CONTENT_TYPE_PROTOBUF = "application/x-protobuf"; + private Pattern blacklistPattern; private Properties prop; private static Pattern pattern = Pattern.compile("\\$\\$([^$]+[^$])\\$\\$"); diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java index 30cfc79c8..b5b560abd 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java @@ -132,24 +132,38 @@ public static void sendError( sendJsonResponse(ctx, error, status); } + public static void sendError( + ChannelHandlerContext ctx, HttpResponseStatus status, Throwable t, String msg) { + ErrorResponse error = new ErrorResponse(status.code(), t.getClass().getSimpleName(), msg); + sendJsonResponse(ctx, error, status); + } + public static void sendErrorProto( ChannelHandlerContext ctx, HttpResponseStatus status, Throwable t) { - com.amazonaws.ml.mms.protobuf.codegen.ErrorResponse errorResponse = - com.amazonaws.ml.mms.protobuf.codegen.ErrorResponse.newBuilder() + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse errorResponse = + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse.newBuilder() .setCode(status.code()) .setType(t.getClass().getSimpleName()) .setMessage(t.getMessage()) .build(); FullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false); - resp.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM); + resp.headers().set(HttpHeaderNames.CONTENT_TYPE, ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); resp.content().writeBytes(errorResponse.toByteArray()); sendHttpResponse(ctx, resp, true); } - public static void sendError( + public static void sendErrorProto( ChannelHandlerContext ctx, HttpResponseStatus status, Throwable t, String msg) { - ErrorResponse error = new ErrorResponse(status.code(), t.getClass().getSimpleName(), msg); - sendJsonResponse(ctx, error, status); + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse errorResponse = + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse.newBuilder() + .setCode(status.code()) + .setType(t.getClass().getSimpleName()) + .setMessage(msg) + .build(); + FullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false); + resp.headers().set(HttpHeaderNames.CONTENT_TYPE, ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); + resp.content().writeBytes(errorResponse.toByteArray()); + sendHttpResponse(ctx, resp, true); } /** diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java index 501a8cb68..ea69096cd 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java @@ -14,6 +14,7 @@ import com.amazonaws.ml.mms.http.InternalServerException; import com.amazonaws.ml.mms.protobuf.codegen.Predictions; +import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.NettyUtils; import com.amazonaws.ml.mms.util.messages.RequestInput; import com.amazonaws.ml.mms.util.messages.WorkerCommands; @@ -22,7 +23,6 @@ import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; import java.util.Map; @@ -109,7 +109,7 @@ public void response( .setResp(ByteString.copyFrom(body)) .build(); resp.headers() - .set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM); + .set(HttpHeaderNames.CONTENT_TYPE, ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); resp.content().writeBytes(predictions.toByteArray()); } else { if (contentType != null && contentType.length() > 0) { diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/ModelManager.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/ModelManager.java index cc478bc10..9dccf37aa 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/ModelManager.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/ModelManager.java @@ -21,7 +21,11 @@ import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.NettyUtils; import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; import java.io.IOException; import java.util.HashSet; import java.util.List; @@ -218,7 +222,7 @@ public boolean addJob(Job job) throws ModelNotFoundException { return model.addJob(job); } - public void workerStatus(final ChannelHandlerContext ctx) { + public void workerStatus(final ChannelHandlerContext ctx, boolean isProto) { Runnable r = () -> { String response = "Healthy"; @@ -237,8 +241,24 @@ public void workerStatus(final ChannelHandlerContext ctx) { // TODO: Check if its OK to send other 2xx errors to ALB for "Partial Healthy" // and "Unhealthy" - NettyUtils.sendJsonResponse( - ctx, new StatusResponse(response), HttpResponseStatus.OK); + if (isProto) { + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse statusResponse = + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse.newBuilder() + .setMessage(response) + .build(); + FullHttpResponse resp = + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false); + resp.headers() + .set( + HttpHeaderNames.CONTENT_TYPE, + ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); + resp.content().writeBytes(statusResponse.toByteArray()); + NettyUtils.sendHttpResponse(ctx, resp, true); + } else { + NettyUtils.sendJsonResponse( + ctx, new StatusResponse(response), HttpResponseStatus.OK); + } }; wlm.scheduleAsync(r); } diff --git a/frontend/server/src/main/proto/inference.proto b/frontend/server/src/main/proto/inference.proto index 7fb697d1f..06c8f6449 100644 --- a/frontend/server/src/main/proto/inference.proto +++ b/frontend/server/src/main/proto/inference.proto @@ -22,7 +22,7 @@ message Predictions { bytes resp = 6; } -message ErrorResponse { +message StatusResponse { int32 code = 1; string type = 2; string message = 3; @@ -30,10 +30,7 @@ message ErrorResponse { enum WorkerCommands { ping = 0; - models = 1; - predictions = 2; - apiDescription = 3; - invocations = 4; + predictions = 1; } message RequestInput { diff --git a/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java b/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java index 8aca550d6..b99746439 100644 --- a/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java +++ b/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java @@ -19,11 +19,18 @@ import com.amazonaws.ml.mms.metrics.Dimension; import com.amazonaws.ml.mms.metrics.Metric; import com.amazonaws.ml.mms.metrics.MetricManager; +import com.amazonaws.ml.mms.protobuf.codegen.InferenceRequest; +import com.amazonaws.ml.mms.protobuf.codegen.InputParameter; +import com.amazonaws.ml.mms.protobuf.codegen.Predictions; +import com.amazonaws.ml.mms.protobuf.codegen.RequestInput; +import com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands; import com.amazonaws.ml.mms.servingsdk.impl.PluginsManager; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.Connector; import com.amazonaws.ml.mms.util.JsonUtils; import com.google.gson.JsonParseException; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; @@ -61,6 +68,7 @@ import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Field; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; import java.util.List; @@ -87,6 +95,7 @@ public class ModelServerTest { CountDownLatch latch; HttpResponseStatus httpStatus; String result; + ByteBuffer resultBuf; HttpHeaders headers; private String listInferenceApisResult; private String listManagementApisResult; @@ -150,6 +159,7 @@ public void test() Assert.assertNotNull(channel, "Failed to connect to inference port."); Assert.assertNotNull(managementChannel, "Failed to connect to management port."); testPing(channel); + // testPingProto(channel); testRoot(channel, listInferenceApisResult); testRoot(managementChannel, listManagementApisResult); @@ -167,6 +177,7 @@ public void test() testPredictions(channel); testPredictionsBinary(channel); testPredictionsJson(channel); + testPredictionsProto(channel); testInvocationsJson(channel); testInvocationsMultipart(channel); testModelsInvokeJson(channel); @@ -195,7 +206,7 @@ public void test() testInvalidPredictionsUri(); testInvalidDescribeModel(); testPredictionsModelNotFound(); - + testPredictionsModelNotFoundProto(); testInvalidManagementUri(); testInvalidModelsMethod(); testInvalidModelMethod(); @@ -237,6 +248,25 @@ private void testPing(Channel channel) throws InterruptedException { Assert.assertTrue(headers.contains("x-request-id")); } + private void testPingProto(Channel channel) + throws InterruptedException, InvalidProtocolBufferException { + resultBuf = null; + latch = new CountDownLatch(1); + InferenceRequest inferenceRequest = + InferenceRequest.newBuilder().setCommand(WorkerCommands.ping).build(); + DefaultFullHttpRequest req = + new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/ping"); + req.headers().add("Content-Type", ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); + req.content().writeBytes(inferenceRequest.toByteArray()); + channel.writeAndFlush(req); + latch.await(); + + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse resp = + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse.parseFrom(resultBuf); + Assert.assertEquals(resp.getMessage(), "Healthy"); + Assert.assertTrue(headers.contains("x-request-id")); + } + private void testApiDescription(Channel channel, String expected) throws InterruptedException { result = null; latch = new CountDownLatch(1); @@ -430,6 +460,35 @@ private void testPredictionsBinary(Channel channel) throws InterruptedException Assert.assertEquals(result, "OK"); } + private void testPredictionsProto(Channel channel) + throws InterruptedException, InvalidProtocolBufferException { + resultBuf = null; + latch = new CountDownLatch(1); + + InputParameter parameter = + InputParameter.newBuilder() + .setName("data") + .setValue(ByteString.copyFrom("test", CharsetUtil.UTF_8)) + .build(); + InferenceRequest inferenceRequest = + InferenceRequest.newBuilder() + .setCommand(WorkerCommands.predictions) + .setModelName("noop") + .setRequest(RequestInput.newBuilder().addParameters(parameter).build()) + .build(); + + DefaultFullHttpRequest req = + new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + req.content().writeBytes(inferenceRequest.toByteArray()); + HttpUtil.setContentLength(req, req.content().readableBytes()); + req.headers().set(HttpHeaderNames.CONTENT_TYPE, ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); + channel.writeAndFlush(req); + + latch.await(); + Predictions predictions = Predictions.parseFrom(resultBuf); + Assert.assertEquals(predictions.getResp().toStringUtf8(), "OK"); + } + private void testInvocationsJson(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); @@ -778,6 +837,44 @@ private void testPredictionsModelNotFound() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel"); } + private void testPredictionsModelNotFoundProto() + throws InterruptedException, InvalidProtocolBufferException { + Channel channel = connect(false); + Assert.assertNotNull(channel); + + resultBuf = null; + latch = new CountDownLatch(1); + + InputParameter parameter = + InputParameter.newBuilder() + .setName("data") + .setValue(ByteString.copyFrom("test", CharsetUtil.UTF_8)) + .build(); + InferenceRequest inferenceRequest = + InferenceRequest.newBuilder() + .setCommand(WorkerCommands.predictions) + .setModelName("InvalidModel") + .setRequest(RequestInput.newBuilder().addParameters(parameter).build()) + .build(); + + DefaultFullHttpRequest req = + new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + req.content().writeBytes(inferenceRequest.toByteArray()); + HttpUtil.setContentLength(req, req.content().readableBytes()); + req.headers().set(HttpHeaderNames.CONTENT_TYPE, ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); + channel.writeAndFlush(req).sync(); + channel.closeFuture().sync(); + + latch.await(); + + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse resp = + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse.parseFrom(resultBuf); + + Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); + Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel"); + channel.close(); + } + private void testInvalidManagementUri() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); @@ -1361,8 +1458,16 @@ private class TestHandler extends SimpleChannelInboundHandler @Override public void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) { + CharSequence contentType = HttpUtil.getMimeType(msg); + if (contentType != null + && contentType + .toString() + .equalsIgnoreCase(ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF)) { + resultBuf = msg.content().nioBuffer(); + } else { + result = msg.content().toString(StandardCharsets.UTF_8); + } httpStatus = msg.status(); - result = msg.content().toString(StandardCharsets.UTF_8); headers = msg.headers(); latch.countDown(); } From cc54c2ffdbd8d5656fb773754b1c7354713bbb1d Mon Sep 17 00:00:00 2001 From: Li Ning Date: Sun, 18 Oct 2020 10:32:23 -0700 Subject: [PATCH 6/8] enable testPingProto --- .../src/test/java/com/amazonaws/ml/mms/ModelServerTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java b/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java index b99746439..c0aa30465 100644 --- a/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java +++ b/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java @@ -159,7 +159,7 @@ public void test() Assert.assertNotNull(channel, "Failed to connect to inference port."); Assert.assertNotNull(managementChannel, "Failed to connect to management port."); testPing(channel); - // testPingProto(channel); + testPingProto(channel); testRoot(channel, listInferenceApisResult); testRoot(managementChannel, listManagementApisResult); From 62dd819f711febd00c9072142171cc2264d88587 Mon Sep 17 00:00:00 2001 From: Li Ning Date: Mon, 19 Oct 2020 18:03:23 -0700 Subject: [PATCH 7/8] fix mms-ci-python3 error --- ci/Dockerfile.python3.6 | 1 + mms/tests/unit_tests/test_beckend_metric.py | 1 + mms/tests/unit_tests/test_worker_service.py | 1 + 3 files changed, 3 insertions(+) diff --git a/ci/Dockerfile.python3.6 b/ci/Dockerfile.python3.6 index 76c45a314..f77a07e7d 100644 --- a/ci/Dockerfile.python3.6 +++ b/ci/Dockerfile.python3.6 @@ -190,6 +190,7 @@ RUN set -ex \ && pip install retrying \ && pip install mock \ && pip install pytest -U \ + && pip install pytest-mock \ && pip install pylint # Install protobuf diff --git a/mms/tests/unit_tests/test_beckend_metric.py b/mms/tests/unit_tests/test_beckend_metric.py index 4ef2e6ad0..6732973bb 100644 --- a/mms/tests/unit_tests/test_beckend_metric.py +++ b/mms/tests/unit_tests/test_beckend_metric.py @@ -29,6 +29,7 @@ def test_metrics(caplog): Test if metric classes methods behave as expected Also checks global metric service methods """ + caplog.set_level(logging.INFO) # Create a batch of request ids request_ids = {0: 'abcd', 1: "xyz", 2: "qwerty", 3: "hjshfj"} all_req_ids = ','.join(request_ids.values()) diff --git a/mms/tests/unit_tests/test_worker_service.py b/mms/tests/unit_tests/test_worker_service.py index 575f13d73..bafba308c 100644 --- a/mms/tests/unit_tests/test_worker_service.py +++ b/mms/tests/unit_tests/test_worker_service.py @@ -50,6 +50,7 @@ def test_valid_req(self, service): class TestEmitMetrics: def test_emit_metrics(self, caplog): + caplog.set_level(logging.INFO) metrics = {'test_emit_metrics': True} emit_metrics(metrics) assert "[METRICS]" in caplog.text From e0a6724de619df1e7d526094431080a3064e6ec6 Mon Sep 17 00:00:00 2001 From: Li Ning Date: Fri, 23 Oct 2020 12:49:46 -0700 Subject: [PATCH 8/8] test protobuf size --- .../com/amazonaws/ml/mms/ModelServerTest.java | 177 ++++++++++++++++-- 1 file changed, 157 insertions(+), 20 deletions(-) diff --git a/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java b/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java index c0aa30465..69bbae9fd 100644 --- a/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java +++ b/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java @@ -21,7 +21,6 @@ import com.amazonaws.ml.mms.metrics.MetricManager; import com.amazonaws.ml.mms.protobuf.codegen.InferenceRequest; import com.amazonaws.ml.mms.protobuf.codegen.InputParameter; -import com.amazonaws.ml.mms.protobuf.codegen.Predictions; import com.amazonaws.ml.mms.protobuf.codegen.RequestInput; import com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands; import com.amazonaws.ml.mms.servingsdk.impl.PluginsManager; @@ -67,15 +66,20 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.io.ObjectOutputStream; import java.lang.reflect.Field; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; +import java.util.LinkedList; import java.util.List; import java.util.Properties; +import java.util.Random; import java.util.Scanner; import java.util.concurrent.CountDownLatch; import org.apache.commons.io.IOUtils; +import org.apache.commons.io.output.ByteArrayOutputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; @@ -177,7 +181,7 @@ public void test() testPredictions(channel); testPredictionsBinary(channel); testPredictionsJson(channel); - testPredictionsProto(channel); + testPredictionsProto(); testInvocationsJson(channel); testInvocationsMultipart(channel); testModelsInvokeJson(channel); @@ -460,33 +464,166 @@ private void testPredictionsBinary(Channel channel) throws InterruptedException Assert.assertEquals(result, "OK"); } - private void testPredictionsProto(Channel channel) - throws InterruptedException, InvalidProtocolBufferException { - resultBuf = null; - latch = new CountDownLatch(1); + private void testPredictionsProto() throws InterruptedException, IOException { + Logger logger = LoggerFactory.getLogger(ModelServerTest.class); + + float[] featureVec = { + 0.8241127f, 0.77719664f, 0.47123995f, 0.27323001f, 0.24874457f, 0.77869387f, + 0.50711921f, 0.10696663f, 0.60663805f, 0.76063525f, 0.96358908f, 0.71026102f, + 0.57714464f, 0.58250422f, 0.91595038f, 0.24119576f, 0.58981158f, 0.67119473f, + 0.94832165f, 0.91711728f, 0.0323646f, 0.07007003f, 0.89158581f, 0.01916486f, + 0.5647568f, 0.99879008f, 0.58311515f, 0.87001143f, 0.50620349f, 0.65268692f, + 0.83657373f, 0.31589474f, 0.70910797f, 0.62886395f, 0.03498501f, 0.36503007f, + 0.94178899f, 0.21739391f, 0.29688258f, 0.34630696f, 0.30494259f, 0.04302086f, + 0.3578226f, 0.04361075f, 0.91962488f, 0.24961093f, 0.0124245f, 0.31004002f, + 0.61543447f, 0.34500444f, 0.30441186f, 0.44085924f, 0.67489625f, 0.03938287f, + 0.89307169f, 0.22283647f, 0.44441515f, 0.82044036f, 0.37541783f, 0.25868981f, + 0.46510721f, 0.51640271f, 0.40917042f, 0.65912921f, 0.72228879f, 0.42611241f, + 0.71283259f, 0.37417586f, 0.786403f, 0.6912011f, 0.4338622f, 0.29868897f, + 0.0342538f, 0.16938266f, 0.90234809f, 0.3051922f, 0.92377579f, 0.97883088f, + 0.2028601f, 0.50478822f, 0.84762944f, 0.11011502f, 0.70006246f, 0.34329564f, + 0.49022718f, 0.8569296f, 0.75698334f, 0.84864789f, 0.9477985f, 0.46994381f, + 0.05319027f, 0.07369953f, 0.08497094f, 0.54536333f, 0.87922514f, 0.97857665f, + 0.06930542f, 0.27101086f, 0.03069235f, 0.13432096f, 0.96021588f, 0.9484153f, + 0.75365465f, 0.76216408f, 0.43294879f, 0.41034781f, 0.01088872f, 0.29060839f, + 0.94462721f, 0.83999491f, 0.4364634f, 0.63611379f, 0.32102346f, 0.10418961f, + 0.2776194f, 0.73166493f, 0.76387601f, 0.83429646f, 0.94348065f, 0.85956626f, + 0.81160069f, 0.1650624f, 0.79505978f, 0.67288331f, 0.3204887f, 0.89388283f, + 0.85290859f, 0.11308228f, 0.81252801f, 0.87276483f, 0.76737167f, 0.16166891f, + 0.78767838f, 0.79160494f, 0.80843258f, 0.39723985f, 0.47062281f, 0.96028728f, + 0.55309858f, 0.05378428f, 0.3619188f, 0.69888766f, 0.76134346f, 0.60911425f, + 0.85562674f, 0.58098788f, 0.5438003f, 0.61229528f, 0.14350196f, 0.75286178f, + 0.88131248f, 0.69132185f, 0.12576858f, 0.23459534f, 0.26883056f, 0.98129534f, + 0.74060036f, 0.9607236f, 0.99617814f, 0.75829678f, 0.06310486f, 0.55572225f, + 0.72709395f, 0.77374732f, 0.81625695f, 0.13475297f, 0.89352917f, 0.19805313f, + 0.34789188f, 0.08422005f, 0.67733949f, 0.94300965f, 0.22116594f, 0.10948816f, + 0.50651639f, 0.40402931f, 0.46181863f, 0.14743327f, 0.33300708f, 0.87358395f, + 0.79312213f, 0.54662338f, 0.83890467f, 0.87690315f, 0.24570711f, 0.01534696f, + 0.11803501f, 0.21333099f, 0.75169896f, 0.42758898f, 0.80780874f, 0.57331851f, + 0.96341639f, 0.52078203f, 0.22610806f, 0.83348684f, 0.76036637f, 0.99407179f, + 0.96098997f, 0.2451298f, 0.41848766f, 0.01584927f, 0.28213452f, 0.04494721f, + 0.16963578f, 0.68096619f, 0.39404686f, 0.7621266f, 0.02721071f, 0.5481559f, + 0.59972178f, 0.61725009f, 0.76405802f, 0.83030081f, 0.87232659f, 0.16119207f, + 0.51143718f, 0.13040968f, 0.57453206f, 0.63200166f, 0.27077547f, 0.72281371f, + 0.44055048f, 0.51538986f, 0.29096202f, 0.99726975f, 0.50958807f, 0.87792484f, + 0.03956957f, 0.42187308f, 0.87694541f, 0.88974026f, 0.65590356f, 0.35029236f, + 0.18853136f, 0.50500502f, 0.95545852f, 0.94636341f, 0.84731837f, 0.13936297f, + 0.32537976f, 0.41430316f, 0.18574781f, 0.97574309f, 0.26483325f, 0.79840404f, + 0.74069621f, 0.98526361f, 0.63957011f, 0.30924823f, 0.20429374f, 0.09850504f, + 0.77676228f, 0.40561045f, 0.71999222f, 0.42545573f, 0.78092917f, 0.74532941f, + 0.52263514f, 0.01771433f, 0.15041333f, 0.41157879f, 0.15047035f, 0.66149007f, + 0.95970903f, 0.97348663f, 0.30155038f, 0.06596597f, 0.3317747f, 0.09346482f, + 0.71672818f, 0.13279156f, 0.19758743f, 0.20143709f, 0.84517665f, 0.767672f, + 0.21471986f, 0.75663108f, 0.35878468f, 0.58943601f, 0.98005496f, 0.30451585f, + 0.34754926f, 0.3298018f, 0.36859658f, 0.52568727f, 0.45107675f, 0.27778918f, + 0.4825746f, 0.6521011f, 0.16924284f, 0.54550222f, 0.33862934f, 0.88247624f, + 0.97012639f, 0.64496125f, 0.09514454f, 0.90497989f, 0.82705286f, 0.5232794f, + 0.80558394f, 0.86949601f, 0.78825486f, 0.23086437f, 0.64405503f, 0.02989425f, + 0.61423185f, 0.45341492f, 0.52462891f, 0.93029992f, 0.74040612f, 0.45227326f, + 0.35339424f, 0.30661544f, 0.70083487f, 0.68725394f, 0.2036894f, 0.85478822f, + 0.13176267f, 0.10494695f, 0.17226407f, 0.88662847f, 0.42744141f, 0.44540842f, + 0.94161152f, 0.46699513f, 0.36795051f, 0.0234292f, 0.68830582f, 0.33571055f, + 0.93930267f, 0.76513689f, 0.69002036f, 0.11983312f, 0.05524331f, 0.28743821f, + 0.53563344f, 0.00152629f, 0.50295284f, 0.24351331f, 0.6770774f, 0.42484211f, + 0.10956752f, 0.01239354f, 0.57630947f, 0.16575461f, 0.7870273f, 0.64387019f, + 0.65514058f, 0.62808722f, 0.29263556f, 0.8159863f, 0.18642033f + }; + List instances = new LinkedList<>(); + Random rand = new Random(); + for (int i = 0; i < 50; i++) { + float[] data = new float[featureVec.length]; + for (int j = 0; j < featureVec.length; j++) { + data[j] = featureVec[rand.nextInt(featureVec.length)]; + } + instances.add(data); + } + + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(bos); + oos.writeObject(instances); + byte[] bytes = bos.toByteArray(); + byte[] byteString = ByteString.copyFrom(bytes).toByteArray(); InputParameter parameter = - InputParameter.newBuilder() - .setName("data") - .setValue(ByteString.copyFrom("test", CharsetUtil.UTF_8)) - .build(); + InputParameter.newBuilder().setValue(ByteString.copyFrom(bytes)).build(); InferenceRequest inferenceRequest = InferenceRequest.newBuilder() .setCommand(WorkerCommands.predictions) - .setModelName("noop") + .setModelName("test") .setRequest(RequestInput.newBuilder().addParameters(parameter).build()) .build(); + logger.info( + "2D random float size=" + + featureVec.length * 50 + + ", byteString size=" + + byteString.length + + ", bytes size=" + + bytes.length + + ", parameter size=" + + parameter.toByteArray().length + + ", protobuf size=" + + inferenceRequest.toByteArray().length); + oos.close(); + + List instances1 = new LinkedList<>(); + for (int i = 0; i < 50; i++) { + instances1.add(featureVec); + } - DefaultFullHttpRequest req = - new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); - req.content().writeBytes(inferenceRequest.toByteArray()); - HttpUtil.setContentLength(req, req.content().readableBytes()); - req.headers().set(HttpHeaderNames.CONTENT_TYPE, ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); - channel.writeAndFlush(req); + ByteArrayOutputStream bos1 = new ByteArrayOutputStream(); + ObjectOutputStream oos1 = new ObjectOutputStream(bos1); + oos1.writeObject(instances1); + byte[] bytes1 = bos1.toByteArray(); + byte[] byteString1 = ByteString.copyFrom(bytes1).toByteArray(); - latch.await(); - Predictions predictions = Predictions.parseFrom(resultBuf); - Assert.assertEquals(predictions.getResp().toStringUtf8(), "OK"); + InputParameter parameter1 = + InputParameter.newBuilder().setValue(ByteString.copyFrom(bytes1)).build(); + InferenceRequest inferenceRequest1 = + InferenceRequest.newBuilder() + .setCommand(WorkerCommands.predictions) + .setModelName("test") + .setRequest(RequestInput.newBuilder().addParameters(parameter1).build()) + .build(); + logger.info( + "2D repeated float size=" + + featureVec.length * 50 + + ", byteString size=" + + byteString1.length + + ", bytes size=" + + bytes1.length + + ", parameter size=" + + parameter1.toByteArray().length + + ", protobuf size=" + + inferenceRequest1.toByteArray().length); + oos1.close(); + + ByteBuffer fBuffer = ByteBuffer.allocate(Float.BYTES * featureVec.length * 50); + fBuffer.order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < 50; i++) { + for (float feature : featureVec) { + fBuffer.putFloat(feature); + } + } + byte[] bytes2 = fBuffer.array(); + InputParameter parameter2 = + InputParameter.newBuilder().setValue(ByteString.copyFrom(bytes2)).build(); + InferenceRequest inferenceRequest2 = + InferenceRequest.newBuilder() + .setCommand(WorkerCommands.predictions) + .setModelName("test") + .setRequest(RequestInput.newBuilder().addParameters(parameter2).build()) + .build(); + logger.info( + "1D repeated float size=" + + featureVec.length * 50 + + ", fBuffer size=" + + fBuffer.array().length + + ", bytes size=" + + bytes2.length + + ", parameter size=" + + parameter2.toByteArray().length + + ", protobuf size=" + + inferenceRequest2.toByteArray().length); } private void testInvocationsJson(Channel channel) throws InterruptedException {