Skip to content

Commit cc2bf06

Browse files
authored
[Inference API] Propagate usage context to Elastic Inference Service (#120698) (#121023)
1 parent 9d017d6 commit cc2bf06

10 files changed

+183
-28
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreator.java

+11-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.external.action.elastic;
99

10+
import org.elasticsearch.inference.InputType;
1011
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1112
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1213
import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceSparseEmbeddingsRequestManager;
@@ -29,15 +30,23 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer
2930

3031
private final TraceContext traceContext;
3132

32-
public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents, TraceContext traceContext) {
33+
private final InputType inputType;
34+
35+
public ElasticInferenceServiceActionCreator(
36+
Sender sender,
37+
ServiceComponents serviceComponents,
38+
TraceContext traceContext,
39+
InputType inputType
40+
) {
3341
this.sender = Objects.requireNonNull(sender);
3442
this.serviceComponents = Objects.requireNonNull(serviceComponents);
3543
this.traceContext = traceContext;
44+
this.inputType = inputType;
3645
}
3746

3847
@Override
3948
public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) {
40-
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext);
49+
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext, inputType);
4150
var errorMessage = constructFailedToSendRequestMessage(
4251
model.uri(),
4352
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.inference.InputType;
1415
import org.elasticsearch.xpack.inference.common.Truncator;
1516
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler;
1617
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
@@ -40,6 +41,8 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast
4041

4142
private final TraceContext traceContext;
4243

44+
private final InputType inputType;
45+
4346
private static ResponseHandler createSparseEmbeddingsHandler() {
4447
return new ElasticInferenceServiceResponseHandler(
4548
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER),
@@ -50,12 +53,14 @@ private static ResponseHandler createSparseEmbeddingsHandler() {
5053
public ElasticInferenceServiceSparseEmbeddingsRequestManager(
5154
ElasticInferenceServiceSparseEmbeddingsModel model,
5255
ServiceComponents serviceComponents,
53-
TraceContext traceContext
56+
TraceContext traceContext,
57+
InputType inputType
5458
) {
5559
super(serviceComponents.threadPool(), model);
5660
this.model = model;
5761
this.truncator = serviceComponents.truncator();
5862
this.traceContext = traceContext;
63+
this.inputType = inputType;
5964
}
6065

6166
@Override
@@ -72,7 +77,8 @@ public void execute(
7277
truncator,
7378
truncatedInput,
7479
model,
75-
traceContext
80+
traceContext,
81+
inputType
7682
);
7783
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
7884
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java

+31-3
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
import org.apache.http.entity.ByteArrayEntity;
1313
import org.apache.http.message.BasicHeader;
1414
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.inference.InputType;
1516
import org.elasticsearch.xcontent.XContentType;
1617
import org.elasticsearch.xpack.inference.common.Truncator;
1718
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1819
import org.elasticsearch.xpack.inference.external.request.Request;
1920
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
21+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
2022
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
2123
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
2224

@@ -31,24 +33,30 @@ public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticIn
3133
private final Truncator.TruncationResult truncationResult;
3234
private final Truncator truncator;
3335
private final TraceContextHandler traceContextHandler;
36+
private final InputType inputType;
3437

3538
public ElasticInferenceServiceSparseEmbeddingsRequest(
3639
Truncator truncator,
3740
Truncator.TruncationResult truncationResult,
3841
ElasticInferenceServiceSparseEmbeddingsModel model,
39-
TraceContext traceContext
42+
TraceContext traceContext,
43+
InputType inputType
4044
) {
4145
this.truncator = truncator;
4246
this.truncationResult = truncationResult;
4347
this.model = Objects.requireNonNull(model);
4448
this.uri = model.uri();
4549
this.traceContextHandler = new TraceContextHandler(traceContext);
50+
this.inputType = inputType;
4651
}
4752

4853
@Override
4954
public HttpRequest createHttpRequest() {
5055
var httpPost = new HttpPost(uri);
51-
var requestEntity = Strings.toString(new ElasticInferenceServiceSparseEmbeddingsRequestEntity(truncationResult.input()));
56+
var usageContext = inputTypeToUsageContext(inputType);
57+
var requestEntity = Strings.toString(
58+
new ElasticInferenceServiceSparseEmbeddingsRequestEntity(truncationResult.input(), usageContext)
59+
);
5260

5361
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
5462
httpPost.setEntity(byteEntity);
@@ -76,12 +84,32 @@ public URI getURI() {
7684
@Override
7785
public Request truncate() {
7886
var truncatedInput = truncator.truncate(truncationResult.input());
79-
return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContextHandler.traceContext());
87+
return new ElasticInferenceServiceSparseEmbeddingsRequest(
88+
truncator,
89+
truncatedInput,
90+
model,
91+
traceContextHandler.traceContext(),
92+
inputType
93+
);
8094
}
8195

8296
@Override
8397
public boolean[] getTruncationInfo() {
8498
return truncationResult.truncated().clone();
8599
}
86100

101+
// visible for testing
102+
static ElasticInferenceServiceUsageContext inputTypeToUsageContext(InputType inputType) {
103+
switch (inputType) {
104+
case SEARCH -> {
105+
return ElasticInferenceServiceUsageContext.SEARCH;
106+
}
107+
case INGEST -> {
108+
return ElasticInferenceServiceUsageContext.INGEST;
109+
}
110+
default -> {
111+
return ElasticInferenceServiceUsageContext.UNSPECIFIED;
112+
}
113+
}
114+
}
87115
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntity.java

+14-1
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,22 @@
77

88
package org.elasticsearch.xpack.inference.external.request.elastic;
99

10+
import org.elasticsearch.core.Nullable;
1011
import org.elasticsearch.xcontent.ToXContentObject;
1112
import org.elasticsearch.xcontent.XContentBuilder;
13+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
1214

1315
import java.io.IOException;
1416
import java.util.List;
1517
import java.util.Objects;
1618

17-
public record ElasticInferenceServiceSparseEmbeddingsRequestEntity(List<String> inputs) implements ToXContentObject {
19+
public record ElasticInferenceServiceSparseEmbeddingsRequestEntity(
20+
List<String> inputs,
21+
@Nullable ElasticInferenceServiceUsageContext usageContext
22+
) implements ToXContentObject {
1823

1924
private static final String INPUT_FIELD = "input";
25+
private static final String USAGE_CONTEXT = "usage_context";
2026

2127
public ElasticInferenceServiceSparseEmbeddingsRequestEntity {
2228
Objects.requireNonNull(inputs);
@@ -34,8 +40,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
3440
}
3541

3642
builder.endArray();
43+
44+
// optional field
45+
if ((usageContext == ElasticInferenceServiceUsageContext.UNSPECIFIED) == false) {
46+
builder.field(USAGE_CONTEXT, usageContext);
47+
}
48+
3749
builder.endObject();
3850

3951
return builder;
4052
}
53+
4154
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ protected void doInfer(
329329
var currentTraceInfo = getCurrentTraceInfo();
330330

331331
ElasticInferenceServiceExecutableActionModel elasticInferenceServiceModel = (ElasticInferenceServiceExecutableActionModel) model;
332-
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), currentTraceInfo);
332+
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), currentTraceInfo, inputType);
333333

334334
var action = elasticInferenceServiceModel.accept(actionCreator, taskSettings);
335335
action.execute(inputs, timeout, listener);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic;
9+
10+
import java.util.Locale;
11+
12+
/**
13+
* Specifies the usage context for a request to the Elastic Inference Service.
14+
* This helps to determine the type of resources that are allocated in the Elastic Inference Service for the particular request.
15+
*/
16+
public enum ElasticInferenceServiceUsageContext {
17+
18+
SEARCH,
19+
INGEST,
20+
UNSPECIFIED;
21+
22+
@Override
23+
public String toString() {
24+
return name().toLowerCase(Locale.ROOT);
25+
}
26+
27+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java

+25-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.common.settings.Settings;
1414
import org.elasticsearch.core.TimeValue;
1515
import org.elasticsearch.inference.InferenceServiceResults;
16+
import org.elasticsearch.inference.InputType;
1617
import org.elasticsearch.test.ESTestCase;
1718
import org.elasticsearch.test.http.MockResponse;
1819
import org.elasticsearch.test.http.MockWebServer;
@@ -90,7 +91,12 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce
9091
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
9192

9293
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
93-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
94+
var actionCreator = new ElasticInferenceServiceActionCreator(
95+
sender,
96+
createWithEmptySettings(threadPool),
97+
createTraceContext(),
98+
InputType.UNSPECIFIED
99+
);
94100
var action = actionCreator.create(model);
95101

96102
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -146,7 +152,12 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx
146152
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
147153

148154
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
149-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
155+
var actionCreator = new ElasticInferenceServiceActionCreator(
156+
sender,
157+
createWithEmptySettings(threadPool),
158+
createTraceContext(),
159+
InputType.UNSPECIFIED
160+
);
150161
var action = actionCreator.create(model);
151162

152163
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -198,7 +209,12 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc
198209
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
199210

200211
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
201-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
212+
var actionCreator = new ElasticInferenceServiceActionCreator(
213+
sender,
214+
createWithEmptySettings(threadPool),
215+
createTraceContext(),
216+
InputType.UNSPECIFIED
217+
);
202218
var action = actionCreator.create(model);
203219

204220
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -258,7 +274,12 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
258274

259275
// truncated to 1 token = 3 characters
260276
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), 1);
261-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
277+
var actionCreator = new ElasticInferenceServiceActionCreator(
278+
sender,
279+
createWithEmptySettings(threadPool),
280+
createTraceContext(),
281+
InputType.UNSPECIFIED
282+
);
262283
var action = actionCreator.create(model);
263284

264285
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntityTests.java

+33-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.xcontent.XContentBuilder;
1313
import org.elasticsearch.xcontent.XContentFactory;
1414
import org.elasticsearch.xcontent.XContentType;
15+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
1516

1617
import java.io.IOException;
1718
import java.util.List;
@@ -20,17 +21,23 @@
2021

2122
public class ElasticInferenceServiceSparseEmbeddingsRequestEntityTests extends ESTestCase {
2223

23-
public void testToXContent_SingleInput() throws IOException {
24-
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(List.of("abc"));
24+
public void testToXContent_SingleInput_UnspecifiedUsageContext() throws IOException {
25+
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
26+
List.of("abc"),
27+
ElasticInferenceServiceUsageContext.UNSPECIFIED
28+
);
2529
String xContentString = xContentEntityToString(entity);
2630
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
2731
{
2832
"input": ["abc"]
2933
}"""));
3034
}
3135

32-
public void testToXContent_MultipleInputs() throws IOException {
33-
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(List.of("abc", "def"));
36+
public void testToXContent_MultipleInputs_UnspecifiedUsageContext() throws IOException {
37+
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
38+
List.of("abc", "def"),
39+
ElasticInferenceServiceUsageContext.UNSPECIFIED
40+
);
3441
String xContentString = xContentEntityToString(entity);
3542
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
3643
{
@@ -42,6 +49,28 @@ public void testToXContent_MultipleInputs() throws IOException {
4249
"""));
4350
}
4451

52+
public void testToXContent_MultipleInputs_SearchUsageContext() throws IOException {
53+
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(List.of("abc"), ElasticInferenceServiceUsageContext.SEARCH);
54+
String xContentString = xContentEntityToString(entity);
55+
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
56+
{
57+
"input": ["abc"],
58+
"usage_context": "search"
59+
}
60+
"""));
61+
}
62+
63+
public void testToXContent_MultipleInputs_IngestUsageContext() throws IOException {
64+
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(List.of("abc"), ElasticInferenceServiceUsageContext.INGEST);
65+
String xContentString = xContentEntityToString(entity);
66+
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
67+
{
68+
"input": ["abc"],
69+
"usage_context": "ingest"
70+
}
71+
"""));
72+
}
73+
4574
private String xContentEntityToString(ElasticInferenceServiceSparseEmbeddingsRequestEntity entity) throws IOException {
4675
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
4776
entity.toXContent(builder, null);

0 commit comments

Comments
 (0)