Skip to content

Commit 11db6f9

Browse files
committed
codebot: use completion API instead of chat API
1 parent a2bbf6f commit 11db6f9

File tree

2 files changed

+102
-8
lines changed

2 files changed

+102
-8
lines changed

dev/tools/controllerbuilder/pkg/commands/exportcsv/prompt.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ func RunPrompt(ctx context.Context, o *PromptOptions) error {
144144
log.Info("built data point", "dataPoint", dataPoint)
145145

146146
out := &bytes.Buffer{}
147-
if err := x.RunGemini(ctx, dataPoint, out); err != nil {
147+
if err := x.InferOutput_WithCompletion(ctx, dataPoint, out); err != nil {
148148
return fmt.Errorf("running LLM inference: %w", err)
149149

150150
}

dev/tools/controllerbuilder/pkg/toolbot/csv.go

+101-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"io"
2323
"os"
2424
"path/filepath"
25+
"strings"
2526

2627
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/llm"
2728
"k8s.io/apimachinery/pkg/util/sets"
@@ -189,8 +190,21 @@ func (x *CSVExporter) BuildDataPoints(ctx context.Context, description string, s
189190
return dataPoints, nil
190191
}
191192

192-
// RunGemini runs a prompt against Gemini, generating context based on the source code.
193-
func (x *CSVExporter) RunGemini(ctx context.Context, input *DataPoint, out io.Writer) error {
193+
// pickExamples returns the examples we should feed into the promp
194+
func (x *CSVExporter) pickExamples(input *DataPoint) []*DataPoint {
195+
var examples []*DataPoint
196+
// We only include data points for the same tool as the input.
197+
for _, dataPoint := range x.dataPoints {
198+
if dataPoint.Type != input.Type {
199+
continue
200+
}
201+
examples = append(examples, dataPoint)
202+
}
203+
return examples
204+
}
205+
206+
// InferOutput_WithChat tries to infer an output value, using the Chat LLM APIs.
207+
func (x *CSVExporter) InferOutput_WithChat(ctx context.Context, input *DataPoint, out io.Writer) error {
194208
log := klog.FromContext(ctx)
195209

196210
client, err := llm.BuildVertexAIClient(ctx)
@@ -202,14 +216,12 @@ func (x *CSVExporter) RunGemini(ctx context.Context, input *DataPoint, out io.Wr
202216
systemPrompt := "" // TODO
203217
chat := client.StartChat(systemPrompt)
204218

219+
examples := x.pickExamples(input)
220+
205221
var userParts []string
206222

207223
// We only include data points for the same tool as the input.
208-
for _, dataPoint := range x.dataPoints {
209-
if dataPoint.Type != input.Type {
210-
continue
211-
}
212-
224+
for _, dataPoint := range examples {
213225
inputColumnKeys := dataPoint.InputColumnKeys()
214226
if x.StrictInputColumnKeys != nil && !x.StrictInputColumnKeys.Equal(inputColumnKeys) {
215227
return fmt.Errorf("unexpected input columns for %v; got %v, want %v", dataPoint.Description, inputColumnKeys, x.StrictInputColumnKeys)
@@ -248,3 +260,85 @@ func (x *CSVExporter) RunGemini(ctx context.Context, input *DataPoint, out io.Wr
248260

249261
return nil
250262
}
263+
264+
// InferOutput_WithCompletion tries to infer an output value, using the Completion LLM APIs.
265+
func (x *CSVExporter) InferOutput_WithCompletion(ctx context.Context, input *DataPoint, out io.Writer) error {
266+
log := klog.FromContext(ctx)
267+
268+
client, err := llm.BuildVertexAIClient(ctx)
269+
if err != nil {
270+
return fmt.Errorf("building gemini client: %w", err)
271+
}
272+
defer client.Close()
273+
274+
var prompt strings.Builder
275+
276+
fmt.Fprintf(&prompt, "I'm implementing a mock for a proto API. I need to implement go code that implements the proto service. Here are some examples:\n")
277+
278+
examples := x.pickExamples(input)
279+
280+
for _, dataPoint := range examples {
281+
inputColumnKeys := dataPoint.InputColumnKeys()
282+
if x.StrictInputColumnKeys != nil && !x.StrictInputColumnKeys.Equal(inputColumnKeys) {
283+
return fmt.Errorf("unexpected input columns for %v; got %v, want %v", dataPoint.Description, inputColumnKeys, x.StrictInputColumnKeys)
284+
}
285+
286+
s := dataPoint.ToGenAIFormat()
287+
s = "<example>\n" + s + "\n</example>\n\n"
288+
fmt.Fprintf(&prompt, "\n%s\n\n", s)
289+
}
290+
291+
{
292+
// Prompt with the input data point.
293+
s := input.ToGenAIFormat()
294+
// We also include the beginning of the output for Gemini to fill in.
295+
s += "<out>\n```go\n"
296+
s = "<example>\n" + s
297+
fmt.Fprintf(&prompt, "\nCan you complete the item? Don't output any additional commentary.\n\n%s", s)
298+
}
299+
300+
log.Info("sending completion request", "prompt", prompt.String())
301+
302+
resp, err := client.GenerateCompletion(ctx, &llm.CompletionRequest{
303+
Prompt: prompt.String(),
304+
})
305+
if err != nil {
306+
return fmt.Errorf("generating content with gemini: %w", err)
307+
}
308+
309+
// Print the usage metadata (includes token count i.e. cost)
310+
klog.Infof("UsageMetadata: %+v", resp.UsageMetadata())
311+
312+
text := resp.Response()
313+
314+
lines := strings.Split(strings.TrimSpace(text), "\n")
315+
316+
// Remove some of the decoration
317+
for len(lines) > 1 {
318+
if lines[0] == "```go" {
319+
lines = lines[1:]
320+
continue
321+
}
322+
323+
if lines[len(lines)-1] == "```" {
324+
lines = lines[:len(lines)-1]
325+
continue
326+
}
327+
328+
if lines[len(lines)-1] == "</out>" {
329+
lines = lines[:len(lines)-1]
330+
continue
331+
}
332+
333+
if lines[len(lines)-1] == "</example>" {
334+
lines = lines[:len(lines)-1]
335+
continue
336+
}
337+
break
338+
}
339+
340+
text = strings.Join(lines, "\n")
341+
out.Write([]byte(text + "\n"))
342+
343+
return nil
344+
}

0 commit comments

Comments
 (0)