Skip to content

Commit 7963e77

Browse files
csteegzshahasad
authored andcommitted
Cherry pick 2 OnnxRuntime Server fixes
1 parent 7d5b089 commit 7963e77

File tree

5 files changed

+78
-19
lines changed

5 files changed

+78
-19
lines changed

onnxruntime/server/http/core/session.cc

-5
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,6 @@ http::status HttpSession::ExecuteUserFunction(HttpContext& context) {
121121
context.client_request_id = context.request[util::MS_CLIENT_REQUEST_ID_HEADER].to_string();
122122
}
123123

124-
if (path == "/score") {
125-
// This is a shortcut since we have only one model instance currently.
126-
// This code path will be removed once we start supporting multiple models or multiple versions of one model.
127-
path = "/v1/models/default/versions/1:predict";
128-
}
129124

130125
auto status = routes_.ParseUrl(context.request.method(), path, model_name, model_version, action, func);
131126

onnxruntime/server/http/predict_request_handler.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ void Predict(const std::string& name,
3939
auto logger = env->GetLogger(context.request_id);
4040
logger->info("Model Name: {}, Version: {}, Action: {}", name, version, action);
4141

42+
auto effective_name = name.empty() ? "default" : name;
43+
auto effective_version = version.empty() ? "1" : version;
44+
4245
if (!context.client_request_id.empty()) {
4346
logger->info("{}: [{}]", util::MS_CLIENT_REQUEST_ID_HEADER, context.client_request_id);
4447
}
@@ -64,7 +67,7 @@ void Predict(const std::string& name,
6467
// Run Prediction
6568
Executor executor(env.get(), context.request_id);
6669
PredictResponse predict_response{};
67-
auto status = executor.Predict(name, version, predict_request, predict_response);
70+
auto status = executor.Predict(effective_name, effective_version, predict_request, predict_response);
6871
if (!status.ok()) {
6972
GenerateErrorResponse(logger, GetHttpStatusCode((status)), status.error_message(), context);
7073
return;

onnxruntime/server/main.cc

+8-1
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,18 @@ int main(int argc, char* argv[]) {
104104
});
105105

106106
app.RegisterPost(
107-
R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))",
107+
R"(/(?:v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))|(?:score()()()))",
108108
[&env](const auto& name, const auto& version, const auto& action, auto& context) -> void {
109109
server::Predict(name, version, action, context, env);
110110
});
111111

112+
app.RegisterPost(
113+
R"(/score()()())",
114+
[&env](const auto& name, const auto& version, const auto& action, auto& context) -> void {
115+
server::Predict(name, version, action, context, env);
116+
}
117+
);
118+
112119
app.Bind(boost_address, config.http_port)
113120
.NumThreads(config.num_http_threads)
114121
.Run();

onnxruntime/test/server/integration_tests/function_tests.py

+44
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,50 @@ def test_single_model_shortcut(self):
192192
for i in range(0, 10):
193193
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))
194194

195+
def test_single_version_shortcut(self):
196+
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json')
197+
output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.json')
198+
199+
with open(input_data_file, 'r') as f:
200+
request_payload = f.read()
201+
202+
with open(output_data_file, 'r') as f:
203+
expected_response_json = f.read()
204+
expected_response = json.loads(expected_response_json)
205+
206+
request_headers = {
207+
'Content-Type': 'application/json',
208+
'Accept': 'application/json',
209+
'x-ms-client-request-id': 'This~is~my~id'
210+
}
211+
212+
url = "http://{0}:{1}/v1/models/{2}:predict".format(self.server_ip, self.server_port, 'default')
213+
test_util.test_log(url)
214+
r = requests.post(url, headers=request_headers, data=request_payload)
215+
self.assertEqual(r.status_code, 200)
216+
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
217+
self.assertTrue(r.headers.get('x-ms-request-id'))
218+
self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id')
219+
220+
actual_response = json.loads(r.content.decode('utf-8'))
221+
222+
# Note:
223+
# The 'dims' field is defined as "repeated int64" in protobuf.
224+
# When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string
225+
# Reference: https://developers.google.com/protocol-buffers/docs/proto3#json
226+
227+
self.assertTrue(actual_response['outputs'])
228+
self.assertTrue(actual_response['outputs']['Plus214_Output_0'])
229+
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dims'])
230+
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dims'], ['1', '10'])
231+
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dataType'])
232+
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dataType'], 1)
233+
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['rawData'])
234+
actual_data = test_util.decode_base64_string(actual_response['outputs']['Plus214_Output_0']['rawData'], '10f')
235+
expected_data = test_util.decode_base64_string(expected_response['outputs']['Plus214_Output_0']['rawData'], '10f')
236+
237+
for i in range(0, 10):
238+
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))
195239

196240
class HttpProtobufPayloadTests(unittest.TestCase):
197241
server_ip = '127.0.0.1'

onnxruntime/test/server/unit_tests/http_routes_tests.cc

