Skip to content

Commit f5e7c7b

Browse files
committed
Change AIFunction.InvokeAsync to accept AIFunctionArguments
- Adds a new AIFunctionArguments type. - Changes AIFunction.InvokeAsync to accept an AIFunctionArguments instead of an arbitrary enumerable. - Changes FunctionInvokingChatClient to accept an IServiceProvider and expose it as a Services property, and to then pass that IServiceProvider into the AIFunction via AIFunctionArguments.Services. - Augments FunctionInvocationContext with an AIFunctionArguments property. - Changes AIFunctionFactory to special-case parameters of type IServiceProvider and AIFunctionArguments, sourcing from AIFunctionArguments. - Makes AIJsonSchemaCreateOptions a record.
1 parent a70dcfa commit f5e7c7b

File tree

17 files changed

+519
-112
lines changed

17 files changed

+519
-112
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs

+4-10
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4-
using System.Collections.Generic;
54
using System.Reflection;
65
using System.Text.Json;
76
using System.Threading;
87
using System.Threading.Tasks;
9-
using Microsoft.Shared.Collections;
108

119
namespace Microsoft.Extensions.AI;
1210

@@ -56,19 +54,15 @@ public abstract class AIFunction : AITool
5654
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
5755
/// <returns>The result of the function's execution.</returns>
5856
public Task<object?> InvokeAsync(
59-
IEnumerable<KeyValuePair<string, object?>>? arguments = null,
60-
CancellationToken cancellationToken = default)
61-
{
62-
arguments ??= EmptyReadOnlyDictionary<string, object?>.Instance;
63-
64-
return InvokeCoreAsync(arguments, cancellationToken);
65-
}
57+
AIFunctionArguments? arguments = null,
58+
CancellationToken cancellationToken = default) =>
59+
InvokeCoreAsync(arguments ?? [], cancellationToken);
6660

