Skip to content

Commit 90d3aea

Browse files
committed
implementing in memory transports
1 parent a583bac commit 90d3aea

File tree

5 files changed

+548
-470
lines changed

5 files changed

+548
-470
lines changed

src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs

+17-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
using ModelContextProtocol.Configuration;
1+
using Microsoft.Extensions.DependencyInjection;
2+
3+
using ModelContextProtocol.Configuration;
24
using ModelContextProtocol.Hosting;
35
using ModelContextProtocol.Protocol.Transport;
46
using ModelContextProtocol.Utils;
5-
using Microsoft.Extensions.DependencyInjection;
67

78
namespace ModelContextProtocol;
89

@@ -11,6 +12,20 @@ namespace ModelContextProtocol;
1112
/// </summary>
1213
public static partial class McpServerBuilderExtensions
1314
{
15+
/// <summary>
16+
/// Adds a server transport that uses in memory communication.
17+
/// </summary>
18+
/// <param name="builder">The builder instance.</param>
19+
public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder)
20+
{
21+
Throw.IfNull(builder);
22+
var (clientTransport, serverTransport) = InMemoryTransport.Create();
23+
builder.Services.AddSingleton<IServerTransport>(serverTransport);
24+
builder.Services.AddSingleton<IClientTransport>(clientTransport);
25+
builder.Services.AddHostedService<McpServerHostedService>();
26+
return builder;
27+
}
28+
1429
/// <summary>
1530
/// Adds a server transport that uses stdin/stdout for communication.
1631
/// </summary>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.Extensions.Logging.Abstractions;
3+
4+
using ModelContextProtocol.Logging;
5+
using ModelContextProtocol.Protocol.Messages;
6+
7+
using System.Threading.Channels;
8+
9+
namespace ModelContextProtocol.Protocol.Transport;
10+
11+
/// <summary>
12+
/// Provides an in-memory implementation of the MCP client transport.
13+
/// </summary>
14+
public sealed class InMemoryClientTransport : TransportBase, IClientTransport
15+
{
16+
private readonly string _endpointName = "InMemoryClientTransport";
17+
private readonly ILogger _logger;
18+
private readonly ChannelWriter<IJsonRpcMessage> _outgoingChannel;
19+
private readonly ChannelReader<IJsonRpcMessage> _incomingChannel;
20+
private CancellationTokenSource? _cancellationTokenSource;
21+
private Task? _readTask;
22+
private SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1);
23+
private volatile bool _disposed;
24+
25+
/// <summary>
26+
/// Gets or sets the server transport this client connects to.
27+
/// </summary>
28+
internal InMemoryServerTransport? ServerTransport { get; set; }
29+
30+
/// <summary>
31+
/// Initializes a new instance of the <see cref="InMemoryClientTransport"/> class.
32+
/// </summary>
33+
/// <param name="loggerFactory">Optional logger factory for logging transport operations.</param>
34+
/// <param name="outgoingChannel">Channel for sending messages to the server.</param>
35+
/// <param name="incomingChannel">Channel for receiving messages from the server.</param>
36+
internal InMemoryClientTransport(
37+
ILoggerFactory? loggerFactory,
38+
ChannelWriter<IJsonRpcMessage> outgoingChannel,
39+
ChannelReader<IJsonRpcMessage> incomingChannel)
40+
: base(loggerFactory)
41+
{
42+
_logger = loggerFactory?.CreateLogger<InMemoryClientTransport>()
43+
?? NullLogger<InMemoryClientTransport>.Instance;
44+
_outgoingChannel = outgoingChannel;
45+
_incomingChannel = incomingChannel;
46+
}
47+
48+
49+
50+
/// <inheritdoc/>
51+
public async Task ConnectAsync(CancellationToken cancellationToken = default)
52+
{
53+
await _connectLock.WaitAsync(cancellationToken).ConfigureAwait(false);
54+
try
55+
{
56+
ThrowIfDisposed();
57+
58+
if (IsConnected)
59+
{
60+
_logger.TransportAlreadyConnected(_endpointName);
61+
throw new McpTransportException("Transport is already connected");
62+
}
63+
64+
_logger.TransportConnecting(_endpointName);
65+
66+
try
67+
{
68+
// Start the server if it exists and is not already connected
69+
if (ServerTransport != null && !ServerTransport.IsConnected)
70+
{
71+
await ServerTransport.StartListeningAsync(cancellationToken).ConfigureAwait(false);
72+
}
73+
74+
_cancellationTokenSource = new CancellationTokenSource();
75+
_readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token);
76+
77+
SetConnected(true);
78+
}
79+
catch (Exception ex)
80+
{
81+
_logger.TransportConnectFailed(_endpointName, ex);
82+
await CleanupAsync(cancellationToken).ConfigureAwait(false);
83+
throw new McpTransportException("Failed to connect transport", ex);
84+
}
85+
}
86+
finally
87+
{
88+
_connectLock.Release();
89+
}
90+
}
91+
92+
/// <inheritdoc/>
93+
public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
94+
{
95+
ThrowIfDisposed();
96+
97+
if (!IsConnected)
98+
{
99+
_logger.TransportNotConnected(_endpointName);
100+
throw new McpTransportException("Transport is not connected");
101+
}
102+
103+
string id = "(no id)";
104+
if (message is IJsonRpcMessageWithId messageWithId)
105+
{
106+
id = messageWithId.Id.ToString();
107+
}
108+
109+
try
110+
{
111+
_logger.TransportSendingMessage(_endpointName, id);
112+
await _outgoingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false);
113+
_logger.TransportSentMessage(_endpointName, id);
114+
}
115+
catch (Exception ex)
116+
{
117+
_logger.TransportSendFailed(_endpointName, id, ex);
118+
throw new McpTransportException("Failed to send message", ex);
119+
}
120+
}
121+
122+
/// <inheritdoc/>
123+
public override async ValueTask DisposeAsync()
124+
{
125+
await CleanupAsync(CancellationToken.None).ConfigureAwait(false);
126+
GC.SuppressFinalize(this);
127+
}
128+
129+
private async Task ReadMessagesAsync(CancellationToken cancellationToken)
130+
{
131+
try
132+
{
133+
_logger.TransportEnteringReadMessagesLoop(_endpointName);
134+
135+
await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken))
136+
{
137+
string id = "(no id)";
138+
if (message is IJsonRpcMessageWithId messageWithId)
139+
{
140+
id = messageWithId.Id.ToString();
141+
}
142+
143+
_logger.TransportReceivedMessageParsed(_endpointName, id);
144+
145+
// Write to the base class's message channel that's exposed via MessageReader
146+
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
147+
148+
_logger.TransportMessageWritten(_endpointName, id);
149+
}
150+
151+
_logger.TransportExitingReadMessagesLoop(_endpointName);
152+
}
153+
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
154+
{
155+
_logger.TransportReadMessagesCancelled(_endpointName);
156+
// Normal shutdown
157+
}
158+
catch (Exception ex)
159+
{
160+
_logger.TransportReadMessagesFailed(_endpointName, ex);
161+
}
162+
}
163+
164+
private async Task CleanupAsync(CancellationToken cancellationToken)
165+
{
166+
if (_disposed)
167+
{
168+
return;
169+
}
170+
171+
_disposed = true;
172+
_logger.TransportCleaningUp(_endpointName);
173+
174+
try
175+
{
176+
if (_cancellationTokenSource != null)
177+
{
178+
await _cancellationTokenSource.CancelAsync().ConfigureAwait(false);
179+
_cancellationTokenSource.Dispose();
180+
_cancellationTokenSource = null;
181+
}
182+
183+
if (_readTask != null)
184+
{
185+
try
186+
{
187+
_logger.TransportWaitingForReadTask(_endpointName);
188+
await _readTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken).ConfigureAwait(false);
189+
}
190+
catch (TimeoutException)
191+
{
192+
_logger.TransportCleanupReadTaskTimeout(_endpointName);
193+
}
194+
catch (OperationCanceledException)
195+
{
196+
_logger.TransportCleanupReadTaskCancelled(_endpointName);
197+
}
198+
catch (Exception ex)
199+
{
200+
_logger.TransportCleanupReadTaskFailed(_endpointName, ex);
201+
}
202+
finally
203+
{
204+
_readTask = null;
205+
}
206+
}
207+
208+
_connectLock.Dispose();
209+
}
210+
finally
211+
{
212+
SetConnected(false);
213+
_logger.TransportCleanedUp(_endpointName);
214+
}
215+
}
216+
217+
private void ThrowIfDisposed()
218+
{
219+
if (_disposed)
220+
{
221+
throw new ObjectDisposedException(nameof(InMemoryClientTransport));
222+
}
223+
}
224+
}

0 commit comments

Comments
 (0)