Skip to content

Commit bc9897b

Browse files
Merge pull request GoogleCloudPlatform#3096 from justinsb/nothing_but_gemini
feat: add gemini prompting to controllerbuilder
2 parents 3dd1360 + 01aed36 commit bc9897b

16 files changed

+1001
-29
lines changed

dev/tools/controllerbuilder/cmd/root.go

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"os"
2020
"strings"
2121

22+
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/commands/exportcsv"
2223
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/commands/generatemapper"
2324
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/commands/generatetypes"
2425
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/commands/updatetypes"
@@ -94,6 +95,8 @@ func Execute() {
9495
rootCmd.AddCommand(generatetypes.BuildCommand(&generateOptions))
9596
rootCmd.AddCommand(generatemapper.BuildCommand(&generateOptions))
9697
rootCmd.AddCommand(updatetypes.BuildCommand(&generateOptions))
98+
rootCmd.AddCommand(exportcsv.BuildCommand(&generateOptions))
99+
rootCmd.AddCommand(exportcsv.BuildPromptCommand(&generateOptions))
97100

98101
if err := rootCmd.Execute(); err != nil {
99102
fmt.Fprintf(os.Stderr, "%v\n", err)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package exportcsv
16+
17+
import (
18+
"context"
19+
"fmt"
20+
"os"
21+
"strings"
22+
23+
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/options"
24+
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/toolbot"
25+
26+
"github.com/spf13/cobra"
27+
)
28+
29+
// ExportCSVOptions are the options for the export-csv command.
30+
type ExportCSVOptions struct {
31+
*options.GenerateOptions
32+
33+
ProtoDir string
34+
SrcDir string
35+
OutputDir string
36+
}
37+
38+
// BindFlags binds the flags to the command.
39+
func (o *ExportCSVOptions) BindFlags(cmd *cobra.Command) {
40+
cmd.Flags().StringVar(&o.ProtoDir, "proto-dir", o.ProtoDir, "base directory for checkout of proto API definitions")
41+
cmd.Flags().StringVar(&o.SrcDir, "src-dir", o.SrcDir, "base directory for source code")
42+
cmd.Flags().StringVar(&o.OutputDir, "output-dir", o.OutputDir, "base directory for writing CSVs")
43+
}
44+
45+
// BuildCommand builds the export-csv command.
46+
func BuildCommand(baseOptions *options.GenerateOptions) *cobra.Command {
47+
opt := &ExportCSVOptions{
48+
GenerateOptions: baseOptions,
49+
}
50+
51+
cmd := &cobra.Command{
52+
Use: "export-csv",
53+
Short: "generate CSV from tool annotations",
54+
RunE: func(cmd *cobra.Command, args []string) error {
55+
ctx := cmd.Context()
56+
if err := RunExportCSV(ctx, opt); err != nil {
57+
return err
58+
}
59+
return nil
60+
},
61+
}
62+
63+
opt.BindFlags(cmd)
64+
65+
return cmd
66+
}
67+
68+
// rewriteFilePath rewrites the file path to the user's home directory if it starts with "~".
69+
func rewriteFilePath(p *string) error {
70+
if strings.HasPrefix(*p, "~/") {
71+
homeDir, err := os.UserHomeDir()
72+
if err != nil {
73+
return fmt.Errorf("getting home directory: %w", err)
74+
}
75+
*p = strings.Replace(*p, "~", homeDir, 1)
76+
}
77+
return nil
78+
}
79+
80+
// RunExportCSV runs the export-csv command.
81+
func RunExportCSV(ctx context.Context, o *ExportCSVOptions) error {
82+
if err := rewriteFilePath(&o.ProtoDir); err != nil {
83+
return err
84+
}
85+
86+
if o.ProtoDir == "" {
87+
return fmt.Errorf("--proto-dir is required")
88+
}
89+
if o.SrcDir == "" {
90+
return fmt.Errorf("--src-dir is required")
91+
}
92+
if o.OutputDir == "" {
93+
return fmt.Errorf("--output-dir is required")
94+
}
95+
96+
extractor := &toolbot.ExtractToolMarkers{}
97+
addProtoDefinition, err := toolbot.NewEnhanceWithProtoDefinition(o.ProtoDir)
98+
if err != nil {
99+
return err
100+
}
101+
x, err := toolbot.NewCSVExporter(extractor, addProtoDefinition)
102+
if err != nil {
103+
return err
104+
}
105+
if err := x.VisitCodeDir(ctx, o.SrcDir); err != nil {
106+
return err
107+
}
108+
109+
if err := x.WriteCSVForAllTools(ctx, o.OutputDir); err != nil {
110+
return err
111+
}
112+
113+
return nil
114+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package exportcsv
16+
17+
import (
18+
"context"
19+
"fmt"
20+
"io"
21+
"os"
22+
23+
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/options"
24+
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/toolbot"
25+
"k8s.io/klog/v2"
26+
27+
"github.com/spf13/cobra"
28+
)
29+
30+
// PromptOptions are the options for the prompt command.
31+
type PromptOptions struct {
32+
*options.GenerateOptions
33+
34+
ProtoDir string
35+
SrcDir string
36+
}
37+
38+
// BindFlags binds the flags to the command.
39+
func (o *PromptOptions) BindFlags(cmd *cobra.Command) {
40+
cmd.Flags().StringVar(&o.SrcDir, "src-dir", o.SrcDir, "base directory for source code")
41+
cmd.Flags().StringVar(&o.ProtoDir, "proto-dir", o.ProtoDir, "base directory for checkout of proto API definitions")
42+
}
43+
44+
// BuildPromptCommand builds the `prompt` command.
45+
func BuildPromptCommand(baseOptions *options.GenerateOptions) *cobra.Command {
46+
opt := &PromptOptions{
47+
GenerateOptions: baseOptions,
48+
}
49+
50+
cmd := &cobra.Command{
51+
Use: "prompt",
52+
Short: "executes a prompt against Gemini, generating context based on the source code.",
53+
RunE: func(cmd *cobra.Command, args []string) error {
54+
ctx := cmd.Context()
55+
if err := RunPrompt(ctx, opt); err != nil {
56+
return err
57+
}
58+
return nil
59+
},
60+
}
61+
62+
opt.BindFlags(cmd)
63+
64+
return cmd
65+
}
66+
67+
// RunPrompt runs the `prompt` command.
68+
func RunPrompt(ctx context.Context, o *PromptOptions) error {
69+
log := klog.FromContext(ctx)
70+
71+
if err := rewriteFilePath(&o.ProtoDir); err != nil {
72+
return err
73+
}
74+
75+
if o.ProtoDir == "" {
76+
return fmt.Errorf("--proto-dir is required")
77+
}
78+
extractor := &toolbot.ExtractToolMarkers{}
79+
addProtoDefinition, err := toolbot.NewEnhanceWithProtoDefinition(o.ProtoDir)
80+
if err != nil {
81+
return err
82+
}
83+
x, err := toolbot.NewCSVExporter(extractor, addProtoDefinition)
84+
if err != nil {
85+
return err
86+
}
87+
88+
if o.SrcDir != "" {
89+
if err := x.VisitCodeDir(ctx, o.SrcDir); err != nil {
90+
return err
91+
}
92+
}
93+
94+
b, err := io.ReadAll(os.Stdin)
95+
if err != nil {
96+
return fmt.Errorf("reading from stdin: %w", err)
97+
}
98+
99+
dataPoints, err := x.BuildDataPoints(ctx, b)
100+
if err != nil {
101+
return err
102+
}
103+
104+
if len(dataPoints) != 1 {
105+
return fmt.Errorf("expected exactly one data point, got %d", len(dataPoints))
106+
}
107+
108+
dataPoint := dataPoints[0]
109+
110+
log.Info("built data point", "dataPoint", dataPoint)
111+
112+
if err := x.RunGemini(ctx, dataPoint, os.Stdout); err != nil {
113+
return fmt.Errorf("running LLM inference: %w", err)
114+
115+
}
116+
return nil
117+
}

0 commit comments

Comments
 (0)