Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add [McpServerTool] support for instance methods #100

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,51 @@ public static partial class McpServerBuilderExtensions
{
private const string RequiresUnreferencedCodeMessage = "This method requires dynamic lookup of method metadata and might not work in Native AOT.";

/// <summary>
/// Adds a tool to the server.
/// </summary>
/// <summary>Adds <see cref="McpServerTool"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <typeparam name="TTool">The tool type.</typeparam>
/// <param name="builder">The builder instance.</param>
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods)] TTool>(
/// <remarks>
/// This method discovers all instance and static methods (public and non-public) on the specified <typeparamref name="TTool"/>
/// type, where the methods are attributed as <see cref="McpServerToolAttribute"/>, and adds an <see cref="McpServerTool"/>
/// instance for each. For instance methods, an instance will be constructed for each invocation of the tool.
/// </remarks>
public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers(
DynamicallyAccessedMemberTypes.PublicMethods |
DynamicallyAccessedMemberTypes.NonPublicMethods |
DynamicallyAccessedMemberTypes.PublicConstructors)] TTool>(
this IMcpServerBuilder builder)
{
Throw.IfNull(builder);

foreach (var toolMethod in GetToolMethods(typeof(TTool)))
foreach (var toolMethod in typeof(TTool).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
{
builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, services: services));
if (toolMethod.GetCustomAttribute<McpServerToolAttribute>() is not null)
{
if (toolMethod.IsStatic)
{
builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, services: services));
}
else
{
builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, typeof(TTool), services: services));
}
}
}

return builder;
}

/// <summary>
/// Adds tools to the server.
/// </summary>
/// <summary>Adds <see cref="McpServerTool"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <param name="builder">The builder instance.</param>
/// <param name="toolTypes">Types with marked methods to add as tools to the server.</param>
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="toolTypes"/> is <see langword="null"/>.</exception>
/// <remarks>
/// This method discovers all instance and static methods (public and non-public) on the specified <paramref name="toolTypes"/>
/// types, where the methods are attributed as <see cref="McpServerToolAttribute"/>, and adds an <see cref="McpServerTool"/>
/// instance for each. For instance methods, an instance will be constructed for each invocation of the tool.
/// </remarks>
[RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params IEnumerable<Type> toolTypes)
{
Expand All @@ -50,13 +69,23 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params
{
if (toolType is not null)
{
foreach (var toolMethod in GetToolMethods(toolType))
foreach (var method in toolType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
{
builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, services: services));
if (method.GetCustomAttribute<McpServerToolAttribute>() is not null)
{
if (method.IsStatic)
{
builder.Services.AddSingleton(services => McpServerTool.Create(method, services: services));
}
else
{
builder.Services.AddSingleton(services => McpServerTool.Create(method, toolType, services: services));
}
}
}
}
}

return builder;
}

Expand All @@ -78,10 +107,4 @@ from t in toolAssembly.GetTypes()
where t.GetCustomAttribute<McpServerToolTypeAttribute>() is not null
select t);
}

private static IEnumerable<MethodInfo> GetToolMethods(
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods)] Type toolType) =>
from method in toolType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static)
where method.GetCustomAttribute<McpServerToolAttribute>() is not null
select method;
}
28 changes: 24 additions & 4 deletions src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Text.Json;

Expand All @@ -13,7 +14,7 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool
{
/// <summary>Key used temporarily for flowing request context into an AIFunction.</summary>
/// <remarks>This will be replaced with use of AIFunctionArguments.Context.</remarks>
private const string RequestContextKey = "__temporary_RequestContext";
internal const string RequestContextKey = "__temporary_RequestContext";

/// <summary>
/// Creates an <see cref="McpServerTool"/> instance for a method, specified via a <see cref="Delegate"/> instance.
Expand Down Expand Up @@ -48,7 +49,27 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool
// AIFunctionFactory, delete the TemporaryXx types, and fix-up the mechanism by
// which the arguments are passed.

return Create(TemporaryAIFunctionFactory.Create(method, target, new TemporaryAIFunctionFactoryOptions()
return Create(TemporaryAIFunctionFactory.Create(method, target, CreateAIFunctionFactoryOptions(method, name, description, services)));
}

/// <summary>
/// Creates an <see cref="McpServerTool"/> instance for a method, specified via a <see cref="Delegate"/> instance.
/// </summary>
public static new AIFunctionMcpServerTool Create(
MethodInfo method,
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType,
string? name = null,
string? description = null,
IServiceProvider? services = null)
{
Throw.IfNull(method);

return Create(TemporaryAIFunctionFactory.Create(method, targetType, CreateAIFunctionFactoryOptions(method, name, description, services)));
}

private static TemporaryAIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
MethodInfo method, string? name, string? description, IServiceProvider? services) =>
new TemporaryAIFunctionFactoryOptions()
{
Name = name ?? method.GetCustomAttribute<McpServerToolAttribute>()?.Name,
Description = description,
Expand Down Expand Up @@ -115,8 +136,7 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool
return null;
}
},
}));
}
};

