diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs index 4733fce1..e11a4ab9 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs @@ -14,32 +14,51 @@ public static partial class McpServerBuilderExtensions { private const string RequiresUnreferencedCodeMessage = "This method requires dynamic lookup of method metadata and might not work in Native AOT."; - /// <summary> - /// Adds a tool to the server. - /// </summary> + /// <summary>Adds <see cref="McpServerTool"/> instances to the service collection backing <paramref name="builder"/>.</summary> /// <typeparam name="TTool">The tool type.</typeparam> /// <param name="builder">The builder instance.</param> /// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception> - public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods)] TTool>( + /// <remarks> + /// This method discovers all instance and static methods (public and non-public) on the specified <typeparamref name="TTool"/> + /// type, where the methods are attributed as <see cref="McpServerToolAttribute"/>, and adds an <see cref="McpServerTool"/> + /// instance for each. For instance methods, an instance will be constructed for each invocation of the tool. + /// </remarks> + public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods | + DynamicallyAccessedMemberTypes.PublicConstructors)] TTool>( this IMcpServerBuilder builder) { Throw.IfNull(builder); - foreach (var toolMethod in GetToolMethods(typeof(TTool))) + foreach (var toolMethod in typeof(TTool).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) { - builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, services: services)); + if (toolMethod.GetCustomAttribute<McpServerToolAttribute>() is not null) + { + if (toolMethod.IsStatic) + { + builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, services: services)); + } + else + { + builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, typeof(TTool), services: services)); + } + } } return builder; } - /// <summary> - /// Adds tools to the server. - /// </summary> + /// <summary>Adds <see cref="McpServerTool"/> instances to the service collection backing <paramref name="builder"/>.</summary> /// <param name="builder">The builder instance.</param> /// <param name="toolTypes">Types with marked methods to add as tools to the server.</param> /// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception> /// <exception cref="ArgumentNullException"><paramref name="toolTypes"/> is <see langword="null"/>.</exception> + /// <remarks> + /// This method discovers all instance and static methods (public and non-public) on the specified <paramref name="toolTypes"/> + /// types, where the methods are attributed as <see cref="McpServerToolAttribute"/>, and adds an <see cref="McpServerTool"/> + /// instance for each. For instance methods, an instance will be constructed for each invocation of the tool. + /// </remarks> [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params IEnumerable<Type> toolTypes) { @@ -50,13 +69,23 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params { if (toolType is not null) { - foreach (var toolMethod in GetToolMethods(toolType)) + foreach (var method in toolType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) { - builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, services: services)); + if (method.GetCustomAttribute<McpServerToolAttribute>() is not null) + { + if (method.IsStatic) + { + builder.Services.AddSingleton(services => McpServerTool.Create(method, services: services)); + } + else + { + builder.Services.AddSingleton(services => McpServerTool.Create(method, toolType, services: services)); + } + } } } } - + return builder; } @@ -78,10 +107,4 @@ from t in toolAssembly.GetTypes() where t.GetCustomAttribute<McpServerToolTypeAttribute>() is not null select t); } - - private static IEnumerable<MethodInfo> GetToolMethods( - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods)] Type toolType) => - from method in toolType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static) - where method.GetCustomAttribute<McpServerToolAttribute>() is not null - select method; } diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index d3fbd93c..ff3f9288 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; +using System.Diagnostics.CodeAnalysis; using System.Reflection; using System.Text.Json; @@ -13,7 +14,7 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool { /// <summary>Key used temporarily for flowing request context into an AIFunction.</summary> /// <remarks>This will be replaced with use of AIFunctionArguments.Context.</remarks> - private const string RequestContextKey = "__temporary_RequestContext"; + internal const string RequestContextKey = "__temporary_RequestContext"; /// <summary> /// Creates an <see cref="McpServerTool"/> instance for a method, specified via a <see cref="Delegate"/> instance. @@ -48,7 +49,27 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool // AIFunctionFactory, delete the TemporaryXx types, and fix-up the mechanism by // which the arguments are passed. - return Create(TemporaryAIFunctionFactory.Create(method, target, new TemporaryAIFunctionFactoryOptions() + return Create(TemporaryAIFunctionFactory.Create(method, target, CreateAIFunctionFactoryOptions(method, name, description, services))); + } + + /// <summary> + /// Creates an <see cref="McpServerTool"/> instance for a method, specified via a <see cref="Delegate"/> instance. + /// </summary> + public static new AIFunctionMcpServerTool Create( + MethodInfo method, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + string? name = null, + string? description = null, + IServiceProvider? services = null) + { + Throw.IfNull(method); + + return Create(TemporaryAIFunctionFactory.Create(method, targetType, CreateAIFunctionFactoryOptions(method, name, description, services))); + } + + private static TemporaryAIFunctionFactoryOptions CreateAIFunctionFactoryOptions( + MethodInfo method, string? name, string? description, IServiceProvider? services) => + new TemporaryAIFunctionFactoryOptions() { Name = name ?? method.GetCustomAttribute<McpServerToolAttribute>()?.Name, Description = description, @@ -115,8 +136,7 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool return null; } }, - })); - } + }; /// <summary>Creates an <see cref="McpServerTool"/> that wraps the specified <see cref="AIFunction"/>.</summary> public static new AIFunctionMcpServerTool Create(AIFunction function) diff --git a/src/ModelContextProtocol/Server/McpServerTool.cs b/src/ModelContextProtocol/Server/McpServerTool.cs index f6122764..c262df75 100644 --- a/src/ModelContextProtocol/Server/McpServerTool.cs +++ b/src/ModelContextProtocol/Server/McpServerTool.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.AI; using ModelContextProtocol.Protocol.Types; using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; using System.Reflection; namespace ModelContextProtocol.Server; @@ -40,7 +41,7 @@ public abstract Task<CallToolResponse> InvokeAsync( /// </param> /// <param name="services"> /// Optional services used in the construction of the <see cref="McpServerTool"/>. These services will be - /// used to determine which parameters should be satisifed from dependency injection, and so what services + /// used to determine which parameters should be satisifed from dependency injection; what services /// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time. /// </param> /// <returns>The created <see cref="McpServerTool"/> for invoking <paramref name="method"/>.</returns> @@ -68,7 +69,7 @@ public static McpServerTool Create( /// </param> /// <param name="services"> /// Optional services used in the construction of the <see cref="McpServerTool"/>. These services will be - /// used to determine which parameters should be satisifed from dependency injection, and so what services + /// used to determine which parameters should be satisifed from dependency injection; what services /// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time. /// </param> /// <returns>The created <see cref="McpServerTool"/> for invoking <paramref name="method"/>.</returns> @@ -82,6 +83,43 @@ public static McpServerTool Create( IServiceProvider? services = null) => AIFunctionMcpServerTool.Create(method, target, name, description, services); + /// <summary> + /// Creates an <see cref="McpServerTool"/> instance for a method, specified via an <see cref="MethodInfo"/> for + /// and instance method, along with a <see cref="Type"/> representing the type of the target object to + /// instantiate each time the method is invoked. + /// </summary> + /// <param name="method">The instance method to be represented via the created <see cref="AIFunction"/>.</param> + /// <param name="targetType"> + /// The <see cref="Type"/> to construct an instance of on which to invoke <paramref name="method"/> when + /// the resulting <see cref="AIFunction"/> is invoked. If services are provided, + /// ActivatorUtilities.CreateInstance will be used to construct the instance using those services; otherwise, + /// <see cref="Activator.CreateInstance(Type)"/> is used, utilizing the type's public parameterless constructor. + /// If an instance can't be constructed, an exception is thrown during the function's invocation. + /// </param> + /// <param name="name"> + /// The name to use for the <see cref="McpServerTool"/>. If <see langword="null"/>, but an <see cref="McpServerToolAttribute"/> + /// is applied to <paramref name="method"/>, the name from the attribute will be used. If that's not present, the name based + /// on <paramref name="method"/>'s name will be used. + /// </param> + /// <param name="description"> + /// The description to use for the <see cref="McpServerTool"/>. If <see langword="null"/>, but a <see cref="DescriptionAttribute"/> + /// is applied to <paramref name="method"/>, the description from that attribute will be used. + /// </param> + /// <param name="services"> + /// Optional services used in the construction of the <see cref="McpServerTool"/>. These services will be + /// used to determine which parameters should be satisifed from dependency injection; what services + /// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time. + /// </param> + /// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns> + /// <exception cref="ArgumentNullException"><paramref name="method"/> is <see langword="null"/>.</exception> + public static McpServerTool Create( + MethodInfo method, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + string? name = null, + string? description = null, + IServiceProvider? services = null) => + AIFunctionMcpServerTool.Create(method, targetType, name, description, services); + /// <summary>Creates an <see cref="McpServerTool"/> that wraps the specified <see cref="AIFunction"/>.</summary> /// <param name="function">The function to wrap.</param> /// <exception cref="ArgumentNullException"><paramref name="function"/> is <see langword="null"/>.</exception> diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs index bf0ae8ae..67e7a99b 100644 --- a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs +++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs @@ -1,10 +1,15 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; using ModelContextProtocol.Utils; using System.Collections.Concurrent; using System.ComponentModel; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + #if !NET using System.Linq; #endif @@ -110,6 +115,42 @@ public static AIFunction Create(MethodInfo method, object? target, TemporaryAIFu return ReflectionAIFunction.Build(method, target, options ?? _defaultOptions); } + /// <summary> + /// Creates an <see cref="AIFunction"/> instance for a method, specified via an <see cref="MethodInfo"/> instance + /// and an optional target object if the method is an instance method. + /// </summary> + /// <param name="method">The instance method to be represented via the created <see cref="AIFunction"/>.</param> + /// <param name="targetType"> + /// The <see cref="Type"/> to construct an instance of on which to invoke <paramref name="method"/> when + /// the resulting <see cref="AIFunction"/> is invoked. If services are provided, + /// ActivatorUtilities.CreateInstance will be used to construct the instance using those services; otherwise, + /// <see cref="Activator.CreateInstance(Type)"/> is used, utilizing the type's public parameterless constructor. + /// If an instance can't be constructed, an exception is thrown during the function's invocation. + /// </param> + /// <param name="options">Metadata to use to override defaults inferred from <paramref name="method"/>.</param> + /// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns> + /// <remarks> + /// <para> + /// Return values are serialized to <see cref="JsonElement"/> using <paramref name="options"/>'s + /// <see cref="AIFunctionFactoryOptions.SerializerOptions"/>. Arguments that are not already of the expected type are + /// marshaled to the expected type via JSON and using <paramref name="options"/>'s + /// <see cref="AIFunctionFactoryOptions.SerializerOptions"/>. If the argument is a <see cref="JsonElement"/>, + /// <see cref="JsonDocument"/>, or <see cref="JsonNode"/>, it is deserialized directly. If the argument is anything else unknown, + /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. + /// </para> + /// </remarks> + /// <exception cref="ArgumentNullException"><paramref name="method"/> is <see langword="null"/>.</exception> + public static AIFunction Create( + MethodInfo method, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + TemporaryAIFunctionFactoryOptions? options = null) + { + Throw.IfNull(method); + Throw.IfNull(targetType); + + return ReflectionAIFunction.Build(method, targetType, options ?? _defaultOptions); + } + /// <summary> /// Creates an <see cref="AIFunction"/> instance for a method, specified via an <see cref="MethodInfo"/> instance /// and an optional target object if the method is an instance method. @@ -176,6 +217,32 @@ public static ReflectionAIFunction Build(MethodInfo method, object? target, Temp return new(functionDescriptor, target, options); } + public static ReflectionAIFunction Build( + MethodInfo method, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + TemporaryAIFunctionFactoryOptions options) + { + Throw.IfNull(method); + + if (method.ContainsGenericParameters) + { + throw new ArgumentException("Open generic methods are not supported", nameof(method)); + } + + if (method.IsStatic) + { + throw new ArgumentException("The method must be an instance method.", nameof(method)); + } + + if (method.DeclaringType is { } declaringType && + !declaringType.IsAssignableFrom(targetType)) + { + throw new ArgumentException("The target type must be assignable to the method's declaring type.", nameof(targetType)); + } + + return new(ReflectionAIFunctionDescriptor.GetOrCreate(method, options), targetType, options); + } + private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, object? target, TemporaryAIFunctionFactoryOptions options) { FunctionDescriptor = functionDescriptor; @@ -183,8 +250,20 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, AdditionalProperties = options.AdditionalProperties ?? new Dictionary<string, object?>(); } + private ReflectionAIFunction( + ReflectionAIFunctionDescriptor functionDescriptor, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + TemporaryAIFunctionFactoryOptions options) + { + FunctionDescriptor = functionDescriptor; + TargetType = targetType; + AdditionalProperties = options.AdditionalProperties ?? new Dictionary<string, object?>(); + } + public ReflectionAIFunctionDescriptor FunctionDescriptor { get; } public object? Target { get; } + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] + public Type? TargetType { get; } public override IReadOnlyDictionary<string, object?> AdditionalProperties { get; } public override string Name => FunctionDescriptor.Name; public override string Description => FunctionDescriptor.Description; @@ -192,22 +271,59 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema; public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; - protected override Task<object?> InvokeCoreAsync( + protected override async Task<object?> InvokeCoreAsync( IEnumerable<KeyValuePair<string, object?>> arguments, CancellationToken cancellationToken) { - var paramMarshallers = FunctionDescriptor.ParameterMarshallers; - object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : []; - Dictionary<string, object?> argumentsDictionary = arguments.ToDictionary(); - for (int i = 0; i < args.Length; i++) + bool disposeTarget = false; + object? target = Target; + try { - args[i] = paramMarshallers[i](argumentsDictionary, cancellationToken); - } + if (TargetType is { } targetType) + { + Debug.Assert(target is null, "Expected target to be null when we have a non-null target type"); + Debug.Assert(!FunctionDescriptor.Method.IsStatic, "Expected an instance method"); + + if (argumentsDictionary.TryGetValue(AIFunctionMcpServerTool.RequestContextKey, out object? value) && + value is RequestContext<CallToolRequestParams> requestContext && + requestContext.Server?.Services is { } services) + { + target = ActivatorUtilities.CreateInstance(services, targetType!); + } + else + { + target = Activator.CreateInstance(targetType); + } + + disposeTarget = true; + } + var paramMarshallers = FunctionDescriptor.ParameterMarshallers; + object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : []; - return FunctionDescriptor.ReturnParameterMarshaller( - ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken); + for (int i = 0; i < args.Length; i++) + { + args[i] = paramMarshallers[i](argumentsDictionary, cancellationToken); + } + + return await FunctionDescriptor.ReturnParameterMarshaller( + ReflectionInvoke(FunctionDescriptor.Method, target, args), cancellationToken).ConfigureAwait(false); + } + finally + { + if (disposeTarget) + { + if (target is IAsyncDisposable ad) + { + await ad.DisposeAsync().ConfigureAwait(false); + } + else if (target is IDisposable d) + { + d.Dispose(); + } + } + } } } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 3ecf9a69..dbf13536 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -25,6 +25,7 @@ public McpServerBuilderExtensionsToolsTests() { ServiceCollection sc = new(); sc.AddSingleton<IServerTransport>(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); + sc.AddSingleton(new ObjectWithId()); _builder = sc.AddMcpServer().WithTools<EchoTool>(); _server = sc.BuildServiceProvider().GetRequiredService<IMcpServer>(); } @@ -70,7 +71,7 @@ public async Task Can_List_Registered_Tools() IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(10, tools.Count); + Assert.Equal(11, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); Assert.Equal("Echo", echoTool.Name); @@ -91,7 +92,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(10, tools.Count); + Assert.Equal(11, tools.Count); Channel<JsonRpcNotification> listChanged = Channel.CreateUnbounded<JsonRpcNotification>(); client.AddNotificationHandler("notifications/tools/list_changed", notification => @@ -111,7 +112,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() await notificationRead; tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(11, tools.Count); + Assert.Equal(12, tools.Count); Assert.Contains(tools, t => t.Name == "NewTool"); notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); @@ -120,7 +121,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() await notificationRead; tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(10, tools.Count); + Assert.Equal(11, tools.Count); Assert.DoesNotContain(tools, t => t.Name == "NewTool"); } @@ -224,6 +225,35 @@ public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() Assert.Equal("text", result.Content[0].Type); } + [Fact] + public async Task Can_Call_Registered_Tool_With_Instance_Method() + { + IMcpClient client = await CreateMcpClientForServer(); + + string[][] parts = new string[2][]; + for (int i = 0; i < 2; i++) + { + var result = await client.CallToolAsync( + nameof(EchoTool.GetCtorParameter), + cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(result); + Assert.NotNull(result.Content); + Assert.NotEmpty(result.Content); + + parts[i] = result.Content[0].Text?.Split(':') ?? []; + Assert.Equal(2, parts[i].Length); + } + + string random1 = parts[0][0]; + string random2 = parts[1][0]; + Assert.NotEqual(random1, random2); + + string id1 = parts[0][1]; + string id2 = parts[1][1]; + Assert.Equal(id1, id2); + } + [Fact] public async Task Returns_IsError_Content_When_Tool_Fails() { @@ -334,8 +364,10 @@ public void Register_Tools_From_Multiple_Sources() } [McpServerToolType] - public sealed class EchoTool + public sealed class EchoTool(ObjectWithId objectFromDI) { + private string _randomValue = Guid.NewGuid().ToString("N"); + [McpServerTool, Description("Echoes the input back to the client.")] public static string Echo([Description("the echoes message")] string message) { @@ -395,6 +427,9 @@ public static string EchoComplex(ComplexObject complex) { return complex.Name!; } + + [McpServerTool] + public string GetCtorParameter() => $"{_randomValue}:{objectFromDI.Id}"; } [McpServerToolType] @@ -421,4 +456,9 @@ public class ComplexObject public string? Name { get; set; } public int Age { get; set; } } + + public class ObjectWithId + { + public string Id { get; set; } = Guid.NewGuid().ToString("N"); + } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 49a82319..3f066dd5 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -4,6 +4,7 @@ using Moq; using System.Reflection; using System.Text.Json; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Tests.Server; @@ -89,5 +90,120 @@ public async Task SupportsOptionalServiceFromDI() Assert.Equal("42", result.Content[0].Text); } + [Fact] + public async Task SupportsDisposingInstantiatedDisposableTargets() + { + McpServerTool tool1 = McpServerTool.Create( + typeof(DisposableToolType).GetMethod(nameof(DisposableToolType.InstanceMethod))!, + typeof(DisposableToolType)); + + var result = await tool1.InvokeAsync( + new RequestContext<CallToolRequestParams>(null!, null), + TestContext.Current.CancellationToken); + Assert.Equal("""{"disposals":1}""", result.Content[0].Text); + } + + [Fact] + public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() + { + McpServerTool tool1 = McpServerTool.Create( + typeof(AsyncDisposableToolType).GetMethod(nameof(AsyncDisposableToolType.InstanceMethod))!, + typeof(AsyncDisposableToolType)); + + var result = await tool1.InvokeAsync( + new RequestContext<CallToolRequestParams>(null!, null), + TestContext.Current.CancellationToken); + Assert.Equal("""{"asyncDisposals":1}""", result.Content[0].Text); + } + + [Fact] + public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposableTargets() + { + McpServerTool tool1 = McpServerTool.Create( + typeof(AsyncDisposableAndDisposableToolType).GetMethod(nameof(AsyncDisposableAndDisposableToolType.InstanceMethod))!, + typeof(AsyncDisposableAndDisposableToolType)); + + var result = await tool1.InvokeAsync( + new RequestContext<CallToolRequestParams>(null!, null), + TestContext.Current.CancellationToken); + Assert.Equal("""{"asyncDisposals":1,"disposals":0}""", result.Content[0].Text); + } + private sealed class MyService; + + private class DisposableToolType : IDisposable + { + public int Disposals { get; private set; } + + public void Dispose() + { + Disposals++; + } + + public object InstanceMethod() + { + if (Disposals != 0) + { + throw new InvalidOperationException("Dispose was called"); + } + + return this; + } + } + + private class AsyncDisposableToolType : IAsyncDisposable + { + public int AsyncDisposals { get; private set; } + + public ValueTask DisposeAsync() + { + AsyncDisposals++; + return default; + } + + public object InstanceMethod() + { + if (AsyncDisposals != 0) + { + throw new InvalidOperationException("DisposeAsync was called"); + } + + return this; + } + } + + private class AsyncDisposableAndDisposableToolType : IAsyncDisposable, IDisposable + { + [JsonPropertyOrder(0)] + public int AsyncDisposals { get; private set; } + + [JsonPropertyOrder(1)] + public int Disposals { get; private set; } + + public void Dispose() + { + Disposals++; + } + + public ValueTask DisposeAsync() + { + AsyncDisposals++; + return default; + } + + public object InstanceMethod() + { + if (Disposals != 0) + { + throw new InvalidOperationException("Dispose was called"); + } + + if (AsyncDisposals != 0) + { + throw new InvalidOperationException("DisposeAsync was called"); + } + + return this; + } + } }