|
3 | 3 | # Copyright (c) 2025 Red Hat, Inc. |
4 | 4 | # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) |
5 | 5 |
|
6 | | -import json |
7 | | -import os |
8 | | -import select |
9 | | -import subprocess |
10 | | -import time |
11 | 6 |
|
12 | | -from abc import ABC, abstractmethod |
13 | | -from functools import wraps |
14 | | -from typing import Any, Callable, Dict, Optional, Union |
| 7 | +from typing import Any, Dict, Optional |
15 | 8 |
|
16 | | -from ansible.errors import AnsibleConnectionFailure |
17 | | -from ansible.module_utils.urls import open_url |
18 | | - |
19 | | - |
20 | | -class MCPError(Exception): |
21 | | - """Base exception class for MCP related errors. |
22 | | -
|
23 | | - This exception is raised when MCP operations fail, such as initialization, |
24 | | - tool listing, tool execution, or validation errors. |
25 | | - """ |
26 | | - |
27 | | - pass |
28 | | - |
29 | | - |
30 | | -class Transport(ABC): |
31 | | - @abstractmethod |
32 | | - def connect(self) -> None: |
33 | | - """Connect to the MCP server. |
34 | | -
|
35 | | - This is called before attempting to perform initialization. |
36 | | - """ |
37 | | - pass |
38 | | - |
39 | | - @abstractmethod |
40 | | - def notify(self, data: dict) -> None: |
41 | | - """Send a notification message to the server. |
42 | | -
|
43 | | - This sends a JSON-RPC payload to the server when no response is |
44 | | - expected. |
45 | | -
|
46 | | - Args: |
47 | | - data: JSON-RPC payload. |
48 | | - """ |
49 | | - pass |
50 | | - |
51 | | - @abstractmethod |
52 | | - def request(self, data: dict) -> dict: |
53 | | - """Send a request to the server. |
54 | | -
|
55 | | - This sends a JSON-RPC payload to the server when a response is expected. |
56 | | -
|
57 | | - Args: |
58 | | - data: JSON-RPC payload. |
59 | | - Returns: |
60 | | - The JSON-RPC response from the server. |
61 | | - """ |
62 | | - pass |
63 | | - |
64 | | - @abstractmethod |
65 | | - def close(self) -> None: |
66 | | - """Close the server connection. |
67 | | -
|
68 | | - This is called to perform any final actions to close and clean up the |
69 | | - connection. |
70 | | - """ |
71 | | - pass |
72 | | - |
73 | | - |
74 | | -class Stdio(Transport): |
75 | | - def __init__(self, cmd: Union[list[str], str], env: Optional[dict] = None): |
76 | | - """Initialize the stdio transport class. |
77 | | -
|
78 | | - Args: |
79 | | - cmd: Command used to run the MCP server. |
80 | | - env: Environment variables to set for the MCP server process. |
81 | | - """ |
82 | | - self._cmd = cmd |
83 | | - self._env = env |
84 | | - self._process: Optional[Any] = None |
85 | | - |
86 | | - def connect(self) -> None: |
87 | | - """Spawn a local MCP server subprocess.""" |
88 | | - params: dict[str, Any] = { |
89 | | - "stdin": subprocess.PIPE, |
90 | | - "stdout": subprocess.PIPE, |
91 | | - "stderr": subprocess.PIPE, |
92 | | - "text": True, |
93 | | - "bufsize": 0, # Unbuffered for real-time communication |
94 | | - } |
95 | | - |
96 | | - if self._env: |
97 | | - # Prepare environment for command |
98 | | - env: dict[str, Any] = os.environ.copy() |
99 | | - env.update(self._env) |
100 | | - params.update({"env": env}) |
101 | | - |
102 | | - try: |
103 | | - cmd = self._cmd |
104 | | - if isinstance(self._cmd, str): |
105 | | - cmd = [self._cmd] |
106 | | - self._process = subprocess.Popen(cmd, **params) |
107 | | - |
108 | | - # Give the server a moment to start |
109 | | - time.sleep(0.1) |
110 | | - |
111 | | - # Check if process started successfully |
112 | | - if self._process.poll() is not None: |
113 | | - try: |
114 | | - stdout, stderr = self._process.communicate(timeout=3) |
115 | | - except subprocess.TimeoutExpired: |
116 | | - stdout, stderr = "", "" |
117 | | - pass |
118 | | - raise AnsibleConnectionFailure( |
119 | | - f"MCP server exited immediately. stdout: {stdout}, stderr: {stderr}" |
120 | | - ) |
121 | | - except AnsibleConnectionFailure: |
122 | | - raise |
123 | | - except Exception as e: |
124 | | - raise AnsibleConnectionFailure(f"Failed to start MCP server: {str(e)}") |
125 | | - |
126 | | - def _stdout_read(self, wait_timeout: int = 5) -> dict: |
127 | | - """Read response from MCP server with timeout. |
128 | | -
|
129 | | - Args: |
130 | | - wait_timeout: The wait timeout value, default: 5. |
131 | | - Returns: |
132 | | - A JSON-RPC response dictionary from the MCP server. |
133 | | - """ |
134 | | - |
135 | | - response = {} |
136 | | - if self._process: |
137 | | - rfd, wfd, efd = select.select([self._process.stdout], [], [], wait_timeout) |
138 | | - if not (rfd or wfd or efd): |
139 | | - # Process has timeout |
140 | | - raise AnsibleConnectionFailure( |
141 | | - f"MCP server response timeout after {wait_timeout} seconds." |
142 | | - ) |
143 | | - |
144 | | - if self._process.stdout in rfd: |
145 | | - response = json.loads( |
146 | | - os.read(self._process.stdout.fileno(), 4096).decode("utf-8").strip() |
147 | | - ) |
148 | | - return response |
149 | | - |
150 | | - def _stdin_write(self, data: dict) -> None: |
151 | | - """Write data to process standard input. |
152 | | -
|
153 | | - Args: |
154 | | - data: JSON-RPC payload. |
155 | | - """ |
156 | | - data_json = json.dumps(data) + "\n" |
157 | | - if self._process is not None: |
158 | | - self._process.stdin.write(data_json) |
159 | | - self._process.stdin.flush() |
160 | | - |
161 | | - def _ensure_server_started(func: Callable): # type: ignore # see https://github.com/python/mypy/issues/7778 # pylint: disable=no-self-argument |
162 | | - """Decorator to ensure that the MCP server process is running before method execution.""" |
163 | | - |
164 | | - @wraps(func) |
165 | | - def wrapped(self, *args, **kwargs: dict[str, Any]): |
166 | | - if self._process is None: |
167 | | - raise AnsibleConnectionFailure("MCP server process not started.") |
168 | | - if self._process.poll() is not None: |
169 | | - stdout, stderr = self._process.communicate() |
170 | | - raise AnsibleConnectionFailure( |
171 | | - f"MCP server process terminated unexpectedly. stdout: {stdout}, stderr: {stderr}" |
172 | | - ) |
173 | | - return func(self, *args, **kwargs) |
174 | | - |
175 | | - return wrapped |
176 | | - |
177 | | - @_ensure_server_started |
178 | | - def notify(self, data: dict) -> None: |
179 | | - """Send a notification message to the server. |
180 | | -
|
181 | | - This sends a JSON-RPC payload to the server when no response is |
182 | | - expected. |
183 | | -
|
184 | | - Args: |
185 | | - data: JSON-RPC payload. |
186 | | - """ |
187 | | - try: |
188 | | - self._stdin_write(data) |
189 | | - except Exception as e: |
190 | | - raise AnsibleConnectionFailure(f"Error sending notification to MCP server: {str(e)}") |
191 | | - |
192 | | - @_ensure_server_started |
193 | | - def request(self, data: dict) -> dict: |
194 | | - """Send a request to the server. |
195 | | -
|
196 | | - This sends a JSON-RPC payload to the server when a response is expected. |
197 | | -
|
198 | | - Args: |
199 | | - data: JSON-RPC payload. |
200 | | - Returns: |
201 | | - The JSON-RPC response from the server. |
202 | | - """ |
203 | | - try: |
204 | | - # Send request to the server |
205 | | - self._stdin_write(data) |
206 | | - # Read response |
207 | | - return self._stdout_read() |
208 | | - except Exception as e: |
209 | | - raise AnsibleConnectionFailure(f"Error sending request to MCP server: {str(e)}") |
210 | | - |
211 | | - def close(self) -> None: |
212 | | - """Close the server connection.""" |
213 | | - if self._process: |
214 | | - try: |
215 | | - # Try to terminate gracefully first |
216 | | - self._process.terminate() |
217 | | - |
218 | | - # Wait for process to terminate |
219 | | - self._process.wait(timeout=5) |
220 | | - except subprocess.TimeoutExpired: |
221 | | - # Force kill if it doesn't terminate gracefully |
222 | | - self._process.kill() |
223 | | - self._process.wait() |
224 | | - except Exception as e: |
225 | | - raise AnsibleConnectionFailure(f"Error closing MCP process: {str(e)}") |
226 | | - finally: |
227 | | - self._process = None |
228 | | - |
229 | | - |
230 | | -class StreamableHTTP(Transport): |
231 | | - def __init__(self, url: str, headers: Optional[dict] = None, validate_certs: bool = True): |
232 | | - """Initialize the StreamableHTTP transport. |
233 | | -
|
234 | | - Args: |
235 | | - url: The MCP server URL endpoint |
236 | | - headers: Optional headers to include with requests |
237 | | - validate_certs: Whether to validate SSL certificates (default: True) |
238 | | - """ |
239 | | - self.url = url |
240 | | - self._headers: Dict[str, str] = headers.copy() if headers else {} |
241 | | - self.validate_certs = validate_certs |
242 | | - self._session_id = None |
243 | | - |
244 | | - def connect(self) -> None: |
245 | | - """Connect to the MCP server. |
246 | | -
|
247 | | - For HTTP transport, this is a no-op as connection is established |
248 | | - per-request. |
249 | | - """ |
250 | | - pass |
251 | | - |
252 | | - def notify(self, data: dict) -> None: |
253 | | - """Send a notification message to the server. |
254 | | -
|
255 | | - Args: |
256 | | - data: JSON-RPC payload. |
257 | | - """ |
258 | | - headers = self._build_headers() |
259 | | - |
260 | | - try: |
261 | | - response = open_url( |
262 | | - self.url, |
263 | | - method="POST", |
264 | | - data=json.dumps(data), |
265 | | - headers=headers, |
266 | | - validate_certs=self.validate_certs, |
267 | | - ) |
268 | | - |
269 | | - if response.getcode() != 202: |
270 | | - raise Exception(f"Unexpected response code: {response.getcode()}") |
271 | | - |
272 | | - self._extract_session_id(response) |
273 | | - |
274 | | - except Exception as e: |
275 | | - raise Exception(f"Failed to send notification: {str(e)}") |
276 | | - |
277 | | - def request(self, data: dict) -> dict: |
278 | | - """Send a request to the server. |
279 | | -
|
280 | | - Args: |
281 | | - data: JSON-RPC payload. |
282 | | -
|
283 | | - Returns: |
284 | | - The JSON-RPC response from the server. |
285 | | - """ |
286 | | - headers = self._build_headers() |
287 | | - |
288 | | - try: |
289 | | - response = open_url( |
290 | | - self.url, |
291 | | - method="POST", |
292 | | - data=json.dumps(data), |
293 | | - headers=headers, |
294 | | - validate_certs=self.validate_certs, |
295 | | - ) |
296 | | - |
297 | | - if response.getcode() != 200: |
298 | | - raise Exception(f"Unexpected response code: {response.getcode()}") |
299 | | - |
300 | | - self._extract_session_id(response) |
301 | | - |
302 | | - response_data = response.read() |
303 | | - |
304 | | - # Parse JSON response |
305 | | - try: |
306 | | - return json.loads(response_data.decode("utf-8")) |
307 | | - except json.JSONDecodeError as e: |
308 | | - raise Exception(f"Invalid JSON response: {str(e)}") |
309 | | - |
310 | | - except Exception as e: |
311 | | - raise Exception(f"Failed to send request: {str(e)}") |
312 | | - |
313 | | - def close(self) -> None: |
314 | | - """Close the server connection. |
315 | | -
|
316 | | - For HTTP transport, this is a no-op as connections are not persistent. |
317 | | - """ |
318 | | - pass |
319 | | - |
320 | | - def _build_headers(self) -> dict: |
321 | | - """Build headers for HTTP requests. |
322 | | -
|
323 | | - Returns: |
324 | | - Dictionary of headers to include in the request. |
325 | | - """ |
326 | | - headers = { |
327 | | - "Content-Type": "application/json", |
328 | | - "Accept": "application/json, text/event-stream", |
329 | | - "MCP-Protocol-Version": "2025-06-18", |
330 | | - } |
331 | | - |
332 | | - # Add custom headers |
333 | | - headers.update(self._headers) |
334 | | - |
335 | | - # Add session ID if available |
336 | | - if self._session_id: |
337 | | - headers["Mcp-Session-Id"] = self._session_id |
338 | | - |
339 | | - return headers |
340 | | - |
341 | | - def _extract_session_id(self, response) -> None: |
342 | | - """Extract session ID from response headers. |
343 | | -
|
344 | | - Args: |
345 | | - response: The HTTP response object |
346 | | - """ |
347 | | - # Check for Mcp-Session-Id header in response |
348 | | - session_header = response.headers.get("Mcp-Session-Id") |
349 | | - if session_header is not None: |
350 | | - self._session_id = session_header |
| 9 | +from ansible_collections.ansible.mcp.plugins.plugin_utils.errors import MCPError |
| 10 | +from ansible_collections.ansible.mcp.plugins.plugin_utils.transport import Transport |
351 | 11 |
|
352 | 12 |
|
353 | 13 | class MCPClient: |
|
0 commit comments