Skip to content

Commit 03e3094

Browse files
Fix and enhance cancellation operations across MCP Sessions. (#179)
* Propagate CancellationToken request cancellation to remote endpoint * Refactor existing tests to have a shared base class * Add cancellation tests --------- Co-authored-by: Stephen Toub <[email protected]>
1 parent 4dd2f42 commit 03e3094

21 files changed

+334
-336
lines changed

samples/EverythingServer/LoggingUpdateMessageSender.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using Microsoft.Extensions.DependencyInjection;
2-
using Microsoft.Extensions.Hosting;
1+
using Microsoft.Extensions.Hosting;
32
using ModelContextProtocol;
43
using ModelContextProtocol.Protocol.Types;
54
using ModelContextProtocol.Server;

samples/EverythingServer/ResourceGenerator.cs

-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
using ModelContextProtocol.Protocol.Types;
2-
using System;
3-
using System.Collections.Generic;
4-
using System.Linq;
52

63
namespace EverythingServer;
74

samples/EverythingServer/Tools/LongRunningTool.cs

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using ModelContextProtocol;
2-
using ModelContextProtocol.Protocol.Messages;
32
using ModelContextProtocol.Protocol.Types;
43
using ModelContextProtocol.Server;
54
using System.ComponentModel;

src/ModelContextProtocol/Client/McpClientFactory.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
using System.Globalization;
2-
using System.Runtime.InteropServices;
3-
using ModelContextProtocol.Logging;
1+
using ModelContextProtocol.Logging;
42
using ModelContextProtocol.Protocol.Transport;
53
using ModelContextProtocol.Utils;
64
using Microsoft.Extensions.Logging;

src/ModelContextProtocol/McpEndpointExtensions.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,14 @@ public static Task NotifyProgressAsync(
155155
{
156156
Throw.IfNull(endpoint);
157157

158-
return endpoint.SendMessageAsync(new JsonRpcNotification()
159-
{
160-
Method = NotificationMethods.ProgressNotification,
161-
Params = JsonSerializer.SerializeToNode(new ProgressNotification
158+
return endpoint.SendNotificationAsync(
159+
NotificationMethods.ProgressNotification,
160+
new ProgressNotification
162161
{
163162
ProgressToken = progressToken,
164163
Progress = progress,
165-
}, McpJsonUtilities.JsonContext.Default.ProgressNotification),
166-
}, cancellationToken);
164+
},
165+
McpJsonUtilities.JsonContext.Default.ProgressNotification,
166+
cancellationToken);
167167
}
168168
}

src/ModelContextProtocol/Protocol/Types/ClientCapabilities.cs

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using ModelContextProtocol.Protocol.Messages;
2-
using ModelContextProtocol.Server;
32
using System.Text.Json.Serialization;
43

54
namespace ModelContextProtocol.Protocol.Types;

src/ModelContextProtocol/Protocol/Types/LoggingCapability.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using ModelContextProtocol.Protocol.Messages;
2-
using ModelContextProtocol.Server;
1+
using ModelContextProtocol.Server;
32
using System.Text.Json.Serialization;
43

54
namespace ModelContextProtocol.Protocol.Types;

src/ModelContextProtocol/Protocol/Types/PromptsCapability.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using ModelContextProtocol.Protocol.Messages;
2-
using ModelContextProtocol.Server;
1+
using ModelContextProtocol.Server;
32
using System.Text.Json.Serialization;
43

54
namespace ModelContextProtocol.Protocol.Types;

src/ModelContextProtocol/Protocol/Types/ResourcesCapability.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using ModelContextProtocol.Protocol.Messages;
2-
using ModelContextProtocol.Server;
1+
using ModelContextProtocol.Server;
32
using System.Text.Json.Serialization;
43

54
namespace ModelContextProtocol.Protocol.Types;

src/ModelContextProtocol/Protocol/Types/RootsCapability.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
using ModelContextProtocol.Protocol.Messages;
2-
using ModelContextProtocol.Server;
3-
using System.Text.Json.Serialization;
1+
using System.Text.Json.Serialization;
42

53
namespace ModelContextProtocol.Protocol.Types;
64

src/ModelContextProtocol/Protocol/Types/SamplingCapability.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
using ModelContextProtocol.Protocol.Messages;
2-
using ModelContextProtocol.Server;
3-
using System.Text.Json.Serialization;
1+
using System.Text.Json.Serialization;
42

53
namespace ModelContextProtocol.Protocol.Types;
64

src/ModelContextProtocol/Protocol/Types/ToolsCapability.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using ModelContextProtocol.Protocol.Messages;
2-
using ModelContextProtocol.Server;
1+
using ModelContextProtocol.Server;
32
using System.Text.Json.Serialization;
43

54
namespace ModelContextProtocol.Protocol.Types;

src/ModelContextProtocol/Server/McpServerOptions.cs

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11

22
using ModelContextProtocol.Protocol.Types;
3-
using System.Text.Json.Serialization;
43

54
namespace ModelContextProtocol.Server;
65

src/ModelContextProtocol/Shared/McpSession.cs

