Skip to content

Commit f7307ef

Browse files
committed
Refine returned model data types & add initial examples
1 parent cc03c8d commit f7307ef

18 files changed

+545
-66
lines changed

examples/mnist.java

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
///usr/bin/env jbang "$0" "$@" ; exit $?
2+
//DEPS com.github.tadayosi.torchserve:torchserve-client:0.1-SNAPSHOT
3+
4+
import java.nio.file.Files;
5+
import java.nio.file.Path;
6+
7+
import com.github.tadayosi.torchserve.client.impl.DefaultInference;
8+
import com.github.tadayosi.torchserve.client.inference.invoker.ApiException;
9+
10+
public class mnist {
11+
12+
private static String MNIST_MODEL = "mnist_v2";
13+
14+
public static void main(String... args) throws Exception {
15+
var zero = Files.readAllBytes(Path.of("src/test/resources/data/0.png"));
16+
var one = Files.readAllBytes(Path.of("src/test/resources/data/1.png"));
17+
try {
18+
var inference = new DefaultInference();
19+
var result0 = inference.predictions(MNIST_MODEL, zero);
20+
System.out.println("Answer> " + result0);
21+
var result1 = inference.predictions(MNIST_MODEL, one);
22+
System.out.println("Answer> " + result1);
23+
} catch (ApiException e) {
24+
System.err.println(e.getResponseBody());
25+
e.printStackTrace();
26+
}
27+
}
28+
}

examples/register_mnist.java

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
///usr/bin/env jbang "$0" "$@" ; exit $?
2+
//DEPS com.github.tadayosi.torchserve:torchserve-client:0.1-SNAPSHOT
3+
4+
import com.github.tadayosi.torchserve.client.impl.DefaultManagement;
5+
import com.github.tadayosi.torchserve.client.management.invoker.ApiException;
6+
import com.github.tadayosi.torchserve.client.model.RegisterModelOptions;
7+
import com.github.tadayosi.torchserve.client.model.SetAutoScaleOptions;
8+
9+
public class register_mnist {
10+
11+
private static String MNIST_URL = "https://torchserve.pytorch.org/mar_files/mnist_v2.mar";
12+
private static String MNIST_MODEL = "mnist_v2";
13+
14+
public static void main(String... args) throws Exception {
15+
try {
16+
var management = new DefaultManagement();
17+
var response = management.registerModel(MNIST_URL, RegisterModelOptions.empty());
18+
System.out.println("registerModel> " + response.getStatus());
19+
response = management.setAutoScale(MNIST_MODEL, SetAutoScaleOptions.builder()
20+
.minWorker(1)
21+
.maxWorker(1)
22+
.build());
23+
System.out.println("setAutoScale> " + response.getStatus());
24+
} catch (ApiException e) {
25+
System.err.println(e.getResponseBody());
26+
}
27+
}
28+
}

src/main/java/com/github/tadayosi/torchserve/client/Inference.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package com.github.tadayosi.torchserve.client;
22

3+
import com.github.tadayosi.torchserve.client.model.API;
4+
import com.github.tadayosi.torchserve.client.model.Response;
5+
36
/**
47
* Inference API
58
*/
@@ -8,12 +11,12 @@ public interface Inference {
811
/**
912
* Get openapi description.
1013
*/
11-
Object apiDescription() throws Exception;
14+
API apiDescription() throws Exception;
1215

1316
/**
1417
* Get TorchServe status.
1518
*/
16-
Object ping() throws Exception;
19+
Response ping() throws Exception;
1720

1821
/**
1922
* Predictions entry point to get inference using default model version.

src/main/java/com/github/tadayosi/torchserve/client/Management.java

+14-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
import java.util.List;
44

5+
import com.github.tadayosi.torchserve.client.model.API;
6+
import com.github.tadayosi.torchserve.client.model.ModelDetail;
7+
import com.github.tadayosi.torchserve.client.model.ModelList;
58
import com.github.tadayosi.torchserve.client.model.RegisterModelOptions;
9+
import com.github.tadayosi.torchserve.client.model.Response;
610
import com.github.tadayosi.torchserve.client.model.SetAutoScaleOptions;
711
import com.github.tadayosi.torchserve.client.model.UnregisterModelOptions;
812

@@ -14,52 +18,52 @@ public interface Management {
1418
/**
1519
* Register a new model in TorchServe.
1620
*/
17-
Object registerModel(String url, RegisterModelOptions options) throws Exception;
21+
Response registerModel(String url, RegisterModelOptions options) throws Exception;
1822

1923
/**
2024
* Configure number of workers for a default version of a model. This is an asynchronous call by default. Caller need to call describeModel to check if the model workers has been changed.
2125
*/
22-
Object setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception;
26+
Response setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception;
2327

2428
/**
2529
* Configure number of workers for a specified version of a model. This is an asynchronous call by default. Caller need to call describeModel to check if the model workers has been changed.
2630
*/
27-
Object setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception;
31+
Response setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception;
2832

2933
/**
3034
* Provides detailed information about the default version of a model.
3135
*/
32-
List<Object> describeModel(String modelName) throws Exception;
36+
List<ModelDetail> describeModel(String modelName) throws Exception;
3337

3438
/**
3539
* Provides detailed information about the specified version of a model.If "all" is specified as version, returns the details about all the versions of the model.
3640
*/
37-
List<Object> describeModel(String modelName, String modelVersion) throws Exception;
41+
List<ModelDetail> describeModel(String modelName, String modelVersion) throws Exception;
3842

3943
/**
4044
* Unregister the default version of a model from TorchServe if it is the only version available. This is an asynchronous call by default. Caller can call listModels to confirm model is unregistered.
4145
*/
42-
Object unregisterModel(String modelName, UnregisterModelOptions options) throws Exception;
46+
Response unregisterModel(String modelName, UnregisterModelOptions options) throws Exception;
4347

4448
/**
4549
* Unregister the specified version of a model from TorchServe. This is an asynchronous call by default. Caller can call listModels to confirm model is unregistered.
4650
*/
47-
Object unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options) throws Exception;
51+
Response unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options) throws Exception;
4852

