diff --git a/Directory.Packages.props b/Directory.Packages.props index acdc0ee8..8ac0a52c 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -58,6 +58,10 @@ + + + + diff --git a/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj b/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj index 94a5ccdb..8274f7cd 100644 --- a/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj +++ b/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj @@ -12,4 +12,11 @@ + + + + + + + diff --git a/samples/AspNetCoreSseServer/Program.cs b/samples/AspNetCoreSseServer/Program.cs index 306a6e8f..687bb6d5 100644 --- a/samples/AspNetCoreSseServer/Program.cs +++ b/samples/AspNetCoreSseServer/Program.cs @@ -1,10 +1,23 @@ using TestServerWithHosting.Tools; +using OpenTelemetry.Metrics; +using OpenTelemetry.Trace; +using OpenTelemetry; var builder = WebApplication.CreateBuilder(args); builder.Services.AddMcpServer() .WithTools() .WithTools(); +builder.Services.AddOpenTelemetry() + .WithTracing(b => b.AddSource("*") + .AddAspNetCoreInstrumentation() + .AddHttpClientInstrumentation()) + .WithMetrics(b => b.AddMeter("*") + .AddAspNetCoreInstrumentation() + .AddHttpClientInstrumentation()) + .WithLogging() + .UseOtlpExporter(); + var app = builder.Build(); app.MapMcp(); diff --git a/samples/AspNetCoreSseServer/Properties/launchSettings.json b/samples/AspNetCoreSseServer/Properties/launchSettings.json index 3b6f145d..c789fb47 100644 --- a/samples/AspNetCoreSseServer/Properties/launchSettings.json +++ b/samples/AspNetCoreSseServer/Properties/launchSettings.json @@ -6,7 +6,8 @@ "dotnetRunMessages": true, "applicationUrl": "http://localhost:3001", "environmentVariables": { - "ASPNETCORE_ENVIRONMENT": "Development" + "ASPNETCORE_ENVIRONMENT": "Development", + "OTEL_SERVICE_NAME": "sse-server", } }, "https": { @@ -14,7 +15,8 @@ "dotnetRunMessages": true, "applicationUrl": "https://localhost:7133;http://localhost:3001", "environmentVariables": { - "ASPNETCORE_ENVIRONMENT": "Development" + "ASPNETCORE_ENVIRONMENT": "Development", + "OTEL_SERVICE_NAME": "sse-server", } } } diff --git a/samples/ChatWithTools/ChatWithTools.csproj b/samples/ChatWithTools/ChatWithTools.csproj index 8e08a455..13bdafc0 100644 --- a/samples/ChatWithTools/ChatWithTools.csproj +++ b/samples/ChatWithTools/ChatWithTools.csproj @@ -15,6 +15,8 @@ + + diff --git a/samples/ChatWithTools/Program.cs b/samples/ChatWithTools/Program.cs index dd09a7c9..8c5ae823 100644 --- a/samples/ChatWithTools/Program.cs +++ b/samples/ChatWithTools/Program.cs @@ -3,15 +3,49 @@ using Microsoft.Extensions.AI; using OpenAI; +using OpenTelemetry; +using OpenTelemetry.Trace; +using Microsoft.Extensions.Logging; +using OpenTelemetry.Logs; +using OpenTelemetry.Metrics; + +using var tracerProvider = Sdk.CreateTracerProviderBuilder() + .AddHttpClientInstrumentation() + .AddSource("*") + .AddOtlpExporter() + .Build(); +using var metricsProvider = Sdk.CreateMeterProviderBuilder() + .AddHttpClientInstrumentation() + .AddMeter("*") + .AddOtlpExporter() + .Build(); +using var loggerFactory = LoggerFactory.Create(builder => builder.AddOpenTelemetry(opt => opt.AddOtlpExporter())); + // Connect to an MCP server Console.WriteLine("Connecting client to MCP 'everything' server"); + +// Create OpenAI client (or any other compatible with IChatClient) +// Provide your own OPENAI_API_KEY via an environment variable. +var openAIClient = new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")).GetChatClient("gpt-4o-mini"); + +// Create a sampling client. +using IChatClient samplingClient = openAIClient.AsIChatClient() + .AsBuilder() + .UseOpenTelemetry(loggerFactory: loggerFactory, configure: o => o.EnableSensitiveData = true) + .Build(); + var mcpClient = await McpClientFactory.CreateAsync( new StdioClientTransport(new() { Command = "npx", Arguments = ["-y", "--verbose", "@modelcontextprotocol/server-everything"], Name = "Everything", - })); + }), + clientOptions: new() + { + Capabilities = new() { Sampling = new() { SamplingHandler = samplingClient.CreateSamplingHandler() } }, + }, + loggerFactory: loggerFactory); // Get all available tools Console.WriteLine("Tools available:"); @@ -20,13 +54,15 @@ { Console.WriteLine($" {tool}"); } + Console.WriteLine(); -// Create an IChatClient. (This shows using OpenAIClient, but it could be any other IChatClient implementation.) -// Provide your own OPENAI_API_KEY via an environment variable. -using IChatClient chatClient = - new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")).GetChatClient("gpt-4o-mini").AsIChatClient() - .AsBuilder().UseFunctionInvocation().Build(); +// Create an IChatClient that can use the tools. +using IChatClient chatClient = openAIClient.AsIChatClient() + .AsBuilder() + .UseFunctionInvocation() + .UseOpenTelemetry(loggerFactory: loggerFactory, configure: o => o.EnableSensitiveData = true) + .Build(); // Have a conversation, making all tools available to the LLM. List messages = []; diff --git a/samples/EverythingServer/EverythingServer.csproj b/samples/EverythingServer/EverythingServer.csproj index 3aee2bc2..d5046f7e 100644 --- a/samples/EverythingServer/EverythingServer.csproj +++ b/samples/EverythingServer/EverythingServer.csproj @@ -9,6 +9,9 @@ + + + diff --git a/samples/EverythingServer/Program.cs b/samples/EverythingServer/Program.cs index a0966fe7..c9bc1272 100644 --- a/samples/EverythingServer/Program.cs +++ b/samples/EverythingServer/Program.cs @@ -8,6 +8,11 @@ using ModelContextProtocol; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; +using OpenTelemetry; +using OpenTelemetry.Logs; +using OpenTelemetry.Metrics; +using OpenTelemetry.Resources; +using OpenTelemetry.Trace; #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously @@ -186,6 +191,13 @@ await ctx.Server.RequestSamplingAsync([ return new EmptyResult(); }); +ResourceBuilder resource = ResourceBuilder.CreateDefault().AddService("everything-server"); +builder.Services.AddOpenTelemetry() + .WithTracing(b => b.AddSource("*").AddHttpClientInstrumentation().SetResourceBuilder(resource)) + .WithMetrics(b => b.AddMeter("*").AddHttpClientInstrumentation().SetResourceBuilder(resource)) + .WithLogging(b => b.SetResourceBuilder(resource)) + .UseOtlpExporter(); + builder.Services.AddSingleton(subscriptions); builder.Services.AddHostedService(); builder.Services.AddHostedService(); diff --git a/src/ModelContextProtocol/Diagnostics.cs b/src/ModelContextProtocol/Diagnostics.cs index 5b4e31f4..a5d293f0 100644 --- a/src/ModelContextProtocol/Diagnostics.cs +++ b/src/ModelContextProtocol/Diagnostics.cs @@ -1,5 +1,8 @@ using System.Diagnostics; using System.Diagnostics.Metrics; +using System.Text.Json; +using System.Text.Json.Nodes; +using ModelContextProtocol.Protocol.Messages; namespace ModelContextProtocol; @@ -34,4 +37,77 @@ internal static Histogram CreateDurationHistogram(string name, string de HistogramBucketBoundaries = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10, 30, 60, 120, 300], }; #endif + + internal static ActivityContext ExtractActivityContext(this DistributedContextPropagator propagator, IJsonRpcMessage message) + { + propagator.ExtractTraceIdAndState(message, ExtractContext, out var traceparent, out var tracestate); + ActivityContext.TryParse(traceparent, tracestate, true, out var activityContext); + return activityContext; + } + + private static void ExtractContext(object? message, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) + { + fieldValues = null; + fieldValue = null; + + JsonNode? parameters = null; + switch (message) + { + case JsonRpcRequest request: + parameters = request.Params; + break; + + case JsonRpcNotification notification: + parameters = notification.Params; + break; + + default: + break; + } + + if (parameters?[fieldName] is JsonValue value && value.GetValueKind() == JsonValueKind.String) + { + fieldValue = value.GetValue(); + } + } + + internal static void InjectActivityContext(this DistributedContextPropagator propagator, Activity? activity, IJsonRpcMessage message) + { + // noop if activity is null + propagator.Inject(activity, message, InjectContext); + } + + private static void InjectContext(object? message, string key, string value) + { + JsonNode? parameters = null; + switch (message) + { + case JsonRpcRequest request: + parameters = request.Params; + break; + + case JsonRpcNotification notification: + parameters = notification.Params; + break; + + default: + break; + } + + if (parameters is JsonObject jsonObject && jsonObject[key] == null) + { + jsonObject[key] = value; + } + } + + internal static bool ShouldInstrumentMessage(IJsonRpcMessage message) => + ActivitySource.HasListeners() && + message switch + { + JsonRpcRequest => true, + JsonRpcNotification notification => notification.Method != NotificationMethods.LoggingMessageNotification, + _ => false + }; + + internal static ActivityLink[] ActivityLinkFromCurrent() => Activity.Current is null ? [] : [new ActivityLink(Activity.Current.Context)]; } diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index 062d8cf4..afbcecab 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; @@ -10,7 +10,9 @@ using System.Diagnostics.Metrics; using System.Text.Json; using System.Text.Json.Nodes; +#if !NET using System.Threading.Channels; +#endif namespace ModelContextProtocol.Shared; @@ -23,10 +25,10 @@ internal sealed class McpSession : IDisposable "mcp.client.session.duration", "Measures the duration of a client session.", longBuckets: true); private static readonly Histogram s_serverSessionDuration = Diagnostics.CreateDurationHistogram( "mcp.server.session.duration", "Measures the duration of a server session.", longBuckets: true); - private static readonly Histogram s_clientRequestDuration = Diagnostics.CreateDurationHistogram( - "rpc.client.duration", "Measures the duration of outbound RPC.", longBuckets: false); - private static readonly Histogram s_serverRequestDuration = Diagnostics.CreateDurationHistogram( - "rpc.server.duration", "Measures the duration of inbound RPC.", longBuckets: false); + private static readonly Histogram s_clientOperationDuration = Diagnostics.CreateDurationHistogram( + "mcp.client.operation.duration", "Measures the duration of outbound message.", longBuckets: false); + private static readonly Histogram s_serverOperationDuration = Diagnostics.CreateDurationHistogram( + "mcp.server.operation.duration", "Measures the duration of inbound message processing.", longBuckets: false); private readonly bool _isServer; private readonly string _transportKind; @@ -35,6 +37,8 @@ internal sealed class McpSession : IDisposable private readonly NotificationHandlers _notificationHandlers; private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp(); + private readonly DistributedContextPropagator _propagator = DistributedContextPropagator.Current; + /// Collection of requests sent on this session and waiting for responses. private readonly ConcurrentDictionary> _pendingRequests = []; /// @@ -184,12 +188,17 @@ await _transport.SendMessageAsync(new JsonRpcError private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken) { - Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration; + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; string method = GetMethodName(message); long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - Activity? activity = Diagnostics.ActivitySource.HasListeners() ? - Diagnostics.ActivitySource.StartActivity(CreateActivityName(method)) : + + Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? + Diagnostics.ActivitySource.StartActivity( + CreateActivityName(method), + ActivityKind.Server, + parentContext: _propagator.ExtractActivityContext(message), + links: Diagnostics.ActivityLinkFromCurrent()) : null; TagList tags = default; @@ -198,18 +207,14 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken { if (addTags) { - AddStandardTags(ref tags, method); + AddTags(ref tags, activity, message, method); } switch (message) { case JsonRpcRequest request: - if (addTags) - { - AddRpcRequestTags(ref tags, activity, request); - } - - await HandleRequest(request, cancellationToken).ConfigureAwait(false); + var result = await HandleRequest(request, cancellationToken).ConfigureAwait(false); + AddResponseTags(ref tags, activity, result, method); break; case JsonRpcNotification notification: @@ -227,7 +232,7 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken } catch (Exception e) when (addTags) { - AddExceptionTags(ref tags, e); + AddExceptionTags(ref tags, activity, e); throw; } finally @@ -277,7 +282,7 @@ private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId } } - private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) + private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) { if (!_requestHandlers.TryGetValue(request.Method, out var handler)) { @@ -294,6 +299,8 @@ await _transport.SendMessageAsync(new JsonRpcResponse JsonRpc = "2.0", Result = result }, cancellationToken).ConfigureAwait(false); + + return result; } private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, RequestId requestId) @@ -340,12 +347,12 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc cancellationToken.ThrowIfCancellationRequested(); - Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration; + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; string method = request.Method; long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - using Activity? activity = Diagnostics.ActivitySource.HasListeners() ? - Diagnostics.ActivitySource.StartActivity(CreateActivityName(method)) : + using Activity? activity = Diagnostics.ShouldInstrumentMessage(request) ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : null; // Set request ID @@ -354,6 +361,8 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc request.Id = new RequestId($"{_id}-{Interlocked.Increment(ref _nextRequestId)}"); } + _propagator.InjectActivityContext(activity, request); + TagList tags = default; bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; @@ -363,8 +372,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc { if (addTags) { - AddStandardTags(ref tags, method); - AddRpcRequestTags(ref tags, activity, request); + AddTags(ref tags, activity, request, method); } // Expensive logging, use the logging framework to check if the logger is enabled @@ -396,6 +404,11 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc if (response is JsonRpcResponse success) { + if (addTags) + { + AddResponseTags(ref tags, activity, success.Result, method); + } + _logger.RequestResponseReceivedPayload(EndpointName, success.Result?.ToJsonString() ?? "null"); _logger.RequestResponseReceived(EndpointName, request.Method); return success; @@ -407,7 +420,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc } catch (Exception ex) when (addTags) { - AddExceptionTags(ref tags, ex); + AddExceptionTags(ref tags, activity, ex); throw; } finally @@ -429,22 +442,25 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca cancellationToken.ThrowIfCancellationRequested(); - Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration; + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; string method = GetMethodName(message); long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - using Activity? activity = Diagnostics.ActivitySource.HasListeners() ? - Diagnostics.ActivitySource.StartActivity(CreateActivityName(method)) : + using Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : null; TagList tags = default; bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + // propagate trace context + _propagator?.InjectActivityContext(activity, message); + try { if (addTags) { - AddStandardTags(ref tags, method); + AddTags(ref tags, activity, message, method); } if (_logger.IsEnabled(LogLevel.Debug)) @@ -466,7 +482,7 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca } catch (Exception ex) when (addTags) { - AddExceptionTags(ref tags, ex); + AddExceptionTags(ref tags, activity, ex); throw; } finally @@ -487,77 +503,118 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca } } - private string CreateActivityName(string method) => - $"mcp.{(_isServer ? "server" : "client")}.{_transportKind}/{method}"; + private string CreateActivityName(string method) => method; private static string GetMethodName(IJsonRpcMessage message) => message switch { JsonRpcRequest request => request.Method, JsonRpcNotification notification => notification.Method, - _ => "unknownMethod", + _ => "unknownMethod" }; - private void AddStandardTags(ref TagList tags, string method) + private void AddTags(ref TagList tags, Activity? activity, IJsonRpcMessage message, string method) { - tags.Add("session.id", _id); - tags.Add("rpc.system", "jsonrpc"); - tags.Add("rpc.jsonrpc.version", "2.0"); - tags.Add("rpc.method", method); + tags.Add("mcp.method.name", method); tags.Add("network.transport", _transportKind); - // RPC spans convention also includes: - // server.address, server.port, client.address, client.port, network.peer.address, network.peer.port, network.type - } + // TODO: When using SSE transport, add: + // - server.address and server.port on client spans and metrics + // - client.address and client.port on server spans (not metrics because of cardinality) when using SSE transport + if (activity is { IsAllDataRequested: true }) + { + // session and request id have high cardinality, so not applying to metric tags + activity.AddTag("mcp.session.id", _id); - private static void AddRpcRequestTags(ref TagList tags, Activity? activity, JsonRpcRequest request) - { - tags.Add("rpc.jsonrpc.request_id", request.Id.ToString()); + if (message is IJsonRpcMessageWithId withId) + { + activity.AddTag("mcp.request.id", withId.Id.Id?.ToString()); + } + } - if (request.Params is JsonObject paramsObj) + JsonObject? paramsObj = message switch { - switch (request.Method) - { - case RequestMethods.ToolsCall: - case RequestMethods.PromptsGet: - if (paramsObj.TryGetPropertyValue("name", out var prop) && prop?.GetValueKind() is JsonValueKind.String) - { - string name = prop.GetValue(); - tags.Add("mcp.request.params.name", name); - if (activity is not null) - { - activity.DisplayName = $"{request.Method}({name})"; - } - } - break; + JsonRpcRequest request => request.Params as JsonObject, + JsonRpcNotification notification => notification.Params as JsonObject, + _ => null + }; - case RequestMethods.ResourcesRead: - if (paramsObj.TryGetPropertyValue("uri", out prop) && prop?.GetValueKind() is JsonValueKind.String) - { - string uri = prop.GetValue(); - tags.Add("mcp.request.params.uri", uri); - if (activity is not null) - { - activity.DisplayName = $"{request.Method}({uri})"; - } - } - break; - } + if (paramsObj == null) + { + return; + } + + string? target = null; + switch (method) + { + case RequestMethods.ToolsCall: + case RequestMethods.PromptsGet: + target = GetStringProperty(paramsObj, "name"); + if (target is not null) + { + tags.Add(method == RequestMethods.ToolsCall ? "mcp.tool.name" : "mcp.prompt.name", target); + } + break; + + case RequestMethods.ResourcesRead: + case RequestMethods.ResourcesSubscribe: + case RequestMethods.ResourcesUnsubscribe: + case NotificationMethods.ResourceUpdatedNotification: + target = GetStringProperty(paramsObj, "uri"); + if (target is not null) + { + tags.Add("mcp.resource.uri", target); + } + break; + } + + if (activity is { IsAllDataRequested: true }) + { + activity.DisplayName = target == null ? method : $"{method} {target}"; } } - private static void AddExceptionTags(ref TagList tags, Exception e) + private static void AddExceptionTags(ref TagList tags, Activity? activity, Exception e) { if (e is AggregateException ae && ae.InnerException is not null and not AggregateException) { e = ae.InnerException; } - tags.Add("error.type", e.GetType().FullName); - tags.Add("rpc.jsonrpc.error_code", - (e as McpException)?.ErrorCode is int errorCode ? errorCode : - e is JsonException ? ErrorCodes.ParseError : - ErrorCodes.InternalError); + int? intErrorCode = (e as McpException)?.ErrorCode is int errorCode ? errorCode : + e is JsonException ? ErrorCodes.ParseError : null; + + tags.Add("error.type", intErrorCode == null ? e.GetType().FullName : intErrorCode.ToString()); + if (intErrorCode is not null) + { + tags.Add("rpc.jsonrpc.error_code", intErrorCode.ToString()); + } + + if (activity is { IsAllDataRequested: true }) + { + activity.SetStatus(ActivityStatusCode.Error, e.Message); + } + } + + private static void AddResponseTags(ref TagList tags, Activity? activity, JsonNode? response, string method) + { + if (response is JsonObject jsonObject + && jsonObject.TryGetPropertyValue("isError", out var isError) + && isError?.GetValueKind() == JsonValueKind.True) + { + if (activity is { IsAllDataRequested: true }) + { + string? content = null; + if (jsonObject.TryGetPropertyValue("content", out var prop) && prop != null) + { + content = prop.ToJsonString(); + } + + activity.SetStatus(ActivityStatusCode.Error, content); + } + + tags.Add("error.type", method == RequestMethods.ToolsCall ? "tool_error" : "_OTHER"); + } } private static void FinalizeDiagnostics( @@ -590,8 +647,10 @@ public void Dispose() if (durationMetric.Enabled) { TagList tags = default; - tags.Add("session.id", _id); tags.Add("network.transport", _transportKind); + + // TODO: Add server.address and server.port on client-side when using SSE transport, + // client.* attributes are not added to metrics because of cardinality durationMetric.Record(GetElapsed(_sessionStartingTimestamp).TotalSeconds, tags); } @@ -614,4 +673,14 @@ private static TimeSpan GetElapsed(long startingTimestamp) => #else new((long)(s_timestampToTicks * (Stopwatch.GetTimestamp() - startingTimestamp))); #endif -} \ No newline at end of file + + private static string? GetStringProperty(JsonObject parameters, string propName) + { + if (parameters.TryGetPropertyValue(propName, out var prop) && prop?.GetValueKind() is JsonValueKind.String) + { + return prop.GetValue(); + } + + return null; + } +} diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index fdd3e81f..1d168914 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -33,9 +33,91 @@ await RunConnected(async (client, server) => Assert.NotEmpty(activities); - Activity toolCallActivity = activities.First(a => - a.Tags.Any(t => t.Key == "rpc.method" && t.Value == "tools/call")); - Assert.Equal("DoubleValue", toolCallActivity.Tags.First(t => t.Key == "mcp.request.params.name").Value); + var clientToolCall = Assert.Single(activities, a => + a.Tags.Any(t => t.Key == "mcp.tool.name" && t.Value == "DoubleValue") && + a.Tags.Any(t => t.Key == "mcp.method.name" && t.Value == "tools/call") && + a.DisplayName == "tools/call DoubleValue" && + a.Kind == ActivityKind.Client && + a.Status == ActivityStatusCode.Unset); + + var serverToolCall = Assert.Single(activities, a => + a.Tags.Any(t => t.Key == "mcp.tool.name" && t.Value == "DoubleValue") && + a.Tags.Any(t => t.Key == "mcp.method.name" && t.Value == "tools/call") && + a.DisplayName == "tools/call DoubleValue" && + a.Kind == ActivityKind.Server && + a.Status == ActivityStatusCode.Unset); + + Assert.Equal(clientToolCall.SpanId, serverToolCall.ParentSpanId); + Assert.Equal(clientToolCall.TraceId, serverToolCall.TraceId); + + var clientListToolsCall = Assert.Single(activities, a => + a.Tags.Any(t => t.Key == "mcp.method.name" && t.Value == "tools/list") && + a.DisplayName == "tools/list" && + a.Kind == ActivityKind.Client && + a.Status == ActivityStatusCode.Unset); + + var serverListToolsCall = Assert.Single(activities, a => + a.Tags.Any(t => t.Key == "mcp.method.name" && t.Value == "tools/list") && + a.DisplayName == "tools/list" && + a.Kind == ActivityKind.Server && + a.Status == ActivityStatusCode.Unset); + + Assert.Equal(clientListToolsCall.SpanId, serverListToolsCall.ParentSpanId); + Assert.Equal(clientListToolsCall.TraceId, serverListToolsCall.TraceId); + } + + [Fact] + public async Task Session_FailedToolCall() + { + var activities = new List(); + + using (var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource("Experimental.ModelContextProtocol") + .AddInMemoryExporter(activities) + .Build()) + { + await RunConnected(async (client, server) => + { + await client.CallToolAsync("Throw", cancellationToken: TestContext.Current.CancellationToken); + await Assert.ThrowsAsync(() => client.CallToolAsync("does-not-exist", cancellationToken: TestContext.Current.CancellationToken)); + }); + } + + Assert.NotEmpty(activities); + + var throwToolClient = Assert.Single(activities, a => + a.Tags.Any(t => t.Key == "mcp.tool.name" && t.Value == "Throw") && + a.Tags.Any(t => t.Key == "mcp.method.name" && t.Value == "tools/call") && + a.DisplayName == "tools/call Throw" && + a.Kind == ActivityKind.Client); + + Assert.Equal(ActivityStatusCode.Error, throwToolClient.Status); + + var throwToolServer = Assert.Single(activities, a => + a.Tags.Any(t => t.Key == "mcp.tool.name" && t.Value == "Throw") && + a.Tags.Any(t => t.Key == "mcp.method.name" && t.Value == "tools/call") && + a.DisplayName == "tools/call Throw" && + a.Kind == ActivityKind.Server); + + Assert.Equal(ActivityStatusCode.Error, throwToolServer.Status); + + var doesNotExistToolClient = Assert.Single(activities, a => + a.Tags.Any(t => t.Key == "mcp.tool.name" && t.Value == "does-not-exist") && + a.Tags.Any(t => t.Key == "mcp.method.name" && t.Value == "tools/call") && + a.DisplayName == "tools/call does-not-exist" && + a.Kind == ActivityKind.Client); + + Assert.Equal(ActivityStatusCode.Error, doesNotExistToolClient.Status); + Assert.Equal("-32603", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); + + var doesNotExistToolServer = Assert.Single(activities, a => + a.Tags.Any(t => t.Key == "mcp.tool.name" && t.Value == "does-not-exist") && + a.Tags.Any(t => t.Key == "mcp.method.name" && t.Value == "tools/call") && + a.DisplayName == "tools/call does-not-exist" && + a.Kind == ActivityKind.Server); + + Assert.Equal(ActivityStatusCode.Error, doesNotExistToolServer.Status); + Assert.Equal("-32603", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); } private static async Task RunConnected(Func action) @@ -52,7 +134,10 @@ private static async Task RunConnected(Func action { Tools = new() { - ToolCollection = [McpServerTool.Create((int amount) => amount * 2, new() { Name = "DoubleValue", Description = "Doubles the value." })], + ToolCollection = [ + McpServerTool.Create((int amount) => amount * 2, new() { Name = "DoubleValue", Description = "Doubles the value." }), + McpServerTool.Create(() => { throw new Exception("boom"); }, new() { Name = "Throw", Description = "Throws error." }), + ], } } }))