Skip to content

Add Unit Tests for MCP feature #3787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
package org.opensearch.ml.common.connector;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE;
import static org.opensearch.ml.common.connector.RetryBackoffPolicy.CONSTANT;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiFunction;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.TestHelper;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.search.SearchModule;

public class McpConnectorTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

BiFunction<String, String, String> encryptFunction;
BiFunction<String, String, String> decryptFunction;

String TEST_CONNECTOR_JSON_STRING =
"{\"name\":\"test_mcp_connector_name\",\"version\":\"1\",\"description\":\"this is a test mcp connector\",\"protocol\":\"mcp_sse\",\"credential\":{\"key\":\"test_key_value\"},\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\",\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"},\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"}}";

@Before
public void setUp() {
encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT);
decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT);
}

@Test
public void constructor_InvalidProtocol() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Unsupported connector protocol. Please use one of [aws_sigv4, http, mcp_sse]");

McpConnector.builder().protocol("wrong protocol").build();
}

@Test
public void writeTo() throws IOException {
McpConnector connector = createMcpConnector();

BytesStreamOutput output = new BytesStreamOutput();
connector.writeTo(output);

McpConnector connector2 = new McpConnector(output.bytes().streamInput());
Assert.assertEquals(connector, connector2);
}

@Test
public void toXContent() throws IOException {
McpConnector connector = createMcpConnector();

XContentBuilder builder = XContentFactory.jsonBuilder();
connector.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);

Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content);
}

@Test
public void constructor_Parser() throws IOException {
XContentParser parser = XContentType.JSON
.xContent()
.createParser(
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
null,
TEST_CONNECTOR_JSON_STRING
);
parser.nextToken();

McpConnector connector = new McpConnector("mcp_sse", parser);
Assert.assertEquals("test_mcp_connector_name", connector.getName());
Assert.assertEquals("1", connector.getVersion());
Assert.assertEquals("this is a test mcp connector", connector.getDescription());
Assert.assertEquals("mcp_sse", connector.getProtocol());
Assert.assertEquals(AccessMode.PUBLIC, connector.getAccess());
Assert.assertEquals("https://test.com", connector.getUrl());
connector.decrypt(PREDICT.name(), decryptFunction, null);
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
Assert.assertEquals(1, decryptedCredential.size());
Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key"));
Assert.assertNotNull(connector.getDecryptedHeaders());
Assert.assertEquals(1, connector.getDecryptedHeaders().size());
Assert.assertEquals("decrypted: TEST_KEY_VALUE", connector.getDecryptedHeaders().get("api_key"));
}

@Test
public void cloneConnector() {
McpConnector connector = createMcpConnector();
Connector connector2 = connector.cloneConnector();
Assert.assertEquals(connector, connector2);
}

@Test
public void decrypt() {
McpConnector connector = createMcpConnector();
connector.decrypt("", decryptFunction, null);
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
Assert.assertEquals(1, decryptedCredential.size());
Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key"));
Assert.assertNotNull(connector.getDecryptedHeaders());
Assert.assertEquals(1, connector.getDecryptedHeaders().size());
Assert.assertEquals("decrypted: TEST_KEY_VALUE", connector.getDecryptedHeaders().get("api_key"));

connector.removeCredential();
Assert.assertNull(connector.getCredential());
Assert.assertNull(connector.getDecryptedCredential());
Assert.assertNull(connector.getDecryptedHeaders());
}

@Test
public void encrypt() {
McpConnector connector = createMcpConnector();
connector.encrypt(encryptFunction, null);
Map<String, String> credential = connector.getCredential();
Assert.assertEquals(1, credential.size());
Assert.assertEquals("encrypted: test_key_value", credential.get("key"));

connector.removeCredential();
Assert.assertNull(connector.getCredential());
Assert.assertNull(connector.getDecryptedCredential());
Assert.assertNull(connector.getDecryptedHeaders());
}

@Test
public void validateConnectorURL_Invalid() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Connector URL is not matching the trusted connector endpoint regex");
McpConnector connector = createMcpConnector();
connector
.validateConnectorURL(
Arrays
.asList(
"^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://api\\.openai\\.com/.*$",
"^https://api\\.cohere\\.ai/.*$",
"^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$"
)
);
}

@Test
public void validateConnectorURL() {
McpConnector connector = createMcpConnector();
connector
.validateConnectorURL(
Arrays
.asList(
"^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://api\\.openai\\.com/.*$",
"^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$",
"^" + connector.getUrl()
)
);
}

