Skip to content

Commit c1b1bb6

Browse files
committed
Create top-level interface TorchServeClient
1 parent 4cc0e28 commit c1b1bb6

File tree

5 files changed

+104
-48
lines changed

5 files changed

+104
-48
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package com.github.tadayosi.torchserve.client;
2+
3+
import com.github.tadayosi.torchserve.client.impl.DefaultInference;
4+
import com.github.tadayosi.torchserve.client.impl.DefaultManagement;
5+
import com.github.tadayosi.torchserve.client.impl.DefaultMetrics;
6+
7+
public class TorchServeClient {
8+
9+
private final Inference inference;
10+
private final Management management;
11+
private final Metrics metrics;
12+
13+
private TorchServeClient() {
14+
this(new DefaultInference(), new DefaultManagement(), new DefaultMetrics());
15+
}
16+
17+
private TorchServeClient(Inference inference, Management management, Metrics metrics) {
18+
this.inference = inference;
19+
this.management = management;
20+
this.metrics = metrics;
21+
}
22+
23+
public static TorchServeClient newInstance() {
24+
return new TorchServeClient();
25+
}
26+
27+
public static Builder builder() {
28+
return new Builder();
29+
}
30+
31+
public Inference inference() {
32+
return inference;
33+
}
34+
35+
public Management management() {
36+
return management;
37+
}
38+
39+
public Metrics metrics() {
40+
return metrics;
41+
}
42+
43+
public static class Builder {
44+
45+
private Integer inferencePort;
46+
private Integer managementPort;
47+
private Integer metricsPort;
48+
49+
public Builder inferencePort(int port) {
50+
this.inferencePort = port;
51+
return this;
52+
}
53+
54+
public Builder managementPort(int port) {
55+
this.managementPort = port;
56+
return this;
57+
}
58+
59+
public Builder metricsPort(Integer metricsPort) {
60+
this.metricsPort = metricsPort;
61+
return this;
62+
}
63+
64+
public TorchServeClient build() {
65+
Inference inference = inferencePort == null ? new DefaultInference() : new DefaultInference(inferencePort);
66+
Management management = managementPort == null ? new DefaultManagement() : new DefaultManagement(managementPort);
67+
Metrics metrics = metricsPort == null ? new DefaultMetrics() : new DefaultMetrics(metricsPort);
68+
return new TorchServeClient(inference, management, metrics);
69+
}
70+
}
71+
}

src/test/java/com/github/tadayosi/torchserve/client/InferenceTest.java

+5-14
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import java.nio.file.Path;
55
import java.util.Map;
66

7-
import com.github.tadayosi.torchserve.client.impl.DefaultInference;
8-
import org.junit.jupiter.api.BeforeEach;
97
import org.junit.jupiter.api.Test;
108
import org.testcontainers.junit.jupiter.Testcontainers;
119

@@ -21,42 +19,35 @@ public class InferenceTest extends TorchServeTestSupport {
2119
private static final String DEFAULT_MODEL_VERSION = "1.0";
2220
private static final String TEST_DATA = "src/test/resources/data/kitten.jpg";
2321

24-
private Inference inference;
25-
26-
@BeforeEach
27-
public void setUp() {
28-
inference = new DefaultInference(torchServe.getMappedPort(8080));
29-
}
30-
3122
@Test
3223
public void testApiDescription() throws Exception {
33-
var response = inference.apiDescription();
24+
var response = client.inference().apiDescription();
3425
assertNotNull(response);
3526
}
3627

3728
@Test
3829
public void testPing() throws Exception {
39-
var response = inference.ping();
30+
var response = client.inference().ping();
4031
assertEquals("Healthy", response.getStatus());
4132
}
4233

4334
@Test
4435
public void testPredictions() throws Exception {
4536
var body = Files.readAllBytes(Path.of(TEST_DATA));
46-
var response = inference.predictions(DEFAULT_MODEL, body);
37+
var response = client.inference().predictions(DEFAULT_MODEL, body);
4738
assertInstanceOf(Map.class, response);
4839
}
4940

5041
@Test
5142
public void testPredictions_version() throws Exception {
5243
var body = Files.readAllBytes(Path.of(TEST_DATA));
53-
var response = inference.predictions(DEFAULT_MODEL, DEFAULT_MODEL_VERSION, body);
44+
var response = client.inference().predictions(DEFAULT_MODEL, DEFAULT_MODEL_VERSION, body);
5445
assertInstanceOf(Map.class, response);
5546
}
5647

5748
@Test
5849
public void testExplanations() {
59-
assertThrows(UnsupportedOperationException.class, () -> inference.explanations(DEFAULT_MODEL));
50+
assertThrows(UnsupportedOperationException.class, () -> client.inference().explanations(DEFAULT_MODEL));
6051
}
6152

6253
}

