-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add MarkIfFlagPresentThenOthersRequired #2200
base: main
Are you sure you want to change the base?
Changes from all commits
ca3972e
4944350
23cadf7
ad5e8c3
cb45935
1c44c42
dd7bc42
a9bed3a
578b074
8d0e71e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,9 +23,10 @@ import ( | |
) | ||
|
||
const ( | ||
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set" | ||
oneRequiredAnnotation = "cobra_annotation_one_required" | ||
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive" | ||
annotationGroupRequired = "cobra_annotation_required_if_others_set" | ||
annotationRequiredOne = "cobra_annotation_one_required" | ||
annotationMutuallyExclusive = "cobra_annotation_mutually_exclusive" | ||
annotationGroupDependent = "cobra_annotation_if_present_then_others_required" | ||
) | ||
|
||
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors | ||
|
@@ -37,7 +38,7 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { | |
if f == nil { | ||
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v)) | ||
} | ||
if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil { | ||
if err := c.Flags().SetAnnotation(v, annotationGroupRequired, append(f.Annotations[annotationGroupRequired], strings.Join(flagNames, " "))); err != nil { | ||
// Only errs if the flag isn't found. | ||
panic(err) | ||
} | ||
|
@@ -53,7 +54,7 @@ func (c *Command) MarkFlagsOneRequired(flagNames ...string) { | |
if f == nil { | ||
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v)) | ||
} | ||
if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil { | ||
if err := c.Flags().SetAnnotation(v, annotationRequiredOne, append(f.Annotations[annotationRequiredOne], strings.Join(flagNames, " "))); err != nil { | ||
// Only errs if the flag isn't found. | ||
panic(err) | ||
} | ||
|
@@ -70,7 +71,26 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { | |
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) | ||
} | ||
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. | ||
if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil { | ||
if err := c.Flags().SetAnnotation(v, annotationMutuallyExclusive, append(f.Annotations[annotationMutuallyExclusive], strings.Join(flagNames, " "))); err != nil { | ||
panic(err) | ||
} | ||
} | ||
} | ||
|
||
// MarkIfFlagPresentThenOthersRequired marks the given flags so that if the first flag is set, | ||
// all the other flags become required. | ||
func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) { | ||
if len(flagNames) < 2 { | ||
panic("MarkIfFlagPresentThenRequired requires at least two flags") | ||
} | ||
c.mergePersistentFlags() | ||
for _, v := range flagNames { | ||
f := c.Flags().Lookup(v) | ||
if f == nil { | ||
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't you mention the method name as you do in the other panic? Or maybe, both shouldn't mention it |
||
} | ||
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. | ||
faizan-siddiqui marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if err := c.Flags().SetAnnotation(v, annotationGroupDependent, append(f.Annotations[annotationGroupDependent], strings.Join(flagNames, " "))); err != nil { | ||
panic(err) | ||
} | ||
} | ||
|
@@ -90,10 +110,12 @@ func (c *Command) ValidateFlagGroups() error { | |
groupStatus := map[string]map[string]bool{} | ||
oneRequiredGroupStatus := map[string]map[string]bool{} | ||
mutuallyExclusiveGroupStatus := map[string]map[string]bool{} | ||
ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{} | ||
flags.VisitAll(func(pflag *flag.Flag) { | ||
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenOthersRequiredGroupStatus) | ||
}) | ||
|
||
if err := validateRequiredFlagGroups(groupStatus); err != nil { | ||
|
@@ -105,6 +127,9 @@ func (c *Command) ValidateFlagGroups() error { | |
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { | ||
return err | ||
} | ||
if err := validateIfPresentThenRequiredFlagGroups(ifPresentThenOthersRequiredGroupStatus); err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
|
@@ -206,6 +231,38 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error { | |
return nil | ||
} | ||
|
||
func validateIfPresentThenRequiredFlagGroups(data map[string]map[string]bool) error { | ||
for flagList, flagnameAndStatus := range data { | ||
flags := strings.Split(flagList, " ") | ||
primaryFlag := flags[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Split spec explains it can return nil, so here flags[0] could panic Maybe it's not possible, maybe it's already checked somewhere else before reaching this line of code, but I'm always suspicious when I see such optimistic code |
||
remainingFlags := flags[1:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we sure this won't panic if the flags is limited to one element? |
||
|
||
// Handle missing primary flag entry | ||
if _, exists := flagnameAndStatus[primaryFlag]; !exists { | ||
flagnameAndStatus[primaryFlag] = false | ||
} | ||
|
||
// Check if the primary flag is set | ||
if flagnameAndStatus[primaryFlag] { | ||
var unset []string | ||
for _, flag := range remainingFlags { | ||
if !flagnameAndStatus[flag] { | ||
unset = append(unset, flag) | ||
} | ||
} | ||
|
||
// If any dependent flags are unset, trigger an error | ||
if len(unset) > 0 { | ||
return fmt.Errorf( | ||
"%v is set, the following flags must be provided: %v", | ||
primaryFlag, unset, | ||
) | ||
} | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func sortedKeys(m map[string]map[string]bool) []string { | ||
keys := make([]string, len(m)) | ||
i := 0 | ||
|
@@ -221,6 +278,7 @@ func sortedKeys(m map[string]map[string]bool) []string { | |
// - when a flag in a group is present, other flags in the group will be marked required | ||
// - when none of the flags in a one-required group are present, all flags in the group will be marked required | ||
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden | ||
// - when the first flag in an if-present-then-required group is present, the other flags will be marked as required | ||
// This allows the standard completion logic to behave appropriately for flag groups | ||
func (c *Command) enforceFlagGroupsForCompletion() { | ||
if c.DisableFlagParsing { | ||
|
@@ -231,10 +289,12 @@ func (c *Command) enforceFlagGroupsForCompletion() { | |
groupStatus := map[string]map[string]bool{} | ||
oneRequiredGroupStatus := map[string]map[string]bool{} | ||
mutuallyExclusiveGroupStatus := map[string]map[string]bool{} | ||
ifPresentThenRequiredGroupStatus := map[string]map[string]bool{} | ||
c.Flags().VisitAll(func(pflag *flag.Flag) { | ||
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenRequiredGroupStatus) | ||
}) | ||
|
||
// If a flag that is part of a group is present, we make all the other flags | ||
|
@@ -287,4 +347,17 @@ func (c *Command) enforceFlagGroupsForCompletion() { | |
} | ||
} | ||
} | ||
|
||
// If a flag that is marked as if-present-then-required is present, make other flags in the group required | ||
for flagList, flagnameAndStatus := range ifPresentThenRequiredGroupStatus { | ||
flags := strings.Split(flagList, " ") | ||
primaryFlag := flags[0] | ||
remainingFlags := flags[1:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question about flags length being 1 or 0 |
||
|
||
if flagnameAndStatus[primaryFlag] { | ||
for _, fName := range remainingFlags { | ||
_ = c.MarkFlagRequired(fName) | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎨 Nitpicking
I'm unsure about this wording here 🤔
Shouldn't it be, something like this