diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 577ed1b9a..60abdb8d9 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -87,6 +87,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]): debug: bool = False log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + # STDIO settings + stateless_stdio: bool = False + # HTTP settings host: str = "127.0.0.1" port: int = 8000 @@ -597,6 +600,7 @@ async def run_stdio_async(self) -> None: read_stream, write_stream, self._mcp_server.create_initialization_options(), + stateless=self.settings.stateless_stdio, ) async def run_sse_async(self, mount_path: str | None = None) -> None: diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 1375df12f..422c50de1 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -200,3 +200,106 @@ async def mock_client(): assert received_initialized assert received_protocol_version == "2024-11-05" + + +@pytest.mark.anyio +async def test_server_session_requires_initialization(): + """Test that ServerSession requires initialization before accepting requests.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | Exception + ](1) + + try: + init_options = InitializationOptions( + server_name="TestServer", + server_version="1.0", + capabilities=ServerCapabilities(), + ) + + async with ServerSession( + client_to_server_receive, + server_to_client_send, + init_options, + stateless=False, + ) as server_session: + request = types.ClientRequest( + root=types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name="test_tool", arguments={}), + ) + ) + + responder = RequestResponder( + request_id="test-id", + request_meta=None, # Using None instead of {} to fix type error + request=request, + session=server_session, + on_complete=lambda _: None, + ) + + with pytest.raises(RuntimeError) as excinfo: + await server_session._received_request(responder) + + assert "initialization" in str(excinfo.value).lower() + assert "before initialization was complete" in str(excinfo.value) + finally: + # Clean up the streams to prevent ResourceWarning + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + + +@pytest.mark.anyio +async def test_server_session_stateless_mode(): + """Test that ServerSession in stateless mode doesn't require initialization.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | Exception + ](1) + + try: + init_options = InitializationOptions( + server_name="TestServer", + server_version="1.0", + capabilities=ServerCapabilities(), + ) + + async with ServerSession( + client_to_server_receive, + server_to_client_send, + init_options, + stateless=True, + ) as server_session: + request = types.ClientRequest( + root=types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name="test_tool", arguments={}), + ) + ) + + responder = RequestResponder( + request_id="test-id", + request_meta=None, # Using None instead of {} to fix type error + request=request, + session=server_session, + on_complete=lambda _: None, + ) + + try: + await server_session._received_request(responder) + except RuntimeError as e: + if "initialization" in str(e).lower(): + msg = f"Unexpected initialization error in stateless mode: {e}" + pytest.fail(msg) + finally: + # Clean up the streams to prevent ResourceWarning + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index c546a7167..4f03057a5 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,8 +1,12 @@ import io +import tempfile +from pathlib import Path import anyio import pytest +from mcp.client.session import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.server.stdio import stdio_server from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse @@ -68,3 +72,37 @@ async def test_stdio_server(): assert received_responses[1] == JSONRPCMessage( root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) ) + + +@pytest.mark.anyio +async def test_stateless_stdio(): + """Test that stateless stdio mode allows tool calls without initialization.""" + with tempfile.TemporaryDirectory() as temp_dir: + server_path = Path(temp_dir) / "server.py" + + with open(server_path, "w") as f: + f.write(""" +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP("StatelessServer") +mcp.settings.stateless_stdio = True + +@mcp.tool() +def echo(message: str) -> str: + return f"Echo: {message}" + +if __name__ == "__main__": + mcp.run() +""") + + server_params = StdioServerParameters( + command="python", + args=[str(server_path)], + ) + + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + result = await session.call_tool("echo", {"message": "hello"}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert getattr(result.content[0], "text") == "Echo: hello"