/// <summary>Creates an <see cref="McpServerTool"/> that wraps the specified <see cref="AIFunction"/>.</summary>
public static new AIFunctionMcpServerTool Create(AIFunction function)
Expand Down
42 changes: 40 additions & 2 deletions src/ModelContextProtocol/Server/McpServerTool.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.Extensions.AI;
using ModelContextProtocol.Protocol.Types;
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;

namespace ModelContextProtocol.Server;
Expand Down Expand Up @@ -40,7 +41,7 @@ public abstract Task<CallToolResponse> InvokeAsync(
/// </param>
/// <param name="services">
/// Optional services used in the construction of the <see cref="McpServerTool"/>. These services will be
/// used to determine which parameters should be satisifed from dependency injection, and so what services
/// used to determine which parameters should be satisifed from dependency injection; what services
/// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time.
/// </param>
/// <returns>The created <see cref="McpServerTool"/> for invoking <paramref name="method"/>.</returns>
Expand Down Expand Up @@ -68,7 +69,7 @@ public static McpServerTool Create(
/// </param>
/// <param name="services">
/// Optional services used in the construction of the <see cref="McpServerTool"/>. These services will be
/// used to determine which parameters should be satisifed from dependency injection, and so what services
/// used to determine which parameters should be satisifed from dependency injection; what services
/// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time.
/// </param>
/// <returns>The created <see cref="McpServerTool"/> for invoking <paramref name="method"/>.</returns>
Expand All @@ -82,6 +83,43 @@ public static McpServerTool Create(
IServiceProvider? services = null) =>
AIFunctionMcpServerTool.Create(method, target, name, description, services);

/// <summary>
/// Creates an <see cref="McpServerTool"/> instance for a method, specified via an <see cref="MethodInfo"/> for
/// and instance method, along with a <see cref="Type"/> representing the type of the target object to
/// instantiate each time the method is invoked.
/// </summary>
/// <param name="method">The instance method to be represented via the created <see cref="AIFunction"/>.</param>
/// <param name="targetType">
/// The <see cref="Type"/> to construct an instance of on which to invoke <paramref name="method"/> when
/// the resulting <see cref="AIFunction"/> is invoked. If services are provided,
/// ActivatorUtilities.CreateInstance will be used to construct the instance using those services; otherwise,
/// <see cref="Activator.CreateInstance(Type)"/> is used, utilizing the type's public parameterless constructor.
/// If an instance can't be constructed, an exception is thrown during the function's invocation.
/// </param>
/// <param name="name">
/// The name to use for the <see cref="McpServerTool"/>. If <see langword="null"/>, but an <see cref="McpServerToolAttribute"/>
/// is applied to <paramref name="method"/>, the name from the attribute will be used. If that's not present, the name based
/// on <paramref name="method"/>'s name will be used.
/// </param>
/// <param name="description">
/// The description to use for the <see cref="McpServerTool"/>. If <see langword="null"/>, but a <see cref="DescriptionAttribute"/>
/// is applied to <paramref name="method"/>, the description from that attribute will be used.
/// </param>
/// <param name="services">
/// Optional services used in the construction of the <see cref="McpServerTool"/>. These services will be
/// used to determine which parameters should be satisifed from dependency injection; what services
/// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time.
/// </param>
/// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="method"/> is <see langword="null"/>.</exception>
public static McpServerTool Create(
MethodInfo method,
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType,
string? name = null,
string? description = null,
IServiceProvider? services = null) =>
AIFunctionMcpServerTool.Create(method, targetType, name, description, services);

/// <summary>Creates an <see cref="McpServerTool"/> that wraps the specified <see cref="AIFunction"/>.</summary>
/// <param name="function">The function to wrap.</param>
/// <exception cref="ArgumentNullException"><paramref name="function"/> is <see langword="null"/>.</exception>
Expand Down
Loading