Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,70 @@ public async Task IfAutoInvokeMaximumAutoInvokeAttemptsReachedShouldStopInvoking
c is GeminiChatMessageContent gm && gm.Role == AuthorRole.Tool && gm.CalledToolResult is not null);
}

[Fact]
public async Task ShouldBatchMultipleToolResponsesIntoSingleMessageAsync()
{
// Arrange
var responseContentWithMultipleFunctions = File.ReadAllText("./TestData/chat_multiple_function_calls_response.json")
.Replace("%nameSeparator%", GeminiFunction.NameSeparator, StringComparison.Ordinal);

using var handlerStub = new MultipleHttpMessageHandlerStub();
handlerStub.AddJsonResponse(responseContentWithMultipleFunctions);
handlerStub.AddJsonResponse(this._responseContent); // Final response after tool execution

#pragma warning disable CA2000
var client = this.CreateChatCompletionClient(httpClient: handlerStub.CreateHttpClient());
#pragma warning restore CA2000
var chatHistory = CreateSampleChatHistory();
var executionSettings = new GeminiPromptExecutionSettings
{
ToolCallBehavior = GeminiToolCallBehavior.AutoInvokeKernelFunctions
};

// Act
await client.GenerateChatMessageAsync(chatHistory, executionSettings: executionSettings, kernel: this._kernelWithFunctions);

// Assert
// Find the tool response message that should be batched
var toolResponseMessage = chatHistory.OfType<GeminiChatMessageContent>()
.FirstOrDefault(m => m.Role == AuthorRole.Tool && m.CalledToolResults != null);

Assert.NotNull(toolResponseMessage);
Assert.NotNull(toolResponseMessage.CalledToolResults);

// Verify that multiple tool results are batched into a single message
Assert.Equal(2, toolResponseMessage.CalledToolResults.Count);

// Verify the specific tool calls that were batched
var toolNames = toolResponseMessage.CalledToolResults.Select(tr => tr.FullyQualifiedName).ToArray();
Assert.Contains(this._timePluginNow.FullyQualifiedName, toolNames);
Assert.Contains(this._timePluginDate.FullyQualifiedName, toolNames);

// Verify backward compatibility - CalledToolResult property should return the first result
Assert.NotNull(toolResponseMessage.CalledToolResult);
Assert.Equal(toolResponseMessage.CalledToolResults[0], toolResponseMessage.CalledToolResult);

// Verify the request that would be sent to Gemini contains the correct structure
var requestJson = handlerStub.GetRequestContentAsString(1); // Get the second request (after tool execution)
Assert.NotNull(requestJson);
var request = JsonSerializer.Deserialize<GeminiRequest>(requestJson);
Assert.NotNull(request);

// Find the content that represents the batched tool responses
var toolResponseContent = request.Contents.FirstOrDefault(c => c.Role == AuthorRole.Tool);
Assert.NotNull(toolResponseContent);
Assert.NotNull(toolResponseContent.Parts);

// Verify that all function responses are included as separate parts in the single message
var functionResponseParts = toolResponseContent.Parts.Where(p => p.FunctionResponse != null).ToArray();
Assert.Equal(2, functionResponseParts.Length);

// Verify each function response part corresponds to the tool calls
var functionNames = functionResponseParts.Select(p => p.FunctionResponse!.FunctionName).ToArray();
Assert.Contains(this._timePluginNow.FullyQualifiedName, functionNames);
Assert.Contains(this._timePluginDate.FullyQualifiedName, functionNames);
}

private static ChatHistory CreateSampleChatHistory()
{
var chatHistory = new ChatHistory();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ public void AddChatMessageToRequest()
// Arrange
ChatHistory chat = [];
var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chat, new GeminiPromptExecutionSettings());
var message = new GeminiChatMessageContent(AuthorRole.User, "user-message", "model-id");
var message = new GeminiChatMessageContent(AuthorRole.User, "user-message", "model-id", calledToolResults: null);

