Skip to content

Commit 7ad1a16

Browse files
committed
change header names and rename method
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent 7748952 commit 7ad1a16

File tree

4 files changed

+55
-41
lines changed

4 files changed

+55
-41
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ public class AgentUtils {
145145
public static final String DEFAULT_DATETIME_PREFIX = "Current date and time: ";
146146
private static final ZoneId UTC_ZONE = ZoneId.of("UTC");
147147

148-
public static Map<String, String> extractRequestHeaders(Client client) {
148+
public static Map<String, String> extractMcpRequestHeaders(Client client) {
149149
try {
150150
@SuppressWarnings("unchecked")
151151
Map<String, String> headers = client
@@ -713,7 +713,7 @@ public static void getMcpToolSpecs(
713713
}.getType();
714714
List<Map<String, Object>> mcpConnectorConfigs = gson.fromJson(mcpConnectorConfigJSON, listType);
715715

716-
Map<String, String> requestHeaders = extractRequestHeaders(client);
716+
Map<String, String> requestHeaders = extractMcpRequestHeaders(client);
717717

718718
// Use AtomicInteger to track completion of all async operations
719719
AtomicInteger remainingConnectors = new AtomicInteger(mcpConnectorConfigs.size());

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2002,7 +2002,7 @@ private void mockMcpStreamableHttpConnector(MockedStatic<Connector> connectorSta
20022002
}
20032003

20042004
@Test
2005-
public void testExtractRequestHeaders_WithValidHeaders() {
2005+
public void testExtractMcpRequestHeaders_WithValidHeaders() {
20062006
// Setup ThreadContext with headers
20072007
Map<String, String> expectedHeaders = new HashMap<>();
20082008
expectedHeaders.put("x-amzn-fas-accesskey", "access-key-value");
@@ -2014,27 +2014,27 @@ public void testExtractRequestHeaders_WithValidHeaders() {
20142014

20152015
realThreadContext.putTransient(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS_THREAD_CONTEXT_KEY, expectedHeaders);
20162016

2017-
Map<String, String> result = AgentUtils.extractRequestHeaders(client);
2017+
Map<String, String> result = AgentUtils.extractMcpRequestHeaders(client);
20182018

20192019
assertEquals(2, result.size());
20202020
assertEquals("access-key-value", result.get("x-amzn-fas-accesskey"));
20212021
assertEquals("https://example.aos.us-east-1.on.aws", result.get("x-amzn-datasources"));
20222022
}
20232023

20242024
@Test
2025-
public void testExtractRequestHeaders_WithNoHeaders() {
2025+
public void testExtractMcpRequestHeaders_WithNoHeaders() {
20262026
ThreadContext realThreadContext = new ThreadContext(Settings.EMPTY);
20272027
when(client.threadPool()).thenReturn(threadPool);
20282028
when(threadPool.getThreadContext()).thenReturn(realThreadContext);
20292029

2030-
Map<String, String> result = AgentUtils.extractRequestHeaders(client);
2030+
Map<String, String> result = AgentUtils.extractMcpRequestHeaders(client);
20312031

20322032
assertEquals(0, result.size());
20332033
assertEquals(Collections.emptyMap(), result);
20342034
}
20352035

20362036
@Test
2037-
public void testExtractRequestHeaders_WithEmptyHeaders() {
2037+
public void testExtractMcpRequestHeaders_WithEmptyHeaders() {
20382038
Map<String, String> emptyHeaders = new HashMap<>();
20392039

20402040
ThreadContext realThreadContext = new ThreadContext(Settings.EMPTY);
@@ -2043,26 +2043,26 @@ public void testExtractRequestHeaders_WithEmptyHeaders() {
20432043

20442044
realThreadContext.putTransient(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS_THREAD_CONTEXT_KEY, emptyHeaders);
20452045

2046-
Map<String, String> result = AgentUtils.extractRequestHeaders(client);
2046+
Map<String, String> result = AgentUtils.extractMcpRequestHeaders(client);
20472047

20482048
assertEquals(0, result.size());
20492049
assertEquals(emptyHeaders, result);
20502050
}
20512051

20522052
@Test
2053-
public void testExtractRequestHeaders_WithException() {
2053+
public void testExtractMcpRequestHeaders_WithException() {
20542054
// Setup mock to throw exception
20552055
when(client.threadPool()).thenReturn(threadPool);
20562056
when(threadPool.getThreadContext()).thenThrow(new RuntimeException("ThreadContext access failed"));
20572057

2058-
Map<String, String> result = AgentUtils.extractRequestHeaders(client);
2058+
Map<String, String> result = AgentUtils.extractMcpRequestHeaders(client);
20592059

20602060
assertEquals(0, result.size());
20612061
assertEquals(Collections.emptyMap(), result);
20622062
}
20632063

20642064
@Test
2065-
public void testExtractRequestHeaders_WithPartialHeaders() {
2065+
public void testExtractMcpRequestHeaders_WithPartialHeaders() {
20662066
Map<String, String> partialHeaders = new HashMap<>();
20672067
partialHeaders.put("x-amzn-fas-accesskey", "access-key-value");
20682068

@@ -2072,7 +2072,7 @@ public void testExtractRequestHeaders_WithPartialHeaders() {
20722072

20732073
realThreadContext.putTransient(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS_THREAD_CONTEXT_KEY, partialHeaders);
20742074

2075-
Map<String, String> result = AgentUtils.extractRequestHeaders(client);
2075+
Map<String, String> result = AgentUtils.extractMcpRequestHeaders(client);
20762076

20772077
assertEquals(1, result.size());
20782078
assertEquals("access-key-value", result.get("x-amzn-fas-accesskey"));

plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,12 @@ public class RestActionUtils {
8585
public static final String OPENDISTRO_SECURITY_USER = OPENDISTRO_SECURITY_CONFIG_PREFIX + "user";
8686

8787
// Header names for MCP request passthrough
88-
private static final String HEADER_FAS_ACCESS_KEY = "x-amzn-fas-accesskey";
89-
private static final String HEADER_FAS_SECRET_KEY = "x-amzn-fas-secretkey";
90-
private static final String HEADER_FAS_SESSION_TOKEN = "x-amzn-fas-sessiontoken";
91-
private static final String HEADER_DATASOURCES = "x-amzn-datasources";
88+
private static final String HEADER_AWS_ACCESS_KEY_ID = "aws-access-key-id";
89+
private static final String HEADER_AWS_SECRET_ACCESS_KEY = "aws-secret-access-key";
90+
private static final String HEADER_AWS_SESSION_TOKEN = "aws-session-token";
91+
private static final String HEADER_AWS_REGION = "aws-region";
92+
private static final String HEADER_AWS_SERVICE_NAME = "aws-service-name";
93+
private static final String HEADER_OPENSEARCH_URL = "opensearch-url";
9294

9395
static final Set<LdapName> adminDn = new HashSet<>();
9496
static final Set<String> adminUsernames = new HashSet<String>();
@@ -347,28 +349,36 @@ public static String getActionTypeFromRestRequest(RestRequest request) {
347349

348350
/**
349351
* Extracts MCP (Model Context Protocol) request headers from the REST request and stores them in ThreadContext.
350-
* Extracts FAS credentials and datasources headers for forwarding to MCP connectors.
352+
* Extracts AWS credentials, region, service name, and OpenSearch URL headers for forwarding to MCP connectors.
351353
*
352354
* @param request RestRequest containing the MCP headers
353355
* @param client Client to access ThreadContext
354356
*/
355357
public static void storeMcpRequestHeaders(RestRequest request, Client client) {
356358
Map<String, String> headers = new HashMap<>();
357-
String accessKey = request.header(HEADER_FAS_ACCESS_KEY);
358-
if (accessKey != null && !accessKey.isEmpty()) {
359-
headers.put(HEADER_FAS_ACCESS_KEY, accessKey);
359+
String accessKeyId = request.header(HEADER_AWS_ACCESS_KEY_ID);
360+
if (accessKeyId != null && !accessKeyId.isEmpty()) {
361+
headers.put(HEADER_AWS_ACCESS_KEY_ID, accessKeyId);
360362
}
361-
String secretKey = request.header(HEADER_FAS_SECRET_KEY);
362-
if (secretKey != null && !secretKey.isEmpty()) {
363-
headers.put(HEADER_FAS_SECRET_KEY, secretKey);
363+
String secretAccessKey = request.header(HEADER_AWS_SECRET_ACCESS_KEY);
364+
if (secretAccessKey != null && !secretAccessKey.isEmpty()) {
365+
headers.put(HEADER_AWS_SECRET_ACCESS_KEY, secretAccessKey);
364366
}
365-
String sessionToken = request.header(HEADER_FAS_SESSION_TOKEN);
367+
String sessionToken = request.header(HEADER_AWS_SESSION_TOKEN);
366368
if (sessionToken != null && !sessionToken.isEmpty()) {
367-
headers.put(HEADER_FAS_SESSION_TOKEN, sessionToken);
369+
headers.put(HEADER_AWS_SESSION_TOKEN, sessionToken);
368370
}
369-
String datasources = request.header(HEADER_DATASOURCES);
370-
if (datasources != null && !datasources.isEmpty()) {
371-
headers.put(HEADER_DATASOURCES, datasources);
371+
String region = request.header(HEADER_AWS_REGION);
372+
if (region != null && !region.isEmpty()) {
373+
headers.put(HEADER_AWS_REGION, region);
374+
}
375+
String serviceName = request.header(HEADER_AWS_SERVICE_NAME);
376+
if (serviceName != null && !serviceName.isEmpty()) {
377+
headers.put(HEADER_AWS_SERVICE_NAME, serviceName);
378+
}
379+
String opensearchUrl = request.header(HEADER_OPENSEARCH_URL);
380+
if (opensearchUrl != null && !opensearchUrl.isEmpty()) {
381+
headers.put(HEADER_OPENSEARCH_URL, opensearchUrl);
372382
}
373383
if (!headers.isEmpty()) {
374384
client.threadPool().getThreadContext().putTransient(CommonValue.MCP_REQUEST_HEADERS_THREAD_CONTEXT_KEY, headers);

plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -400,10 +400,12 @@ public synchronized Throwable getCause() {
400400
public void testStoreMcpRequestHeaders_withAllHeaders() {
401401
// Setup
402402
Map<String, List<String>> headers = new HashMap<>();
403-
headers.put("x-amzn-fas-accesskey", List.of("access-key-value"));
404-
headers.put("x-amzn-fas-secretkey", List.of("secret-key-value"));
405-
headers.put("x-amzn-fas-sessiontoken", List.of("session-token-value"));
406-
headers.put("x-amzn-datasources", List.of("https://example.aos.us-east-1.on.aws"));
403+
headers.put("aws-access-key-id", List.of("access-key-value"));
404+
headers.put("aws-secret-access-key", List.of("secret-key-value"));
405+
headers.put("aws-session-token", List.of("session-token-value"));
406+
headers.put("aws-region", List.of("us-east-1"));
407+
headers.put("aws-service-name", List.of("es"));
408+
headers.put("opensearch-url", List.of("https://example.aos.us-east-1.on.aws"));
407409

408410
FakeRestRequest request = new FakeRestRequest.Builder(xContentRegistry())
409411
.withMethod(RestRequest.Method.POST)
@@ -425,11 +427,13 @@ public void testStoreMcpRequestHeaders_withAllHeaders() {
425427
Map<String, String> storedHeaders = threadContext
426428
.getTransient(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS_THREAD_CONTEXT_KEY);
427429
assertNotNull(storedHeaders);
428-
assertEquals(4, storedHeaders.size());
429-
assertEquals("access-key-value", storedHeaders.get("x-amzn-fas-accesskey"));
430-
assertEquals("secret-key-value", storedHeaders.get("x-amzn-fas-secretkey"));
431-
assertEquals("session-token-value", storedHeaders.get("x-amzn-fas-sessiontoken"));
432-
assertEquals("https://example.aos.us-east-1.on.aws", storedHeaders.get("x-amzn-datasources"));
430+
assertEquals(6, storedHeaders.size());
431+
assertEquals("access-key-value", storedHeaders.get("aws-access-key-id"));
432+
assertEquals("secret-key-value", storedHeaders.get("aws-secret-access-key"));
433+
assertEquals("session-token-value", storedHeaders.get("aws-session-token"));
434+
assertEquals("us-east-1", storedHeaders.get("aws-region"));
435+
assertEquals("es", storedHeaders.get("aws-service-name"));
436+
assertEquals("https://example.aos.us-east-1.on.aws", storedHeaders.get("opensearch-url"));
433437
}
434438

435439
@Test
@@ -466,8 +470,8 @@ public void testGetMcpRequestHeaders_withHeaders() {
466470
when(threadPool.getThreadContext()).thenReturn(threadContext);
467471

468472
Map<String, String> expectedHeaders = new HashMap<>();
469-
expectedHeaders.put("x-amzn-fas-accesskey", "access-key-value");
470-
expectedHeaders.put("x-amzn-datasources", "https://example.aos.us-east-1.on.aws");
473+
expectedHeaders.put("aws-access-key-id", "access-key-value");
474+
expectedHeaders.put("opensearch-url", "https://example.aos.us-east-1.on.aws");
471475
threadContext.putTransient(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS_THREAD_CONTEXT_KEY, expectedHeaders);
472476

473477
// Execute
@@ -476,8 +480,8 @@ public void testGetMcpRequestHeaders_withHeaders() {
476480
// Verify
477481
assertEquals(expectedHeaders, result);
478482
assertEquals(2, result.size());
479-
assertEquals("access-key-value", result.get("x-amzn-fas-accesskey"));
480-
assertEquals("https://example.aos.us-east-1.on.aws", result.get("x-amzn-datasources"));
483+
assertEquals("access-key-value", result.get("aws-access-key-id"));
484+
assertEquals("https://example.aos.us-east-1.on.aws", result.get("opensearch-url"));
481485
}
482486

483487
@Test

0 commit comments

Comments
 (0)