|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +# Copyright (c) 2025 Red Hat, Inc. |
| 4 | +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) |
| 5 | + |
1 | 6 | import json |
2 | 7 | import os |
3 | 8 | import select |
|
12 | 17 | from ansible.module_utils.urls import open_url |
13 | 18 |
|
14 | 19 |
|
| 20 | +class MCPError(Exception): |
| 21 | + """Base exception class for MCP related errors. |
| 22 | +
|
| 23 | + This exception is raised when MCP operations fail, such as initialization, |
| 24 | + tool listing, tool execution, or validation errors. |
| 25 | + """ |
| 26 | + |
| 27 | + pass |
| 28 | + |
| 29 | + |
15 | 30 | class Transport(ABC): |
16 | 31 | @abstractmethod |
17 | 32 | def connect(self) -> None: |
@@ -333,3 +348,351 @@ def _extract_session_id(self, response) -> None: |
333 | 348 | session_header = response.headers.get("Mcp-Session-Id") |
334 | 349 | if session_header is not None: |
335 | 350 | self._session_id = session_header |
| 351 | + |
| 352 | + |
| 353 | +class MCPClient: |
| 354 | + """Client for communicating with MCP (Model Context Protocol) servers. |
| 355 | +
|
| 356 | + Attributes: |
| 357 | + transport: The transport layer for communication with the server |
| 358 | + """ |
| 359 | + |
| 360 | + def __init__(self, transport: Transport) -> None: |
| 361 | + """Initialize the MCP client. |
| 362 | +
|
| 363 | + Args: |
| 364 | + transport: Transport implementation for server communication |
| 365 | + """ |
| 366 | + self.transport = transport |
| 367 | + self._connected = False |
| 368 | + self._server_info: Optional[Dict[str, Any]] = None |
| 369 | + self._tools_cache: Optional[Dict[str, Any]] = None |
| 370 | + self._request_id = 0 |
| 371 | + |
| 372 | + def _get_next_id(self) -> int: |
| 373 | + """Generate the next request ID. |
| 374 | +
|
| 375 | + Returns: |
| 376 | + Unique request ID |
| 377 | + """ |
| 378 | + self._request_id += 1 |
| 379 | + return self._request_id |
| 380 | + |
| 381 | + def _build_request( |
| 382 | + self, method: str, params: Optional[Dict[str, Any]] = None |
| 383 | + ) -> Dict[str, Any]: |
| 384 | + """Compose a JSON-RPC 2.0 request for MCP. |
| 385 | +
|
| 386 | + Args: |
| 387 | + method: The JSON-RPC method name |
| 388 | + params: Optional parameters for the request |
| 389 | +
|
| 390 | + Returns: |
| 391 | + Dictionary containing the JSON-RPC request |
| 392 | + """ |
| 393 | + request = { |
| 394 | + "jsonrpc": "2.0", |
| 395 | + "id": self._get_next_id(), |
| 396 | + "method": method, |
| 397 | + } |
| 398 | + if params is not None: |
| 399 | + request["params"] = params |
| 400 | + return request |
| 401 | + |
| 402 | + def _handle_response(self, response: Dict[str, Any], operation: str) -> Dict[str, Any]: |
| 403 | + """Handle JSON-RPC response and extract result or raise appropriate error. |
| 404 | +
|
| 405 | + Args: |
| 406 | + response: JSON-RPC response from server |
| 407 | + operation: Description of the operation being performed (for error messages) |
| 408 | +
|
| 409 | + Returns: |
| 410 | + The result from the response |
| 411 | +
|
| 412 | + Raises: |
| 413 | + MCPError: If the response contains an error |
| 414 | + """ |
| 415 | + if "result" in response: |
| 416 | + return response["result"] |
| 417 | + else: |
| 418 | + raise MCPError( |
| 419 | + f"Failed to {operation}: {response.get('error', f'Error in {operation}')}" |
| 420 | + ) |
| 421 | + |
| 422 | + def initialize(self) -> None: |
| 423 | + """Initialize the connection to the MCP server. |
| 424 | +
|
| 425 | + Raises: |
| 426 | + MCPError: If initialization fails |
| 427 | + """ |
| 428 | + if not self._connected: |
| 429 | + self.transport.connect() |
| 430 | + |
| 431 | + # Send initialize request |
| 432 | + init_request = self._build_request( |
| 433 | + "initialize", |
| 434 | + { |
| 435 | + "protocolVersion": "2025-03-26", |
| 436 | + "capabilities": { |
| 437 | + "roots": {"listChanged": True}, |
| 438 | + "sampling": {}, |
| 439 | + }, |
| 440 | + "clientInfo": { |
| 441 | + "name": "ansible-mcp-client", |
| 442 | + "version": "1.0.0", |
| 443 | + }, |
| 444 | + }, |
| 445 | + ) |
| 446 | + |
| 447 | + response = self.transport.request(init_request) |
| 448 | + |
| 449 | + # Cache server info from response |
| 450 | + self._server_info = self._handle_response(response, "initialize") |
| 451 | + |
| 452 | + # Send initialized notification |
| 453 | + initialized_notification = { |
| 454 | + "jsonrpc": "2.0", |
| 455 | + "method": "notifications/initialized", |
| 456 | + } |
| 457 | + self.transport.notify(initialized_notification) |
| 458 | + |
| 459 | + # Mark as connected only after successful initialization |
| 460 | + self._connected = True |
| 461 | + |
| 462 | + def list_tools(self) -> Dict[str, Any]: |
| 463 | + """List all available tools from the MCP server. |
| 464 | +
|
| 465 | + Returns: |
| 466 | + Dictionary containing the tools list response |
| 467 | +
|
| 468 | + Raises: |
| 469 | + MCPError: If the request fails |
| 470 | + """ |
| 471 | + if not self._connected or self._server_info is None: |
| 472 | + raise MCPError("Client not initialized. Call initialize() first.") |
| 473 | + |
| 474 | + # Return cached result if available |
| 475 | + if self._tools_cache is not None: |
| 476 | + return self._tools_cache |
| 477 | + |
| 478 | + # Make request to server |
| 479 | + request = self._build_request("tools/list") |
| 480 | + |
| 481 | + response = self.transport.request(request) |
| 482 | + |
| 483 | + self._tools_cache = self._handle_response(response, "list tools") |
| 484 | + return self._tools_cache |
| 485 | + |
| 486 | + def get_tool(self, tool: str) -> Dict[str, Any]: |
| 487 | + """Get the definition of a specific tool. |
| 488 | +
|
| 489 | + Args: |
| 490 | + tool: Name of the tool to retrieve |
| 491 | +
|
| 492 | + Returns: |
| 493 | + Dictionary containing the tool definition |
| 494 | +
|
| 495 | + Raises: |
| 496 | + MCPError: If client is not initialized or if the tool is not found |
| 497 | + """ |
| 498 | + if not self._connected or self._server_info is None: |
| 499 | + raise MCPError("Client not initialized. Call initialize() first.") |
| 500 | + |
| 501 | + tools_response = self.list_tools() |
| 502 | + tools = tools_response.get("tools", []) |
| 503 | + |
| 504 | + for tool_def in tools: |
| 505 | + if tool_def.get("name") == tool: |
| 506 | + return tool_def |
| 507 | + |
| 508 | + raise MCPError(f"Tool '{tool}' not found") |
| 509 | + |
| 510 | + def call_tool(self, tool: str, **kwargs: Any) -> Dict[str, Any]: |
| 511 | + """Call a tool on the MCP server with the provided arguments. |
| 512 | +
|
| 513 | + Args: |
| 514 | + tool: Name of the tool to call |
| 515 | + **kwargs: Arguments to pass to the tool |
| 516 | +
|
| 517 | + Returns: |
| 518 | + Dictionary containing the tool call response |
| 519 | +
|
| 520 | + Raises: |
| 521 | + ValueError: If validation fails |
| 522 | + MCPError: If the tool call fails |
| 523 | + """ |
| 524 | + if not self._connected or self._server_info is None: |
| 525 | + raise MCPError("Client not initialized. Call initialize() first.") |
| 526 | + |
| 527 | + # Validate parameters before making the request |
| 528 | + self.validate(tool, **kwargs) |
| 529 | + |
| 530 | + request = self._build_request( |
| 531 | + "tools/call", |
| 532 | + { |
| 533 | + "name": tool, |
| 534 | + "arguments": kwargs, |
| 535 | + }, |
| 536 | + ) |
| 537 | + |
| 538 | + response = self.transport.request(request) |
| 539 | + |
| 540 | + return self._handle_response(response, f"call tool '{tool}'") |
| 541 | + |
| 542 | + @property |
| 543 | + def server_info(self) -> Dict[str, Any]: |
| 544 | + """Return cached server information from initialization. |
| 545 | +
|
| 546 | + Returns: |
| 547 | + Dictionary containing server information |
| 548 | +
|
| 549 | + Raises: |
| 550 | + MCPError: If initialize() has not been called yet |
| 551 | + """ |
| 552 | + if self._server_info is None: |
| 553 | + raise MCPError("Client not initialized. Call initialize() first.") |
| 554 | + return self._server_info |
| 555 | + |
| 556 | + def _validate_schema_type(self, tool: str, schema: Dict[str, Any]) -> None: |
| 557 | + """Validate that the schema type is supported. |
| 558 | +
|
| 559 | + Args: |
| 560 | + tool: Name of the tool being validated |
| 561 | + schema: The input schema from the tool definition |
| 562 | +
|
| 563 | + Raises: |
| 564 | + ValueError: If the schema type is not supported |
| 565 | + """ |
| 566 | + schema_type = schema.get("type") |
| 567 | + if schema_type and schema_type != "object": |
| 568 | + raise ValueError( |
| 569 | + f"Tool '{tool}' has unsupported schema type '{schema_type}', expected 'object'" |
| 570 | + ) |
| 571 | + |
| 572 | + def _validate_required_parameters( |
| 573 | + self, tool: str, kwargs: Dict[str, Any], required_parameters: list |
| 574 | + ) -> None: |
| 575 | + """Validate that all required parameters are provided. |
| 576 | +
|
| 577 | + Args: |
| 578 | + tool: Name of the tool being validated |
| 579 | + kwargs: Arguments provided to the tool |
| 580 | + required_parameters: List of required parameter names |
| 581 | +
|
| 582 | + Raises: |
| 583 | + ValueError: If required parameters are missing |
| 584 | + """ |
| 585 | + missing_required = [param for param in required_parameters if param not in kwargs] |
| 586 | + if missing_required: |
| 587 | + raise ValueError( |
| 588 | + f"Tool '{tool}' missing required parameters: {', '.join(missing_required)}" |
| 589 | + ) |
| 590 | + |
| 591 | + def _validate_unknown_parameters( |
| 592 | + self, tool: str, kwargs: Dict[str, Any], schema_properties: Dict[str, Any] |
| 593 | + ) -> None: |
| 594 | + """Validate that no unknown parameters are provided. |
| 595 | +
|
| 596 | + Args: |
| 597 | + tool: Name of the tool being validated |
| 598 | + kwargs: Arguments provided to the tool |
| 599 | + schema_properties: Properties defined in the schema |
| 600 | +
|
| 601 | + Raises: |
| 602 | + ValueError: If unknown parameters are provided |
| 603 | + """ |
| 604 | + if schema_properties: |
| 605 | + unknown_parameters = [param for param in kwargs if param not in schema_properties] |
| 606 | + if unknown_parameters: |
| 607 | + raise ValueError( |
| 608 | + f"Tool '{tool}' received unknown parameters: {', '.join(unknown_parameters)}" |
| 609 | + ) |
| 610 | + |
| 611 | + def _validate_parameter_type( |
| 612 | + self, tool: str, parameter_name: str, parameter_value: Any, parameter_schema: Dict[str, Any] |
| 613 | + ) -> None: |
| 614 | + """Validate that a parameter value matches its expected type. |
| 615 | +
|
| 616 | + Args: |
| 617 | + tool: Name of the tool being validated |
| 618 | + parameter_name: Name of the parameter being validated |
| 619 | + parameter_value: Value of the parameter |
| 620 | + parameter_schema: Schema definition for the parameter |
| 621 | +
|
| 622 | + Raises: |
| 623 | + ValueError: If the parameter type is invalid |
| 624 | + """ |
| 625 | + parameter_type_in_schema = parameter_schema.get("type") |
| 626 | + if not parameter_type_in_schema: |
| 627 | + return |
| 628 | + |
| 629 | + # Handle None values first |
| 630 | + if parameter_value is None: |
| 631 | + if parameter_type_in_schema != "null": |
| 632 | + raise ValueError( |
| 633 | + f"Parameter '{parameter_name}' for tool '{tool}' cannot be None (expected type '{parameter_type_in_schema}')" |
| 634 | + ) |
| 635 | + return |
| 636 | + |
| 637 | + # Map JSON Schema types to their corresponding Python types |
| 638 | + schema_type_to_python_type = { |
| 639 | + "string": str, |
| 640 | + "number": (int, float), |
| 641 | + "integer": int, |
| 642 | + "boolean": bool, |
| 643 | + "array": list, |
| 644 | + "object": dict, |
| 645 | + "null": type(None), |
| 646 | + } |
| 647 | + |
| 648 | + expected_type = schema_type_to_python_type.get(parameter_type_in_schema) |
| 649 | + if expected_type is None: |
| 650 | + raise ValueError( |
| 651 | + f"Tool '{tool}' has unsupported parameter type '{parameter_type_in_schema}' for parameter '{parameter_name}'" |
| 652 | + ) |
| 653 | + |
| 654 | + if not isinstance(parameter_value, expected_type): # type: ignore[arg-type] |
| 655 | + raise ValueError( |
| 656 | + f"Parameter '{parameter_name}' for tool '{tool}' should be of type " |
| 657 | + f"'{parameter_type_in_schema}', but got '{type(parameter_value).__name__}'" |
| 658 | + ) |
| 659 | + |
| 660 | + def validate(self, tool: str, **kwargs: Any) -> None: |
| 661 | + """Validate that a tool call arguments match the tool's schema. |
| 662 | +
|
| 663 | + Args: |
| 664 | + tool: Name of the tool to validate |
| 665 | + **kwargs: Arguments to validate against the tool schema |
| 666 | +
|
| 667 | + Raises: |
| 668 | + MCPError: If the tool is not found |
| 669 | + ValueError: If validation fails (missing required parameters, etc.) |
| 670 | + """ |
| 671 | + # Get tool definition and schema |
| 672 | + tool_definition = self.get_tool(tool) |
| 673 | + schema = tool_definition.get("inputSchema", {}) |
| 674 | + |
| 675 | + # Extract schema components |
| 676 | + parameters_from_schema_properties = schema.get("properties", {}) |
| 677 | + required_parameters = schema.get("required", []) |
| 678 | + |
| 679 | + # Perform validation |
| 680 | + self._validate_schema_type(tool, schema) |
| 681 | + self._validate_required_parameters(tool, kwargs, required_parameters) |
| 682 | + self._validate_unknown_parameters(tool, kwargs, parameters_from_schema_properties) |
| 683 | + |
| 684 | + # Validate parameter types |
| 685 | + for parameter_name, parameter_value in kwargs.items(): |
| 686 | + if parameter_name in parameters_from_schema_properties: |
| 687 | + parameter_schema = parameters_from_schema_properties[parameter_name] |
| 688 | + self._validate_parameter_type( |
| 689 | + tool, parameter_name, parameter_value, parameter_schema |
| 690 | + ) |
| 691 | + |
| 692 | + def close(self) -> None: |
| 693 | + """Close the connection to the MCP server.""" |
| 694 | + self.transport.close() |
| 695 | + self._connected = False |
| 696 | + self._server_info = None |
| 697 | + self._tools_cache = None |
| 698 | + self._request_id = 0 |
0 commit comments