Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import logging
import os
import sys
from datetime import timedelta
from urllib.parse import ParseResult, parse_qs, urlparse

import httpx
Expand Down Expand Up @@ -263,8 +262,8 @@ async def _run_session(server_url: str, oauth_auth: OAuthClientProvider) -> None
async with streamablehttp_client(
url=server_url,
auth=oauth_auth,
timeout=timedelta(seconds=30),
sse_read_timeout=timedelta(seconds=60),
timeout=30.0,
sse_read_timeout=60.0,
) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import threading
import time
import webbrowser
from datetime import timedelta
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from urllib.parse import parse_qs, urlparse
Expand Down Expand Up @@ -215,7 +214,7 @@ async def _default_redirect_handler(authorization_url: str) -> None:
async with streamablehttp_client(
url=self.server_url,
auth=oauth_auth,
timeout=timedelta(seconds=60),
timeout=60.0,
) as (read_stream, write_stream, get_session_id):
await self._run_session(read_stream, write_stream, get_session_id)

Expand Down
5 changes: 2 additions & 3 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from datetime import timedelta
from typing import Any, Protocol, overload

import anyio.lowlevel
Expand Down Expand Up @@ -113,7 +112,7 @@ def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
Expand Down Expand Up @@ -369,7 +368,7 @@ async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
Expand Down
25 changes: 12 additions & 13 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import logging
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from types import TracebackType
from typing import Any, TypeAlias, overload

Expand All @@ -39,11 +38,11 @@ class SseServerParameters(BaseModel):
# Optional headers to include in requests.
headers: dict[str, Any] | None = None

# HTTP timeout for regular operations.
timeout: float = 5
# HTTP timeout for regular operations (in seconds).
timeout: float = 5.0

# Timeout for SSE read operations.
sse_read_timeout: float = 60 * 5
# Timeout for SSE read operations (in seconds).
sse_read_timeout: float = 300.0


class StreamableHttpParameters(BaseModel):
Expand All @@ -55,11 +54,11 @@ class StreamableHttpParameters(BaseModel):
# Optional headers to include in requests.
headers: dict[str, Any] | None = None

# HTTP timeout for regular operations.
timeout: timedelta = timedelta(seconds=30)
# HTTP timeout for regular operations (in seconds).
timeout: float = 30.0

# Timeout for SSE read operations.
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
# Timeout for SSE read operations (in seconds).
sse_read_timeout: float = 300.0

# Close the client session when the transport closes.
terminate_on_close: bool = True
Expand All @@ -74,7 +73,7 @@ class StreamableHttpParameters(BaseModel):
class ClientSessionParameters:
"""Parameters for establishing a client session to an MCP server."""

read_timeout_seconds: timedelta | None = None
read_timeout_seconds: float | None = None
sampling_callback: SamplingFnT | None = None
elicitation_callback: ElicitationFnT | None = None
list_roots_callback: ListRootsFnT | None = None
Expand Down Expand Up @@ -195,7 +194,7 @@ async def call_tool(
self,
name: str,
arguments: dict[str, Any],
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
Expand All @@ -208,7 +207,7 @@ async def call_tool(
name: str,
*,
args: dict[str, Any],
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult: ...
Expand All @@ -217,7 +216,7 @@ async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
Expand Down
8 changes: 4 additions & 4 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
timeout: float = 5.0,
sse_read_timeout: float = 300.0,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
on_session_created: Callable[[str], None] | None = None,
Expand All @@ -46,8 +46,8 @@ async def sse_client(
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
timeout: HTTP timeout for regular operations (in seconds).
sse_read_timeout: Timeout for SSE read operations (in seconds).
auth: Optional HTTPX authentication handler.
on_session_created: Optional callback invoked with the session ID when received.
"""
Expand Down
19 changes: 8 additions & 11 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta

import anyio
import httpx
Expand Down Expand Up @@ -82,25 +81,23 @@ def __init__(
self,
url: str,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
timeout: float = 30.0,
sse_read_timeout: float = 300.0,
auth: httpx.Auth | None = None,
) -> None:
"""Initialize the StreamableHTTP transport.

Args:
url: The endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
timeout: HTTP timeout for regular operations (in seconds).
sse_read_timeout: Timeout for SSE read operations (in seconds).
auth: Optional HTTPX authentication handler.
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.auth = auth
self.session_id = None
self.protocol_version = None
Expand Down Expand Up @@ -563,8 +560,8 @@ def get_session_id(self) -> str | None:
async def streamablehttp_client(
url: str,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
timeout: float = 30.0,
sse_read_timeout: float = 300.0,
terminate_on_close: bool = True,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
Expand Down
3 changes: 1 addition & 2 deletions src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import Any

import anyio
Expand Down Expand Up @@ -49,7 +48,7 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS
@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server[Any] | FastMCP,
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
Expand Down
9 changes: 4 additions & 5 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from collections.abc import Callable
from contextlib import AsyncExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Protocol, TypeVar

Expand Down Expand Up @@ -189,7 +188,7 @@ def __init__(
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
Expand Down Expand Up @@ -241,7 +240,7 @@ async def send_request(
self,
request: SendRequestT,
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
request_read_timeout_seconds: float | None = None,
metadata: MessageMetadata = None,
progress_callback: ProgressFnT | None = None,
) -> ReceiveResultT:
Expand Down Expand Up @@ -283,9 +282,9 @@ async def send_request(
# request read timeout takes precedence over session read timeout
timeout = None
if request_read_timeout_seconds is not None: # pragma: no cover
timeout = request_read_timeout_seconds.total_seconds()
timeout = request_read_timeout_seconds
elif self._session_read_timeout_seconds is not None: # pragma: no cover
timeout = self._session_read_timeout_seconds.total_seconds()
timeout = self._session_read_timeout_seconds

try:
with anyio.fail_after(timeout):
Expand Down
5 changes: 2 additions & 3 deletions tests/issues/test_88_random_error.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Test to reproduce issue #88: Random error thrown on response."""

from collections.abc import Sequence
from datetime import timedelta
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -93,10 +92,10 @@ async def client(
assert not slow_request_lock.is_set()

# Second call should timeout (slow operation with minimal timeout)
# Use 10ms timeout to trigger quickly without waiting
# Use very small timeout to trigger quickly without waiting
with pytest.raises(McpError) as exc_info:
await session.call_tool(
"slow", read_timeout_seconds=timedelta(microseconds=1)
"slow", read_timeout_seconds=0.000001
) # artificial timeout that always fails
assert "Timed out while waiting" in str(exc_info.value)

Expand Down
4 changes: 1 addition & 3 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,10 @@ async def mock_server():
async def make_request(client_session: ClientSession):
try:
# Use a short timeout since we expect this to fail
from datetime import timedelta

await client_session.send_request(
ClientRequest(types.PingRequest()),
types.EmptyResult,
request_read_timeout_seconds=timedelta(seconds=0.5),
request_read_timeout_seconds=0.5,
)
pytest.fail("Expected timeout") # pragma: no cover
except McpError as e:
Expand Down
Loading