Skip to content

Commit 96c4701

Browse files
.Net: Feature/gemini function parts format (#13258)
### Motivation and Context 1. Why is this change required? It appears that there's a quirk to how Gemini handles tool call results in it's API. The tool responses need to be present in the same chat item as the tool requests - rather than adding them individually as separate chat history items. 2. What problem does it solve? This _appears_ to fix (at least in my use case) the error whereby you receive: ``` { "error": { "code": 400, "message": "Please ensure that the number of function response parts is equal to the number of function call parts of the function call turn.", "status": "INVALID_ARGUMENT" } } ``` 4. Which appears to happen when an SK response auto invokes > 1 tool. When it's a single tool result the API appears to handle the response gracefully. 5. #12823 6. #12528 --> ### Description Google Gemini expects the function responses to be grouped together in a single message rather than split across multiple separate messages with role "function". By adding a batch tool result overload to the GeminiChatMessage we can correctly parse the multiple parts of the function call into their respective function parts. ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄 FYI @stephentoub (As I see you're working in the Gemini connector recently) --------- Co-authored-by: Roger Barreto <[email protected]>
1 parent 3434190 commit 96c4701

File tree

7 files changed

+368
-15
lines changed

7 files changed

+368
-15
lines changed

dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationFunctionCallingTests.cs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,70 @@ public async Task IfAutoInvokeMaximumAutoInvokeAttemptsReachedShouldStopInvoking
376376
c is GeminiChatMessageContent gm && gm.Role == AuthorRole.Tool && gm.CalledToolResult is not null);
377377
}
378378

379+
[Fact]
380+
public async Task ShouldBatchMultipleToolResponsesIntoSingleMessageAsync()
381+
{
382+
// Arrange
383+
var responseContentWithMultipleFunctions = File.ReadAllText("./TestData/chat_multiple_function_calls_response.json")
384+
.Replace("%nameSeparator%", GeminiFunction.NameSeparator, StringComparison.Ordinal);
385+
386+
using var handlerStub = new MultipleHttpMessageHandlerStub();
387+
handlerStub.AddJsonResponse(responseContentWithMultipleFunctions);
388+
handlerStub.AddJsonResponse(this._responseContent); // Final response after tool execution
389+
390+
#pragma warning disable CA2000
391+
var client = this.CreateChatCompletionClient(httpClient: handlerStub.CreateHttpClient());
392+
#pragma warning restore CA2000
393+
var chatHistory = CreateSampleChatHistory();
394+
var executionSettings = new GeminiPromptExecutionSettings
395+
{
396+
ToolCallBehavior = GeminiToolCallBehavior.AutoInvokeKernelFunctions
397+
};
398+
399+
// Act
400+
await client.GenerateChatMessageAsync(chatHistory, executionSettings: executionSettings, kernel: this._kernelWithFunctions);
401+
402+
// Assert
403+
// Find the tool response message that should be batched
404+
var toolResponseMessage = chatHistory.OfType<GeminiChatMessageContent>()
405+
.FirstOrDefault(m => m.Role == AuthorRole.Tool && m.CalledToolResults != null);
406+
407+
Assert.NotNull(toolResponseMessage);
408+
Assert.NotNull(toolResponseMessage.CalledToolResults);
409+
410+
// Verify that multiple tool results are batched into a single message
411+
Assert.Equal(2, toolResponseMessage.CalledToolResults.Count);
412+
413+
// Verify the specific tool calls that were batched
414+
var toolNames = toolResponseMessage.CalledToolResults.Select(tr => tr.FullyQualifiedName).ToArray();
415+
Assert.Contains(this._timePluginNow.FullyQualifiedName, toolNames);
416+
Assert.Contains(this._timePluginDate.FullyQualifiedName, toolNames);
417+
418+
// Verify backward compatibility - CalledToolResult property should return the first result
419+
Assert.NotNull(toolResponseMessage.CalledToolResult);
420+
Assert.Equal(toolResponseMessage.CalledToolResults[0], toolResponseMessage.CalledToolResult);
421+
422+
// Verify the request that would be sent to Gemini contains the correct structure
423+
var requestJson = handlerStub.GetRequestContentAsString(1); // Get the second request (after tool execution)
424+
Assert.NotNull(requestJson);
425+
var request = JsonSerializer.Deserialize<GeminiRequest>(requestJson);
426+
Assert.NotNull(request);
427+
428+
// Find the content that represents the batched tool responses
429+
var toolResponseContent = request.Contents.FirstOrDefault(c => c.Role == AuthorRole.Tool);
430+
Assert.NotNull(toolResponseContent);
431+
Assert.NotNull(toolResponseContent.Parts);
432+
433+
// Verify that all function responses are included as separate parts in the single message
434+
var functionResponseParts = toolResponseContent.Parts.Where(p => p.FunctionResponse != null).ToArray();
435+
Assert.Equal(2, functionResponseParts.Length);
436+
437+
// Verify each function response part corresponds to the tool calls
438+
var functionNames = functionResponseParts.Select(p => p.FunctionResponse!.FunctionName).ToArray();
439+
Assert.Contains(this._timePluginNow.FullyQualifiedName, functionNames);
440+
Assert.Contains(this._timePluginDate.FullyQualifiedName, functionNames);
441+
}
442+
379443
private static ChatHistory CreateSampleChatHistory()
380444
{
381445
var chatHistory = new ChatHistory();

dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ public void AddChatMessageToRequest()
491491
// Arrange
492492
ChatHistory chat = [];
493493
var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chat, new GeminiPromptExecutionSettings());
494-
var message = new GeminiChatMessageContent(AuthorRole.User, "user-message", "model-id");
494+
var message = new GeminiChatMessageContent(AuthorRole.User, "user-message", "model-id", calledToolResults: null);
495495

