github.com/pyroscope-io/pyroscope@v0.37.3-0.20230725203016-5f6947968bd0/pkg/cli/flags.go (about) 1 package cli 2 3 import ( 4 "fmt" 5 "net/http" 6 "reflect" 7 "strconv" 8 "strings" 9 "time" 10 11 "github.com/iancoleman/strcase" 12 "github.com/mitchellh/mapstructure" 13 "github.com/sirupsen/logrus" 14 "github.com/spf13/pflag" 15 "github.com/spf13/viper" 16 17 "github.com/pyroscope-io/pyroscope/pkg/adhoc/util" 18 "github.com/pyroscope-io/pyroscope/pkg/agent/spy" 19 "github.com/pyroscope-io/pyroscope/pkg/config" 20 "github.com/pyroscope-io/pyroscope/pkg/util/bytesize" 21 "github.com/pyroscope-io/pyroscope/pkg/util/duration" 22 "github.com/pyroscope-io/pyroscope/pkg/util/slices" 23 ) 24 25 const timeFormat = "2006-01-02T15:04:05Z0700" 26 27 type arrayFlags []string 28 29 func (i *arrayFlags) String() string { 30 if len(*i) == 0 { 31 return "[]" 32 } 33 return strings.Join(*i, ",") 34 } 35 36 func (i *arrayFlags) Set(value string) error { 37 *i = append(*i, value) 38 return nil 39 } 40 41 func (*arrayFlags) Type() string { 42 t := reflect.TypeOf([]string{}) 43 return t.String() 44 } 45 46 type timeFlag time.Time 47 48 func (tf *timeFlag) String() string { 49 v := time.Time(*tf) 50 return v.Format(timeFormat) 51 } 52 53 func (tf *timeFlag) Set(value string) error { 54 t2, err := time.Parse(timeFormat, value) 55 if err != nil { 56 var i int 57 i, err = strconv.Atoi(value) 58 if err != nil { 59 return err 60 } 61 t2 = time.Unix(int64(i), 0) 62 } 63 64 t := (*time.Time)(tf) 65 b, _ := t2.MarshalBinary() 66 t.UnmarshalBinary(b) 67 68 return nil 69 } 70 71 func (tf *timeFlag) Type() string { 72 v := time.Time(*tf) 73 t := reflect.TypeOf(v) 74 return t.String() 75 } 76 77 type mapFlags map[string]string 78 79 func (m mapFlags) String() string { 80 if len(m) == 0 { 81 return "{}" 82 } 83 // Cast to map to avoid recursion. 84 return fmt.Sprint((map[string]string)(m)) 85 } 86 87 func (m *mapFlags) Set(s string) error { 88 if len(s) == 0 { 89 return nil 90 } 91 v := strings.Split(s, "=") 92 if len(v) != 2 { 93 return fmt.Errorf("invalid flag %s: should be in key=value format", s) 94 } 95 if *m == nil { 96 *m = map[string]string{v[0]: v[1]} 97 } else { 98 (*m)[v[0]] = v[1] 99 } 100 return nil 101 } 102 103 func (*mapFlags) Type() string { 104 t := reflect.TypeOf(map[string]string{}) 105 return t.String() 106 } 107 108 func Unmarshal(vpr *viper.Viper, cfg interface{}) error { 109 return vpr.Unmarshal(cfg, viper.DecodeHook( 110 mapstructure.ComposeDecodeHookFunc( 111 // Function to add a special type for «env. mode» 112 stringToByteSize, 113 stringToSameSite, 114 // Function to support net.IP 115 mapstructure.StringToIPHookFunc(), 116 // Appended by the two default functions 117 mapstructure.StringToTimeDurationHookFunc(), 118 mapstructure.StringToSliceHookFunc(","), 119 ), 120 )) 121 } 122 123 func stringToByteSize(_, t reflect.Type, data interface{}) (interface{}, error) { 124 if t != reflect.TypeOf(bytesize.Byte) { 125 return data, nil 126 } 127 stringData, ok := data.(string) 128 if !ok { 129 return data, nil 130 } 131 return bytesize.Parse(stringData) 132 } 133 134 func stringToSameSite(_, t reflect.Type, data interface{}) (interface{}, error) { 135 if t != reflect.TypeOf(http.SameSiteStrictMode) { 136 return data, nil 137 } 138 stringData, ok := data.(string) 139 if !ok { 140 return data, nil 141 } 142 return parseSameSite(stringData) 143 } 144 145 type options struct { 146 replacements map[string]string 147 skip []string 148 skipDeprecated bool 149 } 150 151 type FlagOption func(*options) 152 153 func WithSkip(n ...string) FlagOption { 154 return func(o *options) { 155 o.skip = append(o.skip, n...) 156 } 157 } 158 159 // WithSkipDeprecated specifies that fields marked as deprecated won't be parsed. 160 // By default PopulateFlagSet parses them but not shows in Usage; setting this 161 // option to true causes PopulateFlagSet to skip parsing. 162 func WithSkipDeprecated(ok bool) FlagOption { 163 return func(o *options) { 164 o.skipDeprecated = ok 165 } 166 } 167 168 func WithReplacement(k, v string) FlagOption { 169 return func(o *options) { 170 o.replacements[k] = v 171 } 172 } 173 174 type durFlag time.Duration 175 176 func (df *durFlag) String() string { 177 v := time.Duration(*df) 178 return v.String() 179 } 180 181 func (df *durFlag) Set(value string) error { 182 d, err := duration.ParseDuration(value) 183 if err != nil { 184 return err 185 } 186 187 *df = durFlag(d) 188 189 return nil 190 } 191 192 func (df *durFlag) Type() string { 193 v := time.Duration(*df) 194 t := reflect.TypeOf(v) 195 return t.String() 196 } 197 198 type sameSiteFlag http.SameSite 199 200 func (sf *sameSiteFlag) String() string { 201 switch http.SameSite(*sf) { 202 case http.SameSiteLaxMode: 203 return "Lax" 204 case http.SameSiteStrictMode: 205 return "Strict" 206 default: 207 return "None" 208 } 209 } 210 211 func (sf *sameSiteFlag) Set(value string) error { 212 v, err := parseSameSite(value) 213 if err != nil { 214 return err 215 } 216 *sf = sameSiteFlag(v) 217 return nil 218 } 219 220 func parseSameSite(s string) (http.SameSite, error) { 221 switch strings.ToLower(s) { 222 case "lax": 223 return http.SameSiteLaxMode, nil 224 case "strict": 225 return http.SameSiteStrictMode, nil 226 case "none": 227 return http.SameSiteNoneMode, nil 228 default: 229 return http.SameSiteDefaultMode, fmt.Errorf("unknown SameSite ") 230 } 231 } 232 233 func (sf *sameSiteFlag) Type() string { 234 return reflect.TypeOf(http.SameSite(*sf)).String() 235 } 236 237 type byteSizeFlag bytesize.ByteSize 238 239 func (bs *byteSizeFlag) String() string { 240 v := bytesize.ByteSize(*bs) 241 return v.String() 242 } 243 244 func (bs *byteSizeFlag) Set(value string) error { 245 d, err := bytesize.Parse(value) 246 if err != nil { 247 return err 248 } 249 250 *bs = byteSizeFlag(d) 251 252 return nil 253 } 254 255 func (bs *byteSizeFlag) Type() string { 256 v := bytesize.ByteSize(*bs) 257 t := reflect.TypeOf(v) 258 return t.String() 259 } 260 261 func PopulateFlagSet(obj interface{}, flagSet *pflag.FlagSet, vpr *viper.Viper, opts ...FlagOption) *pflag.FlagSet { 262 v := reflect.ValueOf(obj).Elem() 263 t := reflect.TypeOf(v.Interface()) 264 265 o := &options{ 266 replacements: map[string]string{ 267 "<installPrefix>": getInstallPrefix(), 268 "<defaultAdhocDataPath>": util.DataDirectory(), 269 "<defaultAgentConfigPath>": defaultAgentConfigPath(), 270 "<defaultAgentLogFilePath>": defaultAgentLogFilePath(), 271 "<supportedProfilers>": strings.Join(spy.SupportedExecSpies(), ", "), 272 }, 273 } 274 for _, option := range opts { 275 option(o) 276 } 277 278 visitFields(flagSet, vpr, "", t, v, o) 279 280 return flagSet 281 } 282 283 //revive:disable-next-line:argument-limit,cognitive-complexity necessary complexity 284 func visitFields(flagSet *pflag.FlagSet, vpr *viper.Viper, prefix string, t reflect.Type, v reflect.Value, o *options) { 285 num := t.NumField() 286 for i := 0; i < num; i++ { 287 field := t.Field(i) 288 fieldV := v.Field(i) 289 290 if !(fieldV.IsValid() && fieldV.CanSet()) { 291 continue 292 } 293 294 defaultValStr := field.Tag.Get("def") 295 descVal := field.Tag.Get("desc") 296 skipVal := field.Tag.Get("skip") 297 deprecatedVal := field.Tag.Get("deprecated") 298 nameVal := field.Tag.Get("name") 299 if nameVal == "" { 300 nameVal = strcase.ToKebab(field.Name) 301 } 302 if prefix != "" { 303 nameVal = prefix + "." + nameVal 304 } 305 306 if skipVal == "true" || slices.StringContains(o.skip, nameVal) { 307 continue 308 } 309 310 for old, n := range o.replacements { 311 descVal = strings.ReplaceAll(descVal, old, n) 312 } 313 314 if fieldV.Kind() == reflect.Slice && field.Type.Elem().Kind() == reflect.Struct { 315 flagSet.Var(new(arrayFlags), nameVal, descVal) 316 continue 317 } 318 319 switch field.Type { 320 case reflect.TypeOf(http.SameSiteStrictMode): 321 valP := fieldV.Addr().Interface().(*http.SameSite) 322 val := (*sameSiteFlag)(valP) 323 var defaultVal http.SameSite 324 if defaultValStr != "" { 325 var err error 326 defaultVal, err = parseSameSite(defaultValStr) 327 if err != nil { 328 logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal) 329 } 330 } 331 *val = (sameSiteFlag)(defaultVal) 332 flagSet.Var(val, nameVal, descVal) 333 vpr.SetDefault(nameVal, defaultVal) 334 case reflect.TypeOf([]string{}): 335 val := fieldV.Addr().Interface().(*[]string) 336 val2 := (*arrayFlags)(val) 337 flagSet.Var(val2, nameVal, descVal) 338 // setting empty defaults to allow vpr.Unmarshal to recognize this field 339 vpr.SetDefault(nameVal, []string{}) 340 case reflect.TypeOf(map[string]string{}): 341 val := fieldV.Addr().Interface().(*map[string]string) 342 val2 := (*mapFlags)(val) 343 flagSet.Var(val2, nameVal, descVal) 344 // setting empty defaults to allow vpr.Unmarshal to recognize this field 345 vpr.SetDefault(nameVal, map[string]string{}) 346 case reflect.TypeOf(""): 347 val := fieldV.Addr().Interface().(*string) 348 for old, n := range o.replacements { 349 defaultValStr = strings.ReplaceAll(defaultValStr, old, n) 350 } 351 flagSet.StringVar(val, nameVal, defaultValStr, descVal) 352 vpr.SetDefault(nameVal, defaultValStr) 353 case reflect.TypeOf(true): 354 val := fieldV.Addr().Interface().(*bool) 355 flagSet.BoolVar(val, nameVal, defaultValStr == "true", descVal) 356 vpr.SetDefault(nameVal, defaultValStr == "true") 357 case reflect.TypeOf(time.Time{}): 358 valTime := fieldV.Addr().Interface().(*time.Time) 359 val := (*timeFlag)(valTime) 360 flagSet.Var(val, nameVal, descVal) 361 // setting empty defaults to allow vpr.Unmarshal to recognize this field 362 vpr.SetDefault(nameVal, time.Time{}) 363 case reflect.TypeOf(time.Second): 364 valDur := fieldV.Addr().Interface().(*time.Duration) 365 val := (*durFlag)(valDur) 366 367 var defaultVal time.Duration 368 if defaultValStr != "" { 369 var err error 370 defaultVal, err = duration.ParseDuration(defaultValStr) 371 if err != nil { 372 logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal) 373 } 374 } 375 *val = (durFlag)(defaultVal) 376 377 flagSet.Var(val, nameVal, descVal) 378 vpr.SetDefault(nameVal, defaultVal) 379 case reflect.TypeOf(bytesize.Byte): 380 valByteSize := fieldV.Addr().Interface().(*bytesize.ByteSize) 381 val := (*byteSizeFlag)(valByteSize) 382 var defaultVal bytesize.ByteSize 383 if defaultValStr != "" { 384 var err error 385 defaultVal, err = bytesize.Parse(defaultValStr) 386 if err != nil { 387 logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal) 388 } 389 } 390 391 *val = (byteSizeFlag)(defaultVal) 392 flagSet.Var(val, nameVal, descVal) 393 vpr.SetDefault(nameVal, defaultVal) 394 case reflect.TypeOf(1): 395 val := fieldV.Addr().Interface().(*int) 396 var defaultVal int 397 if defaultValStr == "" { 398 defaultVal = 0 399 } else { 400 var err error 401 defaultVal, err = strconv.Atoi(defaultValStr) 402 if err != nil { 403 logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal) 404 } 405 } 406 flagSet.IntVar(val, nameVal, defaultVal, descVal) 407 vpr.SetDefault(nameVal, defaultVal) 408 case reflect.TypeOf(1.00): 409 val := fieldV.Addr().Interface().(*float64) 410 var defaultVal float64 411 if defaultValStr == "" { 412 defaultVal = 0.00 413 } else { 414 var err error 415 defaultVal, err = strconv.ParseFloat(defaultValStr, 64) 416 if err != nil { 417 logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal) 418 } 419 } 420 flagSet.Float64Var(val, nameVal, defaultVal, descVal) 421 vpr.SetDefault(nameVal, defaultVal) 422 case reflect.TypeOf(uint64(1)): 423 val := fieldV.Addr().Interface().(*uint64) 424 var defaultVal uint64 425 if defaultValStr == "" { 426 defaultVal = uint64(0) 427 } else { 428 var err error 429 defaultVal, err = strconv.ParseUint(defaultValStr, 10, 64) 430 if err != nil { 431 logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal) 432 } 433 } 434 flagSet.Uint64Var(val, nameVal, defaultVal, descVal) 435 vpr.SetDefault(nameVal, defaultVal) 436 case reflect.TypeOf(uint(1)): 437 val := fieldV.Addr().Interface().(*uint) 438 var defaultVal uint 439 if defaultValStr == "" { 440 defaultVal = uint(0) 441 } else { 442 out, err := strconv.ParseUint(defaultValStr, 10, 64) 443 if err != nil { 444 logrus.Fatalf("invalid default value: %q (%s)", defaultValStr, nameVal) 445 } 446 defaultVal = uint(out) 447 } 448 flagSet.UintVar(val, nameVal, defaultVal, descVal) 449 vpr.SetDefault(nameVal, defaultVal) 450 case reflect.TypeOf(config.MetricsExportRules{}): 451 flagSet.Var(new(mapFlags), nameVal, descVal) 452 vpr.SetDefault(nameVal, config.MetricsExportRules{}) 453 case reflect.TypeOf([]config.Target{}): 454 flagSet.Var(new(arrayFlags), nameVal, descVal) 455 vpr.SetDefault(nameVal, []config.Target{}) 456 default: 457 if field.Type.Kind() == reflect.Struct { 458 visitFields(flagSet, vpr, nameVal, field.Type, fieldV, o) 459 continue 460 } 461 462 // A stub for unknown types. This is required for generated configs and 463 // documentation (when a parameter can not be set via flag but present 464 // in the configuration). Empty value is shown as '{}'. 465 flagSet.Var(new(mapFlags), nameVal, descVal) 466 vpr.SetDefault(nameVal, nil) 467 } 468 469 if deprecatedVal == "true" { 470 // TODO: We could specify which flag to use instead but would add code complexity 471 flagSet.MarkDeprecated(nameVal, "replace this flag as it will be removed in future versions") 472 } 473 } 474 }