diff --git a/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java new file mode 100644 index 0000000000..103ac0f7c5 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java @@ -0,0 +1,267 @@ +package org.opensearch.ml.common.connector; + +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; +import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE; +import static org.opensearch.ml.common.connector.RetryBackoffPolicy.CONSTANT; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.BiFunction; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.search.SearchModule; + +public class McpConnectorTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + BiFunction encryptFunction; + BiFunction decryptFunction; + + String TEST_CONNECTOR_JSON_STRING = + "{\"name\":\"test_mcp_connector_name\",\"version\":\"1\",\"description\":\"this is a test mcp connector\",\"protocol\":\"mcp_sse\",\"credential\":{\"key\":\"test_key_value\"},\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\",\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"},\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"}}"; + + @Before + public void setUp() { + encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT); + decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT); + } + + @Test + public void constructor_InvalidProtocol() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Unsupported connector protocol. Please use one of [aws_sigv4, http, mcp_sse]"); + + McpConnector.builder().protocol("wrong protocol").build(); + } + + @Test + public void writeTo() throws IOException { + McpConnector connector = createMcpConnector(); + + BytesStreamOutput output = new BytesStreamOutput(); + connector.writeTo(output); + + McpConnector connector2 = new McpConnector(output.bytes().streamInput()); + Assert.assertEquals(connector, connector2); + } + + @Test + public void toXContent() throws IOException { + McpConnector connector = createMcpConnector(); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + connector.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content); + } + + @Test + public void constructor_Parser() throws IOException { + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + TEST_CONNECTOR_JSON_STRING + ); + parser.nextToken(); + + McpConnector connector = new McpConnector("mcp_sse", parser); + Assert.assertEquals("test_mcp_connector_name", connector.getName()); + Assert.assertEquals("1", connector.getVersion()); + Assert.assertEquals("this is a test mcp connector", connector.getDescription()); + Assert.assertEquals("mcp_sse", connector.getProtocol()); + Assert.assertEquals(AccessMode.PUBLIC, connector.getAccess()); + Assert.assertEquals("https://test.com", connector.getUrl()); + connector.decrypt(PREDICT.name(), decryptFunction, null); + Map decryptedCredential = connector.getDecryptedCredential(); + Assert.assertEquals(1, decryptedCredential.size()); + Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key")); + Assert.assertNotNull(connector.getDecryptedHeaders()); + Assert.assertEquals(1, connector.getDecryptedHeaders().size()); + Assert.assertEquals("decrypted: TEST_KEY_VALUE", connector.getDecryptedHeaders().get("api_key")); + } + + @Test + public void cloneConnector() { + McpConnector connector = createMcpConnector(); + Connector connector2 = connector.cloneConnector(); + Assert.assertEquals(connector, connector2); + } + + @Test + public void decrypt() { + McpConnector connector = createMcpConnector(); + connector.decrypt("", decryptFunction, null); + Map decryptedCredential = connector.getDecryptedCredential(); + Assert.assertEquals(1, decryptedCredential.size()); + Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key")); + Assert.assertNotNull(connector.getDecryptedHeaders()); + Assert.assertEquals(1, connector.getDecryptedHeaders().size()); + Assert.assertEquals("decrypted: TEST_KEY_VALUE", connector.getDecryptedHeaders().get("api_key")); + + connector.removeCredential(); + Assert.assertNull(connector.getCredential()); + Assert.assertNull(connector.getDecryptedCredential()); + Assert.assertNull(connector.getDecryptedHeaders()); + } + + @Test + public void encrypt() { + McpConnector connector = createMcpConnector(); + connector.encrypt(encryptFunction, null); + Map credential = connector.getCredential(); + Assert.assertEquals(1, credential.size()); + Assert.assertEquals("encrypted: test_key_value", credential.get("key")); + + connector.removeCredential(); + Assert.assertNull(connector.getCredential()); + Assert.assertNull(connector.getDecryptedCredential()); + Assert.assertNull(connector.getDecryptedHeaders()); + } + + @Test + public void validateConnectorURL_Invalid() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Connector URL is not matching the trusted connector endpoint regex"); + McpConnector connector = createMcpConnector(); + connector + .validateConnectorURL( + Arrays + .asList( + "^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", + "^https://api\\.openai\\.com/.*$", + "^https://api\\.cohere\\.ai/.*$", + "^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$" + ) + ); + } + + @Test + public void validateConnectorURL() { + McpConnector connector = createMcpConnector(); + connector + .validateConnectorURL( + Arrays + .asList( + "^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", + "^https://api\\.openai\\.com/.*$", + "^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$", + "^" + connector.getUrl() + ) + ); + } + + @Test + public void testUpdate() { + McpConnector connector = createMcpConnector(); + Map initialCredential = new HashMap<>(connector.getCredential()); + + // Create update content + String updatedName = "updated_name"; + String updatedDescription = "updated description"; + String updatedVersion = "2"; + Map updatedCredential = new HashMap<>(); + updatedCredential.put("new_key", "new_value"); + List updatedBackendRoles = List.of("role3", "role4"); + AccessMode updatedAccessMode = AccessMode.PRIVATE; + ConnectorClientConfig updatedClientConfig = new ConnectorClientConfig(40, 40000, 40000, 20, 20, 5, CONSTANT); + String updatedUrl = "https://updated.test.com"; + Map updatedHeaders = new HashMap<>(); + updatedHeaders.put("new_header", "new_header_value"); + updatedHeaders.put("updated_api_key", "${credential.new_key}"); // Referencing new credential key + + MLCreateConnectorInput updateInput = MLCreateConnectorInput.builder() + .name(updatedName) + .description(updatedDescription) + .version(updatedVersion) + .credential(updatedCredential) + .backendRoles(updatedBackendRoles) + .access(updatedAccessMode) + .connectorClientConfig(updatedClientConfig) + .url(updatedUrl) + .headers(updatedHeaders) + .protocol(MCP_SSE) + .build(); + + // Call the update method + connector.update(updateInput, encryptFunction); + + // Assertions + Assert.assertEquals(updatedName, connector.getName()); + Assert.assertEquals(updatedDescription, connector.getDescription()); + Assert.assertEquals(updatedVersion, connector.getVersion()); + Assert.assertEquals(MCP_SSE, connector.getProtocol()); // Should not change if not provided + Assert.assertEquals(updatedBackendRoles, connector.getBackendRoles()); + Assert.assertEquals(updatedAccessMode, connector.getAccess()); + Assert.assertEquals(updatedClientConfig, connector.getConnectorClientConfig()); + Assert.assertEquals(updatedUrl, connector.getUrl()); + Assert.assertEquals(updatedHeaders, connector.getHeaders()); + + // Check encrypted credentials + Map currentCredential = connector.getCredential(); + Assert.assertNotNull(currentCredential); + Assert.assertEquals(1, currentCredential.size()); // Should replace old credentials + Assert.assertEquals("encrypted: new_value", currentCredential.get("new_key")); + Assert.assertNotEquals(initialCredential, currentCredential); + + // Check decrypted credentials and headers (need to explicitly decrypt after update) + connector.decrypt("", decryptFunction, null); // Use decrypt function from setUp + Map decryptedCredential = connector.getDecryptedCredential(); + Assert.assertNotNull(decryptedCredential); + Assert.assertEquals(1, decryptedCredential.size()); + Assert.assertEquals("decrypted: ENCRYPTED: NEW_VALUE", decryptedCredential.get("new_key")); // Uses the decrypt function logic + + Map decryptedHeaders = connector.getDecryptedHeaders(); + Assert.assertNotNull(decryptedHeaders); + Assert.assertEquals(2, decryptedHeaders.size()); + Assert.assertEquals("new_header_value", decryptedHeaders.get("new_header")); + Assert.assertEquals("decrypted: ENCRYPTED: NEW_VALUE", decryptedHeaders.get("updated_api_key")); // Check header substitution + } + + public static McpConnector createMcpConnector() { + Map credential = new HashMap<>(); + credential.put("key", "test_key_value"); + + Map headers = new HashMap<>(); + headers.put("api_key", "${credential.key}"); + + ConnectorClientConfig clientConfig = new ConnectorClientConfig(30, 30000, 30000, 10, 10, -1, RetryBackoffPolicy.CONSTANT); + + return McpConnector + .builder() + .name("test_mcp_connector_name") + .version("1") + .description("this is a test mcp connector") + .protocol(MCP_SSE) + .credential(credential) + .backendRoles(List.of("role1", "role2")) + .accessMode(AccessMode.PUBLIC) + .connectorClientConfig(clientConfig) + .url("https://test.com") + .headers(headers) + .build(); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index e56cb71559..2b8487d1ec 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -7,6 +7,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_PATH; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_TOOL_USE; @@ -44,17 +46,32 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.transport.client.Client; + public class AgentUtilsTest { @Mock private Tool tool1, tool2; + @Mock + private MLAgent mlAgent; + @Mock + private Client client; + @Mock + private SdkClient sdkClient; + @Mock + private Encryptor encryptor; + private Map> llmResponseExpectedParseResults; private String responseForAction = "---------------------\n{\n " @@ -1152,6 +1169,16 @@ public void testParseLLMOutputWithDeepseekFormat() { Assert.assertTrue(output3.get(FINAL_ANSWER).contains("This is a test response")); } + @Test + public void testGetMcpToolSpecs_NoMcpJsonConfig() { + when(mlAgent.getParameters()).thenReturn(null); + + ActionListener> listener = mock(ActionListener.class); + AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, listener); + + verify(listener).onResponse(Collections.emptyList()); + } + private void verifyConstructToolParams(String question, String actionInput, Consumer> verify) { Map tools = Map.of("tool1", tool1); Map toolSpecMap = Map diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java new file mode 100644 index 0000000000..66808cb9bd --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java @@ -0,0 +1,92 @@ +package org.opensearch.ml.engine.algorithms.remote; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.connector.McpConnector; +import org.opensearch.ml.engine.MLStaticMockBase; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; + +public class McpConnectorExecutorTest extends MLStaticMockBase { + + @Mock + private McpConnector mockConnector; + @Mock + private McpSyncClient mcpClient; + @Mock + private McpClient.SyncSpec builder; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + Map decryptedHeaders = Map.of("Authorization", "Bearer secret-token"); + + when(mockConnector.getUrl()).thenReturn("http://random-url"); + when(mockConnector.getDecryptedHeaders()).thenReturn(decryptedHeaders); + + /* ---------- stub the fluent builder chain ------------------------ */ + when(builder.requestTimeout(any())).thenReturn(builder); + when(builder.capabilities(any())).thenReturn(builder); + when(builder.build()).thenReturn(mcpClient); + } + + @Test + public void getMcpToolSpecs_returnsExpectedSpecs() { + + String inputSchemaJSON = + "{\"type\":\"object\",\"properties\":{\"state\":{\"title\":\"State\",\"type\":\"string\"}},\"required\":[\"state\"],\"additionalProperties\":false}"; + + McpSchema.Tool tool = new McpSchema.Tool("tool1", "desc1", inputSchemaJSON); + McpSchema.ListToolsResult mockTools = new McpSchema.ListToolsResult(List.of(tool), null); + + when(mcpClient.listTools()).thenReturn(mockTools); + when(mcpClient.initialize()).thenReturn(null); + + try (MockedStatic mocked = mockStatic(McpClient.class)) { + mocked.when(() -> McpClient.sync(any(McpClientTransport.class))).thenReturn(builder); + McpConnectorExecutor exec = new McpConnectorExecutor(mockConnector); + List specs = exec.getMcpToolSpecs(); + + Assert.assertEquals(1, specs.size()); + MLToolSpec spec = specs.get(0); + Assert.assertEquals("tool1", spec.getName()); + Assert.assertEquals("desc1", spec.getDescription()); + Assert.assertEquals(inputSchemaJSON, spec.getAttributes().get("input_schema")); + Assert.assertSame(mcpClient, spec.getRuntimeResources().get("mcp_sync_client")); + mocked.verify(() -> McpClient.sync(any(McpClientTransport.class))); + verify(builder, times(1)).build(); + verify(mcpClient, times(1)).initialize(); + verify(mcpClient, times(1)).listTools(); + } + } + + @Test + public void getMcpToolSpecs_throwsOnInitError() { + + when(mcpClient.initialize()).thenThrow(new RuntimeException("Error initializing")); + try (MockedStatic mocked = mockStatic(McpClient.class)) { + mocked.when(() -> McpClient.sync(any(McpClientTransport.class))).thenReturn(builder); + McpConnectorExecutor exec = new McpConnectorExecutor(mockConnector); + + assertThrows(RuntimeException.class, () -> exec.getMcpToolSpecs()); + } + } + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java new file mode 100644 index 0000000000..b1b76615e6 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java @@ -0,0 +1,114 @@ +package org.opensearch.ml.engine.tools; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT; + +import java.util.Collections; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.spi.tools.Tool; + +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; + +public class McpSseToolTests { + + @Mock + private McpSyncClient mcpSyncClient; + + @Mock + private ActionListener listener; + + private Tool tool; + private Map validParams; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + // Initialize the tool with the mocked client + tool = McpSseTool.Factory.getInstance().create( + Map.of(MCP_SYNC_CLIENT, mcpSyncClient) + ); + validParams = Map.of("input", "{\"foo\":\"bar\"}"); + } + + @Test + public void testRunSuccess() { + // Arrange: create a CallToolResult wrapping a JSON string + McpSchema.CallToolResult result = new McpSchema.CallToolResult("{\"foo\":\"bar\"}", false); + when(mcpSyncClient.callTool(any(McpSchema.CallToolRequest.class))) + .thenReturn(result); + + // Act + tool.run(validParams, listener); + + // Assert: ensure onResponse is called with the JSON string + verify(listener).onResponse( + "[{\"text\":\"{\\\"foo\\\":\\\"bar\\\"}\"}]" + ); + verify(listener, never()).onFailure(any()); + } + + @Test + public void testRunInvalidJsonInput() { + // Passing a non-JSON string should trigger failure in parsing + Map badParams = Map.of("input", "not-json"); + tool.run(badParams, listener); + + verify(listener).onFailure(any(Exception.class)); + verify(listener, never()).onResponse(any()); + } + + @Test + public void testRunClientThrows() { + // Simulate the MCP client throwing an exception + when(mcpSyncClient.callTool(any())).thenThrow(new RuntimeException("client error")); + + tool.run(validParams, listener); + + verify(listener).onFailure(any(RuntimeException.class)); + verify(listener, never()).onResponse(any()); + } + + @Test + public void testRunMissingInputParam() { + // No "input" key in parameters should also be caught + tool.run(Collections.emptyMap(), listener); + + verify(listener).onFailure(any(Exception.class)); + verify(listener, never()).onResponse(any()); + } + + @Test + public void testValidateAndMetadata() { + // validate + assertTrue(tool.validate(validParams)); + assertFalse(tool.validate(Collections.emptyMap())); + // metadata + assertEquals(McpSseTool.TYPE, tool.getName()); + assertEquals(McpSseTool.TYPE, tool.getType()); + assertNull(tool.getVersion()); + assertEquals(McpSseTool.DEFAULT_DESCRIPTION, tool.getDescription()); + } + + @Test + public void testFactoryDefaults() { + McpSseTool.Factory factory = McpSseTool.Factory.getInstance(); + assertEquals(McpSseTool.DEFAULT_DESCRIPTION, factory.getDefaultDescription()); + assertEquals(McpSseTool.TYPE, factory.getDefaultType()); + assertNull(factory.getDefaultVersion()); + assertTrue(factory.getAllModelKeys().isEmpty()); + } +}