github.com/go-maxhub/gremlins@v1.0.1-0.20231227222204-b03a6a1e3e09/cmd/flags/flags.go (about) 1 /* 2 * Copyright 2022 The Gremlins Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package flags 18 19 import ( 20 "github.com/spf13/cobra" 21 "github.com/spf13/pflag" 22 "github.com/spf13/viper" 23 ) 24 25 // Flag is the core representation of a command flag. It is used to set 26 // flags in a more generic way. 27 type Flag struct { 28 Name string 29 CfgKey string 30 Shorthand string 31 DefaultV any 32 Usage string 33 } 34 35 // Set is a "generic" function used to set flags on cobra.Command and bind 36 // them to viper.Viper. 37 func Set(cmd *cobra.Command, flag *Flag) error { 38 flagSet := cmd.Flags() 39 40 return setFlags(flag, flagSet) 41 } 42 43 // SetPersistent is a "generic" function used to set persistent flags 44 // on cobra.Command and bind them to viper.Viper. 45 func SetPersistent(cmd *cobra.Command, flag *Flag) error { 46 flagSet := cmd.PersistentFlags() 47 48 return setFlags(flag, flagSet) 49 } 50 51 func setFlags(flag *Flag, fs *pflag.FlagSet) error { 52 switch dv := flag.DefaultV.(type) { 53 // TODO: add a case for all the supported types 54 case bool: 55 setBool(flag, fs, dv) 56 case string: 57 setString(flag, fs, dv) 58 case int: 59 setInt(flag, fs, dv) 60 case float64: 61 setFloat64(flag, fs, dv) 62 } 63 err := viper.BindPFlag(flag.CfgKey, fs.Lookup(flag.Name)) 64 if err != nil { 65 return err 66 } 67 68 return nil 69 } 70 71 func setInt(flag *Flag, flags *pflag.FlagSet, dv int) { 72 if flag.Shorthand != "" { 73 flags.IntP(flag.Name, flag.Shorthand, dv, flag.Usage) 74 } else { 75 flags.Int(flag.Name, dv, flag.Usage) 76 } 77 } 78 79 func setFloat64(flag *Flag, flags *pflag.FlagSet, dv float64) { 80 if flag.Shorthand != "" { 81 flags.Float64P(flag.Name, flag.Shorthand, dv, flag.Usage) 82 } else { 83 flags.Float64(flag.Name, dv, flag.Usage) 84 } 85 } 86 87 func setString(flag *Flag, flags *pflag.FlagSet, dv string) { 88 if flag.Shorthand != "" { 89 flags.StringP(flag.Name, flag.Shorthand, dv, flag.Usage) 90 } else { 91 flags.String(flag.Name, dv, flag.Usage) 92 } 93 } 94 95 func setBool(flag *Flag, flags *pflag.FlagSet, dv bool) { 96 if flag.Shorthand != "" { 97 flags.BoolP(flag.Name, flag.Shorthand, dv, flag.Usage) 98 } else { 99 flags.Bool(flag.Name, dv, flag.Usage) 100 } 101 }