diff --git a/README.MD b/README.MD index a55975aa..2f40a114 100644 --- a/README.MD +++ b/README.MD @@ -44,7 +44,7 @@ await foreach (var tool in client.ListToolsAsync()) // Execute a tool (this would normally be driven by LLM tool invocations). var result = await client.CallToolAsync( "echo", - new() { ["message"] = "Hello MCP!" }, + new Dictionary() { ["message"] = "Hello MCP!" }, CancellationToken.None); // echo always returns one and only one text content object @@ -59,16 +59,13 @@ Tools can be exposed easily as `AIFunction` instances so that they are immediate ```csharp // Get available functions. -IList tools = await client.GetAIFunctionsAsync(); +IList tools = await client.ListToolsAsync(); // Call the chat client using the tools. IChatClient chatClient = ...; var response = await chatClient.GetResponseAsync( "your prompt here", - new() - { - Tools = [.. tools], - }); + new() { Tools = [.. tools] }, ``` ## Getting Started (Server) @@ -88,17 +85,47 @@ var builder = Host.CreateEmptyApplicationBuilder(settings: null); builder.Services .AddMcpServer() .WithStdioServerTransport() - .WithTools(); + .WithToolsFromAssembly(); await builder.Build().RunAsync(); -[McpToolType] +[McpServerToolType] public static class EchoTool { - [McpTool, Description("Echoes the message back to the client.")] + [McpServerTool, Description("Echoes the message back to the client.")] public static string Echo(string message) => $"hello {message}"; } ``` +Tools can have the `IMcpServer` representing the server injected via a parameter to the method, and can use that for interaction with +the connected client. Similarly, arguments may be injected via dependency injection. For example, this tool will use the supplied +`IMcpServer` to make sampling requests back to the client in order to summarize content it downloads from the specified url via +an `HttpClient` injected via dependency injection. +```csharp +[McpServerTool("SummarizeContentFromUrl"), Description("Summarizes content downloaded from a specific URI")] +public static async Task SummarizeDownloadedContent( + IMcpServer thisServer, + HttpClient httpClient, + [Description("The url from which to download the content to summarize")] string url, + CancellationToken cancellationToken) +{ + string content = await httpClient.GetStringAsync(url); + + ChatMessage[] messages = + [ + new(ChatRole.User, "Briefly summarize the following downloaded content:"), + new(ChatRole.User, content), + ] + + ChatOptions options = new() + { + MaxOutputTokens = 256, + Temperature = 0.3f, + }; + + return $"Summary: {await thisServer.AsSamplingChatClient().GetResponseAsync(messages, options, cancellationToken)}"; +} +``` + More control is also available, with fine-grained control over configuring the server and how it should handle client requests. For example: ```csharp @@ -124,14 +151,18 @@ McpServerOptions options = new() { Name = "echo", Description = "Echoes the input back to the client.", - InputSchema = new JsonSchema() - { - Type = "object", - Properties = new Dictionary() + InputSchema = JsonSerializer.Deserialize(""" { - ["message"] = new JsonSchemaProperty() { Type = "string", Description = "The input to echo back." } + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The input to echo back" + } + }, + "required": ["message"] } - }, + """), } ] }; diff --git a/samples/AspNetCoreSseServer/Program.cs b/samples/AspNetCoreSseServer/Program.cs index 9e210d8f..a3cd9414 100644 --- a/samples/AspNetCoreSseServer/Program.cs +++ b/samples/AspNetCoreSseServer/Program.cs @@ -2,7 +2,7 @@ using AspNetCoreSseServer; var builder = WebApplication.CreateBuilder(args); -builder.Services.AddMcpServer().WithTools(); +builder.Services.AddMcpServer().WithToolsFromAssembly(); var app = builder.Build(); app.MapGet("/", () => "Hello World!"); diff --git a/samples/AspNetCoreSseServer/Tools/EchoTool.cs b/samples/AspNetCoreSseServer/Tools/EchoTool.cs index cb21cc5c..636b4063 100644 --- a/samples/AspNetCoreSseServer/Tools/EchoTool.cs +++ b/samples/AspNetCoreSseServer/Tools/EchoTool.cs @@ -3,10 +3,10 @@ namespace TestServerWithHosting.Tools; -[McpToolType] +[McpServerToolType] public static class EchoTool { - [McpTool, Description("Echoes the input back to the client.")] + [McpServerTool, Description("Echoes the input back to the client.")] public static string Echo(string message) { return "hello " + message; diff --git a/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs b/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs index 44b59bc6..880787ed 100644 --- a/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs +++ b/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs @@ -1,51 +1,36 @@ -using ModelContextProtocol.Protocol.Types; +using Microsoft.Extensions.AI; using ModelContextProtocol.Server; using System.ComponentModel; namespace TestServerWithHosting.Tools; /// -/// This tool uses depenency injection and async method +/// This tool uses dependency injection and async method /// -[McpToolType] -public class SampleLlmTool +[McpServerToolType] +public static class SampleLlmTool { - private readonly IMcpServer _server; - - public SampleLlmTool(IMcpServer server) - { - _server = server ?? throw new ArgumentNullException(nameof(server)); - } - - [McpTool("sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] - public async Task SampleLLM( + [McpServerTool("sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] + public static async Task SampleLLM( + IMcpServer thisServer, [Description("The prompt to send to the LLM")] string prompt, [Description("Maximum number of tokens to generate")] int maxTokens, CancellationToken cancellationToken) { - var samplingParams = CreateRequestSamplingParams(prompt ?? string.Empty, "sampleLLM", maxTokens); - var sampleResult = await _server.RequestSamplingAsync(samplingParams, cancellationToken); + ChatMessage[] messages = + [ + new(ChatRole.System, "You are a helpful test server."), + new(ChatRole.User, prompt), + ]; - return $"LLM sampling result: {sampleResult.Content.Text}"; - } - - private static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) - { - return new CreateMessageRequestParams() + ChatOptions options = new() { - Messages = [new SamplingMessage() - { - Role = Role.User, - Content = new Content() - { - Type = "text", - Text = $"Resource {uri} context: {context}" - } - }], - SystemPrompt = "You are a helpful test server.", - MaxTokens = maxTokens, + MaxOutputTokens = maxTokens, Temperature = 0.7f, - IncludeContext = ContextInclusion.ThisServer }; + + var samplingResponse = await thisServer.AsSamplingChatClient().GetResponseAsync(messages, options, cancellationToken); + + return $"LLM sampling result: {samplingResponse}"; } } diff --git a/samples/ChatWithTools/ChatWithTools.csproj b/samples/ChatWithTools/ChatWithTools.csproj index e3ae1b79..af8fac19 100644 --- a/samples/ChatWithTools/ChatWithTools.csproj +++ b/samples/ChatWithTools/ChatWithTools.csproj @@ -11,7 +11,6 @@ - diff --git a/samples/ChatWithTools/Program.cs b/samples/ChatWithTools/Program.cs index 2dcf06fa..49380674 100644 --- a/samples/ChatWithTools/Program.cs +++ b/samples/ChatWithTools/Program.cs @@ -19,7 +19,7 @@ // Get all available tools Console.WriteLine("Tools available:"); -var tools = await mcpClient.GetAIFunctionsAsync(); +var tools = await mcpClient.ListToolsAsync(); foreach (var tool in tools) { Console.WriteLine($" {tool}"); diff --git a/samples/TestServerWithHosting/Program.cs b/samples/TestServerWithHosting/Program.cs index 1f745086..82b731df 100644 --- a/samples/TestServerWithHosting/Program.cs +++ b/samples/TestServerWithHosting/Program.cs @@ -19,7 +19,7 @@ builder.Services.AddSerilog(); builder.Services.AddMcpServer() .WithStdioServerTransport() - .WithTools(); + .WithToolsFromAssembly(); var app = builder.Build(); diff --git a/samples/TestServerWithHosting/Tools/EchoTool.cs b/samples/TestServerWithHosting/Tools/EchoTool.cs index cb21cc5c..636b4063 100644 --- a/samples/TestServerWithHosting/Tools/EchoTool.cs +++ b/samples/TestServerWithHosting/Tools/EchoTool.cs @@ -3,10 +3,10 @@ namespace TestServerWithHosting.Tools; -[McpToolType] +[McpServerToolType] public static class EchoTool { - [McpTool, Description("Echoes the input back to the client.")] + [McpServerTool, Description("Echoes the input back to the client.")] public static string Echo(string message) { return "hello " + message; diff --git a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs index 44b59bc6..9c8c02d0 100644 --- a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs +++ b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs @@ -7,7 +7,7 @@ namespace TestServerWithHosting.Tools; /// /// This tool uses depenency injection and async method /// -[McpToolType] +[McpServerToolType] public class SampleLlmTool { private readonly IMcpServer _server; @@ -17,7 +17,7 @@ public SampleLlmTool(IMcpServer server) _server = server ?? throw new ArgumentNullException(nameof(server)); } - [McpTool("sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] + [McpServerTool("sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public async Task SampleLLM( [Description("The prompt to send to the LLM")] string prompt, [Description("Maximum number of tokens to generate")] int maxTokens, diff --git a/src/Common/Polyfills/System/Collections/Generic/CollectionExtensions.cs b/src/Common/Polyfills/System/Collections/Generic/CollectionExtensions.cs index ccfbd392..6a980088 100644 --- a/src/Common/Polyfills/System/Collections/Generic/CollectionExtensions.cs +++ b/src/Common/Polyfills/System/Collections/Generic/CollectionExtensions.cs @@ -15,4 +15,7 @@ public static TValue GetValueOrDefault(this IReadOnlyDictionary ToDictionary(this IEnumerable> source) => + source.ToDictionary(kv => kv.Key, kv => kv.Value); } \ No newline at end of file diff --git a/src/ModelContextProtocol/AIContentExtensions.cs b/src/ModelContextProtocol/AIContentExtensions.cs new file mode 100644 index 00000000..6a3f1773 --- /dev/null +++ b/src/ModelContextProtocol/AIContentExtensions.cs @@ -0,0 +1,104 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Utils; +using System.Runtime.InteropServices; + +namespace ModelContextProtocol; + +/// Provides helpers for conversions related to . +public static class AIContentExtensions +{ + /// Creates a from a . + /// The message to convert. + /// The created . + public static ChatMessage ToChatMessage(this PromptMessage promptMessage) + { + Throw.IfNull(promptMessage); + + return new() + { + RawRepresentation = promptMessage, + Role = promptMessage.Role == Role.User ? ChatRole.User : ChatRole.Assistant, + Contents = [ToAIContent(promptMessage.Content)] + }; + } + + /// Creates a new from the content of a . + /// The to convert. + /// The created . + public static AIContent ToAIContent(this Content content) + { + Throw.IfNull(content); + + AIContent ac; + if (content is { Type: "image", MimeType: not null, Data: not null }) + { + ac = new DataContent(Convert.FromBase64String(content.Data), content.MimeType); + } + else if (content is { Type: "resource" } && content.Resource is { } resourceContents) + { + ac = resourceContents.Blob is not null && resourceContents.MimeType is not null ? + new DataContent(Convert.FromBase64String(resourceContents.Blob), resourceContents.MimeType) : + new TextContent(resourceContents.Text); + + (ac.AdditionalProperties ??= [])["uri"] = resourceContents.Uri; + } + else + { + ac = new TextContent(content.Text); + } + + ac.RawRepresentation = content; + + return ac; + } + + /// Creates a new from the content of a . + /// The to convert. + /// The created . + public static AIContent ToAIContent(this ResourceContents content) + { + Throw.IfNull(content); + + AIContent ac = content.Blob is not null && content.MimeType is not null ? + new DataContent(Convert.FromBase64String(content.Blob), content.MimeType) : + new TextContent(content.Text); + + (ac.AdditionalProperties ??= [])["uri"] = content.Uri; + ac.RawRepresentation = content; + + return ac; + } + + /// Creates a list of from a sequence of . + /// The instances to convert. + /// The created instances. + public static IList ToAIContents(this IEnumerable contents) + { + Throw.IfNull(contents); + + return contents.Select(ToAIContent).ToList(); + } + + /// Creates a list of from a sequence of . + /// The instances to convert. + /// The created instances. + public static IList ToAIContents(this IEnumerable contents) + { + Throw.IfNull(contents); + + return contents.Select(ToAIContent).ToList(); + } + + /// Extracts the data from a as a Base64 string. + internal static string GetBase64Data(this DataContent dataContent) + { +#if NET + return Convert.ToBase64String(dataContent.Data.Span); +#else + return MemoryMarshal.TryGetArray(dataContent.Data, out ArraySegment segment) ? + Convert.ToBase64String(segment.Array!, segment.Offset, segment.Count) : + Convert.ToBase64String(dataContent.Data.ToArray()); +#endif + } +} diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 1a5c57cc..bf453cd9 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -3,8 +3,8 @@ using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using Microsoft.Extensions.AI; -using System.Runtime.CompilerServices; using System.Text.Json; +using System.Runtime.CompilerServices; namespace ModelContextProtocol.Client; @@ -23,6 +23,7 @@ public static class McpClientExtensions public static Task SendNotificationAsync(this IMcpClient client, string method, object? parameters = null, CancellationToken cancellationToken = default) { Throw.IfNull(client); + Throw.IfNullOrWhiteSpace(method); return client.SendMessageAsync( new JsonRpcNotification { Method = method, Params = parameters }, @@ -45,42 +46,67 @@ public static Task PingAsync(this IMcpClient client, CancellationToken cancellat } /// - /// Retrieves a sequence of available tools from the server. + /// Retrieves a list of available tools from the server. /// /// The client. /// A token to cancel the operation. - /// An asynchronous sequence of tool information. - public static async IAsyncEnumerable ListToolsAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) + /// A list of all available tools. + public static async Task> ListToolsAsync( + this IMcpClient client, CancellationToken cancellationToken = default) { + Throw.IfNull(client); + + List? tools = null; string? cursor = null; do { - var tools = await ListToolsAsync(client, cursor, cancellationToken).ConfigureAwait(false); - foreach (var tool in tools.Tools) + var toolResults = await client.SendRequestAsync( + CreateRequest("tools/list", CreateCursorDictionary(cursor)), + cancellationToken).ConfigureAwait(false); + + tools ??= new List(toolResults.Tools.Count); + foreach (var tool in toolResults.Tools) { - yield return tool; + tools.Add(new McpClientTool(client, tool)); } - cursor = tools.NextCursor; + cursor = toolResults.NextCursor; } while (cursor is not null); + + return tools; } /// - /// Retrieves a sequence of available tools from the server. + /// Creates an enumerable for asynchronously enumerating all available tools from the server. /// /// The client. - /// A cursor to paginate the results. /// A token to cancel the operation. - /// A task containing the server's response with tool information. - public static Task ListToolsAsync(this IMcpClient client, string? cursor, CancellationToken cancellationToken = default) + /// An asynchronous sequence of all available tools. + /// + /// Every iteration through the returned + /// will result in requerying the server and yielding the sequence of available tools. + /// + public static async IAsyncEnumerable EnumerateToolsAsync( + this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Throw.IfNull(client); - return client.SendRequestAsync( - CreateRequest("tools/list", CreateCursorDictionary(cursor)), - cancellationToken); + string? cursor = null; + do + { + var toolResults = await client.SendRequestAsync( + CreateRequest("tools/list", CreateCursorDictionary(cursor)), + cancellationToken).ConfigureAwait(false); + + foreach (var tool in toolResults.Tools) + { + yield return new McpClientTool(client, tool); + } + + cursor = toolResults.NextCursor; + } + while (cursor is not null); } /// @@ -88,38 +114,67 @@ public static Task ListToolsAsync(this IMcpClient client, strin /// /// The client. /// A token to cancel the operation. - /// An asynchronous sequence of prompt information. - public static async IAsyncEnumerable ListPromptsAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) + /// A list of all available prompts. + public static async Task> ListPromptsAsync( + this IMcpClient client, CancellationToken cancellationToken = default) { + Throw.IfNull(client); + + List? prompts = null; + string? cursor = null; do { - var prompts = await ListPromptsAsync(client, cursor, cancellationToken).ConfigureAwait(false); - foreach (var prompt in prompts.Prompts) + var promptResults = await client.SendRequestAsync( + CreateRequest("prompts/list", CreateCursorDictionary(cursor)), + cancellationToken).ConfigureAwait(false); + + if (prompts is null) { - yield return prompt; + prompts = promptResults.Prompts; + } + else + { + prompts.AddRange(promptResults.Prompts); } - cursor = prompts.NextCursor; + cursor = promptResults.NextCursor; } while (cursor is not null); + + return prompts; } /// - /// Retrieves a list of available prompts from the server. + /// Creates an enumerable for asynchronously enumerating all available prompts from the server. /// /// The client. - /// A cursor to paginate the results. /// A token to cancel the operation. - /// A task containing the server's response with prompt information. - public static Task ListPromptsAsync(this IMcpClient client, string? cursor, CancellationToken cancellationToken = default) + /// An asynchronous sequence of all available prompts. + /// + /// Every iteration through the returned + /// will result in requerying the server and yielding the sequence of available prompts. + /// + public static async IAsyncEnumerable EnumeratePromptsAsync( + this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Throw.IfNull(client); - return client.SendRequestAsync( - CreateRequest("prompts/list", CreateCursorDictionary(cursor)), - cancellationToken); + string? cursor = null; + do + { + var promptResults = await client.SendRequestAsync( + CreateRequest("prompts/list", CreateCursorDictionary(cursor)), + cancellationToken).ConfigureAwait(false); + + foreach (var prompt in promptResults.Prompts) + { + yield return prompt; + } + + cursor = promptResults.NextCursor; + } + while (cursor is not null); } /// @@ -130,9 +185,11 @@ public static Task ListPromptsAsync(this IMcpClient client, s /// Optional arguments for the prompt /// A token to cancel the operation. /// A task containing the prompt's content and messages. - public static Task GetPromptAsync(this IMcpClient client, string name, Dictionary? arguments = null, CancellationToken cancellationToken = default) + public static Task GetPromptAsync( + this IMcpClient client, string name, Dictionary? arguments = null, CancellationToken cancellationToken = default) { Throw.IfNull(client); + Throw.IfNullOrWhiteSpace(name); return client.SendRequestAsync( CreateRequest("prompts/get", CreateParametersDictionary(name, arguments)), @@ -140,79 +197,139 @@ public static Task GetPromptAsync(this IMcpClient client, strin } /// - /// Retrieves a sequence of available resource templates from the server. + /// Retrieves a list of available resource templates from the server. /// /// The client. /// A token to cancel the operation. - /// An asynchronous sequence of resource template information. - public static async IAsyncEnumerable ListResourceTemplatesAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) + /// A list of all available resource templates. + public static async Task> ListResourceTemplatesAsync( + this IMcpClient client, CancellationToken cancellationToken = default) { + Throw.IfNull(client); + + List? templates = null; + string? cursor = null; do { - var resources = await ListResourceTemplatesAsync(client, cursor, cancellationToken).ConfigureAwait(false); - foreach (var resource in resources.ResourceTemplates) + var templateResults = await client.SendRequestAsync( + CreateRequest("resources/templates/list", CreateCursorDictionary(cursor)), + cancellationToken).ConfigureAwait(false); + + if (templates is null) { - yield return resource; + templates = templateResults.ResourceTemplates; + } + else + { + templates.AddRange(templateResults.ResourceTemplates); } - cursor = resources.NextCursor; + cursor = templateResults.NextCursor; } while (cursor is not null); + + return templates; } /// - /// Retrieves a list of available resources from the server. + /// Creates an enumerable for asynchronously enumerating all available resource templates from the server. /// /// The client. - /// A cursor to paginate the results. /// A token to cancel the operation. - public static Task ListResourceTemplatesAsync(this IMcpClient client, string? cursor, CancellationToken cancellationToken = default) + /// An asynchronous sequence of all available resource templates. + /// + /// Every iteration through the returned + /// will result in requerying the server and yielding the sequence of available resource templates. + /// + public static async IAsyncEnumerable EnumerateResourceTemplatesAsync( + this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Throw.IfNull(client); - return client.SendRequestAsync( - CreateRequest("resources/templates/list", CreateCursorDictionary(cursor)), - cancellationToken); + string? cursor = null; + do + { + var templateResults = await client.SendRequestAsync( + CreateRequest("resources/templates/list", CreateCursorDictionary(cursor)), + cancellationToken).ConfigureAwait(false); + + foreach (var template in templateResults.ResourceTemplates) + { + yield return template; + } + + cursor = templateResults.NextCursor; + } + while (cursor is not null); } /// - /// Retrieves a sequence of available resources from the server. + /// Retrieves a list of available resources from the server. /// /// The client. /// A token to cancel the operation. - /// An asynchronous sequence of resource information. - public static async IAsyncEnumerable ListResourcesAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) + /// A list of all available resources. + public static async Task> ListResourcesAsync( + this IMcpClient client, CancellationToken cancellationToken = default) { + Throw.IfNull(client); + + List? resources = null; + string? cursor = null; do { - var resources = await ListResourcesAsync(client, cursor, cancellationToken).ConfigureAwait(false); - foreach (var resource in resources.Resources) + var resourceResults = await client.SendRequestAsync( + CreateRequest("resources/list", CreateCursorDictionary(cursor)), + cancellationToken).ConfigureAwait(false); + + if (resources is null) { - yield return resource; + resources = resourceResults.Resources; + } + else + { + resources.AddRange(resourceResults.Resources); } - cursor = resources.NextCursor; + cursor = resourceResults.NextCursor; } while (cursor is not null); + + return resources; } /// - /// Retrieves a list of available resources from the server. + /// Creates an enumerable for asynchronously enumerating all available resources from the server. /// /// The client. - /// A cursor to paginate the results. /// A token to cancel the operation. - public static Task ListResourcesAsync(this IMcpClient client, string? cursor, CancellationToken cancellationToken = default) + /// An asynchronous sequence of all available resources. + /// + /// Every iteration through the returned + /// will result in requerying the server and yielding the sequence of available resources. + /// + public static async IAsyncEnumerable EnumerateResourcesAsync( + this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Throw.IfNull(client); - return client.SendRequestAsync( - CreateRequest("resources/list", CreateCursorDictionary(cursor)), - cancellationToken); + string? cursor = null; + do + { + var resourceResults = await client.SendRequestAsync( + CreateRequest("resources/list", CreateCursorDictionary(cursor)), + cancellationToken).ConfigureAwait(false); + + foreach (var resource in resourceResults.Resources) + { + yield return resource; + } + + cursor = resourceResults.NextCursor; + } + while (cursor is not null); } /// @@ -221,9 +338,11 @@ public static Task ListResourcesAsync(this IMcpClient clien /// The client. /// The uri of the resource. /// A token to cancel the operation. - public static Task ReadResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) + public static Task ReadResourceAsync( + this IMcpClient client, string uri, CancellationToken cancellationToken = default) { Throw.IfNull(client); + Throw.IfNullOrWhiteSpace(uri); return client.SendRequestAsync( CreateRequest("resources/read", new() { ["uri"] = uri }), @@ -267,6 +386,7 @@ public static Task GetCompletionAsync(this IMcpClient client, Re public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) { Throw.IfNull(client); + Throw.IfNullOrWhiteSpace(uri); return client.SendRequestAsync( CreateRequest("resources/subscribe", new() { ["uri"] = uri }), @@ -282,6 +402,7 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) { Throw.IfNull(client); + Throw.IfNullOrWhiteSpace(uri); return client.SendRequestAsync( CreateRequest("resources/unsubscribe", new() { ["uri"] = uri }), @@ -296,48 +417,17 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u /// Optional arguments for the tool. /// A token to cancel the operation. /// A task containing the tool's response. - public static Task CallToolAsync(this IMcpClient client, string toolName, Dictionary arguments, CancellationToken cancellationToken = default) + public static Task CallToolAsync( + this IMcpClient client, string toolName, IReadOnlyDictionary? arguments = null, CancellationToken cancellationToken = default) { Throw.IfNull(client); + Throw.IfNull(toolName); return client.SendRequestAsync( CreateRequest("tools/call", CreateParametersDictionary(toolName, arguments)), cancellationToken); } - /// Gets instances for all of the tools available through the specified . - /// The client for which instances should be created. - /// A token to cancel the operation. - /// A task containing a list of the available functions. - public static async Task> GetAIFunctionsAsync(this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - List functions = []; - await foreach (var tool in client.ListToolsAsync(cancellationToken).ConfigureAwait(false)) - { - functions.Add(AsAIFunction(client, tool)); - } - - return functions; - } - - /// Gets an for invoking via this . - /// The client with which to perform the invocation. - /// The tool to be invoked. - /// An for performing the call. - /// - /// This operation does not validate that is valid for the specified . - /// If the tool is not valid for the client, it will fail when invoked. - /// - public static AIFunction AsAIFunction(this IMcpClient client, Tool tool) - { - Throw.IfNull(client); - Throw.IfNull(tool); - - return new McpAIFunction(client, tool); - } - /// /// Converts the contents of a into a pair of /// and instances to use @@ -428,12 +518,7 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat { Type = "image", MimeType = dc.MediaType, - Data = Convert.ToBase64String(dc.Data -#if NET - .Span), -#else - .ToArray()), -#endif + Data = dc.GetBase64Data(), }; } } @@ -499,7 +584,8 @@ private static JsonRpcRequest CreateRequest(string method, Dictionary? CreateCursorDictionary(string? cursor) => cursor != null ? new() { ["cursor"] = cursor } : null; - private static Dictionary CreateParametersDictionary(string nameParameter, Dictionary? arguments) + private static Dictionary CreateParametersDictionary( + string nameParameter, IReadOnlyDictionary? arguments) { Dictionary parameters = new() { @@ -526,20 +612,20 @@ private sealed class McpAIFunction(IMcpClient client, Tool tool) : AIFunction /// public override JsonElement JsonSchema => tool.InputSchema; + /// + public override JsonSerializerOptions JsonSerializerOptions => McpJsonUtilities.DefaultOptions; + /// protected async override Task InvokeCoreAsync( IEnumerable> arguments, CancellationToken cancellationToken) { - Throw.IfNull(arguments); - - Dictionary argDict = []; - foreach (var arg in arguments) - { - if (arg.Value is not null) - { - argDict[arg.Key] = arg.Value; - } - } + IReadOnlyDictionary argDict = + arguments as IReadOnlyDictionary ?? +#if NET + arguments.ToDictionary(); +#else + arguments.ToDictionary(kv => kv.Key, kv => kv.Value); +#endif CallToolResponse result = await client.CallToolAsync(tool.Name, argDict, cancellationToken).ConfigureAwait(false); return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CallToolResponse); diff --git a/src/ModelContextProtocol/Client/McpClientTool.cs b/src/ModelContextProtocol/Client/McpClientTool.cs new file mode 100644 index 00000000..a2bb172a --- /dev/null +++ b/src/ModelContextProtocol/Client/McpClientTool.cs @@ -0,0 +1,43 @@ +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Utils.Json; +using Microsoft.Extensions.AI; +using System.Text.Json; + +namespace ModelContextProtocol.Client; + +/// Provides an AI function that calls a tool through . +public sealed class McpClientTool : AIFunction +{ + private readonly IMcpClient _client; + private readonly Tool _tool; + + internal McpClientTool(IMcpClient client, Tool tool) + { + _client = client; + _tool = tool; + } + + /// + public override string Name => _tool.Name; + + /// + public override string Description => _tool.Description ?? string.Empty; + + /// + public override JsonElement JsonSchema => _tool.InputSchema; + + /// + public override JsonSerializerOptions JsonSerializerOptions => McpJsonUtilities.DefaultOptions; + + /// + protected async override Task InvokeCoreAsync( + IEnumerable> arguments, CancellationToken cancellationToken) + { + IReadOnlyDictionary argDict = + arguments as IReadOnlyDictionary ?? + arguments.ToDictionary(); + + CallToolResponse result = await _client.CallToolAsync(_tool.Name, argDict, cancellationToken).ConfigureAwait(false); + return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CallToolResponse); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs index 544c8df5..4733fce1 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs @@ -1,12 +1,9 @@ using ModelContextProtocol.Configuration; -using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Utils; -using ModelContextProtocol.Utils.Json; -using Microsoft.Extensions.AI; using System.Diagnostics.CodeAnalysis; using System.Reflection; -using System.Text.Json; +using Microsoft.Extensions.DependencyInjection; namespace ModelContextProtocol; @@ -23,23 +20,17 @@ public static partial class McpServerBuilderExtensions /// The tool type. /// The builder instance. /// is . - public static IMcpServerBuilder WithTool<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods)] TTool>(this IMcpServerBuilder builder) + public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods)] TTool>( + this IMcpServerBuilder builder) { Throw.IfNull(builder); - List functions = []; - PopulateFunctions(typeof(TTool), functions); - return WithTools(builder, functions); - } - /// - /// Adds all tools marked with from the current assembly to the server. - /// - /// The builder instance. - /// is . - [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] - public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder) - { - return WithToolsFromAssembly(builder, Assembly.GetCallingAssembly()); + foreach (var toolMethod in GetToolMethods(typeof(TTool))) + { + builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, services: services)); + } + + return builder; } /// @@ -55,147 +46,42 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params Throw.IfNull(builder); Throw.IfNull(toolTypes); - List functions = []; - foreach (var toolType in toolTypes) { - if (toolType is null) + if (toolType is not null) { - throw new ArgumentNullException(nameof(toolTypes), $"A tool type provided by the enumerator was null."); - } - - PopulateFunctions(toolType, functions); - } - - return WithTools(builder, functions); - } - - /// - /// Adds tools to the server. - /// - /// The builder instance. - /// instances to use as the tools. - /// is . - /// is . - public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params IEnumerable functions) - { - Throw.IfNull(builder); - Throw.IfNull(functions); - - List tools = []; - Dictionary, CancellationToken, Task>> callbacks = []; - - foreach (AIFunction function in functions) - { - if (function is null) - { - throw new ArgumentNullException(nameof(functions), $"A function provided by the enumerator was null."); - } - - tools.Add(new() - { - Name = function.Name, - Description = function.Description, - InputSchema = function.JsonSchema, - }); - - callbacks.Add(function.Name, async (request, cancellationToken) => - { - cancellationToken.ThrowIfCancellationRequested(); - - object? result; - try + foreach (var toolMethod in GetToolMethods(toolType)) { - result = await function.InvokeAsync((request.Params?.Arguments ?? [])!, cancellationToken).ConfigureAwait(false); + builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, services: services)); } - catch (Exception e) when (e is not OperationCanceledException) - { - return new CallToolResponse() - { - IsError = true, - Content = [new() { Text = e.Message, Type = "text" }], - }; - } - - switch (result) - { - case JsonElement je when je.ValueKind == JsonValueKind.Null: - return new() { Content = [] }; - - case JsonElement je when je.ValueKind == JsonValueKind.Array: - return new() { Content = je.EnumerateArray().Select(x => new Content() { Text = x.ToString(), Type = "text" }).ToList() }; - - default: - return new() { Content = [new() { Text = result?.ToString(), Type = "text" }] }; - } - }); - } - - builder.WithListToolsHandler((_, _) => Task.FromResult(new ListToolsResult() { Tools = tools })); - - builder.WithCallToolHandler(async (request, cancellationToken) => - { - if (request.Params is null || !callbacks.TryGetValue(request.Params.Name, out var callback)) - { - throw new McpServerException($"Unknown tool '{request.Params?.Name}'"); } - - return await callback(request, cancellationToken).ConfigureAwait(false); - }); - + } + return builder; } /// - /// Adds types marked with the attribute from the given assembly as tools to the server. + /// Adds types marked with the attribute from the given assembly as tools to the server. /// /// The builder instance. - /// The assembly to load the types from. Null to get the current assembly + /// The assembly to load the types from. Null to get the current assembly /// is . [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] - public static IMcpServerBuilder WithToolsFromAssembly(this IMcpServerBuilder builder, Assembly? assembly = null) + public static IMcpServerBuilder WithToolsFromAssembly(this IMcpServerBuilder builder, Assembly? toolAssembly = null) { - assembly ??= Assembly.GetCallingAssembly(); - - List toolTypes = []; - - foreach (var type in assembly.GetTypes()) - { - if (type.GetCustomAttribute() is null) - { - continue; - } + Throw.IfNull(builder); - foreach (var method in type.GetMethods(BindingFlags.Public | BindingFlags.Static)) - { - if (method.GetCustomAttribute() is not null) - { - toolTypes.Add(type); - break; - } - } - } + toolAssembly ??= Assembly.GetCallingAssembly(); - return toolTypes.Count > 0 ? - WithTools(builder, toolTypes) : - builder; + return builder.WithTools( + from t in toolAssembly.GetTypes() + where t.GetCustomAttribute() is not null + select t); } - private static void PopulateFunctions( - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods)] Type toolType, - List functions) - { - foreach (var method in toolType.GetMethods(BindingFlags.Public | BindingFlags.Static)) - { - if (method.GetCustomAttribute() is not { } attribute) - { - continue; - } - - functions.Add(AIFunctionFactory.Create(method, target: null, new() - { - Name = attribute.Name ?? method.Name, - })); - } - } + private static IEnumerable GetToolMethods( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods)] Type toolType) => + from method in toolType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static) + where method.GetCustomAttribute() is not null + select method; } diff --git a/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs b/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs index 4132140e..687a534b 100644 --- a/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs +++ b/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs @@ -1,5 +1,4 @@ using System.Reflection; -using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using Microsoft.Extensions.Options; using ModelContextProtocol.Utils; @@ -7,30 +6,66 @@ namespace ModelContextProtocol.Configuration; /// -/// Configures the McpServerOptions using provided server handlers. +/// Configures the McpServerOptions using addition services from DI. /// /// The server handlers configuration options. -internal sealed class McpServerOptionsSetup(IOptions serverHandlers) : IConfigureOptions +/// Tools individually registered. +internal sealed class McpServerOptionsSetup( + IOptions serverHandlers, + IEnumerable serverTools) : IConfigureOptions { /// /// Configures the given McpServerOptions instance by setting server information - /// and applying custom server handlers. + /// and applying custom server handlers and tools. /// /// The options instance to be configured. public void Configure(McpServerOptions options) { Throw.IfNull(options); - var assemblyName = Assembly.GetEntryAssembly()?.GetName(); + // Configure the option's server information based on the current process, + // if it otherwise lacks server information. + var assemblyName = (Assembly.GetEntryAssembly() ?? Assembly.GetCallingAssembly()).GetName(); + if (options.ServerInfo is not { } serverInfo || + serverInfo.Name is null || + serverInfo.Version is null) + { + options.ServerInfo = options.ServerInfo is null ? + new() + { + Name = assemblyName.Name ?? "McpServer", + Version = assemblyName.Version?.ToString() ?? "1.0.0", + } : + options.ServerInfo with + { + Name = options.ServerInfo.Name ?? assemblyName.Name ?? "McpServer", + Version = options.ServerInfo.Version ?? assemblyName.Version?.ToString() ?? "1.0.0", + }; + } + + // Collect all of the provided tools into a tools collection. If the options already has + // a collection, add to it, otherwise create a new one. We want to maintain the identity + // of an existing collection in case someone has provided their own derived type, wants + // change notifications, etc. + McpServerToolCollection toolsCollection = options.Capabilities?.Tools?.ToolCollection ?? []; + foreach (var tool in serverTools) + { + toolsCollection.TryAdd(tool); + } - // Set server information based on the entry assembly - options.ServerInfo = new Implementation + if (!toolsCollection.IsEmpty) { - Name = assemblyName?.Name ?? "McpServer", - Version = assemblyName?.Version?.ToString() ?? "1.0.0", - }; + options.Capabilities = options.Capabilities is null ? + new() { Tools = new() { ToolCollection = toolsCollection } } : + options.Capabilities with + { + Tools = options.Capabilities.Tools is null ? + new() { ToolCollection = toolsCollection } : + options.Capabilities.Tools with { ToolCollection = toolsCollection }, + }; + } - // Apply custom server handlers + // Apply custom server handlers. serverHandlers.Value.OverwriteWithSetHandlers(options); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs index 9b82d4ea..e1f517e5 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs @@ -25,6 +25,7 @@ public sealed class StdioServerTransport : TransportBase, IServerTransport private readonly TextReader _stdInReader; private readonly Stream _stdOutStream; + private SemaphoreSlim _sendLock = new(1, 1); private Task? _readTask; private CancellationTokenSource? _shutdownCts; @@ -137,6 +138,8 @@ public Task StartListeningAsync(CancellationToken cancellationToken = default) /// public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { + using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); + if (!IsConnected) { _logger.TransportNotConnected(EndpointName); diff --git a/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs index 9811f99b..a33d3aac 100644 --- a/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs @@ -16,5 +16,5 @@ public class CallToolRequestParams /// Optional arguments to pass to the tool. /// [System.Text.Json.Serialization.JsonPropertyName("arguments")] - public Dictionary? Arguments { get; init; } + public Dictionary? Arguments { get; init; } } diff --git a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs index 0e4b88f9..a82afee0 100644 --- a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs @@ -155,7 +155,7 @@ public record ResourcesCapability public record ToolsCapability { /// - /// Whether this server supports notifications for changes to the tool list. + /// Gets or sets whether this server supports notifications for changes to the tool list. /// [JsonPropertyName("listChanged")] public bool? ListChanged { get; init; } @@ -171,4 +171,16 @@ public record ToolsCapability /// [JsonIgnore] public Func, CancellationToken, Task>? CallToolHandler { get; init; } + + /// Gets or sets a collection of tools served by the server. + /// + /// Tools will specified via augment the and + /// , if provided. ListTools requests will output information about every tool + /// in and then also any tools output by , if it's + /// non-. CallTool requests will first check for the tool + /// being requested, and if the tool is not found in the , any specified + /// will be invoked as a fallback. + /// + [JsonIgnore] + public McpServerToolCollection? ToolCollection { get; init; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/ResourceTemplate.cs b/src/ModelContextProtocol/Protocol/Types/ResourceTemplate.cs index 4bf6e60c..f5731188 100644 --- a/src/ModelContextProtocol/Protocol/Types/ResourceTemplate.cs +++ b/src/ModelContextProtocol/Protocol/Types/ResourceTemplate.cs @@ -1,6 +1,4 @@ -using ModelContextProtocol.Protocol.Types; - -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; diff --git a/src/ModelContextProtocol/Protocol/Types/Tool.cs b/src/ModelContextProtocol/Protocol/Types/Tool.cs index a36cef2e..3349da15 100644 --- a/src/ModelContextProtocol/Protocol/Types/Tool.cs +++ b/src/ModelContextProtocol/Protocol/Types/Tool.cs @@ -10,6 +10,8 @@ namespace ModelContextProtocol.Protocol.Types; /// public class Tool { + private JsonElement _inputSchema = McpJsonUtilities.DefaultMcpToolSchema; + /// /// The name of the tool. /// @@ -42,6 +44,4 @@ public JsonElement InputSchema _inputSchema = value; } } - - private JsonElement _inputSchema = McpJsonUtilities.DefaultMcpToolSchema; } diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs new file mode 100644 index 00000000..d3fbd93c --- /dev/null +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -0,0 +1,231 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Reflection; +using System.Text.Json; + +namespace ModelContextProtocol.Server; + +/// Provides an that's implemented via an . +internal sealed class AIFunctionMcpServerTool : McpServerTool +{ + /// Key used temporarily for flowing request context into an AIFunction. + /// This will be replaced with use of AIFunctionArguments.Context. + private const string RequestContextKey = "__temporary_RequestContext"; + + /// + /// Creates an instance for a method, specified via a instance. + /// + public static new AIFunctionMcpServerTool Create( + Delegate method, + string? name, + string? description, + IServiceProvider? services) + { + Throw.IfNull(method); + + return Create(method.Method, method.Target, name, description, services); + } + + /// + /// Creates an instance for a method, specified via a instance. + /// + public static new AIFunctionMcpServerTool Create( + MethodInfo method, + object? target, + string? name, + string? description, + IServiceProvider? services) + { + Throw.IfNull(method); + + // TODO: Once this repo consumes a new build of Microsoft.Extensions.AI containing + // https://github.com/dotnet/extensions/pull/6158, + // https://github.com/dotnet/extensions/pull/6162, and + // https://github.com/dotnet/extensions/pull/6175, switch over to using the real + // AIFunctionFactory, delete the TemporaryXx types, and fix-up the mechanism by + // which the arguments are passed. + + return Create(TemporaryAIFunctionFactory.Create(method, target, new TemporaryAIFunctionFactoryOptions() + { + Name = name ?? method.GetCustomAttribute()?.Name, + Description = description, + MarshalResult = static (result, _, cancellationToken) => Task.FromResult(result), + ConfigureParameterBinding = pi => + { + if (pi.ParameterType == typeof(RequestContext)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => GetRequestContext(args), + }; + } + + if (pi.ParameterType == typeof(IMcpServer)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => GetRequestContext(args)?.Server, + }; + } + + // We assume that if the services used to create the tool support a particular type, + // so too do the services associated with the server. This is the same basic assumption + // made in ASP.NET. + if (services is not null && + services.GetService() is { } ispis && + ispis.IsService(pi.ParameterType)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + GetRequestContext(args)?.Server?.Services?.GetService(pi.ParameterType) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } + + if (pi.GetCustomAttribute() is { } keyedAttr) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + (GetRequestContext(args)?.Server?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } + + return default; + + static RequestContext? GetRequestContext(IReadOnlyDictionary args) + { + if (args.TryGetValue(RequestContextKey, out var orc) && + orc is RequestContext requestContext) + { + return requestContext; + } + + return null; + } + }, + })); + } + + /// Creates an that wraps the specified . + public static new AIFunctionMcpServerTool Create(AIFunction function) + { + Throw.IfNull(function); + + return new AIFunctionMcpServerTool(function); + } + + /// Gets the wrapped by this tool. + internal AIFunction AIFunction { get; } + + /// Initializes a new instance of the class. + private AIFunctionMcpServerTool(AIFunction function) + { + AIFunction = function; + ProtocolTool = new() + { + Name = function.Name, + Description = function.Description, + InputSchema = function.JsonSchema, + }; + } + + /// + public override string ToString() => AIFunction.ToString(); + + /// + public override Tool ProtocolTool { get; } + + /// + public override async Task InvokeAsync( + RequestContext request, CancellationToken cancellationToken = default) + { + Throw.IfNull(request); + + cancellationToken.ThrowIfCancellationRequested(); + + // TODO: Once we shift to the real AIFunctionFactory, the request should be passed via AIFunctionArguments.Context. + Dictionary arguments = request.Params?.Arguments is IDictionary existingArgs ? + new(existingArgs) : + []; + arguments[RequestContextKey] = request; + + object? result; + try + { + result = await AIFunction.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); + } + catch (Exception e) when (e is not OperationCanceledException) + { + return new CallToolResponse() + { + IsError = true, + Content = [new() { Text = e.Message, Type = "text" }], + }; + } + + switch (result) + { + case null: + return new() + { + Content = [] + }; + + case string text: + return new() + { + Content = [new() { Text = text, Type = "text" }] + }; + + case TextContent textContent: + return new() + { + Content = [new() { Text = textContent.Text, Type = "text" }] + }; + + case DataContent dataContent: + return new() + { + Content = [new() + { + Data = dataContent.GetBase64Data(), + MimeType = dataContent.MediaType, + Type = dataContent.HasTopLevelMediaType("image") ? "image" : "resource", + }] + }; + + case string[] texts: + return new() + { + Content = texts + .Select(x => new Content() { Type = "text", Text = x ?? string.Empty }) + .ToList() + }; + + // TODO https://github.com/modelcontextprotocol/csharp-sdk/issues/69: + // Add specialization for annotations. + + default: + return new() + { + Content = [new() + { + Text = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))), + Type = "text" + }] + }; + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/DelegatingMcpServerTool.cs b/src/ModelContextProtocol/Server/DelegatingMcpServerTool.cs new file mode 100644 index 00000000..d4555d71 --- /dev/null +++ b/src/ModelContextProtocol/Server/DelegatingMcpServerTool.cs @@ -0,0 +1,34 @@ +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Utils; + +namespace ModelContextProtocol.Server; + +/// Provides an that delegates all operations to an inner . +/// +/// This is recommended as a base type when building tools that can be chained around an underlying . +/// The default implementation simply passes each call to the inner tool instance. +/// +public abstract class DelegatingMcpServerTool : McpServerTool +{ + private readonly McpServerTool _innerTool; + + /// Initializes a new instance of the class around the specified . + /// The inner tool wrapped by this delegating tool. + protected DelegatingMcpServerTool(McpServerTool innerTool) + { + Throw.IfNull(innerTool); + _innerTool = innerTool; + } + + /// + public override Tool ProtocolTool => _innerTool.ProtocolTool; + + /// + public override Task InvokeAsync( + RequestContext request, + CancellationToken cancellationToken = default) => + _innerTool.InvokeAsync(request, cancellationToken); + + /// + public override string ToString() => _innerTool.ToString(); +} diff --git a/src/ModelContextProtocol/Server/IMcpServer.cs b/src/ModelContextProtocol/Server/IMcpServer.cs index 7168a41c..08d4d990 100644 --- a/src/ModelContextProtocol/Server/IMcpServer.cs +++ b/src/ModelContextProtocol/Server/IMcpServer.cs @@ -23,10 +23,13 @@ public interface IMcpServer : IAsyncDisposable /// Implementation? ClientInfo { get; } + /// Gets the options used to construct this server. + McpServerOptions ServerOptions { get; } + /// /// Gets the service provider for the server. /// - IServiceProvider? ServiceProvider { get; } + IServiceProvider? Services { get; } /// /// Adds a handler for client notifications of a specific method. diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index c2fdb074..a7779947 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -1,12 +1,11 @@ -using ModelContextProtocol.Logging; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Shared; using ModelContextProtocol.Utils; - -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; - using System.Text.Json.Nodes; namespace ModelContextProtocol.Server; @@ -15,9 +14,9 @@ namespace ModelContextProtocol.Server; internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer { private readonly IServerTransport? _serverTransport; - private readonly McpServerOptions _options; - private volatile bool _isInitializing; private readonly ILogger _logger; + private readonly string _serverDescription; + private volatile bool _isInitializing; /// /// Creates a new instance of . @@ -34,10 +33,10 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? Throw.IfNull(options); _serverTransport = transport as IServerTransport; - _options = options; _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; ServerInstructions = options.ServerInstructions; - ServiceProvider = serviceProvider; + Services = serviceProvider; + _serverDescription = $"{options.ServerInfo.Name} {options.ServerInfo.Version}"; AddNotificationHandler("notifications/initialized", _ => { @@ -45,13 +44,16 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? return Task.CompletedTask; }); + SetToolsHandler(ref options); + SetInitializeHandler(options); SetCompletionHandler(options); SetPingHandler(); - SetToolsHandler(options); SetPromptsHandler(options); SetResourcesHandler(options); SetSetLoggingLevelHandler(options); + + ServerOptions = options; } public ClientCapabilities? ClientCapabilities { get; set; } @@ -63,11 +65,14 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? public string? ServerInstructions { get; set; } /// - public IServiceProvider? ServiceProvider { get; } + public McpServerOptions ServerOptions { get; } + + /// + public IServiceProvider? Services { get; } /// public override string EndpointName => - $"Server ({_options.ServerInfo.Name} {_options.ServerInfo.Version}), Client ({ClientInfo?.Name} {ClientInfo?.Version})"; + $"Server ({_serverDescription}), Client ({ClientInfo?.Name} {ClientInfo?.Version})"; /// public async Task StartAsync(CancellationToken cancellationToken = default) @@ -125,7 +130,7 @@ private void SetInitializeHandler(McpServerOptions options) { ProtocolVersion = options.ProtocolVersion, Instructions = ServerInstructions, - ServerInfo = _options.ServerInfo, + ServerInfo = options.ServerInfo, Capabilities = options.Capabilities ?? new ServerCapabilities(), }); }); @@ -195,17 +200,110 @@ private void SetPromptsHandler(McpServerOptions options) SetRequestHandler("prompts/get", (request, ct) => getPromptHandler(new(this, request), ct)); } - private void SetToolsHandler(McpServerOptions options) + private void SetToolsHandler(ref McpServerOptions options) { - if (options.Capabilities?.Tools is not { } toolsCapability) + ToolsCapability? toolsCapability = options.Capabilities?.Tools; + var listToolsHandler = toolsCapability?.ListToolsHandler; + var callToolHandler = toolsCapability?.CallToolHandler; + var tools = toolsCapability?.ToolCollection; + + if (listToolsHandler is null != callToolHandler is null) { - return; + throw new McpServerException("ListTools and CallTool handlers should be specified together."); } - if (toolsCapability.ListToolsHandler is not { } listToolsHandler || - toolsCapability.CallToolHandler is not { } callToolHandler) + // Handle tools provided via DI. + if (tools is { IsEmpty: false }) + { + var originalListToolsHandler = listToolsHandler; + var originalCallToolHandler = callToolHandler; + + // Synthesize the handlers, making sure a ToolsCapability is specified. + listToolsHandler = async (request, cancellationToken) => + { + ListToolsResult result = new(); + foreach (McpServerTool tool in tools) + { + result.Tools.Add(tool.ProtocolTool); + } + + if (originalListToolsHandler is not null) + { + string? nextCursor = null; + do + { + ListToolsResult extraResults = await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false); + result.Tools.AddRange(extraResults.Tools); + + nextCursor = extraResults.NextCursor; + if (nextCursor is not null) + { + request = request with { Params = new() { Cursor = nextCursor } }; + } + } + while (nextCursor is not null); + } + + return result; + }; + + callToolHandler = (request, cancellationToken) => + { + if (request.Params is null || + !tools.TryGetTool(request.Params.Name, out var tool)) + { + if (originalCallToolHandler is not null) + { + return originalCallToolHandler(request, cancellationToken); + } + + throw new McpServerException($"Unknown tool '{request.Params?.Name}'"); + } + + return tool.InvokeAsync(request, cancellationToken); + }; + + toolsCapability = toolsCapability is null ? + new() + { + CallToolHandler = callToolHandler, + ListToolsHandler = listToolsHandler, + ToolCollection = tools, + ListChanged = true, + } : + toolsCapability with + { + CallToolHandler = callToolHandler, + ListToolsHandler = listToolsHandler, + ToolCollection = tools, + ListChanged = true, + }; + + options.Capabilities = options.Capabilities is null ? + new() { Tools = toolsCapability } : + options.Capabilities with { Tools = toolsCapability }; + + tools.Changed += delegate + { + _ = SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.ToolListChangedNotification, + }); + }; + } + else { - throw new McpServerException("ListTools and/or CallTool handlers were specified but the Tools capability was not enabled."); + if (toolsCapability is null) + { + // No tools, and no tools capability was declared, so nothing to do. + return; + } + + // Make sure the handlers are provided if the capability is enabled. + if (listToolsHandler is null || callToolHandler is null) + { + throw new McpServerException("ListTools and/or CallTool handlers were not specified but the Tools capability was enabled."); + } } SetRequestHandler("tools/list", (request, ct) => listToolsHandler(new(this, request), ct)); @@ -226,4 +324,4 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) SetRequestHandler("logging/setLevel", (request, ct) => setLoggingLevelHandler(new(this, request), ct)); } -} +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 3f746004..73bd528d 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -96,12 +96,7 @@ public static async Task RequestSamplingAsync( { Type = "image", MimeType = dataContent.MediaType, - Data = Convert.ToBase64String(dataContent.Data. -#if NET - Span), -#else - ToArray()), -#endif + Data = dataContent.GetBase64Data(), }, }); break; @@ -126,33 +121,7 @@ public static async Task RequestSamplingAsync( ModelPreferences = modelPreferences, }, cancellationToken).ConfigureAwait(false); - ChatMessage responseMessage = new() - { - Role = result.Role == "user" ? ChatRole.User : ChatRole.Assistant - }; - - if (result.Content is { Type: "text" }) - { - responseMessage.Contents.Add(new TextContent(result.Content.Text)); - } - else if (result.Content is { Type: "image", MimeType: not null, Data: not null }) - { - responseMessage.Contents.Add(new DataContent(Convert.FromBase64String(result.Content.Data), result.Content.MimeType)); - } - else if (result.Content is { Type: "resource" } && result.Content.Resource is { } resourceContents) - { - if (resourceContents.Text is not null) - { - responseMessage.Contents.Add(new TextContent(resourceContents.Text)); - } - - if (resourceContents.Blob is not null && resourceContents.MimeType is not null) - { - responseMessage.Contents.Add(new DataContent(Convert.FromBase64String(resourceContents.Blob), resourceContents.MimeType)); - } - } - - return new(responseMessage) + return new(new ChatMessage(new(result.Role), [result.Content.ToAIContent()])) { ModelId = result.Model, FinishReason = result.StopReason switch diff --git a/src/ModelContextProtocol/Server/McpServerTool.cs b/src/ModelContextProtocol/Server/McpServerTool.cs new file mode 100644 index 00000000..f6122764 --- /dev/null +++ b/src/ModelContextProtocol/Server/McpServerTool.cs @@ -0,0 +1,97 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Protocol.Types; +using System.ComponentModel; +using System.Reflection; + +namespace ModelContextProtocol.Server; + +/// Represents an invocable tool used by Model Context Protocol clients and servers. +public abstract class McpServerTool +{ + /// Initializes a new instance of the class. + protected McpServerTool() + { + } + + /// Gets the protocol type for this instance. + public abstract Tool ProtocolTool { get; } + + /// Invokes the . + /// The request information resulting in the invocation of this tool. + /// The to monitor for cancellation requests. The default is . + /// The call response from invoking the tool. + /// is . + public abstract Task InvokeAsync( + RequestContext request, + CancellationToken cancellationToken = default); + + /// + /// Creates an instance for a method, specified via a instance. + /// + /// The method to be represented via the created . + /// + /// The name to use for the . If , but an + /// is applied to , the name from the attribute will be used. If that's not present, the name based + /// on 's name will be used. + /// + /// + /// The description to use for the . If , but a + /// is applied to , the description from that attribute will be used. + /// + /// + /// Optional services used in the construction of the . These services will be + /// used to determine which parameters should be satisifed from dependency injection, and so what services + /// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time. + /// + /// The created for invoking . + /// is . + public static McpServerTool Create( + Delegate method, + string? name = null, + string? description = null, + IServiceProvider? services = null) => + AIFunctionMcpServerTool.Create(method, name, description, services); + + /// + /// Creates an instance for a method, specified via a instance. + /// + /// The method to be represented via the created . + /// The instance if is an instance method; otherwise, . + /// + /// The name to use for the . If , but an + /// is applied to , the name from the attribute will be used. If that's not present, the name based + /// on 's name will be used. + /// + /// + /// The description to use for the . If , but a + /// is applied to , the description from that attribute will be used. + /// + /// + /// Optional services used in the construction of the . These services will be + /// used to determine which parameters should be satisifed from dependency injection, and so what services + /// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time. + /// + /// The created for invoking . + /// is . + /// is an instance method but is . + public static McpServerTool Create( + MethodInfo method, + object? target = null, + string? name = null, + string? description = null, + IServiceProvider? services = null) => + AIFunctionMcpServerTool.Create(method, target, name, description, services); + + /// Creates an that wraps the specified . + /// The function to wrap. + /// is . + /// + /// Unlike the other overloads of Create, the created by + /// does not provide all of the special parameter handling for MCP-specific concepts, like . + /// + public static McpServerTool Create(AIFunction function) => + AIFunctionMcpServerTool.Create(function); + + /// + public override string ToString() => ProtocolTool.Name; +} diff --git a/src/ModelContextProtocol/Server/McpServerToolAttribute.cs b/src/ModelContextProtocol/Server/McpServerToolAttribute.cs new file mode 100644 index 00000000..d1489f6d --- /dev/null +++ b/src/ModelContextProtocol/Server/McpServerToolAttribute.cs @@ -0,0 +1,18 @@ +namespace ModelContextProtocol.Server; + +/// +/// Used to mark a public method as an MCP tool. +/// +[AttributeUsage(AttributeTargets.Method)] +public sealed class McpServerToolAttribute : Attribute +{ + /// Gets the name of the tool. + /// If , the method name will be used. + public string? Name { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The name of the tool. If , the method name will be used. + public McpServerToolAttribute(string? name = null) => Name = name; +} diff --git a/src/ModelContextProtocol/Server/McpServerToolCollection.cs b/src/ModelContextProtocol/Server/McpServerToolCollection.cs new file mode 100644 index 00000000..f5234aa9 --- /dev/null +++ b/src/ModelContextProtocol/Server/McpServerToolCollection.cs @@ -0,0 +1,164 @@ +using ModelContextProtocol.Utils; +using System.Collections; +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; + +namespace ModelContextProtocol.Server; + +/// Provides a thread-safe collection of instances, indexed by their names. +public class McpServerToolCollection : ICollection, IReadOnlyCollection +{ + /// Concurrent dictionary of tools, indexed by their names. + private readonly ConcurrentDictionary _tools = []; + + /// + /// Initializes a new instance of the class. + /// + public McpServerToolCollection() + { + } + + /// Occurs when the collection is changed. + /// + /// By default, this is raised when a tool is added or removed. However, a derived implementation + /// may raise this event for other reasons, such as when a tool is modified. + /// + public event EventHandler? Changed; + + /// Gets the number of tools in the collection. + public int Count => _tools.Count; + + /// Gets whether there are any tools in the collection. + public bool IsEmpty => _tools.IsEmpty; + + /// Raises if there are registered handlers. + protected void RaiseChanged() => Changed?.Invoke(this, EventArgs.Empty); + + /// Gets the with the specified from the collection. + /// The name of the tool to retrieve. + /// The with the specified name. + /// is . + /// A tool with the specified name does not exist in the collection. + public McpServerTool this[string name] + { + get + { + Throw.IfNull(name); + return _tools[name]; + } + } + + /// Clears all tools from the collection. + public virtual void Clear() + { + _tools.Clear(); + RaiseChanged(); + } + + /// Adds the specified to the collection. + /// The tool to be added. + /// is . + /// A tool with the same name as already exists in the collection. + public void Add(McpServerTool tool) + { + if (!TryAdd(tool)) + { + throw new ArgumentException($"A tool with the same name '{tool.ProtocolTool.Name}' already exists in the collection.", nameof(tool)); + } + } + + /// Adds the specified to the collection. + /// The tool to be added. + /// if the tool was added; otherwise, . + /// is . + public virtual bool TryAdd(McpServerTool tool) + { + Throw.IfNull(tool); + + bool added = _tools.TryAdd(tool.ProtocolTool.Name, tool); + if (added) + { + RaiseChanged(); + } + + return added; + } + + /// Removes the specified toolfrom the collection. + /// The tool to be removed from the collection. + /// + /// if the tool was found in the collection and removed; otherwise, if it couldn't be found. + /// + /// is . + public virtual bool Remove(McpServerTool tool) + { + Throw.IfNull(tool); + + bool removed = ((ICollection>)_tools).Remove(new(tool.ProtocolTool.Name, tool)); + if (removed) + { + RaiseChanged(); + } + + return removed; + } + + /// Attempts to get the tool with the specified name from the collection. + /// The name of the tool to retrieve. + /// The tool, if found; otherwise, . + /// + /// if the tool was found in the collection and return; otherwise, if it couldn't be found. + /// + /// is . + public virtual bool TryGetTool(string name, [NotNullWhen(true)] out McpServerTool? tool) + { + Throw.IfNull(name); + return _tools.TryGetValue(name, out tool); + } + + /// Checks if a specific tool is present in the collection of tools. + /// The tool to search for in the collection. + /// if the tool was found in the collection and return; otherwise, if it couldn't be found. + /// is . + public virtual bool Contains(McpServerTool tool) + { + Throw.IfNull(tool); + return ((ICollection>)_tools).Contains(new(tool.ProtocolTool.Name, tool)); + } + + /// Gets the names of all of the tools in the collection. + public virtual ICollection ToolNames => _tools.Keys; + + /// Creates an array containing all of the tools in the collection. + /// An array containing all of the tools in the collection. + public virtual McpServerTool[] ToArray() => _tools.Values.ToArray(); + + /// + public virtual void CopyTo(McpServerTool[] array, int arrayIndex) + { + Throw.IfNull(array); + + foreach (var entry in _tools) + { + array[arrayIndex++] = entry.Value; + } + } + + /// + public virtual IEnumerator GetEnumerator() + { + foreach (var entry in _tools) + { + yield return entry.Value; + } + } + + /// + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// + bool ICollection.IsReadOnly => false; +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/McpServerToolTypeAttribute.cs b/src/ModelContextProtocol/Server/McpServerToolTypeAttribute.cs new file mode 100644 index 00000000..cb160633 --- /dev/null +++ b/src/ModelContextProtocol/Server/McpServerToolTypeAttribute.cs @@ -0,0 +1,7 @@ +namespace ModelContextProtocol.Server; + +/// +/// Used to attribute a type containing methods that should be exposed as MCP tools. +/// +[AttributeUsage(AttributeTargets.Class)] +public sealed class McpServerToolTypeAttribute : Attribute; diff --git a/src/ModelContextProtocol/Server/McpToolAttribute.cs b/src/ModelContextProtocol/Server/McpToolAttribute.cs deleted file mode 100644 index 26fb140c..00000000 --- a/src/ModelContextProtocol/Server/McpToolAttribute.cs +++ /dev/null @@ -1,20 +0,0 @@ -namespace ModelContextProtocol.Server; - -/// -/// Attribute to mark a method as an MCP tool. -/// -[AttributeUsage(AttributeTargets.Method)] -public sealed class McpToolAttribute : Attribute -{ - /// Gets the name of the tool. - /// If not provided, the method name will be used. - public string? Name { get; } - /// - /// Attribute to mark a method as an MCP tool. - /// - /// The name of the tool. If not provided, the method name will be used. - public McpToolAttribute(string? name = null) - { - Name = name; - } -} diff --git a/src/ModelContextProtocol/Server/McpToolTypeAttribute.cs b/src/ModelContextProtocol/Server/McpToolTypeAttribute.cs deleted file mode 100644 index 798dfd54..00000000 --- a/src/ModelContextProtocol/Server/McpToolTypeAttribute.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace ModelContextProtocol.Server; - -/// -/// Attribute to mark a type as container for MCP tools. -/// -[AttributeUsage(AttributeTargets.Class)] -public sealed class McpToolTypeAttribute : Attribute; diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.Utilities.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.Utilities.cs new file mode 100644 index 00000000..d0596f82 --- /dev/null +++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.Utilities.cs @@ -0,0 +1,135 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using ModelContextProtocol.Utils; +using System.Buffers; +using System.Reflection; +using System.Text.RegularExpressions; + +namespace Microsoft.Extensions.AI; + +internal static partial class TemporaryAIFunctionFactory +{ + /// + /// Removes characters from a .NET member name that shouldn't be used in an AI function name. + /// + /// The .NET member name that should be sanitized. + /// + /// Replaces non-alphanumeric characters in the identifier with the underscore character. + /// Primarily intended to remove characters produced by compiler-generated method name mangling. + /// + internal static string SanitizeMemberName(string memberName) + { + Throw.IfNull(memberName); + return InvalidNameCharsRegex().Replace(memberName, "_"); + } + + /// Regex that flags any character other than ASCII digits or letters or the underscore. +#if NET + [GeneratedRegex("[^0-9A-Za-z_]")] + private static partial Regex InvalidNameCharsRegex(); +#else + private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex; + private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled); +#endif + + /// Invokes the MethodInfo with the specified target object and arguments. + private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments) + { +#if NET + return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null); +#else + try + { + return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null); + } + catch (TargetInvocationException e) when (e.InnerException is not null) + { + // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions + // is ignored, the original exception will be wrapped in a TargetInvocationException. + // Unwrap it and throw that original exception, maintaining its stack information. + System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); + throw; + } +#endif + } + + /// + /// Implements a simple write-only memory stream that uses pooled buffers. + /// + private sealed class PooledMemoryStream : Stream + { + private const int DefaultBufferSize = 4096; + private byte[] _buffer; + private int _position; + + public PooledMemoryStream(int initialCapacity = DefaultBufferSize) + { + _buffer = ArrayPool.Shared.Rent(initialCapacity); + _position = 0; + } + + public ReadOnlySpan GetBuffer() => _buffer.AsSpan(0, _position); + public override bool CanWrite => true; + public override bool CanRead => false; + public override bool CanSeek => false; + public override long Length => _position; + public override long Position + { + get => _position; + set => throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + EnsureNotDisposed(); + EnsureCapacity(_position + count); + + Buffer.BlockCopy(buffer, offset, _buffer, _position, count); + _position += count; + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + protected override void Dispose(bool disposing) + { + if (_buffer is not null) + { + ArrayPool.Shared.Return(_buffer); + _buffer = null!; + } + + base.Dispose(disposing); + } + + private void EnsureCapacity(int requiredCapacity) + { + if (requiredCapacity <= _buffer.Length) + { + return; + } + + int newCapacity = Math.Max(requiredCapacity, _buffer.Length * 2); + byte[] newBuffer = ArrayPool.Shared.Rent(newCapacity); + Buffer.BlockCopy(_buffer, 0, newBuffer, 0, _position); + + ArrayPool.Shared.Return(_buffer); + _buffer = newBuffer; + } + + private void EnsureNotDisposed() + { + if (_buffer is null) + { + Throw(); + static void Throw() => throw new ObjectDisposedException(nameof(PooledMemoryStream)); + } + } + } +} diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs new file mode 100644 index 00000000..bf0ae8ae --- /dev/null +++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs @@ -0,0 +1,667 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using ModelContextProtocol.Utils; +using System.Collections.Concurrent; +using System.ComponentModel; +using System.Diagnostics; +#if !NET +using System.Linq; +#endif +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; +using static ModelContextProtocol.Utils.Json.McpJsonUtilities; + +#pragma warning disable CA1031 // Do not catch general exception types +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields +#pragma warning disable SA1118 // Parameter should not span multiple lines +#pragma warning disable SA1500 // Braces for multi-line statements should not share line + +namespace Microsoft.Extensions.AI; + +/// Provides factory methods for creating commonly used implementations of . +internal static partial class TemporaryAIFunctionFactory +{ + /// Holds the default options instance used when creating function. + private static readonly TemporaryAIFunctionFactoryOptions _defaultOptions = new(); + + /// Creates an instance for a method, specified via a delegate. + /// The method to be represented via the created . + /// Metadata to use to override defaults inferred from . + /// The created for invoking . + /// + /// + /// Return values are serialized to using 's + /// . Arguments that are not already of the expected type are + /// marshaled to the expected type via JSON and using 's + /// . If the argument is a , + /// , or , it is deserialized directly. If the argument is anything else unknown, + /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. + /// + /// + /// is . + public static AIFunction Create(Delegate method, TemporaryAIFunctionFactoryOptions? options) + { + Throw.IfNull(method); + + return ReflectionAIFunction.Build(method.Method, method.Target, options ?? _defaultOptions); + } + + /// Creates an instance for a method, specified via a delegate. + /// The method to be represented via the created . + /// The name to use for the . + /// The description to use for the . + /// The used to marshal function parameters and any return value. + /// The created for invoking . + /// + /// + /// Return values are serialized to using . + /// Arguments that are not already of the expected type are marshaled to the expected type via JSON and using + /// . If the argument is a , , + /// or , it is deserialized directly. If the argument is anything else unknown, it is + /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. + /// + /// + /// is . + public static AIFunction Create(Delegate method, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null) + { + Throw.IfNull(method); + + TemporaryAIFunctionFactoryOptions createOptions = serializerOptions is null && name is null && description is null + ? _defaultOptions + : new() + { + Name = name, + Description = description, + SerializerOptions = serializerOptions, + }; + + return ReflectionAIFunction.Build(method.Method, method.Target, createOptions); + } + + /// + /// Creates an instance for a method, specified via an instance + /// and an optional target object if the method is an instance method. + /// + /// The method to be represented via the created . + /// + /// The target object for the if it represents an instance method. + /// This should be if and only if is a static method. + /// + /// Metadata to use to override defaults inferred from . + /// The created for invoking . + /// + /// + /// Return values are serialized to using 's + /// . Arguments that are not already of the expected type are + /// marshaled to the expected type via JSON and using 's + /// . If the argument is a , + /// , or , it is deserialized directly. If the argument is anything else unknown, + /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. + /// + /// + /// is . + public static AIFunction Create(MethodInfo method, object? target, TemporaryAIFunctionFactoryOptions? options) + { + Throw.IfNull(method); + return ReflectionAIFunction.Build(method, target, options ?? _defaultOptions); + } + + /// + /// Creates an instance for a method, specified via an instance + /// and an optional target object if the method is an instance method. + /// + /// The method to be represented via the created . + /// + /// The target object for the if it represents an instance method. + /// This should be if and only if is a static method. + /// + /// The name to use for the . + /// The description to use for the . + /// The used to marshal function parameters and return value. + /// The created for invoking . + /// + /// + /// Return values are serialized to using . + /// Arguments that are not already of the expected type are marshaled to the expected type via JSON and using + /// . If the argument is a , , + /// or , it is deserialized directly. If the argument is anything else unknown, it is + /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. + /// + /// + /// is . + public static AIFunction Create(MethodInfo method, object? target, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null) + { + Throw.IfNull(method); + + TemporaryAIFunctionFactoryOptions createOptions = serializerOptions is null && name is null && description is null + ? _defaultOptions + : new() + { + Name = name, + Description = description, + SerializerOptions = serializerOptions, + }; + + return ReflectionAIFunction.Build(method, target, createOptions); + } + + private sealed class ReflectionAIFunction : AIFunction + { + public static ReflectionAIFunction Build(MethodInfo method, object? target, TemporaryAIFunctionFactoryOptions options) + { + Throw.IfNull(method); + + if (method.ContainsGenericParameters) + { + throw new ArgumentException("Open generic methods are not supported", nameof(method)); + } + + if (!method.IsStatic && target is null) + { + throw new ArgumentNullException("Target must not be null for an instance method.", nameof(target)); + } + + var functionDescriptor = ReflectionAIFunctionDescriptor.GetOrCreate(method, options); + + if (target is null && options.AdditionalProperties is null) + { + // We can use a cached value for static methods not specifying additional properties. + return functionDescriptor.CachedDefaultInstance ??= new(functionDescriptor, target, options); + } + + return new(functionDescriptor, target, options); + } + + private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, object? target, TemporaryAIFunctionFactoryOptions options) + { + FunctionDescriptor = functionDescriptor; + Target = target; + AdditionalProperties = options.AdditionalProperties ?? new Dictionary(); + } + + public ReflectionAIFunctionDescriptor FunctionDescriptor { get; } + public object? Target { get; } + public override IReadOnlyDictionary AdditionalProperties { get; } + public override string Name => FunctionDescriptor.Name; + public override string Description => FunctionDescriptor.Description; + public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method; + public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema; + public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; + + protected override Task InvokeCoreAsync( + IEnumerable> arguments, + CancellationToken cancellationToken) + { + var paramMarshallers = FunctionDescriptor.ParameterMarshallers; + object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : []; + + Dictionary argumentsDictionary = arguments.ToDictionary(); + + for (int i = 0; i < args.Length; i++) + { + args[i] = paramMarshallers[i](argumentsDictionary, cancellationToken); + } + + return FunctionDescriptor.ReturnParameterMarshaller( + ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken); + } + } + + /// + /// A descriptor for a .NET method-backed AIFunction that precomputes its marshalling delegates and JSON schema. + /// + private sealed class ReflectionAIFunctionDescriptor + { + private const int InnerCacheSoftLimit = 512; + private static readonly ConditionalWeakTable> _descriptorCache = new(); + + /// A boxed . + private static readonly object? _boxedDefaultCancellationToken = default(CancellationToken); + + /// + /// Gets or creates a descriptors using the specified method and options. + /// + public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, TemporaryAIFunctionFactoryOptions options) + { + JsonSerializerOptions serializerOptions = options.SerializerOptions ?? AIJsonUtilities.DefaultOptions; + AIJsonSchemaCreateOptions schemaOptions = options.JsonSchemaCreateOptions ?? AIJsonSchemaCreateOptions.Default; + serializerOptions.MakeReadOnly(); + ConcurrentDictionary innerCache = _descriptorCache.GetOrCreateValue(serializerOptions); + + DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, schemaOptions); + if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor)) + { + return descriptor; + } + + descriptor = new(key, serializerOptions); + return innerCache.Count < InnerCacheSoftLimit + ? innerCache.GetOrAdd(key, descriptor) + : descriptor; + } + + private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions) + { + ParameterInfo[] parameters = key.Method.GetParameters(); + + // Determine how each parameter should be bound. + Dictionary? boundParameters = null; + if (parameters.Length != 0 && key.GetBindParameterOptions is not null) + { + boundParameters = new(parameters.Length); + for (int i = 0; i < parameters.Length; i++) + { + boundParameters[parameters[i]] = key.GetBindParameterOptions(parameters[i]); + } + } + + // Get marshaling delegates for parameters. + ParameterMarshallers = parameters.Length > 0 ? new Func, CancellationToken, object?>[parameters.Length] : []; + for (int i = 0; i < parameters.Length; i++) + { + if (boundParameters?.TryGetValue(parameters[i], out TemporaryAIFunctionFactoryOptions.ParameterBindingOptions options) is not true) + { + options = default; + } + + ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]); + } + + // Get a marshaling delegate for the return value. + ReturnParameterMarshaller = GetReturnParameterMarshaller(key, serializerOptions); + + Method = key.Method; + Name = key.Name ?? GetFunctionName(key.Method); + Description = key.Description ?? key.Method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty; + JsonSerializerOptions = serializerOptions; + JsonSchema = CreateFunctionJsonSchema( + key.Method, + Name, + Description, + serializerOptions, + parameterInfo => + { + // AIFunctionArguments and IServiceProvider parameters are always excluded from the schema. + if (parameterInfo.ParameterType == typeof(IServiceProvider)) + { + return false; + } + + // If the parameter is marked as excluded by GetBindParameterOptions, exclude it. + if (boundParameters?.TryGetValue(parameterInfo, out var options) is true && + options.ExcludeFromSchema) + { + return false; + } + + // Everything else is included. + return true; + }); + } + + /// + /// Determines a JSON schema for the provided method. + /// + /// The method from which to extract schema information. + /// The title keyword used by the method schema. + /// The description keyword used by the method schema. + /// The options used to extract the schema from the specified type. + /// Delegate controlling whether to include a parameter in the schema. + /// A JSON schema document encoded as a . + /// is . + private static JsonElement CreateFunctionJsonSchema( + MethodBase method, + string? title = null, + string? description = null, + JsonSerializerOptions? serializerOptions = null, + Func? includeParameter = null) + { + Throw.IfNull(method); + + serializerOptions ??= DefaultOptions; + title ??= method.Name; + description ??= method.GetCustomAttribute()?.Description; + + JsonObject parameterSchemas = new(); + JsonArray? requiredProperties = null; + foreach (ParameterInfo parameter in method.GetParameters()) + { + if (string.IsNullOrWhiteSpace(parameter.Name)) + { + throw new ArgumentException("Parameter is missing a name.", nameof(parameter)); + } + + if (parameter.ParameterType == typeof(CancellationToken)) + { + // CancellationToken is a special case that, by convention, we don't want to include in the schema. + // Invocations of methods that include a CancellationToken argument should also special-case CancellationToken + // to pass along what relevant token into the method's invocation. + continue; + } + + if (includeParameter?.Invoke(parameter) is false) + { + continue; + } + + JsonNode? parameterSchema = JsonSerializer.SerializeToNode(AIJsonUtilities.CreateJsonSchema( + type: parameter.ParameterType, + description: parameter.GetCustomAttribute(inherit: true)?.Description, + hasDefaultValue: parameter.HasDefaultValue, + defaultValue: parameter.HasDefaultValue ? parameter.DefaultValue : null, + serializerOptions), AIJsonUtilities.DefaultOptions.GetTypeInfo()); + + parameterSchemas.Add(parameter.Name, parameterSchema); + if (!parameter.IsOptional) + { + (requiredProperties ??= []).Add((JsonNode)parameter.Name); + } + } + + JsonObject schema = new(); + + if (!string.IsNullOrWhiteSpace(title)) + { + schema["title"] = title; + } + + if (!string.IsNullOrWhiteSpace(description)) + { + schema["description"] = description; + } + + schema["type"] = "object"; // Method schemas always hardcode the type as "object". + schema["properties"] = parameterSchemas; + + if (requiredProperties is not null) + { + schema["required"] = requiredProperties; + } + + return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode); + } + + public string Name { get; } + public string Description { get; } + public MethodInfo Method { get; } + public JsonSerializerOptions JsonSerializerOptions { get; } + public JsonElement JsonSchema { get; } + public Func, CancellationToken, object?>[] ParameterMarshallers { get; } + public Func> ReturnParameterMarshaller { get; } + public ReflectionAIFunction? CachedDefaultInstance { get; set; } + + private static string GetFunctionName(MethodInfo method) + { + // Get the function name to use. + string name = SanitizeMemberName(method.Name); + + const string AsyncSuffix = "Async"; + if (IsAsyncMethod(method) && + name.EndsWith(AsyncSuffix, StringComparison.Ordinal) && + name.Length > AsyncSuffix.Length) + { + name = name.Substring(0, name.Length - AsyncSuffix.Length); + } + + return name; + + static bool IsAsyncMethod(MethodInfo method) + { + Type t = method.ReturnType; + + if (t == typeof(Task) || t == typeof(ValueTask)) + { + return true; + } + + if (t.IsGenericType) + { + t = t.GetGenericTypeDefinition(); + if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) + { + return true; + } + } + + return false; + } + } + + /// + /// Gets a delegate for handling the marshaling of a parameter. + /// + private static Func, CancellationToken, object?> GetParameterMarshaller( + JsonSerializerOptions serializerOptions, + TemporaryAIFunctionFactoryOptions.ParameterBindingOptions bindingOptions, + ParameterInfo parameter) + { + if (string.IsNullOrWhiteSpace(parameter.Name)) + { + throw new ArgumentException("Parameter is missing a name.", nameof(parameter)); + } + + // Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found. + Type parameterType = parameter.ParameterType; + JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType); + + // For CancellationToken parameters, we always bind to the token passed directly to InvokeAsync. + if (parameterType == typeof(CancellationToken)) + { + return static (_, cancellationToken) => + cancellationToken == default ? _boxedDefaultCancellationToken : // optimize common case of a default CT to avoid boxing + cancellationToken; + } + + // CancellationToken is the only parameter type that's handled exclusively by the implementation. + // Now that it's been processed, check to see if the parameter should be handled via BindParameter. + if (bindingOptions.BindParameter is { } bindParameter) + { + return (arguments, _) => bindParameter(parameter, arguments); + } + + // We're now into default handling of everything else. + + // For IServiceProvider parameters, we bind to the services passed directly to InvokeAsync via AIFunctionArguments. + if (parameterType == typeof(IServiceProvider)) + { + return (arguments, _) => + { + arguments.TryGetValue("__temporary_IServiceProvider", out object? objServices); + IServiceProvider? services = objServices as IServiceProvider; + if (services is null && !parameter.HasDefaultValue) + { + throw new ArgumentException($"An {nameof(IServiceProvider)} was not provided for the {parameter.Name} parameter.", nameof(arguments)); + } + + return services; + }; + } + + // For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary. + return (arguments, _) => + { + // If the parameter has an argument specified in the dictionary, return that argument. + if (arguments.TryGetValue(parameter.Name, out object? value)) + { + return value switch + { + null => null, // Return as-is if null -- if the parameter is a struct this will be handled by MethodInfo.Invoke + _ when parameterType.IsInstanceOfType(value) => value, // Do nothing if value is assignable to parameter type + JsonElement element => JsonSerializer.Deserialize(element, typeInfo), + JsonDocument doc => JsonSerializer.Deserialize(doc, typeInfo), + JsonNode node => JsonSerializer.Deserialize(node, typeInfo), + _ => MarshallViaJsonRoundtrip(value), + }; + + object? MarshallViaJsonRoundtrip(object value) + { + try + { + string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType())); + return JsonSerializer.Deserialize(json, typeInfo); + } + catch + { + // Eat any exceptions and fall back to the original value to force a cast exception later on. + return value; + } + } + } + + // If the parameter is required and there's no argument specified for it, throw. + if (!parameter.HasDefaultValue) + { + throw new ArgumentException($"Missing required parameter '{parameter.Name}' for method '{parameter.Member.Name}'.", nameof(arguments)); + } + + // Otherwise, use the optional parameter's default value. + return parameter.DefaultValue; + }; + } + + /// + /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation. + /// + private static Func> GetReturnParameterMarshaller( + DescriptorKey key, JsonSerializerOptions serializerOptions) + { + Type returnType = key.Method.ReturnType; + JsonTypeInfo returnTypeInfo; + Func>? marshalResult = key.MarshalResult; + + // Void + if (returnType == typeof(void)) + { + if (marshalResult is not null) + { + return (result, cancellationToken) => marshalResult(null, null, cancellationToken); + } + + return static (_, _) => Task.FromResult((object?)null); + } + + // Task + if (returnType == typeof(Task)) + { + if (marshalResult is not null) + { + return async (result, cancellationToken) => + { + await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); + return await marshalResult(null, null, cancellationToken).ConfigureAwait(false); + }; + } + + return async static (result, _) => + { + await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); + return null; + }; + } + + // ValueTask + if (returnType == typeof(ValueTask)) + { + if (marshalResult is not null) + { + return async (result, cancellationToken) => + { + await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); + return await marshalResult(null, null, cancellationToken).ConfigureAwait(false); + }; + } + + return async static (result, _) => + { + await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); + return null; + }; + } + + if (returnType.IsGenericType) + { + // Task + if (returnType.GetGenericTypeDefinition() == typeof(Task<>)) + { + MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult); + returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType); + return async (taskObj, cancellationToken) => + { + await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false); + object? result = ReflectionInvoke(taskResultGetter, taskObj, null); + return marshalResult is not null ? + await marshalResult(result, returnTypeInfo.Type, cancellationToken).ConfigureAwait(false) : + await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + }; + } + + // ValueTask + if (returnType.GetGenericTypeDefinition() == typeof(ValueTask<>)) + { + MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask); + MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult); + returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType); + return async (taskObj, cancellationToken) => + { + var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; + await task.ConfigureAwait(false); + object? result = ReflectionInvoke(asTaskResultGetter, task, null); + return marshalResult is not null ? + await marshalResult(result, returnTypeInfo.Type, cancellationToken).ConfigureAwait(false) : + await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + }; + } + } + + // For everything else, just serialize the result as-is. + returnTypeInfo = serializerOptions.GetTypeInfo(returnType); + return marshalResult is not null ? + (result, cancellationToken) => marshalResult(result, returnTypeInfo.Type, cancellationToken) : + (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken); + + static async Task SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken) + { + if (returnTypeInfo.Kind is JsonTypeInfoKind.None) + { + // Special-case trivial contracts to avoid the more expensive general-purpose serialization path. + return JsonSerializer.SerializeToElement(result, returnTypeInfo); + } + + // Serialize asynchronously to support potential IAsyncEnumerable responses. + using PooledMemoryStream stream = new(); + await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + Utf8JsonReader reader = new(stream.GetBuffer()); + return JsonElement.ParseValue(ref reader); + } + + // Throws an exception if a result is found to be null unexpectedly + static object ThrowIfNullResult(object? result) => result ?? throw new InvalidOperationException("Function returned null unexpectedly."); + } + + private static readonly MethodInfo _taskGetResult = typeof(Task<>).GetProperty(nameof(Task.Result), BindingFlags.Instance | BindingFlags.Public)!.GetMethod!; + private static readonly MethodInfo _valueTaskAsTask = typeof(ValueTask<>).GetMethod(nameof(ValueTask.AsTask), BindingFlags.Instance | BindingFlags.Public)!; + + private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedType, MethodInfo genericMethodDefinition) + { + Debug.Assert(specializedType.IsGenericType && specializedType.GetGenericTypeDefinition() == genericMethodDefinition.DeclaringType, "generic member definition doesn't match type."); +#if NET + return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition); +#else + const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance; + return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken); +#endif + } + + private record struct DescriptorKey( + MethodInfo Method, + string? Name, + string? Description, + Func? GetBindParameterOptions, + Func>? MarshalResult, + AIJsonSchemaCreateOptions SchemaOptions); + } +} diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs new file mode 100644 index 00000000..e1f712d1 --- /dev/null +++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs @@ -0,0 +1,138 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel; +using System.Reflection; +using System.Text.Json; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents options that can be provided when creating an from a method. +/// +internal sealed class TemporaryAIFunctionFactoryOptions +{ + /// + /// Initializes a new instance of the class. + /// + public TemporaryAIFunctionFactoryOptions() + { + } + + /// Gets or sets the used to marshal .NET values being passed to the underlying delegate. + /// + /// If no value has been specified, the instance will be used. + /// + public JsonSerializerOptions? SerializerOptions { get; set; } + + /// + /// Gets or sets the governing the generation of JSON schemas for the function. + /// + /// + /// If no value has been specified, the instance will be used. + /// + public AIJsonSchemaCreateOptions? JsonSchemaCreateOptions { get; set; } + + /// Gets or sets the name to use for the function. + /// + /// The name to use for the function. The default value is a name derived from the method represented by the passed or . + /// + public string? Name { get; set; } + + /// Gets or sets the description to use for the function. + /// + /// The description for the function. The default value is a description derived from the passed or , if possible + /// (for example, via a on the method). + /// + public string? Description { get; set; } + + /// + /// Gets or sets additional values to store on the resulting property. + /// + /// + /// This property can be used to provide arbitrary information about the function. + /// + public IReadOnlyDictionary? AdditionalProperties { get; set; } + + /// Gets or sets a delegate used to determine how a particular parameter to the function should be bound. + /// + /// + /// If , the default parameter binding logic will be used. If non- value, + /// this delegate will be invoked once for each parameter in the function as part of creating the instance. + /// It is not invoked for parameters of type , which are invariably bound to the token + /// provided to the invocation. + /// + /// + /// Returning a default results in the same behavior as if + /// is . + /// + /// + public Func? ConfigureParameterBinding { get; set; } + + /// Gets or sets a delegate used to determine the returned by . + /// + /// + /// By default, the return value of invoking the method wrapped into an by + /// is then JSON serialized, with the resulting returned from the method. + /// This default behavior is ideal for the common case where the result will be passed back to an AI service. However, if the caller + /// requires more control over the result's marshaling, the property may be set to a delegate that is + /// then provided with complete control over the result's marshaling. The delegate is invoked with the value returned by the method, + /// and its return value is then returned from the method. + /// + /// + /// When set, the delegate is invoked even for -returning methods, in which case it is invoked with + /// a argument. By default, is returned from the + /// method for instances produced by to wrap + /// -returning methods). + /// + /// + /// Methods strongly-typed to return types of , , , + /// and are special-cased. For methods typed to return or , + /// will be invoked with the value after the returned task has successfully completed. + /// For methods typed to return or , the delegate will be invoked with the + /// task's result value after the task has successfully completed.These behaviors keep synchronous and asynchronous methods consistent. + /// + /// + /// In addition to the returned value, which is provided to the delegate as the first argument, the delegate is also provided with + /// a represented the declared return type of the method. This can be used to determine how to marshal the result. + /// This may be different than the actual type of the object () if the method returns a derived type + /// or . If the method is typed to return , , or , + /// the argument will be . + /// + /// + public Func>? MarshalResult { get; set; } + + /// Provides configuration options produced by the delegate. + public readonly record struct ParameterBindingOptions + { + /// Gets a delegate used to determine the value for a bound parameter. + /// + /// + /// The default value is . + /// + /// + /// If , the default binding semantics are used for the parameter. + /// If non- , each time the is invoked, this delegate will be invoked + /// to select the argument value to use for the parameter. The return value of the delegate will be used for the parameter's + /// value. + /// + /// + public Func, object?>? BindParameter { get; init; } + + /// Gets a value indicating whether the parameter should be excluded from the generated schema. + /// + /// + /// The default value is . + /// + /// + /// Typically, this property is set to if and only if is also set to + /// non-. While it's possible to exclude the schema when is , + /// doing so means that default marshaling will be used but the AI service won't be aware of the parameter or able to generate + /// an argument for it. This is likely to result in invocation errors, as the parameter information is unlikely to be available. + /// It, however, is permissible for cases where invocation of the is tightly controlled, and the caller + /// is expected to augment the argument dictionary with the parameter value. + /// + /// + public bool ExcludeFromSchema { get; init; } + } +} diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index 78f98483..71e9eb34 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -1,4 +1,5 @@ -using ModelContextProtocol.Protocol.Messages; +using Microsoft.Extensions.AI; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Types; using System.Diagnostics.CodeAnalysis; using System.Text.Json; @@ -62,6 +63,10 @@ private static JsonSerializerOptions CreateDefaultOptions() }; } + // Include all types from AIJsonUtilities, so that anything default usable as part of an AIFunction + // is also usable as part of an McpServerTool. + options.TypeInfoResolverChain.Add(AIJsonUtilities.DefaultOptions.TypeInfoResolver!); + options.MakeReadOnly(); return options; } diff --git a/src/ModelContextProtocol/Utils/SemaphoreSlimExtensions.cs b/src/ModelContextProtocol/Utils/SemaphoreSlimExtensions.cs new file mode 100644 index 00000000..0ee95098 --- /dev/null +++ b/src/ModelContextProtocol/Utils/SemaphoreSlimExtensions.cs @@ -0,0 +1,15 @@ +namespace ModelContextProtocol.Utils; + +internal static class SynchronizationExtensions +{ + public static async ValueTask LockAsync(this SemaphoreSlim semaphore, CancellationToken cancellationToken = default) + { + await semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + return new(semaphore); + } + + public readonly struct Releaser(SemaphoreSlim semaphore) : IDisposable + { + public void Dispose() => semaphore.Release(); + } +} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs new file mode 100644 index 00000000..43d51312 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -0,0 +1,86 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Configuration; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Transport; +using System.IO.Pipelines; + +namespace ModelContextProtocol.Tests.Client; + +public class McpClientExtensionsTests +{ + private Pipe _clientToServerPipe = new(); + private Pipe _serverToClientPipe = new(); + private readonly IMcpServer _server; + + public McpClientExtensionsTests() + { + ServiceCollection sc = new(); + sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); + sc.AddMcpServer(); + for (int f = 0; f < 10; f++) + { + string name = $"Method{f}"; + sc.AddSingleton(McpServerTool.Create((int i) => $"{name} Result {i}", name)); + } + _server = sc.BuildServiceProvider().GetRequiredService(); + } + + public ValueTask DisposeAsync() + { + _clientToServerPipe.Writer.Complete(); + _serverToClientPipe.Writer.Complete(); + return _server.DisposeAsync(); + } + + private async Task CreateMcpClientForServer() + { + await _server.StartAsync(TestContext.Current.CancellationToken); + + var stdin = new StreamReader(_serverToClientPipe.Reader.AsStream()); + var stdout = new StreamWriter(_clientToServerPipe.Writer.AsStream()); + + var serverConfig = new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }; + + return await McpClientFactory.CreateAsync( + serverConfig, + createTransportFunc: (_, _) => new StreamClientTransport(stdin, stdout), + cancellationToken: TestContext.Current.CancellationToken); + } + + [Fact] + public async Task ListToolsAsync_AllToolsReturned() + { + IMcpClient client = await CreateMcpClientForServer(); + + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + Assert.Equal(10, tools.Count); + var echo = tools.Single(t => t.Name == "Method4"); + var result = await echo.InvokeAsync(new Dictionary() { ["i"] = 42 }, TestContext.Current.CancellationToken); + Assert.Contains("Method4 Result 42", result?.ToString()); + } + + [Fact] + public async Task EnumerateToolsAsync_AllToolsReturned() + { + IMcpClient client = await CreateMcpClientForServer(); + + await foreach (var tool in client.EnumerateToolsAsync(TestContext.Current.CancellationToken)) + { + if (tool.Name == "Method4") + { + var result = await tool.InvokeAsync(new Dictionary() { ["i"] = 42 }, TestContext.Current.CancellationToken); + Assert.Contains("Method4 Result 42", result?.ToString()); + return; + } + } + + Assert.Fail("Couldn't find target method"); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index b210fb03..69321675 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -7,7 +7,6 @@ using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Tests.Utils; -using Xunit.Sdk; using System.Text.Encodings.Web; using System.Text.Json.Serialization.Metadata; using System.Text.Json.Serialization; @@ -18,6 +17,8 @@ public class ClientIntegrationTests : LoggedTest, IClassFixture string.IsNullOrWhiteSpace(s_openAIKey); + private readonly ClientIntegrationTestFixture _fixture; public ClientIntegrationTests(ClientIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) @@ -68,13 +69,10 @@ public async Task ListTools_Stdio(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken).ToListAsync(TestContext.Current.CancellationToken); - var aiFunctions = await client.GetAIFunctionsAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); // assert Assert.NotEmpty(tools); - Assert.NotEmpty(aiFunctions); - Assert.Equal(tools.Count, aiFunctions.Count); } [Theory] @@ -87,7 +85,7 @@ public async Task CallTool_Stdio_EchoServer(string clientId) await using var client = await _fixture.CreateClientAsync(clientId); var result = await client.CallToolAsync( "echo", - new Dictionary + new Dictionary { ["message"] = "Hello MCP!" }, @@ -109,7 +107,7 @@ public async Task CallTool_Stdio_ViaAIFunction_EchoServer(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - var aiFunctions = await client.GetAIFunctionsAsync(TestContext.Current.CancellationToken); + var aiFunctions = await client.ListToolsAsync(TestContext.Current.CancellationToken); var echo = aiFunctions.Single(t => t.Name == "echo"); var result = await echo.InvokeAsync([new KeyValuePair("message", "Hello MCP!")], TestContext.Current.CancellationToken); @@ -126,7 +124,7 @@ public async Task ListPrompts_Stdio(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken).ToListAsync(TestContext.Current.CancellationToken); + var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); // assert Assert.NotEmpty(prompts); @@ -158,7 +156,7 @@ public async Task GetPrompt_Stdio_ComplexPrompt(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - var arguments = new Dictionary + var arguments = new Dictionary { { "temperature", "0.7" }, { "style", "formal" } @@ -191,7 +189,7 @@ public async Task ListResourceTemplates_Stdio(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - List allResourceTemplates = await client.ListResourceTemplatesAsync(TestContext.Current.CancellationToken).ToListAsync(TestContext.Current.CancellationToken); + IList allResourceTemplates = await client.ListResourceTemplatesAsync(TestContext.Current.CancellationToken); // The server provides a single test resource template Assert.Single(allResourceTemplates); @@ -206,15 +204,7 @@ public async Task ListResources_Stdio(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - List allResources = []; - string? cursor = null; - do - { - var resources = await client.ListResourcesAsync(cursor, CancellationToken.None); - allResources.AddRange(resources.Resources); - cursor = resources.NextCursor; - } - while (cursor != null); + IList allResources = await client.ListResourcesAsync(TestContext.Current.CancellationToken); // The server provides 100 test resources Assert.Equal(100, allResources.Count); @@ -393,7 +383,7 @@ public async Task Sampling_Stdio(string clientId) // Call the server's sampleLLM tool which should trigger our sampling handler var result = await client.CallToolAsync( "sampleLLM", - new Dictionary + new Dictionary { ["prompt"] = "Test prompt", ["maxTokens"] = 100 @@ -483,7 +473,7 @@ public async Task CallTool_Stdio_MemoryServer() // act var result = await client.CallToolAsync( "read_graph", - [], + new Dictionary(), TestContext.Current.CancellationToken); // assert @@ -494,17 +484,15 @@ public async Task CallTool_Stdio_MemoryServer() await client.DisposeAsync(); } - [Fact] - public async Task GetAIFunctionsAsync_UsingEverythingServer_ToolsAreProperlyCalled() + [Fact(Skip = "Requires OpenAI API Key", SkipWhen = nameof(NoOpenAIKeySet))] + public async Task ListToolsAsync_UsingEverythingServer_ToolsAreProperlyCalled() { - SkipTestIfNoOpenAIKey(); - // Get the MCP client and tools from it. await using var client = await McpClientFactory.CreateAsync( _fixture.EverythingServerConfig, _fixture.DefaultOptions, cancellationToken: TestContext.Current.CancellationToken); - var mappedTools = await client.GetAIFunctionsAsync(TestContext.Current.CancellationToken); + var mappedTools = await client.ListToolsAsync(TestContext.Current.CancellationToken); // Create the chat client. using IChatClient chatClient = new OpenAIClient(s_openAIKey).AsChatClient("gpt-4o-mini") @@ -527,11 +515,9 @@ public async Task GetAIFunctionsAsync_UsingEverythingServer_ToolsAreProperlyCall Assert.Contains("Echo: Hello MCP!", response.Text); } - [Fact] + [Fact(Skip = "Requires OpenAI API Key", SkipWhen = nameof(NoOpenAIKeySet))] public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() { - SkipTestIfNoOpenAIKey(); - await using var client = await McpClientFactory.CreateAsync(_fixture.EverythingServerConfig, new() { ClientInfo = new() { Name = nameof(SamplingViaChatClient_RequestResponseProperlyPropagated), Version = "1.0.0" }, @@ -544,7 +530,7 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() }, }, cancellationToken: TestContext.Current.CancellationToken); - var result = await client.CallToolAsync("sampleLLM", new() + var result = await client.CallToolAsync("sampleLLM", new Dictionary() { ["prompt"] = "In just a few words, what is the most famous tower in Paris?", }, TestContext.Current.CancellationToken); @@ -590,9 +576,4 @@ public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) // assert Assert.True(logCounter > 0); } - - private static void SkipTestIfNoOpenAIKey() - { - Assert.SkipWhen(s_openAIKey is null, "No OpenAI key provided. Skipping test."); - } } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index c169d1e6..3ecf9a69 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -1,72 +1,139 @@ using System.ComponentModel; using System.Text.Json; -using System.Text.RegularExpressions; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Server; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Options; -using Moq; +using ModelContextProtocol.Protocol.Transport; +using System.IO.Pipelines; +using ModelContextProtocol.Client; +using ModelContextProtocol.Configuration; +using ModelContextProtocol.Tests.Transport; +using System.Text.RegularExpressions; +using Microsoft.Extensions.AI; +using System.Threading.Channels; +using ModelContextProtocol.Protocol.Messages; namespace ModelContextProtocol.Tests.Configuration; -public class McpServerBuilderExtensionsToolsTests +public class McpServerBuilderExtensionsToolsTests : IAsyncDisposable { - private readonly Mock _builder; - private readonly ServiceCollection _services; + private Pipe _clientToServerPipe = new(); + private Pipe _serverToClientPipe = new(); + private readonly IMcpServerBuilder _builder; + private readonly IMcpServer _server; public McpServerBuilderExtensionsToolsTests() { - _services = new ServiceCollection(); - _builder = new Mock(); - _builder.SetupGet(b => b.Services).Returns(_services); + ServiceCollection sc = new(); + sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); + _builder = sc.AddMcpServer().WithTools(); + _server = sc.BuildServiceProvider().GetRequiredService(); + } + + public ValueTask DisposeAsync() + { + _clientToServerPipe.Writer.Complete(); + _serverToClientPipe.Writer.Complete(); + return _server.DisposeAsync(); + } + + private async Task CreateMcpClientForServer() + { + await _server.StartAsync(TestContext.Current.CancellationToken); + + var stdin = new StreamReader(_serverToClientPipe.Reader.AsStream()); + var stdout = new StreamWriter(_clientToServerPipe.Writer.AsStream()); + + var serverConfig = new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }; + + return await McpClientFactory.CreateAsync( + serverConfig, + createTransportFunc: (_, _) => new StreamClientTransport(stdin, stdout), + cancellationToken: TestContext.Current.CancellationToken); } [Fact] public void Adds_Tools_To_Server() { - _builder.Object.WithTools(typeof(EchoTool)); - - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var tools = _server.ServerOptions?.Capabilities?.Tools?.ToolCollection; + Assert.NotNull(tools); + Assert.NotEmpty(tools); + } - Assert.NotNull(options.ListToolsHandler); - Assert.NotNull(options.CallToolHandler); + [Fact] + public async Task Can_List_Registered_Tools() + { + IMcpClient client = await CreateMcpClientForServer(); + + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + Assert.Equal(10, tools.Count); + + McpClientTool echoTool = tools.First(t => t.Name == "Echo"); + Assert.Equal("Echo", echoTool.Name); + Assert.Equal("Echoes the input back to the client.", echoTool.Description); + Assert.Equal("object", echoTool.JsonSchema.GetProperty("type").GetString()); + Assert.Equal(JsonValueKind.Object, echoTool.JsonSchema.GetProperty("properties").GetProperty("message").ValueKind); + Assert.Equal("the echoes message", echoTool.JsonSchema.GetProperty("properties").GetProperty("message").GetProperty("description").GetString()); + Assert.Equal(1, echoTool.JsonSchema.GetProperty("required").GetArrayLength()); + + McpClientTool doubleEchoTool = tools.First(t => t.Name == "double_echo"); + Assert.Equal("double_echo", doubleEchoTool.Name); + Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description); } [Fact] - public async Task Can_List_Registered_Tool() + public async Task Can_Be_Notified_Of_Tool_Changes() { - _builder.Object.WithTools(typeof(EchoTool)); + IMcpClient client = await CreateMcpClientForServer(); - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + Assert.Equal(10, tools.Count); - var result = await options.ListToolsHandler!(new(Mock.Of(), new()), CancellationToken.None); - Assert.NotNull(result); - Assert.NotEmpty(result.Tools); - - var tool = result.Tools[0]; - Assert.Equal("Echo", tool.Name); - Assert.Equal("Echoes the input back to the client.", tool.Description); - Assert.Equal("object", tool.InputSchema.GetProperty("type").GetString()); - Assert.Equal(JsonValueKind.Object, tool.InputSchema.GetProperty("properties").GetProperty("message").ValueKind); - Assert.Equal("the echoes message", tool.InputSchema.GetProperty("properties").GetProperty("message").GetProperty("description").GetString()); - Assert.Equal(1, tool.InputSchema.GetProperty("required").GetArrayLength()); - - tool = result.Tools[1]; - Assert.Equal("double_echo", tool.Name); - Assert.Equal("Echoes the input back to the client.", tool.Description); + Channel listChanged = Channel.CreateUnbounded(); + client.AddNotificationHandler("notifications/tools/list_changed", notification => + { + listChanged.Writer.TryWrite(notification); + return Task.CompletedTask; + }); + + var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); + Assert.False(notificationRead.IsCompleted); + + var serverTools = _server.ServerOptions.Capabilities?.Tools?.ToolCollection; + Assert.NotNull(serverTools); + + var newTool = McpServerTool.Create([McpServerTool(name: "NewTool")] () => "42"); + serverTools.Add(newTool); + await notificationRead; + + tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + Assert.Equal(11, tools.Count); + Assert.Contains(tools, t => t.Name == "NewTool"); + + notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); + Assert.False(notificationRead.IsCompleted); + serverTools.Remove(newTool); + await notificationRead; + + tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + Assert.Equal(10, tools.Count); + Assert.DoesNotContain(tools, t => t.Name == "NewTool"); } [Fact] public async Task Can_Call_Registered_Tool() { - _builder.Object.WithTools(typeof(EchoTool)); + IMcpClient client = await CreateMcpClientForServer(); - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var result = await client.CallToolAsync( + "Echo", + new Dictionary() { ["message"] = "Peter" }, + TestContext.Current.CancellationToken); - var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "Echo", Arguments = new() { { "message", "Peter" } } }), CancellationToken.None); Assert.NotNull(result); Assert.NotNull(result.Content); Assert.NotEmpty(result.Content); @@ -78,13 +145,13 @@ public async Task Can_Call_Registered_Tool() [Fact] public async Task Can_Call_Registered_Tool_With_Array_Result() { - _builder.Object.WithTools(typeof(EchoTool)); + IMcpClient client = await CreateMcpClientForServer(); - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var result = await client.CallToolAsync( + "EchoArray", + new Dictionary() { ["message"] = "Peter" }, + TestContext.Current.CancellationToken); - var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "EchoArray", Arguments = new() { { "message", "Peter" } } }), CancellationToken.None); - Assert.NotNull(result); Assert.NotNull(result.Content); Assert.NotEmpty(result.Content); @@ -95,12 +162,12 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Null_Result() { - _builder.Object.WithTools(typeof(EchoTool)); + IMcpClient client = await CreateMcpClientForServer(); - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var result = await client.CallToolAsync( + "ReturnNull", + cancellationToken: TestContext.Current.CancellationToken); - var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnNull" }), CancellationToken.None); Assert.NotNull(result); Assert.NotNull(result.Content); Assert.Empty(result.Content); @@ -109,30 +176,29 @@ public async Task Can_Call_Registered_Tool_With_Null_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Json_Result() { - _builder.Object.WithTools(typeof(EchoTool)); + IMcpClient client = await CreateMcpClientForServer(); - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var result = await client.CallToolAsync( + "ReturnJson", + cancellationToken: TestContext.Current.CancellationToken); - var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnJson" }), CancellationToken.None); Assert.NotNull(result); Assert.NotNull(result.Content); Assert.NotEmpty(result.Content); - Assert.Equal("{\"SomeProp\":false}", Regex.Replace(result.Content[0].Text ?? string.Empty, "\\s+", "")); + Assert.Equal("""{"SomeProp":false}""", Regex.Replace(result.Content[0].Text ?? string.Empty, "\\s+", "")); Assert.Equal("text", result.Content[0].Type); } [Fact] public async Task Can_Call_Registered_Tool_With_Int_Result() { - _builder.Object.WithTools(typeof(EchoTool)); + IMcpClient client = await CreateMcpClientForServer(); - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var result = await client.CallToolAsync( + "ReturnInteger", + cancellationToken: TestContext.Current.CancellationToken); - var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnInteger" }), CancellationToken.None); - Assert.NotNull(result); Assert.NotNull(result.Content); Assert.NotEmpty(result.Content); @@ -140,51 +206,16 @@ public async Task Can_Call_Registered_Tool_With_Int_Result() Assert.Equal("text", result.Content[0].Type); } - [Fact] - public async Task Can_Call_Registered_Tool_And_Pass_Cancellation_Token() - { - _builder.Object.WithTools(typeof(EchoTool)); - - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; - - using var cts = new CancellationTokenSource(); - var token = cts.Token; - - var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnCancellationToken" }), token); - Assert.NotNull(result); - Assert.NotNull(result.Content); - Assert.NotEmpty(result.Content); - - Assert.Equal(token.GetHashCode().ToString(), result.Content[0].Text); - } - - [Fact] - public async Task Can_Call_Registered_Tool_And_Returns_Cancelled_Response() - { - _builder.Object.WithTools(typeof(EchoTool)); - - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; - - using var cts = new CancellationTokenSource(); - var token = cts.Token; - await cts.CancelAsync(); - - var action = async () => await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnCancellationToken" }), token); - - await Assert.ThrowsAsync(action); - } - [Fact] public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() { - _builder.Object.WithTools(typeof(EchoTool)); + IMcpClient client = await CreateMcpClientForServer(); - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var result = await client.CallToolAsync( + "EchoComplex", + new Dictionary() { ["complex"] = JsonDocument.Parse("""{"Name": "Peter", "Age": 25}""").RootElement }, + cancellationToken: TestContext.Current.CancellationToken); - var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "EchoComplex", Arguments = new() { { "complex", JsonDocument.Parse("{\"Name\": \"Peter\", \"Age\": 25}").RootElement } } }), CancellationToken.None); Assert.NotNull(result); Assert.NotNull(result.Content); Assert.NotEmpty(result.Content); @@ -196,166 +227,195 @@ public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() [Fact] public async Task Returns_IsError_Content_When_Tool_Fails() { - _builder.Object.WithTools(typeof(EchoTool)); + IMcpClient client = await CreateMcpClientForServer(); - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var result = await client.CallToolAsync( + "ThrowException", + cancellationToken: TestContext.Current.CancellationToken); - var response = await options.CallToolHandler!(new(Mock.Of(), new() { Name = nameof(EchoTool.ThrowException) }), CancellationToken.None); - Assert.True(response.IsError); - Assert.NotNull(response.Content); - Assert.NotEmpty(response.Content); - Assert.Contains("Test error", response.Content[0].Text); + Assert.True(result.IsError); + Assert.NotNull(result.Content); + Assert.NotEmpty(result.Content); + Assert.Contains("Test error", result.Content[0].Text); } [Fact] public async Task Throws_Exception_On_Unknown_Tool() { - _builder.Object.WithTools(typeof(EchoTool)); + IMcpClient client = await CreateMcpClientForServer(); - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var e = await Assert.ThrowsAsync(async () => await client.CallToolAsync( + "NotRegisteredTool", + cancellationToken: TestContext.Current.CancellationToken)); - var exception = await Assert.ThrowsAsync(async () => await options.CallToolHandler!(new(Mock.Of(), new() { Name = "NotRegisteredTool" }), CancellationToken.None)); - Assert.Contains("'NotRegisteredTool'", exception.Message); + Assert.Contains("'NotRegisteredTool'", e.Message); } [Fact(Skip = "https://github.com/dotnet/extensions/issues/6124")] public async Task Throws_Exception_Missing_Parameter() { - _builder.Object.WithTools(typeof(EchoTool)); + IMcpClient client = await CreateMcpClientForServer(); - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var e = await Assert.ThrowsAsync(async () => await client.CallToolAsync( + "Echo", + cancellationToken: TestContext.Current.CancellationToken)); - var exception = await Assert.ThrowsAsync(async () => await options.CallToolHandler!(new(Mock.Of(), new() { Name = "Echo" }), CancellationToken.None)); - Assert.Equal("Missing required argument 'message'.", exception.Message); + Assert.Equal("Missing required argument 'message'.", e.Message); } [Fact] - public void Throws_Exception_For_Null_Types() + public void WithTools_InvalidArgs_Throws() { - Assert.Throws("toolTypes", () => _builder.Object.WithTools(toolTypes: null!)); - } + Assert.Throws("toolTypes", () => _builder.WithTools((IEnumerable)null!)); - [Fact] - public void Empty_Types_Is_Allowed() - { - _builder.Object.WithTools(toolTypes: []); // no exception + IMcpServerBuilder nullBuilder = null!; + Assert.Throws("builder", () => nullBuilder.WithTools()); + Assert.Throws("builder", () => nullBuilder.WithTools(Array.Empty())); + Assert.Throws("builder", () => nullBuilder.WithToolsFromAssembly()); } [Fact] - public async Task Register_Tools_From_Current_Assembly() + public void Empty_Enumerables_Is_Allowed() { - _builder.Object.WithTools(); - - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; - - var result = await options.ListToolsHandler!(new(Mock.Of(), new()), CancellationToken.None); - Assert.NotNull(result); - Assert.NotEmpty(result.Tools); - - var tool = result.Tools[0]; - Assert.Equal("Echo", tool.Name); + _builder.WithTools(toolTypes: []); // no exception + _builder.WithTools(); // no exception even though no tools exposed + _builder.WithToolsFromAssembly(typeof(AIFunction).Assembly); // no exception even though no tools exposed } [Fact] - public void Ok_If_No_Tools_Are_Found_In_Given_Assembly() + public void Register_Tools_From_Current_Assembly() { - _builder.Object.WithToolsFromAssembly(typeof(Mock).Assembly); + ServiceCollection sc = new(); + sc.AddMcpServer().WithToolsFromAssembly(); + IServiceProvider services = sc.BuildServiceProvider(); + + Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "Echo"); } [Fact] public async Task Recognizes_Parameter_Types() { - _builder.Object.WithTools(typeof(EchoTool)); + IMcpClient client = await CreateMcpClientForServer(); - var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - var result = await options.ListToolsHandler!(new(Mock.Of(), new()), CancellationToken.None); - Assert.NotNull(result); - Assert.NotEmpty(result.Tools); + Assert.NotNull(tools); + Assert.NotEmpty(tools); - var tool = result.Tools.First(t => t.Name == "TestTool"); + var tool = tools.First(t => t.Name == "TestTool"); Assert.Equal("TestTool", tool.Name); Assert.Empty(tool.Description!); - Assert.Equal("object", tool.InputSchema.GetProperty("type").GetString()); - - Assert.Contains("integer", tool.InputSchema.GetProperty("properties").GetProperty("number").GetProperty("type").GetString()); - Assert.Contains("number", tool.InputSchema.GetProperty("properties").GetProperty("otherNumber").GetProperty("type").GetString()); - Assert.Contains("boolean", tool.InputSchema.GetProperty("properties").GetProperty("someCheck").GetProperty("type").GetString()); - Assert.Contains("string", tool.InputSchema.GetProperty("properties").GetProperty("someDate").GetProperty("type").GetString()); - Assert.Contains("string", tool.InputSchema.GetProperty("properties").GetProperty("someOtherDate").GetProperty("type").GetString()); - Assert.Contains("array", tool.InputSchema.GetProperty("properties").GetProperty("data").GetProperty("type").GetString()); - Assert.Contains("object", tool.InputSchema.GetProperty("properties").GetProperty("complexObject").GetProperty("type").GetString()); + Assert.Equal("object", tool.JsonSchema.GetProperty("type").GetString()); + + Assert.Contains("integer", tool.JsonSchema.GetProperty("properties").GetProperty("number").GetProperty("type").GetString()); + Assert.Contains("number", tool.JsonSchema.GetProperty("properties").GetProperty("otherNumber").GetProperty("type").GetString()); + Assert.Contains("boolean", tool.JsonSchema.GetProperty("properties").GetProperty("someCheck").GetProperty("type").GetString()); + Assert.Contains("string", tool.JsonSchema.GetProperty("properties").GetProperty("someDate").GetProperty("type").GetString()); + Assert.Contains("string", tool.JsonSchema.GetProperty("properties").GetProperty("someOtherDate").GetProperty("type").GetString()); + Assert.Contains("array", tool.JsonSchema.GetProperty("properties").GetProperty("data").GetProperty("type").GetString()); + Assert.Contains("object", tool.JsonSchema.GetProperty("properties").GetProperty("complexObject").GetProperty("type").GetString()); + } + + [Fact] + public void Register_Tools_From_Multiple_Sources() + { + ServiceCollection sc = new(); + sc.AddMcpServer() + .WithTools() + .WithTools() + .WithTools(typeof(ToolTypeWithNoAttribute)); + IServiceProvider services = sc.BuildServiceProvider(); + + Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "double_echo"); + Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "DifferentName"); + Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "MethodB"); + Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "MethodC"); + Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "MethodD"); } - [McpToolType] - public static class EchoTool + [McpServerToolType] + public sealed class EchoTool { - [McpTool, Description("Echoes the input back to the client.")] + [McpServerTool, Description("Echoes the input back to the client.")] public static string Echo([Description("the echoes message")] string message) { return "hello " + message; } - [McpTool("double_echo"), Description("Echoes the input back to the client.")] + [McpServerTool("double_echo"), Description("Echoes the input back to the client.")] public static string Echo2(string message) { return "hello hello" + message; } - [McpTool] + [McpServerTool] public static string TestTool(int number, double otherNumber, bool someCheck, DateTime someDate, DateTimeOffset someOtherDate, string[] data, ComplexObject complexObject) { return "hello hello"; } - [McpTool] + [McpServerTool] public static string[] EchoArray(string message) { return ["hello " + message, "hello2 " + message]; } - [McpTool] + [McpServerTool] public static string? ReturnNull() { return null; } - [McpTool] + [McpServerTool] public static JsonElement ReturnJson() { return JsonDocument.Parse("{\"SomeProp\": false}").RootElement; } - [McpTool] + [McpServerTool] public static int ReturnInteger() { return 5; } - [McpTool] + [McpServerTool] public static string ThrowException() { throw new InvalidOperationException("Test error"); } - [McpTool] + [McpServerTool] public static int ReturnCancellationToken(CancellationToken cancellationToken) { return cancellationToken.GetHashCode(); } - [McpTool] + [McpServerTool] public static string EchoComplex(ComplexObject complex) { return complex.Name!; } } + [McpServerToolType] + internal class AnotherToolType + { + [McpServerTool("DifferentName")] + private static string MethodA(int a) => a.ToString(); + + [McpServerTool] + internal static string MethodB(string b) => b.ToString(); + + [McpServerTool] + protected static string MethodC(long c) => c.ToString(); + } + + internal class ToolTypeWithNoAttribute + { + [McpServerTool] + public static string MethodD(string d) => d.ToString(); + } + public class ComplexObject { public string? Name { get; set; } diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index 372e3a05..8d7f65ba 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -1,51 +1,51 @@ - - - - net8.0 - enable - enable - Latest - - false - true - ModelContextProtocol.Tests - - - + + + + net8.0 + enable + enable + Latest + + false + true + ModelContextProtocol.Tests + + + runtime; build; native; contentfiles; analyzers; buildtransitive all - + runtime; build; native; contentfiles; analyzers; buildtransitive all - - - - - - - - - - runtime; build; native; contentfiles; analyzers; buildtransitive - all - - - - - - - - - - - - PreserveNewest - - - PreserveNewest - - - - + + + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + PreserveNewest + + + PreserveNewest + + + + diff --git a/tests/ModelContextProtocol.Tests/Protocol/ProtocolTypeTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ProtocolTypeTests.cs index 81264336..d3881bc3 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ProtocolTypeTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ProtocolTypeTests.cs @@ -1,5 +1,4 @@ using ModelContextProtocol.Protocol.Types; -using ModelContextProtocol.Utils.Json; using System.Text.Json; namespace ModelContextProtocol.Tests.Protocol; diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index cddb9926..3f4dd1d8 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; @@ -24,7 +25,7 @@ public McpServerTests(ITestOutputHelper testOutputHelper) _serverTransport = new Mock(); _logger = new Mock(); _options = CreateOptions(); - _serviceProvider = new Mock().Object; + _serviceProvider = new ServiceCollection().BuildServiceProvider(); } private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = null) @@ -133,8 +134,7 @@ public async Task StartAsync_Sets_Initialized_After_Transport_Responses_Initiali await transport.SendMessageAsync(new JsonRpcNotification { Method = "notifications/initialized" - } -, TestContext.Current.CancellationToken); + }, TestContext.Current.CancellationToken); await Task.Delay(50, TestContext.Current.CancellationToken); @@ -678,12 +678,15 @@ public Task SendRequestAsync(JsonRpcRequest request, CancellationToken can public bool IsInitialized => throw new NotImplementedException(); public Implementation? ClientInfo => throw new NotImplementedException(); - public IServiceProvider? ServiceProvider => throw new NotImplementedException(); + public McpServerOptions ServerOptions => throw new NotImplementedException(); + public IServiceProvider? Services => throw new NotImplementedException(); public void AddNotificationHandler(string method, Func handler) => throw new NotImplementedException(); public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) => throw new NotImplementedException(); public Task StartAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public object? GetService(Type serviceType, object? serviceKey = null) => null; } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs new file mode 100644 index 00000000..49a82319 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -0,0 +1,93 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using Moq; +using System.Reflection; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Server; + +public class McpServerToolTests +{ + [Fact] + public void Create_InvalidArgs_Throws() + { + Assert.Throws("function", () => McpServerTool.Create(null!)); + Assert.Throws("method", () => McpServerTool.Create((MethodInfo)null!)); + Assert.Throws("method", () => McpServerTool.Create((Delegate)null!)); + } + + [Fact] + public async Task SupportsIMcpServer() + { + Mock mockServer = new(); + + McpServerTool tool = McpServerTool.Create((IMcpServer server) => + { + Assert.Same(mockServer.Object, server); + return "42"; + }); + + Assert.DoesNotContain("server", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema)); + + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + Assert.Equal("42", result.Content[0].Text); + } + + [Fact] + public async Task SupportsServiceFromDI() + { + MyService expectedMyService = new(); + + ServiceCollection sc = new(); + sc.AddSingleton(expectedMyService); + IServiceProvider services = sc.BuildServiceProvider(); + + McpServerTool tool = McpServerTool.Create((MyService actualMyService) => + { + Assert.Same(expectedMyService, actualMyService); + return "42"; + }, services: services); + + Assert.DoesNotContain("actualMyService", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema)); + + Mock mockServer = new(); + + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + Assert.True(result.IsError); + + mockServer.SetupGet(x => x.Services).Returns(services); + + result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + Assert.Equal("42", result.Content[0].Text); + } + + [Fact] + public async Task SupportsOptionalServiceFromDI() + { + MyService expectedMyService = new(); + + ServiceCollection sc = new(); + sc.AddSingleton(expectedMyService); + IServiceProvider services = sc.BuildServiceProvider(); + + McpServerTool tool = McpServerTool.Create((MyService? actualMyService = null) => + { + Assert.Null(actualMyService); + return "42"; + }, services: services); + + var result = await tool.InvokeAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + Assert.Equal("42", result.Content[0].Text); + } + + private sealed class MyService; +} diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 50188cfa..a2baf5d8 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -89,7 +89,7 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() defaultOptions, loggerFactory: loggerFactory, cancellationToken: TestContext.Current.CancellationToken); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken).ToListAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); // assert Assert.NotEmpty(tools); @@ -157,7 +157,7 @@ public async Task Sampling_Sse_EverythingServer() cancellationToken: TestContext.Current.CancellationToken); // Call the server's sampleLLM tool which should trigger our sampling handler - var result = await client.CallToolAsync("sampleLLM", new Dictionary + var result = await client.CallToolAsync("sampleLLM", new Dictionary { ["prompt"] = "Test prompt", ["maxTokens"] = 100 diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index 7142ee05..767005ea 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -56,7 +56,7 @@ public async Task ListTools_Sse_TestServer() // act var client = await GetClientAsync(); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken).ToListAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); // assert Assert.NotNull(tools); @@ -71,7 +71,7 @@ public async Task CallTool_Sse_EchoServer() var client = await GetClientAsync(); var result = await client.CallToolAsync( "echo", - new Dictionary + new Dictionary { ["message"] = "Hello MCP!" }, @@ -93,15 +93,7 @@ public async Task ListResources_Sse_TestServer() // act var client = await GetClientAsync(); - List allResources = []; - string? cursor = null; - do - { - var resources = await client.ListResourcesAsync(cursor, CancellationToken.None); - allResources.AddRange(resources.Resources); - cursor = resources.NextCursor; - } - while (cursor != null); + IList allResources = await client.ListResourcesAsync(TestContext.Current.CancellationToken); // The everything server provides 100 test resources Assert.Equal(100, allResources.Count); @@ -148,7 +140,7 @@ public async Task ListPrompts_Sse_TestServer() // act var client = await GetClientAsync(); - var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken).ToListAsync(TestContext.Current.CancellationToken); + var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); // assert Assert.NotNull(prompts); @@ -179,7 +171,7 @@ public async Task GetPrompt_Sse_ComplexPrompt() // act var client = await GetClientAsync(); - var arguments = new Dictionary + var arguments = new Dictionary { { "temperature", "0.7" }, { "style", "formal" } @@ -236,7 +228,7 @@ public async Task Sampling_Sse_TestServer() #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously // Call the server's sampleLLM tool which should trigger our sampling handler - var result = await client.CallToolAsync("sampleLLM", new Dictionary + var result = await client.CallToolAsync("sampleLLM", new Dictionary { ["prompt"] = "Test prompt", ["maxTokens"] = 100 diff --git a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs b/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs new file mode 100644 index 00000000..a8014e3a --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs @@ -0,0 +1,73 @@ +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Utils.Json; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Transport; + +internal sealed class StreamClientTransport : TransportBase, IClientTransport +{ + private readonly JsonSerializerOptions _jsonOptions = McpJsonUtilities.DefaultOptions; + private Task? _readTask; + private CancellationTokenSource _shutdownCts = new CancellationTokenSource(); + private readonly TextReader _stdin; + private readonly TextWriter _stdout; + + public StreamClientTransport(TextReader stdin, TextWriter stdout) + : base(NullLoggerFactory.Instance) + { + _stdin = stdin; + _stdout = stdout; + _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); + SetConnected(true); + } + + public Task ConnectAsync(CancellationToken cancellationToken = default) => Task.CompletedTask; + + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + string id = message is IJsonRpcMessageWithId messageWithId ? + messageWithId.Id.ToString() : + "(no id)"; + + await _stdout.WriteLineAsync(JsonSerializer.Serialize(message)).ConfigureAwait(false); + await _stdout.FlushAsync(cancellationToken).ConfigureAwait(false); + } + + private async Task ReadMessagesAsync(CancellationToken cancellationToken) + { + while (await _stdin.ReadLineAsync(cancellationToken).ConfigureAwait(false) is string line) + { + if (!string.IsNullOrWhiteSpace(line)) + { + try + { + if (JsonSerializer.Deserialize(line.Trim(), _jsonOptions) is { } message) + { + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + } + } + catch (JsonException) + { + } + } + } + } + + public override async ValueTask DisposeAsync() + { + if (_shutdownCts is { } shutdownCts) + { + await shutdownCts.CancelAsync().ConfigureAwait(false); + shutdownCts.Dispose(); + } + + if (_readTask is Task readTask) + { + await readTask.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(false); + } + + SetConnected(false); + } +} diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs index 44294009..f38d933e 100644 --- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs @@ -76,15 +76,6 @@ await WriteMessageAsync(new JsonRpcResponse }, cancellationToken); } - private async Task Error(JsonRpcRequest request, CancellationToken cancellationToken) - { - await WriteMessageAsync(new JsonRpcError - { - Id = request.Id, - Error = new JsonRpcErrorDetail() { Code = -32601, Message = $"Method '{request.Method}' not supported" } - }, cancellationToken); - } - protected async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { await _messageChannel.Writer.WriteAsync(message, cancellationToken);