6761
/// <summary>Invokes the <see cref="AIFunction"/> and returns its result.</summary>
6862
/// <param name="arguments">The arguments to pass to the function's invocation.</param>
6963
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
7064
/// <returns>The result of the function's execution.</returns>
7165
protected abstract Task<object?> InvokeCoreAsync(
72-
IEnumerable<KeyValuePair<string, object?>> arguments,
66+
AIFunctionArguments arguments,
7367
CancellationToken cancellationToken);
7468
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections;
6+
using System.Collections.Generic;
7+
8+
#pragma warning disable SA1111 // Closing parenthesis should be on line of last parameter
9+
#pragma warning disable SA1112 // Closing parenthesis should be on line of opening parenthesis
10+
#pragma warning disable SA1114 // Parameter list should follow declaration
11+
#pragma warning disable CA1710 // Identifiers should have correct suffix
12+
13+
namespace Microsoft.Extensions.AI;
14+
15+
/// <summary>Represents arguments to be used with <see cref="AIFunction.InvokeAsync"/>.</summary>
16+
/// <remarks>
17+
/// <see cref="AIFunctionArguments"/> is a dictionary of name/value pairs that are used
18+
/// as inputs to an <see cref="AIFunction"/>. However, an instance carries additional non-nominal
19+
/// information, such as an optional <see cref="IServiceProvider"/> that can be used by
20+
/// an <see cref="AIFunction"/> if it needs to resolve any services from a dependency injection
21+
/// container.
22+
/// </remarks>
23+
public sealed class AIFunctionArguments : IDictionary<string, object?>, IReadOnlyDictionary<string, object?>
24+
{
25+
/// <summary>The nominal arguments.</summary>
26+
private readonly Dictionary<string, object?> _arguments;
27+
28+
/// <summary>Initializes a new instance of the <see cref="AIFunctionArguments"/> class.</summary>
29+
public AIFunctionArguments()
30+
{
31+
_arguments = [];
32+
}
33+
34+
/// <summary>
35+
/// Initializes a new instance of the <see cref="AIFunctionArguments"/> class containing
36+
/// the specified <paramref name="arguments"/>.
37+
/// </summary>
38+
/// <param name="arguments">The arguments represented by this instance.</param>
39+
/// <remarks>
40+
/// The <paramref name="arguments"/> reference will be stored if the instance is
41+
/// already a <see cref="Dictionary{TKey, TValue}"/>, in which case all dictionary
42+
/// operations on this instance will be routed directly to that instance. If <paramref name="arguments"/>
43+
/// is not a dictionary, a shallow clone of its data will be used to populate this
44+
/// instance. A <see langword="null"/> <paramref name="arguments"/> is treated as an
45+
/// empty dictionary.
46+
/// </remarks>
47+
public AIFunctionArguments(IDictionary<string, object?>? arguments)
48+
{
49+
_arguments =
50+
arguments is null ? [] :
51+
arguments as Dictionary<string, object?> ??
52+
new Dictionary<string, object?>(arguments);
53+
}
54+
55+
/// <summary>Gets or sets services optionally associated with these arguments.</summary>
56+
public IServiceProvider? Services { get; set; }
57+
58+
/// <inheritdoc />
59+
public object? this[string key]
60+
{
61+
get => _arguments[key];
62+
set => _arguments[key] = value;
63+
}
64+
65+
/// <inheritdoc />
66+
public ICollection<string> Keys => _arguments.Keys;
67+
68+
/// <inheritdoc />
69+
public ICollection<object?> Values => _arguments.Values;
70+
71+
/// <inheritdoc />
72+
public int Count => _arguments.Count;
73+
74+
/// <inheritdoc />
75+
bool ICollection<KeyValuePair<string, object?>>.IsReadOnly => false;
76+
77+
/// <inheritdoc />
78+
IEnumerable<string> IReadOnlyDictionary<string, object?>.Keys => Keys;
79+
80+
/// <inheritdoc />
81+
IEnumerable<object?> IReadOnlyDictionary<string, object?>.Values => Values;
82+
83+
/// <inheritdoc />
84+
public void Add(string key, object? value) => _arguments.Add(key, value);
85+
86+
/// <inheritdoc />
87+
void ICollection<KeyValuePair<string, object?>>.Add(KeyValuePair<string, object?> item) =>
88+
((ICollection<KeyValuePair<string, object?>>)_arguments).Add(item);
89+
90+
/// <inheritdoc />
91+
public void Clear() => _arguments.Clear();
92+
93+
/// <inheritdoc />
94+
bool ICollection<KeyValuePair<string, object?>>.Contains(KeyValuePair<string, object?> item) =>
95+
((ICollection<KeyValuePair<string, object?>>)_arguments).Contains(item);
96+
97+
/// <inheritdoc />
98+
public bool ContainsKey(string key) => _arguments.ContainsKey(key);
99+
100+
/// <inheritdoc />
101+
public void CopyTo(KeyValuePair<string, object?>[] array, int arrayIndex) =>
102+
((ICollection<KeyValuePair<string, object?>>)_arguments).CopyTo(array, arrayIndex);
103+
104+
/// <inheritdoc />
105+
public IEnumerator<KeyValuePair<string, object?>> GetEnumerator() => _arguments.GetEnumerator();
106+
107+
/// <inheritdoc />
108+
public bool Remove(string key) => _arguments.Remove(key);
109+
110+
/// <inheritdoc />
111+
bool ICollection<KeyValuePair<string, object?>>.Remove(KeyValuePair<string, object?> item) =>
112+
((ICollection<KeyValuePair<string, object?>>)_arguments).Remove(item);
113+
114+
/// <inheritdoc />
115+
public bool TryGetValue(string key, out object? value) => _arguments.TryGetValue(key, out value);
116+
117+
/// <inheritdoc />
118+
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
119+
}

src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs

