Skip to content

Commit 144e86f

Browse files
committed
use threadcontext instead of parameters
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent 77d5d79 commit 144e86f

File tree

11 files changed

+210
-114
lines changed

11 files changed

+210
-114
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,10 @@ public class CommonValue {
119119
public static final String MCP_TOOL_DESCRIPTION_FIELD = "description";
120120
public static final String MCP_TOOL_INPUT_SCHEMA_FIELD = "inputSchema";
121121
public static final String MCP_SYNC_CLIENT = "mcp_sync_client";
122-
public static final String MCP_CONNECTOR = "mcp_connector";
123-
public static final String MCP_CONNECTOR_CONFIG = "mcp_connector_config";
124-
public static final String MCP_REQUEST_HEADERS = "mcp_request_headers";
125122
public static final String MCP_TOOLS_FIELD = "tools";
126123
public static final String MCP_CONNECTORS_FIELD = "mcp_connectors";
127124
public static final String MCP_CONNECTOR_ID_FIELD = "mcp_connector_id";
125+
public static final String MCP_REQUEST_HEADERS_THREAD_CONTEXT_KEY = "ML_MCP_REQUEST_HEADERS";
128126
public static final String MCP_DEFAULT_SSE_ENDPOINT = "/sse";
129127
public static final String SSE_ENDPOINT_FIELD = "sse_endpoint";
130128
public static final String MCP_DEFAULT_STREAMABLE_HTTP_ENDPOINT = "/mcp/";

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,23 +145,15 @@ public class AgentUtils {
145145
public static final String DEFAULT_DATETIME_PREFIX = "Current date and time: ";
146146
private static final ZoneId UTC_ZONE = ZoneId.of("UTC");
147147

148-
public static Map<String, String> extractRequestHeaders(Map<String, String> parameters) {
149-
if (parameters == null) {
150-
return Collections.emptyMap();
151-
}
152-
153-
String headersJson = parameters.get(CommonValue.MCP_REQUEST_HEADERS);
154-
if (headersJson == null || headersJson.trim().isEmpty()) {
155-
return Collections.emptyMap();
156-
}
157-
148+
public static Map<String, String> extractRequestHeaders(Client client) {
158149
try {
159-
Type mapType = new TypeToken<Map<String, String>>() {
160-
}.getType();
161-
Map<String, String> headers = gson.fromJson(headersJson, mapType);
150+
@SuppressWarnings("unchecked")
151+
Map<String, String> headers = client.threadPool().getThreadContext().getTransient(
152+
CommonValue.MCP_REQUEST_HEADERS_THREAD_CONTEXT_KEY
153+
);
162154
return headers != null ? headers : Collections.emptyMap();
163155
} catch (Exception e) {
164-
log.warn("Failed to parse request headers from JSON: {}", headersJson, e);
156+
log.warn("Failed to retrieve MCP request headers from ThreadContext", e);
165157
return Collections.emptyMap();
166158
}
167159
}
@@ -720,7 +712,7 @@ public static void getMcpToolSpecs(
720712
}.getType();
721713
List<Map<String, Object>> mcpConnectorConfigs = gson.fromJson(mcpConnectorConfigJSON, listType);
722714

723-
Map<String, String> requestHeaders = extractRequestHeaders(params);
715+
Map<String, String> requestHeaders = extractRequestHeaders(client);
724716

725717
// Use AtomicInteger to track completion of all async operations
726718
AtomicInteger remainingConnectors = new AtomicInteger(mcpConnectorConfigs.size());

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,18 @@ public List<MLToolSpec> getMcpToolSpecs(Map<String, String> requestHeaders) {
7474
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
7575
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
7676

77-
Map<String, String> mergedHeaders = new HashMap<>();
78-
if (connector.getDecryptedHeaders() != null) {
79-
mergedHeaders.putAll(connector.getDecryptedHeaders());
80-
}
81-
if (requestHeaders != null) {
82-
mergedHeaders.putAll(requestHeaders);
83-
}
84-
8577
Consumer<HttpRequest.Builder> headerConfig = builder -> {
86-
for (Map.Entry<String, String> entry : mergedHeaders.entrySet()) {
87-
builder.header(entry.getKey(), entry.getValue());
78+
// Add connector headers first
79+
if (connector.getDecryptedHeaders() != null) {
80+
for (Map.Entry<String, String> entry : connector.getDecryptedHeaders().entrySet()) {
81+
builder.header(entry.getKey(), entry.getValue());
82+
}
83+
}
84+
// Add request headers second (they override connector headers)
85+
if (requestHeaders != null) {
86+
for (Map.Entry<String, String> entry : requestHeaders.entrySet()) {
87+
builder.header(entry.getKey(), entry.getValue());
88+
}
8889
}
8990
};
9091

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,18 @@ public List<MLToolSpec> getMcpToolSpecs(Map<String, String> requestHeaders) {
7777
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
7878
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
7979

80-
Map<String, String> mergedHeaders = new HashMap<>();
81-
if (connector.getDecryptedHeaders() != null) {
82-
mergedHeaders.putAll(connector.getDecryptedHeaders());
83-
}
84-
if (requestHeaders != null) {
85-
mergedHeaders.putAll(requestHeaders);
86-
}
87-
8880
Consumer<HttpRequest.Builder> headerConfig = builder -> {
89-
for (Map.Entry<String, String> entry : mergedHeaders.entrySet()) {
90-
builder.header(entry.getKey(), entry.getValue());
81+
// Add connector headers first
82+
if (connector.getDecryptedHeaders() != null) {
83+
for (Map.Entry<String, String> entry : connector.getDecryptedHeaders().entrySet()) {
84+
builder.header(entry.getKey(), entry.getValue());
85+
}
86+
}
87+
// Add request headers second (they override connector headers)
88+
if (requestHeaders != null) {
89+
for (Map.Entry<String, String> entry : requestHeaders.entrySet()) {
90+
builder.header(entry.getKey(), entry.getValue());
91+
}
9192
}
9293
};
9394

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,80 +2003,80 @@ private void mockMcpStreamableHttpConnector(MockedStatic<Connector> connectorSta
20032003

20042004
@Test
20052005
public void testExtractRequestHeaders_WithValidHeaders() {
2006-
Map<String, String> parameters = new HashMap<>();
2007-
parameters
2008-
.put(
2009-
org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS,
2010-
"{\"Authorization\":\"Bearer token123\",\"Content-Type\":\"application/json\"}"
2011-
);
2006+
// Setup ThreadContext with headers
2007+
Map<String, String> expectedHeaders = new HashMap<>();
2008+
expectedHeaders.put("x-amzn-fas-accesskey", "access-key-value");
2009+
expectedHeaders.put("x-amzn-datasources", "https://example.aos.us-east-1.on.aws");
2010+
2011+
ThreadContext realThreadContext = new ThreadContext(Settings.EMPTY);
2012+
when(client.threadPool()).thenReturn(threadPool);
2013+
when(threadPool.getThreadContext()).thenReturn(realThreadContext);
2014+
2015+
realThreadContext.putTransient(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS_THREAD_CONTEXT_KEY, expectedHeaders);
20122016

2013-
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2017+
Map<String, String> result = AgentUtils.extractRequestHeaders(client);
20142018

20152019
assertEquals(2, result.size());
2016-
assertEquals("Bearer token123", result.get("Authorization"));
2017-
assertEquals("application/json", result.get("Content-Type"));
2018-
}
2019-
2020-
@Test
2021-
public void testExtractRequestHeaders_WithNullParameters() {
2022-
Map<String, String> result = AgentUtils.extractRequestHeaders(null);
2023-
2024-
assertEquals(0, result.size());
2025-
assertEquals(Collections.emptyMap(), result);
2026-
}
2027-
2028-
@Test
2029-
public void testExtractRequestHeaders_WithEmptyHeadersJson() {
2030-
Map<String, String> parameters = new HashMap<>();
2031-
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "");
2032-
2033-
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2034-
2035-
assertEquals(0, result.size());
2036-
assertEquals(Collections.emptyMap(), result);
2020+
assertEquals("access-key-value", result.get("x-amzn-fas-accesskey"));
2021+
assertEquals("https://example.aos.us-east-1.on.aws", result.get("x-amzn-datasources"));
20372022
}
20382023

20392024
@Test
2040-
public void testExtractRequestHeaders_WithNullHeadersJson() {
2041-
Map<String, String> parameters = new HashMap<>();
2042-
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, null);
2025+
public void testExtractRequestHeaders_WithNoHeaders() {
2026+
ThreadContext realThreadContext = new ThreadContext(Settings.EMPTY);
2027+
when(client.threadPool()).thenReturn(threadPool);
2028+
when(threadPool.getThreadContext()).thenReturn(realThreadContext);
20432029

2044-
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2030+
Map<String, String> result = AgentUtils.extractRequestHeaders(client);
20452031

20462032
assertEquals(0, result.size());
20472033
assertEquals(Collections.emptyMap(), result);
20482034
}
20492035

20502036
@Test
2051-
public void testExtractRequestHeaders_WithWhitespaceHeadersJson() {
2052-
Map<String, String> parameters = new HashMap<>();
2053-
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, " ");
2037+
public void testExtractRequestHeaders_WithEmptyHeaders() {
2038+
Map<String, String> emptyHeaders = new HashMap<>();
2039+
2040+
ThreadContext realThreadContext = new ThreadContext(Settings.EMPTY);
2041+
when(client.threadPool()).thenReturn(threadPool);
2042+
when(threadPool.getThreadContext()).thenReturn(realThreadContext);
2043+
2044+
realThreadContext.putTransient(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS_THREAD_CONTEXT_KEY, emptyHeaders);
20542045

2055-
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2046+
Map<String, String> result = AgentUtils.extractRequestHeaders(client);
20562047

20572048
assertEquals(0, result.size());
2058-
assertEquals(Collections.emptyMap(), result);
2049+
assertEquals(emptyHeaders, result);
20592050
}
20602051

20612052
@Test
2062-
public void testExtractRequestHeaders_WithInvalidJson() {
2063-
Map<String, String> parameters = new HashMap<>();
2064-
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "{invalid json}");
2053+
public void testExtractRequestHeaders_WithException() {
2054+
// Setup mock to throw exception
2055+
when(client.threadPool()).thenReturn(threadPool);
2056+
when(threadPool.getThreadContext()).thenThrow(new RuntimeException("ThreadContext access failed"));
20652057

2066-
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2058+
Map<String, String> result = AgentUtils.extractRequestHeaders(client);
20672059

20682060
assertEquals(0, result.size());
20692061
assertEquals(Collections.emptyMap(), result);
20702062
}
20712063

20722064
@Test
2073-
public void testExtractRequestHeaders_WithEmptyJsonObject() {
2074-
Map<String, String> parameters = new HashMap<>();
2075-
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "{}");
2065+
public void testExtractRequestHeaders_WithPartialHeaders() {
2066+
Map<String, String> partialHeaders = new HashMap<>();
2067+
partialHeaders.put("x-amzn-fas-accesskey", "access-key-value");
2068+
2069+
ThreadContext realThreadContext = new ThreadContext(Settings.EMPTY);
2070+
when(client.threadPool()).thenReturn(threadPool);
2071+
when(threadPool.getThreadContext()).thenReturn(realThreadContext);
2072+
2073+
realThreadContext.putTransient(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS_THREAD_CONTEXT_KEY, partialHeaders);
20762074

2077-
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2075+
Map<String, String> result = AgentUtils.extractRequestHeaders(client);
20782076

2079-
assertEquals(0, result.size());
2077+
assertEquals(1, result.size());
2078+
assertEquals("access-key-value", result.get("x-amzn-fas-accesskey"));
2079+
assertEquals(partialHeaders, result);
20802080
}
20812081

20822082
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID;
1515
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
1616
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TOOL_NAME;
17-
import static org.opensearch.ml.utils.RestActionUtils.addMcpRequestHeaders;
17+
import static org.opensearch.ml.utils.RestActionUtils.storeMcpRequestHeaders;
1818
import static org.opensearch.ml.utils.RestActionUtils.getAlgorithm;
1919
import static org.opensearch.ml.utils.RestActionUtils.isAsync;
2020
import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;
@@ -78,7 +78,7 @@ public List<Route> routes() {
7878

7979
@Override
8080
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
81-
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(request);
81+
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(request, client);
8282

8383
return channel -> client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, new ActionListener<>() {
8484
@Override
@@ -107,10 +107,11 @@ public void onFailure(Exception e) {
107107
* Creates a MLExecuteTaskRequest from a RestRequest
108108
*
109109
* @param request RestRequest
110+
* @param client NodeClient
110111
* @return MLExecuteTaskRequest
111112
*/
112113
@VisibleForTesting
113-
MLExecuteTaskRequest getRequest(RestRequest request) throws IOException {
114+
MLExecuteTaskRequest getRequest(RestRequest request, NodeClient client) throws IOException {
114115
XContentParser parser = request.contentParser();
115116
boolean async = isAsync(request);
116117
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
@@ -129,7 +130,7 @@ MLExecuteTaskRequest getRequest(RestRequest request) throws IOException {
129130
((AgentMLInput) input).setAgentId(agentId);
130131
((AgentMLInput) input).setTenantId(tenantId);
131132
((AgentMLInput) input).setIsAsync(async);
132-
addMcpRequestHeaders(request, (AgentMLInput) input);
133+
storeMcpRequestHeaders(request, client);
133134
} else if (uri.startsWith(ML_BASE_URI + "/tools/")) {
134135
if (!mlFeatureEnabledSetting.isToolExecuteEnabled()) {
135136
throw new IllegalStateException(ML_COMMONS_EXECUTE_TOOL_DISABLED_MESSAGE);

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;
1515
import static org.opensearch.ml.utils.MLExceptionUtils.STREAM_DISABLED_ERR_MSG;
1616
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID;
17-
import static org.opensearch.ml.utils.RestActionUtils.addMcpRequestHeaders;
17+
import static org.opensearch.ml.utils.RestActionUtils.storeMcpRequestHeaders;
1818
import static org.opensearch.ml.utils.RestActionUtils.isAsync;
1919
import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;
2020

@@ -161,7 +161,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
161161
Flux.from(channel).ofType(HttpChunk.class).collectList().flatMap(chunks -> {
162162
try {
163163
BytesReference completeContent = combineChunks(chunks);
164-
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, completeContent);
164+
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, completeContent, client);
165165

166166
final CompletableFuture<HttpChunk> future = new CompletableFuture<>();
167167
StreamTransportResponseHandler<MLTaskResponse> handler = new StreamTransportResponseHandler<MLTaskResponse>() {
@@ -303,11 +303,14 @@ boolean isModelValid(String modelId, RestRequest request, NodeClient client) thr
303303
/**
304304
* Creates a MLExecuteTaskRequest from a RestRequest
305305
*
306+
* @param agentId Agent ID
306307
* @param request RestRequest
308+
* @param content Request content
309+
* @param client NodeClient
307310
* @return MLExecuteTaskRequest
308311
*/
309312
@VisibleForTesting
310-
MLExecuteTaskRequest getRequest(String agentId, RestRequest request, BytesReference content) throws IOException {
313+
MLExecuteTaskRequest getRequest(String agentId, RestRequest request, BytesReference content, NodeClient client) throws IOException {
311314
XContentParser parser = request
312315
.getMediaType()
313316
.xContent()
@@ -327,7 +330,7 @@ MLExecuteTaskRequest getRequest(String agentId, RestRequest request, BytesRefere
327330
agentInput.setIsAsync(async);
328331
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentInput.getInputDataset();
329332
inputDataSet.getParameters().put("stream", String.valueOf(true));
330-
addMcpRequestHeaders(request, agentInput);
333+
storeMcpRequestHeaders(request, client);
331334
return new MLExecuteTaskRequest(functionName, input);
332335
}
333336

0 commit comments

Comments
 (0)