git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/cobra/flag_groups.go (about)

     1  // Copyright 2013-2022 The Cobra Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cobra
    16  
    17  import (
    18  	"fmt"
    19  	"sort"
    20  	"strings"
    21  
    22  	flag "github.com/spf13/pflag"
    23  )
    24  
    25  const (
    26  	requiredAsGroup   = "cobra_annotation_required_if_others_set"
    27  	mutuallyExclusive = "cobra_annotation_mutually_exclusive"
    28  )
    29  
    30  // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
    31  // if the command is invoked with a subset (but not all) of the given flags.
    32  func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
    33  	c.mergePersistentFlags()
    34  	for _, v := range flagNames {
    35  		f := c.Flags().Lookup(v)
    36  		if f == nil {
    37  			panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
    38  		}
    39  		if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
    40  			// Only errs if the flag isn't found.
    41  			panic(err)
    42  		}
    43  	}
    44  }
    45  
    46  // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
    47  // if the command is invoked with more than one flag from the given set of flags.
    48  func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
    49  	c.mergePersistentFlags()
    50  	for _, v := range flagNames {
    51  		f := c.Flags().Lookup(v)
    52  		if f == nil {
    53  			panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
    54  		}
    55  		// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
    56  		if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
    57  			panic(err)
    58  		}
    59  	}
    60  }
    61  
    62  // ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the
    63  // first error encountered.
    64  func (c *Command) ValidateFlagGroups() error {
    65  	if c.DisableFlagParsing {
    66  		return nil
    67  	}
    68  
    69  	flags := c.Flags()
    70  
    71  	// groupStatus format is the list of flags as a unique ID,
    72  	// then a map of each flag name and whether it is set or not.
    73  	groupStatus := map[string]map[string]bool{}
    74  	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
    75  	flags.VisitAll(func(pflag *flag.Flag) {
    76  		processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
    77  		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
    78  	})
    79  
    80  	if err := validateRequiredFlagGroups(groupStatus); err != nil {
    81  		return err
    82  	}
    83  	if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
    84  		return err
    85  	}
    86  	return nil
    87  }
    88  
    89  func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
    90  	for _, fname := range flagnames {
    91  		f := fs.Lookup(fname)
    92  		if f == nil {
    93  			return false
    94  		}
    95  	}
    96  	return true
    97  }
    98  
    99  func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
   100  	groupInfo, found := pflag.Annotations[annotation]
   101  	if found {
   102  		for _, group := range groupInfo {
   103  			if groupStatus[group] == nil {
   104  				flagnames := strings.Split(group, " ")
   105  
   106  				// Only consider this flag group at all if all the flags are defined.
   107  				if !hasAllFlags(flags, flagnames...) {
   108  					continue
   109  				}
   110  
   111  				groupStatus[group] = map[string]bool{}
   112  				for _, name := range flagnames {
   113  					groupStatus[group][name] = false
   114  				}
   115  			}
   116  
   117  			groupStatus[group][pflag.Name] = pflag.Changed
   118  		}
   119  	}
   120  }
   121  
   122  func validateRequiredFlagGroups(data map[string]map[string]bool) error {
   123  	keys := sortedKeys(data)
   124  	for _, flagList := range keys {
   125  		flagnameAndStatus := data[flagList]
   126  
   127  		unset := []string{}
   128  		for flagname, isSet := range flagnameAndStatus {
   129  			if !isSet {
   130  				unset = append(unset, flagname)
   131  			}
   132  		}
   133  		if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
   134  			continue
   135  		}
   136  
   137  		// Sort values, so they can be tested/scripted against consistently.
   138  		sort.Strings(unset)
   139  		return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
   140  	}
   141  
   142  	return nil
   143  }
   144  
   145  func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
   146  	keys := sortedKeys(data)
   147  	for _, flagList := range keys {
   148  		flagnameAndStatus := data[flagList]
   149  		var set []string
   150  		for flagname, isSet := range flagnameAndStatus {
   151  			if isSet {
   152  				set = append(set, flagname)
   153  			}
   154  		}
   155  		if len(set) == 0 || len(set) == 1 {
   156  			continue
   157  		}
   158  
   159  		// Sort values, so they can be tested/scripted against consistently.
   160  		sort.Strings(set)
   161  		return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
   162  	}
   163  	return nil
   164  }
   165  
   166  func sortedKeys(m map[string]map[string]bool) []string {
   167  	keys := make([]string, len(m))
   168  	i := 0
   169  	for k := range m {
   170  		keys[i] = k
   171  		i++
   172  	}
   173  	sort.Strings(keys)
   174  	return keys
   175  }
   176  
   177  // enforceFlagGroupsForCompletion will do the following:
   178  // - when a flag in a group is present, other flags in the group will be marked required
   179  // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
   180  // This allows the standard completion logic to behave appropriately for flag groups
   181  func (c *Command) enforceFlagGroupsForCompletion() {
   182  	if c.DisableFlagParsing {
   183  		return
   184  	}
   185  
   186  	flags := c.Flags()
   187  	groupStatus := map[string]map[string]bool{}
   188  	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
   189  	c.Flags().VisitAll(func(pflag *flag.Flag) {
   190  		processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
   191  		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
   192  	})
   193  
   194  	// If a flag that is part of a group is present, we make all the other flags
   195  	// of that group required so that the shell completion suggests them automatically
   196  	for flagList, flagnameAndStatus := range groupStatus {
   197  		for _, isSet := range flagnameAndStatus {
   198  			if isSet {
   199  				// One of the flags of the group is set, mark the other ones as required
   200  				for _, fName := range strings.Split(flagList, " ") {
   201  					_ = c.MarkFlagRequired(fName)
   202  				}
   203  			}
   204  		}
   205  	}
   206  
   207  	// If a flag that is mutually exclusive to others is present, we hide the other
   208  	// flags of that group so the shell completion does not suggest them
   209  	for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
   210  		for flagName, isSet := range flagnameAndStatus {
   211  			if isSet {
   212  				// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
   213  				// Don't mark the flag that is already set as hidden because it may be an
   214  				// array or slice flag and therefore must continue being suggested
   215  				for _, fName := range strings.Split(flagList, " ") {
   216  					if fName != flagName {
   217  						flag := c.Flags().Lookup(fName)
   218  						flag.Hidden = true
   219  					}
   220  				}
   221  			}
   222  		}
   223  	}
   224  }