diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs
index 667a956a2f7..3d5023d3a0a 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System;
using System.Collections.Generic;
using System.Reflection;
using System.Text.Json;
@@ -53,22 +54,26 @@ public abstract class AIFunction : AITool
/// Invokes the and returns its result.
/// The arguments to pass to the function's invocation.
+ /// The optionally associated with this invocation.
/// The to monitor for cancellation requests. The default is .
/// The result of the function's execution.
public Task InvokeAsync(
IEnumerable>? arguments = null,
+ IServiceProvider? services = null,
CancellationToken cancellationToken = default)
{
arguments ??= EmptyReadOnlyDictionary.Instance;
- return InvokeCoreAsync(arguments, cancellationToken);
+ return InvokeCoreAsync(arguments, services, cancellationToken);
}
/// Invokes the and returns its result.
/// The arguments to pass to the function's invocation.
+ /// The optionally associated with this invocation.
/// The to monitor for cancellation requests.
/// The result of the function's execution.
protected abstract Task InvokeCoreAsync(
IEnumerable> arguments,
+ IServiceProvider? services,
CancellationToken cancellationToken);
}
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs
index faf9bb71d7c..a58efb44bbc 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs
@@ -79,11 +79,15 @@ public static JsonElement CreateFunctionJsonSchema(
Throw.ArgumentException(nameof(parameter), "Parameter is missing a name.");
}
- if (parameter.ParameterType == typeof(CancellationToken))
+ if (parameter.ParameterType == typeof(CancellationToken) ||
+ parameter.ParameterType == typeof(IServiceProvider))
{
// CancellationToken is a special case that, by convention, we don't want to include in the schema.
// Invocations of methods that include a CancellationToken argument should also special-case CancellationToken
// to pass along what relevant token into the method's invocation.
+
+ // IServiceProvider is a special because it's directly handled by AIFunction.
+
continue;
}
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs
index d74505e64f8..082b7da710f 100644
--- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs
@@ -107,7 +107,7 @@ public static async Task HandleToolCallsAsync(
try
{
- var result = await aiFunction.InvokeAsync(functionCallContent.Arguments, cancellationToken).ConfigureAwait(false);
+ var result = await aiFunction.InvokeAsync(functionCallContent.Arguments, null, cancellationToken).ConfigureAwait(false);
var resultJson = JsonSerializer.Serialize(result, jsonOptions.GetTypeInfo(typeof(object)));
return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, resultJson);
}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
index cf0d25b3f17..ba91fff3225 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
@@ -9,6 +9,7 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Shared.Diagnostics;
@@ -63,11 +64,13 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient
///
/// The underlying , or the next instance in a chain of clients.
/// An to use for logging information about function invocation.
- public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null)
+ /// An optional to use for resolving services required by the instances being invoked.
+ public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null, IServiceProvider? services = null)
: base(innerClient)
{
- _logger = logger ?? NullLogger.Instance;
+ _logger = logger ?? (ILogger?)services?.GetService>() ?? NullLogger.Instance;
_activitySource = innerClient.GetService();
+ Services = services;
}
///
@@ -82,6 +85,9 @@ public static FunctionInvocationContext? CurrentContext
protected set => _currentContext.Value = value;
}
+ /// Gets the used for resolving services required by the instances being invoked.
+ public IServiceProvider? Services { get; }
+
///
/// Gets or sets a value indicating whether to handle exceptions that occur during function calls.
///
@@ -722,7 +728,10 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
try
{
CurrentContext = context;
- result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
+ result = await context.Function.InvokeAsync(
+ context.CallContent.Arguments,
+ Services,
+ cancellationToken).ConfigureAwait(false);
}
catch (Exception e)
{
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs
index f2a60718ea9..ae0cb3e316b 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs
@@ -33,7 +33,7 @@ public static ChatClientBuilder UseFunctionInvocation(
{
loggerFactory ??= services.GetService();
- var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)));
+ var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)), services);
configure?.Invoke(chatClient);
return chatClient;
});
diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
index f81ee89fb6d..698f65787d6 100644
--- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
@@ -17,6 +17,9 @@
using Microsoft.Shared.Collections;
using Microsoft.Shared.Diagnostics;
+#pragma warning disable CA1031 // Do not catch general exception types
+#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
+
namespace Microsoft.Extensions.AI;
/// Provides factory methods for creating commonly used implementations of .
@@ -190,14 +193,15 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor,
public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions;
protected override Task InvokeCoreAsync(
IEnumerable>? arguments,
+ IServiceProvider? services,
CancellationToken cancellationToken)
{
var paramMarshallers = FunctionDescriptor.ParameterMarshallers;
object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : [];
IReadOnlyDictionary argDict =
- arguments is null || args.Length == 0 ? EmptyReadOnlyDictionary.Instance :
- arguments as IReadOnlyDictionary ??
+ arguments is null ? EmptyReadOnlyDictionary.Instance :
+ arguments as IReadOnlyDictionary ?? // if arguments is an AIFunctionArguments, which is an IROD, use it as-is
arguments.
#if NET8_0_OR_GREATER
ToDictionary();
@@ -206,7 +210,7 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor,
#endif
for (int i = 0; i < args.Length; i++)
{
- args[i] = paramMarshallers[i](argDict, cancellationToken);
+ args[i] = paramMarshallers[i](argDict, services, cancellationToken);
}
return FunctionDescriptor.ReturnParameterMarshaller(ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken);
@@ -248,9 +252,31 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu
private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions)
{
+ AIJsonSchemaCreateOptions schemaOptions = new()
+ {
+ // This needs to be kept in sync with the shape of AIJsonSchemaCreateOptions.
+ TransformSchemaNode = key.SchemaOptions.TransformSchemaNode,
+ IncludeParameter = parameterInfo =>
+ {
+ // Explicitly exclude IServiceProvider. It'll be satisifed via AIFunctionArguments.
+ if (parameterInfo.ParameterType == typeof(IServiceProvider))
+ {
+ return false;
+ }
+
+ // For all other parameters, delegate to whatever behavior is specified in the options.
+ // If none is specified, include the parameter.
+ return key.SchemaOptions.IncludeParameter?.Invoke(parameterInfo) ?? true;
+ },
+ IncludeTypeInEnumSchemas = key.SchemaOptions.IncludeTypeInEnumSchemas,
+ DisallowAdditionalProperties = key.SchemaOptions.DisallowAdditionalProperties,
+ IncludeSchemaKeyword = key.SchemaOptions.IncludeSchemaKeyword,
+ RequireAllProperties = key.SchemaOptions.RequireAllProperties,
+ };
+
// Get marshaling delegates for parameters.
ParameterInfo[] parameters = key.Method.GetParameters();
- ParameterMarshallers = new Func, CancellationToken, object?>[parameters.Length];
+ ParameterMarshallers = new Func, IServiceProvider?, CancellationToken, object?>[parameters.Length];
for (int i = 0; i < parameters.Length; i++)
{
ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, parameters[i]);
@@ -268,7 +294,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
Name,
Description,
serializerOptions,
- key.SchemaOptions);
+ schemaOptions);
}
public string Name { get; }
@@ -276,7 +302,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
public MethodInfo Method { get; }
public JsonSerializerOptions JsonSerializerOptions { get; }
public JsonElement JsonSchema { get; }
- public Func, CancellationToken, object?>[] ParameterMarshallers { get; }
+ public Func, IServiceProvider?, CancellationToken, object?>[] ParameterMarshallers { get; }
public Func> ReturnParameterMarshaller { get; }
public ReflectionAIFunction? CachedDefaultInstance { get; set; }
@@ -320,7 +346,7 @@ static bool IsAsyncMethod(MethodInfo method)
///
/// Gets a delegate for handling the marshaling of a parameter.
///
- private static Func, CancellationToken, object?> GetParameterMarshaller(
+ private static Func, IServiceProvider?, CancellationToken, object?> GetParameterMarshaller(
JsonSerializerOptions serializerOptions,
ParameterInfo parameter)
{
@@ -336,13 +362,33 @@ static bool IsAsyncMethod(MethodInfo method)
// For CancellationToken parameters, we always bind to the token passed directly to InvokeAsync.
if (parameterType == typeof(CancellationToken))
{
- return static (_, cancellationToken) =>
+ return static (_, _, cancellationToken) =>
cancellationToken == default ? _boxedDefaultCancellationToken : // optimize common case of a default CT to avoid boxing
cancellationToken;
}
+ // For IServiceProvider parameters, we always bind to the services passed directly to InvokeAsync.
+ if (parameterType == typeof(IServiceProvider))
+ {
+ return (arguments, services, _) =>
+ {
+ if (services is not null)
+ {
+ return services;
+ }
+
+ if (!parameter.HasDefaultValue)
+ {
+ Throw.ArgumentException(nameof(arguments), $"An {nameof(IServiceProvider)} was not provided for the {parameter.Name} parameter.");
+ }
+
+ // The IServiceProvider parameter was optional. Return the default value.
+ return null;
+ };
+ }
+
// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
- return (arguments, _) =>
+ return (arguments, _, _) =>
{
// If the parameter has an argument specified in the dictionary, return that argument.
if (arguments.TryGetValue(parameter.Name, out object? value))
@@ -359,7 +405,6 @@ static bool IsAsyncMethod(MethodInfo method)
object? MarshallViaJsonRoundtrip(object value)
{
-#pragma warning disable CA1031 // Do not catch general exception types
try
{
string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType()));
@@ -370,7 +415,6 @@ static bool IsAsyncMethod(MethodInfo method)
// Eat any exceptions and fall back to the original value to force a cast exception later on.
return value;
}
-#pragma warning restore CA1031
}
}
@@ -482,9 +526,7 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT
#if NET
return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition);
#else
-#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
-#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken);
#endif
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs
index 103bc884022..a1e41dbeebf 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs
@@ -251,7 +251,10 @@ private sealed class NetTypelessAIFunction : AIFunction
public override string Name => "NetTypeless";
public override string Description => "AIFunction with parameters that lack .NET types";
- protected override Task InvokeCoreAsync(IEnumerable>? arguments, CancellationToken cancellationToken) =>
+ protected override Task InvokeCoreAsync(
+ IEnumerable>? arguments,
+ IServiceProvider? services,
+ CancellationToken cancellationToken) =>
Task.FromResult(arguments);
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs
index 1ced6ae3185..51325badcb6 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
@@ -16,13 +17,13 @@ public async Task InvokeAsync_UsesDefaultEmptyCollectionForNullArgsAsync()
DerivedAIFunction f = new();
using CancellationTokenSource cts = new();
- var result1 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!;
+ var result1 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, null, cts.Token))!;
Assert.NotNull(result1.Item1);
Assert.Empty(result1.Item1);
Assert.Equal(cts.Token, result1.Item2);
- var result2 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!;
+ var result2 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, null, cts.Token))!;
Assert.Same(result1.Item1, result2.Item1);
}
@@ -38,7 +39,10 @@ private sealed class DerivedAIFunction : AIFunction
public override string Name => "name";
public override string Description => "";
- protected override Task InvokeCoreAsync(IEnumerable> arguments, CancellationToken cancellationToken)
+ protected override Task InvokeCoreAsync(
+ IEnumerable> arguments,
+ IServiceProvider? services,
+ CancellationToken cancellationToken)
{
Assert.NotNull(arguments);
return Task.FromResult((arguments, cancellationToken));
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
index 8d069034e15..acc3147ffbb 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
@@ -606,6 +606,32 @@ public async Task PropagatesResponseChatThreadIdToOptions()
Assert.Equal("done!", (await service.GetStreamingResponseAsync("hey", options).ToChatResponseAsync()).ToString());
}
+ [Fact]
+ public async Task FunctionInvocations_PassesServices()
+ {
+ List plan =
+ [
+ new ChatMessage(ChatRole.User, "hello"),
+ new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg1"] = "value1" })]),
+ new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
+ new ChatMessage(ChatRole.Assistant, "world"),
+ ];
+
+ ServiceCollection c = new();
+ IServiceProvider expected = c.BuildServiceProvider();
+
+ var options = new ChatOptions
+ {
+ Tools = [AIFunctionFactory.Create((IServiceProvider actual) =>
+ {
+ Assert.Same(expected, actual);
+ return "Result 1";
+ }, "Func1")]
+ };
+
+ await InvokeAndAssertAsync(options, plan, services: expected);
+ }
+
private static async Task> InvokeAndAssertAsync(
ChatOptions options,
List plan,
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs
index dc104ea6be6..22b2fa319e7 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs
@@ -8,6 +8,7 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.Extensions.DependencyInjection;
using Xunit;
namespace Microsoft.Extensions.AI;
@@ -216,4 +217,55 @@ public async Task AIFunctionFactoryOptions_SupportsSkippingParameters()
Assert.NotNull(result);
Assert.Contains("test42", result.ToString());
}
+
+ [Fact]
+ public async Task AIFunctionArguments_ServicesSatisfyParameters()
+ {
+ ServiceCollection sc = new();
+ IServiceProvider sp = sc.BuildServiceProvider();
+
+ AIFunction func = AIFunctionFactory.Create((
+ int myInteger,
+ IServiceProvider services1,
+ IServiceProvider services2,
+ IServiceProvider? services3,
+ IServiceProvider? services4 = null) =>
+ {
+ Assert.Same(sp, services1);
+ Assert.Same(sp, services2);
+ Assert.Same(sp, services3);
+ Assert.Same(sp, services4);
+ return myInteger;
+ });
+
+ Assert.Contains("myInteger", func.JsonSchema.ToString());
+ Assert.DoesNotContain("services", func.JsonSchema.ToString());
+
+ await Assert.ThrowsAsync("arguments", () => func.InvokeAsync([new KeyValuePair("myInteger", 42)]));
+ var result = await func.InvokeAsync([new KeyValuePair("myInteger", 42)], sp);
+
+ Assert.Contains("42", result?.ToString());
+ }
+
+ [Fact]
+ public async Task AIFunctionArguments_MissingServicesMayBeOptional()
+ {
+ ServiceCollection sc = new();
+ IServiceProvider sp = sc.BuildServiceProvider();
+
+ AIFunction func = AIFunctionFactory.Create((
+ int myInteger,
+ IServiceProvider? services = null) =>
+ {
+ Assert.Null(services);
+ return myInteger;
+ });
+
+ Assert.Contains("myInteger", func.JsonSchema.ToString());
+ Assert.DoesNotContain("services", func.JsonSchema.ToString());
+
+ var result = await func.InvokeAsync([new KeyValuePair("myInteger", 42)]);
+
+ Assert.Contains("42", result?.ToString());
+ }
}