+22-12
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace onnxruntime {
1010
namespace server {
1111
namespace test {
1212

13+
static const std::string predict_regex = R"(/(?:v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict)))";
1314
using test_data = std::tuple<http::verb, std::string, std::string, std::string, std::string, http::status>;
1415

1516
void do_something(const std::string& name, const std::string& version,
@@ -20,7 +21,6 @@ void do_something(const std::string& name, const std::string& version,
2021
void run_route(const std::string& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data);
2122

2223
TEST(HttpRouteTests, RegisterTest) {
23-
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
2424
Routes routes;
2525
EXPECT_TRUE(routes.RegisterController(http::verb::post, predict_regex, do_something));
2626

@@ -29,20 +29,22 @@ TEST(HttpRouteTests, RegisterTest) {
2929
}
3030

3131
TEST(HttpRouteTests, PostRouteTest) {
32-
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
3332

3433
std::vector<test_data> actions{
3534
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::ok),
3635
std::make_tuple(http::verb::post, "/v1/models/abc:predict", "abc", "", "predict", http::status::ok),
3736
std::make_tuple(http::verb::post, "/v1/models/models/versions/45:predict", "models", "45", "predict", http::status::ok),
3837
std::make_tuple(http::verb::post, "/v1/models/??$$%%@@$^^/versions/45:predict", "??$$%%@@$^^", "45", "predict", http::status::ok),
39-
std::make_tuple(http::verb::post, "/v1/models/versions/versions/45:predict", "versions", "45", "predict", http::status::ok)};
38+
std::make_tuple(http::verb::post, "/v1/models/versions/versions/45:predict", "versions", "45", "predict", http::status::ok),
39+
std::make_tuple(http::verb::post, "/v1/models/versions:predict", "versions", "", "predict", http::status::ok),
40+
std::make_tuple(http::verb::post, "/v1/models/default:predict", "default", "", "predict", http::status::ok)
41+
};
42+
4043

4144
run_route(predict_regex, http::verb::post, actions, true);
4245
}
4346

4447
TEST(HttpRouteTests, PostRouteInvalidURLTest) {
45-
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
4648

4749
std::vector<test_data> actions{
4850
std::make_tuple(http::verb::post, "", "", "", "", http::status::not_found),
@@ -56,27 +58,35 @@ TEST(HttpRouteTests, PostRouteInvalidURLTest) {
5658
std::make_tuple(http::verb::post, "/models/abc/versions/2:predict", "", "", "", http::status::not_found),
5759
std::make_tuple(http::verb::post, "/v1/models/versions/2:predict", "", "", "", http::status::not_found),
5860
std::make_tuple(http::verb::post, "/v1/models/foo/versions/:predict", "", "", "", http::status::not_found),
59-
std::make_tuple(http::verb::post, "/v1/models/foo/versions:predict", "", "", "", http::status::not_found),
6061
std::make_tuple(http::verb::post, "v1/models/foo/versions/12:predict", "", "", "", http::status::not_found),
61-
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:foo", "", "", "", http::status::not_found)};
62+
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:foo", "", "", "", http::status::not_found)
63+
};
6264

6365
run_route(predict_regex, http::verb::post, actions, false);
6466
}
6567

6668
// These tests are because we currently only support POST and GET
6769
// Some HTTP methods should be removed from test data if we support more (e.g. PUT)
6870
TEST(HttpRouteTests, PostRouteInvalidMethodTest) {
69-
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
7071

7172
std::vector<test_data> actions{
7273
std::make_tuple(http::verb::get, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::method_not_allowed),
7374
std::make_tuple(http::verb::put, "/v1/models", "", "", "", http::status::method_not_allowed),
7475
std::make_tuple(http::verb::delete_, "/v1/models", "", "", "", http::status::method_not_allowed),
75-
std::make_tuple(http::verb::head, "/v1/models", "", "", "", http::status::method_not_allowed)};
76+
std::make_tuple(http::verb::head, "/v1/models", "", "", "", http::status::method_not_allowed)
77+
};
7678

7779
run_route(predict_regex, http::verb::post, actions, false);
7880
}
7981

82+
TEST(HttpRouteTests, PostRouteSpecialMethodTest){
83+
std::vector<test_data> actions{
84+
std::make_tuple(http::verb::post, "/score", "", "", "", http::status::ok)
85+
};
86+
87+
run_route(R"(/score()()())", http::verb::post, actions, true);
88+
}
89+
8090
void run_route(const std::string& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data) {
8191
Routes routes;
8292
EXPECT_TRUE(routes.RegisterController(method, pattern, do_something));
@@ -95,11 +105,11 @@ void run_route(const std::string& pattern, http::verb method, const std::vector<
95105
http::status expected_status;
96106

97107
std::tie(test_method, url_string, expected_name, expected_version, expected_action, expected_status) = i;
98-
EXPECT_EQ(expected_status, routes.ParseUrl(test_method, url_string, name, version, action, fn));
108+
EXPECT_EQ(expected_status, routes.ParseUrl(test_method, url_string, name, version, action, fn)) << "On route " << url_string;
99109
if (does_validate_data) {
100-
EXPECT_EQ(name, expected_name);
101-
EXPECT_EQ(version, expected_version);
102-
EXPECT_EQ(action, expected_action);
110+
EXPECT_EQ(name, expected_name) << "On route " << url_string;
111+
EXPECT_EQ(version, expected_version) << "On route " << url_string;
112+
EXPECT_EQ(action, expected_action) << "On route " << url_string;
103113
}
104114
}
105115
}

0 commit comments

Comments
 (0)