+1-23
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace Microsoft.Extensions.AI;
1313
/// <summary>
1414
/// Provides options for configuring the behavior of <see cref="AIJsonUtilities"/> JSON schema creation functionality.
1515
/// </summary>
16-
public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOptions>
16+
public sealed record class AIJsonSchemaCreateOptions
1717
{
1818
/// <summary>
1919
/// Gets the default options instance.
@@ -56,26 +56,4 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOpt
5656
/// Gets a value indicating whether to mark all properties as required in the schema.
5757
/// </summary>
5858
public bool RequireAllProperties { get; init; } = true;
59-
60-
/// <inheritdoc/>
61-
public bool Equals(AIJsonSchemaCreateOptions? other) =>
62-
other is not null &&
63-
TransformSchemaNode == other.TransformSchemaNode &&
64-
IncludeParameter == other.IncludeParameter &&
65-
IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas &&
66-
DisallowAdditionalProperties == other.DisallowAdditionalProperties &&
67-
IncludeSchemaKeyword == other.IncludeSchemaKeyword &&
68-
RequireAllProperties == other.RequireAllProperties;
69-
70-
/// <inheritdoc />
71-
public override bool Equals(object? obj) => obj is AIJsonSchemaCreateOptions other && Equals(other);
72-
73-
/// <inheritdoc />
74-
public override int GetHashCode() =>
75-
(TransformSchemaNode,
76-
IncludeParameter,
77-
IncludeTypeInEnumSchemas,
78-
DisallowAdditionalProperties,
79-
IncludeSchemaKeyword,
80-
RequireAllProperties).GetHashCode();
8159
}

src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ private static JsonSerializerOptions CreateDefaultOptions()
108108
[JsonSerializable(typeof(Embedding<float>))]
109109
[JsonSerializable(typeof(Embedding<double>))]
110110
[JsonSerializable(typeof(AIContent))]
111+
[JsonSerializable(typeof(AIFunctionArguments))]
111112
[EditorBrowsable(EditorBrowsableState.Never)] // Never use JsonContext directly, use DefaultOptions instead.
112113
private sealed partial class JsonContext : JsonSerializerContext;
113114
}

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ private sealed class MetadataOnlyAIFunction(string name, string description, Jso
437437
public override string Description => description;
438438
public override JsonElement JsonSchema => schema;
439439
public override IReadOnlyDictionary<string, object?> AdditionalProperties => additionalProps;
440-
protected override Task<object?> InvokeCoreAsync(IEnumerable<KeyValuePair<string, object?>> arguments, CancellationToken cancellationToken) =>
440+
protected override Task<object?> InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) =>
441441
throw new InvalidOperationException($"The AI function '{Name}' does not support being invoked.");
442442
}
443443

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public static async Task HandleToolCallsAsync(
107107

108108
try
109109
{
110-
var result = await aiFunction.InvokeAsync(functionCallContent.Arguments, cancellationToken).ConfigureAwait(false);
110+
var result = await aiFunction.InvokeAsync(new(functionCallContent.Arguments), cancellationToken).ConfigureAwait(false);
111111
var resultJson = JsonSerializer.Serialize(result, jsonOptions.GetTypeInfo(typeof(object)));
112112
return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, resultJson);
113113
}

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs

+19-9
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,34 @@ public sealed class FunctionInvocationContext
2424
private AIFunction _function = _nopFunction;
2525

2626
/// <summary>The function call content information associated with this invocation.</summary>
27-
private FunctionCallContent _callContent = new(string.Empty, _nopFunction.Name, EmptyReadOnlyDictionary<string, object?>.Instance);
27+
private FunctionCallContent? _callContent;
28+
29+
/// <summary>The arguments used with the function.</summary>
30+
private AIFunctionArguments? _arguments;
2831

2932
/// <summary>Initializes a new instance of the <see cref="FunctionInvocationContext"/> class.</summary>
3033
public FunctionInvocationContext()
3134
{
3235
}
3336

37+
/// <summary>Gets or sets the AI function to be invoked.</summary>
38+
public AIFunction Function
39+
{
40+
get => _function;
41+
set => _function = Throw.IfNull(value);
42+
}
43+
44+
/// <summary>Gets or sets the arguments associated with this invocation.</summary>
45+
public AIFunctionArguments Arguments
46+
{
47+
get => _arguments ??= [];
48+
set => _arguments = Throw.IfNull(value);
49+
}
50+
3451
/// <summary>Gets or sets the function call content information associated with this invocation.</summary>
3552
public FunctionCallContent CallContent
3653
{
37-
get => _callContent;
54+
get => _callContent ??= new(string.Empty, _nopFunction.Name, EmptyReadOnlyDictionary<string, object?>.Instance);
3855
set => _callContent = Throw.IfNull(value);
3956
}
4057