@Test
public void testUpdate() {
McpConnector connector = createMcpConnector();
Map<String, String> initialCredential = new HashMap<>(connector.getCredential());

// Create update content
String updatedName = "updated_name";
String updatedDescription = "updated description";
String updatedVersion = "2";
Map<String, String> updatedCredential = new HashMap<>();
updatedCredential.put("new_key", "new_value");
List<String> updatedBackendRoles = List.of("role3", "role4");
AccessMode updatedAccessMode = AccessMode.PRIVATE;
ConnectorClientConfig updatedClientConfig = new ConnectorClientConfig(40, 40000, 40000, 20, 20, 5, CONSTANT);
String updatedUrl = "https://updated.test.com";
Map<String, String> updatedHeaders = new HashMap<>();
updatedHeaders.put("new_header", "new_header_value");
updatedHeaders.put("updated_api_key", "${credential.new_key}"); // Referencing new credential key

MLCreateConnectorInput updateInput = MLCreateConnectorInput.builder()
.name(updatedName)
.description(updatedDescription)
.version(updatedVersion)
.credential(updatedCredential)
.backendRoles(updatedBackendRoles)
.access(updatedAccessMode)
.connectorClientConfig(updatedClientConfig)
.url(updatedUrl)
.headers(updatedHeaders)
.protocol(MCP_SSE)
.build();

// Call the update method
connector.update(updateInput, encryptFunction);

// Assertions
Assert.assertEquals(updatedName, connector.getName());
Assert.assertEquals(updatedDescription, connector.getDescription());
Assert.assertEquals(updatedVersion, connector.getVersion());
Assert.assertEquals(MCP_SSE, connector.getProtocol()); // Should not change if not provided
Assert.assertEquals(updatedBackendRoles, connector.getBackendRoles());
Assert.assertEquals(updatedAccessMode, connector.getAccess());
Assert.assertEquals(updatedClientConfig, connector.getConnectorClientConfig());
Assert.assertEquals(updatedUrl, connector.getUrl());
Assert.assertEquals(updatedHeaders, connector.getHeaders());

// Check encrypted credentials
Map<String, String> currentCredential = connector.getCredential();
Assert.assertNotNull(currentCredential);
Assert.assertEquals(1, currentCredential.size()); // Should replace old credentials
Assert.assertEquals("encrypted: new_value", currentCredential.get("new_key"));
Assert.assertNotEquals(initialCredential, currentCredential);

// Check decrypted credentials and headers (need to explicitly decrypt after update)
connector.decrypt("", decryptFunction, null); // Use decrypt function from setUp
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
Assert.assertNotNull(decryptedCredential);
Assert.assertEquals(1, decryptedCredential.size());
Assert.assertEquals("decrypted: ENCRYPTED: NEW_VALUE", decryptedCredential.get("new_key")); // Uses the decrypt function logic

Map<String, String> decryptedHeaders = connector.getDecryptedHeaders();
Assert.assertNotNull(decryptedHeaders);
Assert.assertEquals(2, decryptedHeaders.size());
Assert.assertEquals("new_header_value", decryptedHeaders.get("new_header"));
Assert.assertEquals("decrypted: ENCRYPTED: NEW_VALUE", decryptedHeaders.get("updated_api_key")); // Check header substitution
}

public static McpConnector createMcpConnector() {
Map<String, String> credential = new HashMap<>();
credential.put("key", "test_key_value");

Map<String, String> headers = new HashMap<>();
headers.put("api_key", "${credential.key}");

ConnectorClientConfig clientConfig = new ConnectorClientConfig(30, 30000, 30000, 10, 10, -1, RetryBackoffPolicy.CONSTANT);

return McpConnector
.builder()
.name("test_mcp_connector_name")
.version("1")
.description("this is a test mcp connector")
.protocol(MCP_SSE)
.credential(credential)
.backendRoles(List.of("role1", "role2"))
.accessMode(AccessMode.PUBLIC)
.connectorClientConfig(clientConfig)
.url("https://test.com")
.headers(headers)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_PATH;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_TOOL_USE;
Expand Down Expand Up @@ -44,17 +46,32 @@
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.transport.client.Client;


public class AgentUtilsTest {

@Mock
private Tool tool1, tool2;

@Mock
private MLAgent mlAgent;
@Mock
private Client client;
@Mock
private SdkClient sdkClient;
@Mock
private Encryptor encryptor;

private Map<String, Map<String, String>> llmResponseExpectedParseResults;

private String responseForAction = "---------------------\n{\n "
Expand Down Expand Up @@ -1152,6 +1169,16 @@ public void testParseLLMOutputWithDeepseekFormat() {
Assert.assertTrue(output3.get(FINAL_ANSWER).contains("This is a test response"));
}

@Test
public void testGetMcpToolSpecs_NoMcpJsonConfig() {
when(mlAgent.getParameters()).thenReturn(null);

ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, listener);

verify(listener).onResponse(Collections.emptyList());
}

private void verifyConstructToolParams(String question, String actionInput, Consumer<Map<String, String>> verify) {
Map<String, Tool> tools = Map.of("tool1", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
Expand Down
Loading
Loading