Skip to content

Commit 8f39ed6

Browse files
mandar242gravesm
andauthored
Add MCPClient class (#7)
* add MCPClient class * black linter fix * add unit tests * minor fix * minor fix * modified based on feedback * modified based on feedback * sanity fix * minor fix * restructure validate, other minox fixes * move tests to unit/plugins/module_utils * linter fix * Restrict sanity tests to >=3.6 --------- Co-authored-by: Mike Graves <[email protected]>
1 parent 51eefed commit 8f39ed6

File tree

3 files changed

+622
-0
lines changed

3 files changed

+622
-0
lines changed

plugins/plugin_utils/mcp.py

Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright (c) 2025 Red Hat, Inc.
4+
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
5+
16
import json
27
import os
38
import select
@@ -12,6 +17,16 @@
1217
from ansible.module_utils.urls import open_url
1318

1419

20+
class MCPError(Exception):
21+
"""Base exception class for MCP related errors.
22+
23+
This exception is raised when MCP operations fail, such as initialization,
24+
tool listing, tool execution, or validation errors.
25+
"""
26+
27+
pass
28+
29+
1530
class Transport(ABC):
1631
@abstractmethod
1732
def connect(self) -> None:
@@ -333,3 +348,351 @@ def _extract_session_id(self, response) -> None:
333348
session_header = response.headers.get("Mcp-Session-Id")
334349
if session_header is not None:
335350
self._session_id = session_header
351+
352+
353+
class MCPClient:
354+
"""Client for communicating with MCP (Model Context Protocol) servers.
355+
356+
Attributes:
357+
transport: The transport layer for communication with the server
358+
"""
359+
360+
def __init__(self, transport: Transport) -> None:
361+
"""Initialize the MCP client.
362+
363+
Args:
364+
transport: Transport implementation for server communication
365+
"""
366+
self.transport = transport
367+
self._connected = False
368+
self._server_info: Optional[Dict[str, Any]] = None
369+
self._tools_cache: Optional[Dict[str, Any]] = None
370+
self._request_id = 0
371+
372+
def _get_next_id(self) -> int:
373+
"""Generate the next request ID.
374+
375+
Returns:
376+
Unique request ID
377+
"""
378+
self._request_id += 1
379+
return self._request_id
380+
381+
def _build_request(
382+
self, method: str, params: Optional[Dict[str, Any]] = None
383+
) -> Dict[str, Any]:
384+
"""Compose a JSON-RPC 2.0 request for MCP.
385+
386+
Args:
387+
method: The JSON-RPC method name
388+
params: Optional parameters for the request
389+
390+
Returns:
391+
Dictionary containing the JSON-RPC request
392+
"""
393+
request = {
394+
"jsonrpc": "2.0",
395+
"id": self._get_next_id(),
396+
"method": method,
397+
}
398+
if params is not None:
399+
request["params"] = params
400+
return request
401+
402+
def _handle_response(self, response: Dict[str, Any], operation: str) -> Dict[str, Any]:
403+
"""Handle JSON-RPC response and extract result or raise appropriate error.
404+
405+
Args:
406+
response: JSON-RPC response from server
407+
operation: Description of the operation being performed (for error messages)
408+
409+
Returns:
410+
The result from the response
411+
412+
Raises:
413+
MCPError: If the response contains an error
414+
"""
415+
if "result" in response:
416+
return response["result"]
417+
else:
418+
raise MCPError(
419+
f"Failed to {operation}: {response.get('error', f'Error in {operation}')}"
420+
)
421+
422+
def initialize(self) -> None:
423+
"""Initialize the connection to the MCP server.
424+
425+
Raises:
426+
MCPError: If initialization fails
427+
"""
428+
if not self._connected:
429+
self.transport.connect()
430+
431+
# Send initialize request
432+
init_request = self._build_request(
433+
"initialize",
434+
{
435+
"protocolVersion": "2025-03-26",
436+
"capabilities": {
437+
"roots": {"listChanged": True},
438+
"sampling": {},
439+
},
440+
"clientInfo": {
441+
"name": "ansible-mcp-client",
442+
"version": "1.0.0",
443+
},
444+
},
445+
)
446+
447+
response = self.transport.request(init_request)
448+
449+
# Cache server info from response
450+
self._server_info = self._handle_response(response, "initialize")
451+
452+
# Send initialized notification
453+
initialized_notification = {
454+
"jsonrpc": "2.0",
455+
"method": "notifications/initialized",
456+
}
457+
self.transport.notify(initialized_notification)
458+
459+
# Mark as connected only after successful initialization
460+
self._connected = True
461+
462+
def list_tools(self) -> Dict[str, Any]:
463+
"""List all available tools from the MCP server.
464+
465+
Returns:
466+
Dictionary containing the tools list response
467+
468+
Raises:
469+
MCPError: If the request fails
470+
"""
471+
if not self._connected or self._server_info is None:
472+
raise MCPError("Client not initialized. Call initialize() first.")
473+
474+
# Return cached result if available
475+
if self._tools_cache is not None:
476+
return self._tools_cache
477+
478+
# Make request to server
479+
request = self._build_request("tools/list")
480+
481+
response = self.transport.request(request)
482+
483+
self._tools_cache = self._handle_response(response, "list tools")
484+
return self._tools_cache
485+
486+
def get_tool(self, tool: str) -> Dict[str, Any]:
487+
"""Get the definition of a specific tool.
488+
489+
Args:
490+
tool: Name of the tool to retrieve
491+
492+
Returns:
493+
Dictionary containing the tool definition
494+
495+
Raises:
496+
MCPError: If client is not initialized or if the tool is not found
497+
"""
498+
if not self._connected or self._server_info is None:
499+
raise MCPError("Client not initialized. Call initialize() first.")
500+
501+
tools_response = self.list_tools()
502+
tools = tools_response.get("tools", [])
503+
504+
for tool_def in tools:
505+
if tool_def.get("name") == tool:
506+
return tool_def
507+
508+
raise MCPError(f"Tool '{tool}' not found")
509+
510+
def call_tool(self, tool: str, **kwargs: Any) -> Dict[str, Any]:
511+
"""Call a tool on the MCP server with the provided arguments.
512+
513+
Args:
514+
tool: Name of the tool to call
515+
**kwargs: Arguments to pass to the tool
516+
517+
Returns:
518+
Dictionary containing the tool call response
519+
520+
Raises:
521+
ValueError: If validation fails
522+
MCPError: If the tool call fails
523+
"""
524+
if not self._connected or self._server_info is None:
525+
raise MCPError("Client not initialized. Call initialize() first.")
526+
527+
# Validate parameters before making the request
528+
self.validate(tool, **kwargs)
529+
530+
request = self._build_request(
531+
"tools/call",
532+
{
533+
"name": tool,
534+
"arguments": kwargs,
535+
},
536+
)
537+
538+
response = self.transport.request(request)
539+
540+
return self._handle_response(response, f"call tool '{tool}'")
541+
542+
@property
543+
def server_info(self) -> Dict[str, Any]:
544+
"""Return cached server information from initialization.
545+
546+
Returns:
547+
Dictionary containing server information
548+
549+
Raises:
550+
MCPError: If initialize() has not been called yet
551+
"""
552+
if self._server_info is None:
553+
raise MCPError("Client not initialized. Call initialize() first.")
554+
return self._server_info
555+
556+
def _validate_schema_type(self, tool: str, schema: Dict[str, Any]) -> None:
557+
"""Validate that the schema type is supported.
558+
559+
Args:
560+
tool: Name of the tool being validated
561+
schema: The input schema from the tool definition
562+
563+
Raises:
564+
ValueError: If the schema type is not supported
565+
"""
566+
schema_type = schema.get("type")
567+
if schema_type and schema_type != "object":
568+
raise ValueError(
569+
f"Tool '{tool}' has unsupported schema type '{schema_type}', expected 'object'"
570+
)
571+
572+
def _validate_required_parameters(
573+
self, tool: str, kwargs: Dict[str, Any], required_parameters: list
574+
) -> None:
575+
"""Validate that all required parameters are provided.
576+
577+
Args:
578+
tool: Name of the tool being validated
579+
kwargs: Arguments provided to the tool
580+
required_parameters: List of required parameter names
581+
582+
Raises:
583+
ValueError: If required parameters are missing
584+
"""
585+
missing_required = [param for param in required_parameters if param not in kwargs]
586+
if missing_required:
587+
raise ValueError(
588+
f"Tool '{tool}' missing required parameters: {', '.join(missing_required)}"
589+
)
590+
591+
def _validate_unknown_parameters(
592+
self, tool: str, kwargs: Dict[str, Any], schema_properties: Dict[str, Any]
593+
) -> None:
594+
"""Validate that no unknown parameters are provided.
595+
596+
Args:
597+
tool: Name of the tool being validated
598+
kwargs: Arguments provided to the tool
599+
schema_properties: Properties defined in the schema
600+
601+
Raises:
602+
ValueError: If unknown parameters are provided
603+
"""
604+
if schema_properties:
605+
unknown_parameters = [param for param in kwargs if param not in schema_properties]
606+
if unknown_parameters:
607+
raise ValueError(
608+
f"Tool '{tool}' received unknown parameters: {', '.join(unknown_parameters)}"
609+
)
610+
611+
def _validate_parameter_type(
612+
self, tool: str, parameter_name: str, parameter_value: Any, parameter_schema: Dict[str, Any]
613+
) -> None:
614+
"""Validate that a parameter value matches its expected type.
615+
616+
Args:
617+
tool: Name of the tool being validated
618+
parameter_name: Name of the parameter being validated
619+
parameter_value: Value of the parameter
620+
parameter_schema: Schema definition for the parameter
621+
622+
Raises:
623+
ValueError: If the parameter type is invalid
624+
"""
625+
parameter_type_in_schema = parameter_schema.get("type")
626+
if not parameter_type_in_schema:
627+
return
628+
629+
# Handle None values first
630+
if parameter_value is None:
631+
if parameter_type_in_schema != "null":
632+
raise ValueError(
633+
f"Parameter '{parameter_name}' for tool '{tool}' cannot be None (expected type '{parameter_type_in_schema}')"
634+
)
635+
return
636+
637+
# Map JSON Schema types to their corresponding Python types
638+
schema_type_to_python_type = {
639+
"string": str,
640+
"number": (int, float),
641+
"integer": int,
642+
"boolean": bool,
643+
"array": list,
644+
"object": dict,
645+
"null": type(None),
646+
}
647+
648+
expected_type = schema_type_to_python_type.get(parameter_type_in_schema)
649+
if expected_type is None:
650+
raise ValueError(
651+
f"Tool '{tool}' has unsupported parameter type '{parameter_type_in_schema}' for parameter '{parameter_name}'"
652+
)
653+
654+
if not isinstance(parameter_value, expected_type): # type: ignore[arg-type]
655+
raise ValueError(
656+
f"Parameter '{parameter_name}' for tool '{tool}' should be of type "
657+
f"'{parameter_type_in_schema}', but got '{type(parameter_value).__name__}'"
658+
)
659+
660+
def validate(self, tool: str, **kwargs: Any) -> None:
661+
"""Validate that a tool call arguments match the tool's schema.
662+
663+
Args:
664+
tool: Name of the tool to validate
665+
**kwargs: Arguments to validate against the tool schema
666+
667+
Raises:
668+
MCPError: If the tool is not found
669+
ValueError: If validation fails (missing required parameters, etc.)
670+
"""
671+
# Get tool definition and schema
672+
tool_definition = self.get_tool(tool)
673+
schema = tool_definition.get("inputSchema", {})
674+
675+
# Extract schema components
676+
parameters_from_schema_properties = schema.get("properties", {})
677+
required_parameters = schema.get("required", [])
678+
679+
# Perform validation
680+
self._validate_schema_type(tool, schema)
681+
self._validate_required_parameters(tool, kwargs, required_parameters)
682+
self._validate_unknown_parameters(tool, kwargs, parameters_from_schema_properties)
683+
684+
# Validate parameter types
685+
for parameter_name, parameter_value in kwargs.items():
686+
if parameter_name in parameters_from_schema_properties:
687+
parameter_schema = parameters_from_schema_properties[parameter_name]
688+
self._validate_parameter_type(
689+
tool, parameter_name, parameter_value, parameter_schema
690+
)
691+
692+
def close(self) -> None:
693+
"""Close the connection to the MCP server."""
694+
self.transport.close()
695+
self._connected = False
696+
self._server_info = None
697+
self._tools_cache = None
698+
self._request_id = 0

tests/config.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
modules:
2+
python_requires: ">=3.6"

0 commit comments

Comments
 (0)