Skip to content

Commit 1579db4

Browse files
refactor: Simplify StreamableHTTPTransport to respect client configuration
Implements "principle of least surprise" by making the httpx client the single source of truth for HTTP configuration (headers, timeout, auth). Changes: - StreamableHTTPTransport constructor now only takes url parameter - Transport reads configuration from client when making requests - Removed redundant config extraction and storage - Removed headers and sse_read_timeout from RequestContext - Removed MCP_DEFAULT_TIMEOUT and MCP_DEFAULT_SSE_READ_TIMEOUT from _httpx_utils public API (__all__) This addresses PR feedback about awkward config extraction when client is provided. The transport now only adds protocol requirements (MCP headers, session headers) on top of the client's configuration rather than extracting and overriding it. All tests pass, no type errors.
1 parent a14eeb2 commit 1579db4

File tree

2 files changed

+26
-54
lines changed

2 files changed

+26
-54
lines changed

src/mcp/client/streamable_http.py

Lines changed: 25 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,7 @@
2020
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
2121
from typing_extensions import deprecated
2222

23-
from mcp.shared._httpx_utils import (
24-
MCP_DEFAULT_SSE_READ_TIMEOUT,
25-
MCP_DEFAULT_TIMEOUT,
26-
McpHttpClientFactory,
27-
create_mcp_http_client,
28-
)
23+
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
2924
from mcp.shared.message import ClientMessageMetadata, SessionMessage
3025
from mcp.types import (
3126
ErrorData,
@@ -70,52 +65,41 @@ class RequestContext:
7065
"""Context for a request operation."""
7166

7267
client: httpx.AsyncClient
73-
headers: dict[str, str]
7468
session_id: str | None
7569
session_message: SessionMessage
7670
metadata: ClientMessageMetadata | None
7771
read_stream_writer: StreamWriter
78-
sse_read_timeout: float
7972

8073

8174
class StreamableHTTPTransport:
8275
"""StreamableHTTP client transport implementation."""
8376

84-
def __init__(
85-
self,
86-
url: str,
87-
headers: dict[str, str] | None = None,
88-
timeout: float | timedelta = 30,
89-
sse_read_timeout: float | timedelta = 60 * 5,
90-
auth: httpx.Auth | None = None,
91-
) -> None:
77+
def __init__(self, url: str) -> None:
9278
"""Initialize the StreamableHTTP transport.
9379
9480
Args:
9581
url: The endpoint URL.
96-
headers: Optional headers to include in requests.
97-
timeout: HTTP timeout for regular operations.
98-
sse_read_timeout: Timeout for SSE read operations.
99-
auth: Optional HTTPX authentication handler.
10082
"""
10183
self.url = url
102-
self.headers = headers or {}
103-
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
104-
self.sse_read_timeout = (
105-
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
106-
)
107-
self.auth = auth
10884
self.session_id = None
10985
self.protocol_version = None
110-
self.request_headers = {
111-
**self.headers,
112-
ACCEPT: f"{JSON}, {SSE}",
113-
CONTENT_TYPE: JSON,
114-
}
115-
116-
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
117-
"""Update headers with session ID and protocol version if available."""
118-
headers = base_headers.copy()
86+
87+
def _prepare_headers(self, client: httpx.AsyncClient) -> dict[str, str]:
88+
"""Build request headers from client config and protocol requirements.
89+
90+
Merges the client's default headers with MCP protocol headers and session headers.
91+
92+
Args:
93+
client: The httpx client whose headers to use as base.
94+
95+
Returns:
96+
Complete headers dict with client headers, protocol headers, and session headers.
97+
"""
98+
headers = dict(client.headers) if client.headers else {}
99+
# Add MCP protocol headers
100+
headers[ACCEPT] = f"{JSON}, {SSE}"
101+
headers[CONTENT_TYPE] = JSON
102+
# Add session headers if available
119103
if self.session_id:
120104
headers[MCP_SESSION_ID] = self.session_id
121105
if self.protocol_version:
@@ -206,14 +190,13 @@ async def handle_get_stream(
206190
if not self.session_id:
207191
return
208192

209-
headers = self._prepare_request_headers(self.request_headers)
193+
headers = self._prepare_headers(client)
210194

211195
async with aconnect_sse(
212196
client,
213197
"GET",
214198
self.url,
215199
headers=headers,
216-
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
217200
) as event_source:
218201
event_source.response.raise_for_status()
219202
logger.debug("GET SSE connection established")
@@ -226,7 +209,7 @@ async def handle_get_stream(
226209

227210
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
228211
"""Handle a resumption request using GET with SSE."""
229-
headers = self._prepare_request_headers(ctx.headers)
212+
headers = self._prepare_headers(ctx.client)
230213
if ctx.metadata and ctx.metadata.resumption_token:
231214
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
232215
else:
@@ -242,7 +225,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
242225
"GET",
243226
self.url,
244227
headers=headers,
245-
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
246228
) as event_source:
247229
event_source.response.raise_for_status()
248230
logger.debug("Resumption GET SSE connection established")
@@ -260,7 +242,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
260242

261243
async def _handle_post_request(self, ctx: RequestContext) -> None:
262244
"""Handle a POST request with response processing."""
263-
headers = self._prepare_request_headers(ctx.headers)
245+
headers = self._prepare_headers(ctx.client)
264246
message = ctx.session_message.message
265247
is_initialization = self._is_initialization_request(message)
266248

@@ -401,12 +383,10 @@ async def post_writer(
401383

402384
ctx = RequestContext(
403385
client=client,
404-
headers=self.request_headers,
405386
session_id=self.session_id,
406387
session_message=session_message,
407388
metadata=metadata,
408389
read_stream_writer=read_stream_writer,
409-
sse_read_timeout=self.sse_read_timeout,
410390
)
411391

412392
async def handle_request_async():
@@ -433,7 +413,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
433413
return
434414

435415
try:
436-
headers = self._prepare_request_headers(self.request_headers)
416+
headers = self._prepare_headers(client)
437417
response = await client.delete(self.url, headers=headers)
438418

439419
if response.status_code == 405:
@@ -493,16 +473,8 @@ async def streamable_http_client(
493473
# Create default client with recommended MCP timeouts
494474
client = create_mcp_http_client()
495475

496-
# Extract configuration from the client to pass to transport
497-
headers_dict = dict(client.headers) if client.headers else None
498-
timeout = client.timeout.connect if (client.timeout and client.timeout.connect is not None) else MCP_DEFAULT_TIMEOUT
499-
sse_read_timeout = (
500-
client.timeout.read if (client.timeout and client.timeout.read is not None) else MCP_DEFAULT_SSE_READ_TIMEOUT
501-
)
502-
auth = client.auth
503-
504-
# Create transport with extracted configuration
505-
transport = StreamableHTTPTransport(url, headers_dict, timeout, sse_read_timeout, auth)
476+
# Create transport
477+
transport = StreamableHTTPTransport(url)
506478

507479
async with anyio.create_task_group() as tg:
508480
try:

src/mcp/shared/_httpx_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import httpx
66

7-
__all__ = ["create_mcp_http_client", "MCP_DEFAULT_TIMEOUT", "MCP_DEFAULT_SSE_READ_TIMEOUT"]
7+
__all__ = ["create_mcp_http_client"]
88

99
# Default MCP timeout configuration
1010
MCP_DEFAULT_TIMEOUT = 30.0 # General operations (seconds)

0 commit comments

Comments
 (0)