496496
// Act
497497
request.AddChatMessage(message);
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
{
2+
"candidates": [
3+
{
4+
"content": {
5+
"parts": [
6+
{
7+
"text": "I'll help you get the current time and date. Let me call both functions for you."
8+
},
9+
{
10+
"functionCall": {
11+
"name": "TimePlugin%nameSeparator%Now",
12+
"args": {
13+
"param1": "current time"
14+
}
15+
}
16+
},
17+
{
18+
"functionCall": {
19+
"name": "TimePlugin%nameSeparator%Date",
20+
"args": {
21+
"format": "yyyy-MM-dd"
22+
}
23+
}
24+
}
25+
],
26+
"role": "model"
27+
},
28+
"finishReason": "STOP",
29+
"index": 0,
30+
"safetyRatings": [
31+
{
32+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
33+
"probability": "NEGLIGIBLE"
34+
},
35+
{
36+
"category": "HARM_CATEGORY_HATE_SPEECH",
37+
"probability": "NEGLIGIBLE"
38+
},
39+
{
40+
"category": "HARM_CATEGORY_HARASSMENT",
41+
"probability": "NEGLIGIBLE"
42+
},
43+
{
44+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
45+
"probability": "NEGLIGIBLE"
46+
}
47+
]
48+
}
49+
],
50+
"promptFeedback": {
51+
"safetyRatings": [
52+
{
53+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
54+
"probability": "NEGLIGIBLE"
55+
},
56+
{
57+
"category": "HARM_CATEGORY_HATE_SPEECH",
58+
"probability": "NEGLIGIBLE"
59+
},
60+
{
61+
"category": "HARM_CATEGORY_HARASSMENT",
62+
"probability": "NEGLIGIBLE"
63+
},
64+
{
65+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
66+
"probability": "NEGLIGIBLE"
67+
}
68+
]
69+
},
70+
"usageMetadata": {
71+
"promptTokenCount": 50,
72+
"candidatesTokenCount": 25,
73+
"totalTokenCount": 75
74+
}
75+
}

dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,18 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation
395395