4953
/**
5054
* List registered models in TorchServe.
5155
*/
52-
Object listModels(Integer limit, String nextPageToken) throws Exception;
56+
ModelList listModels(Integer limit, String nextPageToken) throws Exception;
5357

5458
/**
5559
* Set default version of a model.
5660
*/
57-
Object setDefault(String modelName, String modelVersion) throws Exception;
61+
Response setDefault(String modelName, String modelVersion) throws Exception;
5862

5963
/**
6064
* Get openapi description.
6165
*/
62-
Object apiDescription() throws Exception;
66+
API apiDescription() throws Exception;
6367

6468
/**
6569
* Not supported yet.

src/main/java/com/github/tadayosi/torchserve/client/Metrics.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ public interface Metrics {
88
/**
99
* Get TorchServe application metrics in prometheus format.
1010
*/
11-
Object metrics() throws Exception;
11+
String metrics() throws Exception;
1212

1313
/**
1414
* Get TorchServe application metrics in prometheus format.
1515
*/
16-
Object metrics(String name) throws Exception;
16+
String metrics(String name) throws Exception;
1717

1818
}

src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultInference.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import com.github.tadayosi.torchserve.client.Inference;
44
import com.github.tadayosi.torchserve.client.inference.api.DefaultApi;
55
import com.github.tadayosi.torchserve.client.inference.invoker.ApiClient;
6+
import com.github.tadayosi.torchserve.client.model.API;
7+
import com.github.tadayosi.torchserve.client.model.Response;
68

