Skip to content

Commit 27b8eeb

Browse files
Merge pull request GoogleCloudPlatform#3644 from justinsb/llm_completion_method
codebot: add support for completion method
2 parents 384834c + 11db6f9 commit 27b8eeb

File tree

6 files changed

+315
-23
lines changed

6 files changed

+315
-23
lines changed

dev/tools/controllerbuilder/pkg/commands/exportcsv/prompt.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,12 @@ func RunPrompt(ctx context.Context, o *PromptOptions) error {
139139
}
140140

141141
dataPoint := dataPoints[0]
142+
dataPoint.Output = ""
142143

143144
log.Info("built data point", "dataPoint", dataPoint)
144145

145146
out := &bytes.Buffer{}
146-
if err := x.RunGemini(ctx, dataPoint, out); err != nil {
147+
if err := x.InferOutput_WithCompletion(ctx, dataPoint, out); err != nil {
147148
return fmt.Errorf("running LLM inference: %w", err)
148149

149150
}

dev/tools/controllerbuilder/pkg/llm/gemini.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,25 @@ func BuildGeminiClient(ctx context.Context) (Client, error) {
3939

4040
return &GeminiClient{
4141
client: client,
42+
model: "gemini-2.0-pro-exp-02-05",
4243
}, nil
4344
}
4445

4546
type GeminiClient struct {
4647
client *genai.Client
48+
model string
4749
}
4850

4951
func (c *GeminiClient) Close() error {
5052
return c.client.Close()
5153
}
5254

55+
func (c *GeminiClient) GenerateCompletion(ctx context.Context, request *CompletionRequest) (CompletionResponse, error) {
56+
return nil, fmt.Errorf("GeminiClient::GenerateCompletion not implemented")
57+
}
58+
5359
func (c *GeminiClient) StartChat(systemPrompt string) Chat {
54-
model := c.client.GenerativeModel("gemini-2.0-flash-exp")
60+
model := c.client.GenerativeModel(c.model)
5561
// model := c.client.GenerativeModel("gemini-1.5-pro-002")
5662

5763
// Some values that are recommended by aistudio

dev/tools/controllerbuilder/pkg/llm/interfaces.go

+11
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222
type Client interface {
2323
io.Closer
2424
StartChat(systemPrompt string) Chat
25+
26+
GenerateCompletion(ctx context.Context, req *CompletionRequest) (CompletionResponse, error)
2527
}
2628

2729
type Chat interface {
@@ -44,3 +46,12 @@ type Part interface {
4446
AsText() (string, bool)
4547
AsFunctionCalls() ([]FunctionCall, bool)
4648
}
49+
50+
type CompletionRequest struct {
51+
Prompt string
52+
}
53+
54+
type CompletionResponse interface {
55+
Response() string
56+
UsageMetadata() any
57+
}

dev/tools/controllerbuilder/pkg/llm/ollama.go

+133-7
Original file line numberDiff line numberDiff line change
@@ -41,29 +41,32 @@ func BuildOllamaClient(ctx context.Context) (*OllamaClient, error) {
4141
}
4242
klog.Infof("using ollama with base url %v", baseURL.String())
4343

44+
model := os.Getenv("OLLAMA_MODEL")
45+
if model == "" {
46+
klog.Fatalf("OLLAMA_MODEL not set")
47+
}
48+
4449
return &OllamaClient{
4550
baseURL: baseURL,
4651
httpClient: http.DefaultClient,
52+
model: model,
4753
}, nil
4854
}
4955

5056
type OllamaClient struct {
5157
baseURL *url.URL
5258
httpClient *http.Client
59+
model string
5360
}
5461

5562
func (c *OllamaClient) Close() error {
5663
return nil
5764
}
5865

5966
func (c *OllamaClient) StartChat(systemPrompt string) Chat {
60-
session := &chatRequest{}
61-
62-
model := os.Getenv("OLLAMA_MODEL")
63-
if model == "" {
64-
klog.Fatalf("OLLAMA_MODEL not set")
67+
session := &chatRequest{
68+
Model: c.model,
6569
}
66-
session.Model = model
6770

6871
// HACK: Setting the system prompt seems to really mess up some ollama models
6972
// session.Messages = append(session.Messages, chatMessage{
@@ -114,6 +117,51 @@ type chatResponse struct {
114117
EvalDuration int64 `json:"eval_duration"`
115118
}
116119

120+
type completionRequest struct {
121+
// model: (required) the model name
122+
Model string `json:"model,omitempty"`
123+
// prompt: the prompt to generate a response for
124+
Prompt string `json:"prompt,omitempty"`
125+
126+
// suffix: the text after the model response
127+
128+
// images: (optional) a list of base64-encoded images (for multimodal models such as llava)
129+
130+
// format: the format to return a response in. Format can be json or a JSON schema
131+
132+
// options: additional model parameters listed in the documentation for the Modelfile such as temperature
133+
Options map[string]any `json:"options,omitempty"`
134+
135+
// system: system message to (overrides what is defined in the Modelfile)
136+
137+
// template: the prompt template to use (overrides what is defined in the Modelfile)
138+
139+
// stream: if false the response will be returned as a single response object, rather than a stream of objects
140+
Stream *bool `json:"stream,omitempty"`
141+
142+
// raw: if true no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API
143+
144+
// keep_alive: controls how long the model will stay loaded into memory following the request (default: 5m)
145+
146+
// context (deprecated): the context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory
147+
}
148+
149+
type completionResponse struct {
150+
Model string `json:"model"`
151+
CreatedAt string `json:"created_at"`
152+
Response string `json:"response"`
153+
Done bool `json:"done"`
154+
155+
// "context": [1, 2, 3],
156+
157+
TotalDuration int64 `json:"total_duration"`
158+
LoadDuration int64 `json:"load_duration"`
159+
PromptEvalCount int64 `json:"prompt_eval_count"`
160+
PromptEvalDuration int64 `json:"prompt_eval_duration"`
161+
EvalCount int64 `json:"eval_count"`
162+
EvalDuration int64 `json:"eval_duration"`
163+
}
164+
117165
type chatMessage struct {
118166
// role: the role of the message, either system, user, assistant, or tool
119167
Role string `json:"role,omitempty"`
@@ -196,7 +244,9 @@ func (c *OllamaChat) SendMessage(ctx context.Context, parts ...string) (Response
196244
Role: "user",
197245
Content: part,
198246
})
247+
klog.Infof("sending user:\n%v", part)
199248
}
249+
200250
ollamaResponse, err := c.client.doChat(ctx, c.session)
201251
if err != nil {
202252
return nil, err
@@ -213,6 +263,68 @@ func (c *OllamaChat) SendMessage(ctx context.Context, parts ...string) (Response
213263
return response, nil
214264
}
215265

266+
func (c *OllamaClient) GenerateCompletion(ctx context.Context, request *CompletionRequest) (CompletionResponse, error) {
267+
ollamaRequest := &completionRequest{
268+
Model: c.model,
269+
Prompt: request.Prompt,
270+
Options: map[string]any{
271+
"num_ctx": 128 * 1024,
272+
},
273+
}
274+
275+
ollamaResponse, err := c.doCompletion(ctx, ollamaRequest)
276+
if err != nil {
277+
return nil, err
278+
}
279+
280+
if ollamaResponse.Response == "" {
281+
return nil, fmt.Errorf("no response returned from ollama")
282+
}
283+
284+
response := &OllamaCompletionResponse{ollamaResponse: ollamaResponse}
285+
return response, nil
286+
}
287+
288+
func (c *OllamaClient) doCompletion(ctx context.Context, req *completionRequest) (*completionResponse, error) {
289+
stream := false
290+
req.Stream = &stream
291+
292+
body, err := json.Marshal(req)
293+
if err != nil {
294+
return nil, fmt.Errorf("building json body: %w", err)
295+
}
296+
u := c.baseURL.JoinPath("api", "generate")
297+
klog.V(2).Infof("sending POST request to %v: %v", u.String(), string(body))
298+
httpRequest, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body))
299+
if err != nil {
300+
return nil, fmt.Errorf("building http request: %w", err)
301+
}
302+
httpRequest.Header.Set("Content-Type", "application/json")
303+
304+
httpResponse, err := c.httpClient.Do(httpRequest)
305+
if err != nil {
306+
return nil, fmt.Errorf("performing http request: %w", err)
307+
}
308+
defer httpResponse.Body.Close()
309+
310+
b, err := io.ReadAll(httpResponse.Body)
311+
if err != nil {
312+
return nil, fmt.Errorf("reading response body: %w", err)
313+
}
314+
315+
klog.Infof("response is: %v", string(b))
316+
317+
if httpResponse.StatusCode != 200 {
318+
return nil, fmt.Errorf("unexpected http status: %q with response %q", httpResponse.Status, string(b))
319+
}
320+
321+
completionResponse := &completionResponse{}
322+
if err := json.Unmarshal(b, completionResponse); err != nil {
323+
return nil, fmt.Errorf("unmarshalling json response: %w", err)
324+
}
325+
return completionResponse, nil
326+
}
327+
216328
func (c *OllamaClient) doChat(ctx context.Context, req *chatRequest) (*chatResponse, error) {
217329
stream := false
218330
req.Stream = &stream
@@ -222,7 +334,7 @@ func (c *OllamaClient) doChat(ctx context.Context, req *chatRequest) (*chatRespo
222334
return nil, fmt.Errorf("building json body: %w", err)
223335
}
224336
u := c.baseURL.JoinPath("api", "chat")
225-
klog.Infof("sending POST request to %v: %v", u.String(), string(body))
337+
klog.V(2).Infof("sending POST request to %v: %v", u.String(), string(body))
226338
httpRequest, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body))
227339
if err != nil {
228340
return nil, fmt.Errorf("building http request: %w", err)
@@ -323,3 +435,17 @@ func (p *OllamaPart) AsFunctionCalls() ([]FunctionCall, bool) {
323435
}
324436
return functionCalls, true
325437
}
438+
439+
type OllamaCompletionResponse struct {
440+
ollamaResponse *completionResponse
441+
}
442+
443+
var _ CompletionResponse = &OllamaCompletionResponse{}
444+
445+
func (r *OllamaCompletionResponse) Response() string {
446+
return r.ollamaResponse.Response
447+
}
448+
449+
func (r *OllamaCompletionResponse) UsageMetadata() any {
450+
return r.ollamaResponse
451+
}

dev/tools/controllerbuilder/pkg/llm/vertexai.go

+61-7
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,24 @@ func BuildVertexAIClient(ctx context.Context) (*VertexAIClient, error) {
6060
if err != nil {
6161
return nil, fmt.Errorf("building vertexai client: %w", err)
6262
}
63-
return &VertexAIClient{client: client}, nil
63+
model := "gemini-2.0-pro-exp-02-05"
64+
return &VertexAIClient{
65+
client: client,
66+
model: model,
67+
}, nil
6468
}
6569

6670
type VertexAIClient struct {
6771
client *genai.Client
72+
model string
6873
}
6974

7075
func (c *VertexAIClient) Close() error {
7176
return c.client.Close()
7277
}
7378

7479
func (c *VertexAIClient) StartChat(systemPrompt string) Chat {
75-
// model := c.client.GenerativeModel("vertexai-1.5-flash")
76-
// model := c.client.GenerativeModel("vertexai-exp-1206")
77-
// model := c.client.GenerativeModel("gemini-2.0-flash-exp")
78-
model := c.client.GenerativeModel("gemini-2.0-pro-exp-02-05")
79-
// model := c.client.GenerativeModel("gemma-2-27b-it")
80-
// model := c.client.GenerativeModel("gemini-1.5-pro-002")
80+
model := c.client.GenerativeModel(c.model)
8181

8282
// Some values that are recommended by aistudio
8383
model.SetTemperature(1)
@@ -173,6 +173,45 @@ func toVertexAISchema(schema *Schema) (*genai.Schema, error) {
173173
// })
174174
// }
175175

176+
func (c *VertexAIClient) GenerateCompletion(ctx context.Context, request *CompletionRequest) (CompletionResponse, error) {
177+
log := klog.FromContext(ctx)
178+
179+
model := c.client.GenerativeModel(c.model)
180+
181+
var vertexaiParts []genai.Part
182+
183+
vertexaiParts = append(vertexaiParts, genai.Text(request.Prompt))
184+
185+
log.Info("sending GenerateContent request to vertexai", "parts", vertexaiParts)
186+
vertexaiResponse, err := model.GenerateContent(ctx, vertexaiParts...)
187+
if err != nil {
188+
return nil, err
189+
}
190+
191+
if len(vertexaiResponse.Candidates) > 1 {
192+
klog.Infof("only considering first candidate")
193+
for i := 1; i < len(vertexaiResponse.Candidates); i++ {
194+
candidate := vertexaiResponse.Candidates[i]
195+
klog.Infof("ignoring candidate: %q", candidate.Content)
196+
}
197+
}
198+
var response strings.Builder
199+
candidate := vertexaiResponse.Candidates[0]
200+
for _, part := range candidate.Content.Parts {
201+
switch part := part.(type) {
202+
case genai.Text:
203+
if response.Len() != 0 {
204+
response.WriteString("\n")
205+
}
206+
response.WriteString(string(part))
207+
default:
208+
return nil, fmt.Errorf("unexpected type of content part: %T", part)
209+
}
210+
}
211+
212+
return &VertexAICompletionResponse{vertexaiResponse: vertexaiResponse, text: response.String()}, nil
213+
}
214+
176215
func (c *VertexAIChat) SendMessage(ctx context.Context, parts ...string) (Response, error) {
177216
log := klog.FromContext(ctx)
178217
var vertexaiParts []genai.Part
@@ -256,3 +295,18 @@ func (p *VertexAIPart) AsFunctionCalls() ([]FunctionCall, bool) {
256295
}
257296
return nil, false
258297
}
298+
299+
type VertexAICompletionResponse struct {
300+
vertexaiResponse *genai.GenerateContentResponse
301+
text string
302+
}
303+
304+
var _ CompletionResponse = &VertexAICompletionResponse{}
305+
306+
func (r *VertexAICompletionResponse) Response() string {
307+
return r.text
308+
}
309+
310+
func (r *VertexAICompletionResponse) UsageMetadata() any {
311+
return r.vertexaiResponse.UsageMetadata
312+
}

0 commit comments

Comments
 (0)