396396
// We must send back a response for every tool call, regardless of whether we successfully executed it or not.
397397
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
398+
// Collect all tool responses before adding to chat history
399+
var toolResponses = new List<GeminiChatMessageContent>();
400+
398401
foreach (var toolCall in state.LastMessage!.ToolCalls!)
399402
{
400-
await this.ProcessSingleToolCallAsync(state, toolCall, cancellationToken).ConfigureAwait(false);
403+
var toolResponse = await this.ProcessSingleToolCallAndReturnResponseAsync(state, toolCall, cancellationToken).ConfigureAwait(false);
404+
toolResponses.Add(toolResponse);
401405
}
402406

407+
// Add all tool responses as a single batched message
408+
this.AddBatchedToolResponseMessage(state.ChatHistory, state.GeminiRequest, toolResponses);
409+
403410
// Clear the tools. If we end up wanting to use tools, we'll reset it to the desired value.
404411
state.GeminiRequest.Tools = null;
405412

@@ -431,6 +438,46 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation
431438
}
432439
}
433440

441+
private void AddBatchedToolResponseMessage(
442+
ChatHistory chat,
443+
GeminiRequest request,
444+
List<GeminiChatMessageContent> toolResponses)
445+
{
446+
if (toolResponses.Count == 0)
447+
{
448+
return;
449+
}
450+
451+
// Extract all tool results and combine content
452+
var allToolResults = toolResponses
453+
.Where(tr => tr.CalledToolResults != null)
454+
.SelectMany(tr => tr.CalledToolResults!)
455+
.ToList();
456+
457+
// Combine tool response content as a JSON array for better structure
458+
var combinedContentList = toolResponses
459+
.Select(tr => tr.Content)
460+
.Where(c => !string.IsNullOrEmpty(c))
461+
.ToList();
462+
463+
var combinedContent = combinedContentList.Count switch
464+
{
465+
0 => string.Empty,
466+
1 => combinedContentList[0],
467+
_ => JsonSerializer.Serialize(combinedContentList)
468+
};
469+
470+
// Create a single message with all function response parts using the new constructor
471+
var batchedMessage = new GeminiChatMessageContent(
472+
AuthorRole.Tool,
473+
combinedContent,
474+
this._modelId,
475+
calledToolResults: allToolResults);
476+
477+
chat.Add(batchedMessage);
478+
request.AddChatMessage(batchedMessage);
479+
}
480+
434481
private async Task ProcessSingleToolCallAsync(ChatCompletionState state, GeminiFunctionToolCall toolCall, CancellationToken cancellationToken)
435482
{
436483
// Make sure the requested function is one we requested. If we're permitting any kernel function to be invoked,
@@ -480,6 +527,65 @@ private async Task ProcessSingleToolCallAsync(ChatCompletionState state, GeminiF
480527
functionResponse: functionResult, errorMessage: null);
481528
}
482529