src/test/java/com/github/tadayosi/torchserve/client/ManagementTest.java

+14-23
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import java.nio.file.Path;
55

66
import com.github.tadayosi.torchserve.client.impl.DefaultInference;
7-
import com.github.tadayosi.torchserve.client.impl.DefaultManagement;
87
import com.github.tadayosi.torchserve.client.model.ApiException;
98
import com.github.tadayosi.torchserve.client.model.RegisterModelOptions;
109
import com.github.tadayosi.torchserve.client.model.SetAutoScaleOptions;
@@ -31,18 +30,11 @@ public class ManagementTest extends TorchServeTestSupport {
3130
private static final String ADDED_MODEL_VERSION = "2.0";
3231
private static final String TEST_DATA_DIR = "src/test/resources/data";
3332

34-
private Management management;
35-
36-
@BeforeEach
37-
public void setUp() throws Exception {
38-
management = new DefaultManagement(torchServe.getMappedPort(8081));
39-
}
40-
4133
@Test
4234
public void testRegisterModel() throws Exception {
4335
var url = "https://torchserve.pytorch.org/mar_files/mnist_v2.mar";
4436
try {
45-
var response = management.registerModel(url, RegisterModelOptions.empty());
37+
var response = client.management().registerModel(url, RegisterModelOptions.empty());
4638
assertTrue(response.getStatus().contains("registered"));
4739
} catch (ApiException e) {
4840
e.printStackTrace();
@@ -57,21 +49,21 @@ class AfterRegisteringModel {
5749
public void registerModel() throws Exception {
5850
var url = "https://torchserve.pytorch.org/mar_files/mnist_v2.mar";
5951
try {
60-
management.registerModel(url, RegisterModelOptions.empty());
52+
client.management().registerModel(url, RegisterModelOptions.empty());
6153
} catch (ApiException e) {
6254
// Ignore if the model is already registered
6355
}
6456
}
6557

6658
@Test
6759
public void testUnregisterModel() throws Exception {
68-
var response = management.unregisterModel(ADDED_MODEL, UnregisterModelOptions.empty());
60+
var response = client.management().unregisterModel(ADDED_MODEL, UnregisterModelOptions.empty());
6961
assertTrue(response.getStatus().contains("unregistered"));
7062
}
7163

7264
@Test
7365
public void testUnregisterModel_version() throws Exception {
74-
var response = management.unregisterModel(ADDED_MODEL, ADDED_MODEL_VERSION, UnregisterModelOptions.empty());
66+
var response = client.management().unregisterModel(ADDED_MODEL, ADDED_MODEL_VERSION, UnregisterModelOptions.empty());
7567
assertTrue(response.getStatus().contains("unregistered"));
7668
}
7769

@@ -80,12 +72,12 @@ class BeforeUnregisteringModel {
8072

8173
@AfterEach
8274
public void unregisterModel() throws Exception {
83-
management.unregisterModel(ADDED_MODEL, UnregisterModelOptions.empty());
75+
client.management().unregisterModel(ADDED_MODEL, UnregisterModelOptions.empty());
8476
}
8577

8678
@Test
8779
public void testSetAutoScale() throws Exception {
88-
var response1 = management.setAutoScale(ADDED_MODEL,
80+
var response1 = client.management().setAutoScale(ADDED_MODEL,
8981
SetAutoScaleOptions.builder()
9082
.minWorker(1)
9183
.build());
@@ -101,16 +93,15 @@ public void testSetAutoScale() throws Exception {
10193

10294
@Test
10395
public void testSetAutoScale_version() throws Exception {
104-
var response1 = management.setAutoScale(ADDED_MODEL, ADDED_MODEL_VERSION,
96+
var response1 = client.management().setAutoScale(ADDED_MODEL, ADDED_MODEL_VERSION,
10597
SetAutoScaleOptions.builder()
10698
.minWorker(1)
10799
.build());
108100
assertTrue(response1.getStatus().contains("Processing worker updates"));
109101

110102
// Testing inference with MNIST V2
111-
var inference = new DefaultInference(torchServe.getMappedPort(8080));
112103
var body = Files.readAllBytes(Path.of(TEST_DATA_DIR, "1.png"));
113-
var response2 = inference.predictions(ADDED_MODEL, body);
104+
var response2 = client.inference().predictions(ADDED_MODEL, body);
114105
assertInstanceOf(Double.class, response2);
115106
assertEquals(1.0, (Double) response2, 0.001);
116107
}
@@ -119,14 +110,14 @@ public void testSetAutoScale_version() throws Exception {
119110

120111
@Test
121112
public void testDescribeModel() throws Exception {
122-
var response = management.describeModel(DEFAULT_MODEL);
113+
var response = client.management().describeModel(DEFAULT_MODEL);
123114
assertEquals(1, response.size());
124115
assertEquals("squeezenet1_1", response.get(0).getModelName());
125116
}
126117

127118
@Test
128119
public void testDescribeModel_version() throws Exception {
129-
var response = management.describeModel(DEFAULT_MODEL, DEFAULT_MODEL_VERSION);
120+
var response = client.management().describeModel(DEFAULT_MODEL, DEFAULT_MODEL_VERSION);
130121
assertEquals(1, response.size());
131122
assertEquals("squeezenet1_1", response.get(0).getModelName());
132123
assertEquals("1.0", response.get(0).getModelVersion());
@@ -136,27 +127,27 @@ public void testDescribeModel_version() throws Exception {
136127
public void testListModels() throws Exception {
137128
int limit = 10;
138129
String nextPageToken = null;
139-
var response = management.listModels(limit, nextPageToken);
130+
var response = client.management().listModels(limit, nextPageToken);
140131
var models = response.getModels();
141132
assertFalse(models.isEmpty());
142133
assertEquals(DEFAULT_MODEL, models.get(0).getModelName());
143134
}
144135

145136
@Test
146137
public void testSetDefault() throws Exception {
147-
var response = management.setDefault(DEFAULT_MODEL, DEFAULT_MODEL_VERSION);
138+
var response = client.management().setDefault(DEFAULT_MODEL, DEFAULT_MODEL_VERSION);
148139
assertTrue(response.getStatus().contains("Default vesion succsesfully updated"));
149140
}
150141

151142
@Test
152143
public void testApiDescription() throws Exception {
153-
var response = management.apiDescription();
144+
var response = client.management().apiDescription();
154145
assertEquals("TorchServe APIs", response.getInfo().get("title"));
155146
}
156147

157148
@Test
158149
public void testToken() throws Exception {
159-
assertThrows(UnsupportedOperationException.class, () -> management.token("management"));
150+
assertThrows(UnsupportedOperationException.class, () -> client.management().token("management"));
160151
}
161152

162153
}
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package com.github.tadayosi.torchserve.client;
22

3-
import com.github.tadayosi.torchserve.client.impl.DefaultMetrics;
4-
import org.junit.jupiter.api.BeforeEach;
53
import org.junit.jupiter.api.Test;
64
import org.testcontainers.junit.jupiter.Testcontainers;
75

@@ -10,16 +8,9 @@
108
@Testcontainers
119
public class MetricsTest extends TorchServeTestSupport {
1210

13-
private Metrics metrics;
14-
15-
@BeforeEach
16-
public void setUp() {
17-
metrics = new DefaultMetrics(torchServe.getMappedPort(8082));
18-
}
19-
2011
@Test
2112
public void testMetrics() throws Exception {
22-
var response = metrics.metrics();
13+
var response = client.metrics().metrics();
2314
assertNotNull(response);
2415
}
2516
}

src/test/java/com/github/tadayosi/torchserve/client/TorchServeTestSupport.java

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.github.tadayosi.torchserve.client;
22

3+
import org.junit.jupiter.api.BeforeEach;
34
import org.testcontainers.containers.GenericContainer;
45
import org.testcontainers.containers.wait.strategy.Wait;
56
import org.testcontainers.junit.jupiter.Container;
@@ -15,6 +16,17 @@ public class TorchServeTestSupport {
1516
.withExposedPorts(8080, 8081, 8082)
1617
.withCopyFileToContainer(MountableFile.forClasspathResource("config.properties"), "/home/model-server/config.properties")
1718
.withCopyFileToContainer(MountableFile.forClasspathResource("models/squeezenet1_1.mar"), "/home/model-server/model-store/squeezenet1_1.mar")
18-
.waitingFor(Wait.forListeningPorts(8080))
19+
.waitingFor(Wait.forListeningPorts(8080, 8081, 8082))
1920
.withCommand("torchserve --ncs --disable-token-auth --enable-model-api --model-store /home/model-server/model-store --models squeezenet1_1.mar");
21+
22+
protected TorchServeClient client;
23+
24+
@BeforeEach
25+
public void setUp() {
26+
client = TorchServeClient.builder()
27+
.inferencePort(torchServe.getMappedPort(8080))
28+
.managementPort(torchServe.getMappedPort(8081))
29+
.metricsPort(torchServe.getMappedPort(8082))
30+
.build();
31+
}
2032
}

0 commit comments

Comments
 (0)