Skip to content

Commit bc5d219

Browse files
[8.19] [ML] Adding timeout to request for creating inference endpoint (#126805) (#127779)
* [ML] Fixing bug with TransportPutModelAction listener and adding timeout to request (#126805) * Fixing bug with listener and adding timeout * Update docs/changelog/126805.yaml * Fixing tests * Fixing writeTo (cherry picked from commit 4c507e2) # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java * Fixing test issue * Fixing comment length
1 parent 9633d6c commit bc5d219

File tree

7 files changed

+90
-33
lines changed

7 files changed

+90
-33
lines changed

docs/changelog/126805.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126805
2+
summary: Adding timeout to request for creating inference endpoint
3+
area: Machine Learning
4+
type: bug
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ static TransportVersion def(int id) {
216216
public static final TransportVersion INTRODUCE_FAILURES_LIFECYCLE_BACKPORT_8_19 = def(8_841_0_25);
217217
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION_BACKPORT_8_19 = def(8_841_0_26);
218218
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
219-
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_28);
219+
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
220+
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
220221

221222
/*
222223
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java

+22-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.core.inference.action;
99

10+
import org.elasticsearch.TransportVersions;
1011
import org.elasticsearch.action.ActionRequestValidationException;
1112
import org.elasticsearch.action.ActionResponse;
1213
import org.elasticsearch.action.ActionType;
@@ -15,6 +16,7 @@
1516
import org.elasticsearch.common.io.stream.StreamInput;
1617
import org.elasticsearch.common.io.stream.StreamOutput;
1718
import org.elasticsearch.common.xcontent.XContentHelper;
19+
import org.elasticsearch.core.TimeValue;
1820
import org.elasticsearch.inference.ModelConfigurations;
1921
import org.elasticsearch.inference.TaskType;
2022
import org.elasticsearch.xcontent.ToXContentObject;
@@ -41,13 +43,15 @@ public static class Request extends AcknowledgedRequest<Request> {
4143
private final String inferenceEntityId;
4244
private final BytesReference content;
4345
private final XContentType contentType;
46+
private final TimeValue timeout;
4447

45-
public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
48+
public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType, TimeValue timeout) {
4649
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
4750
this.taskType = taskType;
4851
this.inferenceEntityId = inferenceEntityId;
4952
this.content = content;
5053
this.contentType = contentType;
54+
this.timeout = timeout;
5155
}
5256

5357
public Request(StreamInput in) throws IOException {
@@ -56,6 +60,12 @@ public Request(StreamInput in) throws IOException {
5660
this.taskType = TaskType.fromStream(in);
5761
this.content = in.readBytesReference();
5862
this.contentType = in.readEnum(XContentType.class);
63+
64+
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
65+
this.timeout = in.readTimeValue();
66+
} else {
67+
this.timeout = InferenceAction.Request.DEFAULT_TIMEOUT;
68+
}
5969
}
6070

6171
public TaskType getTaskType() {
@@ -74,13 +84,21 @@ public XContentType getContentType() {
7484
return contentType;
7585
}
7686

87+
public TimeValue getTimeout() {
88+
return timeout;
89+
}
90+
7791
@Override
7892
public void writeTo(StreamOutput out) throws IOException {
7993
super.writeTo(out);
8094
out.writeString(inferenceEntityId);
8195
taskType.writeTo(out);
8296
out.writeBytesReference(content);
8397
XContentHelper.writeTo(out, contentType);
98+
99+
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
100+
out.writeTimeValue(timeout);
101+
}
84102
}
85103

86104
@Override
@@ -105,12 +123,13 @@ public boolean equals(Object o) {
105123
return taskType == request.taskType
106124
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
107125
&& Objects.equals(content, request.content)
108-
&& contentType == request.contentType;
126+
&& contentType == request.contentType
127+
&& Objects.equals(timeout, request.timeout);
109128
}
110129

111130
@Override
112131
public int hashCode() {
113-
return Objects.hash(taskType, inferenceEntityId, content, contentType);
132+
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout);
114133
}
115134
}
116135

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java

+23-4
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,45 @@ public void setup() throws Exception {
3434

3535
public void testValidate() {
3636
// valid model ID
37-
var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID + "_-0", BYTES, X_CONTENT_TYPE);
37+
var request = new PutInferenceModelAction.Request(
38+
TASK_TYPE,
39+
MODEL_ID + "_-0",
40+
BYTES,
41+
X_CONTENT_TYPE,
42+
InferenceAction.Request.DEFAULT_TIMEOUT
43+
);
3844
ActionRequestValidationException validationException = request.validate();
3945
assertNull(validationException);
4046

4147
// invalid model IDs
4248

43-
var invalidRequest = new PutInferenceModelAction.Request(TASK_TYPE, "", BYTES, X_CONTENT_TYPE);
49+
var invalidRequest = new PutInferenceModelAction.Request(
50+
TASK_TYPE,
51+
"",
52+
BYTES,
53+
X_CONTENT_TYPE,
54+
InferenceAction.Request.DEFAULT_TIMEOUT
55+
);
4456
validationException = invalidRequest.validate();
4557
assertNotNull(validationException);
4658

4759
var invalidRequest2 = new PutInferenceModelAction.Request(
4860
TASK_TYPE,
4961
randomAlphaOfLengthBetween(1, 10) + randomFrom(MlStringsTests.SOME_INVALID_CHARS),
5062
BYTES,
51-
X_CONTENT_TYPE
63+
X_CONTENT_TYPE,
64+
InferenceAction.Request.DEFAULT_TIMEOUT
5265
);
5366
validationException = invalidRequest2.validate();
5467
assertNotNull(validationException);
5568

56-
var invalidRequest3 = new PutInferenceModelAction.Request(TASK_TYPE, null, BYTES, X_CONTENT_TYPE);
69+
var invalidRequest3 = new PutInferenceModelAction.Request(
70+
TASK_TYPE,
71+
null,
72+
BYTES,
73+
X_CONTENT_TYPE,
74+
InferenceAction.Request.DEFAULT_TIMEOUT
75+
);
5776
validationException = invalidRequest3.validate();
5877
assertNotNull(validationException);
5978
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ protected void masterOperation(
181181
return;
182182
}
183183

184-
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
184+
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener);
185185
}
186186

187187
private void parseAndStoreModel(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.List;
2121

2222
import static org.elasticsearch.rest.RestRequest.Method.PUT;
23+
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout;
2324
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
2425
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
2526
import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
@@ -49,8 +50,15 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
4950
taskType = TaskType.ANY; // task type must be defined in the body
5051
}
5152

53+
var inferTimeout = parseTimeout(restRequest);
5254
var content = restRequest.requiredContent();
53-
var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType());
55+
var request = new PutInferenceModelAction.Request(
56+
taskType,
57+
inferenceEntityId,
58+
content,
59+
restRequest.getXContentType(),
60+
inferTimeout
61+
);
5462
return channel -> client.execute(
5563
PutInferenceModelAction.INSTANCE,
5664
request,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java

+28-23
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
package org.elasticsearch.xpack.inference.action;
99

10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
1012
import org.elasticsearch.common.io.stream.Writeable;
1113
import org.elasticsearch.inference.TaskType;
12-
import org.elasticsearch.test.AbstractWireSerializingTestCase;
1314
import org.elasticsearch.xcontent.XContentType;
15+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1416
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
17+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1518

16-
public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase<PutInferenceModelAction.Request> {
19+
public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase<PutInferenceModelAction.Request> {
1720
@Override
1821
protected Writeable.Reader<PutInferenceModelAction.Request> instanceReader() {
1922
return PutInferenceModelAction.Request::new;
@@ -25,38 +28,40 @@ protected PutInferenceModelAction.Request createTestInstance() {
2528
randomFrom(TaskType.values()),
2629
randomAlphaOfLength(6),
2730
randomBytesReference(50),
28-
randomFrom(XContentType.values())
31+
randomFrom(XContentType.values()),
32+
randomTimeValue()
2933
);
3034
}
3135

3236
@Override
3337
protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) {
34-
return switch (randomIntBetween(0, 3)) {
35-
case 0 -> new PutInferenceModelAction.Request(
36-
TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length],
37-
instance.getInferenceEntityId(),
38-
instance.getContent(),
39-
instance.getContentType()
40-
);
41-
case 1 -> new PutInferenceModelAction.Request(
42-
instance.getTaskType(),
43-
instance.getInferenceEntityId() + "foo",
44-
instance.getContent(),
45-
instance.getContentType()
46-
);
47-
case 2 -> new PutInferenceModelAction.Request(
38+
return randomValueOtherThan(instance, this::createTestInstance);
39+
}
40+
41+
@Override
42+
protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) {
43+
if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
44+
return instance;
45+
} else if (version.onOrAfter(TransportVersions.V_8_0_0)) {
46+
return new PutInferenceModelAction.Request(
4847
instance.getTaskType(),
4948
instance.getInferenceEntityId(),
50-
randomBytesReference(instance.getContent().length() + 1),
51-
instance.getContentType()
49+
instance.getContent(),
50+
instance.getContentType(),
51+
InferenceAction.Request.DEFAULT_TIMEOUT
5252
);
53-
case 3 -> new PutInferenceModelAction.Request(
53+
} else {
54+
return new PutInferenceModelAction.Request(
5455
instance.getTaskType(),
5556
instance.getInferenceEntityId(),
5657
instance.getContent(),
57-
XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length]
58+
/*
59+
* See XContentHelper.java#L733
60+
* for versions prior to 8.0.0, the content type does not have the VND_ instances
61+
*/
62+
XContentType.ofOrdinal(instance.getContentType().canonical().ordinal()),
63+
InferenceAction.Request.DEFAULT_TIMEOUT
5864
);
59-
default -> throw new IllegalStateException();
60-
};
65+
}
6166
}
6267
}

0 commit comments

Comments
 (0)