@@ -22,6 +22,7 @@ import (
22
22
"io"
23
23
"os"
24
24
"path/filepath"
25
+ "strings"
25
26
26
27
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/llm"
27
28
"k8s.io/apimachinery/pkg/util/sets"
@@ -189,8 +190,21 @@ func (x *CSVExporter) BuildDataPoints(ctx context.Context, description string, s
189
190
return dataPoints , nil
190
191
}
191
192
192
- // RunGemini runs a prompt against Gemini, generating context based on the source code.
193
- func (x * CSVExporter ) RunGemini (ctx context.Context , input * DataPoint , out io.Writer ) error {
193
+ // pickExamples returns the examples we should feed into the promp
194
+ func (x * CSVExporter ) pickExamples (input * DataPoint ) []* DataPoint {
195
+ var examples []* DataPoint
196
+ // We only include data points for the same tool as the input.
197
+ for _ , dataPoint := range x .dataPoints {
198
+ if dataPoint .Type != input .Type {
199
+ continue
200
+ }
201
+ examples = append (examples , dataPoint )
202
+ }
203
+ return examples
204
+ }
205
+
206
+ // InferOutput_WithChat tries to infer an output value, using the Chat LLM APIs.
207
+ func (x * CSVExporter ) InferOutput_WithChat (ctx context.Context , input * DataPoint , out io.Writer ) error {
194
208
log := klog .FromContext (ctx )
195
209
196
210
client , err := llm .BuildVertexAIClient (ctx )
@@ -202,14 +216,12 @@ func (x *CSVExporter) RunGemini(ctx context.Context, input *DataPoint, out io.Wr
202
216
systemPrompt := "" // TODO
203
217
chat := client .StartChat (systemPrompt )
204
218
219
+ examples := x .pickExamples (input )
220
+
205
221
var userParts []string
206
222
207
223
// We only include data points for the same tool as the input.
208
- for _ , dataPoint := range x .dataPoints {
209
- if dataPoint .Type != input .Type {
210
- continue
211
- }
212
-
224
+ for _ , dataPoint := range examples {
213
225
inputColumnKeys := dataPoint .InputColumnKeys ()
214
226
if x .StrictInputColumnKeys != nil && ! x .StrictInputColumnKeys .Equal (inputColumnKeys ) {
215
227
return fmt .Errorf ("unexpected input columns for %v; got %v, want %v" , dataPoint .Description , inputColumnKeys , x .StrictInputColumnKeys )
@@ -248,3 +260,85 @@ func (x *CSVExporter) RunGemini(ctx context.Context, input *DataPoint, out io.Wr
248
260
249
261
return nil
250
262
}
263
+
264
+ // InferOutput_WithCompletion tries to infer an output value, using the Completion LLM APIs.
265
+ func (x * CSVExporter ) InferOutput_WithCompletion (ctx context.Context , input * DataPoint , out io.Writer ) error {
266
+ log := klog .FromContext (ctx )
267
+
268
+ client , err := llm .BuildVertexAIClient (ctx )
269
+ if err != nil {
270
+ return fmt .Errorf ("building gemini client: %w" , err )
271
+ }
272
+ defer client .Close ()
273
+
274
+ var prompt strings.Builder
275
+
276
+ fmt .Fprintf (& prompt , "I'm implementing a mock for a proto API. I need to implement go code that implements the proto service. Here are some examples:\n " )
277
+
278
+ examples := x .pickExamples (input )
279
+
280
+ for _ , dataPoint := range examples {
281
+ inputColumnKeys := dataPoint .InputColumnKeys ()
282
+ if x .StrictInputColumnKeys != nil && ! x .StrictInputColumnKeys .Equal (inputColumnKeys ) {
283
+ return fmt .Errorf ("unexpected input columns for %v; got %v, want %v" , dataPoint .Description , inputColumnKeys , x .StrictInputColumnKeys )
284
+ }
285
+
286
+ s := dataPoint .ToGenAIFormat ()
287
+ s = "<example>\n " + s + "\n </example>\n \n "
288
+ fmt .Fprintf (& prompt , "\n %s\n \n " , s )
289
+ }
290
+
291
+ {
292
+ // Prompt with the input data point.
293
+ s := input .ToGenAIFormat ()
294
+ // We also include the beginning of the output for Gemini to fill in.
295
+ s += "<out>\n ```go\n "
296
+ s = "<example>\n " + s
297
+ fmt .Fprintf (& prompt , "\n Can you complete the item? Don't output any additional commentary.\n \n %s" , s )
298
+ }
299
+
300
+ log .Info ("sending completion request" , "prompt" , prompt .String ())
301
+
302
+ resp , err := client .GenerateCompletion (ctx , & llm.CompletionRequest {
303
+ Prompt : prompt .String (),
304
+ })
305
+ if err != nil {
306
+ return fmt .Errorf ("generating content with gemini: %w" , err )
307
+ }
308
+
309
+ // Print the usage metadata (includes token count i.e. cost)
310
+ klog .Infof ("UsageMetadata: %+v" , resp .UsageMetadata ())
311
+
312
+ text := resp .Response ()
313
+
314
+ lines := strings .Split (strings .TrimSpace (text ), "\n " )
315
+
316
+ // Remove some of the decoration
317
+ for len (lines ) > 1 {
318
+ if lines [0 ] == "```go" {
319
+ lines = lines [1 :]
320
+ continue
321
+ }
322
+
323
+ if lines [len (lines )- 1 ] == "```" {
324
+ lines = lines [:len (lines )- 1 ]
325
+ continue
326
+ }
327
+
328
+ if lines [len (lines )- 1 ] == "</out>" {
329
+ lines = lines [:len (lines )- 1 ]
330
+ continue
331
+ }
332
+
333
+ if lines [len (lines )- 1 ] == "</example>" {
334
+ lines = lines [:len (lines )- 1 ]
335
+ continue
336
+ }
337
+ break
338
+ }
339
+
340
+ text = strings .Join (lines , "\n " )
341
+ out .Write ([]byte (text + "\n " ))
342
+
343
+ return nil
344
+ }
0 commit comments