Skip to content

Commit 43f4fb2

Browse files
sse endpoint validation + test
1 parent 0126d1b commit 43f4fb2

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

src/mcp/server/sse.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, endpoint: str) -> None:
7878
messages to the relative path given.
7979
8080
Args:
81-
endpoint: A relative path where messages should be posted
81+
endpoint: A relative path where messages should be posted
8282
(e.g., "/messages/").
8383
8484
Note:
@@ -97,9 +97,10 @@ def __init__(self, endpoint: str) -> None:
9797
super().__init__()
9898

9999
# Validate that endpoint is a relative path and not a full URL
100-
if "://" in endpoint or endpoint.startswith("//"):
100+
if "://" in endpoint or endpoint.startswith("//") or "?" in endpoint or "#" in endpoint:
101101
raise ValueError(
102-
"Endpoint must be a relative path (e.g., '/messages/'), not a full URL."
102+
f"Given endpoint: {endpoint} is not a relative path (e.g., '/messages/'), \
103+
expecting a relative path(e.g., '/messages/')."
103104
)
104105

105106
# Ensure endpoint starts with a forward slash

tests/shared/test_sse.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,31 @@ def test_sse_message_id_coercion():
464464
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
465465
msg = types.JSONRPCMessage.model_validate_json(json_message)
466466
assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)))
467+
468+
469+
@pytest.mark.parametrize(
470+
"endpoint, expected_result",
471+
[
472+
# Valid endpoints - should normalize and work
473+
("/messages/", "/messages/"),
474+
("messages/", "/messages/"),
475+
("/", "/"),
476+
# Invalid endpoints - should raise ValueError
477+
("http://example.com/messages/", ValueError),
478+
("//example.com/messages/", ValueError),
479+
("ftp://example.com/messages/", ValueError),
480+
("/messages/?param=value", ValueError),
481+
("/messages/#fragment", ValueError),
482+
],
483+
)
484+
def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]):
485+
"""Test that SseServerTransport properly validates and normalizes endpoints."""
486+
if isinstance(expected_result, type) and issubclass(expected_result, Exception):
487+
# Test invalid endpoints that should raise an exception
488+
with pytest.raises(expected_result, match="is not a relative path.*expecting a relative path"):
489+
SseServerTransport(endpoint)
490+
else:
491+
# Test valid endpoints that should normalize correctly
492+
sse = SseServerTransport(endpoint)
493+
assert sse._endpoint == expected_result
494+
assert sse._endpoint.startswith("/")

0 commit comments

Comments
 (0)