diff --git a/Directory.Packages.props b/Directory.Packages.props index 8ac0a52c..e26d68cd 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -54,6 +54,7 @@ + diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index c9a5ba87..8bff4596 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -22,6 +22,7 @@ public static class HttpMcpServerBuilderExtensions public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder, Action? configureOptions = null) { ArgumentNullException.ThrowIfNull(builder); + builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(); builder.Services.AddHostedService(); diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs index fed2f131..1b854b94 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs @@ -47,13 +47,18 @@ public async ValueTask DisposeAsync() } finally { - if (Server is not null) + try { - await Server.DisposeAsync(); + if (Server is not null) + { + await Server.DisposeAsync(); + } + } + finally + { + await Transport.DisposeAsync(); + _disposeCts.Dispose(); } - - await Transport.DisposeAsync(); - _disposeCts.Dispose(); } } diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 23eeddbe..4880714c 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -26,9 +26,17 @@ public class HttpServerTransportOptions /// Represents the duration of time the server will wait between any active requests before timing out an /// MCP session. This is checked in background every 5 seconds. A client trying to resume a session will /// receive a 404 status code and should restart their session. A client can keep their session open by - /// keeping a GET request open. The default value is set to 2 minutes. + /// keeping a GET request open. The default value is set to 2 hours. /// - public TimeSpan IdleTimeout { get; set; } = TimeSpan.FromMinutes(2); + public TimeSpan IdleTimeout { get; set; } = TimeSpan.FromHours(2); + + /// + /// The maximum number of idle sessions to track. This is used to limit the number of sessions that can be idle at once. + /// Past this limit, the server will log a critical error and terminate the oldest idle sessions even if they have not reached + /// their until the idle session count is below this limit. Clients that keep their session open by + /// keeping a GET request open will not count towards this limit. The default value is set to 100,000 sessions. + /// + public int MaxIdleSessionCount { get; set; } = 100_000; /// /// Used for testing the . diff --git a/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs index df3203b5..d7c57735 100644 --- a/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs +++ b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs @@ -8,28 +8,40 @@ namespace ModelContextProtocol.AspNetCore; internal sealed partial class IdleTrackingBackgroundService( StreamableHttpHandler handler, IOptions options, + IHostApplicationLifetime appLifetime, ILogger logger) : BackgroundService { // The compiler will complain about the parameter being unused otherwise despite the source generator. private ILogger _logger = logger; - // We can make this configurable once we properly harden the MCP server. In the meantime, anyone running - // this should be taking a cattle not pets approach to their servers and be able to launch more processes - // to handle more than 10,000 idle sessions at a time. - private const int MaxIdleSessionCount = 10_000; - protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - var timeProvider = options.Value.TimeProvider; - using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider); + // Still run loop given infinite IdleTimeout to enforce the MaxIdleSessionCount and assist graceful shutdown. + if (options.Value.IdleTimeout != Timeout.InfiniteTimeSpan) + { + ArgumentOutOfRangeException.ThrowIfLessThan(options.Value.IdleTimeout, TimeSpan.Zero); + } + ArgumentOutOfRangeException.ThrowIfLessThan(options.Value.MaxIdleSessionCount, 0); try { + var timeProvider = options.Value.TimeProvider; + using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider); + + var idleTimeoutTicks = options.Value.IdleTimeout.Ticks; + var maxIdleSessionCount = options.Value.MaxIdleSessionCount; + + // The default ValueTuple Comparer will check the first item then the second which preserves both order and uniqueness. + var idleSessions = new SortedSet<(long Timestamp, string SessionId)>(); + while (!stoppingToken.IsCancellationRequested && await timer.WaitForNextTickAsync(stoppingToken)) { - var idleActivityCutoff = timeProvider.GetTimestamp() - options.Value.IdleTimeout.Ticks; + var idleActivityCutoff = idleTimeoutTicks switch + { + < 0 => long.MinValue, + var ticks => timeProvider.GetTimestamp() - ticks, + }; - var idleCount = 0; foreach (var (_, session) in handler.Sessions) { if (session.IsActive || session.SessionClosed.IsCancellationRequested) @@ -38,26 +50,32 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) continue; } - idleCount++; - if (idleCount == MaxIdleSessionCount) - { - // Emit critical log at most once every 5 seconds the idle count it exceeded, - //since the IdleTimeout will no longer be respected. - LogMaxSessionIdleCountExceeded(); - } - else if (idleCount < MaxIdleSessionCount && session.LastActivityTicks > idleActivityCutoff) + if (session.LastActivityTicks < idleActivityCutoff) { + RemoveAndCloseSession(session.Id); continue; } - if (handler.Sessions.TryRemove(session.Id, out var removedSession)) + idleSessions.Add((session.LastActivityTicks, session.Id)); + + // Emit critical log at most once every 5 seconds the idle count it exceeded, + // since the IdleTimeout will no longer be respected. + if (idleSessions.Count == maxIdleSessionCount + 1) { - LogSessionIdle(removedSession.Id); + LogMaxSessionIdleCountExceeded(maxIdleSessionCount); + } + } - // Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown. - _ = DisposeSessionAsync(removedSession); + if (idleSessions.Count > maxIdleSessionCount) + { + var sessionsToPrune = idleSessions.ToArray()[..^maxIdleSessionCount]; + foreach (var (_, id) in sessionsToPrune) + { + RemoveAndCloseSession(id); } } + + idleSessions.Clear(); } } catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested) @@ -65,7 +83,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) } finally { - if (stoppingToken.IsCancellationRequested) + try { List disposeSessionTasks = []; @@ -79,7 +97,29 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) await Task.WhenAll(disposeSessionTasks); } + finally + { + if (!stoppingToken.IsCancellationRequested) + { + // Something went terribly wrong. A very unexpected exception must be bubbling up, but let's ensure we also stop the application, + // so that it hopefully gets looked at and restarted. This shouldn't really be reachable. + appLifetime.StopApplication(); + IdleTrackingBackgroundServiceStoppedUnexpectedly(); + } + } + } + } + + private void RemoveAndCloseSession(string sessionId) + { + if (!handler.Sessions.TryRemove(sessionId, out var session)) + { + return; } + + LogSessionIdle(session.Id); + // Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown. + _ = DisposeSessionAsync(session); } private async Task DisposeSessionAsync(HttpMcpSession session) @@ -97,9 +137,12 @@ private async Task DisposeSessionAsync(HttpMcpSession s_errorTypeInfo = GetRequiredJsonTypeInfo(); + private static MediaTypeHeaderValue ApplicationJsonMediaType = new("application/json"); + private static MediaTypeHeaderValue TextEventStreamMediaType = new("text/event-stream"); public ConcurrentDictionary> Sessions { get; } = new(StringComparer.Ordinal); public async Task HandlePostRequestAsync(HttpContext context) { // The Streamable HTTP spec mandates the client MUST accept both application/json and text/event-stream. - // ASP.NET Core Minimal APIs mostly ry to stay out of the business of response content negotiation, so - // we have to do this manually. The spec doesn't mandate that servers MUST reject these requests, but it's - // probably good to at least start out trying to be strict. - var acceptHeader = context.Request.Headers.Accept.ToString(); - if (!acceptHeader.Contains("application/json", StringComparison.Ordinal) || - !acceptHeader.Contains("text/event-stream", StringComparison.Ordinal)) + // ASP.NET Core Minimal APIs mostly try to stay out of the business of response content negotiation, + // so we have to do this manually. The spec doesn't mandate that servers MUST reject these requests, + // but it's probably good to at least start out trying to be strict. + var acceptHeaders = context.Request.GetTypedHeaders().Accept; + if (!acceptHeaders.Contains(ApplicationJsonMediaType) || !acceptHeaders.Contains(TextEventStreamMediaType)) { await WriteJsonRpcErrorAsync(context, "Not Acceptable: Client must accept both application/json and text/event-stream", @@ -49,9 +51,8 @@ await WriteJsonRpcErrorAsync(context, } using var _ = session.AcquireReference(); - using var cts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, session.SessionClosed); InitializeSseResponse(context); - var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), cts.Token); + var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted); if (!wroteResponse) { // We wound up writing nothing, so there should be no Content-Type response header. @@ -62,8 +63,8 @@ await WriteJsonRpcErrorAsync(context, public async Task HandleGetRequestAsync(HttpContext context) { - var acceptHeader = context.Request.Headers.Accept.ToString(); - if (!acceptHeader.Contains("application/json", StringComparison.Ordinal)) + var acceptHeaders = context.Request.GetTypedHeaders().Accept; + if (!acceptHeaders.Contains(TextEventStreamMediaType)) { await WriteJsonRpcErrorAsync(context, "Not Acceptable: Client must accept text/event-stream", @@ -105,12 +106,6 @@ public async Task HandleDeleteRequestAsync(HttpContext context) } } - private void InitializeSessionResponse(HttpContext context, HttpMcpSession session) - { - context.Response.Headers["mcp-session-id"] = session.Id; - context.Features.Set(session.Server); - } - private async ValueTask?> GetSessionAsync(HttpContext context, string sessionId) { if (Sessions.TryGetValue(sessionId, out var existingSession)) @@ -123,7 +118,8 @@ await WriteJsonRpcErrorAsync(context, return null; } - InitializeSessionResponse(context, existingSession); + context.Response.Headers["mcp-session-id"] = existingSession.Id; + context.Features.Set(existingSession.Server); return existingSession; } @@ -138,11 +134,10 @@ await WriteJsonRpcErrorAsync(context, private async ValueTask?> GetOrCreateSessionAsync(HttpContext context) { var sessionId = context.Request.Headers["mcp-session-id"].ToString(); - HttpMcpSession? session; if (string.IsNullOrEmpty(sessionId)) { - session = await CreateSessionAsync(context); + var session = await CreateSessionAsync(context); if (!Sessions.TryAdd(session.Id, session)) { @@ -159,6 +154,9 @@ await WriteJsonRpcErrorAsync(context, private async ValueTask> CreateSessionAsync(HttpContext context) { + var sessionId = MakeNewSessionId(); + context.Response.Headers["mcp-session-id"] = sessionId; + var mcpServerOptions = mcpServerOptionsSnapshot.Value; if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions) { @@ -169,8 +167,9 @@ private async ValueTask> CreateSes var transport = new StreamableHttpServerTransport(); // Use application instead of request services, because the session will likely outlive the first initialization request. var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, applicationServices); + context.Features.Set(server); - var session = new HttpMcpSession(MakeNewSessionId(), transport, context.User, httpMcpServerOptions.Value.TimeProvider) + var session = new HttpMcpSession(sessionId, transport, context.User, httpMcpServerOptions.Value.TimeProvider) { Server = server, }; @@ -178,7 +177,6 @@ private async ValueTask> CreateSes var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? RunSessionAsync; session.ServerRunTask = runSessionAsync(context, server, session.SessionClosed); - InitializeSessionResponse(context, session); return session; } diff --git a/src/ModelContextProtocol/IMcpEndpoint.cs b/src/ModelContextProtocol/IMcpEndpoint.cs index dcfdf687..f6431367 100644 --- a/src/ModelContextProtocol/IMcpEndpoint.cs +++ b/src/ModelContextProtocol/IMcpEndpoint.cs @@ -34,7 +34,8 @@ public interface IMcpEndpoint : IAsyncDisposable /// The JSON-RPC request to send. /// The to monitor for cancellation requests. The default is . /// A task containing the endpoint's response. - /// The transport is not connected, or another error occurs during request processing. + /// The transport is not connected, or another error occurs during request processing. + /// An error occured during request processing. /// /// This method provides low-level access to send raw JSON-RPC requests. For most use cases, /// consider using the strongly-typed extension methods that provide a more convenient API. @@ -50,7 +51,7 @@ public interface IMcpEndpoint : IAsyncDisposable /// /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous send operation. - /// The transport is not connected. + /// The transport is not connected. /// is . /// /// diff --git a/src/ModelContextProtocol/Protocol/Transport/SseWriter.cs b/src/ModelContextProtocol/Protocol/Transport/SseWriter.cs index a3eb0ce4..c2cce9f1 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseWriter.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseWriter.cs @@ -51,7 +51,7 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can { Throw.IfNull(message); - using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); + using var _ = await _disposeLock.LockAsync(cancellationToken).ConfigureAwait(false); if (_disposed) { diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs index 42e3ff70..aa9e522d 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs @@ -51,8 +51,8 @@ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken c throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session."); } - using var getCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken); - await _sseWriter.WriteAllAsync(sseResponseStream, getCts.Token).ConfigureAwait(false); + // We do not need to reference _disposeCts like in HandlePostRequest, because the session ending completes the _sseWriter gracefully. + await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false); } /// diff --git a/tests/ModelContextProtocol.Tests/Utils/DelegatingTestOutputHelper.cs b/tests/Common/Utils/DelegatingTestOutputHelper.cs similarity index 100% rename from tests/ModelContextProtocol.Tests/Utils/DelegatingTestOutputHelper.cs rename to tests/Common/Utils/DelegatingTestOutputHelper.cs diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/LoggedTest.cs b/tests/Common/Utils/LoggedTest.cs similarity index 100% rename from tests/ModelContextProtocol.AspNetCore.Tests/Utils/LoggedTest.cs rename to tests/Common/Utils/LoggedTest.cs diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/MockHttpHandler.cs b/tests/Common/Utils/MockHttpHandler.cs similarity index 100% rename from tests/ModelContextProtocol.AspNetCore.Tests/Utils/MockHttpHandler.cs rename to tests/Common/Utils/MockHttpHandler.cs diff --git a/tests/Common/Utils/MockLoggerProvider.cs b/tests/Common/Utils/MockLoggerProvider.cs new file mode 100644 index 00000000..f5264edc --- /dev/null +++ b/tests/Common/Utils/MockLoggerProvider.cs @@ -0,0 +1,32 @@ +using Microsoft.Extensions.Logging; +using System.Collections.Concurrent; + +namespace ModelContextProtocol.Tests.Utils; + +public class MockLoggerProvider() : ILoggerProvider +{ + public ConcurrentQueue<(string Category, LogLevel LogLevel, string Message, Exception? Exception)> LogMessages { get; } = []; + + public ILogger CreateLogger(string categoryName) + { + return new MockLogger(this, categoryName); + } + + public void Dispose() + { + } + + private class MockLogger(MockLoggerProvider mockProvider, string category) : ILogger + { + public void Log( + LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + mockProvider.LogMessages.Enqueue((category, logLevel, formatter(state, exception), exception)); + } + + public bool IsEnabled(LogLevel logLevel) => true; + + // The MockLoggerProvider is a convenient NoopDisposable + public IDisposable BeginScope(TState state) where TState : notnull => mockProvider; + } +} diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/Common/Utils/TestServerTransport.cs similarity index 100% rename from tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs rename to tests/Common/Utils/TestServerTransport.cs diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/XunitLoggerProvider.cs b/tests/Common/Utils/XunitLoggerProvider.cs similarity index 100% rename from tests/ModelContextProtocol.AspNetCore.Tests/Utils/XunitLoggerProvider.cs rename to tests/Common/Utils/XunitLoggerProvider.cs diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj b/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj index 9d7246bf..fd30b71e 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj @@ -22,6 +22,10 @@ true + + + + runtime; build; native; contentfiles; analyzers; buildtransitive @@ -35,6 +39,7 @@ + diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs index cf3aa4f4..fa3f8fe0 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs @@ -1,9 +1,13 @@ using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Time.Testing; +using Microsoft.Net.Http.Headers; using ModelContextProtocol.AspNetCore.Tests.Utils; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.Utils.Json; using System.Net; using System.Net.ServerSentEvents; @@ -27,8 +31,6 @@ public class StreamableHttpTests(ITestOutputHelper outputHelper) : KestrelInMemo private async Task StartAsync() { - AddDefaultHttpClientRequestHeaders(); - Builder.Services.AddMcpServer(options => { options.ServerInfo = new Implementation @@ -43,6 +45,9 @@ private async Task StartAsync() _app.MapMcp(); await _app.StartAsync(TestContext.Current.CancellationToken); + + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); } public async ValueTask DisposeAsync() @@ -54,6 +59,31 @@ public async ValueTask DisposeAsync() base.Dispose(); } + [Fact] + public async Task NegativeNonInfiniteIdleTimeout_Throws_ArgumentOutOfRangeException() + { + Builder.Services.AddMcpServer().WithHttpTransport(options => + { + options.IdleTimeout = TimeSpan.MinValue; + }); + + var ex = await Assert.ThrowsAnyAsync(StartAsync); + Assert.Contains("IdleTimeout", ex.Message); + } + + + [Fact] + public async Task NegativeMaxIdleSessionCount_Throws_ArgumentOutOfRangeException() + { + Builder.Services.AddMcpServer().WithHttpTransport(options => + { + options.MaxIdleSessionCount = -1; + }); + + var ex = await Assert.ThrowsAnyAsync(StartAsync); + Assert.Contains("MaxIdleSessionCount", ex.Message); + } + [Fact] public async Task InitialPostResponse_Includes_McpSessionIdHeader() { @@ -74,26 +104,16 @@ public async Task PostRequest_IsUnsupportedMediaType_WithoutJsonContentType() Assert.Equal(HttpStatusCode.UnsupportedMediaType, response.StatusCode); } - [Fact] - public async Task PostRequest_IsNotAcceptable_WithoutApplicationJsonAcceptHeader() + [Theory] + [InlineData("text/event-stream")] + [InlineData("application/json")] + [InlineData("application/json-text/event-stream")] + public async Task PostRequest_IsNotAcceptable_WithSingleAcceptHeader(string singleAcceptValue) { await StartAsync(); HttpClient.DefaultRequestHeaders.Accept.Clear(); - HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); - - using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); - Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode); - } - - - [Fact] - public async Task PostRequest_IsNotAcceptable_WithoutTextEventStreamAcceptHeader() - { - await StartAsync(); - - HttpClient.DefaultRequestHeaders.Accept.Clear(); - HttpClient.DefaultRequestHeaders.Accept.Add(new("text/json")); + HttpClient.DefaultRequestHeaders.TryAddWithoutValidation(HeaderNames.Accept, singleAcceptValue); using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode); @@ -105,7 +125,7 @@ public async Task GetRequest_IsNotAcceptable_WithoutTextEventStreamAcceptHeader( await StartAsync(); HttpClient.DefaultRequestHeaders.Accept.Clear(); - HttpClient.DefaultRequestHeaders.Accept.Add(new("text/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); using var response = await HttpClient.GetAsync("", TestContext.Current.CancellationToken); Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode); @@ -131,7 +151,6 @@ public async Task PostRequest_IsNotFound_WithUnrecognizedSessionId() [Fact] public async Task InitializeRequest_Matches_CustomRoute() { - AddDefaultHttpClientRequestHeaders(); Builder.Services.AddMcpServer().WithHttpTransport(); await using var app = Builder.Build(); @@ -139,6 +158,8 @@ public async Task InitializeRequest_Matches_CustomRoute() await app.StartAsync(TestContext.Current.CancellationToken); + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); using var response = await HttpClient.PostAsync("/custom-route", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); Assert.Equal(HttpStatusCode.OK, response.StatusCode); } @@ -312,10 +333,21 @@ Task CallLongRunningToolAsync() => for (int i = 0; i < longRunningToolTasks.Length; i++) { longRunningToolTasks[i] = CallLongRunningToolAsync(); + } + + var getResponse = await HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + + for (int i = 0; i < longRunningToolTasks.Length; i++) + { Assert.False(longRunningToolTasks[i].IsCompleted); } + await HttpClient.DeleteAsync("", TestContext.Current.CancellationToken); + // Get request should complete gracefully. + var sseResponseBody = await getResponse.Content.ReadAsStringAsync(TestContext.Current.CancellationToken); + Assert.Empty(sseResponseBody); + // Currently, the OCE thrown by the canceled session is unhandled and turned into a 500 error by Kestrel. // The spec suggests sending CancelledNotifications. That would be good, but we can do that later. // For now, the important thing is that request completes without indicating success. @@ -361,10 +393,98 @@ public async Task Progress_IsReported_InSameSseResponseAsRpcResponse() Assert.Equal(11, currentSseItem); } - private void AddDefaultHttpClientRequestHeaders() + [Fact] + public async Task IdleSessions_ArePruned_AfterIdleTimeout() { - HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); - HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); + var fakeTimeProvider = new FakeTimeProvider(); + Builder.Services.AddMcpServer().WithHttpTransport(options => + { + Assert.Equal(TimeSpan.FromHours(2), options.IdleTimeout); + options.TimeProvider = fakeTimeProvider; + }); + + await StartAsync(); + await CallInitializeAndValidateAsync(); + await CallEchoAndValidateAsync(); + + // Add 5 seconds to idle timeout to account for the interval of the PeriodicTimer. + fakeTimeProvider.Advance(TimeSpan.FromHours(2) + TimeSpan.FromSeconds(5)); + + using var response = await HttpClient.PostAsync("", JsonContent(EchoRequest), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + + [Fact] + public async Task IdleSessions_AreNotPruned_WithInfiniteIdleTimeoutWhileUnderMaxIdleSessionCount() + { + var fakeTimeProvider = new FakeTimeProvider(); + Builder.Services.AddMcpServer().WithHttpTransport(options => + { + options.IdleTimeout = Timeout.InfiniteTimeSpan; + options.TimeProvider = fakeTimeProvider; + }); + + await StartAsync(); + await CallInitializeAndValidateAsync(); + await CallEchoAndValidateAsync(); + + fakeTimeProvider.Advance(TimeSpan.FromDays(1)); + + // Echo still works because the session has not been pruned. + await CallEchoAndValidateAsync(); + } + + [Fact] + public async Task IdleSessionsPastMaxIdleSessionCount_ArePruned_LongestIdleFirstDespiteIdleTimeout() + { + var fakeTimeProvider = new FakeTimeProvider(); + Builder.Services.AddMcpServer().WithHttpTransport(options => + { + options.IdleTimeout = Timeout.InfiniteTimeSpan; + options.MaxIdleSessionCount = 2; + options.TimeProvider = fakeTimeProvider; + }); + + var mockLoggerProvider = new MockLoggerProvider(); + Builder.Logging.AddProvider(mockLoggerProvider); + + await StartAsync(); + + // Start first session. + var firstSessionId = await CallInitializeAndValidateAsync(); + + // Start a second session to trigger pruning of the original session. + fakeTimeProvider.Advance(TimeSpan.FromTicks(1)); + var secondSessionId = await CallInitializeAndValidateAsync(); + + Assert.NotEqual(firstSessionId, secondSessionId); + + // First session ID still works, since we allow up to 2 idle sessions. + fakeTimeProvider.Advance(TimeSpan.FromTicks(1)); + SetSessionId(firstSessionId); + await CallEchoAndValidateAsync(); + + // Start a third session to trigger pruning of the first session. + fakeTimeProvider.Advance(TimeSpan.FromTicks(1)); + var thirdSessionId = await CallInitializeAndValidateAsync(); + + Assert.NotEqual(secondSessionId, thirdSessionId); + + // Pruning of the second session results in a 404 since we used the first session more recently. + fakeTimeProvider.Advance(TimeSpan.FromSeconds(10)); + SetSessionId(secondSessionId); + using var response = await HttpClient.PostAsync("", JsonContent(EchoRequest), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + + // But the first and third session IDs should still work. + SetSessionId(firstSessionId); + await CallEchoAndValidateAsync(); + + SetSessionId(thirdSessionId); + await CallEchoAndValidateAsync(); + + var logMessage = Assert.Single(mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Critical); + Assert.StartsWith("Exceeded maximum of 2 idle sessions.", logMessage.Message); } private static StringContent JsonContent(string json) => new StringContent(json, Encoding.UTF8, "application/json"); @@ -437,7 +557,7 @@ private string CallTool(string toolName, string arguments = "{}") => private string CallToolWithProgressToken(string toolName, string arguments = "{}") => Request("tools/call", $$$""" - {"name":"{{{toolName}}}","arguments":{{{arguments}}}, "_meta":{"progressToken": "abc123"}} + {"name":"{{{toolName}}}","arguments":{{{arguments}}},"_meta":{"progressToken":"abc123"}} """); private static InitializeResult AssertServerInfo(JsonRpcResponse rpcResponse) @@ -457,13 +577,21 @@ private static CallToolResponse AssertEchoResponse(JsonRpcResponse rpcResponse) return callToolResponse; } - private async Task CallInitializeAndValidateAsync() + private async Task CallInitializeAndValidateAsync() { + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); var rpcResponse = await AssertSingleSseResponseAsync(response); AssertServerInfo(rpcResponse); var sessionId = Assert.Single(response.Headers.GetValues("mcp-session-id")); + SetSessionId(sessionId); + return sessionId; + } + + private void SetSessionId(string sessionId) + { + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); HttpClient.DefaultRequestHeaders.Add("mcp-session-id", sessionId); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/DelegatingTestOutputHelper.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/DelegatingTestOutputHelper.cs deleted file mode 100644 index ef452fcb..00000000 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/DelegatingTestOutputHelper.cs +++ /dev/null @@ -1,13 +0,0 @@ -namespace ModelContextProtocol.Tests.Utils; - -public class DelegatingTestOutputHelper : ITestOutputHelper -{ - public ITestOutputHelper? CurrentTestOutputHelper { get; set; } - - public string Output => CurrentTestOutputHelper?.Output ?? string.Empty; - - public void Write(string message) => CurrentTestOutputHelper?.Write(message); - public void Write(string format, params object[] args) => CurrentTestOutputHelper?.Write(format, args); - public void WriteLine(string message) => CurrentTestOutputHelper?.WriteLine(message); - public void WriteLine(string format, params object[] args) => CurrentTestOutputHelper?.WriteLine(format, args); -} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs deleted file mode 100644 index a221b8a3..00000000 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs +++ /dev/null @@ -1,83 +0,0 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Protocol.Types; -using System.Text.Json; -using System.Threading.Channels; - -namespace ModelContextProtocol.Tests.Utils; - -public class TestServerTransport : ITransport -{ - private readonly Channel _messageChannel; - - public bool IsConnected { get; set; } - - public ChannelReader MessageReader => _messageChannel; - - public List SentMessages { get; } = []; - - public Action? OnMessageSent { get; set; } - - public TestServerTransport() - { - _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions - { - SingleReader = true, - SingleWriter = true, - }); - IsConnected = true; - } - - public ValueTask DisposeAsync() - { - _messageChannel.Writer.TryComplete(); - IsConnected = false; - return ValueTask.CompletedTask; - } - - public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) - { - SentMessages.Add(message); - if (message is JsonRpcRequest request) - { - if (request.Method == RequestMethods.RootsList) - await ListRoots(request, cancellationToken); - else if (request.Method == RequestMethods.SamplingCreateMessage) - await Sampling(request, cancellationToken); - else - await WriteMessageAsync(request, cancellationToken); - } - else if (message is JsonRpcNotification notification) - { - await WriteMessageAsync(notification, cancellationToken); - } - - OnMessageSent?.Invoke(message); - } - - private async Task ListRoots(JsonRpcRequest request, CancellationToken cancellationToken) - { - await WriteMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = JsonSerializer.SerializeToNode(new ListRootsResult - { - Roots = [] - }), - }, cancellationToken); - } - - private async Task Sampling(JsonRpcRequest request, CancellationToken cancellationToken) - { - await WriteMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = new(), Model = "model", Role = Role.Assistant }), - }, cancellationToken); - } - - private async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) - { - await _messageChannel.Writer.WriteAsync(message, cancellationToken); - } -} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index 4b3fe3a9..c5e7c706 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -45,7 +45,10 @@ public async Task Cancellation_ThrowsCancellationException(bool preCanceled) Task t = McpClientFactory.CreateAsync( new StreamClientTransport(new Pipe().Writer.AsStream(), new Pipe().Reader.AsStream()), cancellationToken: cts.Token); - Assert.False(t.IsCompleted); + if (!preCanceled) + { + Assert.False(t.IsCompleted); + } if (!preCanceled) { diff --git a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs index 75797bb8..c3c45867 100644 --- a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs +++ b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs @@ -62,7 +62,7 @@ public async ValueTask DisposeAsync() Dispose(); } - protected async Task CreateMcpClientForServer(McpClientOptions? options = null) + protected async Task CreateMcpClientForServer() { return await McpClientFactory.CreateAsync( new StreamClientTransport( diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index 4bcb83fd..b99e2020 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -24,6 +24,10 @@ true + + + + runtime; build; native; contentfiles; analyzers; buildtransitive @@ -55,13 +59,6 @@ - - - - - - - PreserveNewest diff --git a/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs b/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs deleted file mode 100644 index a2e9e2ba..00000000 --- a/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs +++ /dev/null @@ -1,30 +0,0 @@ -using Microsoft.Extensions.Logging; - -namespace ModelContextProtocol.Tests.Utils; - -public class LoggedTest : IDisposable -{ - private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper; - - public LoggedTest(ITestOutputHelper testOutputHelper) - { - _delegatingTestOutputHelper = new() - { - CurrentTestOutputHelper = testOutputHelper, - }; - LoggerProvider = new XunitLoggerProvider(_delegatingTestOutputHelper); - LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => - { - builder.AddProvider(LoggerProvider); - }); - } - - public ITestOutputHelper TestOutputHelper => _delegatingTestOutputHelper; - public ILoggerFactory LoggerFactory { get; } - public ILoggerProvider LoggerProvider { get; } - - public virtual void Dispose() - { - _delegatingTestOutputHelper.CurrentTestOutputHelper = null; - } -} diff --git a/tests/ModelContextProtocol.Tests/Utils/MockHttpHandler.cs b/tests/ModelContextProtocol.Tests/Utils/MockHttpHandler.cs deleted file mode 100644 index 5e58a6cd..00000000 --- a/tests/ModelContextProtocol.Tests/Utils/MockHttpHandler.cs +++ /dev/null @@ -1,20 +0,0 @@ -namespace ModelContextProtocol.Tests.Utils; - -public class MockHttpHandler : HttpMessageHandler -{ - public Func>? RequestHandler { get; set; } - - protected async override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) - { - if (RequestHandler == null) - throw new InvalidOperationException($"No {nameof(RequestHandler)} was set! Please set handler first and make request afterwards."); - - cancellationToken.ThrowIfCancellationRequested(); - - var result = await RequestHandler.Invoke(request); - - cancellationToken.ThrowIfCancellationRequested(); - - return result; - } -} diff --git a/tests/ModelContextProtocol.Tests/Utils/XunitLoggerProvider.cs b/tests/ModelContextProtocol.Tests/Utils/XunitLoggerProvider.cs deleted file mode 100644 index f66a828a..00000000 --- a/tests/ModelContextProtocol.Tests/Utils/XunitLoggerProvider.cs +++ /dev/null @@ -1,52 +0,0 @@ -using System.Globalization; -using System.Text; -using Microsoft.Extensions.Logging; - -namespace ModelContextProtocol.Tests.Utils; - -public class XunitLoggerProvider(ITestOutputHelper output) : ILoggerProvider -{ - public ILogger CreateLogger(string categoryName) - { - return new XunitLogger(output, categoryName); - } - - public void Dispose() - { - } - - private class XunitLogger(ITestOutputHelper output, string category) : ILogger - { - public void Log( - LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) - { - var sb = new StringBuilder(); - - var timestamp = DateTimeOffset.UtcNow.ToString("s", CultureInfo.InvariantCulture); - var prefix = $"| [{timestamp}] {category} {logLevel}: "; - var lines = formatter(state, exception); - sb.Append(prefix); - sb.Append(lines); - - if (exception is not null) - { - sb.AppendLine(); - sb.Append(exception.ToString()); - } - - output.WriteLine(sb.ToString()); - } - - public bool IsEnabled(LogLevel logLevel) => true; - - public IDisposable BeginScope(TState state) where TState : notnull - => new NoopDisposable(); - - private sealed class NoopDisposable : IDisposable - { - public void Dispose() - { - } - } - } -}