Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable AIFunctions to be passed an IServiceProvider, Alternate 3 #6146

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Reflection;
using System.Text.Json;
@@ -53,22 +54,26 @@ public abstract class AIFunction : AITool

/// <summary>Invokes the <see cref="AIFunction"/> and returns its result.</summary>
/// <param name="arguments">The arguments to pass to the function's invocation.</param>
/// <param name="services">The <see cref="IServiceProvider"/> optionally associated with this invocation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The result of the function's execution.</returns>
public Task<object?> InvokeAsync(
IEnumerable<KeyValuePair<string, object?>>? arguments = null,
IServiceProvider? services = null,
Copy link
Member

@halter73 halter73 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of all the alternative proposals, I think I like this one the best.

This makes the IServiceProvider way more discoverable compared to putting it in the IEnumerable<KeyValuePair<string, object?>>? param or the FunctionInvokingChatClient.CurrentContext async local.

I'd be more okay with putting the IServiceProvider in the arguments parameter if it was typed as a AIFunctionArguments rather than IEnumerable<KeyValuePair<string, object?>>?, since then the IServiceProvider would be a little more discoverable, but still I think the IServiceProvider should be a first-class concept similar to the CancellationToken. It's very different than all the other non-CancellationToken arguments that are deserialized from the tool invocation.

I also like that it doesn't introduce a FromServicesAttribute in the Microsoft.Extensions.AI namespace. I get the other proposals don't require that, but I think it's a good decision. I think wanting to use attributes to identify service parameters is a common enough scenario to move into Microst.Extensions.DependencyInjection.Abstractions alongside the FromKeyedServicesAttribute.

However, it might be interesting to have an overload of AIFunction.Create that took an IServiceProvider and used IServiceProviderIsService to determine which parameters should come from the service provider. Than I think it would be reasonable to respect the existing FromKeyedServicesAttribute without adding a new FromServicesAttribute which would no longer be necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, it might be interesting to have an overload of AIFunction.Create that took an IServiceProvider and used IServiceProviderIsService to determine which parameters should come from the service provider.

It still seems really strange to me that we'd base decisions for the returned AIFunction on an IServiceProvider instance that is possibly not the same IServiceProvider instance that's used later for invocation.

I think wanting to use attributes to identify service parameters is a common enough scenario to move into Microst.Extensions.DependencyInjection.Abstractions alongside the FromKeyedServicesAttribute.

We'd need to introduce a new attribute, presumably, in order to fix the namespace. And ASP.NET would then need to support both (the interface it uses wouldn't be implementable on the M.E.DI.Abstractions one).

Seems like for now the best answer is to just special-case IServerProvider in the signature of AIFunctionFactory.Create methods, and we could later support the convenience attribute mechanism.

That is separate from how the IServiceProvider finds its way into the AIFunction invocation. Choices are basically:

  1. Introduce an AIFunctionArguments type, leave InvokeAsync signature as it is, AIFunction implementations can type test for AIFunctionArguments, callers of InvokeAsync that have an IServiceProvider instantiate an AIFunctionArguments.
  2. Introduce an AIFunctionArguments type, change InvokeAsync's arguments parameter to be strongly typed as AIFunctionArguments.
  3. Change InvokeAsync's signature to take the arguments, the service provider, and the cancellation token, all as peers.

(1) is the only one we could do if we were doing this in a month. As we're doing it now, we can choose to take the break for (2) or (3).

(1) is hacky and relies on callers knowing they should instantiate the special type and implementers knowing they should check for it.

(2) makes it clear exactly what type should be used, and doesn't promote IServiceProvider to the same importance conceptually as the nominal arguments or cancellation token, while still allowing for future expansion should there be other state we want to allow to flow in. But it also forces all inputs to be wrapped in this special collection, which has some overhead and inconvenience (though most code won't directly invoke AIFunctions).