+31-2
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,24 @@ await _transport.SendMessageAsync(new JsonRpcResponse
296296
}, cancellationToken).ConfigureAwait(false);
297297
}
298298

299+
private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, RequestId requestId)
300+
{
301+
if (!cancellationToken.CanBeCanceled)
302+
{
303+
return default;
304+
}
305+
306+
return cancellationToken.Register(static objState =>
307+
{
308+
var state = (Tuple<McpSession, RequestId>)objState!;
309+
_ = state.Item1.SendMessageAsync(new JsonRpcNotification
310+
{
311+
Method = NotificationMethods.CancelledNotification,
312+
Params = JsonSerializer.SerializeToNode(new CancelledNotification { RequestId = state.Item2 }, McpJsonUtilities.JsonContext.Default.CancelledNotification)
313+
});
314+
}, Tuple.Create(this, requestId));
315+
}
316+
299317
public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcNotification, CancellationToken, Task> handler)
300318
{
301319
Throw.IfNullOrWhiteSpace(method);
@@ -320,6 +338,8 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
320338
throw new McpException("Transport is not connected");
321339
}
322340

341+
cancellationToken.ThrowIfCancellationRequested();
342+
323343
Histogram<double> durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration;
324344
string method = request.Method;
325345

@@ -357,9 +377,16 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
357377
_logger.SendingRequest(EndpointName, request.Method);
358378

359379
await _transport.SendMessageAsync(request, cancellationToken).ConfigureAwait(false);
360-
361380
_logger.RequestSentAwaitingResponse(EndpointName, request.Method, request.Id.ToString());
362-
var response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
381+
382+
// Now that the request has been sent, register for cancellation. If we registered before,
383+
// a cancellation request could arrive before the server knew about that request ID, in which
384+
// case the server could ignore it.
385+
IJsonRpcMessage? response;
386+
using (var registration = RegisterCancellation(cancellationToken, request.Id))
387+
{
388+
response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
389+
}
363390

364391
if (response is JsonRpcError error)
365392
{
@@ -400,6 +427,8 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca
400427
throw new McpException("Transport is not connected");
401428
}
402429

430+
cancellationToken.ThrowIfCancellationRequested();
431+
403432
Histogram<double> durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration;
404433
string method = GetMethodName(message);
405434

tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs

+12-49
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,32 @@
33
using Microsoft.Extensions.Logging;
44
using ModelContextProtocol.Client;
55
using ModelContextProtocol.Protocol.Messages;
6-
using ModelContextProtocol.Protocol.Transport;
76
using ModelContextProtocol.Protocol.Types;
87
using ModelContextProtocol.Server;
9-
using ModelContextProtocol.Tests.Utils;
108
using Moq;
11-
using System.IO.Pipelines;
129
using System.Text.Json;
1310
using System.Text.Json.Serialization.Metadata;
1411
using System.Threading.Channels;
1512

1613
namespace ModelContextProtocol.Tests.Client;
1714

18-
public class McpClientExtensionsTests : LoggedTest
15+
public class McpClientExtensionsTests : ClientServerTestBase
1916
{
20-
private readonly Pipe _clientToServerPipe = new();
21-
private readonly Pipe _serverToClientPipe = new();
22-
private readonly ServiceProvider _serviceProvider;
23-
private readonly CancellationTokenSource _cts;
24-
private readonly IMcpServer _server;
25-
private readonly Task _serverTask;
26-
2717
public McpClientExtensionsTests(ITestOutputHelper outputHelper)
2818
: base(outputHelper)
2919
{
30-
ServiceCollection sc = new();
31-
sc.AddSingleton(LoggerFactory);
32-
sc.AddMcpServer().WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream());
20+
}
21+
22+
protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder)
23+
{
3324
for (int f = 0; f < 10; f++)
3425
{
3526
string name = $"Method{f}";
36-
sc.AddSingleton(McpServerTool.Create((int i) => $"{name} Result {i}", new() { Name = name }));
27+
services.AddSingleton(McpServerTool.Create((int i) => $"{name} Result {i}", new() { Name = name }));
3728
}
38-
sc.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)](string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" }));
39-
sc.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)](string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true }));
40-
_serviceProvider = sc.BuildServiceProvider();
29+
services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" }));
30+
services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true }));
4131

42-
_server = _serviceProvider.GetRequiredService<IMcpServer>();
43-
_cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken);
44-
_serverTask = _server.RunAsync(cancellationToken: _cts.Token);
4532
}
4633

4734
[Theory]
@@ -218,30 +205,6 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages()
218205
Assert.Equal("endTurn", result.StopReason);
219206
}
220207

