Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MarkIfFlagPresentThenOthersRequired #2200

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
97 changes: 85 additions & 12 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import (
)

const (
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
oneRequiredAnnotation = "cobra_annotation_one_required"
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
annotationGroupRequired = "cobra_annotation_required_if_others_set"
annotationRequiredOne = "cobra_annotation_one_required"
annotationMutuallyExclusive = "cobra_annotation_mutually_exclusive"
annotationGroupDependent = "cobra_annotation_if_present_then_others_required"
)

// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
Expand All @@ -37,7 +38,7 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
}
if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {
if err := c.Flags().SetAnnotation(v, annotationGroupRequired, append(f.Annotations[annotationGroupRequired], strings.Join(flagNames, " "))); err != nil {
// Only errs if the flag isn't found.
panic(err)
}
Expand All @@ -53,7 +54,7 @@ func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
}
if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
if err := c.Flags().SetAnnotation(v, annotationRequiredOne, append(f.Annotations[annotationRequiredOne], strings.Join(flagNames, " "))); err != nil {
// Only errs if the flag isn't found.
panic(err)
}
Expand All @@ -70,7 +71,26 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {
if err := c.Flags().SetAnnotation(v, annotationMutuallyExclusive, append(f.Annotations[annotationMutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
panic(err)
}
}
}

// MarkIfFlagPresentThenOthersRequired marks the given flags so that if the first flag is set,
// all the other flags become required.
func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) {
if len(flagNames) < 2 {
panic("MarkIfFlagPresentThenRequired requires at least two flags")
}
Comment on lines +80 to +85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎨 Nitpicking

I'm unsure about this wording here 🤔

Shouldn't it be, something like this

Suggested change
// MarkIfFlagPresentThenOthersRequired marks the given flags so that if the first flag is set,
// all the other flags become required.
func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) {
if len(flagNames) < 2 {
panic("MarkIfFlagPresentThenRequired requires at least two flags")
}
// MarkIfFlagPresentThenOthersRequired marks the given flags so that if the first flag is set,
// all the other flags become required.
func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) {
if len(flagNames) < 2 {
panic("MarkIfFlagPresentThenOthersRequired requires at least two flags")
}

c.mergePersistentFlags()
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you mention the method name as you do in the other panic?

Or maybe, both shouldn't mention it

}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, annotationGroupDependent, append(f.Annotations[annotationGroupDependent], strings.Join(flagNames, " "))); err != nil {
panic(err)
}
}
Expand All @@ -90,10 +110,12 @@ func (c *Command) ValidateFlagGroups() error {
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenOthersRequiredGroupStatus)
})

if err := validateRequiredFlagGroups(groupStatus); err != nil {
Expand All @@ -105,6 +127,9 @@ func (c *Command) ValidateFlagGroups() error {
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err
}
if err := validateIfPresentThenRequiredFlagGroups(ifPresentThenOthersRequiredGroupStatus); err != nil {
return err
}
return nil
}

Expand Down Expand Up @@ -206,6 +231,38 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
return nil
}

func validateIfPresentThenRequiredFlagGroups(data map[string]map[string]bool) error {
for flagList, flagnameAndStatus := range data {
flags := strings.Split(flagList, " ")
primaryFlag := flags[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split spec explains it can return nil, so here flags[0] could panic

Maybe it's not possible, maybe it's already checked somewhere else before reaching this line of code, but I'm always suspicious when I see such optimistic code

remainingFlags := flags[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure this won't panic if the flags is limited to one element?


// Handle missing primary flag entry
if _, exists := flagnameAndStatus[primaryFlag]; !exists {
flagnameAndStatus[primaryFlag] = false
}

// Check if the primary flag is set
if flagnameAndStatus[primaryFlag] {
var unset []string
for _, flag := range remainingFlags {
if !flagnameAndStatus[flag] {
unset = append(unset, flag)
}
}

// If any dependent flags are unset, trigger an error
if len(unset) > 0 {
return fmt.Errorf(
"%v is set, the following flags must be provided: %v",
primaryFlag, unset,
)
}
}
}
return nil
}

func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m))
i := 0
Expand All @@ -221,6 +278,7 @@ func sortedKeys(m map[string]map[string]bool) []string {
// - when a flag in a group is present, other flags in the group will be marked required
// - when none of the flags in a one-required group are present, all flags in the group will be marked required
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
// - when the first flag in an if-present-then-required group is present, the other flags will be marked as required
// This allows the standard completion logic to behave appropriately for flag groups
func (c *Command) enforceFlagGroupsForCompletion() {
if c.DisableFlagParsing {
Expand All @@ -231,10 +289,12 @@ func (c *Command) enforceFlagGroupsForCompletion() {
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
ifPresentThenRequiredGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenRequiredGroupStatus)
})

// If a flag that is part of a group is present, we make all the other flags
Expand Down Expand Up @@ -287,4 +347,17 @@ func (c *Command) enforceFlagGroupsForCompletion() {
}
}
}