(3) doesn't require such wrapping and makes it obvious that an IServiceProvider can be passed in, but effectively makes DI primary in the API and doesn't afford us the ability to pass in more state in the future (we'd end up falling back to (1) if we need to be able to do that).

I'm leaning towards (2). @halter73 is leaning towards (3). @eiriktsarpalis? @SteveSandersonMS?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still seems really strange to me that we'd base decisions for the returned AIFunction on an IServiceProvider instance that is possibly not the same IServiceProvider instance that's used later for invocation.It still seems really strange to me that we'd base decisions for the returned AIFunction on an IServiceProvider instance that is possibly not the same IServiceProvider instance that's used later for invocation.

I don't think it's that strange. It's something we do a lot in ASP.NET Core. Here's an example from minimal APIs. We do the same thing for MVC and SignalR to avoid regenerating the method invocation code every request.

In theory, someone could replace HttpContext.RequestServices to not match IHost.Services, but that hasn't been a problem. In practice, RequestServices usually represents a scope created from IHost.Services.

It is really convenient to get services automatically injected without needing to attribute them or rely on the service locator pattern, so it'd be sad to give that up to avoid the small risk of someone getting confused because they used multiple service providers with different sets of services. However, I suppose we could always add support for this later, since this would involve new overloads to AIFunction.Create.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference in my mind is how integrated the usage is into the DI-rooted programming model in the ASP..NET cases. The IServiceProvider ends up largely being ambient implicit context. That is not the case for the largely independent AIFunctionFactory, where you'd be very explicitly providing it with a specific IServiceProvider instance rather than it being picked up automatically from the environment.

we could always add support for this later, since this would involve new overloads to AIFunction.Create

Yes, though I don't think it would need new overloads. We'd just change the behavior if an IServiceProvider was set into a new property on the creation options.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm leaning towards (2). @halter73 is leaning towards (3). @eiriktsarpalis? @SteveSandersonMS?

Barring other's input, I'm going to go with (2). I think we'll regret not having the mechanism to pass in additional state.

CancellationToken cancellationToken = default)
{
arguments ??= EmptyReadOnlyDictionary<string, object?>.Instance;

return InvokeCoreAsync(arguments, cancellationToken);
return InvokeCoreAsync(arguments, services, cancellationToken);
}

/// <summary>Invokes the <see cref="AIFunction"/> and returns its result.</summary>
/// <param name="arguments">The arguments to pass to the function's invocation.</param>
/// <param name="services">The <see cref="IServiceProvider"/> optionally associated with this invocation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
/// <returns>The result of the function's execution.</returns>
protected abstract Task<object?> InvokeCoreAsync(
IEnumerable<KeyValuePair<string, object?>> arguments,
IServiceProvider? services,
CancellationToken cancellationToken);
}
Original file line number Diff line number Diff line change
@@ -79,11 +79,15 @@ public static JsonElement CreateFunctionJsonSchema(
Throw.ArgumentException(nameof(parameter), "Parameter is missing a name.");
}

if (parameter.ParameterType == typeof(CancellationToken))
if (parameter.ParameterType == typeof(CancellationToken) ||
parameter.ParameterType == typeof(IServiceProvider))
{
// CancellationToken is a special case that, by convention, we don't want to include in the schema.
// Invocations of methods that include a CancellationToken argument should also special-case CancellationToken
// to pass along what relevant token into the method's invocation.

// IServiceProvider is a special because it's directly handled by AIFunction.

continue;
}

Original file line number Diff line number Diff line change
@@ -107,7 +107,7 @@ public static async Task HandleToolCallsAsync(

try
{
var result = await aiFunction.InvokeAsync(functionCallContent.Arguments, cancellationToken).ConfigureAwait(false);
var result = await aiFunction.InvokeAsync(functionCallContent.Arguments, null, cancellationToken).ConfigureAwait(false);
var resultJson = JsonSerializer.Serialize(result, jsonOptions.GetTypeInfo(typeof(object)));
return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, resultJson);
}
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Shared.Diagnostics;
@@ -63,11 +64,13 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient
/// </summary>
/// <param name="innerClient">The underlying <see cref="IChatClient"/>, or the next instance in a chain of clients.</param>
/// <param name="logger">An <see cref="ILogger"/> to use for logging information about function invocation.</param>
public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null)
/// <param name="services">An optional <see cref="IServiceProvider"/> to use for resolving services required by the <see cref="AIFunction"/> instances being invoked.</param>
public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null, IServiceProvider? services = null)
: base(innerClient)
{
_logger = logger ?? NullLogger.Instance;
_logger = logger ?? (ILogger?)services?.GetService<ILogger<FunctionInvokingChatClient>>() ?? NullLogger.Instance;
_activitySource = innerClient.GetService<ActivitySource>();
Services = services;
}

/// <summary>
@@ -82,6 +85,9 @@ public static FunctionInvocationContext? CurrentContext
protected set => _currentContext.Value = value;
}

/// <summary>Gets the <see cref="IServiceProvider"/> used for resolving services required by the <see cref="AIFunction"/> instances being invoked.</summary>
public IServiceProvider? Services { get; }

