Skip to content

Commit a720399

Browse files
committed
Support configuration and token authorization
1 parent 8b455ed commit a720399

File tree

13 files changed

+256
-41
lines changed

13 files changed

+256
-41
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
[![Release](https://jitpack.io/v/tadayosi/torchserve-client-java.svg)](<https://jitpack.io/#tadayosi/torchserve-client-java>)
44
[![Test](https://github.com/tadayosi/torchserve-client-java/actions/workflows/test.yml/badge.svg)](https://github.com/tadayosi/torchserve-client-java/actions/workflows/test.yml)
55

6-
TorchServe Client for Java is a Java client library for [TorchServe](https://pytorch.org/serve/index.html).
6+
TorchServe Client for Java (TSC4J) is a Java client library for [TorchServe](https://pytorch.org/serve/index.html).
77

88
## Install
99

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package com.github.tadayosi.torchserve.client;
2+
3+
import java.io.InputStream;
4+
import java.util.Optional;
5+
import java.util.Properties;
6+
7+
import org.slf4j.Logger;
8+
import org.slf4j.LoggerFactory;
9+
10+
public class Configuration {
11+
12+
private static final Logger LOG = LoggerFactory.getLogger(Configuration.class);
13+
14+
public static final String TSC4J_PROPERTIES = "tsc4j.properties";
15+
public static final String TSC4J_PREFIX = "tsc4j.";
16+
17+
public static final String INFERENCE_KEY = "inference.key";
18+
public static final String INFERENCE_ADDRESS = "inference.address";
19+
public static final String INFERENCE_PORT = "inference.port";
20+
21+
public static final String MANAGEMENT_KEY = "management.key";
22+
public static final String MANAGEMENT_ADDRESS = "management.address";
23+
public static final String MANAGEMENT_PORT = "management.port";
24+
25+
public static final String METRICS_ADDRESS = "metrics.address";
26+
public static final String METRICS_PORT = "metrics.port";
27+
28+
private Optional<String> inferenceKey;
29+
private Optional<String> inferenceAddress;
30+
private Optional<Integer> inferencePort;
31+
32+
private Optional<String> managementKey;
33+
private Optional<String> managementAddress;
34+
private Optional<Integer> managementPort;
35+
36+
private Optional<String> metricsAddress;
37+
private Optional<Integer> metricsPort;
38+
39+
private Configuration() {
40+
Properties props = loadProperties();
41+
42+
this.inferenceKey = loadProperty(INFERENCE_KEY, props);
43+
this.inferenceAddress = loadProperty(INFERENCE_ADDRESS, props);
44+
this.inferencePort = loadProperty(INFERENCE_PORT, props).map(Integer::parseInt);
45+
this.managementKey = loadProperty(MANAGEMENT_KEY, props);
46+
this.managementAddress = loadProperty(MANAGEMENT_ADDRESS, props);
47+
this.managementPort = loadProperty(MANAGEMENT_PORT, props).map(Integer::parseInt);
48+
this.metricsAddress = loadProperty(METRICS_ADDRESS, props);
49+
this.metricsPort = loadProperty(METRICS_PORT, props).map(Integer::parseInt);
50+
}
51+
52+
static Properties loadProperties() {
53+
Properties properties = new Properties();
54+
try {
55+
InputStream is = Configuration.class.getClassLoader().getResourceAsStream(TSC4J_PROPERTIES);
56+
properties.load(is);
57+
} catch (Exception e) {
58+
// Ignore
59+
LOG.debug("Failed to load properties file: {}", e.getMessage());
60+
}
61+
return properties;
62+
}
63+
64+
/**
65+
* Order of precedence: System properties > environment variables > properties file
66+
*/
67+
static Optional<String> loadProperty(String key, Properties properties) {
68+
String tsc4jKey = TSC4J_PREFIX + key;
69+
return Optional.ofNullable(System.getProperty(tsc4jKey))
70+
.or(() -> Optional.ofNullable(System.getenv(tsc4jKey.toUpperCase().replace(".", "_"))))
71+
.or(() -> Optional.ofNullable(properties.getProperty(key)));
72+
}
73+
74+
public static Configuration load() {
75+
return new Configuration();
76+
}
77+
78+
public Optional<String> getInferenceKey() {
79+
return inferenceKey;
80+
}
81+
82+
public Optional<String> getInferenceAddress() {
83+
return inferenceAddress;
84+
}
85+
86+
public Optional<Integer> getInferencePort() {
87+
return inferencePort;
88+
}
89+
90+
public Optional<String> getManagementKey() {
91+
return managementKey;
92+
}
93+
94+
public Optional<String> getManagementAddress() {
95+
return managementAddress;
96+
}
97+
98+
public Optional<Integer> getManagementPort() {
99+
return managementPort;
100+
}
101+
102+
public Optional<String> getMetricsAddress() {
103+
return metricsAddress;
104+
}
105+
106+
public Optional<Integer> getMetricsPort() {
107+
return metricsPort;
108+
}
109+
}

src/main/java/com/github/tadayosi/torchserve/client/TorchServeClient.java

+56-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.github.tadayosi.torchserve.client;
22

3+
import java.util.Optional;
4+
35
import com.github.tadayosi.torchserve.client.impl.DefaultInference;
46
import com.github.tadayosi.torchserve.client.impl.DefaultManagement;
57
import com.github.tadayosi.torchserve.client.impl.DefaultMetrics;
@@ -42,29 +44,73 @@ public Metrics metrics() {
4244

4345
public static class Builder {
4446

45-
private Integer inferencePort;
46-
private Integer managementPort;
47-
private Integer metricsPort;
47+
private final Configuration configuration = Configuration.load();
48+
49+
private Optional<String> inferenceKey = configuration.getInferenceKey();
50+
private Optional<String> inferenceAddress = configuration.getInferenceAddress();
51+
private Optional<Integer> inferencePort = configuration.getInferencePort();
52+
53+
private Optional<String> managementKey = configuration.getManagementKey();
54+
private Optional<String> managementAddress = configuration.getManagementAddress();
55+
private Optional<Integer> managementPort = configuration.getManagementPort();
56+
57+
private Optional<String> metricsAddress = configuration.getMetricsAddress();
58+
private Optional<Integer> metricsPort = configuration.getMetricsPort();
59+
60+
public Builder inferenceKey(String key) {
61+
this.inferenceKey = Optional.of(key);
62+
return this;
63+
}
64+
65+
public Builder inferenceAddress(String address) {
66+
this.inferenceAddress = Optional.of(address);
67+
return this;
68+
}
4869

4970
public Builder inferencePort(int port) {
50-
this.inferencePort = port;
71+
this.inferencePort = Optional.of(port);
72+
return this;
73+
}
74+
75+
public Builder managementKey(String key) {
76+
this.managementKey = Optional.of(key);
77+
return this;
78+
}
79+
80+
public Builder managementAddress(String address) {
81+
this.managementAddress = Optional.of(address);
5182
return this;
5283
}
5384

5485
public Builder managementPort(int port) {
55-
this.managementPort = port;
86+
this.managementPort = Optional.of(port);
87+
return this;
88+
}
89+
90+
public Builder metricsAddress(String address) {
91+
this.metricsAddress = Optional.of(address);
5692
return this;
5793
}
5894

59-
public Builder metricsPort(Integer metricsPort) {
60-
this.metricsPort = metricsPort;
95+
public Builder metricsPort(Integer port) {
96+
this.metricsPort = Optional.of(port);
6197
return this;
6298
}
6399

64100
public TorchServeClient build() {
65-
Inference inference = inferencePort == null ? new DefaultInference() : new DefaultInference(inferencePort);
66-
Management management = managementPort == null ? new DefaultManagement() : new DefaultManagement(managementPort);
67-
Metrics metrics = metricsPort == null ? new DefaultMetrics() : new DefaultMetrics(metricsPort);
101+
DefaultInference inference = inferenceAddress.map(DefaultInference::new)
102+
.orElse(inferencePort.map(DefaultInference::new)
103+
.orElse(new DefaultInference()));
104+
inferenceKey.ifPresent(inference::setAuthToken);
105+
106+
DefaultManagement management = managementAddress.map(DefaultManagement::new)
107+
.orElse(managementPort.map(DefaultManagement::new)
108+
.orElse(new DefaultManagement()));
109+
managementKey.ifPresent(management::setAuthToken);
110+
111+
DefaultMetrics metrics = metricsAddress.map(DefaultMetrics::new)
112+
.orElse(metricsPort.map(DefaultMetrics::new)
113+
.orElse(new DefaultMetrics()));
68114
return new TorchServeClient(inference, management, metrics);
69115
}
70116
}

src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultInference.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,18 @@ public DefaultInference() {
1616
}
1717

1818
public DefaultInference(int port) {
19-
ApiClient client = new ApiClient().setBasePath("http://localhost:" + port);
19+
this("http://localhost:" + port);
20+
}
21+
22+
public DefaultInference(String address) {
23+
ApiClient client = new ApiClient().setBasePath(address);
2024
this.api = new DefaultApi(client);
2125
}
2226

27+
public void setAuthToken(String token) {
28+
api.getApiClient().addDefaultHeader("Authorization", "Bearer " + token);
29+
}
30+
2331
@Override
2432
public Api apiDescription() throws ApiException {
2533
try {

src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultManagement.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,18 @@ public DefaultManagement() {
2323
}
2424

2525
public DefaultManagement(int port) {
26-
ApiClient client = new ApiClient().setBasePath("http://localhost:" + port);
26+
this("http://localhost:" + port);
27+
}
28+
29+
public DefaultManagement(String address) {
30+
ApiClient client = new ApiClient().setBasePath(address);
2731
this.api = new DefaultApi(client);
2832
}
2933

34+
public void setAuthToken(String token) {
35+
api.getApiClient().addDefaultHeader("Authorization", "Bearer " + token);
36+
}
37+
3038
@Override
3139
public Response registerModel(String url, RegisterModelOptions options) throws ApiException {
3240
try {
@@ -145,5 +153,4 @@ public Api apiDescription() throws ApiException {
145153
public Object token(String type) throws ApiException {
146154
throw new UnsupportedOperationException("Not supported yet");
147155
}
148-
149156
}

src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultMetrics.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ public DefaultMetrics() {
1414
}
1515

1616
public DefaultMetrics(int port) {
17-
ApiClient client = new ApiClient().setBasePath("http://localhost:" + port);
17+
this("http://localhost:" + port);
18+
}
19+
20+
public DefaultMetrics(String address) {
21+
ApiClient client = new ApiClient().setBasePath(address);
1822
this.api = new DefaultApi(client);
1923
}
2024

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package com.github.tadayosi.torchserve.client;
2+
3+
import com.github.tadayosi.torchserve.client.model.ApiException;
4+
import org.junit.jupiter.api.Test;
5+
6+
import static org.junit.jupiter.api.Assertions.assertEquals;
7+
import static org.junit.jupiter.api.Assertions.assertNotNull;
8+
9+
class ConfigurationTest {
10+
11+
@Test
12+
void testLoad() {
13+
var config = Configuration.load();
14+
assertNotNull(config);
15+
}
16+
17+
@Test
18+
void testSystemProperties() {
19+
System.setProperty("tsc4j.inference.key", "aaaaa");
20+
System.setProperty("tsc4j.inference.address", "https://test.com:8180");
21+
System.setProperty("tsc4j.inference.port", "8180");
22+
System.setProperty("tsc4j.management.key", "bbbbb");
23+
System.setProperty("tsc4j.management.address", "https://test.com:8181");
24+
System.setProperty("tsc4j.management.port", "8181");
25+
System.setProperty("tsc4j.metrics.address", "https://test.com:8182");
26+
System.setProperty("tsc4j.metrics.port", "8182");
27+
28+
var config = Configuration.load();
29+
30+
assertEquals("aaaaa", config.getInferenceKey().get());
31+
assertEquals("https://test.com:8180", config.getInferenceAddress().get());
32+
assertEquals(8180, config.getInferencePort().get());
33+
assertEquals("bbbbb", config.getManagementKey().get());
34+
assertEquals("https://test.com:8181", config.getManagementAddress().get());
35+
assertEquals(8181, config.getManagementPort().get());
36+
assertEquals("https://test.com:8182", config.getMetricsAddress().get());
37+
assertEquals(8182, config.getMetricsPort().get());
38+
}
39+
}

src/test/java/com/github/tadayosi/torchserve/client/InferenceTest.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,40 @@
1313
import static org.junit.jupiter.api.Assertions.assertThrows;
1414

1515
@Testcontainers
16-
public class InferenceTest extends TorchServeTestSupport {
16+
class InferenceTest extends TorchServeTestSupport {
1717

1818
private static final String DEFAULT_MODEL = "squeezenet1_1";
1919
private static final String DEFAULT_MODEL_VERSION = "1.0";
2020
private static final String TEST_DATA = "src/test/resources/data/kitten.jpg";
2121

2222
@Test
23-
public void testApiDescription() throws Exception {
23+
void testApiDescription() throws Exception {
2424
var response = client.inference().apiDescription();
2525
assertNotNull(response);
2626
}
2727

2828
@Test
29-
public void testPing() throws Exception {
29+
void testPing() throws Exception {
3030
var response = client.inference().ping();
3131
assertEquals("Healthy", response.getStatus());
3232
}
3333

3434
@Test
35-
public void testPredictions() throws Exception {
35+
void testPredictions() throws Exception {
3636
var body = Files.readAllBytes(Path.of(TEST_DATA));
3737
var response = client.inference().predictions(DEFAULT_MODEL, body);
3838
assertInstanceOf(Map.class, response);
3939
}
4040

4141
@Test
42-
public void testPredictions_version() throws Exception {
42+
void testPredictions_version() throws Exception {
4343
var body = Files.readAllBytes(Path.of(TEST_DATA));
4444
var response = client.inference().predictions(DEFAULT_MODEL, DEFAULT_MODEL_VERSION, body);
4545
assertInstanceOf(Map.class, response);
4646
}
4747

4848
@Test
49-
public void testExplanations() {
49+
void testExplanations() {
5050
assertThrows(UnsupportedOperationException.class, () -> client.inference().explanations(DEFAULT_MODEL));
5151
}
5252

0 commit comments

Comments
 (0)