github.com/influxdata/influxdb/v2@v2.7.6/kit/cli/viper.go (about)

     1  package cli
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"path"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/influxdata/influxdb/v2/kit/platform"
    11  	"github.com/spf13/cast"
    12  	"github.com/spf13/cobra"
    13  	"github.com/spf13/pflag"
    14  	"github.com/spf13/viper"
    15  	"go.uber.org/zap/zapcore"
    16  )
    17  
    18  // Opt is a single command-line option
    19  type Opt struct {
    20  	DestP interface{} // pointer to the destination
    21  
    22  	EnvVar     string
    23  	Flag       string
    24  	Hidden     bool
    25  	Persistent bool
    26  	Required   bool
    27  	Short      rune // using rune b/c it guarantees correctness. a short must always be a string of length 1
    28  
    29  	Default interface{}
    30  	Desc    string
    31  }
    32  
    33  // Program parses CLI options
    34  type Program struct {
    35  	// Run is invoked by cobra on execute.
    36  	Run func() error
    37  	// Name is the name of the program in help usage and the env var prefix.
    38  	Name string
    39  	// Opts are the command line/env var options to the program
    40  	Opts []Opt
    41  }
    42  
    43  // NewCommand creates a new cobra command to be executed that respects env vars.
    44  //
    45  // Uses the upper-case version of the program's name as a prefix
    46  // to all environment variables.
    47  //
    48  // This is to simplify the viper/cobra boilerplate.
    49  func NewCommand(v *viper.Viper, p *Program) (*cobra.Command, error) {
    50  	cmd := &cobra.Command{
    51  		Use:  p.Name,
    52  		Args: cobra.NoArgs,
    53  		RunE: func(_ *cobra.Command, _ []string) error {
    54  			return p.Run()
    55  		},
    56  	}
    57  
    58  	v.SetEnvPrefix(strings.ToUpper(p.Name))
    59  	v.AutomaticEnv()
    60  	// This normalizes "-" to an underscore in env names.
    61  	v.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
    62  
    63  	// done before we bind flags to viper keys.
    64  	// order of precedence (1 highest -> 3 lowest):
    65  	//	1. flags
    66  	//  2. env vars
    67  	//	3. config file
    68  	if err := initializeConfig(v); err != nil {
    69  		return nil, fmt.Errorf("failed to load config file: %w", err)
    70  	}
    71  	if err := BindOptions(v, cmd, p.Opts); err != nil {
    72  		return nil, fmt.Errorf("failed to bind config options: %w", err)
    73  	}
    74  
    75  	return cmd, nil
    76  }
    77  
    78  func initializeConfig(v *viper.Viper) error {
    79  	configPath := v.GetString("CONFIG_PATH")
    80  	if configPath == "" {
    81  		// Default to looking in the working directory of the running process.
    82  		configPath = "."
    83  	}
    84  
    85  	switch strings.ToLower(path.Ext(configPath)) {
    86  	case ".json", ".toml", ".yaml", ".yml":
    87  		v.SetConfigFile(configPath)
    88  	default:
    89  		v.AddConfigPath(configPath)
    90  	}
    91  
    92  	if err := v.ReadInConfig(); err != nil && !os.IsNotExist(err) {
    93  		if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
    94  			return err
    95  		}
    96  	}
    97  	return nil
    98  }
    99  
   100  // BindOptions adds opts to the specified command and automatically
   101  // registers those options with viper.
   102  func BindOptions(v *viper.Viper, cmd *cobra.Command, opts []Opt) error {
   103  	for _, o := range opts {
   104  		flagset := cmd.Flags()
   105  		if o.Persistent {
   106  			flagset = cmd.PersistentFlags()
   107  		}
   108  		envVal := lookupEnv(v, &o)
   109  		hasShort := o.Short != 0
   110  
   111  		switch destP := o.DestP.(type) {
   112  		case *string:
   113  			var d string
   114  			if o.Default != nil {
   115  				d = o.Default.(string)
   116  			}
   117  			if hasShort {
   118  				flagset.StringVarP(destP, o.Flag, string(o.Short), d, o.Desc)
   119  			} else {
   120  				flagset.StringVar(destP, o.Flag, d, o.Desc)
   121  			}
   122  			if err := v.BindPFlag(o.Flag, flagset.Lookup(o.Flag)); err != nil {
   123  				return fmt.Errorf("failed to bind flag %q: %w", o.Flag, err)
   124  			}
   125  			if envVal != nil {
   126  				if s, err := cast.ToStringE(envVal); err == nil {
   127  					*destP = s
   128  				}
   129  			}
   130  
   131  		case *int:
   132  			var d int
   133  			if o.Default != nil {
   134  				d = o.Default.(int)
   135  			}
   136  			if hasShort {
   137  				flagset.IntVarP(destP, o.Flag, string(o.Short), d, o.Desc)
   138  			} else {
   139  				flagset.IntVar(destP, o.Flag, d, o.Desc)
   140  			}
   141  			if err := v.BindPFlag(o.Flag, flagset.Lookup(o.Flag)); err != nil {
   142  				return fmt.Errorf("failed to bind flag %q: %w", o.Flag, err)
   143  			}
   144  			if envVal != nil {
   145  				if i, err := cast.ToIntE(envVal); err == nil {
   146  					*destP = i
   147  				}
   148  			}
   149  
   150  		case *int32:
   151  			var d int32
   152  			if o.Default != nil {
   153  				// N.B. since our CLI kit types default values as interface{} and
   154  				// literal numbers get typed as int by default, it's very easy to
   155  				// create an int32 CLI flag with an int default value.
   156  				//
   157  				// The compiler doesn't know to complain in that case, so you end up
   158  				// with a runtime panic when trying to bind the CLI options.
   159  				//
   160  				// To avoid that headache, we support both int32 and int defaults
   161  				// for int32 fields. This introduces a new runtime bomb if somebody
   162  				// specifies an int default > math.MaxInt32, but that's hopefully
   163  				// less likely.
   164  				var ok bool
   165  				d, ok = o.Default.(int32)
   166  				if !ok {
   167  					d = int32(o.Default.(int))
   168  				}
   169  			}
   170  			if hasShort {
   171  				flagset.Int32VarP(destP, o.Flag, string(o.Short), d, o.Desc)
   172  			} else {
   173  				flagset.Int32Var(destP, o.Flag, d, o.Desc)
   174  			}
   175  			if err := v.BindPFlag(o.Flag, flagset.Lookup(o.Flag)); err != nil {
   176  				return fmt.Errorf("failed to bind flag %q: %w", o.Flag, err)
   177  			}
   178  			if envVal != nil {
   179  				if i, err := cast.ToInt32E(envVal); err == nil {
   180  					*destP = i
   181  				}
   182  			}
   183  
   184  		case *int64:
   185  			var d int64
   186  			if o.Default != nil {
   187  				// N.B. since our CLI kit types default values as interface{} and
   188  				// literal numbers get typed as int by default, it's very easy to
   189  				// create an int64 CLI flag with an int default value.
   190  				//
   191  				// The compiler doesn't know to complain in that case, so you end up
   192  				// with a runtime panic when trying to bind the CLI options.
   193  				//
   194  				// To avoid that headache, we support both int64 and int defaults
   195  				// for int64 fields.
   196  				var ok bool
   197  				d, ok = o.Default.(int64)
   198  				if !ok {
   199  					d = int64(o.Default.(int))
   200  				}
   201  			}
   202  			if hasShort {
   203  				flagset.Int64VarP(destP, o.Flag, string(o.Short), d, o.Desc)
   204  			} else {
   205  				flagset.Int64Var(destP, o.Flag, d, o.Desc)
   206  			}
   207  			if err := v.BindPFlag(o.Flag, flagset.Lookup(o.Flag)); err != nil {
   208  				return fmt.Errorf("failed to bind flag %q: %w", o.Flag, err)
   209  			}
   210  			if envVal != nil {
   211  				if i, err := cast.ToInt64E(envVal); err == nil {
   212  					*destP = i
   213  				}
   214  			}
   215  
   216  		case *bool:
   217  			var d bool
   218  			if o.Default != nil {
   219  				d = o.Default.(bool)
   220  			}
   221  			if hasShort {
   222  				flagset.BoolVarP(destP, o.Flag, string(o.Short), d, o.Desc)
   223  			} else {
   224  				flagset.BoolVar(destP, o.Flag, d, o.Desc)
   225  			}
   226  			if err := v.BindPFlag(o.Flag, flagset.Lookup(o.Flag)); err != nil {
   227  				return fmt.Errorf("failed to bind flag %q: %w", o.Flag, err)
   228  			}
   229  			if envVal != nil {
   230  				if b, err := cast.ToBoolE(envVal); err == nil {
   231  					*destP = b
   232  				}
   233  			}
   234  
   235  		case *time.Duration:
   236  			var d time.Duration
   237  			if o.Default != nil {
   238  				d = o.Default.(time.Duration)
   239  			}
   240  			if hasShort {
   241  				flagset.DurationVarP(destP, o.Flag, string(o.Short), d, o.Desc)
   242  			} else {
   243  				flagset.DurationVar(destP, o.Flag, d, o.Desc)
   244  			}
   245  			if err := v.BindPFlag(o.Flag, flagset.Lookup(o.Flag)); err != nil {
   246  				return fmt.Errorf("failed to bind flag %q: %w", o.Flag, err)
   247  			}
   248  			if envVal != nil {
   249  				if d, err := cast.ToDurationE(envVal); err == nil {
   250  					*destP = d
   251  				}
   252  			}
   253  
   254  		case *[]string:
   255  			var d []string
   256  			if o.Default != nil {
   257  				d = o.Default.([]string)
   258  			}
   259  			if hasShort {
   260  				flagset.StringSliceVarP(destP, o.Flag, string(o.Short), d, o.Desc)
   261  			} else {
   262  				flagset.StringSliceVar(destP, o.Flag, d, o.Desc)
   263  			}
   264  			if err := v.BindPFlag(o.Flag, flagset.Lookup(o.Flag)); err != nil {
   265  				return fmt.Errorf("failed to bind flag %q: %w", o.Flag, err)
   266  			}
   267  			if envVal != nil {
   268  				if ss, err := cast.ToStringSliceE(envVal); err == nil {
   269  					*destP = ss
   270  				}
   271  			}
   272  
   273  		case *map[string]string:
   274  			var d map[string]string
   275  			if o.Default != nil {
   276  				d = o.Default.(map[string]string)
   277  			}
   278  			if hasShort {
   279  				flagset.StringToStringVarP(destP, o.Flag, string(o.Short), d, o.Desc)
   280  			} else {
   281  				flagset.StringToStringVar(destP, o.Flag, d, o.Desc)
   282  			}
   283  			if err := v.BindPFlag(o.Flag, flagset.Lookup(o.Flag)); err != nil {
   284  				return fmt.Errorf("failed to bind flag %q: %w", o.Flag, err)
   285  			}
   286  			if envVal != nil {
   287  				if sms, err := cast.ToStringMapStringE(envVal); err == nil {
   288  					*destP = sms
   289  				}
   290  			}
   291  
   292  		case pflag.Value:
   293  			if hasShort {
   294  				flagset.VarP(destP, o.Flag, string(o.Short), o.Desc)
   295  			} else {
   296  				flagset.Var(destP, o.Flag, o.Desc)
   297  			}
   298  			if o.Default != nil {
   299  				_ = destP.Set(o.Default.(string))
   300  			}
   301  			if err := v.BindPFlag(o.Flag, flagset.Lookup(o.Flag)); err != nil {
   302  				return fmt.Errorf("failed to bind flag %q: %w", o.Flag, err)
   303  			}
   304  			if envVal != nil {
   305  				if s, err := cast.ToStringE(envVal); err == nil {
   306  					_ = destP.Set(s)
   307  				}
   308  			}
   309  
   310  		case *platform.ID:
   311  			var d platform.ID
   312  			if o.Default != nil {
   313  				d = o.Default.(platform.ID)
   314  			}
   315  			if hasShort {
   316  				IDVarP(flagset, destP, o.Flag, string(o.Short), d, o.Desc)
   317  			} else {
   318  				IDVar(flagset, destP, o.Flag, d, o.Desc)
   319  			}
   320  			if envVal != nil {
   321  				if s, err := cast.ToStringE(envVal); err == nil {
   322  					_ = (*destP).DecodeFromString(s)
   323  				}
   324  			}
   325  
   326  		case *zapcore.Level:
   327  			var l zapcore.Level
   328  			if o.Default != nil {
   329  				l = o.Default.(zapcore.Level)
   330  			}
   331  			if hasShort {
   332  				LevelVarP(flagset, destP, o.Flag, string(o.Short), l, o.Desc)
   333  			} else {
   334  				LevelVar(flagset, destP, o.Flag, l, o.Desc)
   335  			}
   336  			if envVal != nil {
   337  				if s, err := cast.ToStringE(envVal); err == nil {
   338  					_ = (*destP).Set(s)
   339  				}
   340  			}
   341  
   342  		default:
   343  			// if you get this error, sorry about that!
   344  			// anyway, go ahead and make a PR and add another type.
   345  			return fmt.Errorf("unknown destination type %t", o.DestP)
   346  		}
   347  
   348  		// N.B. these "Mark" calls must run after the block above,
   349  		// otherwise cobra will return a "no such flag" error.
   350  
   351  		// Cobra will complain if a flag marked as required isn't present on the CLI.
   352  		// To support setting required args via config and env variables, we only enforce
   353  		// the required check if we didn't find a value in the viper instance.
   354  		if o.Required && envVal == nil {
   355  			if err := cmd.MarkFlagRequired(o.Flag); err != nil {
   356  				return fmt.Errorf("failed to mark flag %q as required: %w", o.Flag, err)
   357  			}
   358  		}
   359  		if o.Hidden {
   360  			if err := flagset.MarkHidden(o.Flag); err != nil {
   361  				return fmt.Errorf("failed to mark flag %q as hidden: %w", o.Flag, err)
   362  			}
   363  		}
   364  	}
   365  
   366  	return nil
   367  }
   368  
   369  // lookupEnv returns the value for a CLI option found in the environment, if any.
   370  func lookupEnv(v *viper.Viper, o *Opt) interface{} {
   371  	envVar := o.Flag
   372  	if o.EnvVar != "" {
   373  		envVar = o.EnvVar
   374  	}
   375  	return v.Get(envVar)
   376  }