/// <summary>
/// Gets or sets a value indicating whether to handle exceptions that occur during function calls.
/// </summary>
@@ -722,7 +728,10 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
try
{
CurrentContext = context;
result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
result = await context.Function.InvokeAsync(
context.CallContent.Arguments,
Services,
cancellationToken).ConfigureAwait(false);
}
catch (Exception e)
{
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ public static ChatClientBuilder UseFunctionInvocation(
{
loggerFactory ??= services.GetService<ILoggerFactory>();

var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)));
var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)), services);
configure?.Invoke(chatClient);
return chatClient;
});
Original file line number Diff line number Diff line change
@@ -17,6 +17,9 @@
using Microsoft.Shared.Collections;
using Microsoft.Shared.Diagnostics;

#pragma warning disable CA1031 // Do not catch general exception types
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields

namespace Microsoft.Extensions.AI;

/// <summary>Provides factory methods for creating commonly used implementations of <see cref="AIFunction"/>.</summary>
@@ -190,14 +193,15 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor,
public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions;
protected override Task<object?> InvokeCoreAsync(
IEnumerable<KeyValuePair<string, object?>>? arguments,
IServiceProvider? services,
CancellationToken cancellationToken)
{
var paramMarshallers = FunctionDescriptor.ParameterMarshallers;
object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : [];

IReadOnlyDictionary<string, object?> argDict =
arguments is null || args.Length == 0 ? EmptyReadOnlyDictionary<string, object?>.Instance :
arguments as IReadOnlyDictionary<string, object?> ??
arguments is null ? EmptyReadOnlyDictionary<string, object?>.Instance :
arguments as IReadOnlyDictionary<string, object?> ?? // if arguments is an AIFunctionArguments, which is an IROD, use it as-is
arguments.
#if NET8_0_OR_GREATER
ToDictionary();
@@ -206,7 +210,7 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor,
#endif
for (int i = 0; i < args.Length; i++)
{
args[i] = paramMarshallers[i](argDict, cancellationToken);
args[i] = paramMarshallers[i](argDict, services, cancellationToken);
}

return FunctionDescriptor.ReturnParameterMarshaller(ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken);
@@ -248,9 +252,31 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu

private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions)
{
AIJsonSchemaCreateOptions schemaOptions = new()
{
// This needs to be kept in sync with the shape of AIJsonSchemaCreateOptions.
TransformSchemaNode = key.SchemaOptions.TransformSchemaNode,
IncludeParameter = parameterInfo =>
{
// Explicitly exclude IServiceProvider. It'll be satisifed via AIFunctionArguments.
if (parameterInfo.ParameterType == typeof(IServiceProvider))
{
return false;
}

// For all other parameters, delegate to whatever behavior is specified in the options.
// If none is specified, include the parameter.
return key.SchemaOptions.IncludeParameter?.Invoke(parameterInfo) ?? true;
},
IncludeTypeInEnumSchemas = key.SchemaOptions.IncludeTypeInEnumSchemas,
DisallowAdditionalProperties = key.SchemaOptions.DisallowAdditionalProperties,
IncludeSchemaKeyword = key.SchemaOptions.IncludeSchemaKeyword,
RequireAllProperties = key.SchemaOptions.RequireAllProperties,
};

// Get marshaling delegates for parameters.
ParameterInfo[] parameters = key.Method.GetParameters();
ParameterMarshallers = new Func<IReadOnlyDictionary<string, object?>, CancellationToken, object?>[parameters.Length];
ParameterMarshallers = new Func<IReadOnlyDictionary<string, object?>, IServiceProvider?, CancellationToken, object?>[parameters.Length];
for (int i = 0; i < parameters.Length; i++)
{
ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, parameters[i]);
@@ -268,15 +294,15 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
Name,
Description,
serializerOptions,
key.SchemaOptions);
schemaOptions);
}

public string Name { get; }
public string Description { get; }
public MethodInfo Method { get; }
public JsonSerializerOptions JsonSerializerOptions { get; }
public JsonElement JsonSchema { get; }
public Func<IReadOnlyDictionary<string, object?>, CancellationToken, object?>[] ParameterMarshallers { get; }
public Func<IReadOnlyDictionary<string, object?>, IServiceProvider?, CancellationToken, object?>[] ParameterMarshallers { get; }
public Func<object?, CancellationToken, Task<object?>> ReturnParameterMarshaller { get; }
public ReflectionAIFunction? CachedDefaultInstance { get; set; }

