Skip to content

Commit 1eff71b

Browse files
Add Unit Tests for MCP feature (#3787)
* Add tests for MCP feature Signed-off-by: rithin-pullela-aws <[email protected]> * Add more UTs Signed-off-by: rithin-pullela-aws <[email protected]> * cleanup code Signed-off-by: rithin-pullela-aws <[email protected]> * Add license header, rename helper function Signed-off-by: rithin-pullela-aws <[email protected]> --------- Signed-off-by: rithin-pullela-aws <[email protected]> Co-authored-by: Dhrubo Saha <[email protected]>
1 parent 6187de8 commit 1eff71b

File tree

4 files changed

+686
-2
lines changed

4 files changed

+686
-2
lines changed
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector;
7+
8+
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
9+
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE;
10+
import static org.opensearch.ml.common.connector.RetryBackoffPolicy.CONSTANT;
11+
12+
import java.io.IOException;
13+
import java.util.Arrays;
14+
import java.util.Collections;
15+
import java.util.HashMap;
16+
import java.util.List;
17+
import java.util.Locale;
18+
import java.util.Map;
19+
import java.util.function.BiFunction;
20+
21+
import org.junit.Assert;
22+
import org.junit.Before;
23+
import org.junit.Rule;
24+
import org.junit.Test;
25+
import org.junit.rules.ExpectedException;
26+
import org.opensearch.common.io.stream.BytesStreamOutput;
27+
import org.opensearch.common.settings.Settings;
28+
import org.opensearch.common.xcontent.XContentFactory;
29+
import org.opensearch.common.xcontent.XContentType;
30+
import org.opensearch.core.xcontent.NamedXContentRegistry;
31+
import org.opensearch.core.xcontent.ToXContent;
32+
import org.opensearch.core.xcontent.XContentBuilder;
33+
import org.opensearch.core.xcontent.XContentParser;
34+
import org.opensearch.ml.common.AccessMode;
35+
import org.opensearch.ml.common.TestHelper;
36+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
37+
import org.opensearch.search.SearchModule;
38+
39+
public class McpConnectorTest {
40+
@Rule
41+
public ExpectedException exceptionRule = ExpectedException.none();
42+
43+
BiFunction<String, String, String> encryptFunction;
44+
BiFunction<String, String, String> decryptFunction;
45+
46+
String TEST_CONNECTOR_JSON_STRING =
47+
"{\"name\":\"test_mcp_connector_name\",\"version\":\"1\",\"description\":\"this is a test mcp connector\",\"protocol\":\"mcp_sse\",\"credential\":{\"key\":\"test_key_value\"},\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\",\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"},\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"}}";
48+
49+
@Before
50+
public void setUp() {
51+
encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT);
52+
decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT);
53+
}
54+
55+
@Test
56+
public void constructor_InvalidProtocol() {
57+
exceptionRule.expect(IllegalArgumentException.class);
58+
exceptionRule.expectMessage("Unsupported connector protocol. Please use one of [aws_sigv4, http, mcp_sse]");
59+
60+
McpConnector.builder().protocol("wrong protocol").build();
61+
}
62+
63+
@Test
64+
public void writeTo() throws IOException {
65+
McpConnector connector = createMcpConnector();
66+
67+
BytesStreamOutput output = new BytesStreamOutput();
68+
connector.writeTo(output);
69+
70+
McpConnector connector2 = new McpConnector(output.bytes().streamInput());
71+
Assert.assertEquals(connector, connector2);
72+
}
73+
74+
@Test
75+
public void toXContent() throws IOException {
76+
McpConnector connector = createMcpConnector();
77+
78+
XContentBuilder builder = XContentFactory.jsonBuilder();
79+
connector.toXContent(builder, ToXContent.EMPTY_PARAMS);
80+
String content = TestHelper.xContentBuilderToString(builder);
81+
82+
Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content);
83+
}
84+
85+
@Test
86+
public void constructor_Parser() throws IOException {
87+
XContentParser parser = XContentType.JSON
88+
.xContent()
89+
.createParser(
90+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
91+
null,
92+
TEST_CONNECTOR_JSON_STRING
93+
);
94+
parser.nextToken();
95+
96+
McpConnector connector = new McpConnector("mcp_sse", parser);
97+
Assert.assertEquals("test_mcp_connector_name", connector.getName());
98+
Assert.assertEquals("1", connector.getVersion());
99+
Assert.assertEquals("this is a test mcp connector", connector.getDescription());
100+
Assert.assertEquals("mcp_sse", connector.getProtocol());
101+
Assert.assertEquals(AccessMode.PUBLIC, connector.getAccess());
102+
Assert.assertEquals("https://test.com", connector.getUrl());
103+
connector.decrypt(PREDICT.name(), decryptFunction, null);
104+
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
105+
Assert.assertEquals(1, decryptedCredential.size());
106+
Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key"));
107+
Assert.assertNotNull(connector.getDecryptedHeaders());
108+
Assert.assertEquals(1, connector.getDecryptedHeaders().size());
109+
Assert.assertEquals("decrypted: TEST_KEY_VALUE", connector.getDecryptedHeaders().get("api_key"));
110+
}
111+
112+
@Test
113+
public void cloneConnector() {
114+
McpConnector connector = createMcpConnector();
115+
Connector connector2 = connector.cloneConnector();
116+
Assert.assertEquals(connector, connector2);
117+
}
118+
119+
@Test
120+
public void decrypt() {
121+
McpConnector connector = createMcpConnector();
122+
connector.decrypt("", decryptFunction, null);
123+
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
124+
Assert.assertEquals(1, decryptedCredential.size());
125+
Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key"));
126+
Assert.assertNotNull(connector.getDecryptedHeaders());
127+
Assert.assertEquals(1, connector.getDecryptedHeaders().size());
128+
Assert.assertEquals("decrypted: TEST_KEY_VALUE", connector.getDecryptedHeaders().get("api_key"));
129+
130+
connector.removeCredential();
131+
Assert.assertNull(connector.getCredential());
132+
Assert.assertNull(connector.getDecryptedCredential());
133+
Assert.assertNull(connector.getDecryptedHeaders());
134+
}
135+
136+
@Test
137+
public void encrypt() {
138+
McpConnector connector = createMcpConnector();
139+
connector.encrypt(encryptFunction, null);
140+
Map<String, String> credential = connector.getCredential();
141+
Assert.assertEquals(1, credential.size());
142+
Assert.assertEquals("encrypted: test_key_value", credential.get("key"));
143+
144+
connector.removeCredential();
145+
Assert.assertNull(connector.getCredential());
146+
Assert.assertNull(connector.getDecryptedCredential());
147+
Assert.assertNull(connector.getDecryptedHeaders());
148+
}
149+
150+
@Test
151+
public void validateConnectorURL_Invalid() {
152+
exceptionRule.expect(IllegalArgumentException.class);
153+
exceptionRule.expectMessage("Connector URL is not matching the trusted connector endpoint regex");
154+
McpConnector connector = createMcpConnector();
155+
connector
156+
.validateConnectorURL(
157+
Arrays
158+
.asList(
159+
"^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
160+
"^https://api\\.openai\\.com/.*$",
161+
"^https://api\\.cohere\\.ai/.*$",
162+
"^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$"
163+
)
164+
);
165+
}
166+
167+
@Test
168+
public void validateConnectorURL() {
169+
McpConnector connector = createMcpConnector();
170+
connector
171+
.validateConnectorURL(
172+
Arrays
173+
.asList(
174+
"^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
175+
"^https://api\\.openai\\.com/.*$",
176+
"^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$",
177+
"^" + connector.getUrl()
178+
)
179+
);
180+
}
181+
182+
@Test
183+
public void testUpdate() {
184+
McpConnector connector = createMcpConnector();
185+
Map<String, String> initialCredential = new HashMap<>(connector.getCredential());
186+
187+
// Create update content
188+
String updatedName = "updated_name";
189+
String updatedDescription = "updated description";
190+
String updatedVersion = "2";
191+
Map<String, String> updatedCredential = new HashMap<>();
192+
updatedCredential.put("new_key", "new_value");
193+
List<String> updatedBackendRoles = List.of("role3", "role4");
194+
AccessMode updatedAccessMode = AccessMode.PRIVATE;
195+
ConnectorClientConfig updatedClientConfig = new ConnectorClientConfig(40, 40000, 40000, 20, 20, 5, CONSTANT);
196+
String updatedUrl = "https://updated.test.com";
197+
Map<String, String> updatedHeaders = new HashMap<>();
198+
updatedHeaders.put("new_header", "new_header_value");
199+
updatedHeaders.put("updated_api_key", "${credential.new_key}"); // Referencing new credential key
200+
201+
MLCreateConnectorInput updateInput = MLCreateConnectorInput
202+
.builder()
203+
.name(updatedName)
204+
.description(updatedDescription)
205+
.version(updatedVersion)
206+
.credential(updatedCredential)
207+
.backendRoles(updatedBackendRoles)
208+
.access(updatedAccessMode)
209+
.connectorClientConfig(updatedClientConfig)
210+
.url(updatedUrl)
211+
.headers(updatedHeaders)
212+
.protocol(MCP_SSE)
213+
.build();
214+
215+
// Call the update method
216+
connector.update(updateInput, encryptFunction);
217+
218+
// Assertions
219+
Assert.assertEquals(updatedName, connector.getName());
220+
Assert.assertEquals(updatedDescription, connector.getDescription());
221+
Assert.assertEquals(updatedVersion, connector.getVersion());
222+
Assert.assertEquals(MCP_SSE, connector.getProtocol()); // Should not change if not provided
223+
Assert.assertEquals(updatedBackendRoles, connector.getBackendRoles());
224+
Assert.assertEquals(updatedAccessMode, connector.getAccess());
225+
Assert.assertEquals(updatedClientConfig, connector.getConnectorClientConfig());
226+
Assert.assertEquals(updatedUrl, connector.getUrl());
227+
Assert.assertEquals(updatedHeaders, connector.getHeaders());
228+
229+
// Check encrypted credentials
230+
Map<String, String> currentCredential = connector.getCredential();
231+
Assert.assertNotNull(currentCredential);
232+
Assert.assertEquals(1, currentCredential.size()); // Should replace old credentials
233+
Assert.assertEquals("encrypted: new_value", currentCredential.get("new_key"));
234+
Assert.assertNotEquals(initialCredential, currentCredential);
235+
236+
// Check decrypted credentials and headers (need to explicitly decrypt after update)
237+
connector.decrypt("", decryptFunction, null); // Use decrypt function from setUp
238+
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
239+
Assert.assertNotNull(decryptedCredential);
240+
Assert.assertEquals(1, decryptedCredential.size());
241+
Assert.assertEquals("decrypted: ENCRYPTED: NEW_VALUE", decryptedCredential.get("new_key")); // Uses the decrypt function logic
242+
243+
Map<String, String> decryptedHeaders = connector.getDecryptedHeaders();
244+
Assert.assertNotNull(decryptedHeaders);
245+
Assert.assertEquals(2, decryptedHeaders.size());
246+
Assert.assertEquals("new_header_value", decryptedHeaders.get("new_header"));
247+
Assert.assertEquals("decrypted: ENCRYPTED: NEW_VALUE", decryptedHeaders.get("updated_api_key")); // Check header substitution
248+
}
249+
250+
public static McpConnector createMcpConnector() {
251+
Map<String, String> credential = new HashMap<>();
252+
credential.put("key", "test_key_value");
253+
254+
Map<String, String> headers = new HashMap<>();
255+
headers.put("api_key", "${credential.key}");
256+
257+
ConnectorClientConfig clientConfig = new ConnectorClientConfig(30, 30000, 30000, 10, 10, -1, RetryBackoffPolicy.CONSTANT);
258+
259+
return McpConnector
260+
.builder()
261+
.name("test_mcp_connector_name")
262+
.version("1")
263+
.description("this is a test mcp connector")
264+
.protocol(MCP_SSE)
265+
.credential(credential)
266+
.backendRoles(List.of("role1", "role2"))
267+
.accessMode(AccessMode.PUBLIC)
268+
.connectorClientConfig(clientConfig)
269+
.url("https://test.com")
270+
.headers(headers)
271+
.build();
272+
}
273+
}

0 commit comments

Comments
 (0)