Skip to content

Commit 4db732a

Browse files
Merge pull request GoogleCloudPlatform#3337 from jingyih/refactor
tool: correctly insert field into Spec or ObservedState.
2 parents 05b63d0 + cf7727e commit 4db732a

File tree

8 files changed

+325
-232
lines changed

8 files changed

+325
-232
lines changed

dev/tools/controllerbuilder/pkg/commands/updatetypes/updatetypescommand.go

+24-170
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,24 @@
1515
package updatetypes
1616

1717
import (
18-
"bytes"
18+
"context"
1919
"fmt"
2020
"os"
21-
"strings"
2221

23-
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/codegen"
24-
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/gocode"
2522
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/options"
23+
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/typeupdater"
2624

2725
"github.com/spf13/cobra"
28-
"google.golang.org/protobuf/proto"
29-
"google.golang.org/protobuf/reflect/protodesc"
30-
"google.golang.org/protobuf/reflect/protoreflect"
31-
"google.golang.org/protobuf/types/descriptorpb"
32-
"k8s.io/apimachinery/pkg/util/sets"
33-
"k8s.io/klog/v2"
3426
)
3527

36-
const kccProtoPrefix = "+kcc:proto="
37-
3828
type UpdateTypeOptions struct {
3929
*options.GenerateOptions
4030

41-
parentMessageFullName string
42-
newField string
43-
ignoredFields string // TODO: could be part of GenerateOptions
44-
apiDirectory string
45-
goPackagePath string
31+
parentNessage string // The fully qualified name of the parent prroto message of the field to be inserted
32+
fieldToInsert string
33+
ignoredFields string // TODO: could be part of GenerateOptions
34+
apiDirectory string
35+
goPackagePath string
4636
}
4737

