Skip to content

Commit 83efb55

Browse files
authored
feat(go/plugins/googlegenai): add image generation native support (#2630)
1 parent 1d9fe19 commit 83efb55

File tree

5 files changed

+202
-4
lines changed

5 files changed

+202
-4
lines changed

go/plugins/googlegenai/gemini.go

+52-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package googlegenai
1818

1919
import (
2020
"context"
21+
"encoding/base64"
2122
"encoding/json"
2223
"fmt"
2324
"net/http"
@@ -63,7 +64,7 @@ var (
6364

6465
// Attribution header
6566
xGoogApiClientHeader = http.CanonicalHeaderKey("x-goog-api-client")
66-
GenkitClientHeader = http.Header{
67+
genkitClientHeader = http.Header{
6768
xGoogApiClientHeader: {fmt.Sprintf("genkit-go/%s", internal.Version)},
6869
}
6970
)
@@ -174,6 +175,15 @@ type SafetySetting struct {
174175
Threshold HarmBlockThreshold `json:"threshold,omitempty"`
175176
}
176177

178+
type Modality string
179+
180+
const (
181+
// Indicates the model should return images
182+
ImageMode Modality = "IMAGE"
183+
// Indicates the model should return text
184+
TextMode Modality = "TEXT"
185+
)
186+
177187
// GeminiConfig mirrors GenerateContentConfig without direct genai dependency
178188
type GeminiConfig struct {
179189
// MaxOutputTokens is the maximum number of tokens to generate.
@@ -192,6 +202,8 @@ type GeminiConfig struct {
192202
SafetySettings []*SafetySetting `json:"safetySettings,omitempty"`
193203
// CodeExecution is whether to allow executing of code generated by the model.
194204
CodeExecution bool `json:"codeExecution,omitempty"`
205+
// Response modalities for returned model messages
206+
ResponseModalities []Modality `json:"responseModalities,omitempty"`
195207
}
196208

197209
// configFromRequest converts any supported config type to [GeminiConfig].
@@ -333,6 +345,23 @@ func generate(
333345
return nil, err
334346
}
335347

348+
if len(config.ResponseModalities) > 0 {
349+
err := validateResponseModalities(model, config.ResponseModalities)
350+
if err != nil {
351+
return nil, err
352+
}
353+
for _, m := range config.ResponseModalities {
354+
gcc.ResponseModalities = append(gcc.ResponseModalities, string(m))
355+
}
356+
357+
// prevent an error in the client where:
358+
// if TEXT modality is not present and the model supports it, the client
359+
// will return an error
360+
if !slices.Contains(gcc.ResponseModalities, string(genai.ModalityText)) {
361+
gcc.ResponseModalities = append(gcc.ResponseModalities, string(genai.ModalityText))
362+
}
363+
}
364+
336365
var contents []*genai.Content
337366
for _, m := range input.Messages {
338367
// system parts are handled separately
@@ -523,6 +552,23 @@ func convertRequest(input *ai.ModelRequest, cache *genai.CachedContent) (*genai.
523552
return &gcc, nil
524553
}
525554

555+
// validateResponseModalities checks if response modality is valid for the requested model
556+
func validateResponseModalities(model string, modalities []Modality) error {
557+
for _, m := range modalities {
558+
switch m {
559+
case ImageMode:
560+
if !slices.Contains(imageGenModels, model) {
561+
return fmt.Errorf("IMAGE response modality is not supported for model %q", model)
562+
}
563+
case TextMode:
564+
continue
565+
default:
566+
return fmt.Errorf("unknown response modality provided: %q", m)
567+
}
568+
}
569+
return nil
570+
}
571+
526572
// toGeminiTools translates a slice of [ai.ToolDefinition] to a slice of [genai.Tool].
527573
func toGeminiTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) {
528574
var outTools []*genai.Tool
@@ -724,7 +770,11 @@ func translateCandidate(cand *genai.Candidate) *ai.ModelResponse {
724770
}
725771
if part.InlineData != nil {
726772
partFound++
727-
p = ai.NewMediaPart(part.InlineData.MIMEType, string(part.InlineData.Data))
773+
p = ai.NewMediaPart(part.InlineData.MIMEType, base64.StdEncoding.EncodeToString(part.InlineData.Data))
774+
}
775+
if part.FileData != nil {
776+
partFound++
777+
p = ai.NewMediaPart(part.FileData.MIMEType, part.FileData.FileURI)
728778
}
729779
if part.FunctionCall != nil {
730780
partFound++

go/plugins/googlegenai/googleai_live_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,34 @@ func TestGoogleAILive(t *testing.T) {
288288
t.Fatalf("image detection failed, want: Mario Kart, got: %s", resp.Text())
289289
}
290290
})
291+
t.Run("image generation", func(t *testing.T) {
292+
m := googlegenai.GoogleAIModel(g, "gemini-2.0-flash-exp")
293+
resp, err := genkit.Generate(ctx, g,
294+
ai.WithConfig(googlegenai.GeminiConfig{
295+
ResponseModalities: []googlegenai.Modality{googlegenai.ImageMode, googlegenai.TextMode},
296+
}),
297+
ai.WithMessages(
298+
ai.NewUserTextMessage("generate an image of a dog wearing a black tejana while playing the accordion"),
299+
),
300+
ai.WithModel(m),
301+
)
302+
if err != nil {
303+
t.Fatal(err)
304+
}
305+
if len(resp.Message.Content) == 0 {
306+
t.Fatal("empty response")
307+
}
308+
part := resp.Message.Content[0]
309+
if part.ContentType != "image/png" {
310+
t.Errorf("expecting image/png content type but got: %q", part.ContentType)
311+
}
312+
if part.Kind != ai.PartMedia {
313+
t.Errorf("expecting part to be Media type but got: %q", part.Kind)
314+
}
315+
if part.Text == "" {
316+
t.Errorf("empty response")
317+
}
318+
})
291319
t.Run("constrained generation", func(t *testing.T) {
292320
type outFormat struct {
293321
Country string

go/plugins/googlegenai/googlegenai.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func (ga *GoogleAI) Init(ctx context.Context, g *genkit.Genkit) (err error) {
8585
Backend: genai.BackendGeminiAPI,
8686
APIKey: apiKey,
8787
HTTPOptions: genai.HTTPOptions{
88-
Headers: GenkitClientHeader,
88+
Headers: genkitClientHeader,
8989
},
9090
}
9191

@@ -159,7 +159,7 @@ func (v *VertexAI) Init(ctx context.Context, g *genkit.Genkit) (err error) {
159159
Project: v.ProjectID,
160160
Location: v.Location,
161161
HTTPOptions: genai.HTTPOptions{
162-
Headers: GenkitClientHeader,
162+
Headers: genkitClientHeader,
163163
},
164164
}
165165

go/plugins/googlegenai/models.go

+13
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ const (
1616
gemini15Flash8b = "gemini-1.5-flash-8b"
1717

1818
gemini20Flash = "gemini-2.0-flash"
19+
gemini20FlashExp = "gemini-2.0-flash-exp"
1920
gemini20FlashLite = "gemini-2.0-flash-lite"
2021
gemini20FlashLitePrev = "gemini-2.0-flash-lite-preview"
2122
gemini20ProExp0205 = "gemini-2.0-pro-exp-02-05"
@@ -45,13 +46,19 @@ var (
4546
gemini15Pro,
4647
gemini15Flash8b,
4748
gemini20Flash,
49+
gemini20FlashExp,
4850
gemini20FlashLitePrev,
4951
gemini20ProExp0205,
5052
gemini20FlashThinkingExp0121,
5153
gemini25ProExp0325,
5254
gemini25ProPreview0325,
5355
}
5456

57+
// models with native image support generation
58+
imageGenModels = []string{
59+
gemini20FlashExp,
60+
}
61+
5562
supportedGeminiModels = map[string]ai.ModelInfo{
5663
gemini15Flash: {
5764
Label: "Gemini 1.5 Flash",
@@ -90,6 +97,12 @@ var (
9097
Supports: &Multimodal,
9198
Stage: ai.ModelStageStable,
9299
},
100+
gemini20FlashExp: {
101+
Label: "Gemini 2.0 Flash Exp",
102+
Versions: []string{},
103+
Supports: &Multimodal,
104+
Stage: ai.ModelStageUnstable,
105+
},
93106
gemini20FlashLite: {
94107
Label: "Gemini 2.0 Flash Lite",
95108
Versions: []string{

go/samples/imagen-gemini/main.go

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import (
18+
"context"
19+
"encoding/base64"
20+
"errors"
21+
"fmt"
22+
"log"
23+
"os"
24+
25+
"github.com/firebase/genkit/go/ai"
26+
"github.com/firebase/genkit/go/genkit"
27+
"github.com/firebase/genkit/go/plugins/googlegenai"
28+
)
29+
30+
func main() {
31+
ctx := context.Background()
32+
33+
// Initialize Genkit with the Google AI plugin. When you pass nil for the
34+
// Config parameter, the Google AI plugin will get the API key from the
35+
// GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended
36+
// practice.
37+
g, err := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{}))
38+
if err != nil {
39+
log.Fatal(err)
40+
}
41+
42+
// Define a simple flow that generates an image of a given topic
43+
genkit.DefineFlow(g, "imageFlow", func(ctx context.Context, input string) (string, error) {
44+
m := googlegenai.GoogleAIModel(g, "gemini-2.0-flash-exp")
45+
if m == nil {
46+
return "", errors.New("imageFlow: failed to find model")
47+
}
48+
49+
if input == "" {
50+
input = `A little blue gopher with big eyes trying to learn Python,
51+
use a cartoon style, the story should be tragic because he
52+
chose the wrong programming language, the proper programing
53+
language for a gopher should be Go`
54+
}
55+
resp, err := genkit.Generate(ctx, g,
56+
ai.WithModel(m),
57+
ai.WithConfig(&googlegenai.GeminiConfig{
58+
Temperature: 0.5,
59+
ResponseModalities: []googlegenai.Modality{
60+
googlegenai.ImageMode,
61+
googlegenai.TextMode,
62+
},
63+
}),
64+
ai.WithPrompt(fmt.Sprintf(`generate a story about %s and for each scene, generate an image for it`, input)))
65+
if err != nil {
66+
return "", err
67+
}
68+
69+
story := ""
70+
scene := 0
71+
for _, p := range resp.Message.Content {
72+
if p.IsMedia() {
73+
scene += 1
74+
err = base64toFile(p.Text, fmt.Sprintf("scene_%d.png", scene))
75+
}
76+
if p.IsText() {
77+
story += p.Text
78+
}
79+
}
80+
if err != nil {
81+
return "", err
82+
}
83+
84+
return story, nil
85+
})
86+
87+
<-ctx.Done()
88+
}
89+
90+
func base64toFile(data, path string) error {
91+
dec, err := base64.StdEncoding.DecodeString(data)
92+
if err != nil {
93+
return err
94+
}
95+
f, err := os.Create(path)
96+
if err != nil {
97+
return err
98+
}
99+
defer f.Close()
100+
101+
_, err = f.Write(dec)
102+
if err != nil {
103+
return err
104+
}
105+
106+
return f.Sync()
107+
}

0 commit comments

Comments
 (0)