// Act
request.AddChatMessage(message);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
{
"candidates": [
{
"content": {
"parts": [
{
"text": "I'll help you get the current time and date. Let me call both functions for you."
},
{
"functionCall": {
"name": "TimePlugin%nameSeparator%Now",
"args": {
"param1": "current time"
}
}
},
{
"functionCall": {
"name": "TimePlugin%nameSeparator%Date",
"args": {
"format": "yyyy-MM-dd"
}
}
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
}
],
"promptFeedback": {
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
},
"usageMetadata": {
"promptTokenCount": 50,
"candidatesTokenCount": 25,
"totalTokenCount": 75
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,18 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation

// We must send back a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
// Collect all tool responses before adding to chat history
var toolResponses = new List<GeminiChatMessageContent>();

foreach (var toolCall in state.LastMessage!.ToolCalls!)
{
await this.ProcessSingleToolCallAsync(state, toolCall, cancellationToken).ConfigureAwait(false);
var toolResponse = await this.ProcessSingleToolCallAndReturnResponseAsync(state, toolCall, cancellationToken).ConfigureAwait(false);
toolResponses.Add(toolResponse);
}

// Add all tool responses as a single batched message
this.AddBatchedToolResponseMessage(state.ChatHistory, state.GeminiRequest, toolResponses);

// Clear the tools. If we end up wanting to use tools, we'll reset it to the desired value.
state.GeminiRequest.Tools = null;

Expand Down Expand Up @@ -431,6 +438,46 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation
}
}

private void AddBatchedToolResponseMessage(
ChatHistory chat,
GeminiRequest request,
List<GeminiChatMessageContent> toolResponses)
{
if (toolResponses.Count == 0)
{
return;
}

// Extract all tool results and combine content
var allToolResults = toolResponses
.Where(tr => tr.CalledToolResults != null)
.SelectMany(tr => tr.CalledToolResults!)
.ToList();

// Combine tool response content as a JSON array for better structure
var combinedContentList = toolResponses
.Select(tr => tr.Content)
.Where(c => !string.IsNullOrEmpty(c))
.ToList();

var combinedContent = combinedContentList.Count switch
{
0 => string.Empty,
1 => combinedContentList[0],
_ => JsonSerializer.Serialize(combinedContentList)
};

// Create a single message with all function response parts using the new constructor
var batchedMessage = new GeminiChatMessageContent(
AuthorRole.Tool,
combinedContent,
this._modelId,
calledToolResults: allToolResults);

chat.Add(batchedMessage);
request.AddChatMessage(batchedMessage);
}

private async Task ProcessSingleToolCallAsync(ChatCompletionState state, GeminiFunctionToolCall toolCall, CancellationToken cancellationToken)
{
// Make sure the requested function is one we requested. If we're permitting any kernel function to be invoked,
Expand Down Expand Up @@ -480,6 +527,65 @@ private async Task ProcessSingleToolCallAsync(ChatCompletionState state, GeminiF
functionResponse: functionResult, errorMessage: null);
}

private async Task<GeminiChatMessageContent> ProcessSingleToolCallAndReturnResponseAsync(ChatCompletionState state, GeminiFunctionToolCall toolCall, CancellationToken cancellationToken)
{
// Make sure the requested function is one we requested. If we're permitting any kernel function to be invoked,
// then we don't need to check this, as it'll be handled when we look up the function in the kernel to be able
// to invoke it. If we're permitting only a specific list of functions, though, then we need to explicitly check.
if (state.ExecutionSettings.ToolCallBehavior?.AllowAnyRequestedKernelFunction is not true &&
!IsRequestableTool(state.GeminiRequest.Tools![0].Functions, toolCall))
{
return this.CreateToolResponseMessage(toolCall, functionResponse: null, "Error: Function call request for a function that wasn't defined.");
}

// Ensure the provided function exists for calling
if (!state.Kernel!.Plugins.TryGetFunctionAndArguments(toolCall, out KernelFunction? function, out KernelArguments? functionArgs))
{
return this.CreateToolResponseMessage(toolCall, functionResponse: null, "Error: Requested function could not be found.");
}

// Now, invoke the function, and create the resulting tool call message.
s_inflightAutoInvokes.Value++;
FunctionResult? functionResult;
try
{
// Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any
// further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
// as the called function could in turn telling the model about itself as a possible candidate for invocation.
functionResult = await function.InvokeAsync(state.Kernel, functionArgs, cancellationToken: cancellationToken)
.ConfigureAwait(false);
}
#pragma warning disable CA1031 // Do not catch general exception types
catch (Exception e)
#pragma warning restore CA1031
{
return this.CreateToolResponseMessage(toolCall, functionResponse: null, $"Error: Exception while invoking function. {e.Message}");
}
finally
{
s_inflightAutoInvokes.Value--;
}

return this.CreateToolResponseMessage(toolCall, functionResponse: functionResult, errorMessage: null);
}

private GeminiChatMessageContent CreateToolResponseMessage(
GeminiFunctionToolCall tool,
FunctionResult? functionResponse,
string? errorMessage)
{
if (errorMessage is not null && this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger.LogDebug("Failed to handle tool request ({ToolName}). {Error}", tool.FullyQualifiedName, errorMessage);
}

return new GeminiChatMessageContent(AuthorRole.Tool,
content: errorMessage ?? string.Empty,
modelId: this._modelId,
calledToolResult: functionResponse is not null ? new GeminiFunctionToolResult(tool, functionResponse) : null,
metadata: null);
}

private async Task<GeminiResponse> SendRequestAndReturnValidGeminiResponseAsync(
Uri endpoint,
GeminiRequest geminiRequest,
Expand Down Expand Up @@ -604,7 +710,7 @@ private void LogUsage(List<GeminiChatMessageContent> chatMessageContents)

private List<GeminiChatMessageContent> GetChatMessageContentsFromResponse(GeminiResponse geminiResponse)
=> geminiResponse.Candidates == null ?
[new GeminiChatMessageContent(role: AuthorRole.Assistant, content: string.Empty, modelId: this._modelId)]
[new GeminiChatMessageContent(role: AuthorRole.Assistant, content: string.Empty, modelId: this._modelId, functionsToolCalls: null)]
: geminiResponse.Candidates.Select(candidate => this.GetChatMessageContentFromCandidate(geminiResponse, candidate)).ToList();

private GeminiChatMessageContent GetChatMessageContentFromCandidate(GeminiResponse geminiResponse, GeminiResponseCandidate candidate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,17 @@ private static List<GeminiPart> CreateGeminiParts(ChatMessageContent content)
List<GeminiPart> parts = [];
switch (content)
{
case GeminiChatMessageContent { CalledToolResult: not null } contentWithCalledTool:
parts.Add(new GeminiPart
{
FunctionResponse = new GeminiPart.FunctionResponsePart
case GeminiChatMessageContent { CalledToolResults: not null } contentWithCalledTools:
// Add all function responses as separate parts in a single message
parts.AddRange(contentWithCalledTools.CalledToolResults.Select(toolResult =>
new GeminiPart
{
FunctionName = contentWithCalledTool.CalledToolResult.FullyQualifiedName,
Response = new(contentWithCalledTool.CalledToolResult.FunctionResult.GetValue<object>())
}
});
FunctionResponse = new GeminiPart.FunctionResponsePart
{
FunctionName = toolResult.FullyQualifiedName,
Response = new(toolResult.FunctionResult.GetValue<object>())
}
}));
break;
case GeminiChatMessageContent { ToolCalls: not null } contentWithToolCalls:
parts.AddRange(contentWithToolCalls.ToolCalls.Select(toolCall =>
Expand Down
Loading
Loading