diff --git a/plugins/connection/mcp.py b/plugins/connection/mcp.py new file mode 100644 index 0000000..5419e8e --- /dev/null +++ b/plugins/connection/mcp.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2025 Red Hat, Inc. +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +DOCUMENTATION = r""" +--- +name: mcp +author: + - Alina Buzachis (@alinabuzachis) +version_added: 1.0.0 +short_description: Persistent connection to an Model Context Protocol (MCP) server +description: + - This connection plugin allows for a persistent connection to an Model Context Protocol (MCP) server. + - It is designed to run once per host for the duration of a playbook, allowing tasks to communicate with a single, long-lived server session. + - Both stdio and Streamable HTTP transport methods are supported. + - All tasks using this connection plugin are run on the Ansible control node. +options: + server_name: + description: + - The name of the MCP server. + type: str + required: true + vars: + - name: ansible_mcp_server_name + server_args: + description: + - Additional command line arguments to pass to the server when using stdio transport. + type: list + elements: str + vars: + - name: ansible_mcp_server_args + env: + - name: MCP_BEARER_TOKEN + server_env: + description: + - Additional environment variables to pass to the server when using stdio transport. + - These are merged with the current environment. + - Ignored when using http transport. + type: dict + vars: + - name: ansible_mcp_server_env + bearer_token: + description: + - Bearer token for authenticating to the MCP server when using http transport. + - Ignored when using stdio transport. + type: str + vars: + - name: ansible_mcp_bearer_token + manifest_path: + description: + - Path to MCP manifest JSON file to resolve server executable paths for stdio. + type: str + default: "/opt/mcp/mcpservers.json" + vars: + - name: ansible_mcp_manifest_path + validate_certs: + description: + - Whether to validate SSL certificates when using http transport. + type: bool + default: true + vars: + - name: ansible_mcp_validate_certs + persistent_connect_timeout: + description: + - Timeout in seconds for initial connection to persistent transport. + type: int + default: 30 + env: + - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT + vars: + - name: ansible_connect_timeout + persistent_command_timeout: + description: + - Timeout for persistent connection commands in seconds. + type: int + default: 30 + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout + persistent_log_messages: + description: + - Enable logging of messages from persistent connection. + - Be sure to fully understand the security implications of enabling this + option as it could create a security vulnerability by logging sensitive information in log file. + type: boolean + default: False + env: + - name: ANSIBLE_PERSISTENT_LOG_MESSAGES + vars: + - name: ansible_persistent_log_messages +""" + + +import json +import os +import time + +from functools import wraps +from typing import Any, Dict + +from ansible.errors import AnsibleConnectionFailure +from ansible.utils.display import Display +from ansible_collections.ansible.utils.plugins.plugin_utils.connection_base import ( + PersistentConnectionBase, +) + +from ansible_collections.ansible.mcp.plugins.plugin_utils.client import MCPClient +from ansible_collections.ansible.mcp.plugins.plugin_utils.transport import ( + Stdio, + StreamableHTTP, + Transport, +) + + +display = Display() + + +def ensure_connected(func): + """Decorator ensuring that a connection is established before a method runs.""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + # Check the connection status + if not self.connected: + display.vvv( + f"MCP connection not established. Calling _connect() for method: {func.__name__}" + ) + # If not connected, establish the connection + try: + self._connect() + except Exception as e: + raise AnsibleConnectionFailure(f"Failed to connect to MCP server: {e}") + # Call the original method + return func(self, *args, **kwargs) + + return wrapper + + +class Connection(PersistentConnectionBase): + """ + Ansible persistent connection plugin for the Model Context Protocol (MCP) server. + """ + + transport = "ansible.mcp.mcp" + has_pipelining = False + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs) + self._client = None + self._connected = False + + @property + def connected(self) -> bool: + """Return True if connected to MCP server.""" + return not self._conn_closed and self._connected and self._client is not None + + def _connect(self): + """ + Establishes the connection and performs the MCP initialization handshake. + This runs only once per host/plugin instance. + """ + if self.connected: + display.vvv("[mcp] Already connected, skipping _connect()") + return + + server_name = self.get_option("server_name") + manifest_path = self.get_option("manifest_path") or "/opt/mcp/mcpservers.json" + + server_info = self._load_server_from_manifest(server_name, manifest_path) + transport = self._create_transport(server_name, server_info) + + # Initialize MCP client + self._client = MCPClient(transport) + + timeout = self.get_option("persistent_connect_timeout") + start_time = time.time() + while True: + try: + self._client.initialize() + break + except Exception as e: + if time.time() - start_time > timeout: + raise AnsibleConnectionFailure( + f"MCP connection timed out after {timeout}s: {e}" + ) + time.sleep(1) + + self._connected = True + display.vvv(f"[mcp] Connection to '{server_name}' successfully initialized") + + def _load_server_from_manifest(self, server_name: str, manifest_path: str) -> dict: + """Load the MCP server info from manifest JSON.""" + if not os.path.exists(manifest_path): + raise AnsibleConnectionFailure(f"MCP manifest not found at {manifest_path}") + + try: + with open(manifest_path, "r", encoding="utf-8") as f: + manifest = json.load(f) + except json.JSONDecodeError as e: + raise AnsibleConnectionFailure(f"[mcp] Failed to parse MCP manifest JSON: {e}") + + if server_name not in manifest: + raise AnsibleConnectionFailure(f"MCP server '{server_name}' not found in manifest") + + return manifest[server_name] + + def _create_transport(self, server_name: str, server_info: dict) -> Transport: + """Create the appropriate transport based on manifest server info.""" + transport_type = server_info.get("type") + + if transport_type == "stdio": + if "command" not in server_info: + raise AnsibleConnectionFailure( + f"[mcp] Manifest for '{server_name}' missing 'command' for stdio transport" + ) + manifest_args = server_info.get("args", []) + plugin_args = self.get_option("server_args") or [] + cmd = [server_info["command"]] + manifest_args + plugin_args + env = self.get_option("server_env") or {} + display.vvv(f"[mcp] Starting stdio MCP server '{server_name}': {' '.join(cmd)}") + return Stdio(cmd=cmd, env=env) + + elif transport_type == "http": + url = server_info.get("url") + + if not url: + raise AnsibleConnectionFailure( + f"[mcp] Manifest for '{server_name}' missing 'url' for http transport" + ) + + headers = {} + token = self.get_option("bearer_token") + if token: + headers["Authorization"] = f"Bearer {token}" + display.vvv(f"[mcp] Connecting to HTTP MCP server '{server_name}': {url}") + return StreamableHTTP( + url=url, headers=headers, validate_certs=self.get_option("validate_certs") + ) + + else: + raise AnsibleConnectionFailure( + f"Invalid transport type '{transport_type}' for server '{server_name}'" + ) + + def close(self) -> None: + """Terminate the persistent connection and reset state.""" + display.vvv("[mcp] Closing MCP connection") + + self._close_client() + super().close() # sets _conn_closed, _connected + + def _close_client(self) -> None: + """Close the MCPClient if it exists and reset the reference.""" + if not self._client: + display.vvv("[mcp] No MCP client to close") + return + + try: + self._client.close() + display.vvv("[mcp] MCP client successfully closed") + except Exception as e: + display.warning(f"[mcp] Error closing MCP client: {e}") + finally: + self._client = None + + @ensure_connected + def list_tools(self) -> Dict[str, Any]: + """Retrieves the list of tools from the MCP server.""" + return self._client.list_tools() + + @ensure_connected + def call_tool(self, tool: str, **kwargs: Any) -> Dict[str, Any]: + """Calls a specific tool on the MCP server.""" + return self._client.call_tool(tool, **kwargs) + + @ensure_connected + def validate(self, tool: str, **kwargs: Any) -> None: + """Validates arguments against a tool's schema (client-side validation).""" + return self._client.validate(tool, **kwargs) + + @ensure_connected + def server_info(self) -> Dict[str, Any]: + """Returns the cached server information from the initialization step.""" + return self._client.server_info diff --git a/tests/unit/plugins/connection/test_connection_mcp.py b/tests/unit/plugins/connection/test_connection_mcp.py new file mode 100644 index 0000000..cc19baf --- /dev/null +++ b/tests/unit/plugins/connection/test_connection_mcp.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2025 Red Hat, Inc. +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +import json + +from io import StringIO +from unittest.mock import MagicMock, patch + +import pytest + +from ansible.errors import AnsibleConnectionFailure +from ansible.playbook.play_context import PlayContext + +from ansible_collections.ansible.mcp.plugins.connection.mcp import Connection + + +@pytest.fixture +def manifest_file(tmp_path): + """Create a temporary MCP manifest JSON file.""" + manifest_data = { + "mcp-hello-world": { + "type": "stdio", + "command": "npx --prefix /opt/mcp/npm_installs mcp-hello-world", + "args": [], + }, + "aws-iam-mcp-server": { + "type": "stdio", + "command": "uvx awslabs.iam-mcp-server", + "args": [], + "package": "awslabs.iam-mcp-server", + }, + "github-mcp-server": { + "type": "stdio", + "command": "/opt/mcp/bin/github-mcp-server", + "args": ["stdio"], + "description": "GitHub MCP Server - Access GitHub repositories, issues, and pull requests", + }, + "remote": {"args": [], "type": "http", "url": "https://example.com/mcp"}, + } + + file_path = tmp_path / "mcpservers.json" + with open(file_path, "w", encoding="utf-8") as f: + json.dump(manifest_data, f) + yield file_path + + +@pytest.fixture +def empty_manifest_file(tmp_path): + """Create a temporary empty MCP manifest JSON file.""" + file_path = tmp_path / "empty.json" + with open(file_path, "w", encoding="utf-8") as f: + json.dump({}, f) + yield file_path + + +@pytest.fixture +def malformed_manifest_file(tmp_path): + """Create a temporary malformed MCP manifest JSON file.""" + file_path = tmp_path / "malformed.json" + with open(file_path, "w", encoding="utf-8") as f: + f.write("{invalid json") + yield file_path + + +@pytest.fixture(name="loaded_mcp_connection") +def fixture_loaded_mcp_connection(manifest_file): + """ + Return a Connection instance with test options set. + Network/stdio/http calls are mocked in the tests. + """ + play_context = PlayContext() + conn = Connection(play_context, StringIO()) + + def get_option(key): + return conn.test_options.get(key) + + # Provide a get_option helper + conn.test_options = { + "server_name": "remote", + "server_args": ["--mock"], + "server_env": {"FOO": "BAR"}, + "bearer_token": "token123", + "validate_certs": True, + "manifest_path": str(manifest_file), + "persistent_connect_timeout": 15, + "persistent_command_timeout": 15, + "persistent_log_messages": False, + } + conn.get_option = get_option + + yield conn + + +class TestMCPConnection: + def test_load_server_from_manifest_success(self, loaded_mcp_connection, manifest_file): + """Should successfully load server info for a known server.""" + server_name = "github-mcp-server" + expected_info = { + "type": "stdio", + "command": "/opt/mcp/bin/github-mcp-server", + "args": ["stdio"], + "description": "GitHub MCP Server - Access GitHub repositories, issues, and pull requests", + } + + info = loaded_mcp_connection._load_server_from_manifest(server_name, str(manifest_file)) + assert info == expected_info + + def test_load_server_from_manifest_file_not_found(self, loaded_mcp_connection): + """Should raise AnsibleConnectionFailure if manifest file is not found.""" + with pytest.raises(AnsibleConnectionFailure, match="MCP manifest not found"): + loaded_mcp_connection._load_server_from_manifest( + "any-server", "/nonexistent/manifest.json" + ) + + def test_load_server_from_manifest_server_not_found(self, loaded_mcp_connection, manifest_file): + """Should raise AnsibleConnectionFailure if the server is not in the manifest.""" + server_name = "non-existent-server" + with pytest.raises( + AnsibleConnectionFailure, match=f"MCP server '{server_name}' not found in manifest" + ): + loaded_mcp_connection._load_server_from_manifest(server_name, str(manifest_file)) + + def test_create_transport_stdio_missing_command(self, loaded_mcp_connection): + """Should raise AnsibleConnectionFailure if stdio manifest entry is missing 'command'.""" + server_name = "invalid_stdio" + server_info = {"type": "stdio"} + with pytest.raises( + AnsibleConnectionFailure, match=f"Manifest for '{server_name}' missing 'command'" + ): + loaded_mcp_connection._create_transport(server_name, server_info) + + @patch("ansible_collections.ansible.mcp.plugins.connection.mcp.StreamableHTTP", autospec=True) + def test_create_transport_http_success_no_token(self, mock_http, loaded_mcp_connection): + """Should correctly create an HTTP transport without a bearer token.""" + server_name = "remote" + server_info = {"type": "http", "url": "https://example.com/mcp"} + + loaded_mcp_connection.test_options["bearer_token"] = None # No token + + loaded_mcp_connection._create_transport(server_name, server_info) + + mock_http.assert_called_once_with( + url="https://example.com/mcp", + headers={}, + validate_certs=True, + ) + + def test_load_server_from_manifest_json_decode_error( + self, loaded_mcp_connection, malformed_manifest_file + ): + """Should raise AnsibleConnectionFailure for a malformed JSON file.""" + with pytest.raises(AnsibleConnectionFailure, match="Failed to parse MCP manifest JSON"): + loaded_mcp_connection._load_server_from_manifest( + "any-server", str(malformed_manifest_file) + ) + + @patch("ansible_collections.ansible.mcp.plugins.connection.mcp.Stdio", autospec=True) + def test_create_transport_stdio_success(self, mock_stdio, loaded_mcp_connection): + """Should correctly create a Stdio transport for a stdio server.""" + server_name = "github-mcp-server" + server_info = { + "type": "stdio", + "command": "/opt/mcp/bin/github-mcp-server", + "args": ["stdio"], + } + + loaded_mcp_connection.test_options["server_args"] = ["--verbose"] + loaded_mcp_connection.test_options["server_env"] = {"DEBUG": "1"} + + loaded_mcp_connection._create_transport(server_name, server_info) + + expected_cmd = [ + "/opt/mcp/bin/github-mcp-server", + "stdio", + "--verbose", + ] + mock_stdio.assert_called_once_with( + cmd=expected_cmd, + env={"DEBUG": "1"}, + ) + + @patch("ansible_collections.ansible.mcp.plugins.connection.mcp.StreamableHTTP", autospec=True) + def test_create_transport_http_success_with_token(self, mock_http, loaded_mcp_connection): + """Should correctly create an HTTP transport with a bearer token and validation.""" + server_name = "remote" + server_info = {"type": "http", "url": "https://example.com/mcp"} + + loaded_mcp_connection.test_options["bearer_token"] = "test-token" + loaded_mcp_connection.test_options["validate_certs"] = False + + loaded_mcp_connection._create_transport(server_name, server_info) + + mock_http.assert_called_once_with( + url="https://example.com/mcp", + headers={"Authorization": "Bearer test-token"}, + validate_certs=False, + ) + + def test_create_transport_http_missing_url(self, loaded_mcp_connection): + """Should raise AnsibleConnectionFailure if http manifest entry is missing 'url'.""" + server_name = "invalid_http" + server_info = {"type": "http"} + with pytest.raises( + AnsibleConnectionFailure, match=f"Manifest for '{server_name}' missing 'url'" + ): + loaded_mcp_connection._create_transport(server_name, server_info) + + def test_create_transport_unknown_transport_type(self, loaded_mcp_connection): + """Should raise AnsibleConnectionFailure for an unknown transport type.""" + server_name = "unknown_transport" + server_info = {"type": "ftp"} + with pytest.raises( + AnsibleConnectionFailure, + match=f"Invalid transport type 'ftp' for server '{server_name}'", + ): + loaded_mcp_connection._create_transport(server_name, server_info) + + @patch( + "ansible_collections.ansible.mcp.plugins.connection.mcp.MCPClient.initialize", + return_value=None, + ) + @patch( + "ansible_collections.ansible.mcp.plugins.connection.mcp.Stdio", + autospec=True, + ) + def test_connect_stdio_transport( + self, mock_stdio, mock_initialize, loaded_mcp_connection, manifest_file + ): + """Verify connection._connect() initializes stdio transport correctly.""" + conn = loaded_mcp_connection + conn.test_options["server_name"] = "mcp-hello-world" + + mock_transport = MagicMock() + mock_stdio.return_value = mock_transport + + conn._connect() + + mock_stdio.assert_called_once() + mock_initialize.assert_called_once() + assert conn._connected is True + assert conn._client is not None + + @patch("ansible_collections.ansible.mcp.plugins.connection.mcp.StreamableHTTP", autospec=True) + def test_connect_http_transport(self, mock_http, loaded_mcp_connection): + """Verify connection uses HTTP transport when configured.""" + mock_transport = MagicMock() + mock_http.return_value = mock_transport + # Mock request for initialize + mock_transport.request.return_value = {"result": {"server": "ok"}} + + loaded_mcp_connection._connect() + + mock_http.assert_called_once_with( + url="https://example.com/mcp", + headers={"Authorization": "Bearer token123"}, + validate_certs=True, + ) + assert loaded_mcp_connection._connected is True + assert loaded_mcp_connection._client is not None + + def test_connect_invalid_transport(self, loaded_mcp_connection): + """Invalid transport type should raise.""" + """Unknown server_name should raise AnsibleConnectionFailure.""" + loaded_mcp_connection.test_options["server_name"] = "unknown-server" + with pytest.raises(AnsibleConnectionFailure): + loaded_mcp_connection._connect() + + def test_list_tools_delegates_to_client(self, loaded_mcp_connection): + """list_tools should call MCPClient.list_tools().""" + loaded_mcp_connection._connect = MagicMock(name="_connect") + mock_client = MagicMock() + loaded_mcp_connection._client = mock_client + mock_client.list_tools.return_value = {"tools": []} + + result = loaded_mcp_connection.list_tools() + mock_client.list_tools.assert_called_once() + assert result == {"tools": []} + + def test_close_resets_state(self, loaded_mcp_connection): + """close() should reset client and connection state.""" + loaded_mcp_connection._connect = MagicMock(name="_connect") + mock_client = MagicMock() + loaded_mcp_connection._client = mock_client + client_ref = loaded_mcp_connection._client + + loaded_mcp_connection.close() + + client_ref.close.assert_called_once() + assert loaded_mcp_connection._connected is False + assert loaded_mcp_connection._client is None diff --git a/tests/unit/plugins/plugin_utils/test_client.py b/tests/unit/plugins/plugin_utils/test_client.py index 4249500..1c5a3e5 100644 --- a/tests/unit/plugins/plugin_utils/test_client.py +++ b/tests/unit/plugins/plugin_utils/test_client.py @@ -3,6 +3,12 @@ # Copyright (c) 2025 Red Hat, Inc. # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import absolute_import, division, print_function + + +__metaclass__ = type + + import pytest from ansible_collections.ansible.mcp.plugins.plugin_utils.client import MCPClient, MCPError diff --git a/tox.ini b/tox.ini index cc6e848..8568e96 100644 --- a/tox.ini +++ b/tox.ini @@ -28,9 +28,10 @@ allowlist_externals = mkdir ln rm + ansible-galaxy commands = rm -rf {toxinidir}/.collection_root - mkdir -p {toxinidir}/.collection_root/ansible_collections/ansible + ansible-galaxy collection install git+https://github.com/ansible-collections/ansible.utils.git --collections-path {toxinidir}/.collection_root ln -s {toxinidir} {toxinidir}/.collection_root/ansible_collections/ansible/mcp mypy commands_post =