diff --git a/.env.example b/.env.example index ac938a702..857e74eba 100644 --- a/.env.example +++ b/.env.example @@ -115,6 +115,15 @@ MCPGATEWAY_UI_ENABLED=true # Enable the Admin API endpoints (true/false) MCPGATEWAY_ADMIN_API_ENABLED=true +##################################### +# Header Passthrough Configuration +##################################### + +# Default headers to pass through from client requests to backing MCP servers +# Comma-separated list or JSON array format +# Example: ["Authorization", "X-Tenant-Id", "X-Trace-Id"] +DEFAULT_PASSTHROUGH_HEADERS=["Authorization", "X-Tenant-Id", "X-Trace-Id"] + ##################################### # Security and CORS ##################################### diff --git a/.gitignore b/.gitignore index 79d578025..12c5ec6be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +todo/ *.sarif devskim-results.sarif debug_login_page.png diff --git a/docs/docs/overview/passthrough.md b/docs/docs/overview/passthrough.md new file mode 100644 index 000000000..5454efed4 --- /dev/null +++ b/docs/docs/overview/passthrough.md @@ -0,0 +1,291 @@ +# HTTP Header Passthrough + +The MCP Gateway supports **HTTP Header Passthrough**, allowing specific headers from incoming client requests to be forwarded to backing MCP servers. This feature is essential for maintaining authentication context and request tracing across the gateway infrastructure. + +## Overview + +When clients make requests through the MCP Gateway, certain headers (like authentication tokens or trace IDs) need to be preserved and passed to the underlying MCP servers. The header passthrough feature provides a configurable, secure way to forward these headers while preventing conflicts with existing authentication mechanisms. + +## Key Features + +- **Global Configuration**: Set default passthrough headers for all gateways +- **Per-Gateway Override**: Customize header passthrough on a per-gateway basis +- **Conflict Prevention**: Automatically prevents overriding existing authentication headers +- **Admin UI Integration**: Configure passthrough headers through the web interface +- **API Management**: Programmatic control via REST endpoints + +## Configuration + +### Environment Variables + +Set global default headers using the `DEFAULT_PASSTHROUGH_HEADERS` environment variable: + +```bash +# JSON array format +DEFAULT_PASSTHROUGH_HEADERS=["Authorization", "X-Tenant-Id", "X-Trace-Id"] + +# Or in .env file +DEFAULT_PASSTHROUGH_HEADERS=["Authorization", "X-Tenant-Id", "X-Trace-Id"] +``` + +### Admin UI Configuration + +#### Global Configuration +Access the admin interface to set global passthrough headers that apply to all gateways by default. + +#### Per-Gateway Configuration +When creating or editing gateways: + +1. Navigate to the **Gateways** section in the admin UI +2. Click **Add Gateway** or edit an existing gateway +3. In the **Passthrough Headers** field, enter a comma-separated list: + ``` + Authorization, X-Tenant-Id, X-Trace-Id + ``` +4. Gateway-specific headers override global defaults + +### API Configuration + +#### Get Global Configuration +```bash +GET /admin/config/passthrough-headers +``` + +Response: +```json +{ + "passthrough_headers": ["Authorization", "X-Tenant-Id", "X-Trace-Id"] +} +``` + +#### Update Global Configuration +```bash +PUT /admin/config/passthrough-headers +Content-Type: application/json + +{ + "passthrough_headers": ["Authorization", "X-Custom-Header"] +} +``` + +## How It Works + +### Header Processing Flow + +1. **Client Request**: Client sends request with various headers +2. **Header Extraction**: Gateway extracts headers configured for passthrough +3. **Conflict Check**: System verifies no conflicts with existing auth headers +4. **Forwarding**: Allowed headers are added to requests sent to backing MCP servers + +### Configuration Hierarchy + +The system follows this priority order: + +1. **Gateway-specific headers** (highest priority) +2. **Global configuration** (from database) +3. **Environment variable defaults** (lowest priority) + +### Example Flow + +```mermaid +graph LR + A[Client Request] --> B[MCP Gateway] + B --> C{Check Passthrough Config} + C --> D[Extract Configured Headers] + D --> E[Conflict Prevention Check] + E --> F[Forward to MCP Server] + + G[Global Config] --> C + H[Gateway Config] --> C +``` + +## Security Considerations + +### Conflict Prevention + +The system automatically prevents header conflicts: + +- **Basic Auth**: Skips `Authorization` header if gateway uses basic authentication +- **Bearer Auth**: Skips `Authorization` header if gateway uses bearer token authentication +- **Warnings**: Logs warnings when headers are skipped due to conflicts + +### Header Validation + +- Headers are validated before forwarding +- Empty or invalid headers are filtered out +- Only explicitly configured headers are passed through + +## Use Cases + +### Authentication Context +Forward authentication tokens to maintain user context: +```bash +# Client request includes +Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9... + +# Forwarded to MCP server if configured +``` + +### Request Tracing +Maintain trace context across service boundaries: +```bash +# Client request includes +X-Trace-Id: abc123def456 +X-Span-Id: span789 + +# Both forwarded to enable distributed tracing +``` + +### Multi-Tenant Systems +Pass tenant identification: +```bash +# Client request includes +X-Tenant-Id: tenant_12345 +X-Organization: acme_corp + +# Forwarded for tenant-specific processing +``` + +## Configuration Examples + +### Basic Setup +```bash +# .env file +DEFAULT_PASSTHROUGH_HEADERS=["Authorization"] +``` + +### Multi-Header Configuration +```bash +# .env file with multiple headers +DEFAULT_PASSTHROUGH_HEADERS=["Authorization", "X-Tenant-Id", "X-Trace-Id", "X-Request-Id"] +``` + +### Gateway-Specific Override +```json +// Via Admin API for specific gateway +{ + "name": "secure-gateway", + "url": "https://secure-mcp-server.example.com", + "passthrough_headers": ["X-API-Key", "X-Client-Id"] +} +``` + +## Troubleshooting + +### Common Issues + +#### Headers Not Being Forwarded +- Verify header names in configuration match exactly (case-sensitive) +- Check for authentication conflicts in logs +- Ensure gateway configuration overrides aren't blocking headers + +#### Authentication Conflicts +If you see warnings like: +``` +Skipping passthrough header 'Authorization' - conflicts with existing basic auth +``` + +**Solution**: Either: +1. Remove `Authorization` from passthrough headers for that gateway +2. Change the gateway to not use basic/bearer authentication +3. Use a different header name for custom auth tokens + +#### Configuration Not Taking Effect +- Restart the gateway after environment variable changes +- Verify database migration has been applied +- Check admin API responses to confirm configuration is saved + +### Debug Logging + +Enable debug logging to see header processing: +```bash +LOG_LEVEL=DEBUG +``` + +Look for log entries containing: +- `Passthrough headers configured` +- `Skipping passthrough header` +- `Adding passthrough header` + +## API Reference + +### Data Models + +#### GlobalConfig +```python +class GlobalConfig(Base): + id: int + passthrough_headers: Optional[List[str]] +``` + +#### Gateway +```python +class Gateway(Base): + # ... other fields + passthrough_headers: Optional[List[str]] +``` + +### Admin Endpoints + +| Method | Endpoint | Description | +|--------|----------|-------------| +| GET | `/admin/config/passthrough-headers` | Get global configuration | +| PUT | `/admin/config/passthrough-headers` | Update global configuration | +| POST | `/admin/gateways` | Create gateway with headers | +| PUT | `/admin/gateways/{id}` | Update gateway headers | + +## Best Practices + +1. **Minimal Headers**: Only configure headers you actually need to reduce overhead +2. **Security Review**: Regularly audit which headers are being passed through +3. **Environment Consistency**: Use consistent header configuration across environments +4. **Documentation**: Document which headers your MCP servers expect +5. **Monitoring**: Monitor logs for conflict warnings and adjust configuration accordingly + +## Migration Notes + +When upgrading to a version with header passthrough: + +1. **Database Migration**: Ensure the migration `3b17fdc40a8d` has been applied +2. **Configuration Review**: Review existing authentication setup for conflicts +3. **Testing**: Test header forwarding in development before production deployment +4. **Monitoring**: Monitor logs for any unexpected behavior after deployment + +## Testing with the Built-in Test Tool + +The MCP Gateway admin interface includes a built-in test tool with passthrough header support: + +### Using the Test Tool + +1. **Access the Admin UI**: Navigate to the **Tools** section +2. **Select a Tool**: Click the **Test** button on any available tool +3. **Configure Headers**: In the test modal, scroll to the **Passthrough Headers** section +4. **Add Headers**: Enter headers in the format `Header-Name: Value` (one per line): + ``` + Authorization: Bearer your-token-here + X-Tenant-Id: tenant-123 + X-Trace-Id: abc-def-456 + ``` +5. **Run Test**: Click **Run Tool** - the headers will be included in the request + +### Example Test Scenarios + +**Authentication Testing**: +``` +Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9... +``` + +**Multi-Tenant Testing**: +``` +X-Tenant-Id: acme-corp +X-Organization-Id: org-12345 +``` + +**Distributed Tracing**: +``` +X-Trace-Id: trace-abc123 +X-Span-Id: span-def456 +X-Request-Id: req-789xyz +``` + +The test tool provides immediate feedback and allows you to verify that your passthrough header configuration is working correctly before deploying to production. diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 9c33cc4f5..bee408403 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -34,13 +34,15 @@ # First-Party from mcpgateway.config import settings -from mcpgateway.db import get_db +from mcpgateway.db import get_db, GlobalConfig from mcpgateway.schemas import ( GatewayCreate, GatewayRead, GatewayTestRequest, GatewayTestResponse, GatewayUpdate, + GlobalConfigRead, + GlobalConfigUpdate, PromptCreate, PromptMetrics, PromptRead, @@ -88,6 +90,52 @@ #################### +@admin_router.get("/config/passthrough-headers", response_model=GlobalConfigRead) +async def get_global_passthrough_headers( + db: Session = Depends(get_db), + _user: str = Depends(require_auth), +) -> GlobalConfigRead: + """Get the global passthrough headers configuration. + + Args: + db: Database session + _user: Authenticated user + + Returns: + GlobalConfigRead: The current global passthrough headers configuration + """ + config = db.query(GlobalConfig).first() + if not config: + config = GlobalConfig() + return GlobalConfigRead(passthrough_headers=config.passthrough_headers) + + +@admin_router.put("/config/passthrough-headers", response_model=GlobalConfigRead) +async def update_global_passthrough_headers( + config_update: GlobalConfigUpdate, + db: Session = Depends(get_db), + _user: str = Depends(require_auth), +) -> GlobalConfigRead: + """Update the global passthrough headers configuration. + + Args: + config_update: The new configuration + db: Database session + _user: Authenticated user + + Returns: + GlobalConfigRead: The updated configuration + """ + config = db.query(GlobalConfig).first() + if not config: + config = GlobalConfig(passthrough_headers=config_update.passthrough_headers) + db.add(config) + else: + config.passthrough_headers = config_update.passthrough_headers + db.commit() + return GlobalConfigRead(passthrough_headers=config.passthrough_headers) + + @admin_router.get("/servers", response_model=List[ServerRead]) async def admin_list_servers( include_inactive: bool = False, diff --git a/mcpgateway/alembic/versions/3b17fdc40a8d_add_passthrough_headers_to_gateways_and_.py b/mcpgateway/alembic/versions/3b17fdc40a8d_add_passthrough_headers_to_gateways_and_.py new file mode 100644 index 000000000..c461cfee5 --- /dev/null +++ b/mcpgateway/alembic/versions/3b17fdc40a8d_add_passthrough_headers_to_gateways_and_.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +"""Add passthrough headers to gateways and global config + +Revision ID: 3b17fdc40a8d +Revises: e75490e949b1 +Create Date: 2025-08-08 03:45:46.489696 + +""" + +# Standard +from typing import Sequence, Union + +# Third-Party +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "3b17fdc40a8d" +down_revision: Union[str, Sequence[str], None] = "e75490e949b1" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # Create global_config table + op.create_table("global_config", sa.Column("id", sa.Integer(), nullable=False), sa.Column("passthrough_headers", sa.JSON(), nullable=True), sa.PrimaryKeyConstraint("id")) + + # Add passthrough_headers column to gateways table + op.add_column("gateways", sa.Column("passthrough_headers", sa.JSON(), nullable=True)) + + +def downgrade() -> None: + """Downgrade schema.""" + # Remove passthrough_headers column from gateways table + op.drop_column("gateways", "passthrough_headers") + + # Drop global_config table + op.drop_table("global_config") diff --git a/mcpgateway/alembic/versions/eb17fd368f9d_merge_passthrough_headers_and_tags_.py b/mcpgateway/alembic/versions/eb17fd368f9d_merge_passthrough_headers_and_tags_.py new file mode 100644 index 000000000..071d9c5bc --- /dev/null +++ b/mcpgateway/alembic/versions/eb17fd368f9d_merge_passthrough_headers_and_tags_.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +"""merge passthrough headers and tags support + +Revision ID: eb17fd368f9d +Revises: 3b17fdc40a8d, cc7b95fec5d9 +Create Date: 2025-08-08 05:31:10.857718 + +""" + +# Standard +from typing import Sequence, Union + +# revision identifiers, used by Alembic. +revision: str = "eb17fd368f9d" +down_revision: Union[str, Sequence[str], None] = ("3b17fdc40a8d", "cc7b95fec5d9") +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + + +def downgrade() -> None: + """Downgrade schema.""" diff --git a/mcpgateway/config.py b/mcpgateway/config.py index db80a501e..c58506a13 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -52,6 +52,7 @@ from importlib.resources import files import json import logging +import os from pathlib import Path import re from typing import Annotated, Any, ClassVar, Dict, List, Optional, Set, Union @@ -548,9 +549,27 @@ def validate_database(self) -> None: # Rate limiting validation_max_requests_per_minute: int = 60 + # passthrough headers + default_passthrough_headers: Any = os.environ.get("DEFAULT_PASSTHROUGH_HEADERS", ["Authorization", "X-Tenant-Id", "X-Trace-Id"]) + if not isinstance(default_passthrough_headers, list): + try: + default_passthrough_headers = list(default_passthrough_headers) + except Exception as e: + logger.warning(f"Invalid DEFAULT_PASSTHROUGH_HEADERS format in .env. Must be a list of header names, e.g. ['Authorization', 'X-Tenant-Id'], error: {e}") + default_passthrough_headers = [] + # Masking value for all sensitive data masked_auth_value: str = "*****" + # passthrough headers + default_passthrough_headers: Any = os.environ.get("DEFAULT_PASSTHROUGH_HEADERS", ["Authorization", "X-Tenant-Id", "X-Trace-Id"]) + if not isinstance(default_passthrough_headers, list): + try: + default_passthrough_headers = list(default_passthrough_headers) + except Exception as e: + logger.warning(f"Invalid DEFAULT_PASSTHROUGH_HEADERS format in .env. Must be a list of header names, e.g. ['Authorization', 'X-Tenant-Id'], error: {e}") + default_passthrough_headers = [] + def extract_using_jq(data, jq_filter=""): """ diff --git a/mcpgateway/db.py b/mcpgateway/db.py index 3f49ad464..b205da631 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -28,24 +28,7 @@ # Third-Party import jsonschema -from sqlalchemy import ( - Boolean, - Column, - create_engine, - DateTime, - event, - Float, - ForeignKey, - func, - Integer, - JSON, - make_url, - select, - String, - Table, - Text, - UniqueConstraint, -) +from sqlalchemy import Boolean, Column, create_engine, DateTime, event, Float, ForeignKey, func, Integer, JSON, make_url, select, String, Table, Text, UniqueConstraint from sqlalchemy.event import listen from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.hybrid import hybrid_property @@ -200,6 +183,20 @@ class Base(DeclarativeBase): ) +class GlobalConfig(Base): + """Global configuration settings. + + Attributes: + id (int): Primary key + passthrough_headers (List[str]): List of headers allowed to be passed through globally + """ + + __tablename__ = "global_config" + + id = Column(Integer, primary_key=True) + passthrough_headers: Mapped[Optional[List[str]]] = Column(JSON, nullable=True) # Store list of strings as JSON array + + class ToolMetric(Base): """ ORM model for recording individual metrics for tool executions. @@ -1117,6 +1114,12 @@ class Gateway(Base): last_seen: Mapped[Optional[datetime]] tags: Mapped[List[str]] = mapped_column(JSON, default=list, nullable=False) + # Header passthrough configuration + passthrough_headers: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) # Store list of strings as JSON array + + # Header passthrough configuration + passthrough_headers: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) # Store list of strings as JSON array + # Relationship with local tools this gateway provides tools: Mapped[List["Tool"]] = relationship(back_populates="gateway", foreign_keys="Tool.gateway_id", cascade="all, delete-orphan") diff --git a/mcpgateway/federation/forward.py b/mcpgateway/federation/forward.py index 5ff5f83dc..159ef52d6 100644 --- a/mcpgateway/federation/forward.py +++ b/mcpgateway/federation/forward.py @@ -35,6 +35,7 @@ from mcpgateway.db import Gateway as DbGateway from mcpgateway.db import Tool as DbTool from mcpgateway.models import ToolResult +from mcpgateway.utils.passthrough_headers import get_passthrough_headers logger = logging.getLogger(__name__) @@ -122,6 +123,7 @@ async def forward_request( method: str, params: Optional[Dict[str, Any]] = None, target_gateway_id: Optional[int] = None, + request_headers: Optional[Dict[str, str]] = None, ) -> Any: """Forward a request to gateway(s). @@ -134,6 +136,8 @@ async def forward_request( method: RPC method name (e.g., "tools/list", "resources/read") params: Optional method parameters as key-value pairs target_gateway_id: Optional specific gateway ID for targeted forwarding + request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Defaults to None. Returns: Any: Single gateway response for targeted requests (when target_gateway_id @@ -173,15 +177,15 @@ async def forward_request( try: if target_gateway_id: # Forward to specific gateway - return await self._forward_to_gateway(db, target_gateway_id, method, params) + return await self._forward_to_gateway(db, target_gateway_id, method, params, request_headers) - # Forward to all relevant gateways - return await self._forward_to_all(db, method, params) + # Forward to all relevant gateways - headers are passed to each gateway + return await self._forward_to_all(db, method, params, request_headers) except Exception as e: raise ForwardingError(f"Forward request failed: {str(e)}") - async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dict[str, Any]) -> ToolResult: + async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dict[str, Any], request_headers: Optional[Dict[str, str]] = None) -> ToolResult: """Forward a tool invocation request. Locates the specified tool in the database, verifies it's federated, @@ -192,6 +196,8 @@ async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dic db: Database session for tool and gateway lookups tool_name: Name of the tool to invoke arguments: Tool arguments as key-value pairs + request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Defaults to None. Returns: ToolResult object containing the tool execution results @@ -252,12 +258,7 @@ async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dic raise ForwardingError(f"Tool {tool_name} is not federated") # Forward to gateway - result = await self._forward_to_gateway( - db, - tool.gateway_id, - "tools/invoke", - {"name": tool_name, "arguments": arguments}, - ) + result = await self._forward_to_gateway(db, tool.gateway_id, "tools/invoke", {"name": tool_name, "arguments": arguments}, request_headers) # Parse result return ToolResult( @@ -353,6 +354,7 @@ async def _forward_to_gateway( gateway_id: str, method: str, params: Optional[Dict[str, Any]] = None, + request_headers: Optional[Dict[str, str]] = None, ) -> Any: """Forward request to a specific gateway. @@ -365,6 +367,8 @@ async def _forward_to_gateway( gateway_id: ID of the gateway to forward to method: RPC method name params: Optional method parameters + request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Defaults to None. Returns: The 'result' field from the gateway's JSON-RPC response @@ -433,10 +437,15 @@ async def _forward_to_gateway( # Send request with retries using the persistent client directly for attempt in range(settings.max_tool_retries): try: + # Merge auth headers with passthrough headers + headers = self._get_auth_headers() + if request_headers: + headers = get_passthrough_headers(request_headers, headers, db, gateway) + response = await self._http_client.post( f"{gateway.url}/rpc", json=request, - headers=self._get_auth_headers(), + headers=headers, ) response.raise_for_status() result = response.json() @@ -457,7 +466,7 @@ async def _forward_to_gateway( except Exception as e: raise ForwardingError(f"Failed to forward to {gateway.name}: {str(e)}") - async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[str, Any]] = None) -> List[Any]: + async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[str, Any]] = None, request_headers: Optional[Dict[str, str]] = None) -> List[Any]: """Forward request to all active gateways. Broadcasts the same request to all enabled gateways in parallel, @@ -468,6 +477,8 @@ async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[ db: Database session for gateway queries method: RPC method name to invoke on all gateways params: Optional method parameters + request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Defaults to None. Returns: List of successful responses from active gateways @@ -492,7 +503,7 @@ async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[ >>> db.execute.return_value = mock_result >>> >>> # Mock forwarding with mixed results - >>> async def mock_forward(db, gw_id, method, params=None): + >>> async def mock_forward(db, gw_id, method, params=None, request_headers=None): ... if gw_id == 1: ... return {"gateway": "gw1", "status": "ok"} ... elif gw_id == 2: @@ -510,7 +521,7 @@ async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[ 'gw3' >>> # Test all gateways failing - >>> async def mock_all_fail(db, gw_id, method, params=None): + >>> async def mock_all_fail(db, gw_id, method, params=None, request_headers=None): ... raise Exception(f"Gateway {gw_id} failed") >>> >>> service._forward_to_gateway = mock_all_fail @@ -529,7 +540,7 @@ async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[ for gateway in gateways: try: - result = await self._forward_to_gateway(db, gateway.id, method, params) + result = await self._forward_to_gateway(db, gateway.id, method, params, request_headers) results.append(result) except Exception as e: errors.append(str(e)) diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 77d444125..0ddb6e826 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -2134,8 +2134,10 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user: str # Per the MCP spec, a ping returns an empty result. result = {} else: + # Get request headers + headers = {k.lower(): v for k, v in request.headers.items()} try: - result = await tool_service.invoke_tool(db=db, name=method, arguments=params) + result = await tool_service.invoke_tool(db=db, name=method, arguments=params, request_headers=headers) if hasattr(result, "model_dump"): result = result.model_dump(by_alias=True, exclude_none=True) except ValueError: diff --git a/mcpgateway/models.py b/mcpgateway/models.py index b577f61e8..1efc619b2 100644 --- a/mcpgateway/models.py +++ b/mcpgateway/models.py @@ -610,6 +610,17 @@ class JSONRPCError(BaseModel): data: Optional[Any] = None +# Global configuration types +class GlobalConfig(BaseModel): + """Global server configuration. + + Attributes: + passthrough_headers (Optional[List[str]]): List of headers allowed to be passed through globally + """ + + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through globally") + + # Transport message types class SSEEvent(BaseModel): """Server-Sent Events message. diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index 6015d2028..aa02fc6bc 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -1658,6 +1658,27 @@ class PromptInvocation(BaseModelWithConfigDict): arguments: Dict[str, str] = Field(default_factory=dict, description="Arguments for template rendering") +# --- Global Config Schemas --- +class GlobalConfigUpdate(BaseModel): + """Schema for updating global configuration. + + Attributes: + passthrough_headers (Optional[List[str]]): List of headers allowed to be passed through globally + """ + + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through globally") + + +class GlobalConfigRead(BaseModel): + """Schema for reading global configuration. + + Attributes: + passthrough_headers (Optional[List[str]]): List of headers allowed to be passed through globally + """ + + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through globally") + + # --- Gateway Schemas --- @@ -1704,6 +1725,7 @@ class GatewayCreate(BaseModel): url: Union[str, AnyHttpUrl] = Field(..., description="Gateway endpoint URL") description: Optional[str] = Field(None, description="Gateway description") transport: str = Field(default="SSE", description="Transport used by MCP server: SSE or STREAMABLEHTTP") + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through from client to target") # Authorizations auth_type: Optional[str] = Field(None, description="Type of authentication: basic, bearer, headers, or none") @@ -1888,13 +1910,15 @@ class GatewayUpdate(BaseModelWithConfigDict): url: Optional[Union[str, AnyHttpUrl]] = Field(None, description="Gateway endpoint URL") description: Optional[str] = Field(None, description="Gateway description") transport: str = Field(default="SSE", description="Transport used by MCP server: SSE or STREAMABLEHTTP") + + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through from client to target") # Authorizations auth_type: Optional[str] = Field(None, description="auth_type: basic, bearer, headers or None") auth_username: Optional[str] = Field(None, description="username for basic authentication") auth_password: Optional[str] = Field(None, description="password for basic authentication") auth_token: Optional[str] = Field(None, description="token for bearer authentication") auth_header_key: Optional[str] = Field(None, description="key for custom headers authentication") - auth_header_value: Optional[str] = Field(None, description="vallue for custom headers authentication") + auth_header_value: Optional[str] = Field(None, description="value for custom headers authentication") # Adding `auth_value` as an alias for better access post-validation auth_value: Optional[str] = Field(None, validate_default=True) @@ -2070,6 +2094,7 @@ class GatewayRead(BaseModelWithConfigDict): last_seen: Optional[datetime] = Field(default_factory=lambda: datetime.now(timezone.utc), description="Last seen timestamp") + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through from client to target") # Authorizations auth_type: Optional[str] = Field(None, description="auth_type: basic, bearer, headers or None") auth_value: Optional[str] = Field(None, description="auth value: username/password or token or custom headers") diff --git a/mcpgateway/services/__init__.py b/mcpgateway/services/__init__.py index 88c4fd675..68a13bffd 100644 --- a/mcpgateway/services/__init__.py +++ b/mcpgateway/services/__init__.py @@ -12,6 +12,7 @@ - Gateway coordination """ +# First-Party from mcpgateway.services.gateway_service import GatewayError, GatewayService from mcpgateway.services.prompt_service import PromptError, PromptService from mcpgateway.services.resource_service import ResourceError, ResourceService diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 1c5263089..21ceecbfa 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -45,6 +45,7 @@ ToolUpdate, ) from mcpgateway.utils.create_slug import slugify +from mcpgateway.utils.passthrough_headers import get_passthrough_headers from mcpgateway.utils.retry_manager import ResilientHttpClient from mcpgateway.utils.services_auth import decode_auth @@ -332,7 +333,9 @@ async def register_tool(self, db: Session, tool: ToolCreate) -> ToolRead: except Exception as ex: raise ToolError(f"Failed to register tool: {str(ex)}") - async def list_tools(self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None, tags: Optional[List[str]] = None) -> List[ToolRead]: + async def list_tools( + self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None, tags: Optional[List[str]] = None, _request_headers: Optional[Dict[str, str]] = None + ) -> List[ToolRead]: """ Retrieve a list of registered tools from the database. @@ -343,6 +346,8 @@ async def list_tools(self, db: Session, include_inactive: bool = False, cursor: cursor (Optional[str], optional): An opaque cursor token for pagination. Currently, this parameter is ignored. Defaults to None. tags (Optional[List[str]]): Filter tools by tags. If provided, only tools with at least one matching tag will be returned. + _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Currently unused but kept for API consistency. Defaults to None. Returns: List[ToolRead]: A list of registered tools represented as ToolRead objects. @@ -378,7 +383,7 @@ async def list_tools(self, db: Session, include_inactive: bool = False, cursor: tools = db.execute(query).scalars().all() return [self._convert_tool_to_read(t) for t in tools] - async def list_server_tools(self, db: Session, server_id: str, include_inactive: bool = False, cursor: Optional[str] = None) -> List[ToolRead]: + async def list_server_tools(self, db: Session, server_id: str, include_inactive: bool = False, cursor: Optional[str] = None, _request_headers: Optional[Dict[str, str]] = None) -> List[ToolRead]: """ Retrieve a list of registered tools from the database. @@ -389,6 +394,8 @@ async def list_server_tools(self, db: Session, server_id: str, include_inactive: Defaults to False. cursor (Optional[str], optional): An opaque cursor token for pagination. Currently, this parameter is ignored. Defaults to None. + _request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Currently unused but kept for API consistency. Defaults to None. Returns: List[ToolRead]: A list of registered tools represented as ToolRead objects. @@ -548,7 +555,7 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re db.rollback() raise ToolError(f"Failed to toggle tool status: {str(e)}") - async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -> ToolResult: + async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], request_headers: Optional[Dict[str, str]] = None) -> ToolResult: """ Invoke a registered tool and record execution metrics. @@ -556,6 +563,8 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) - db: Database session. name: Name of tool to invoke. arguments: Tool arguments. + request_headers (Optional[Dict[str, str]], optional): Headers from the request to pass through. + Defaults to None. Returns: Tool invocation result. @@ -595,14 +604,17 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) - success = False error_message = None try: - # tool.validate_arguments(arguments) - # Build headers with auth if necessary. + # Get combined headers for the tool including base headers, auth, and passthrough headers + # headers = self._get_combined_headers(db, tool, tool.headers or {}, request_headers) headers = tool.headers or {} if tool.integration_type == "REST": credentials = decode_auth(tool.auth_value) headers.update(credentials) - # Build the payload based on integration type. + # Only call get_passthrough_headers if we actually have request headers to pass through + if request_headers: + headers = get_passthrough_headers(request_headers, headers, db) + # Build the payload based on integration type payload = arguments.copy() # Handle URL path parameter substitution @@ -647,6 +659,11 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) - gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.enabled)).scalar_one_or_none() headers = decode_auth(gateway.auth_value) + # Get combined headers including gateway auth and passthrough + # base_headers = decode_auth(gateway.auth_value) if gateway and gateway.auth_value else {} + if request_headers: + headers = get_passthrough_headers(request_headers, headers, db, gateway) + async def connect_to_sse_server(server_url: str) -> str: """ Connect to an MCP server running with SSE transport diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index d21600ff0..d1c81a92f 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -2471,6 +2471,22 @@ async function editGateway(gatewayId) { break; } + // Handle passthrough headers + const passthroughHeadersField = safeGetElement( + "edit-gateway-passthrough-headers", + ); + if (passthroughHeadersField) { + if ( + gateway.passthroughHeaders && + Array.isArray(gateway.passthroughHeaders) + ) { + passthroughHeadersField.value = + gateway.passthroughHeaders.join(", "); + } else { + passthroughHeadersField.value = ""; + } + } + openModal("gateway-edit-modal"); console.log("✓ Gateway edit modal loaded successfully"); } catch (error) { @@ -3763,14 +3779,43 @@ async function runToolTest() { params, }; + // Parse custom headers from the passthrough headers field + const requestHeaders = { + "Content-Type": "application/json", + }; + + const passthroughHeadersField = document.getElementById( + "test-passthrough-headers", + ); + if (passthroughHeadersField && passthroughHeadersField.value.trim()) { + const headerLines = passthroughHeadersField.value + .trim() + .split("\n"); + for (const line of headerLines) { + const trimmedLine = line.trim(); + if (trimmedLine) { + const colonIndex = trimmedLine.indexOf(":"); + if (colonIndex > 0) { + const headerName = trimmedLine + .substring(0, colonIndex) + .trim(); + const headerValue = trimmedLine + .substring(colonIndex + 1) + .trim(); + if (headerName && headerValue) { + requestHeaders[headerName] = headerValue; + } + } + } + } + } + // Use longer timeout for test execution const response = await fetchWithTimeout( `${window.ROOT_PATH}/rpc`, { method: "POST", - headers: { - "Content-Type": "application/json", - }, + headers: requestHeaders, body: JSON.stringify(payload), credentials: "include", }, @@ -4510,6 +4555,23 @@ async function handleGatewayFormSubmit(e) { const isInactiveCheckedBool = isInactiveChecked("gateways"); formData.append("is_inactive_checked", isInactiveCheckedBool); + // Process passthrough headers - convert comma-separated string to array + const passthroughHeadersString = formData.get("passthrough_headers"); + if (passthroughHeadersString && passthroughHeadersString.trim()) { + // Split by comma and clean up each header name + const passthroughHeaders = passthroughHeadersString + .split(",") + .map((header) => header.trim()) + .filter((header) => header.length > 0); + + // Remove the original string and add as JSON array + formData.delete("passthrough_headers"); + formData.append( + "passthrough_headers", + JSON.stringify(passthroughHeaders), + ); + } + const response = await fetchWithTimeout( `${window.ROOT_PATH}/admin/gateways`, { @@ -4921,6 +4983,18 @@ async function handleEditGatewayFormSubmit(e) { throw new Error(urlValidation.error); } + // Handle passthrough headers + const passthroughHeadersString = + formData.get("passthrough_headers") || ""; + const passthroughHeaders = passthroughHeadersString + .split(",") + .map((header) => header.trim()) + .filter((header) => header.length > 0); + formData.append( + "passthrough_headers", + JSON.stringify(passthroughHeaders), + ); + const isInactiveCheckedBool = isInactiveChecked("gateways"); formData.append("is_inactive_checked", isInactiveCheckedBool); // Submit via fetch diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 6532710cd..0ef67b5ff 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -2215,6 +2215,23 @@