4838
func (o *UpdateTypeOptions) InitDefaults() error {
@@ -56,8 +46,9 @@ func (o *UpdateTypeOptions) InitDefaults() error {
5646
}
5747

5848
func (o *UpdateTypeOptions) BindFlags(cmd *cobra.Command) {
59-
cmd.Flags().StringVar(&o.parentMessageFullName, "parent-message-full-name", o.parentMessageFullName, "Fully qualified name of the proto message holding the new field")
60-
cmd.Flags().StringVar(&o.newField, "new-field", o.newField, "Name of the new field")
49+
cmd.Flags().StringVar(&o.parentNessage, "parent-message", o.parentNessage, "Fully qualified name of the proto message holding the new field. e.g. `google.cloud.bigquery.datatransfer.v1.TransferConfig`")
50+
cmd.Flags().StringVar(&o.fieldToInsert, "field-to-insert", o.fieldToInsert, "Name of the new field to be inserted, e.g. `schedule_options_v2`")
51+
// TODO: Update this flag to accept a file path pointing to the ignored fields YAML file.
6152
cmd.Flags().StringVar(&o.ignoredFields, "ignored-fields", o.ignoredFields, "Comma-separated list of fields to ignore")
6253
cmd.Flags().StringVar(&o.apiDirectory, "api-dir", o.apiDirectory, "Base directory for APIs")
6354
cmd.Flags().StringVar(&o.goPackagePath, "api-go-package-path", o.goPackagePath, "Package path")
@@ -77,8 +68,8 @@ func BuildCommand(baseOptions *options.GenerateOptions) *cobra.Command {
7768
Use: "update-types",
7869
Short: "update KRM types for a proto service",
7970
RunE: func(cmd *cobra.Command, args []string) error {
80-
updater := NewTypeUpdater(opt)
81-
if err := updater.Run(); err != nil {
71+
ctx := cmd.Context()
72+
if err := runTypeUpdater(ctx, opt); err != nil {
8273
return err
8374
}
8475
return nil
@@ -90,159 +81,22 @@ func BuildCommand(baseOptions *options.GenerateOptions) *cobra.Command {
9081
return cmd
9182
}
9283

93-
type TypeUpdater struct {
94-
opts *UpdateTypeOptions
95-
newField newProtoField
96-
dependentMessages map[string]protoreflect.MessageDescriptor // key: fully qualified name of proto message
97-
generatedGoField generatedGoField // TODO: support multiple new fields
98-
generatedGoStructs []generatedGoStruct
99-
}
100-
101-
type newProtoField struct {
102-
field protoreflect.FieldDescriptor
103-
parentMessage protoreflect.MessageDescriptor
104-
}
105-
106-
type generatedGoField struct {
107-
parentMessage string // fully qualified name of the parent proto message of this field
108-
content []byte // the content of the generated Go field
109-
}
110-
111-
type generatedGoStruct struct {
112-
name string // fully qualified name of the proto message
113-
content []byte // the content of the generated Go struct
114-
}
115-
116-
func NewTypeUpdater(opts *UpdateTypeOptions) *TypeUpdater {
117-
return &TypeUpdater{
118-
opts: opts,
119-
}
120-
}
121-
122-
func (u *TypeUpdater) Run() error {
123-
// 1. find new field and its dependent proto messages that needs to be generated
124-
if err := u.analyze(); err != nil {
125-
return nil
84+
func runTypeUpdater(ctx context.Context, opt *UpdateTypeOptions) error {
85+
if opt.apiDirectory == "" {
86+
return fmt.Errorf("--api-dir is required")
12687
}
12788

128-
// 2. generate Go types for the new field and its dependent proto messages
129-
if err := u.generate(); err != nil {
130-
return err
131-
}
132-
133-
// 3. insert the generated Go code back to files
134-
if err := u.insertGoField(); err != nil {
135-
return err
89+
typeUpdaterOpts := &typeupdater.UpdaterOptions{
90+
ProtoSourcePath: opt.GenerateOptions.ProtoSourcePath,
91+
ParentMessageFullName: opt.parentNessage,
92+
FieldToInsert: opt.fieldToInsert,
93+
IgnoredFields: opt.ignoredFields,
94+
APIDirectory: opt.apiDirectory,
95+
GoPackagePath: opt.goPackagePath,
13696
}
137-
if err := u.insertGoMessages(); err != nil {
138-
return err
139-
}
140-
141-
return nil
142-
}
143-
144-
// anaylze finds the new field, its parent message, and all dependent messages that need to be generated.
145-
func (u *TypeUpdater) analyze() error {
146-
parentMessage, newField, err := findNewField(u.opts.ProtoSourcePath, u.opts.parentMessageFullName, u.opts.newField)
147-
if err != nil {
148-
return err
149-
}
150-
u.newField = newProtoField{
151-
field: newField,
152-
parentMessage: parentMessage,
153-
}
154-
155-
msgs, err := findDependentMsgs(newField, sets.NewString(strings.Split(u.opts.ignoredFields, ",")...))
156-
if err != nil {
157-
return err
158-
}
159-
160-
codegen.RemoveNotMappedToGoStruct(msgs)
161-
162-
if err := removeAlreadyGenerated(u.opts.goPackagePath, u.opts.apiDirectory, msgs); err != nil {
97+
updater := typeupdater.NewTypeUpdater(typeUpdaterOpts)
98+
if err := updater.Run(); err != nil {
16399
return err
164100
}
165-
u.dependentMessages = msgs
166-
return nil
167-
}
168-
169-
// findNewField locates the parent message and the new field in the proto file
170-
func findNewField(pbSourcePath, parentMsgFullName, newFieldName string) (protoreflect.MessageDescriptor, protoreflect.FieldDescriptor, error) {
171-
fileData, err := os.ReadFile(pbSourcePath)
172-
if err != nil {
173-
return nil, nil, fmt.Errorf("reading %q: %w", pbSourcePath, err)
174-
}
175-
176-
fds := &descriptorpb.FileDescriptorSet{}
177-
if err := proto.Unmarshal(fileData, fds); err != nil {
178-
return nil, nil, fmt.Errorf("unmarshalling %q: %w", pbSourcePath, err)
179-
}
180-
181-
files, err := protodesc.NewFiles(fds)
182-
if err != nil {
183-
return nil, nil, fmt.Errorf("building file description: %w", err)
184-
}
185-
186-
// Find the parent message
187-
messageDesc, err := files.FindDescriptorByName(protoreflect.FullName(parentMsgFullName))
188-
if err != nil {
189-
return nil, nil, err
190-
}
191-
msgDesc, ok := messageDesc.(protoreflect.MessageDescriptor)
192-
if !ok {
193-
return nil, nil, fmt.Errorf("unexpected descriptor type: %T", msgDesc)
194-
}
195-
196-
// Find the new field in parent message
197-
fieldDesc := msgDesc.Fields().ByName(protoreflect.Name(newFieldName))
198-
if fieldDesc == nil {
199-
return nil, nil, fmt.Errorf("field not found in message")
200-
}
201-
202-
return msgDesc, fieldDesc, nil
203-
}
204-
205-
// findDependentMsgs finds all dependent proto messages for the given field, ignoring specified fields
206-
func findDependentMsgs(field protoreflect.FieldDescriptor, ignoredProtoFields sets.String) (map[string]protoreflect.MessageDescriptor, error) {
207-
deps := make(map[string]protoreflect.MessageDescriptor)
208-
codegen.FindDependenciesForField(field, deps, ignoredProtoFields)
209-
return deps, nil
210-
}
211-
212-
// removeAlreadyGenerated removes proto messages that have already been generated (including manually edited)
213-
func removeAlreadyGenerated(goPackagePath, outputAPIDirectory string, targets map[string]protoreflect.MessageDescriptor) error {
214-
packages, err := gocode.LoadPackageTree(goPackagePath, outputAPIDirectory)
215-
if err != nil {
216-
return err
217-
}
218-
for _, p := range packages {
219-
for _, s := range p.Structs {
220-
if annotation := s.GetAnnotation("+kcc:proto"); annotation != "" {
221-
delete(targets, annotation)
222-
}
223-
}
224-
}
225-
return nil
226-
}
227-
228-
func (u *TypeUpdater) generate() error {
229-
var buf bytes.Buffer
230-
klog.Infof("generate Go code for field %s", u.newField.field.Name())
231-
codegen.WriteField(&buf, u.newField.field, u.newField.parentMessage, 0)
232-
u.generatedGoField = generatedGoField{
233-
parentMessage: string(u.newField.parentMessage.FullName()),
234-
content: buf.Bytes(),
235-
}
236-
237-
for _, msg := range u.dependentMessages {
238-
var buf bytes.Buffer
239-
klog.Infof("generate Go code for messge %s", msg.FullName())
240-
codegen.WriteMessage(&buf, msg)
241-
u.generatedGoStructs = append(u.generatedGoStructs,
242-
generatedGoStruct{
243-
name: string(msg.FullName()),
244-
content: buf.Bytes(),
245-
})
246-
}
247101
return nil
248102
}

dev/tools/controllerbuilder/pkg/commands/updatetypes/insertfield-ast.go dev/tools/controllerbuilder/pkg/typeupdater/insertfield-ast.go

+43-20
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
package updatetypes
15+
package typeupdater
1616

1717
import (
1818
"fmt"
@@ -25,16 +25,24 @@ import (
2525
"strings"
2626

2727
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/gocode"
28+
"github.com/GoogleCloudPlatform/k8s-config-connector/pkg/controller/direct/common"
29+
"google.golang.org/genproto/googleapis/api/annotations"
2830

2931
"k8s.io/klog/v2"
3032
)
3133

34+
type target struct {
35+
goName string
36+
endPos int
37+
}
38+
3239
func (u *TypeUpdater) insertGoField() error {
33-
klog.Infof("inserting the generated Go code for field %s", u.newField.field.Name())
40+
klog.Infof("inserting the generated Go code for field %s", u.newField.proto.Name())
3441

35-
targetComment := fmt.Sprintf("+kcc:proto=%s", u.generatedGoField.parentMessage)
42+
targetComment := fmt.Sprintf("+kcc:proto=%s", u.newField.parent.FullName())
43+
outputOnly := common.IsFieldBehavior(u.newField.proto, annotations.FieldBehavior_OUTPUT_ONLY)
3644

37-
filepath.WalkDir(u.opts.apiDirectory, func(path string, d fs.DirEntry, err error) error {
45+
filepath.WalkDir(u.opts.APIDirectory, func(path string, d fs.DirEntry, err error) error {
3846
if err != nil || d.IsDir() || filepath.Ext(path) != ".go" {
3947
return nil
4048
}
@@ -55,13 +63,12 @@ func (u *TypeUpdater) insertGoField() error {
5563
// use a CommentMap to associate comments with nodes
5664
docMap := gocode.NewDocMap(fset, file)
5765

58-
// find the target Go struct and its ending position in the source
59-
var endPos int
66+
// find the target Go struct and the ending position in the source
67+
// there are 2 cases considered.
68+
// - case 1, there is only 1 matching target.
69+
// - case 2, there are two matching targets (Spec and ObservedState).
70+
var targets []target
6071
ast.Inspect(file, func(n ast.Node) bool {
61-
if endPos != 0 {
62-
return false // already found the target
63-
}
64-
6572
ts, ok := n.(*ast.TypeSpec)
6673
if !ok {
6774
return true
@@ -80,19 +87,35 @@ func (u *TypeUpdater) insertGoField() error {
8087
return true // empty struct? this should not happen
8188
}
8289

83-
klog.Infof("found target Go struct %s", ts.Name.Name)
84-
85-
endPos = int(fset.Position(st.End()).Offset)
86-
return false // stop searching, we found the target Go struct
90+
klog.Infof("found potential target Go struct %s", ts.Name.Name)
91+
targets = append(targets, target{
92+
goName: ts.Name.Name,
93+
endPos: int(fset.Position(st.End()).Offset),
94+
})
95+
return true // continue searching for potential target Go struct
8796
})
8897

89-
// if the target Go struct was found, modify the source bytes
90-
if endPos != 0 {
98+
var chosenTarget *target
99+
if len(targets) == 0 { // no target, continue to next file
100+
return nil
101+
} else if len(targets) == 1 { // case 1, one matching Go struct
102+
chosenTarget = &targets[0]
103+
} else if len(targets) == 2 { // case 2, Spec/ObservedState pair
104+
for _, t := range targets {
105+
if !outputOnly && strings.HasSuffix(t.goName, "Spec") ||
106+
outputOnly && strings.HasSuffix(t.goName, "ObservedState") {
107+
chosenTarget = &t
108+
break
109+
}
110+
}
111+
}
112+
113+
if chosenTarget != nil { // target Go struct was found, modify the source bytes
91114
var newSrcBytes []byte
92-
// TODO: ues the same field ordering as in proto message
93-
newSrcBytes = append(newSrcBytes, srcBytes[:endPos-1]...) // up to before '}'
94-
newSrcBytes = append(newSrcBytes, u.generatedGoField.content...) // insert new field
95-
newSrcBytes = append(newSrcBytes, srcBytes[endPos-1:]...) // include the '}'
115+
// TODO: use the same field ordering as in proto message?
116+
newSrcBytes = append(newSrcBytes, srcBytes[:chosenTarget.endPos-1]...) // up to before '}'
117+
newSrcBytes = append(newSrcBytes, u.newField.generatedContent...) // insert new field
118+
newSrcBytes = append(newSrcBytes, srcBytes[chosenTarget.endPos-1:]...) // include the '}'
96119

97120
if err := os.WriteFile(path, newSrcBytes, d.Type()); err != nil {
98121
return err

dev/tools/controllerbuilder/pkg/commands/updatetypes/insertfield-gemini.go dev/tools/controllerbuilder/pkg/typeupdater/insertfield-gemini.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
package updatetypes
15+
package typeupdater
1616

1717
import (
1818
"context"
@@ -28,7 +28,7 @@ import (
2828
)
2929

3030
func (u *TypeUpdater) insertGoFieldGemini() error {
31-
klog.Infof("inserting the generated Go code for field %s", u.newField.field.Name())
31+
klog.Infof("inserting the generated Go code for field %s", u.newField.proto.Name())
3232
ctx := context.Background()
3333
client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY")))
3434
if err != nil {
@@ -50,20 +50,20 @@ func (u *TypeUpdater) insertGoFieldGemini() error {
5050
Could you find the Go struct which has comment "+kcc:proto=%s" with no following suffix,
5151
and insert the Go field into the found Go struct.
5252
In your response, only include what is asked for.
53-
`, u.generatedGoField.parentMessage)),
53+
`, u.newField.parent.FullName())),
5454
},
5555
Role: "user",
5656
},
5757
}
5858
// provide the content of the new Go field
5959
session.History = append(session.History, &genai.Content{
6060
Parts: []genai.Part{
61-
genai.Text(fmt.Sprintf("new Go field:\n%s\n\n", u.generatedGoField.content)),
61+
genai.Text(fmt.Sprintf("new Go field:\n%s\n\n", u.newField.generatedContent)),
6262
},
6363
Role: "user",
6464
})
6565
// provide content of the existing Go files
66-
files, err := listFiles(u.opts.apiDirectory)
66+
files, err := listFiles(u.opts.APIDirectory)
6767
if err != nil {
6868
return fmt.Errorf("error listing files: %w", err)
6969
}

0 commit comments

Comments
 (0)