Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions mcpgateway/services/gateway_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@
from datetime import datetime, timezone
import logging
import os
import socket
import tempfile
from typing import Any, AsyncGenerator, Dict, List, Optional, Set, TYPE_CHECKING
from urllib.parse import urlparse, urlunparse
import uuid

# Third-Party
Expand Down Expand Up @@ -248,6 +250,33 @@ def __init__(self) -> None:
else:
self._redis_client = None

@staticmethod
def normalize_url(url: str) -> str:
"""
Normalize a URL by resolving the hostname to its IP address.

Args:
url (str): The URL to normalize.

Returns:
str: The normalized URL with the hostname replaced by its IP address.

Examples:
>>> GatewayService.normalize_url('http://localhost:8080/path')
'http://127.0.0.1:8080/path'
"""
parsed = urlparse(url)
hostname = parsed.hostname
try:
ip = socket.gethostbyname(hostname)
except Exception:
ip = hostname
netloc = ip
if parsed.port:
netloc += f":{parsed.port}"
normalized = parsed._replace(netloc=netloc)
return urlunparse(normalized)

async def _validate_gateway_url(self, url: str, headers: dict, transport_type: str, timeout: Optional[int] = None):
"""
Validate if the given URL is a live Server-Sent Events (SSE) endpoint.
Expand Down Expand Up @@ -393,6 +422,9 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
gateway_id=existing_gateway.id,
)

# Normalize the gateway URL
normalized_url = self.normalize_url(gateway.url)

auth_type = getattr(gateway, "auth_type", None)
# Support multiple custom headers
auth_value = getattr(gateway, "auth_value", {})
Expand All @@ -401,13 +433,13 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")}
auth_value = encode_auth(header_dict) # Encode the dict for consistency

capabilities, tools = await self._initialize_gateway(gateway.url, auth_value, gateway.transport)
capabilities, tools = await self._initialize_gateway(normalized_url, auth_value, gateway.transport)

tools = [
DbTool(
original_name=tool.name,
original_name_slug=slugify(tool.name),
url=gateway.url,
url=normalized_url,
description=tool.description,
integration_type="MCP", # Gateway-discovered tools are MCP type
request_type=tool.request_type,
Expand All @@ -425,7 +457,7 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
db_gateway = DbGateway(
name=gateway.name,
slug=slugify(gateway.name),
url=gateway.url,
url=normalized_url,
description=gateway.description,
tags=gateway.tags,
transport=gateway.transport,
Expand Down Expand Up @@ -566,7 +598,8 @@ async def update_gateway(self, db: Session, gateway_id: str, gateway_update: Gat
gateway.name = gateway_update.name
gateway.slug = slugify(gateway_update.name)
if gateway_update.url is not None:
gateway.url = gateway_update.url
# Normalize the updated URL
gateway.url = self.normalize_url(gateway_update.url)
if gateway_update.description is not None:
gateway.description = gateway_update.description
if gateway_update.transport is not None:
Expand Down
41 changes: 27 additions & 14 deletions tests/unit/mcpgateway/services/test_gateway_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import asyncio
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import socket

# Third-Party
import httpx
Expand Down Expand Up @@ -172,12 +173,13 @@ async def test_register_gateway(self, gateway_service, test_db, monkeypatch):
)
)
gateway_service._notify_gateway_added = AsyncMock()

normalize_url = lambda url: f"http://{socket.gethostbyname(url)}/gateway"
url = normalize_url("example.com")
# Patch GatewayRead.model_validate to return a mock with .masked()
mock_model = Mock()
mock_model.masked.return_value = mock_model
mock_model.name = "test_gateway"
mock_model.url = "http://example.com/gateway"
mock_model.url = url
mock_model.description = "A test gateway"

monkeypatch.setattr(
Expand All @@ -187,7 +189,7 @@ async def test_register_gateway(self, gateway_service, test_db, monkeypatch):

gateway_create = GatewayCreate(
name="test_gateway",
url="http://example.com/gateway",
url=url,
description="A test gateway",
)

Expand All @@ -202,9 +204,10 @@ async def test_register_gateway(self, gateway_service, test_db, monkeypatch):
# `result` is the same GatewayCreate instance because we stubbed
# GatewayRead.model_validate → just check its fields:
assert result.name == "test_gateway"
assert result.url == "http://example.com/gateway"
expected_url = url
assert result.url == expected_url
assert result.description == "A test gateway"

mock_model.url = expected_url
@pytest.mark.asyncio
async def test_register_gateway_name_conflict(self, gateway_service, mock_gateway, test_db):
"""Trying to register a gateway whose *name* already exists raises a conflict error."""
Expand All @@ -229,7 +232,6 @@ async def test_register_gateway_name_conflict(self, gateway_service, mock_gatewa
async def test_register_gateway_connection_error(self, gateway_service, test_db):
"""Initial connection to the remote gateway fails and the error propagates."""
test_db.execute = Mock(return_value=_make_execute_result(scalar=None))

# _initialize_gateway blows up before any DB work happens
gateway_service._initialize_gateway = AsyncMock(side_effect=GatewayConnectionError("Failed to connect"))

Expand Down Expand Up @@ -257,28 +259,39 @@ async def test_register_gateway_with_auth(self, gateway_service, test_db, monkey
test_db.commit = Mock()
test_db.refresh = Mock()

#url = f"http://{socket.gethostbyname('example.com')}/gateway"
normalize_url = lambda url: f"http://{socket.gethostbyname(url)}/gateway"
url = normalize_url("example.com")
print(f"url:{url}")
gateway_service._initialize_gateway = AsyncMock(
return_value=(
{
"prompts": {"listChanged": True},
"resources": {"listChanged": True},
"tools": {"listChanged": True},
},
[],
)
)

gateway_service._notify_gateway_added = AsyncMock()

mock_model = Mock()
mock_model.masked.return_value = mock_model
mock_model.name = "auth_gateway"
mock_model.url = url

monkeypatch.setattr(
"mcpgateway.services.gateway_service.GatewayRead.model_validate",
lambda x: mock_model,
)

gateway_create = GatewayCreate(name="auth_gateway", url="http://example.com/gateway", description="Gateway with auth", auth_type="bearer", auth_token="test-token")
gateway_create = GatewayCreate(
name="auth_gateway",
url=url,
description="Gateway with auth",
auth_type="bearer",
auth_token="test-token"
)

await gateway_service.register_gateway(test_db, gateway_create)

Expand Down Expand Up @@ -973,16 +986,16 @@ async def test_update_gateway_url_change_with_tools(self, gateway_service, mock_

gateway_service._initialize_gateway = AsyncMock(return_value=({"tools": {"listChanged": True}}, new_tools))
gateway_service._notify_gateway_updated = AsyncMock()

gateway_update = GatewayUpdate(url="http://example.com/new-url")
url = GatewayService.normalize_url("http://example.com/new-url")
gateway_update = GatewayUpdate(url=url)

mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read

with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
await gateway_service.update_gateway(test_db, 1, gateway_update)

assert mock_gateway.url == "http://example.com/new-url"
assert mock_gateway.url == url
gateway_service._initialize_gateway.assert_called_once()
test_db.commit.assert_called_once()

Expand All @@ -997,8 +1010,8 @@ async def test_update_gateway_url_initialization_failure(self, gateway_service,
# Mock initialization failure
gateway_service._initialize_gateway = AsyncMock(side_effect=GatewayConnectionError("Connection failed"))
gateway_service._notify_gateway_updated = AsyncMock()

gateway_update = GatewayUpdate(url="http://example.com/bad-url")
url = GatewayService.normalize_url("http://example.com/bad-url")
gateway_update = GatewayUpdate(url=url)

mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read
Expand All @@ -1007,7 +1020,7 @@ async def test_update_gateway_url_initialization_failure(self, gateway_service,
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
await gateway_service.update_gateway(test_db, 1, gateway_update)

assert mock_gateway.url == "http://example.com/bad-url"
assert mock_gateway.url == url
test_db.commit.assert_called_once()

@pytest.mark.asyncio
Expand Down
Loading
Loading