/> +
+ + + List of headers to pass through from client requests + (comma-separated, e.g., "Authorization, X-Tenant-Id, + X-Trace-Id"). Leave empty to use global defaults. + + +
@@ -2533,6 +2550,27 @@

> +
+
+ + + Additional headers to send with the request (format: + "Header-Name: Value", one per line) + + +
+
+
+ + + List of headers to pass through from client requests + (comma-separated, e.g., "Authorization, X-Tenant-Id, + X-Trace-Id"). Leave empty to use global defaults. + + +
List[Union[types.TextContent, >>> sig.return_annotation typing.List[typing.Union[mcp.types.TextContent, mcp.types.ImageContent, mcp.types.EmbeddedResource]] """ + request_headers = request_headers_var.get() try: async with get_db() as db: - result = await tool_service.invoke_tool(db=db, name=name, arguments=arguments) + result = await tool_service.invoke_tool(db=db, name=name, arguments=arguments, request_headers=request_headers) if not result or not result.content: logger.warning(f"No content returned by tool: {name}") return [] @@ -392,11 +394,12 @@ async def list_tools() -> List[types.Tool]: typing.List[mcp.types.Tool] """ server_id = server_id_var.get() + request_headers = request_headers_var.get() if server_id: try: async with get_db() as db: - tools = await tool_service.list_server_tools(db, server_id) + tools = await tool_service.list_server_tools(db, server_id, _request_headers=request_headers) return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, annotations=tool.annotations) for tool in tools] except Exception as e: logger.exception(f"Error listing tools:{e}") @@ -404,7 +407,7 @@ async def list_tools() -> List[types.Tool]: else: try: async with get_db() as db: - tools = await tool_service.list_tools(db) + tools = await tool_service.list_tools(db, False, None, None, request_headers) return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, annotations=tool.annotations) for tool in tools] except Exception as e: logger.exception(f"Error listing tools:{e}") @@ -519,6 +522,11 @@ async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Sen path = scope["modified_path"] match = re.search(r"/servers/(?P[a-fA-F0-9\-]+)/mcp", path) + # Extract request headers from scope + headers = dict(Headers(scope=scope)) + # Store headers in context for tool invocations + request_headers_var.set(headers) + if match: server_id = match.group("server_id") server_id_var.set(server_id) diff --git a/mcpgateway/utils/passthrough_headers.py b/mcpgateway/utils/passthrough_headers.py new file mode 100644 index 000000000..beaad24c5 --- /dev/null +++ b/mcpgateway/utils/passthrough_headers.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- +"""HTTP Header Passthrough Utilities. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +This module provides utilities for handling HTTP header passthrough functionality +in the MCP Gateway. It enables forwarding of specific headers from incoming +client requests to backing MCP servers while preventing conflicts with +existing authentication mechanisms. + +Key Features: +- Global configuration support via environment variables and database +- Per-gateway header configuration overrides +- Intelligent conflict detection with existing authentication headers +- Security-first approach with explicit allowlist handling +- Comprehensive logging for debugging and monitoring + +The header passthrough system follows a priority hierarchy: +1. Gateway-specific headers (highest priority) +2. Global database configuration +3. Environment variable defaults (lowest priority) + +Example Usage: + Basic header passthrough with global configuration: + >>> from unittest.mock import Mock + >>> mock_db = Mock() + >>> mock_global_config = Mock() + >>> mock_global_config.passthrough_headers = ["X-Tenant-Id"] + >>> mock_db.query.return_value.first.return_value = mock_global_config + >>> headers = get_passthrough_headers( + ... request_headers={"x-tenant-id": "123"}, + ... base_headers={"Content-Type": "application/json"}, + ... db=mock_db + ... ) + >>> sorted(headers.items()) + [('Content-Type', 'application/json'), ('X-Tenant-Id', '123')] +""" + +# Standard +import logging +from typing import Dict, Optional + +# Third-Party +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.config import settings +from mcpgateway.db import Gateway as DbGateway +from mcpgateway.db import GlobalConfig + +logger = logging.getLogger(__name__) + + +def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[str, str], db: Session, gateway: Optional[DbGateway] = None) -> Dict[str, str]: + """Get headers that should be passed through to the target gateway. + + This function implements the core logic for HTTP header passthrough in the MCP Gateway. + It determines which headers from incoming client requests should be forwarded to + backing MCP servers based on configuration settings and security policies. + + Configuration Priority (highest to lowest): + 1. Gateway-specific passthrough_headers setting + 2. Global database configuration (GlobalConfig.passthrough_headers) + 3. Environment variable DEFAULT_PASSTHROUGH_HEADERS + + Security Features: + - Prevents conflicts with existing base headers (e.g., Content-Type) + - Blocks Authorization header conflicts with gateway authentication + - Logs all conflicts and skipped headers for debugging + - Uses case-insensitive header matching for robustness + + Args: + request_headers (Dict[str, str]): Headers from the incoming HTTP request. + Keys should be header names, values should be header values. + Example: {"Authorization": "Bearer token123", "X-Tenant-Id": "acme"} + base_headers (Dict[str, str]): Base headers that should always be included + in the final result. These take precedence over passthrough headers. + Example: {"Content-Type": "application/json", "User-Agent": "MCPGateway/1.0"} + db (Session): SQLAlchemy database session for querying global configuration. + Used to retrieve GlobalConfig.passthrough_headers setting. + gateway (Optional[DbGateway]): Target gateway instance. If provided, uses + gateway.passthrough_headers to override global settings. Also checks + gateway.auth_type to prevent Authorization header conflicts. + + Returns: + Dict[str, str]: Combined dictionary of base headers plus allowed passthrough + headers from the request. Base headers are preserved, and passthrough + headers are added only if they don't conflict with security policies. + + Raises: + No exceptions are raised. Errors are logged as warnings and processing continues. + Database connection issues may propagate from the db.query() call. + + Examples: + Basic usage with global configuration: + >>> # Mock database and settings for doctest + >>> from unittest.mock import Mock, MagicMock + >>> mock_db = Mock() + >>> mock_global_config = Mock() + >>> mock_global_config.passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"] + >>> mock_db.query.return_value.first.return_value = mock_global_config + >>> + >>> request_headers = { + ... "authorization": "Bearer token123", + ... "x-tenant-id": "acme-corp", + ... "x-trace-id": "trace-456", + ... "user-agent": "TestClient/1.0" + ... } + >>> base_headers = {"Content-Type": "application/json"} + >>> + >>> result = get_passthrough_headers(request_headers, base_headers, mock_db) + >>> sorted(result.items()) + [('Content-Type', 'application/json'), ('X-Tenant-Id', 'acme-corp'), ('X-Trace-Id', 'trace-456')] + + Gateway-specific configuration override: + >>> mock_gateway = Mock() + >>> mock_gateway.passthrough_headers = ["X-Custom-Header"] + >>> mock_gateway.auth_type = None + >>> request_headers = { + ... "x-custom-header": "custom-value", + ... "x-tenant-id": "should-be-ignored" + ... } + >>> + >>> result = get_passthrough_headers(request_headers, base_headers, mock_db, mock_gateway) + >>> sorted(result.items()) + [('Content-Type', 'application/json'), ('X-Custom-Header', 'custom-value')] + + Authorization header conflict with basic auth: + >>> mock_gateway.auth_type = "basic" + >>> mock_gateway.passthrough_headers = ["Authorization", "X-Tenant-Id"] + >>> request_headers = { + ... "authorization": "Bearer should-be-blocked", + ... "x-tenant-id": "acme-corp" + ... } + >>> + >>> result = get_passthrough_headers(request_headers, base_headers, mock_db, mock_gateway) + >>> sorted(result.items()) # Authorization blocked due to basic auth conflict + [('Content-Type', 'application/json'), ('X-Tenant-Id', 'acme-corp')] + + Base header conflict prevention: + >>> base_headers_with_conflict = {"Content-Type": "application/json", "x-tenant-id": "from-base"} + >>> request_headers = {"x-tenant-id": "from-request"} + >>> mock_gateway.auth_type = None + >>> mock_gateway.passthrough_headers = ["X-Tenant-Id"] + >>> + >>> result = get_passthrough_headers(request_headers, base_headers_with_conflict, mock_db, mock_gateway) + >>> result["x-tenant-id"] # Base header preserved, request header blocked + 'from-base' + + Empty allowed headers (no passthrough): + >>> empty_global_config = Mock() + >>> empty_global_config.passthrough_headers = [] + >>> mock_db.query.return_value.first.return_value = empty_global_config + >>> + >>> request_headers = {"x-tenant-id": "should-be-ignored"} + >>> result = get_passthrough_headers(request_headers, {"Content-Type": "application/json"}, mock_db) + >>> result + {'Content-Type': 'application/json'} + + Note: + Header names are matched case-insensitively but preserved in their original + case from the allowed_headers configuration. Request header values are + matched case-insensitively against the request_headers dictionary. + """ + passthrough_headers = base_headers.copy() + + # Get global passthrough headers first + global_config = db.query(GlobalConfig).first() + allowed_headers = global_config.passthrough_headers if global_config else settings.default_passthrough_headers + + # Gateway specific headers override global config + if gateway: + if gateway.passthrough_headers is not None: + allowed_headers = gateway.passthrough_headers + + # Get auth headers to check for conflicts + base_headers_keys = {key.lower(): key for key in passthrough_headers.keys()} + + # Copy allowed headers from request + if request_headers and allowed_headers: + for header_name in allowed_headers: + header_value = request_headers.get(header_name.lower()) + if header_value: + + header_lower = header_name.lower() + # Skip if header would conflict with existing auth headers + if header_lower in base_headers_keys: + logger.warning(f"Skipping {header_name} header passthrough as it conflicts with pre-defined headers") + continue + + # Skip if header would conflict with gateway auth + if gateway: + if gateway.auth_type == "basic" and header_lower == "authorization": + logger.warning(f"Skipping Authorization header passthrough due to basic auth configuration on gateway {gateway.name}") + continue + if gateway.auth_type == "bearer" and header_lower == "authorization": + logger.warning(f"Skipping Authorization header passthrough due to bearer auth configuration on gateway {gateway.name}") + continue + + passthrough_headers[header_name] = header_value + else: + logger.warning(f"Header {header_name} not found in request headers, skipping passthrough") + + return passthrough_headers diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index dc045d559..f7f632235 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -248,7 +248,7 @@ def test_rpc_tool_invocation_flow( resp = test_client.post("/rpc/", json=rpc_body, headers=auth_headers) assert resp.status_code == 200 assert resp.json()["content"][0]["text"] == "ok" - mock_invoke.assert_awaited_once_with(db=ANY, name="test_tool", arguments={"foo": "bar"}) + mock_invoke.assert_awaited_once_with(db=ANY, name="test_tool", arguments={"foo": "bar"}, request_headers=ANY) # --------------------------------------------------------------------- # # 5. Metrics aggregation endpoint # diff --git a/tests/unit/mcpgateway/federation/test_forward.py b/tests/unit/mcpgateway/federation/test_forward.py index 3c073bd40..4b885c51c 100644 --- a/tests/unit/mcpgateway/federation/test_forward.py +++ b/tests/unit/mcpgateway/federation/test_forward.py @@ -179,7 +179,7 @@ async def test_forward_request_targeted(fwd_service): gw = DummyGateway(1, "Alpha", "http://alpha") db = FakeSession(gateways=[gw]) - result = await fwd_service.forward_request(db, "test/method", {"param": "value"}, target_gateway_id=1) + result = await fwd_service.forward_request(db, "test/method", {"param": "value"}, target_gateway_id=1, request_headers=None) assert result == {"method": "test/method"} @@ -190,7 +190,7 @@ async def test_forward_request_broadcast(fwd_service): gw2 = DummyGateway(2, "Beta", "http://beta") db = FakeSession(gateways=[gw1, gw2]) - results = await fwd_service.forward_request(db, "tools/list") + results = await fwd_service.forward_request(db, "tools/list", request_headers=None) assert len(results) == 2 assert all(r == {"method": "tools/list"} for r in results) @@ -206,7 +206,7 @@ async def failing_forward(*args, **kwargs): db = FakeSession() with pytest.raises(ForwardingError) as exc_info: - await fwd_service.forward_request(db, "test", target_gateway_id=1) + await fwd_service.forward_request(db, "test", target_gateway_id=1, request_headers=None) assert "Forward request failed: Network error" in str(exc_info.value) @@ -219,7 +219,7 @@ async def failing_forward(*args, **kwargs): async def test_forward_tool_request_success(monkeypatch, fwd_service): """Test successful tool forwarding.""" - async def fake_forward(db, gid, method, params): + async def fake_forward(db, gid, method, params, request_headers=None): assert method == "tools/invoke" assert params["name"] == "calculator" return { @@ -232,7 +232,7 @@ async def fake_forward(db, gid, method, params): tool = DummyTool(1, "calculator", gateway_id=42) db = FakeSession(gateways=[DummyGateway(42, "CalcGW", "http://calc")], tools=[tool]) - result = await fwd_service.forward_tool_request(db, "calculator", {"operation": "add", "a": 20, "b": 22}) + result = await fwd_service.forward_tool_request(db, "calculator", {"operation": "add", "a": 20, "b": 22}, request_headers=None) assert not result.is_error assert len(result.content) == 1 # Access TextContent object attributes, not dictionary keys @@ -246,7 +246,7 @@ async def test_forward_tool_request_not_found(fwd_service): db = FakeSession(tools=[]) with pytest.raises(ForwardingError) as exc_info: - await fwd_service.forward_tool_request(db, "unknown_tool", {}) + await fwd_service.forward_tool_request(db, "unknown_tool", {}, request_headers=None) assert "Tool not found: unknown_tool" in str(exc_info.value) @@ -257,7 +257,7 @@ async def test_forward_tool_request_not_federated(fwd_service): db = FakeSession(tools=[tool]) with pytest.raises(ForwardingError) as exc_info: - await fwd_service.forward_tool_request(db, "local_tool", {}) + await fwd_service.forward_tool_request(db, "local_tool", {}, request_headers=None) assert "Tool local_tool is not federated" in str(exc_info.value) @@ -274,7 +274,7 @@ def failing_execute(query): monkeypatch.setattr(db, "execute", failing_execute) with pytest.raises(ForwardingError) as exc_info: - await fwd_service.forward_tool_request(db, "test_tool", {}) + await fwd_service.forward_tool_request(db, "test_tool", {}, request_headers=None) assert "Failed to forward tool request: Database error" in str(exc_info.value) @@ -291,7 +291,7 @@ async def test_forward_resource_request_text(monkeypatch, fwd_service): async def fake_find_gateway(db, uri): return gateway - async def fake_forward(db, gid, method, params): + async def fake_forward(db, gid, method, params, request_headers=None): assert method == "resources/read" assert params["uri"] == "file://hello.txt" return {"text": "Hello, World!", "mime_type": "text/plain"} @@ -313,7 +313,7 @@ async def test_forward_resource_request_binary(monkeypatch, fwd_service): async def fake_find_gateway(db, uri): return gateway - async def fake_forward(db, gid, method, params): + async def fake_forward(db, gid, method, params, request_headers=None): return {"blob": b"\x89PNG...", "mime_type": "image/png"} monkeypatch.setattr(fwd_service, "_find_resource_gateway", fake_find_gateway) @@ -349,7 +349,7 @@ async def test_forward_resource_request_invalid_format(monkeypatch, fwd_service) async def fake_find_gateway(db, uri): return gateway - async def fake_forward(db, gid, method, params): + async def fake_forward(db, gid, method, params, request_headers=None): return {"invalid": "response"} monkeypatch.setattr(fwd_service, "_find_resource_gateway", fake_find_gateway) @@ -387,7 +387,7 @@ async def test_forward_to_gateway_success(fwd_service): gw = DummyGateway(1, "Alpha", "http://alpha") db = FakeSession(gateways=[gw]) - result = await fwd_service._forward_to_gateway(db, 1, "ping", {"x": 1}) + result = await fwd_service._forward_to_gateway(db, 1, "ping", {"x": 1}, request_headers=None) assert result == {"method": "ping"} assert isinstance(gw.last_seen, datetime) @@ -398,7 +398,7 @@ async def test_forward_to_gateway_not_found(fwd_service): db = FakeSession() with pytest.raises(ForwardingError) as exc_info: - await fwd_service._forward_to_gateway(db, 999, "test") + await fwd_service._forward_to_gateway(db, 999, "test", request_headers=None) assert "Gateway not found: 999" in str(exc_info.value) @@ -409,7 +409,7 @@ async def test_forward_to_gateway_disabled(fwd_service): db = FakeSession(gateways=[gw]) with pytest.raises(ForwardingError) as exc_info: - await fwd_service._forward_to_gateway(db, 1, "test") + await fwd_service._forward_to_gateway(db, 1, "test", request_headers=None) assert "Gateway not found: 1" in str(exc_info.value) @@ -422,7 +422,7 @@ async def test_forward_to_gateway_rate_limited(monkeypatch, fwd_service): monkeypatch.setattr(fwd_service, "_check_rate_limit", lambda url: False) with pytest.raises(ForwardingError) as exc_info: - await fwd_service._forward_to_gateway(db, 1, "test") + await fwd_service._forward_to_gateway(db, 1, "test", request_headers=None) assert "Rate limit exceeded" in str(exc_info.value) @@ -445,7 +445,7 @@ def json(self): monkeypatch.setattr(fwd_service._http_client, "post", fake_post) with pytest.raises(ForwardingError) as exc_info: - await fwd_service._forward_to_gateway(db, 1, "unknown_method") + await fwd_service._forward_to_gateway(db, 1, "unknown_method", request_headers=None) assert "Gateway error: Method not found" in str(exc_info.value) @@ -480,7 +480,7 @@ def json(self): monkeypatch.setattr(fwd_service._http_client, "post", fake_post) - result = await fwd_service._forward_to_gateway(db, 1, "test") + result = await fwd_service._forward_to_gateway(db, 1, "test", request_headers=None) assert result == {"success": True} assert call_count == 3 @@ -502,7 +502,7 @@ async def fake_post(url, json=None, headers=None): monkeypatch.setattr(fwd_service._http_client, "post", fake_post) with pytest.raises(ForwardingError) as exc_info: - await fwd_service._forward_to_gateway(db, 1, "test") + await fwd_service._forward_to_gateway(db, 1, "test", request_headers=None) assert "Failed to forward to Alpha" in str(exc_info.value) @@ -518,7 +518,7 @@ async def fake_post(url, json=None, headers=None): monkeypatch.setattr(fwd_service._http_client, "post", fake_post) with pytest.raises(ForwardingError) as exc_info: - await fwd_service._forward_to_gateway(db, 1, "test") + await fwd_service._forward_to_gateway(db, 1, "test", request_headers=None) assert "Failed to forward to Alpha: Connection refused" in str(exc_info.value) @@ -534,7 +534,7 @@ async def test_forward_to_all_success(fwd_service): gw2 = DummyGateway(2, "Beta", "http://beta") db = FakeSession(gateways=[gw1, gw2]) - results = await fwd_service._forward_to_all(db, "tools/list") + results = await fwd_service._forward_to_all(db, "tools/list", request_headers=None) assert len(results) == 2 assert all(r == {"method": "tools/list"} for r in results) @@ -545,7 +545,7 @@ async def test_forward_to_all_partial_success(monkeypatch, fwd_service): gw_ok = DummyGateway(1, "GoodGW", "http://good") gw_bad = DummyGateway(2, "BadGW", "http://bad") - async def fake_forward(db, gid, method, params=None): # Add default params=None + async def fake_forward(db, gid, method, params=None, request_headers=None): # Add default params=None if gid == 1: return "ok!" raise ForwardingError("boom") @@ -553,7 +553,7 @@ async def fake_forward(db, gid, method, params=None): # Add default params=None monkeypatch.setattr(fwd_service, "_forward_to_gateway", fake_forward) db = FakeSession(gateways=[gw_ok, gw_bad]) - results = await fwd_service._forward_to_all(db, "stats/get") + results = await fwd_service._forward_to_all(db, "stats/get", request_headers=None) assert results == ["ok!"] @@ -563,14 +563,14 @@ async def test_forward_to_all_complete_failure(monkeypatch, fwd_service): gw1 = DummyGateway(1, "BadGW1", "http://bad1") gw2 = DummyGateway(2, "BadGW2", "http://bad2") - async def fake_forward(db, gid, method, params=None): # Add default params=None + async def fake_forward(db, gid, method, params=None, request_headers=None): # Add default params=None raise ForwardingError(f"Gateway {gid} failed") monkeypatch.setattr(fwd_service, "_forward_to_gateway", fake_forward) db = FakeSession(gateways=[gw1, gw2]) with pytest.raises(ForwardingError) as exc_info: - await fwd_service._forward_to_all(db, "test") + await fwd_service._forward_to_all(db, "test", request_headers=None) assert "All forwards failed" in str(exc_info.value) assert "Gateway 1 failed" in str(exc_info.value) assert "Gateway 2 failed" in str(exc_info.value) @@ -580,7 +580,7 @@ async def fake_forward(db, gid, method, params=None): # Add default params=None async def test_forward_to_all_no_gateways(fwd_service): """Test forwarding with no active gateways.""" db = FakeSession(gateways=[]) - results = await fwd_service._forward_to_all(db, "test") + results = await fwd_service._forward_to_all(db, "test", request_headers=None) assert results == [] @@ -595,7 +595,7 @@ async def test_find_resource_gateway_found(monkeypatch, fwd_service): gw1 = DummyGateway(1, "Gateway 1", "http://gw1") gw2 = DummyGateway(2, "Gateway 2", "http://gw2") - async def fake_forward(db, gid, method, params=None): # Add default params=None + async def fake_forward(db, gid, method, params=None, request_headers=None): # Add default params=None assert method == "resources/list" # This is the actual method called if gid == 1: return [{"uri": "file://doc1.txt"}, {"uri": "file://doc2.txt"}] @@ -614,7 +614,7 @@ async def test_find_resource_gateway_not_found(monkeypatch, fwd_service): """Test resource not found in any gateway.""" gw1 = DummyGateway(1, "Gateway 1", "http://gw1") - async def fake_forward(db, gid, method, params=None): # Add default params=None + async def fake_forward(db, gid, method, params=None, request_headers=None): # Add default params=None assert method == "resources/list" return [{"uri": "file://other.txt"}] @@ -631,7 +631,7 @@ async def test_find_resource_gateway_with_errors(monkeypatch, fwd_service, caplo gw1 = DummyGateway(1, "Gateway 1", "http://gw1") gw2 = DummyGateway(2, "Gateway 2", "http://gw2") - async def fake_forward(db, gid, method, params=None): # Add default params=None + async def fake_forward(db, gid, method, params=None, request_headers=None): # Add default params=None assert method == "resources/list" if gid == 1: raise Exception("Gateway unavailable") @@ -765,7 +765,7 @@ async def test_forward_with_no_params(fwd_service): gw = DummyGateway(1, "Alpha", "http://alpha") db = FakeSession(gateways=[gw]) - result = await fwd_service._forward_to_gateway(db, 1, "status") + result = await fwd_service._forward_to_gateway(db, 1, "status", request_headers=None) assert result == {"method": "status"} @@ -776,7 +776,7 @@ async def test_concurrent_forwards(monkeypatch, fwd_service): call_times = [] - async def fake_forward(db, gid, method, params=None): # Add default params=None + async def fake_forward(db, gid, method, params=None, request_headers=None): # Add default params=None start = asyncio.get_event_loop().time() await asyncio.sleep(0.1) # Simulate network delay call_times.append((gid, asyncio.get_event_loop().time() - start)) @@ -785,7 +785,7 @@ async def fake_forward(db, gid, method, params=None): # Add default params=None monkeypatch.setattr(fwd_service, "_forward_to_gateway", fake_forward) db = FakeSession(gateways=gateways) - results = await fwd_service._forward_to_all(db, "health/check") + results = await fwd_service._forward_to_all(db, "health/check", request_headers=None) # All gateways should respond assert len(results) == 5 @@ -798,7 +798,7 @@ async def fake_forward(db, gid, method, params=None): # Add default params=None async def test_forward_tool_with_empty_content(monkeypatch, fwd_service): """Test tool forwarding with empty content.""" - async def fake_forward(db, gid, method, params): + async def fake_forward(db, gid, method, params, request_headers=None): return {"content": [], "is_error": False} monkeypatch.setattr(fwd_service, "_forward_to_gateway", fake_forward) @@ -806,7 +806,7 @@ async def fake_forward(db, gid, method, params): tool = DummyTool(1, "empty_tool", gateway_id=1) db = FakeSession(gateways=[DummyGateway(1, "GW", "http://gw")], tools=[tool]) - result = await fwd_service.forward_tool_request(db, "empty_tool", {}) + result = await fwd_service.forward_tool_request(db, "empty_tool", {}, request_headers=None) assert not result.is_error assert result.content == [] @@ -819,7 +819,7 @@ async def test_forward_resource_with_defaults(monkeypatch, fwd_service): async def fake_find_gateway(db, uri): return gateway - async def fake_forward_text(db, gid, method, params): + async def fake_forward_text(db, gid, method, params, request_headers=None): return {"text": "Plain text"} # No mime_type specified monkeypatch.setattr(fwd_service, "_find_resource_gateway", fake_find_gateway) @@ -831,7 +831,7 @@ async def fake_forward_text(db, gid, method, params): assert mime_type == "text/plain" # Default # Test binary with default - async def fake_forward_binary(db, gid, method, params): + async def fake_forward_binary(db, gid, method, params, request_headers=None): return {"blob": b"binary data"} # No mime_type specified monkeypatch.setattr(fwd_service, "_forward_to_gateway", fake_forward_binary) diff --git a/tests/unit/mcpgateway/services/test_tag_service.py b/tests/unit/mcpgateway/services/test_tag_service.py index b3fd9536b..56ca71338 100644 --- a/tests/unit/mcpgateway/services/test_tag_service.py +++ b/tests/unit/mcpgateway/services/test_tag_service.py @@ -411,7 +411,7 @@ async def test_update_stats(tag_service): # Test invalid entity type (should not crash or increment) tag_service._update_stats(stats, "invalid") - assert stats.total == 5 # Should remain unchanged + assert stats.total == 5 # Should remain unmodified @pytest.mark.asyncio diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index 64e76aa3d..fe166d492 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -1118,7 +1118,7 @@ async def test_invoke_tool_not_found(self, tool_service, test_db): # Should raise NotFoundError with pytest.raises(ToolNotFoundError) as exc_info: - await tool_service.invoke_tool(test_db, "nonexistent_tool", {}) + await tool_service.invoke_tool(test_db, "nonexistent_tool", {}, request_headers=None) assert "Tool not found: nonexistent_tool" in str(exc_info.value) @@ -1139,7 +1139,7 @@ async def test_invoke_tool_inactive(self, tool_service, mock_tool, test_db): # Should raise NotFoundError with "inactive" message with pytest.raises(ToolNotFoundError) as exc_info: - await tool_service.invoke_tool(test_db, "test_tool", {}) + await tool_service.invoke_tool(test_db, "test_tool", {}, request_headers=None) assert "Tool 'test_tool' exists but is inactive" in str(exc_info.value) @@ -1169,7 +1169,7 @@ async def test_invoke_tool_rest_get(self, tool_service, mock_tool, test_db): tool_service._record_tool_metric = AsyncMock() # -------------- invoke ----------------- - result = await tool_service.invoke_tool(test_db, "test_tool", {}) + result = await tool_service.invoke_tool(test_db, "test_tool", {}, request_headers=None) # ------------- asserts ----------------- tool_service._http_client.get.assert_called_once_with( @@ -1192,7 +1192,7 @@ async def test_invoke_tool_rest_get(self, tool_service, mock_tool, test_db): tool_service._record_tool_metric = AsyncMock() # -------------- invoke ----------------- - result = await tool_service.invoke_tool(test_db, "test_tool", {}) + result = await tool_service.invoke_tool(test_db, "test_tool", {}, request_headers=None) assert result.content[0].text == "Request completed successfully (No Content)" @@ -1208,7 +1208,7 @@ async def test_invoke_tool_rest_get(self, tool_service, mock_tool, test_db): tool_service._record_tool_metric = AsyncMock() # -------------- invoke ----------------- - result = await tool_service.invoke_tool(test_db, "test_tool", {}) + result = await tool_service.invoke_tool(test_db, "test_tool", {}, request_headers=None) assert result.content[0].text == "Tool error encountered" @@ -1240,7 +1240,7 @@ async def test_invoke_tool_rest_post(self, tool_service, mock_tool, test_db): # Mock extract_using_jq to return the input unmodified when filter is empty with patch("mcpgateway.services.tool_service.decode_auth", return_value={}), patch("mcpgateway.config.extract_using_jq", return_value={"result": "REST tool response"}): # Invoke tool - result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}) + result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) # Verify HTTP request tool_service._http_client.request.assert_called_once_with( @@ -1286,7 +1286,7 @@ async def test_invoke_tool_rest_parameter_substitution(self, tool_service, mock_ tool_service._http_client.request = AsyncMock(return_value=mock_response) - await tool_service.invoke_tool(test_db, "test_tool", payload) + await tool_service.invoke_tool(test_db, "test_tool", payload, request_headers=None) tool_service._http_client.request.assert_called_once_with( "POST", @@ -1313,7 +1313,7 @@ async def test_invoke_tool_rest_parameter_substitution_missed_input(self, tool_s test_db.execute = Mock(return_value=mock_scalar) with pytest.raises(ToolInvocationError) as exc_info: - await tool_service.invoke_tool(test_db, "test_tool", payload) + await tool_service.invoke_tool(test_db, "test_tool", payload, request_headers=None) assert "Required URL parameter 'type' not found in arguments" in str(exc_info.value) @@ -1382,7 +1382,7 @@ async def mock_streamable_client(*_args, **_kwargs): # ------------------------------------------------------------------ # 4. Act # ------------------------------------------------------------------ - result = await tool_service.invoke_tool(test_db, "dummy_tool", {"param": "value"}) + result = await tool_service.invoke_tool(test_db, "dummy_tool", {"param": "value"}, request_headers=None) session_mock.initialize.assert_awaited_once() session_mock.call_tool.assert_awaited_once_with("dummy_tool", {"param": "value"}) @@ -1468,7 +1468,7 @@ def execute_side_effect(*_args, **_kwargs): # ------------------------------------------------------------------ # 4. Act # ------------------------------------------------------------------ - result = await tool_service.invoke_tool(test_db, "dummy_tool", {"param": "value"}) + result = await tool_service.invoke_tool(test_db, "dummy_tool", {"param": "value"}, request_headers=None) # Our ToolResult bubbled back out assert result.content[0].text == "" @@ -1517,7 +1517,7 @@ async def test_invoke_tool_invalid_tool_type(self, tool_service, mock_tool, test mock_scalar.scalar_one_or_none.return_value = mock_tool test_db.execute = Mock(return_value=mock_scalar) - response = await tool_service.invoke_tool(test_db, "test_tool", payload) + response = await tool_service.invoke_tool(test_db, "test_tool", payload, request_headers=None) assert response.content[0].text == "Invalid tool type" @@ -1581,7 +1581,7 @@ async def test_invoke_tool_mcp_tool_basic_auth(self, tool_service, mock_tool, mo # ------------------------------------------------------------------ # 4. Act # ------------------------------------------------------------------ - result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}) + result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) session_mock.initialize.assert_awaited_once() session_mock.call_tool.assert_awaited_once_with("test_tool", {"param": "value"}) @@ -1616,7 +1616,7 @@ async def test_invoke_tool_error(self, tool_service, mock_tool, test_db): # Should raise ToolInvocationError with pytest.raises(ToolInvocationError) as exc_info: - await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}) + await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) assert "Tool invocation failed: HTTP error" in str(exc_info.value) diff --git a/tests/unit/mcpgateway/test_main.py b/tests/unit/mcpgateway/test_main.py index 436b8f19a..34d471a75 100644 --- a/tests/unit/mcpgateway/test_main.py +++ b/tests/unit/mcpgateway/test_main.py @@ -897,7 +897,7 @@ def test_rpc_tool_invocation(self, mock_invoke_tool, test_client, auth_headers): assert response.status_code == 200 body = response.json() assert body["content"][0]["text"] == "Tool response" - mock_invoke_tool.assert_called_once_with(db=ANY, name="test_tool", arguments={"param": "value"}) + mock_invoke_tool.assert_called_once_with(db=ANY, name="test_tool", arguments={"param": "value"}, request_headers=ANY) @patch("mcpgateway.main.prompt_service.get_prompt") # @patch("mcpgateway.main.validate_request") diff --git a/tests/unit/mcpgateway/utils/test_passthrough_headers.py b/tests/unit/mcpgateway/utils/test_passthrough_headers.py new file mode 100644 index 000000000..77ddd69f3 --- /dev/null +++ b/tests/unit/mcpgateway/utils/test_passthrough_headers.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +"""Unit tests for HTTP header passthrough functionality. + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +This module contains comprehensive unit tests for the passthrough_headers utility +module, covering all scenarios including configuration priorities, conflict +detection, case sensitivity, and security features. +""" + +# Standard +import logging +from unittest.mock import Mock, patch + +# Third-Party +import pytest + +# First-Party +from mcpgateway.db import Gateway as DbGateway, GlobalConfig +from mcpgateway.utils.passthrough_headers import get_passthrough_headers + + +class TestPassthroughHeaders: + """Test suite for HTTP header passthrough functionality.""" + + def test_basic_header_passthrough_global_config(self): + """Test basic header passthrough with global configuration.""" + # Mock database and global config + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"] + mock_db.query.return_value.first.return_value = mock_global_config + + request_headers = { + "x-tenant-id": "acme-corp", + "x-trace-id": "trace-456", + "user-agent": "TestClient/1.0" # Not in allowed headers + } + base_headers = {"Content-Type": "application/json"} + + result = get_passthrough_headers(request_headers, base_headers, mock_db) + + expected = { + "Content-Type": "application/json", + "X-Tenant-Id": "acme-corp", + "X-Trace-Id": "trace-456" + } + assert result == expected + + def test_gateway_specific_override(self): + """Test that gateway-specific headers override global configuration.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"] + mock_db.query.return_value.first.return_value = mock_global_config + + # Gateway with custom headers + mock_gateway = Mock(spec=DbGateway) + mock_gateway.passthrough_headers = ["X-Custom-Header"] + mock_gateway.auth_type = None + + request_headers = { + "x-custom-header": "custom-value", + "x-tenant-id": "should-be-ignored", # Not in gateway config + "x-trace-id": "also-ignored" + } + base_headers = {"Content-Type": "application/json"} + + result = get_passthrough_headers(request_headers, base_headers, mock_db, mock_gateway) + + expected = { + "Content-Type": "application/json", + "X-Custom-Header": "custom-value" + } + assert result == expected + + def test_authorization_conflict_basic_auth(self, caplog): + """Test that Authorization header is blocked when gateway uses basic auth.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["Authorization", "X-Tenant-Id"] + mock_db.query.return_value.first.return_value = mock_global_config + + mock_gateway = Mock(spec=DbGateway) + mock_gateway.passthrough_headers = ["Authorization", "X-Tenant-Id"] + mock_gateway.auth_type = "basic" + mock_gateway.name = "test-gateway" + + request_headers = { + "authorization": "Bearer should-be-blocked", + "x-tenant-id": "acme-corp" + } + base_headers = {"Content-Type": "application/json"} + + with caplog.at_level(logging.WARNING): + result = get_passthrough_headers(request_headers, base_headers, mock_db, mock_gateway) + + # Authorization should be blocked, X-Tenant-Id should pass through + expected = { + "Content-Type": "application/json", + "X-Tenant-Id": "acme-corp" + } + assert result == expected + + # Check warning was logged + assert any("Skipping Authorization header passthrough due to basic auth" in record.message + for record in caplog.records) + + def test_authorization_conflict_bearer_auth(self, caplog): + """Test that Authorization header is blocked when gateway uses bearer auth.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["Authorization"] + mock_db.query.return_value.first.return_value = mock_global_config + + mock_gateway = Mock(spec=DbGateway) + mock_gateway.passthrough_headers = None # Use global + mock_gateway.auth_type = "bearer" + mock_gateway.name = "bearer-gateway" + + request_headers = {"authorization": "Bearer should-be-blocked"} + base_headers = {"Content-Type": "application/json"} + + with caplog.at_level(logging.WARNING): + result = get_passthrough_headers(request_headers, base_headers, mock_db, mock_gateway) + + # Only base headers should remain + expected = {"Content-Type": "application/json"} + assert result == expected + + # Check warning was logged + assert any("Skipping Authorization header passthrough due to bearer auth" in record.message + for record in caplog.records) + + def test_base_header_conflict_prevention(self, caplog): + """Test that request headers don't override base headers.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["Content-Type", "X-Tenant-Id"] + mock_db.query.return_value.first.return_value = mock_global_config + + request_headers = { + "content-type": "text/plain", # Conflicts with base header + "x-tenant-id": "acme-corp" # Should pass through + } + base_headers = {"Content-Type": "application/json"} + + with caplog.at_level(logging.WARNING): + result = get_passthrough_headers(request_headers, base_headers, mock_db) + + # Base header preserved, tenant ID added + expected = { + "Content-Type": "application/json", + "X-Tenant-Id": "acme-corp" + } + assert result == expected + + # Check conflict warning was logged + assert any("conflicts with pre-defined headers" in record.message + for record in caplog.records) + + def test_case_insensitive_header_matching(self): + """Test that header matching works with lowercase request headers.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["X-Tenant-ID", "Authorization"] + mock_db.query.return_value.first.return_value = mock_global_config + + # Request headers are expected to be normalized to lowercase + request_headers = { + "x-tenant-id": "mixed-case-value", # Lowercase key + "authorization": "bearer lowercase-header" + } + base_headers = {} + + result = get_passthrough_headers(request_headers, base_headers, mock_db) + + # Headers should preserve config case in output keys + expected = { + "X-Tenant-ID": "mixed-case-value", + "Authorization": "bearer lowercase-header" + } + assert result == expected + + def test_missing_request_headers(self, caplog): + """Test behavior when configured headers are missing from request.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["X-Missing", "X-Present"] + mock_db.query.return_value.first.return_value = mock_global_config + + request_headers = {"x-present": "present-value"} + base_headers = {"Content-Type": "application/json"} + + with caplog.at_level(logging.WARNING): + result = get_passthrough_headers(request_headers, base_headers, mock_db) + + # Only present header should be included + expected = { + "Content-Type": "application/json", + "X-Present": "present-value" + } + assert result == expected + + # Check warning for missing header + assert any("Header X-Missing not found in request headers" in record.message + for record in caplog.records) + + def test_empty_allowed_headers(self): + """Test behavior with empty allowed headers configuration.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = [] + mock_db.query.return_value.first.return_value = mock_global_config + + request_headers = {"x-tenant-id": "should-be-ignored"} + base_headers = {"Content-Type": "application/json"} + + result = get_passthrough_headers(request_headers, base_headers, mock_db) + + # Only base headers should remain + expected = {"Content-Type": "application/json"} + assert result == expected + + def test_none_allowed_headers(self): + """Test behavior when allowed headers is None.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = None + mock_db.query.return_value.first.return_value = mock_global_config + + request_headers = {"x-tenant-id": "should-be-ignored"} + base_headers = {"Content-Type": "application/json"} + + # Mock settings fallback + with patch('mcpgateway.utils.passthrough_headers.settings') as mock_settings: + mock_settings.default_passthrough_headers = ["X-Default"] + + result = get_passthrough_headers(request_headers, base_headers, mock_db) + + # Should fall back to settings, but request doesn't have X-Default + expected = {"Content-Type": "application/json"} + assert result == expected + + def test_no_global_config_fallback_to_settings(self): + """Test fallback to settings when no global config exists.""" + mock_db = Mock() + mock_db.query.return_value.first.return_value = None # No global config + + request_headers = {"x-default": "default-value"} + base_headers = {"Content-Type": "application/json"} + + # Mock settings fallback + with patch('mcpgateway.utils.passthrough_headers.settings') as mock_settings: + mock_settings.default_passthrough_headers = ["X-Default"] + + result = get_passthrough_headers(request_headers, base_headers, mock_db) + + expected = { + "Content-Type": "application/json", + "X-Default": "default-value" + } + assert result == expected + + def test_empty_request_headers(self): + """Test behavior with empty request headers.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["X-Tenant-Id"] + mock_db.query.return_value.first.return_value = mock_global_config + + request_headers = {} + base_headers = {"Content-Type": "application/json"} + + result = get_passthrough_headers(request_headers, base_headers, mock_db) + + # Only base headers should remain + expected = {"Content-Type": "application/json"} + assert result == expected + + def test_none_request_headers(self): + """Test behavior with None request headers.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["X-Tenant-Id"] + mock_db.query.return_value.first.return_value = mock_global_config + + request_headers = None + base_headers = {"Content-Type": "application/json"} + + result = get_passthrough_headers(request_headers, base_headers, mock_db) + + # Only base headers should remain + expected = {"Content-Type": "application/json"} + assert result == expected + + def test_base_headers_not_modified(self): + """Test that original base_headers dictionary is not modified.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["X-Tenant-Id"] + mock_db.query.return_value.first.return_value = mock_global_config + + request_headers = {"x-tenant-id": "acme-corp"} + base_headers = {"Content-Type": "application/json"} + original_base = base_headers.copy() + + result = get_passthrough_headers(request_headers, base_headers, mock_db) + + # Original base_headers should not be modified + assert base_headers == original_base + + # Result should include both base and passthrough headers + assert "Content-Type" in result + assert "X-Tenant-Id" in result + + def test_multiple_auth_type_conflicts(self, caplog): + """Test various auth type conflict scenarios.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["Authorization"] + mock_db.query.return_value.first.return_value = mock_global_config + + request_headers = {"authorization": "Bearer token"} + base_headers = {} + + # Test with different auth types + auth_types = ["basic", "bearer", "api-key", None] + + for auth_type in auth_types: + caplog.clear() + mock_gateway = Mock(spec=DbGateway) + mock_gateway.passthrough_headers = None + mock_gateway.auth_type = auth_type + mock_gateway.name = f"gateway-{auth_type or 'none'}" + + with caplog.at_level(logging.WARNING): + result = get_passthrough_headers(request_headers, base_headers, mock_db, mock_gateway) + + if auth_type in ["basic", "bearer"]: + # Authorization should be blocked + assert "Authorization" not in result + assert any("Skipping Authorization header passthrough" in record.message + for record in caplog.records) + else: + # Authorization should pass through + assert result.get("Authorization") == "Bearer token" + + def test_complex_mixed_scenario(self): + """Test complex scenario with multiple headers, conflicts, and overrides.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["Authorization", "X-Global", "X-Conflict"] + mock_db.query.return_value.first.return_value = mock_global_config + + mock_gateway = Mock(spec=DbGateway) + mock_gateway.passthrough_headers = ["X-Gateway", "X-Conflict", "Authorization"] + mock_gateway.auth_type = "basic" # Will block Authorization + mock_gateway.name = "complex-gateway" + + request_headers = { + "authorization": "Bearer token", # Blocked by basic auth + "x-global": "global-value", # Not in gateway config, ignored + "x-gateway": "gateway-value", # Should pass through + "x-conflict": "conflict-value", # Should pass through (in both configs) + "x-random": "random-value" # Not configured, ignored + } + base_headers = { + "Content-Type": "application/json", + "User-Agent": "MCPGateway/1.0" + } + + result = get_passthrough_headers(request_headers, base_headers, mock_db, mock_gateway) + + expected = { + "Content-Type": "application/json", + "User-Agent": "MCPGateway/1.0", + "X-Gateway": "gateway-value", + "X-Conflict": "conflict-value" + } + assert result == expected + + def test_database_query_called_correctly(self): + """Test that database is queried correctly for GlobalConfig.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = [] + mock_db.query.return_value.first.return_value = mock_global_config + + get_passthrough_headers({}, {}, mock_db) + + # Verify database was queried for GlobalConfig + mock_db.query.assert_called_once_with(GlobalConfig) + mock_db.query.return_value.first.assert_called_once() + + def test_logging_levels(self, caplog): + """Test that appropriate log levels are used for different scenarios.""" + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["X-Missing", "Authorization", "X-Conflict"] + mock_db.query.return_value.first.return_value = mock_global_config + + mock_gateway = Mock(spec=DbGateway) + mock_gateway.passthrough_headers = None + mock_gateway.auth_type = "basic" + mock_gateway.name = "test-gateway" + + request_headers = { + "authorization": "Bearer token", # Will be blocked by basic auth + "x-conflict": "request-value" # Will conflict with base header + } + base_headers = {"X-Conflict": "base-value"} # Conflicts with x-conflict + + with caplog.at_level(logging.WARNING): + get_passthrough_headers(request_headers, base_headers, mock_db, mock_gateway) + + # Should have warnings for: missing header, auth conflict, base header conflict + warning_messages = [record.message for record in caplog.records if record.levelno == logging.WARNING] + + assert len(warning_messages) == 3 + assert any("not found in request headers" in msg for msg in warning_messages) + assert any("due to basic auth" in msg for msg in warning_messages) + assert any("conflicts with pre-defined headers" in msg for msg in warning_messages)