diff --git a/internal/extproc/translator/gemini_helper.go b/internal/extproc/translator/gemini_helper.go index 325d891bb..46e1fe8e3 100644 --- a/internal/extproc/translator/gemini_helper.go +++ b/internal/extproc/translator/gemini_helper.go @@ -29,6 +29,17 @@ const ( httpHeaderKeyContentLength = "Content-Length" ) +// geminiResponseMode represents the type of response mode for Gemini requests +type geminiResponseMode string + +const ( + responseModeNone geminiResponseMode = "NONE" + responseModeText geminiResponseMode = "TEXT" + responseModeJSON geminiResponseMode = "JSON" + responseModeEnum geminiResponseMode = "ENUM" + responseModeRegex geminiResponseMode = "REGEX" +) + // ------------------------------------------------------------- // Request Conversion Helper for OpenAI to GCP Gemini Translator // -------------------------------------------------------------. @@ -382,7 +393,8 @@ func openAIToolChoiceToGeminiToolConfig(toolChoice *openai.ChatCompletionToolCho } // openAIReqToGeminiGenerationConfig converts OpenAI request to Gemini GenerationConfig. -func openAIReqToGeminiGenerationConfig(openAIReq *openai.ChatCompletionRequest) (*genai.GenerationConfig, error) { +func openAIReqToGeminiGenerationConfig(openAIReq *openai.ChatCompletionRequest) (*genai.GenerationConfig, geminiResponseMode, error) { + responseMode := responseModeNone gc := &genai.GenerationConfig{} if openAIReq.Temperature != nil { f := float32(*openAIReq.Temperature) @@ -407,46 +419,66 @@ func openAIReqToGeminiGenerationConfig(openAIReq *openai.ChatCompletionRequest) gc.ResponseLogprobs = *openAIReq.LogProbs } + formatSpecifiedCount := 0 + if openAIReq.ResponseFormat != nil { + formatSpecifiedCount++ switch { case openAIReq.ResponseFormat.OfText != nil: + responseMode = responseModeText gc.ResponseMIMEType = mimeTypeTextPlain case openAIReq.ResponseFormat.OfJSONObject != nil: + responseMode = responseModeJSON gc.ResponseMIMEType = mimeTypeApplicationJSON case openAIReq.ResponseFormat.OfJSONSchema != nil: var schemaMap map[string]any if err := json.Unmarshal([]byte(openAIReq.ResponseFormat.OfJSONSchema.JSONSchema.Schema), &schemaMap); err != nil { - return nil, fmt.Errorf("invalid JSON schema: %w", err) + return nil, responseMode, fmt.Errorf("invalid JSON schema: %w", err) } + responseMode = responseModeJSON + gc.ResponseMIMEType = mimeTypeApplicationJSON gc.ResponseJsonSchema = schemaMap } } if openAIReq.GuidedChoice != nil { + formatSpecifiedCount++ if existSchema := gc.ResponseSchema != nil || gc.ResponseJsonSchema != nil; existSchema { - return nil, fmt.Errorf("duplicate json scheme specifications") + return nil, responseMode, fmt.Errorf("duplicate json scheme specifications") } + responseMode = responseModeEnum gc.ResponseMIMEType = mimeTypeApplicationEnum gc.ResponseSchema = &genai.Schema{Type: "STRING", Enum: openAIReq.GuidedChoice} } if openAIReq.GuidedRegex != "" { + formatSpecifiedCount++ if existSchema := gc.ResponseSchema != nil || gc.ResponseJsonSchema != nil; existSchema { - return nil, fmt.Errorf("duplicate json scheme specifications") + return nil, responseMode, fmt.Errorf("duplicate json scheme specifications") } + responseMode = responseModeRegex gc.ResponseMIMEType = mimeTypeApplicationJSON gc.ResponseSchema = &genai.Schema{Type: "STRING", Pattern: openAIReq.GuidedRegex} } if openAIReq.GuidedJSON != nil { + formatSpecifiedCount++ if existSchema := gc.ResponseSchema != nil || gc.ResponseJsonSchema != nil; existSchema { - return nil, fmt.Errorf("duplicate json scheme specifications") + return nil, responseMode, fmt.Errorf("duplicate json scheme specifications") } + responseMode = responseModeJSON + gc.ResponseMIMEType = mimeTypeApplicationJSON gc.ResponseJsonSchema = openAIReq.GuidedJSON } + // ResponseFormat and guidedJSON/guidedChoice/guidedRegex are mutually exclusive. + // Verify only one is specified. + if formatSpecifiedCount > 1 { + return nil, responseMode, fmt.Errorf("multiple format specifiers specified. only one of responseFormat, guidedChoice, guidedRegex, guidedJSON can be specified") + } + if openAIReq.N != nil { gc.CandidateCount = int32(*openAIReq.N) // nolint:gosec } @@ -464,7 +496,7 @@ func openAIReqToGeminiGenerationConfig(openAIReq *openai.ChatCompletionRequest) } else if openAIReq.Stop.OfStringArray != nil { gc.StopSequences = openAIReq.Stop.OfStringArray } - return gc, nil + return gc, responseMode, nil } // -------------------------------------------------------------- @@ -472,7 +504,7 @@ func openAIReqToGeminiGenerationConfig(openAIReq *openai.ChatCompletionRequest) // --------------------------------------------------------------. // geminiCandidatesToOpenAIChoices converts Gemini candidates to OpenAI choices. -func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate) ([]openai.ChatCompletionResponseChoice, error) { +func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode geminiResponseMode) ([]openai.ChatCompletionResponseChoice, error) { choices := make([]openai.ChatCompletionResponseChoice, 0, len(candidates)) for idx, candidate := range candidates { @@ -491,7 +523,7 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate) ([]openai.Ch Role: openai.ChatMessageRoleAssistant, } // Extract text from parts. - content := extractTextFromGeminiParts(candidate.Content.Parts) + content := extractTextFromGeminiParts(candidate.Content.Parts, responseMode) message.Content = &content // Extract tool calls if any. @@ -545,10 +577,18 @@ func geminiFinishReasonToOpenAI(reason genai.FinishReason) openai.ChatCompletion } // extractTextFromGeminiParts extracts text from Gemini parts. -func extractTextFromGeminiParts(parts []*genai.Part) string { +func extractTextFromGeminiParts(parts []*genai.Part, responseMode geminiResponseMode) string { var text string for _, part := range parts { if part != nil && part.Text != "" { + if responseMode == responseModeRegex { + // GCP doesn't natively support REGEX response modes, so we instead express them as json schema. + // This causes the response to be wrapped in double-quotes. + // E.g. `"positive"` (the double-quotes at the start and end are unwanted) + // Here we remove the wrapping double-quotes. + part.Text = strings.TrimPrefix(part.Text, "\"") + part.Text = strings.TrimSuffix(part.Text, "\"") + } text += part.Text } } @@ -665,7 +705,7 @@ func buildGCPModelPathSuffix(publisher, model, gcpMethod string, queryParams ... } // geminiCandidatesToOpenAIStreamingChoices converts Gemini candidates to OpenAI streaming choices. -func geminiCandidatesToOpenAIStreamingChoices(candidates []*genai.Candidate) ([]openai.ChatCompletionResponseChunkChoice, error) { +func geminiCandidatesToOpenAIStreamingChoices(candidates []*genai.Candidate, responseMode geminiResponseMode) ([]openai.ChatCompletionResponseChunkChoice, error) { choices := make([]openai.ChatCompletionResponseChunkChoice, 0, len(candidates)) for _, candidate := range candidates { @@ -685,7 +725,7 @@ func geminiCandidatesToOpenAIStreamingChoices(candidates []*genai.Candidate) ([] } // Extract text from parts for streaming (delta). - content := extractTextFromGeminiParts(candidate.Content.Parts) + content := extractTextFromGeminiParts(candidate.Content.Parts, responseMode) if content != "" { delta.Content = &content } diff --git a/internal/extproc/translator/gemini_helper_test.go b/internal/extproc/translator/gemini_helper_test.go index 3263e2005..8619c97cd 100644 --- a/internal/extproc/translator/gemini_helper_test.go +++ b/internal/extproc/translator/gemini_helper_test.go @@ -725,6 +725,7 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { name string input *openai.ChatCompletionRequest expectedGenerationConfig *genai.GenerationConfig + expectedResponseMode geminiResponseMode expectedErrMsg string }{ { @@ -755,11 +756,13 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { FrequencyPenalty: ptr.To(float32(0.5)), StopSequences: []string{"stop1", "stop2"}, }, + expectedResponseMode: responseModeNone, }, { name: "minimal fields", input: &openai.ChatCompletionRequest{}, expectedGenerationConfig: &genai.GenerationConfig{}, + expectedResponseMode: responseModeNone, }, { name: "stop sequences", @@ -771,6 +774,7 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { expectedGenerationConfig: &genai.GenerationConfig{ StopSequences: []string{"stop1"}, }, + expectedResponseMode: responseModeNone, }, { name: "text", @@ -782,6 +786,7 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { }, }, expectedGenerationConfig: &genai.GenerationConfig{ResponseMIMEType: "text/plain"}, + expectedResponseMode: responseModeText, }, { name: "json object", @@ -793,6 +798,7 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { }, }, expectedGenerationConfig: &genai.GenerationConfig{ResponseMIMEType: "application/json"}, + expectedResponseMode: responseModeJSON, }, { name: "json schema (map)", @@ -810,6 +816,7 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { ResponseMIMEType: "application/json", ResponseJsonSchema: map[string]any{"type": "string"}, }, + expectedResponseMode: responseModeJSON, }, { name: "json schema (string)", @@ -827,6 +834,7 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { ResponseMIMEType: "application/json", ResponseJsonSchema: map[string]any{"type": "string"}, }, + expectedResponseMode: responseModeJSON, }, { name: "json schema (invalid string)", @@ -851,6 +859,7 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { ResponseMIMEType: "text/x.enum", ResponseSchema: &genai.Schema{Type: "STRING", Enum: []string{"Positive", "Negative"}}, }, + expectedResponseMode: responseModeEnum, }, { name: "guided regex", @@ -861,6 +870,7 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { ResponseMIMEType: "application/json", ResponseSchema: &genai.Schema{Type: "STRING", Pattern: "\\w+@\\w+\\.com\\n"}, }, + expectedResponseMode: responseModeRegex, }, { name: "guided json", @@ -871,12 +881,25 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { ResponseMIMEType: "application/json", ResponseJsonSchema: json.RawMessage(`{"type": "string"}`), }, + expectedResponseMode: responseModeJSON, + }, + { + name: "multiple format specifiers - ResponseFormat and GuidedChoice", + input: &openai.ChatCompletionRequest{ + ResponseFormat: &openai.ChatCompletionResponseFormatUnion{ + OfText: &openai.ChatCompletionResponseFormatTextParam{ + Type: openai.ChatCompletionResponseFormatTypeText, + }, + }, + GuidedChoice: []string{"A", "B"}, + }, + expectedErrMsg: "multiple format specifiers specified", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got, err := openAIReqToGeminiGenerationConfig(tc.input) + got, responseMode, err := openAIReqToGeminiGenerationConfig(tc.input) if tc.expectedErrMsg != "" { require.ErrorContains(t, err, tc.expectedErrMsg) } else { @@ -885,6 +908,10 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { if diff := cmp.Diff(tc.expectedGenerationConfig, got, cmpopts.IgnoreUnexported(genai.GenerationConfig{})); diff != "" { t.Errorf("GenerationConfig mismatch (-want +got):\n%s", diff) } + + if responseMode != tc.expectedResponseMode { + t.Errorf("geminiResponseMode mismatch: got %v, want %v", responseMode, tc.expectedResponseMode) + } } }) } @@ -1388,3 +1415,77 @@ func TestGeminiFinishReasonToOpenAI(t *testing.T) { }) } } + +func TestExtractTextFromGeminiParts(t *testing.T) { + tests := []struct { + name string + parts []*genai.Part + responseMode geminiResponseMode + expected string + }{ + { + name: "nil parts", + parts: nil, + responseMode: responseModeNone, + expected: "", + }, + { + name: "empty parts", + parts: []*genai.Part{}, + responseMode: responseModeNone, + expected: "", + }, + { + name: "multiple text parts without regex mode", + parts: []*genai.Part{ + {Text: "Hello, "}, + {Text: "world!"}, + }, + responseMode: responseModeJSON, + expected: "Hello, world!", + }, + { + name: "regex mode with mixed quoted and unquoted text", + parts: []*genai.Part{ + {Text: `"positive"`}, + {Text: `unquoted`}, + {Text: `"negative"`}, + }, + responseMode: responseModeRegex, + expected: "positiveunquotednegative", + }, + { + name: "regex mode with only double-quoted first and last words", + parts: []*genai.Part{ + {Text: "\"\"ERROR\" Unable to connect to database \"DatabaseModule\"\""}, + }, + responseMode: responseModeRegex, + expected: "\"ERROR\" Unable to connect to database \"DatabaseModule\"", + }, + { + name: "non-regex mode with double-quoted text (should not remove quotes)", + parts: []*genai.Part{ + {Text: `"positive"`}, + }, + responseMode: responseModeJSON, + expected: `"positive"`, + }, + { + name: "regex mode with text containing internal quotes", + parts: []*genai.Part{ + {Text: `"He said \"hello\" to me"`}, + }, + responseMode: responseModeRegex, + expected: `He said \"hello\" to me`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := extractTextFromGeminiParts(tc.parts, tc.responseMode) + if result != tc.expected { + t.Errorf("extractTextFromGeminiParts() = %q, want %q", result, tc.expected) + } + }) + } +} diff --git a/internal/extproc/translator/openai_gcpvertexai.go b/internal/extproc/translator/openai_gcpvertexai.go index 35708e745..7e1fb8731 100644 --- a/internal/extproc/translator/openai_gcpvertexai.go +++ b/internal/extproc/translator/openai_gcpvertexai.go @@ -48,6 +48,7 @@ func NewChatCompletionOpenAIToGCPVertexAITranslator(modelNameOverride internalap // Note: This uses the Gemini native API directly, not Vertex AI's OpenAI-compatible API: // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference type openAIToGCPVertexAITranslatorV1ChatCompletion struct { + responseMode geminiResponseMode modelNameOverride internalapi.ModelNameOverride stream bool // Track if this is a streaming request. bufferedBody []byte // Buffer for incomplete JSON chunks. @@ -244,7 +245,7 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) parseGCPStreamingChunks( // convertGCPChunkToOpenAI converts a GCP streaming chunk to OpenAI streaming format. func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) convertGCPChunkToOpenAI(chunk genai.GenerateContentResponse) *openai.ChatCompletionResponseChunk { // Convert candidates to OpenAI choices for streaming. - choices, err := geminiCandidatesToOpenAIStreamingChoices(chunk.Candidates) + choices, err := geminiCandidatesToOpenAIStreamingChoices(chunk.Candidates, o.responseMode) if err != nil { // For now, create empty choices on error to prevent breaking the stream. choices = []openai.ChatCompletionResponseChunkChoice{} @@ -284,10 +285,11 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) openAIMessageToGeminiMes } // Convert generation config. - generationConfig, err := openAIReqToGeminiGenerationConfig(openAIReq) + generationConfig, responseMode, err := openAIReqToGeminiGenerationConfig(openAIReq) if err != nil { return nil, fmt.Errorf("error converting generation config: %w", err) } + o.responseMode = responseMode gcr := gcp.GenerateContentRequest{ Contents: contents, @@ -330,7 +332,7 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) applyVendorSpecificField func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) geminiResponseToOpenAIMessage(gcr genai.GenerateContentResponse, responseModel string) (*openai.ChatCompletionResponse, error) { // Convert candidates to OpenAI choices. - choices, err := geminiCandidatesToOpenAIChoices(gcr.Candidates) + choices, err := geminiCandidatesToOpenAIChoices(gcr.Candidates, o.responseMode) if err != nil { return nil, fmt.Errorf("error converting choices: %w", err) }