221-
public async ValueTask DisposeAsync()
222-
{
223-
await _cts.CancelAsync();
224-
225-
_clientToServerPipe.Writer.Complete();
226-
_serverToClientPipe.Writer.Complete();
227-
228-
await _serverTask;
229-
230-
await _serviceProvider.DisposeAsync();
231-
_cts.Dispose();
232-
}
233-
234-
private async Task<IMcpClient> CreateMcpClientForServer()
235-
{
236-
return await McpClientFactory.CreateAsync(
237-
new StreamClientTransport(
238-
serverInput: _clientToServerPipe.Writer.AsStream(),
239-
serverOutput: _serverToClientPipe.Reader.AsStream(),
240-
LoggerFactory),
241-
loggerFactory: LoggerFactory,
242-
cancellationToken: TestContext.Current.CancellationToken);
243-
}
244-
245208
[Fact]
246209
public async Task ListToolsAsync_AllToolsReturned()
247210
{
@@ -377,15 +340,15 @@ public async Task AsClientLoggerProvider_MessagesSentToClient()
377340
{
378341
IMcpClient client = await CreateMcpClientForServer();
379342

380-
ILoggerProvider loggerProvider = _server.AsClientLoggerProvider();
343+
ILoggerProvider loggerProvider = Server.AsClientLoggerProvider();
381344
Assert.Throws<ArgumentNullException>("categoryName", () => loggerProvider.CreateLogger(null!));
382345

383346
ILogger logger = loggerProvider.CreateLogger("TestLogger");
384347
Assert.NotNull(logger);
385348

386349
Assert.Null(logger.BeginScope(""));
387350

388-
Assert.Null(_server.LoggingLevel);
351+
Assert.Null(Server.LoggingLevel);
389352
Assert.False(logger.IsEnabled(LogLevel.Trace));
390353
Assert.False(logger.IsEnabled(LogLevel.Debug));
391354
Assert.False(logger.IsEnabled(LogLevel.Information));
@@ -396,13 +359,13 @@ public async Task AsClientLoggerProvider_MessagesSentToClient()
396359
await client.SetLoggingLevel(LoggingLevel.Info, TestContext.Current.CancellationToken);
397360

398361
DateTime start = DateTime.UtcNow;
399-
while (_server.LoggingLevel is null)
362+
while (Server.LoggingLevel is null)
400363
{
401364
await Task.Delay(1, TestContext.Current.CancellationToken);
402365
Assert.True(DateTime.UtcNow - start < TimeSpan.FromSeconds(10), "Timed out waiting for logging level to be set");
403366
}
404367

405-
Assert.Equal(LoggingLevel.Info, _server.LoggingLevel);
368+
Assert.Equal(LoggingLevel.Info, Server.LoggingLevel);
406369
Assert.False(logger.IsEnabled(LogLevel.Trace));
407370
Assert.False(logger.IsEnabled(LogLevel.Debug));
408371
Assert.True(logger.IsEnabled(LogLevel.Information));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
using Microsoft.Extensions.AI;
2+
using Microsoft.Extensions.DependencyInjection;
3+
using ModelContextProtocol.Client;
4+
using ModelContextProtocol.Protocol.Transport;
5+
using ModelContextProtocol.Server;
6+
using ModelContextProtocol.Tests.Utils;
7+
using System.IO.Pipelines;
8+
9+
namespace ModelContextProtocol.Tests;
10+
11+
public abstract class ClientServerTestBase : LoggedTest, IAsyncDisposable
12+
{
13+
private readonly Pipe _clientToServerPipe = new();
14+
private readonly Pipe _serverToClientPipe = new();
15+
private readonly IMcpServerBuilder _builder;
16+
private readonly CancellationTokenSource _cts;
17+
private readonly Task _serverTask;
18+
19+
public ClientServerTestBase(ITestOutputHelper testOutputHelper)
20+
: base(testOutputHelper)
21+
{
22+
ServiceCollection sc = new();
23+
sc.AddSingleton(LoggerFactory);
24+
_builder = sc
25+
.AddMcpServer()
26+
.WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream());
27+
ConfigureServices(sc, _builder);
28+
ServiceProvider = sc.BuildServiceProvider();
29+
30+
_cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken);
31+
Server = ServiceProvider.GetRequiredService<IMcpServer>();
32+
_serverTask = Server.RunAsync(_cts.Token);
33+
}
34+
35+
protected IMcpServer Server { get; }
36+
37+
protected IServiceProvider ServiceProvider { get; }
38+
39+
protected virtual void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder)
40+
{
41+
}
42+
43+
public async ValueTask DisposeAsync()
44+
{
45+
await _cts.CancelAsync();
46+
47+
_clientToServerPipe.Writer.Complete();
48+
_serverToClientPipe.Writer.Complete();
49+
50+
await _serverTask;
51+
52+
if (ServiceProvider is IAsyncDisposable asyncDisposable)
53+
{
54+
await asyncDisposable.DisposeAsync();
55+
}
56+
else if (ServiceProvider is IDisposable disposable)
57+
{
58+
disposable.Dispose();
59+
}
60+
61+
_cts.Dispose();
62+
Dispose();
63+
}
64+
65+
protected async Task<IMcpClient> CreateMcpClientForServer(McpClientOptions? options = null)
66+
{
67+
return await McpClientFactory.CreateAsync(
68+
new StreamClientTransport(
69+
serverInput: _clientToServerPipe.Writer.AsStream(),
70+
_serverToClientPipe.Reader.AsStream(),
71+
LoggerFactory),
72+
loggerFactory: LoggerFactory,
73+
cancellationToken: TestContext.Current.CancellationToken);
74+
}
75+
}

0 commit comments

Comments
 (0)