530+
private async Task<GeminiChatMessageContent> ProcessSingleToolCallAndReturnResponseAsync(ChatCompletionState state, GeminiFunctionToolCall toolCall, CancellationToken cancellationToken)
531+
{
532+
// Make sure the requested function is one we requested. If we're permitting any kernel function to be invoked,
533+
// then we don't need to check this, as it'll be handled when we look up the function in the kernel to be able
534+
// to invoke it. If we're permitting only a specific list of functions, though, then we need to explicitly check.
535+
if (state.ExecutionSettings.ToolCallBehavior?.AllowAnyRequestedKernelFunction is not true &&
536+
!IsRequestableTool(state.GeminiRequest.Tools![0].Functions, toolCall))
537+
{
538+
return this.CreateToolResponseMessage(toolCall, functionResponse: null, "Error: Function call request for a function that wasn't defined.");
539+
}
540+
541+
// Ensure the provided function exists for calling
542+
if (!state.Kernel!.Plugins.TryGetFunctionAndArguments(toolCall, out KernelFunction? function, out KernelArguments? functionArgs))
543+
{
544+
return this.CreateToolResponseMessage(toolCall, functionResponse: null, "Error: Requested function could not be found.");
545+
}
546+
547+
// Now, invoke the function, and create the resulting tool call message.
548+
s_inflightAutoInvokes.Value++;
549+
FunctionResult? functionResult;
550+
try
551+
{
552+
// Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any
553+
// further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
554+
// as the called function could in turn telling the model about itself as a possible candidate for invocation.
555+
functionResult = await function.InvokeAsync(state.Kernel, functionArgs, cancellationToken: cancellationToken)
556+
.ConfigureAwait(false);
557+
}
558+
#pragma warning disable CA1031 // Do not catch general exception types
559+
catch (Exception e)
560+
#pragma warning restore CA1031
561+
{
562+
return this.CreateToolResponseMessage(toolCall, functionResponse: null, $"Error: Exception while invoking function. {e.Message}");
563+
}
564+
finally
565+
{
566+
s_inflightAutoInvokes.Value--;
567+
}
568+
569+
return this.CreateToolResponseMessage(toolCall, functionResponse: functionResult, errorMessage: null);
570+
}
571+
572+
private GeminiChatMessageContent CreateToolResponseMessage(
573+
GeminiFunctionToolCall tool,
574+
FunctionResult? functionResponse,
575+
string? errorMessage)
576+
{
577+
if (errorMessage is not null && this.Logger.IsEnabled(LogLevel.Debug))
578+
{
579+
this.Logger.LogDebug("Failed to handle tool request ({ToolName}). {Error}", tool.FullyQualifiedName, errorMessage);
580+
}
581+
582+
return new GeminiChatMessageContent(AuthorRole.Tool,
583+
content: errorMessage ?? string.Empty,
584+
modelId: this._modelId,
585+
calledToolResult: functionResponse is not null ? new GeminiFunctionToolResult(tool, functionResponse) : null,
586+
metadata: null);
587+
}
588+
483589
private async Task<GeminiResponse> SendRequestAndReturnValidGeminiResponseAsync(
484590
Uri endpoint,
485591
GeminiRequest geminiRequest,
@@ -604,7 +710,7 @@ private void LogUsage(List<GeminiChatMessageContent> chatMessageContents)
604710

605711
private List<GeminiChatMessageContent> GetChatMessageContentsFromResponse(GeminiResponse geminiResponse)
606712
=> geminiResponse.Candidates == null ?
607-
[new GeminiChatMessageContent(role: AuthorRole.Assistant, content: string.Empty, modelId: this._modelId)]
713+
[new GeminiChatMessageContent(role: AuthorRole.Assistant, content: string.Empty, modelId: this._modelId, functionsToolCalls: null)]
608714
: geminiResponse.Candidates.Select(candidate => this.GetChatMessageContentFromCandidate(geminiResponse, candidate)).ToList();
609715

610716
private GeminiChatMessageContent GetChatMessageContentFromCandidate(GeminiResponse geminiResponse, GeminiResponseCandidate candidate)

dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,17 @@ private static List<GeminiPart> CreateGeminiParts(ChatMessageContent content)
183183
List<GeminiPart> parts = [];
184184
switch (content)
185185
{
186-
case GeminiChatMessageContent { CalledToolResult: not null } contentWithCalledTool:
187-
parts.Add(new GeminiPart
188-
{
189-
FunctionResponse = new GeminiPart.FunctionResponsePart
186+
case GeminiChatMessageContent { CalledToolResults: not null } contentWithCalledTools:
187+
// Add all function responses as separate parts in a single message
188+
parts.AddRange(contentWithCalledTools.CalledToolResults.Select(toolResult =>
189+
new GeminiPart
190190
{
191-
FunctionName = contentWithCalledTool.CalledToolResult.FullyQualifiedName,
192-
Response = new(contentWithCalledTool.CalledToolResult.FunctionResult.GetValue<object>())
193-
}
194-
});
191+
FunctionResponse = new GeminiPart.FunctionResponsePart
192+
{
193+
FunctionName = toolResult.FullyQualifiedName,
194+
Response = new(toolResult.FunctionResult.GetValue<object>())
195+
}
196+
}));
195197
break;
196198
case GeminiChatMessageContent { ToolCalls: not null } contentWithToolCalls:
197199
parts.AddRange(contentWithToolCalls.ToolCalls.Select(toolCall =>

0 commit comments

Comments
 (0)