github.com/pyroscope-io/pyroscope@v0.37.3-0.20230725203016-5f6947968bd0/pkg/cli/flags.go (about)

     1  package cli
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"reflect"
     7  	"strconv"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/iancoleman/strcase"
    12  	"github.com/mitchellh/mapstructure"
    13  	"github.com/sirupsen/logrus"
    14  	"github.com/spf13/pflag"
    15  	"github.com/spf13/viper"
    16  
    17  	"github.com/pyroscope-io/pyroscope/pkg/adhoc/util"
    18  	"github.com/pyroscope-io/pyroscope/pkg/agent/spy"
    19  	"github.com/pyroscope-io/pyroscope/pkg/config"
    20  	"github.com/pyroscope-io/pyroscope/pkg/util/bytesize"
    21  	"github.com/pyroscope-io/pyroscope/pkg/util/duration"
    22  	"github.com/pyroscope-io/pyroscope/pkg/util/slices"
    23  )
    24  
    25  const timeFormat = "2006-01-02T15:04:05Z0700"
    26  
    27  type arrayFlags []string
    28  
    29  func (i *arrayFlags) String() string {
    30  	if len(*i) == 0 {
    31  		return "[]"
    32  	}
    33  	return strings.Join(*i, ",")
    34  }
    35  
    36  func (i *arrayFlags) Set(value string) error {
    37  	*i = append(*i, value)
    38  	return nil
    39  }
    40  
    41  func (*arrayFlags) Type() string {
    42  	t := reflect.TypeOf([]string{})
    43  	return t.String()
    44  }
    45  
    46  type timeFlag time.Time
    47  
    48  func (tf *timeFlag) String() string {
    49  	v := time.Time(*tf)
    50  	return v.Format(timeFormat)
    51  }
    52  
    53  func (tf *timeFlag) Set(value string) error {
    54  	t2, err := time.Parse(timeFormat, value)
    55  	if err != nil {
    56  		var i int
    57  		i, err = strconv.Atoi(value)
    58  		if err != nil {
    59  			return err
    60  		}
    61  		t2 = time.Unix(int64(i), 0)
    62  	}
    63  
    64  	t := (*time.Time)(tf)
    65  	b, _ := t2.MarshalBinary()
    66  	t.UnmarshalBinary(b)
    67  
    68  	return nil
    69  }
    70  
    71  func (tf *timeFlag) Type() string {
    72  	v := time.Time(*tf)
    73  	t := reflect.TypeOf(v)
    74  	return t.String()
    75  }
    76  
    77  type mapFlags map[string]string
    78  
    79  func (m mapFlags) String() string {
    80  	if len(m) == 0 {
    81  		return "{}"
    82  	}
    83  	// Cast to map to avoid recursion.
    84  	return fmt.Sprint((map[string]string)(m))
    85  }
    86  
    87  func (m *mapFlags) Set(s string) error {
    88  	if len(s) == 0 {
    89  		return nil
    90  	}
    91  	v := strings.Split(s, "=")
    92  	if len(v) != 2 {
    93  		return fmt.Errorf("invalid flag %s: should be in key=value format", s)
    94  	}
    95  	if *m == nil {
    96  		*m = map[string]string{v[0]: v[1]}
    97  	} else {
    98  		(*m)[v[0]] = v[1]
    99  	}
   100  	return nil
   101  }
   102  
   103  func (*mapFlags) Type() string {
   104  	t := reflect.TypeOf(map[string]string{})
   105  	return t.String()
   106  }
   107  
   108  func Unmarshal(vpr *viper.Viper, cfg interface{}) error {
   109  	return vpr.Unmarshal(cfg, viper.DecodeHook(
   110  		mapstructure.ComposeDecodeHookFunc(
   111  			// Function to add a special type for «env. mode»
   112  			stringToByteSize,
   113  			stringToSameSite,
   114  			// Function to support net.IP
   115  			mapstructure.StringToIPHookFunc(),
   116  			// Appended by the two default functions
   117  			mapstructure.StringToTimeDurationHookFunc(),
   118  			mapstructure.StringToSliceHookFunc(","),
   119  		),
   120  	))
   121  }
   122  
   123  func stringToByteSize(_, t reflect.Type, data interface{}) (interface{}, error) {
   124  	if t != reflect.TypeOf(bytesize.Byte) {
   125  		return data, nil
   126  	}
   127  	stringData, ok := data.(string)
   128  	if !ok {
   129  		return data, nil
   130  	}
   131  	return bytesize.Parse(stringData)
   132  }
   133  
   134  func stringToSameSite(_, t reflect.Type, data interface{}) (interface{}, error) {
   135  	if t != reflect.TypeOf(http.SameSiteStrictMode) {
   136  		return data, nil
   137  	}
   138  	stringData, ok := data.(string)
   139  	if !ok {
   140  		return data, nil
   141  	}
   142  	return parseSameSite(stringData)
   143  }
   144  
   145  type options struct {
   146  	replacements   map[string]string
   147  	skip           []string
   148  	skipDeprecated bool
   149  }
   150  
   151  type FlagOption func(*options)
   152  
   153  func WithSkip(n ...string) FlagOption {
   154  	return func(o *options) {
   155  		o.skip = append(o.skip, n...)
   156  	}
   157  }
   158  
   159  // WithSkipDeprecated specifies that fields marked as deprecated won't be parsed.
   160  // By default PopulateFlagSet parses them but not shows in Usage; setting this
   161  // option to true causes PopulateFlagSet to skip parsing.
   162  func WithSkipDeprecated(ok bool) FlagOption {
   163  	return func(o *options) {
   164  		o.skipDeprecated = ok
   165  	}
   166  }
   167  
   168  func WithReplacement(k, v string) FlagOption {
   169  	return func(o *options) {
   170  		o.replacements[k] = v
   171  	}
   172  }
   173  
   174  type durFlag time.Duration
   175  
   176  func (df *durFlag) String() string {
   177  	v := time.Duration(*df)
   178  	return v.String()
   179  }
   180  
   181  func (df *durFlag) Set(value string) error {
   182  	d, err := duration.ParseDuration(value)
   183  	if err != nil {
   184  		return err
   185  	}
   186  
   187  	*df = durFlag(d)
   188  
   189  	return nil
   190  }
   191  
   192  func (df *durFlag) Type() string {
   193  	v := time.Duration(*df)
   194  	t := reflect.TypeOf(v)
   195  	return t.String()
   196  }
   197  
   198  type sameSiteFlag http.SameSite
   199  
   200  func (sf *sameSiteFlag) String() string {
   201  	switch http.SameSite(*sf) {
   202  	case http.SameSiteLaxMode:
   203  		return "Lax"
   204  	case http.SameSiteStrictMode:
   205  		return "Strict"
   206  	default:
   207  		return "None"
   208  	}
   209  }
   210  
   211  func (sf *sameSiteFlag) Set(value string) error {
   212  	v, err := parseSameSite(value)
   213  	if err != nil {
   214  		return err
   215  	}
   216  	*sf = sameSiteFlag(v)
   217  	return nil
   218  }
   219  
   220  func parseSameSite(s string) (http.SameSite, error) {
   221  	switch strings.ToLower(s) {
   222  	case "lax":
   223  		return http.SameSiteLaxMode, nil
   224  	case "strict":
   225  		return http.SameSiteStrictMode, nil
   226  	case "none":
   227  		return http.SameSiteNoneMode, nil
   228  	default:
   229  		return http.SameSiteDefaultMode, fmt.Errorf("unknown SameSite ")
   230  	}
   231  }
   232  
   233  func (sf *sameSiteFlag) Type() string {
   234  	return reflect.TypeOf(http.SameSite(*sf)).String()
   235  }
   236  
   237  type byteSizeFlag bytesize.ByteSize
   238  
   239  func (bs *byteSizeFlag) String() string {
   240  	v := bytesize.ByteSize(*bs)
   241  	return v.String()
   242  }
   243  
   244  func (bs *byteSizeFlag) Set(value string) error {
   245  	d, err := bytesize.Parse(value)
   246  	if err != nil {
   247  		return err
   248  	}
   249  
   250  	*bs = byteSizeFlag(d)
   251  
   252  	return nil
   253  }
   254  
   255  func (bs *byteSizeFlag) Type() string {
   256  	v := bytesize.ByteSize(*bs)
   257  	t := reflect.TypeOf(v)
   258  	return t.String()
   259  }
   260  
   261  func PopulateFlagSet(obj interface{}, flagSet *pflag.FlagSet, vpr *viper.Viper, opts ...FlagOption) *pflag.FlagSet {
   262  	v := reflect.ValueOf(obj).Elem()
   263  	t := reflect.TypeOf(v.Interface())
   264  
   265  	o := &options{
   266  		replacements: map[string]string{
   267  			"<installPrefix>":           getInstallPrefix(),
   268  			"<defaultAdhocDataPath>":    util.DataDirectory(),
   269  			"<defaultAgentConfigPath>":  defaultAgentConfigPath(),
   270  			"<defaultAgentLogFilePath>": defaultAgentLogFilePath(),
   271  			"<supportedProfilers>":      strings.Join(spy.SupportedExecSpies(), ", "),
   272  		},
   273  	}
   274  	for _, option := range opts {
   275  		option(o)
   276  	}
   277  
   278  	visitFields(flagSet, vpr, "", t, v, o)
   279  
   280  	return flagSet
   281  }
   282  
   283  //revive:disable-next-line:argument-limit,cognitive-complexity necessary complexity
   284  func visitFields(flagSet *pflag.FlagSet, vpr *viper.Viper, prefix string, t reflect.Type, v reflect.Value, o *options) {
   285  	num := t.NumField()
   286  	for i := 0; i < num; i++ {
   287  		field := t.Field(i)
   288  		fieldV := v.Field(i)
   289  
   290  		if !(fieldV.IsValid() && fieldV.CanSet()) {
   291  			continue
   292  		}
   293  
   294  		defaultValStr := field.Tag.Get("def")
   295  		descVal := field.Tag.Get("desc")
   296  		skipVal := field.Tag.Get("skip")
   297  		deprecatedVal := field.Tag.Get("deprecated")
   298  		nameVal := field.Tag.Get("name")
   299  		if nameVal == "" {
   300  			nameVal = strcase.ToKebab(field.Name)
   301  		}
   302  		if prefix != "" {
   303  			nameVal = prefix + "." + nameVal
   304  		}
   305  
   306  		if skipVal == "true" || slices.StringContains(o.skip, nameVal) {
   307  			continue
   308  		}
   309  
   310  		for old, n := range o.replacements {
   311  			descVal = strings.ReplaceAll(descVal, old, n)
   312  		}
   313  
   314  		if fieldV.Kind() == reflect.Slice && field.Type.Elem().Kind() == reflect.Struct {
   315  			flagSet.Var(new(arrayFlags), nameVal, descVal)
   316  			continue
   317  		}
   318  
   319  		switch field.Type {
   320  		case reflect.TypeOf(http.SameSiteStrictMode):
   321  			valP := fieldV.Addr().Interface().(*http.SameSite)
   322  			val := (*sameSiteFlag)(valP)
   323  			var defaultVal http.SameSite
   324  			if defaultValStr != "" {
   325  				var err error
   326  				defaultVal, err = parseSameSite(defaultValStr)
   327  				if err != nil {
   328  					logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal)
   329  				}
   330  			}
   331  			*val = (sameSiteFlag)(defaultVal)
   332  			flagSet.Var(val, nameVal, descVal)
   333  			vpr.SetDefault(nameVal, defaultVal)
   334  		case reflect.TypeOf([]string{}):
   335  			val := fieldV.Addr().Interface().(*[]string)
   336  			val2 := (*arrayFlags)(val)
   337  			flagSet.Var(val2, nameVal, descVal)
   338  			// setting empty defaults to allow vpr.Unmarshal to recognize this field
   339  			vpr.SetDefault(nameVal, []string{})
   340  		case reflect.TypeOf(map[string]string{}):
   341  			val := fieldV.Addr().Interface().(*map[string]string)
   342  			val2 := (*mapFlags)(val)
   343  			flagSet.Var(val2, nameVal, descVal)
   344  			// setting empty defaults to allow vpr.Unmarshal to recognize this field
   345  			vpr.SetDefault(nameVal, map[string]string{})
   346  		case reflect.TypeOf(""):
   347  			val := fieldV.Addr().Interface().(*string)
   348  			for old, n := range o.replacements {
   349  				defaultValStr = strings.ReplaceAll(defaultValStr, old, n)
   350  			}
   351  			flagSet.StringVar(val, nameVal, defaultValStr, descVal)
   352  			vpr.SetDefault(nameVal, defaultValStr)
   353  		case reflect.TypeOf(true):
   354  			val := fieldV.Addr().Interface().(*bool)
   355  			flagSet.BoolVar(val, nameVal, defaultValStr == "true", descVal)
   356  			vpr.SetDefault(nameVal, defaultValStr == "true")
   357  		case reflect.TypeOf(time.Time{}):
   358  			valTime := fieldV.Addr().Interface().(*time.Time)
   359  			val := (*timeFlag)(valTime)
   360  			flagSet.Var(val, nameVal, descVal)
   361  			// setting empty defaults to allow vpr.Unmarshal to recognize this field
   362  			vpr.SetDefault(nameVal, time.Time{})
   363  		case reflect.TypeOf(time.Second):
   364  			valDur := fieldV.Addr().Interface().(*time.Duration)
   365  			val := (*durFlag)(valDur)
   366  
   367  			var defaultVal time.Duration
   368  			if defaultValStr != "" {
   369  				var err error
   370  				defaultVal, err = duration.ParseDuration(defaultValStr)
   371  				if err != nil {
   372  					logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal)
   373  				}
   374  			}
   375  			*val = (durFlag)(defaultVal)
   376  
   377  			flagSet.Var(val, nameVal, descVal)
   378  			vpr.SetDefault(nameVal, defaultVal)
   379  		case reflect.TypeOf(bytesize.Byte):
   380  			valByteSize := fieldV.Addr().Interface().(*bytesize.ByteSize)
   381  			val := (*byteSizeFlag)(valByteSize)
   382  			var defaultVal bytesize.ByteSize
   383  			if defaultValStr != "" {
   384  				var err error
   385  				defaultVal, err = bytesize.Parse(defaultValStr)
   386  				if err != nil {
   387  					logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal)
   388  				}
   389  			}
   390  
   391  			*val = (byteSizeFlag)(defaultVal)
   392  			flagSet.Var(val, nameVal, descVal)
   393  			vpr.SetDefault(nameVal, defaultVal)
   394  		case reflect.TypeOf(1):
   395  			val := fieldV.Addr().Interface().(*int)
   396  			var defaultVal int
   397  			if defaultValStr == "" {
   398  				defaultVal = 0
   399  			} else {
   400  				var err error
   401  				defaultVal, err = strconv.Atoi(defaultValStr)
   402  				if err != nil {
   403  					logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal)
   404  				}
   405  			}
   406  			flagSet.IntVar(val, nameVal, defaultVal, descVal)
   407  			vpr.SetDefault(nameVal, defaultVal)
   408  		case reflect.TypeOf(1.00):
   409  			val := fieldV.Addr().Interface().(*float64)
   410  			var defaultVal float64
   411  			if defaultValStr == "" {
   412  				defaultVal = 0.00
   413  			} else {
   414  				var err error
   415  				defaultVal, err = strconv.ParseFloat(defaultValStr, 64)
   416  				if err != nil {
   417  					logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal)
   418  				}
   419  			}
   420  			flagSet.Float64Var(val, nameVal, defaultVal, descVal)
   421  			vpr.SetDefault(nameVal, defaultVal)
   422  		case reflect.TypeOf(uint64(1)):
   423  			val := fieldV.Addr().Interface().(*uint64)
   424  			var defaultVal uint64
   425  			if defaultValStr == "" {
   426  				defaultVal = uint64(0)
   427  			} else {
   428  				var err error
   429  				defaultVal, err = strconv.ParseUint(defaultValStr, 10, 64)
   430  				if err != nil {
   431  					logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal)
   432  				}
   433  			}
   434  			flagSet.Uint64Var(val, nameVal, defaultVal, descVal)
   435  			vpr.SetDefault(nameVal, defaultVal)
   436  		case reflect.TypeOf(uint(1)):
   437  			val := fieldV.Addr().Interface().(*uint)
   438  			var defaultVal uint
   439  			if defaultValStr == "" {
   440  				defaultVal = uint(0)
   441  			} else {
   442  				out, err := strconv.ParseUint(defaultValStr, 10, 64)
   443  				if err != nil {
   444  					logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal)
   445  				}
   446  				defaultVal = uint(out)
   447  			}
   448  			flagSet.UintVar(val, nameVal, defaultVal, descVal)
   449  			vpr.SetDefault(nameVal, defaultVal)
   450  		case reflect.TypeOf(config.MetricsExportRules{}):
   451  			flagSet.Var(new(mapFlags), nameVal, descVal)
   452  			vpr.SetDefault(nameVal, config.MetricsExportRules{})
   453  		case reflect.TypeOf([]config.Target{}):
   454  			flagSet.Var(new(arrayFlags), nameVal, descVal)
   455  			vpr.SetDefault(nameVal, []config.Target{})
   456  		default:
   457  			if field.Type.Kind() == reflect.Struct {
   458  				visitFields(flagSet, vpr, nameVal, field.Type, fieldV, o)
   459  				continue
   460  			}
   461  
   462  			// A stub for unknown types. This is required for generated configs and
   463  			// documentation (when a parameter can not be set via flag but present
   464  			// in the configuration). Empty value is shown as '{}'.
   465  			flagSet.Var(new(mapFlags), nameVal, descVal)
   466  			vpr.SetDefault(nameVal, nil)
   467  		}
   468  
   469  		if deprecatedVal == "true" {
   470  			// TODO: We could specify which flag to use instead but would add code complexity
   471  			flagSet.MarkDeprecated(nameVal, "replace this flag as it will be removed in future versions")
   472  		}
   473  	}
   474  }