github.com/henvic/wedeploycli@v1.7.6-0.20200319005353-3630f582f284/command/exec/args/args.go (about)

     1  package execargs
     2  
     3  import (
     4  	"strings"
     5  
     6  	"github.com/spf13/cobra"
     7  	"github.com/spf13/pflag"
     8  )
     9  
    10  // MaybeRewrite cobra arguments.
    11  func MaybeRewrite(cmd *cobra.Command, args []string) ([]string, bool) {
    12  	pos, rewrite := getPosition(cmd, args)
    13  
    14  	if !rewrite {
    15  		return []string{}, false
    16  	}
    17  
    18  	stringFlags := getStringFlags(cmd.Flags())
    19  	skip := false
    20  
    21  	found := -1
    22  
    23  	for index, a := range args[pos:] {
    24  		if a == "--" {
    25  			return []string{}, false
    26  		}
    27  
    28  		if skip {
    29  			skip = false
    30  			continue
    31  		}
    32  
    33  		if strings.HasPrefix(a, "-") {
    34  			// remember flags might be --foo, --foo=value, and --foo value.
    35  			if stringFlags[a] {
    36  				skip = true
    37  			}
    38  
    39  			continue
    40  		}
    41  
    42  		found = index
    43  		break
    44  	}
    45  
    46  	if found == -1 {
    47  		return []string{}, false
    48  	}
    49  
    50  	na := append(args[:found+pos],
    51  		append([]string{"--"}, args[found+pos:]...)...,
    52  	)
    53  
    54  	return na, true
    55  }
    56  
    57  func getPosition(cmd *cobra.Command, args []string) (int, bool) {
    58  	name := cmd.Name()
    59  
    60  	stringFlags := getStringFlags(cmd.Flags())
    61  	skip := false
    62  
    63  	for index, a := range args {
    64  		if a == "--" {
    65  			break
    66  		}
    67  
    68  		if skip {
    69  			skip = false
    70  			continue
    71  		}
    72  
    73  		if strings.HasPrefix(a, "-") {
    74  			// remember flags might be --foo, --foo=value, and --foo value.
    75  			if stringFlags[a] {
    76  				skip = true
    77  			}
    78  
    79  			continue
    80  		}
    81  
    82  		if a == name {
    83  			return index + 1, true
    84  		}
    85  
    86  		break
    87  	}
    88  
    89  	return -1, false
    90  }
    91  
    92  func getStringFlags(all *pflag.FlagSet) map[string]bool {
    93  	var flags = map[string]bool{}
    94  
    95  	all.VisitAll(func(f *pflag.Flag) {
    96  		if f.Value.Type() != "string" {
    97  			return
    98  		}
    99  
   100  		if f.Name != "" {
   101  			flags["--"+f.Name] = true
   102  		}
   103  
   104  		if f.Deprecated != "" {
   105  			flags["--"+f.Deprecated] = true
   106  		}
   107  
   108  		if f.Shorthand != "" {
   109  			flags["-"+f.Shorthand] = true
   110  		}
   111  
   112  		if f.ShorthandDeprecated != "" {
   113  			flags["-"+f.ShorthandDeprecated] = true
   114  		}
   115  	})
   116  
   117  	return flags
   118  }