79
public class DefaultInference implements Inference {
810

@@ -18,13 +20,13 @@ public DefaultInference(int port) {
1820
}
1921

2022
@Override
21-
public Object apiDescription() throws Exception {
22-
return api.apiDescription();
23+
public API apiDescription() throws Exception {
24+
return API.from(api.apiDescription());
2325
}
2426

2527
@Override
26-
public Object ping() throws Exception {
27-
return api.ping();
28+
public Response ping() throws Exception {
29+
return Response.from(api.ping());
2830
}
2931

3032
@Override

src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultManagement.java

+32-25
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
package com.github.tadayosi.torchserve.client.impl;
22

33
import java.util.List;
4+
import java.util.Map;
45

56
import com.github.tadayosi.torchserve.client.Management;
67
import com.github.tadayosi.torchserve.client.management.api.DefaultApi;
78
import com.github.tadayosi.torchserve.client.management.invoker.ApiClient;
9+
import com.github.tadayosi.torchserve.client.model.API;
10+
import com.github.tadayosi.torchserve.client.model.Model;
11+
import com.github.tadayosi.torchserve.client.model.ModelDetail;
12+
import com.github.tadayosi.torchserve.client.model.ModelList;
813
import com.github.tadayosi.torchserve.client.model.RegisterModelOptions;
14+
import com.github.tadayosi.torchserve.client.model.Response;
915
import com.github.tadayosi.torchserve.client.model.SetAutoScaleOptions;
1016
import com.github.tadayosi.torchserve.client.model.UnregisterModelOptions;
1117

@@ -23,8 +29,8 @@ public DefaultManagement(int port) {
2329
}
2430

2531
@Override
26-
public Object registerModel(String url, RegisterModelOptions options) throws Exception {
27-
return api.registerModel(url, null,
32+
public Response registerModel(String url, RegisterModelOptions options) throws Exception {
33+
return Response.from(api.registerModel(url, null,
2834
options.getModelName(),
2935
options.getHandler(),
3036
options.getRuntime(),
@@ -33,66 +39,67 @@ public Object registerModel(String url, RegisterModelOptions options) throws Exc
3339
options.getResponseTimeout(),
3440
options.getInitialWorkers(),
3541
options.getSynchronous(),
36-
options.getS3SseKms());
42+
options.getS3SseKms()));
3743
}
3844

3945
@Override
40-
public Object setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception {
41-
return api.setAutoScale(modelName,
46+
public Response setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception {
47+
return Response.from(api.setAutoScale(modelName,
4248
options.getMinWorker(),
4349
options.getMaxWorker(),
4450
options.getNumberGpu(),
4551
options.getSynchronous(),
46-
options.getTimeout());
52+
options.getTimeout()));
4753
}
4854

4955
@Override
50-
public Object setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception {
51-
return api.versionSetAutoScale(modelName, modelVersion,
56+
public Response setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception {
57+
return Response.from(api.versionSetAutoScale(modelName, modelVersion,
5258
options.getMinWorker(),
5359
options.getMaxWorker(),
5460
options.getNumberGpu(),
5561
options.getSynchronous(),
56-
options.getTimeout());
62+
options.getTimeout()));
5763
}
5864

5965
@Override
60-
public List<Object> describeModel(String modelName) throws Exception {
61-
return List.copyOf(api.describeModel(modelName));
66+
public List<ModelDetail> describeModel(String modelName) throws Exception {
67+
return ModelDetail.from(api.describeModel(modelName));
6268
}
6369

6470
@Override
65-
public List<Object> describeModel(String modelName, String modelVersion) throws Exception {
66-
return List.copyOf(api.versionDescribeModel(modelName, modelVersion));
71+
public List<ModelDetail> describeModel(String modelName, String modelVersion) throws Exception {
72+
return ModelDetail.from(api.versionDescribeModel(modelName, modelVersion));
6773
}
6874

6975
@Override
70-
public Object unregisterModel(String modelName, UnregisterModelOptions options) throws Exception {
71-
return api.unregisterModel(modelName,
76+
public Response unregisterModel(String modelName, UnregisterModelOptions options) throws Exception {
77+
return Response.from(api.unregisterModel(modelName,
7278
options.getSynchronous(),
73-
options.getTimeout());
79+
options.getTimeout()));
7480
}
7581

7682
@Override
77-
public Object unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options) throws Exception {
78-
return api.versionUnregisterModel(modelName, modelVersion,
83+
public Response unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options)
84+
throws Exception {
85+
return Response.from(api.versionUnregisterModel(modelName, modelVersion,
7986
options.getSynchronous(),
80-
options.getTimeout());
87+
options.getTimeout()));
8188
}
8289

8390
@Override
84-
public Object listModels(Integer limit, String nextPageToken) throws Exception {
85-
return api.listModels(limit, nextPageToken);
91+
public ModelList listModels(Integer limit, String nextPageToken) throws Exception {
92+
return ModelList.from(api.listModels(limit, nextPageToken));
8693
}
8794

8895
@Override
89-
public Object setDefault(String modelName, String modelVersion) throws Exception {
90-
return api.setDefault(modelName, modelVersion);
96+
public Response setDefault(String modelName, String modelVersion) throws Exception {
97+
return Response.from(api.setDefault(modelName, modelVersion));
9198
}
9299

93100
@Override
94-
public Object apiDescription() throws Exception {
95-
return api.apiDescription();
101+
public API apiDescription() throws Exception {
102+
return API.from(api.apiDescription());
96103
}
97104

98105
@Override

src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultMetrics.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ public DefaultMetrics(int port) {
1818
}
1919

2020
@Override
21-
public Object metrics() throws Exception {
21+
public String metrics() throws Exception {
2222
return metrics(null);
2323
}
2424

2525
@Override
26-
public Object metrics(String name) throws Exception {
26+
public String metrics(String name) throws Exception {
2727
return api.metrics(name);
2828
}
2929

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package com.github.tadayosi.torchserve.client.model;
2+
3+
import java.util.HashMap;
4+
import java.util.Map;
5+
6+
public class API {
7+
8+
private String openapi = null;
9+
private Map<String, String> info = new HashMap<>();
10+
private Map<String, Object> paths = new HashMap<>();
11+
12+
public API() {
13+
}
14+
15+
@SuppressWarnings("unchecked")
16+
public static API from(com.github.tadayosi.torchserve.client.inference.model.InlineResponse200 src) {
17+
API api = new API();
18+
api.setOpenapi(src.getOpenapi());
19+
api.setInfo((Map<String, String>) src.getInfo());
20+
api.setPaths((Map<String, Object>) src.getPaths());
21+
return api;
22+
}
23+
24+
@SuppressWarnings("unchecked")
25+
public static API from(com.github.tadayosi.torchserve.client.management.model.InlineResponse200 src) {
26+
API api = new API();
27+
api.setOpenapi(src.getOpenapi());
28+
api.setInfo((Map<String, String>) src.getInfo());
29+
api.setPaths((Map<String, Object>) src.getPaths());
30+
return api;
31+
}
32+
33+
public String getOpenapi() {
34+
return openapi;
35+
}
36+
37+
public void setOpenapi(String openapi) {
38+
this.openapi = openapi;
39+
}
40+
41+
public Map<String, String> getInfo() {
42+
return info;
43+
}
44+
45+
public void setInfo(Map<String, String> info) {
46+
this.info = info;
47+
}
48+
49+
public Map<String, Object> getPaths() {
50+
return paths;
51+
}
52+
53+
public void setPaths(Map<String, Object> paths) {
54+
this.paths = paths;
55+
}
56+
}

0 commit comments

Comments
 (0)