github.com/go-maxhub/gremlins@v1.0.1-0.20231227222204-b03a6a1e3e09/cmd/flags/flags.go (about)

     1  /*
     2   * Copyright 2022 The Gremlins Authors
     3   *
     4   *    Licensed under the Apache License, Version 2.0 (the "License");
     5   *    you may not use this file except in compliance with the License.
     6   *    You may obtain a copy of the License at
     7   *
     8   *        http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   *    Unless required by applicable law or agreed to in writing, software
    11   *    distributed under the License is distributed on an "AS IS" BASIS,
    12   *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   *    See the License for the specific language governing permissions and
    14   *    limitations under the License.
    15   */
    16  
    17  package flags
    18  
    19  import (
    20  	"github.com/spf13/cobra"
    21  	"github.com/spf13/pflag"
    22  	"github.com/spf13/viper"
    23  )
    24  
    25  // Flag is the core representation of a command flag. It is used to set
    26  // flags in a more generic way.
    27  type Flag struct {
    28  	Name      string
    29  	CfgKey    string
    30  	Shorthand string
    31  	DefaultV  any
    32  	Usage     string
    33  }
    34  
    35  // Set is a "generic" function used to set flags on cobra.Command and bind
    36  // them to viper.Viper.
    37  func Set(cmd *cobra.Command, flag *Flag) error {
    38  	flagSet := cmd.Flags()
    39  
    40  	return setFlags(flag, flagSet)
    41  }
    42  
    43  // SetPersistent is a "generic" function used to set persistent flags
    44  // on cobra.Command and bind them to viper.Viper.
    45  func SetPersistent(cmd *cobra.Command, flag *Flag) error {
    46  	flagSet := cmd.PersistentFlags()
    47  
    48  	return setFlags(flag, flagSet)
    49  }
    50  
    51  func setFlags(flag *Flag, fs *pflag.FlagSet) error {
    52  	switch dv := flag.DefaultV.(type) {
    53  	// TODO: add a case for all the supported types
    54  	case bool:
    55  		setBool(flag, fs, dv)
    56  	case string:
    57  		setString(flag, fs, dv)
    58  	case int:
    59  		setInt(flag, fs, dv)
    60  	case float64:
    61  		setFloat64(flag, fs, dv)
    62  	}
    63  	err := viper.BindPFlag(flag.CfgKey, fs.Lookup(flag.Name))
    64  	if err != nil {
    65  		return err
    66  	}
    67  
    68  	return nil
    69  }
    70  
    71  func setInt(flag *Flag, flags *pflag.FlagSet, dv int) {
    72  	if flag.Shorthand != "" {
    73  		flags.IntP(flag.Name, flag.Shorthand, dv, flag.Usage)
    74  	} else {
    75  		flags.Int(flag.Name, dv, flag.Usage)
    76  	}
    77  }
    78  
    79  func setFloat64(flag *Flag, flags *pflag.FlagSet, dv float64) {
    80  	if flag.Shorthand != "" {
    81  		flags.Float64P(flag.Name, flag.Shorthand, dv, flag.Usage)
    82  	} else {
    83  		flags.Float64(flag.Name, dv, flag.Usage)
    84  	}
    85  }
    86  
    87  func setString(flag *Flag, flags *pflag.FlagSet, dv string) {
    88  	if flag.Shorthand != "" {
    89  		flags.StringP(flag.Name, flag.Shorthand, dv, flag.Usage)
    90  	} else {
    91  		flags.String(flag.Name, dv, flag.Usage)
    92  	}
    93  }
    94  
    95  func setBool(flag *Flag, flags *pflag.FlagSet, dv bool) {
    96  	if flag.Shorthand != "" {
    97  		flags.BoolP(flag.Name, flag.Shorthand, dv, flag.Usage)
    98  	} else {
    99  		flags.Bool(flag.Name, dv, flag.Usage)
   100  	}
   101  }