@@ -48,13 +65,6 @@ public IList<ChatMessage> Messages
4865
/// <summary>Gets or sets the chat options associated with the operation that initiated this function call request.</summary>
4966
public ChatOptions? Options { get; set; }
5067

51-
/// <summary>Gets or sets the AI function to be invoked.</summary>
52-
public AIFunction Function
53-
{
54-
get => _function;
55-
set => _function = Throw.IfNull(value);
56-
}
57-
5868
/// <summary>Gets or sets the number of this iteration with the underlying client.</summary>
5969
/// <remarks>
6070
/// The initial request to the client that passes along the chat contents provided to the <see cref="FunctionInvokingChatClient"/>

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs

+15-7
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,14 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient
6262
/// Initializes a new instance of the <see cref="FunctionInvokingChatClient"/> class.
6363
/// </summary>
6464
/// <param name="innerClient">The underlying <see cref="IChatClient"/>, or the next instance in a chain of clients.</param>
65-
/// <param name="logger">An <see cref="ILogger"/> to use for logging information about function invocation.</param>
66-
public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null)
65+
/// <param name="loggerFactory">An <see cref="ILoggerFactory"/> to use for logging information about function invocation.</param>
66+
/// <param name="services">An optional <see cref="IServiceProvider"/> to use for resolving services required by the <see cref="AIFunction"/> instances being invoked.</param>
67+
public FunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? services = null)
6768
: base(innerClient)
6869
{
69-
_logger = logger ?? NullLogger.Instance;
70+
_logger = (ILogger?)loggerFactory?.CreateLogger<FunctionInvokingChatClient>() ?? NullLogger.Instance;
7071
_activitySource = innerClient.GetService<ActivitySource>();
72+
Services = services;
7173
}
7274

7375
/// <summary>
@@ -82,6 +84,9 @@ public static FunctionInvocationContext? CurrentContext
8284
protected set => _currentContext.Value = value;
8385
}
8486

87+
/// <summary>Gets the <see cref="IServiceProvider"/> used for resolving services required by the <see cref="AIFunction"/> instances being invoked.</summary>
88+
public IServiceProvider? Services { get; }
89+
8590
/// <summary>
8691
/// Gets or sets a value indicating whether to handle exceptions that occur during function calls.
8792
/// </summary>
@@ -601,10 +606,13 @@ private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
601606

602607
FunctionInvocationContext context = new()
603608
{
609+
Function = function,
610+
Arguments = new(callContent.Arguments) { Services = Services },
611+
604612
Messages = messages,
605613
Options = options,
614+
606615
CallContent = callContent,
607-
Function = function,
608616
Iteration = iteration,
609617
FunctionCallIndex = functionCallIndex,
610618
FunctionCount = callContents.Count,
@@ -710,7 +718,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
710718
startingTimestamp = Stopwatch.GetTimestamp();
711719
if (_logger.IsEnabled(LogLevel.Trace))
712720
{
713-
LogInvokingSensitive(context.Function.Name, LoggingHelpers.AsJson(context.CallContent.Arguments, context.Function.JsonSerializerOptions));
721+
LogInvokingSensitive(context.Function.Name, LoggingHelpers.AsJson(context.Arguments, context.Function.JsonSerializerOptions));
714722
}
715723
else
716724
{
@@ -721,8 +729,8 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
721729
object? result = null;
722730
try
723731
{
724-
CurrentContext = context;
725-
result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
732+
CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit
733+
result = await context.Function.InvokeAsync(context.Arguments, cancellationToken).ConfigureAwait(false);
726734
}
727735
catch (Exception e)
728736
{

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public static ChatClientBuilder UseFunctionInvocation(
3333
{
3434
loggerFactory ??= services.GetService<ILoggerFactory>();
3535

36-
var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)));
36+
var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory, services);
3737
configure?.Invoke(chatClient);
3838
return chatClient;
3939
});

0 commit comments

Comments
 (0)