@@ -320,7 +346,7 @@ static bool IsAsyncMethod(MethodInfo method)
/// <summary>
/// Gets a delegate for handling the marshaling of a parameter.
/// </summary>
private static Func<IReadOnlyDictionary<string, object?>, CancellationToken, object?> GetParameterMarshaller(
private static Func<IReadOnlyDictionary<string, object?>, IServiceProvider?, CancellationToken, object?> GetParameterMarshaller(
JsonSerializerOptions serializerOptions,
ParameterInfo parameter)
{
@@ -336,13 +362,33 @@ static bool IsAsyncMethod(MethodInfo method)
// For CancellationToken parameters, we always bind to the token passed directly to InvokeAsync.
if (parameterType == typeof(CancellationToken))
{
return static (_, cancellationToken) =>
return static (_, _, cancellationToken) =>
cancellationToken == default ? _boxedDefaultCancellationToken : // optimize common case of a default CT to avoid boxing
cancellationToken;
}

// For IServiceProvider parameters, we always bind to the services passed directly to InvokeAsync.
if (parameterType == typeof(IServiceProvider))
{
return (arguments, services, _) =>
{
if (services is not null)
{
return services;
}

if (!parameter.HasDefaultValue)
{
Throw.ArgumentException(nameof(arguments), $"An {nameof(IServiceProvider)} was not provided for the {parameter.Name} parameter.");
}

// The IServiceProvider parameter was optional. Return the default value.
return null;
};
}

// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
return (arguments, _) =>
return (arguments, _, _) =>
{
// If the parameter has an argument specified in the dictionary, return that argument.
if (arguments.TryGetValue(parameter.Name, out object? value))
@@ -359,7 +405,6 @@ static bool IsAsyncMethod(MethodInfo method)

object? MarshallViaJsonRoundtrip(object value)
{
#pragma warning disable CA1031 // Do not catch general exception types
try
{
string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType()));
@@ -370,7 +415,6 @@ static bool IsAsyncMethod(MethodInfo method)
// Eat any exceptions and fall back to the original value to force a cast exception later on.
return value;
}
#pragma warning restore CA1031
}
}

@@ -482,9 +526,7 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT
#if NET
return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition);
#else
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken);
#endif
}
Original file line number Diff line number Diff line change
@@ -251,7 +251,10 @@ private sealed class NetTypelessAIFunction : AIFunction

public override string Name => "NetTypeless";
public override string Description => "AIFunction with parameters that lack .NET types";
protected override Task<object?> InvokeCoreAsync(IEnumerable<KeyValuePair<string, object?>>? arguments, CancellationToken cancellationToken) =>
protected override Task<object?> InvokeCoreAsync(
IEnumerable<KeyValuePair<string, object?>>? arguments,
IServiceProvider? services,
CancellationToken cancellationToken) =>
Task.FromResult<object?>(arguments);
}

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
@@ -16,13 +17,13 @@ public async Task InvokeAsync_UsesDefaultEmptyCollectionForNullArgsAsync()
DerivedAIFunction f = new();

using CancellationTokenSource cts = new();
var result1 = ((IEnumerable<KeyValuePair<string, object?>>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!;
var result1 = ((IEnumerable<KeyValuePair<string, object?>>, CancellationToken))(await f.InvokeAsync(null, null, cts.Token))!;

Assert.NotNull(result1.Item1);
Assert.Empty(result1.Item1);
Assert.Equal(cts.Token, result1.Item2);

var result2 = ((IEnumerable<KeyValuePair<string, object?>>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!;
var result2 = ((IEnumerable<KeyValuePair<string, object?>>, CancellationToken))(await f.InvokeAsync(null, null, cts.Token))!;
Assert.Same(result1.Item1, result2.Item1);
}

@@ -38,7 +39,10 @@ private sealed class DerivedAIFunction : AIFunction
public override string Name => "name";
public override string Description => "";

protected override Task<object?> InvokeCoreAsync(IEnumerable<KeyValuePair<string, object?>> arguments, CancellationToken cancellationToken)
protected override Task<object?> InvokeCoreAsync(
IEnumerable<KeyValuePair<string, object?>> arguments,
IServiceProvider? services,
CancellationToken cancellationToken)
{
Assert.NotNull(arguments);
return Task.FromResult<object?>((arguments, cancellationToken));
Original file line number Diff line number Diff line change
@@ -606,6 +606,32 @@ public async Task PropagatesResponseChatThreadIdToOptions()
Assert.Equal("done!", (await service.GetStreamingResponseAsync("hey", options).ToChatResponseAsync()).ToString());
}

[Fact]
public async Task FunctionInvocations_PassesServices()
{
List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["arg1"] = "value1" })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
new ChatMessage(ChatRole.Assistant, "world"),
];

ServiceCollection c = new();
IServiceProvider expected = c.BuildServiceProvider();

var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create((IServiceProvider actual) =>
{
Assert.Same(expected, actual);
return "Result 1";
}, "Func1")]
};

await InvokeAndAssertAsync(options, plan, services: expected);
}

private static async Task<List<ChatMessage>> InvokeAndAssertAsync(
ChatOptions options,
List<ChatMessage> plan,
Loading
Loading