diff --git a/.env.example b/.env.example index ac938a702..857e74eba 100644 --- a/.env.example +++ b/.env.example @@ -115,6 +115,15 @@ MCPGATEWAY_UI_ENABLED=true # Enable the Admin API endpoints (true/false) MCPGATEWAY_ADMIN_API_ENABLED=true +##################################### +# Header Passthrough Configuration +##################################### + +# Default headers to pass through from client requests to backing MCP servers +# Comma-separated list or JSON array format +# Example: ["Authorization", "X-Tenant-Id", "X-Trace-Id"] +DEFAULT_PASSTHROUGH_HEADERS=["Authorization", "X-Tenant-Id", "X-Trace-Id"] + ##################################### # Security and CORS ##################################### diff --git a/.gitignore b/.gitignore index 79d578025..12c5ec6be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +todo/ *.sarif devskim-results.sarif debug_login_page.png diff --git a/docs/docs/overview/passthrough.md b/docs/docs/overview/passthrough.md new file mode 100644 index 000000000..5454efed4 --- /dev/null +++ b/docs/docs/overview/passthrough.md @@ -0,0 +1,291 @@ +# HTTP Header Passthrough + +The MCP Gateway supports **HTTP Header Passthrough**, allowing specific headers from incoming client requests to be forwarded to backing MCP servers. This feature is essential for maintaining authentication context and request tracing across the gateway infrastructure. + +## Overview + +When clients make requests through the MCP Gateway, certain headers (like authentication tokens or trace IDs) need to be preserved and passed to the underlying MCP servers. The header passthrough feature provides a configurable, secure way to forward these headers while preventing conflicts with existing authentication mechanisms. + +## Key Features + +- **Global Configuration**: Set default passthrough headers for all gateways +- **Per-Gateway Override**: Customize header passthrough on a per-gateway basis +- **Conflict Prevention**: Automatically prevents overriding existing authentication headers +- **Admin UI Integration**: Configure passthrough headers through the web interface +- **API Management**: Programmatic control via REST endpoints + +## Configuration + +### Environment Variables + +Set global default headers using the `DEFAULT_PASSTHROUGH_HEADERS` environment variable: + +```bash +# JSON array format +DEFAULT_PASSTHROUGH_HEADERS=["Authorization", "X-Tenant-Id", "X-Trace-Id"] + +# Or in .env file +DEFAULT_PASSTHROUGH_HEADERS=["Authorization", "X-Tenant-Id", "X-Trace-Id"] +``` + +### Admin UI Configuration + +#### Global Configuration +Access the admin interface to set global passthrough headers that apply to all gateways by default. + +#### Per-Gateway Configuration +When creating or editing gateways: + +1. Navigate to the **Gateways** section in the admin UI +2. Click **Add Gateway** or edit an existing gateway +3. In the **Passthrough Headers** field, enter a comma-separated list: + ``` + Authorization, X-Tenant-Id, X-Trace-Id + ``` +4. Gateway-specific headers override global defaults + +### API Configuration + +#### Get Global Configuration +```bash +GET /admin/config/passthrough-headers +``` + +Response: +```json +{ + "passthrough_headers": ["Authorization", "X-Tenant-Id", "X-Trace-Id"] +} +``` + +#### Update Global Configuration +```bash +PUT /admin/config/passthrough-headers +Content-Type: application/json + +{ + "passthrough_headers": ["Authorization", "X-Custom-Header"] +} +``` + +## How It Works + +### Header Processing Flow + +1. **Client Request**: Client sends request with various headers +2. **Header Extraction**: Gateway extracts headers configured for passthrough +3. **Conflict Check**: System verifies no conflicts with existing auth headers +4. **Forwarding**: Allowed headers are added to requests sent to backing MCP servers + +### Configuration Hierarchy + +The system follows this priority order: + +1. **Gateway-specific headers** (highest priority) +2. **Global configuration** (from database) +3. **Environment variable defaults** (lowest priority) + +### Example Flow + +```mermaid +graph LR + A[Client Request] --> B[MCP Gateway] + B --> C{Check Passthrough Config} + C --> D[Extract Configured Headers] + D --> E[Conflict Prevention Check] + E --> F[Forward to MCP Server] + + G[Global Config] --> C + H[Gateway Config] --> C +``` + +## Security Considerations + +### Conflict Prevention + +The system automatically prevents header conflicts: + +- **Basic Auth**: Skips `Authorization` header if gateway uses basic authentication +- **Bearer Auth**: Skips `Authorization` header if gateway uses bearer token authentication +- **Warnings**: Logs warnings when headers are skipped due to conflicts + +### Header Validation + +- Headers are validated before forwarding +- Empty or invalid headers are filtered out +- Only explicitly configured headers are passed through + +## Use Cases + +### Authentication Context +Forward authentication tokens to maintain user context: +```bash +# Client request includes +Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9... + +# Forwarded to MCP server if configured +``` + +### Request Tracing +Maintain trace context across service boundaries: +```bash +# Client request includes +X-Trace-Id: abc123def456 +X-Span-Id: span789 + +# Both forwarded to enable distributed tracing +``` + +### Multi-Tenant Systems +Pass tenant identification: +```bash +# Client request includes +X-Tenant-Id: tenant_12345 +X-Organization: acme_corp + +# Forwarded for tenant-specific processing +``` + +## Configuration Examples + +### Basic Setup +```bash +# .env file +DEFAULT_PASSTHROUGH_HEADERS=["Authorization"] +``` + +### Multi-Header Configuration +```bash +# .env file with multiple headers +DEFAULT_PASSTHROUGH_HEADERS=["Authorization", "X-Tenant-Id", "X-Trace-Id", "X-Request-Id"] +``` + +### Gateway-Specific Override +```json +// Via Admin API for specific gateway +{ + "name": "secure-gateway", + "url": "https://secure-mcp-server.example.com", + "passthrough_headers": ["X-API-Key", "X-Client-Id"] +} +``` + +## Troubleshooting + +### Common Issues + +#### Headers Not Being Forwarded +- Verify header names in configuration match exactly (case-sensitive) +- Check for authentication conflicts in logs +- Ensure gateway configuration overrides aren't blocking headers + +#### Authentication Conflicts +If you see warnings like: +``` +Skipping passthrough header 'Authorization' - conflicts with existing basic auth +``` + +**Solution**: Either: +1. Remove `Authorization` from passthrough headers for that gateway +2. Change the gateway to not use basic/bearer authentication +3. Use a different header name for custom auth tokens + +#### Configuration Not Taking Effect +- Restart the gateway after environment variable changes +- Verify database migration has been applied +- Check admin API responses to confirm configuration is saved + +### Debug Logging + +Enable debug logging to see header processing: +```bash +LOG_LEVEL=DEBUG +``` + +Look for log entries containing: +- `Passthrough headers configured` +- `Skipping passthrough header` +- `Adding passthrough header` + +## API Reference + +### Data Models + +#### GlobalConfig +```python +class GlobalConfig(Base): + id: int + passthrough_headers: Optional[List[str]] +``` + +#### Gateway +```python +class Gateway(Base): + # ... other fields + passthrough_headers: Optional[List[str]] +``` + +### Admin Endpoints + +| Method | Endpoint | Description | +|--------|----------|-------------| +| GET | `/admin/config/passthrough-headers` | Get global configuration | +| PUT | `/admin/config/passthrough-headers` | Update global configuration | +| POST | `/admin/gateways` | Create gateway with headers | +| PUT | `/admin/gateways/{id}` | Update gateway headers | + +## Best Practices + +1. **Minimal Headers**: Only configure headers you actually need to reduce overhead +2. **Security Review**: Regularly audit which headers are being passed through +3. **Environment Consistency**: Use consistent header configuration across environments +4. **Documentation**: Document which headers your MCP servers expect +5. **Monitoring**: Monitor logs for conflict warnings and adjust configuration accordingly + +## Migration Notes + +When upgrading to a version with header passthrough: + +1. **Database Migration**: Ensure the migration `3b17fdc40a8d` has been applied +2. **Configuration Review**: Review existing authentication setup for conflicts +3. **Testing**: Test header forwarding in development before production deployment +4. **Monitoring**: Monitor logs for any unexpected behavior after deployment + +## Testing with the Built-in Test Tool + +The MCP Gateway admin interface includes a built-in test tool with passthrough header support: + +### Using the Test Tool + +1. **Access the Admin UI**: Navigate to the **Tools** section +2. **Select a Tool**: Click the **Test** button on any available tool +3. **Configure Headers**: In the test modal, scroll to the **Passthrough Headers** section +4. **Add Headers**: Enter headers in the format `Header-Name: Value` (one per line): + ``` + Authorization: Bearer your-token-here + X-Tenant-Id: tenant-123 + X-Trace-Id: abc-def-456 + ``` +5. **Run Test**: Click **Run Tool** - the headers will be included in the request + +### Example Test Scenarios + +**Authentication Testing**: +``` +Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9... +``` + +**Multi-Tenant Testing**: +``` +X-Tenant-Id: acme-corp +X-Organization-Id: org-12345 +``` + +**Distributed Tracing**: +``` +X-Trace-Id: trace-abc123 +X-Span-Id: span-def456 +X-Request-Id: req-789xyz +``` + +The test tool provides immediate feedback and allows you to verify that your passthrough header configuration is working correctly before deploying to production. diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 9c33cc4f5..bee408403 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -34,13 +34,15 @@ # First-Party from mcpgateway.config import settings -from mcpgateway.db import get_db +from mcpgateway.db import get_db, GlobalConfig from mcpgateway.schemas import ( GatewayCreate, GatewayRead, GatewayTestRequest, GatewayTestResponse, GatewayUpdate, + GlobalConfigRead, + GlobalConfigUpdate, PromptCreate, PromptMetrics, PromptRead, @@ -88,6 +90,52 @@ #################### +@admin_router.get("/config/passthrough-headers", response_model=GlobalConfigRead) +async def get_global_passthrough_headers( + db: Session = Depends(get_db), + _user: str = Depends(require_auth), +) -> GlobalConfigRead: + """Get the global passthrough headers configuration. + + Args: + db: Database session + _user: Authenticated user + + Returns: + GlobalConfigRead: The current global passthrough headers configuration + """ + config = db.query(GlobalConfig).first() + if not config: + config = GlobalConfig() + return GlobalConfigRead(passthrough_headers=config.passthrough_headers) + + +@admin_router.put("/config/passthrough-headers", response_model=GlobalConfigRead) +async def update_global_passthrough_headers( + config_update: GlobalConfigUpdate, + db: Session = Depends(get_db), + _user: str = Depends(require_auth), +) -> GlobalConfigRead: + """Update the global passthrough headers configuration. + + Args: + config_update: The new configuration + db: Database session + _user: Authenticated user + + Returns: + GlobalConfigRead: The updated configuration + """ + config = db.query(GlobalConfig).first() + if not config: + config = GlobalConfig(passthrough_headers=config_update.passthrough_headers) + db.add(config) + else: + config.passthrough_headers = config_update.passthrough_headers + db.commit() + return GlobalConfigRead(passthrough_headers=config.passthrough_headers) + + @admin_router.get("/servers", response_model=List[ServerRead]) async def admin_list_servers( include_inactive: bool = False, diff --git a/mcpgateway/alembic/versions/3b17fdc40a8d_add_passthrough_headers_to_gateways_and_.py b/mcpgateway/alembic/versions/3b17fdc40a8d_add_passthrough_headers_to_gateways_and_.py new file mode 100644 index 000000000..c461cfee5 --- /dev/null +++ b/mcpgateway/alembic/versions/3b17fdc40a8d_add_passthrough_headers_to_gateways_and_.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +"""Add passthrough headers to gateways and global config + +Revision ID: 3b17fdc40a8d +Revises: e75490e949b1 +Create Date: 2025-08-08 03:45:46.489696 + +""" + +# Standard +from typing import Sequence, Union + +# Third-Party +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "3b17fdc40a8d" +down_revision: Union[str, Sequence[str], None] = "e75490e949b1" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # Create global_config table + op.create_table("global_config", sa.Column("id", sa.Integer(), nullable=False), sa.Column("passthrough_headers", sa.JSON(), nullable=True), sa.PrimaryKeyConstraint("id")) + + # Add passthrough_headers column to gateways table + op.add_column("gateways", sa.Column("passthrough_headers", sa.JSON(), nullable=True)) + + +def downgrade() -> None: + """Downgrade schema.""" + # Remove passthrough_headers column from gateways table + op.drop_column("gateways", "passthrough_headers") + + # Drop global_config table + op.drop_table("global_config") diff --git a/mcpgateway/alembic/versions/eb17fd368f9d_merge_passthrough_headers_and_tags_.py b/mcpgateway/alembic/versions/eb17fd368f9d_merge_passthrough_headers_and_tags_.py new file mode 100644 index 000000000..071d9c5bc --- /dev/null +++ b/mcpgateway/alembic/versions/eb17fd368f9d_merge_passthrough_headers_and_tags_.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +"""merge passthrough headers and tags support + +Revision ID: eb17fd368f9d +Revises: 3b17fdc40a8d, cc7b95fec5d9 +Create Date: 2025-08-08 05:31:10.857718 + +""" + +# Standard +from typing import Sequence, Union + +# revision identifiers, used by Alembic. +revision: str = "eb17fd368f9d" +down_revision: Union[str, Sequence[str], None] = ("3b17fdc40a8d", "cc7b95fec5d9") +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + + +def downgrade() -> None: + """Downgrade schema.""" diff --git a/mcpgateway/config.py b/mcpgateway/config.py index db80a501e..c58506a13 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -52,6 +52,7 @@ from importlib.resources import files import json import logging +import os from pathlib import Path import re from typing import Annotated, Any, ClassVar, Dict, List, Optional, Set, Union @@ -548,9 +549,27 @@ def validate_database(self) -> None: # Rate limiting validation_max_requests_per_minute: int = 60 + # passthrough headers + default_passthrough_headers: Any = os.environ.get("DEFAULT_PASSTHROUGH_HEADERS", ["Authorization", "X-Tenant-Id", "X-Trace-Id"]) + if not isinstance(default_passthrough_headers, list): + try: + default_passthrough_headers = list(default_passthrough_headers) + except Exception as e: + logger.warning(f"Invalid DEFAULT_PASSTHROUGH_HEADERS format in .env. Must be a list of header names, e.g. ['Authorization', 'X-Tenant-Id'], error: {e}") + default_passthrough_headers = [] + # Masking value for all sensitive data masked_auth_value: str = "*****" + # passthrough headers + default_passthrough_headers: Any = os.environ.get("DEFAULT_PASSTHROUGH_HEADERS", ["Authorization", "X-Tenant-Id", "X-Trace-Id"]) + if not isinstance(default_passthrough_headers, list): + try: + default_passthrough_headers = list(default_passthrough_headers) + except Exception as e: + logger.warning(f"Invalid DEFAULT_PASSTHROUGH_HEADERS format in .env. Must be a list of header names, e.g. ['Authorization', 'X-Tenant-Id'], error: {e}") + default_passthrough_headers = [] + def extract_using_jq(data, jq_filter=""): """ diff --git a/mcpgateway/db.py b/mcpgateway/db.py index 3f49ad464..b205da631 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -28,24 +28,7 @@ # Third-Party import jsonschema -from sqlalchemy import ( - Boolean, - Column, - create_engine, - DateTime, - event, - Float, - ForeignKey, - func, - Integer, - JSON, - make_url, - select, - String, - Table, - Text, - UniqueConstraint, -) +from sqlalchemy import Boolean, Column, create_engine, DateTime, event, Float, ForeignKey, func, Integer, JSON, make_url, select, String, Table, Text, UniqueConstraint from sqlalchemy.event import listen from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.hybrid import hybrid_property @@ -200,6 +183,20 @@ class Base(DeclarativeBase): ) +class GlobalConfig(Base): + """Global configuration settings. + + Attributes: + id (int): Primary key + passthrough_headers (List[str]): List of headers allowed to be passed through globally + """ + + __tablename__ = "global_config" + + id = Column(Integer, primary_key=True) + passthrough_headers: Mapped[Optional[List[str]]] = Column(JSON, nullable=True) # Store list of strings as JSON array + + class ToolMetric(Base): """ ORM model for recording individual metrics for tool executions. @@ -1117,6 +1114,12 @@ class Gateway(Base): last_seen: Mapped[Optional[datetime]] tags: Mapped[List[str]] = mapped_column(JSON, default=list, nullable=False) + # Header passthrough configuration + passthrough_headers: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) # Store list of strings as JSON array + + # Header passthrough configuration + passthrough_headers: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) # Store list of strings as JSON array + # Relationship with local tools this gateway provides tools: Mapped[List["Tool"]] = relationship(back_populates="gateway", foreign_keys="Tool.gateway_id", cascade="all, delete-orphan") diff --git a/mcpgateway/federation/forward.py b/mcpgateway/federation/forward.py index 5ff5f83dc..159ef52d6 100644 --- a/mcpgateway/federation/forward.py +++ b/mcpgateway/federation/forward.py @@ -35,6 +35,7 @@ from mcpgateway.db import Gateway as DbGateway from mcpgateway.db import Tool as DbTool from mcpgateway.models import ToolResult +from mcpgateway.utils.passthrough_headers import get_passthrough_headers logger = logging.getLogger(__name__) @@ -122,6 +123,7 @@ async def forward_request( method: str, params: Optional[Dict[str, Any]] = None, target_gateway_id: Optional[int] = None, + request_headers: Optional[Dict[str, str]] = None, ) -> Any: """Forward a request to gateway(s). @@ -134,6 +136,8 @@ async def forward_request( method: RPC method name (e.g., "tools/list", "resources/read") params: Optional method parameters as key-value pairs target_gateway_id: Optional specific gateway ID for targeted forwarding + request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Defaults to None. Returns: Any: Single gateway response for targeted requests (when target_gateway_id @@ -173,15 +177,15 @@ async def forward_request( try: if target_gateway_id: # Forward to specific gateway - return await self._forward_to_gateway(db, target_gateway_id, method, params) + return await self._forward_to_gateway(db, target_gateway_id, method, params, request_headers) - # Forward to all relevant gateways - return await self._forward_to_all(db, method, params) + # Forward to all relevant gateways - headers are passed to each gateway + return await self._forward_to_all(db, method, params, request_headers) except Exception as e: raise ForwardingError(f"Forward request failed: {str(e)}") - async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dict[str, Any]) -> ToolResult: + async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dict[str, Any], request_headers: Optional[Dict[str, str]] = None) -> ToolResult: """Forward a tool invocation request. Locates the specified tool in the database, verifies it's federated, @@ -192,6 +196,8 @@ async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dic db: Database session for tool and gateway lookups tool_name: Name of the tool to invoke arguments: Tool arguments as key-value pairs + request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Defaults to None. Returns: ToolResult object containing the tool execution results @@ -252,12 +258,7 @@ async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dic raise ForwardingError(f"Tool {tool_name} is not federated") # Forward to gateway - result = await self._forward_to_gateway( - db, - tool.gateway_id, - "tools/invoke", - {"name": tool_name, "arguments": arguments}, - ) + result = await self._forward_to_gateway(db, tool.gateway_id, "tools/invoke", {"name": tool_name, "arguments": arguments}, request_headers) # Parse result return ToolResult( @@ -353,6 +354,7 @@ async def _forward_to_gateway( gateway_id: str, method: str, params: Optional[Dict[str, Any]] = None, + request_headers: Optional[Dict[str, str]] = None, ) -> Any: """Forward request to a specific gateway. @@ -365,6 +367,8 @@ async def _forward_to_gateway( gateway_id: ID of the gateway to forward to method: RPC method name params: Optional method parameters + request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Defaults to None. Returns: The 'result' field from the gateway's JSON-RPC response @@ -433,10 +437,15 @@ async def _forward_to_gateway( # Send request with retries using the persistent client directly for attempt in range(settings.max_tool_retries): try: + # Merge auth headers with passthrough headers + headers = self._get_auth_headers() + if request_headers: + headers = get_passthrough_headers(request_headers, headers, db, gateway) + response = await self._http_client.post( f"{gateway.url}/rpc", json=request, - headers=self._get_auth_headers(), + headers=headers, ) response.raise_for_status() result = response.json() @@ -457,7 +466,7 @@ async def _forward_to_gateway( except Exception as e: raise ForwardingError(f"Failed to forward to {gateway.name}: {str(e)}") - async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[str, Any]] = None) -> List[Any]: + async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[str, Any]] = None, request_headers: Optional[Dict[str, str]] = None) -> List[Any]: """Forward request to all active gateways. Broadcasts the same request to all enabled gateways in parallel, @@ -468,6 +477,8 @@ async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[ db: Database session for gateway queries method: RPC method name to invoke on all gateways params: Optional method parameters + request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Defaults to None. Returns: List of successful responses from active gateways @@ -492,7 +503,7 @@ async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[ >>> db.execute.return_value = mock_result >>> >>> # Mock forwarding with mixed results - >>> async def mock_forward(db, gw_id, method, params=None): + >>> async def mock_forward(db, gw_id, method, params=None, request_headers=None): ... if gw_id == 1: ... return {"gateway": "gw1", "status": "ok"} ... elif gw_id == 2: @@ -510,7 +521,7 @@ async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[ 'gw3' >>> # Test all gateways failing - >>> async def mock_all_fail(db, gw_id, method, params=None): + >>> async def mock_all_fail(db, gw_id, method, params=None, request_headers=None): ... raise Exception(f"Gateway {gw_id} failed") >>> >>> service._forward_to_gateway = mock_all_fail @@ -529,7 +540,7 @@ async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[ for gateway in gateways: try: - result = await self._forward_to_gateway(db, gateway.id, method, params) + result = await self._forward_to_gateway(db, gateway.id, method, params, request_headers) results.append(result) except Exception as e: errors.append(str(e)) diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 77d444125..0ddb6e826 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -2134,8 +2134,10 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user: str # Per the MCP spec, a ping returns an empty result. result = {} else: + # Get request headers + headers = {k.lower(): v for k, v in request.headers.items()} try: - result = await tool_service.invoke_tool(db=db, name=method, arguments=params) + result = await tool_service.invoke_tool(db=db, name=method, arguments=params, request_headers=headers) if hasattr(result, "model_dump"): result = result.model_dump(by_alias=True, exclude_none=True) except ValueError: diff --git a/mcpgateway/models.py b/mcpgateway/models.py index b577f61e8..1efc619b2 100644 --- a/mcpgateway/models.py +++ b/mcpgateway/models.py @@ -610,6 +610,17 @@ class JSONRPCError(BaseModel): data: Optional[Any] = None +# Global configuration types +class GlobalConfig(BaseModel): + """Global server configuration. + + Attributes: + passthrough_headers (Optional[List[str]]): List of headers allowed to be passed through globally + """ + + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through globally") + + # Transport message types class SSEEvent(BaseModel): """Server-Sent Events message. diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index 6015d2028..aa02fc6bc 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -1658,6 +1658,27 @@ class PromptInvocation(BaseModelWithConfigDict): arguments: Dict[str, str] = Field(default_factory=dict, description="Arguments for template rendering") +# --- Global Config Schemas --- +class GlobalConfigUpdate(BaseModel): + """Schema for updating global configuration. + + Attributes: + passthrough_headers (Optional[List[str]]): List of headers allowed to be passed through globally + """ + + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through globally") + + +class GlobalConfigRead(BaseModel): + """Schema for reading global configuration. + + Attributes: + passthrough_headers (Optional[List[str]]): List of headers allowed to be passed through globally + """ + + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through globally") + + # --- Gateway Schemas --- @@ -1704,6 +1725,7 @@ class GatewayCreate(BaseModel): url: Union[str, AnyHttpUrl] = Field(..., description="Gateway endpoint URL") description: Optional[str] = Field(None, description="Gateway description") transport: str = Field(default="SSE", description="Transport used by MCP server: SSE or STREAMABLEHTTP") + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through from client to target") # Authorizations auth_type: Optional[str] = Field(None, description="Type of authentication: basic, bearer, headers, or none") @@ -1888,13 +1910,15 @@ class GatewayUpdate(BaseModelWithConfigDict): url: Optional[Union[str, AnyHttpUrl]] = Field(None, description="Gateway endpoint URL") description: Optional[str] = Field(None, description="Gateway description") transport: str = Field(default="SSE", description="Transport used by MCP server: SSE or STREAMABLEHTTP") + + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through from client to target") # Authorizations auth_type: Optional[str] = Field(None, description="auth_type: basic, bearer, headers or None") auth_username: Optional[str] = Field(None, description="username for basic authentication") auth_password: Optional[str] = Field(None, description="password for basic authentication") auth_token: Optional[str] = Field(None, description="token for bearer authentication") auth_header_key: Optional[str] = Field(None, description="key for custom headers authentication") - auth_header_value: Optional[str] = Field(None, description="vallue for custom headers authentication") + auth_header_value: Optional[str] = Field(None, description="value for custom headers authentication") # Adding `auth_value` as an alias for better access post-validation auth_value: Optional[str] = Field(None, validate_default=True) @@ -2070,6 +2094,7 @@ class GatewayRead(BaseModelWithConfigDict): last_seen: Optional[datetime] = Field(default_factory=lambda: datetime.now(timezone.utc), description="Last seen timestamp") + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through from client to target") # Authorizations auth_type: Optional[str] = Field(None, description="auth_type: basic, bearer, headers or None") auth_value: Optional[str] = Field(None, description="auth value: username/password or token or custom headers") diff --git a/mcpgateway/services/__init__.py b/mcpgateway/services/__init__.py index 88c4fd675..68a13bffd 100644 --- a/mcpgateway/services/__init__.py +++ b/mcpgateway/services/__init__.py @@ -12,6 +12,7 @@ - Gateway coordination """ +# First-Party from mcpgateway.services.gateway_service import GatewayError, GatewayService from mcpgateway.services.prompt_service import PromptError, PromptService from mcpgateway.services.resource_service import ResourceError, ResourceService diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 1c5263089..21ceecbfa 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -45,6 +45,7 @@ ToolUpdate, ) from mcpgateway.utils.create_slug import slugify +from mcpgateway.utils.passthrough_headers import get_passthrough_headers from mcpgateway.utils.retry_manager import ResilientHttpClient from mcpgateway.utils.services_auth import decode_auth @@ -332,7 +333,9 @@ async def register_tool(self, db: Session, tool: ToolCreate) -> ToolRead: except Exception as ex: raise ToolError(f"Failed to register tool: {str(ex)}") - async def list_tools(self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None, tags: Optional[List[str]] = None) -> List[ToolRead]: + async def list_tools( + self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None, tags: Optional[List[str]] = None, _request_headers: Optional[Dict[str, str]] = None + ) -> List[ToolRead]: """ Retrieve a list of registered tools from the database. @@ -343,6 +346,8 @@ async def list_tools(self, db: Session, include_inactive: bool = False, cursor: cursor (Optional[str], optional): An opaque cursor token for pagination. Currently, this parameter is ignored. Defaults to None. tags (Optional[List[str]]): Filter tools by tags. If provided, only tools with at least one matching tag will be returned. + _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Currently unused but kept for API consistency. Defaults to None. Returns: List[ToolRead]: A list of registered tools represented as ToolRead objects. @@ -378,7 +383,7 @@ async def list_tools(self, db: Session, include_inactive: bool = False, cursor: tools = db.execute(query).scalars().all() return [self._convert_tool_to_read(t) for t in tools] - async def list_server_tools(self, db: Session, server_id: str, include_inactive: bool = False, cursor: Optional[str] = None) -> List[ToolRead]: + async def list_server_tools(self, db: Session, server_id: str, include_inactive: bool = False, cursor: Optional[str] = None, _request_headers: Optional[Dict[str, str]] = None) -> List[ToolRead]: """ Retrieve a list of registered tools from the database. @@ -389,6 +394,8 @@ async def list_server_tools(self, db: Session, server_id: str, include_inactive: Defaults to False. cursor (Optional[str], optional): An opaque cursor token for pagination. Currently, this parameter is ignored. Defaults to None. + _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Currently unused but kept for API consistency. Defaults to None. Returns: List[ToolRead]: A list of registered tools represented as ToolRead objects. @@ -548,7 +555,7 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re db.rollback() raise ToolError(f"Failed to toggle tool status: {str(e)}") - async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -> ToolResult: + async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], request_headers: Optional[Dict[str, str]] = None) -> ToolResult: """ Invoke a registered tool and record execution metrics. @@ -556,6 +563,8 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) - db: Database session. name: Name of tool to invoke. arguments: Tool arguments. + request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Defaults to None. Returns: Tool invocation result. @@ -595,14 +604,17 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) - success = False error_message = None try: - # tool.validate_arguments(arguments) - # Build headers with auth if necessary. + # Get combined headers for the tool including base headers, auth, and passthrough headers + # headers = self._get_combined_headers(db, tool, tool.headers or {}, request_headers) headers = tool.headers or {} if tool.integration_type == "REST": credentials = decode_auth(tool.auth_value) headers.update(credentials) - # Build the payload based on integration type. + # Only call get_passthrough_headers if we actually have request headers to pass through + if request_headers: + headers = get_passthrough_headers(request_headers, headers, db) + # Build the payload based on integration type payload = arguments.copy() # Handle URL path parameter substitution @@ -647,6 +659,11 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) - gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.enabled)).scalar_one_or_none() headers = decode_auth(gateway.auth_value) + # Get combined headers including gateway auth and passthrough + # base_headers = decode_auth(gateway.auth_value) if gateway and gateway.auth_value else {} + if request_headers: + headers = get_passthrough_headers(request_headers, headers, db, gateway) + async def connect_to_sse_server(server_url: str) -> str: """ Connect to an MCP server running with SSE transport diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index d21600ff0..d1c81a92f 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -2471,6 +2471,22 @@ async function editGateway(gatewayId) { break; } + // Handle passthrough headers + const passthroughHeadersField = safeGetElement( + "edit-gateway-passthrough-headers", + ); + if (passthroughHeadersField) { + if ( + gateway.passthroughHeaders && + Array.isArray(gateway.passthroughHeaders) + ) { + passthroughHeadersField.value = + gateway.passthroughHeaders.join(", "); + } else { + passthroughHeadersField.value = ""; + } + } + openModal("gateway-edit-modal"); console.log("✓ Gateway edit modal loaded successfully"); } catch (error) { @@ -3763,14 +3779,43 @@ async function runToolTest() { params, }; + // Parse custom headers from the passthrough headers field + const requestHeaders = { + "Content-Type": "application/json", + }; + + const passthroughHeadersField = document.getElementById( + "test-passthrough-headers", + ); + if (passthroughHeadersField && passthroughHeadersField.value.trim()) { + const headerLines = passthroughHeadersField.value + .trim() + .split("\n"); + for (const line of headerLines) { + const trimmedLine = line.trim(); + if (trimmedLine) { + const colonIndex = trimmedLine.indexOf(":"); + if (colonIndex > 0) { + const headerName = trimmedLine + .substring(0, colonIndex) + .trim(); + const headerValue = trimmedLine + .substring(colonIndex + 1) + .trim(); + if (headerName && headerValue) { + requestHeaders[headerName] = headerValue; + } + } + } + } + } + // Use longer timeout for test execution const response = await fetchWithTimeout( `${window.ROOT_PATH}/rpc`, { method: "POST", - headers: { - "Content-Type": "application/json", - }, + headers: requestHeaders, body: JSON.stringify(payload), credentials: "include", }, @@ -4510,6 +4555,23 @@ async function handleGatewayFormSubmit(e) { const isInactiveCheckedBool = isInactiveChecked("gateways"); formData.append("is_inactive_checked", isInactiveCheckedBool); + // Process passthrough headers - convert comma-separated string to array + const passthroughHeadersString = formData.get("passthrough_headers"); + if (passthroughHeadersString && passthroughHeadersString.trim()) { + // Split by comma and clean up each header name + const passthroughHeaders = passthroughHeadersString + .split(",") + .map((header) => header.trim()) + .filter((header) => header.length > 0); + + // Remove the original string and add as JSON array + formData.delete("passthrough_headers"); + formData.append( + "passthrough_headers", + JSON.stringify(passthroughHeaders), + ); + } + const response = await fetchWithTimeout( `${window.ROOT_PATH}/admin/gateways`, { @@ -4921,6 +4983,18 @@ async function handleEditGatewayFormSubmit(e) { throw new Error(urlValidation.error); } + // Handle passthrough headers + const passthroughHeadersString = + formData.get("passthrough_headers") || ""; + const passthroughHeaders = passthroughHeadersString + .split(",") + .map((header) => header.trim()) + .filter((header) => header.length > 0); + formData.append( + "passthrough_headers", + JSON.stringify(passthroughHeaders), + ); + const isInactiveCheckedBool = isInactiveChecked("gateways"); formData.append("is_inactive_checked", isInactiveCheckedBool); // Submit via fetch diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 6532710cd..0ef67b5ff 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -2215,6 +2215,23 @@