// If a flag that is marked as if-present-then-required is present, make other flags in the group required
for flagList, flagnameAndStatus := range ifPresentThenRequiredGroupStatus {
flags := strings.Split(flagList, " ")
primaryFlag := flags[0]
remainingFlags := flags[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question about flags length being 1 or 0


if flagnameAndStatus[primaryFlag] {
for _, fName := range remainingFlags {
_ = c.MarkFlagRequired(fName)
}
}
}
}
113 changes: 101 additions & 12 deletions flag_groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,25 @@ func TestValidateFlagGroups(t *testing.T) {

// Each test case uses a unique command from the function above.
testcases := []struct {
desc string
flagGroupsRequired []string
flagGroupsOneRequired []string
flagGroupsExclusive []string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsOneRequired []string
subCmdFlagGroupsExclusive []string
args []string
expectErr string
desc string
flagGroupsRequired []string
flagGroupsOneRequired []string
flagGroupsExclusive []string
flagGroupsIfPresentThenRequired []string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsOneRequired []string
subCmdFlagGroupsExclusive []string
subCmdFlagGroupsIfPresentThenRequired []string
args []string
expectErr string
}{
{
desc: "No flags no problem",
}, {
desc: "No flags no problem even with conflicting groups",
flagGroupsRequired: []string{"a b"},
flagGroupsExclusive: []string{"a b"},
desc: "No flags no problem even with conflicting groups",
flagGroupsRequired: []string{"a b"},
flagGroupsExclusive: []string{"a b"},
flagGroupsIfPresentThenRequired: []string{"a b", "b a"},
}, {
desc: "Required flag group not satisfied",
flagGroupsRequired: []string{"a b c"},
Expand All @@ -74,6 +77,11 @@ func TestValidateFlagGroups(t *testing.T) {
flagGroupsExclusive: []string{"a b c"},
args: []string{"--a=foo", "--b=foo"},
expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set",
}, {
desc: "If present then others required flag group not satisfied",
flagGroupsIfPresentThenRequired: []string{"a b"},
args: []string{"--a=foo"},
expectErr: "a is set, the following flags must be provided: [b]",
}, {
desc: "Multiple required flag group not satisfied returns first error",
flagGroupsRequired: []string{"a b c", "a d"},
Expand All @@ -89,6 +97,12 @@ func TestValidateFlagGroups(t *testing.T) {
flagGroupsExclusive: []string{"a b c", "a d"},
args: []string{"--a=foo", "--c=foo", "--d=foo"},
expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`,
},
{
desc: "Multiple if present then others required flag group not satisfied returns first error",
flagGroupsIfPresentThenRequired: []string{"a b", "d e"},
args: []string{"--a=foo", "--f=foo"},
expectErr: `a is set, the following flags must be provided: [b]`,
}, {
desc: "Validation of required groups occurs on groups in sorted order",
flagGroupsRequired: []string{"a d", "a b", "a c"},
Expand Down Expand Up @@ -182,6 +196,12 @@ func TestValidateFlagGroups(t *testing.T) {
for _, flagGroup := range tc.subCmdFlagGroupsExclusive {
sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...)
}
for _, flagGroup := range tc.flagGroupsIfPresentThenRequired {
c.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...)
}
for _, flagGroup := range tc.subCmdFlagGroupsIfPresentThenRequired {
sub.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...)
}
c.SetArgs(tc.args)
err := c.Execute()
switch {
Expand All @@ -193,3 +213,72 @@ func TestValidateFlagGroups(t *testing.T) {
})
}
}

func TestMarkIfFlagPresentThenOthersRequiredAnnotations(t *testing.T) {
// Create a new command with some flags.
cmd := &Command{
Use: "testcmd",
}
f := cmd.Flags()
f.String("a", "", "flag a")
f.String("b", "", "flag b")
f.String("c", "", "flag c")

// Call the function with one group: ["a", "b"].
cmd.MarkIfFlagPresentThenOthersRequired("a", "b")

// Check that flag "a" has the correct annotation.
aFlag := f.Lookup("a")
if aFlag == nil {
t.Fatal("Flag 'a' not found")
}
annA := aFlag.Annotations[annotationGroupDependent]
expected1 := "a b" // since strings.Join(["a","b"], " ") yields "a b"
if len(annA) != 1 || annA[0] != expected1 {
t.Errorf("Expected flag 'a' annotation to be [%q], got %v", expected1, annA)
}

// Also check that flag "b" has the correct annotation.
bFlag := f.Lookup("b")
if bFlag == nil {
t.Fatal("Flag 'b' not found")
}
annB := bFlag.Annotations[annotationGroupDependent]
if len(annB) != 1 || annB[0] != expected1 {
t.Errorf("Expected flag 'b' annotation to be [%q], got %v", expected1, annB)
}

// Now, call MarkIfFlagPresentThenOthersRequired again with a different group involving "a" and "c".
cmd.MarkIfFlagPresentThenOthersRequired("a", "c")

// The annotation for flag "a" should now have both groups: "a b" and "a c"
annA = aFlag.Annotations[annotationGroupDependent]
expectedAnnotations := []string{"a b", "a c"}
if len(annA) != 2 {
t.Errorf("Expected 2 annotations on flag 'a', got %v", annA)
}
// Check that both expected annotation strings are present.
for _, expected := range expectedAnnotations {
found := false
for _, ann := range annA {
if ann == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected annotation %q not found on flag 'a': %v", expected, annA)
}
}

// Similarly, check that flag "c" now has the annotation "a c".
cFlag := f.Lookup("c")
if cFlag == nil {
t.Fatal("Flag 'c' not found")
}
annC := cFlag.Annotations[annotationGroupDependent]
expected2 := "a c"
if len(annC) != 1 || annC[0] != expected2 {
t.Errorf("Expected flag 'c' annotation to be [%q], got %v", expected2, annC)
}
}