diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs index 667a956a2f7..3d5023d3a0a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Reflection; using System.Text.Json; @@ -53,22 +54,26 @@ public abstract class AIFunction : AITool /// Invokes the and returns its result. /// The arguments to pass to the function's invocation. + /// The optionally associated with this invocation. /// The to monitor for cancellation requests. The default is . /// The result of the function's execution. public Task InvokeAsync( IEnumerable>? arguments = null, + IServiceProvider? services = null, CancellationToken cancellationToken = default) { arguments ??= EmptyReadOnlyDictionary.Instance; - return InvokeCoreAsync(arguments, cancellationToken); + return InvokeCoreAsync(arguments, services, cancellationToken); } /// Invokes the and returns its result. /// The arguments to pass to the function's invocation. + /// The optionally associated with this invocation. /// The to monitor for cancellation requests. /// The result of the function's execution. protected abstract Task InvokeCoreAsync( IEnumerable> arguments, + IServiceProvider? services, CancellationToken cancellationToken); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs index faf9bb71d7c..a58efb44bbc 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs @@ -79,11 +79,15 @@ public static JsonElement CreateFunctionJsonSchema( Throw.ArgumentException(nameof(parameter), "Parameter is missing a name."); } - if (parameter.ParameterType == typeof(CancellationToken)) + if (parameter.ParameterType == typeof(CancellationToken) || + parameter.ParameterType == typeof(IServiceProvider)) { // CancellationToken is a special case that, by convention, we don't want to include in the schema. // Invocations of methods that include a CancellationToken argument should also special-case CancellationToken // to pass along what relevant token into the method's invocation. + + // IServiceProvider is a special because it's directly handled by AIFunction. + continue; } diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs index d74505e64f8..082b7da710f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs @@ -107,7 +107,7 @@ public static async Task HandleToolCallsAsync( try { - var result = await aiFunction.InvokeAsync(functionCallContent.Arguments, cancellationToken).ConfigureAwait(false); + var result = await aiFunction.InvokeAsync(functionCallContent.Arguments, null, cancellationToken).ConfigureAwait(false); var resultJson = JsonSerializer.Serialize(result, jsonOptions.GetTypeInfo(typeof(object))); return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, resultJson); } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index cf0d25b3f17..ba91fff3225 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -9,6 +9,7 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Shared.Diagnostics; @@ -63,11 +64,13 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient /// /// The underlying , or the next instance in a chain of clients. /// An to use for logging information about function invocation. - public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null) + /// An optional to use for resolving services required by the instances being invoked. + public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null, IServiceProvider? services = null) : base(innerClient) { - _logger = logger ?? NullLogger.Instance; + _logger = logger ?? (ILogger?)services?.GetService>() ?? NullLogger.Instance; _activitySource = innerClient.GetService(); + Services = services; } /// @@ -82,6 +85,9 @@ public static FunctionInvocationContext? CurrentContext protected set => _currentContext.Value = value; } + /// Gets the used for resolving services required by the instances being invoked. + public IServiceProvider? Services { get; } + /// /// Gets or sets a value indicating whether to handle exceptions that occur during function calls. /// @@ -722,7 +728,10 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul try { CurrentContext = context; - result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false); + result = await context.Function.InvokeAsync( + context.CallContent.Arguments, + Services, + cancellationToken).ConfigureAwait(false); } catch (Exception e) { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs index f2a60718ea9..ae0cb3e316b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs @@ -33,7 +33,7 @@ public static ChatClientBuilder UseFunctionInvocation( { loggerFactory ??= services.GetService(); - var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient))); + var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)), services); configure?.Invoke(chatClient); return chatClient; }); diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index f81ee89fb6d..698f65787d6 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -17,6 +17,9 @@ using Microsoft.Shared.Collections; using Microsoft.Shared.Diagnostics; +#pragma warning disable CA1031 // Do not catch general exception types +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields + namespace Microsoft.Extensions.AI; /// Provides factory methods for creating commonly used implementations of . @@ -190,14 +193,15 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; protected override Task InvokeCoreAsync( IEnumerable>? arguments, + IServiceProvider? services, CancellationToken cancellationToken) { var paramMarshallers = FunctionDescriptor.ParameterMarshallers; object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : []; IReadOnlyDictionary argDict = - arguments is null || args.Length == 0 ? EmptyReadOnlyDictionary.Instance : - arguments as IReadOnlyDictionary ?? + arguments is null ? EmptyReadOnlyDictionary.Instance : + arguments as IReadOnlyDictionary ?? // if arguments is an AIFunctionArguments, which is an IROD, use it as-is arguments. #if NET8_0_OR_GREATER ToDictionary(); @@ -206,7 +210,7 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, #endif for (int i = 0; i < args.Length; i++) { - args[i] = paramMarshallers[i](argDict, cancellationToken); + args[i] = paramMarshallers[i](argDict, services, cancellationToken); } return FunctionDescriptor.ReturnParameterMarshaller(ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken); @@ -248,9 +252,31 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions) { + AIJsonSchemaCreateOptions schemaOptions = new() + { + // This needs to be kept in sync with the shape of AIJsonSchemaCreateOptions. + TransformSchemaNode = key.SchemaOptions.TransformSchemaNode, + IncludeParameter = parameterInfo => + { + // Explicitly exclude IServiceProvider. It'll be satisifed via AIFunctionArguments. + if (parameterInfo.ParameterType == typeof(IServiceProvider)) + { + return false; + } + + // For all other parameters, delegate to whatever behavior is specified in the options. + // If none is specified, include the parameter. + return key.SchemaOptions.IncludeParameter?.Invoke(parameterInfo) ?? true; + }, + IncludeTypeInEnumSchemas = key.SchemaOptions.IncludeTypeInEnumSchemas, + DisallowAdditionalProperties = key.SchemaOptions.DisallowAdditionalProperties, + IncludeSchemaKeyword = key.SchemaOptions.IncludeSchemaKeyword, + RequireAllProperties = key.SchemaOptions.RequireAllProperties, + }; + // Get marshaling delegates for parameters. ParameterInfo[] parameters = key.Method.GetParameters(); - ParameterMarshallers = new Func, CancellationToken, object?>[parameters.Length]; + ParameterMarshallers = new Func, IServiceProvider?, CancellationToken, object?>[parameters.Length]; for (int i = 0; i < parameters.Length; i++) { ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, parameters[i]); @@ -268,7 +294,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions Name, Description, serializerOptions, - key.SchemaOptions); + schemaOptions); } public string Name { get; } @@ -276,7 +302,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions public MethodInfo Method { get; } public JsonSerializerOptions JsonSerializerOptions { get; } public JsonElement JsonSchema { get; } - public Func, CancellationToken, object?>[] ParameterMarshallers { get; } + public Func, IServiceProvider?, CancellationToken, object?>[] ParameterMarshallers { get; } public Func> ReturnParameterMarshaller { get; } public ReflectionAIFunction? CachedDefaultInstance { get; set; } @@ -320,7 +346,7 @@ static bool IsAsyncMethod(MethodInfo method) /// /// Gets a delegate for handling the marshaling of a parameter. /// - private static Func, CancellationToken, object?> GetParameterMarshaller( + private static Func, IServiceProvider?, CancellationToken, object?> GetParameterMarshaller( JsonSerializerOptions serializerOptions, ParameterInfo parameter) { @@ -336,13 +362,33 @@ static bool IsAsyncMethod(MethodInfo method) // For CancellationToken parameters, we always bind to the token passed directly to InvokeAsync. if (parameterType == typeof(CancellationToken)) { - return static (_, cancellationToken) => + return static (_, _, cancellationToken) => cancellationToken == default ? _boxedDefaultCancellationToken : // optimize common case of a default CT to avoid boxing cancellationToken; } + // For IServiceProvider parameters, we always bind to the services passed directly to InvokeAsync. + if (parameterType == typeof(IServiceProvider)) + { + return (arguments, services, _) => + { + if (services is not null) + { + return services; + } + + if (!parameter.HasDefaultValue) + { + Throw.ArgumentException(nameof(arguments), $"An {nameof(IServiceProvider)} was not provided for the {parameter.Name} parameter."); + } + + // The IServiceProvider parameter was optional. Return the default value. + return null; + }; + } + // For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary. - return (arguments, _) => + return (arguments, _, _) => { // If the parameter has an argument specified in the dictionary, return that argument. if (arguments.TryGetValue(parameter.Name, out object? value)) @@ -359,7 +405,6 @@ static bool IsAsyncMethod(MethodInfo method) object? MarshallViaJsonRoundtrip(object value) { -#pragma warning disable CA1031 // Do not catch general exception types try { string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType())); @@ -370,7 +415,6 @@ static bool IsAsyncMethod(MethodInfo method) // Eat any exceptions and fall back to the original value to force a cast exception later on. return value; } -#pragma warning restore CA1031 } } @@ -482,9 +526,7 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT #if NET return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition); #else -#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance; -#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken); #endif } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs index 103bc884022..a1e41dbeebf 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs @@ -251,7 +251,10 @@ private sealed class NetTypelessAIFunction : AIFunction public override string Name => "NetTypeless"; public override string Description => "AIFunction with parameters that lack .NET types"; - protected override Task InvokeCoreAsync(IEnumerable>? arguments, CancellationToken cancellationToken) => + protected override Task InvokeCoreAsync( + IEnumerable>? arguments, + IServiceProvider? services, + CancellationToken cancellationToken) => Task.FromResult(arguments); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs index 1ced6ae3185..51325badcb6 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -16,13 +17,13 @@ public async Task InvokeAsync_UsesDefaultEmptyCollectionForNullArgsAsync() DerivedAIFunction f = new(); using CancellationTokenSource cts = new(); - var result1 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!; + var result1 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, null, cts.Token))!; Assert.NotNull(result1.Item1); Assert.Empty(result1.Item1); Assert.Equal(cts.Token, result1.Item2); - var result2 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!; + var result2 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, null, cts.Token))!; Assert.Same(result1.Item1, result2.Item1); } @@ -38,7 +39,10 @@ private sealed class DerivedAIFunction : AIFunction public override string Name => "name"; public override string Description => ""; - protected override Task InvokeCoreAsync(IEnumerable> arguments, CancellationToken cancellationToken) + protected override Task InvokeCoreAsync( + IEnumerable> arguments, + IServiceProvider? services, + CancellationToken cancellationToken) { Assert.NotNull(arguments); return Task.FromResult((arguments, cancellationToken)); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index 8d069034e15..acc3147ffbb 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -606,6 +606,32 @@ public async Task PropagatesResponseChatThreadIdToOptions() Assert.Equal("done!", (await service.GetStreamingResponseAsync("hey", options).ToChatResponseAsync()).ToString()); } + [Fact] + public async Task FunctionInvocations_PassesServices() + { + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg1"] = "value1" })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; + + ServiceCollection c = new(); + IServiceProvider expected = c.BuildServiceProvider(); + + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create((IServiceProvider actual) => + { + Assert.Same(expected, actual); + return "Result 1"; + }, "Func1")] + }; + + await InvokeAndAssertAsync(options, plan, services: expected); + } + private static async Task> InvokeAndAssertAsync( ChatOptions options, List plan, diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs index dc104ea6be6..22b2fa319e7 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -8,6 +8,7 @@ using System.Text.Json; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; using Xunit; namespace Microsoft.Extensions.AI; @@ -216,4 +217,55 @@ public async Task AIFunctionFactoryOptions_SupportsSkippingParameters() Assert.NotNull(result); Assert.Contains("test42", result.ToString()); } + + [Fact] + public async Task AIFunctionArguments_ServicesSatisfyParameters() + { + ServiceCollection sc = new(); + IServiceProvider sp = sc.BuildServiceProvider(); + + AIFunction func = AIFunctionFactory.Create(( + int myInteger, + IServiceProvider services1, + IServiceProvider services2, + IServiceProvider? services3, + IServiceProvider? services4 = null) => + { + Assert.Same(sp, services1); + Assert.Same(sp, services2); + Assert.Same(sp, services3); + Assert.Same(sp, services4); + return myInteger; + }); + + Assert.Contains("myInteger", func.JsonSchema.ToString()); + Assert.DoesNotContain("services", func.JsonSchema.ToString()); + + await Assert.ThrowsAsync("arguments", () => func.InvokeAsync([new KeyValuePair("myInteger", 42)])); + var result = await func.InvokeAsync([new KeyValuePair("myInteger", 42)], sp); + + Assert.Contains("42", result?.ToString()); + } + + [Fact] + public async Task AIFunctionArguments_MissingServicesMayBeOptional() + { + ServiceCollection sc = new(); + IServiceProvider sp = sc.BuildServiceProvider(); + + AIFunction func = AIFunctionFactory.Create(( + int myInteger, + IServiceProvider? services = null) => + { + Assert.Null(services); + return myInteger; + }); + + Assert.Contains("myInteger", func.JsonSchema.ToString()); + Assert.DoesNotContain("services", func.JsonSchema.ToString()); + + var result = await func.InvokeAsync([new KeyValuePair("myInteger", 42)]); + + Assert.Contains("42", result?.ToString()); + } }