github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/flagparse/parse.go (about)

     1  package flagparse
     2  
     3  import (
     4  	"embed"
     5  	"fmt"
     6  	"log"
     7  	"net/http"
     8  	_ "net/http/pprof"
     9  	"os"
    10  	"reflect"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/bingoohuang/gg/pkg/cast"
    15  	"github.com/bingoohuang/gg/pkg/ctl"
    16  	flag "github.com/bingoohuang/gg/pkg/fla9"
    17  	"github.com/bingoohuang/gg/pkg/ss"
    18  	"github.com/bingoohuang/gg/pkg/v"
    19  )
    20  
    21  type PostProcessor interface {
    22  	PostProcess()
    23  }
    24  
    25  type VersionShower interface {
    26  	VersionInfo() string
    27  }
    28  
    29  type UsageShower interface {
    30  	Usage() string
    31  }
    32  
    33  type requiredVar struct {
    34  	name string
    35  	p    *string
    36  	pp   *[]string
    37  }
    38  
    39  type Options struct {
    40  	flagName, defaultCnf string
    41  	cnf                  *string
    42  	initFiles            *embed.FS
    43  }
    44  
    45  type OptionsFn func(*Options)
    46  
    47  func ProcessInit(initFiles *embed.FS) OptionsFn {
    48  	return func(o *Options) {
    49  		o.initFiles = initFiles
    50  	}
    51  }
    52  
    53  func AutoLoadYaml(flagName, defaultCnf string) OptionsFn {
    54  	return func(o *Options) {
    55  		o.flagName = flagName
    56  		o.defaultCnf = defaultCnf
    57  	}
    58  }
    59  
    60  func Parse(a interface{}, optionFns ...OptionsFn) {
    61  	ParseArgs(a, os.Args, optionFns...)
    62  }
    63  
    64  func ParseArgs(a interface{}, args []string, optionFns ...OptionsFn) {
    65  	options := createOptions(optionFns)
    66  
    67  	f := flag.NewFlagSet(args[0], flag.ExitOnError)
    68  	var checkVersionShow func()
    69  	requiredVars := make([]requiredVar, 0)
    70  
    71  	var pprof *string
    72  	initing := false
    73  
    74  	ra := reflect.ValueOf(a).Elem()
    75  	rt := ra.Type()
    76  	for i := 0; i < rt.NumField(); i++ {
    77  		fi, fv := rt.Field(i), ra.Field(i)
    78  		if fi.PkgPath != "" { // ignore unexported
    79  			continue
    80  		}
    81  
    82  		t := fi.Tag.Get
    83  		name := t("flag")
    84  		if name == "-" || !fv.CanAddr() {
    85  			continue
    86  		}
    87  
    88  		if name == "" {
    89  			name = ss.ToLowerKebab(fi.Name)
    90  		} else if strings.HasPrefix(name, ",") { // for shortName
    91  			name = ss.ToLowerKebab(fi.Name) + name
    92  		}
    93  
    94  		val, usage, required, size := t("val"), t("usage"), t("required"), t("size")
    95  		p := fv.Addr().Interface()
    96  		ft := fi.Type
    97  		if reflect.PtrTo(ft).Implements(flagValueType) {
    98  			f.Var(p.(flag.Value), name, usage)
    99  			continue
   100  		} else if ft == timeDurationType {
   101  			f.DurationVar(p.(*time.Duration), name, cast.ToDuration(val), usage)
   102  			continue
   103  		}
   104  
   105  		fullName, shortName := ss.Split2(name, ss.WithSeps(","))
   106  
   107  		switch ft.Kind() {
   108  		case reflect.Slice:
   109  			switch ft.Elem().Kind() {
   110  			case reflect.String:
   111  				pp := p.(*[]string)
   112  				f.Var(&ArrayFlags{pp: pp, Value: val}, name, usage)
   113  				if required == "true" {
   114  					requiredVars = append(requiredVars, requiredVar{name: name, pp: pp})
   115  				}
   116  			}
   117  		case reflect.String:
   118  			pp := p.(*string)
   119  			f.StringVar(pp, name, val, usage)
   120  			if required == "true" {
   121  				requiredVars = append(requiredVars, requiredVar{name: name, p: pp})
   122  			}
   123  
   124  			switch {
   125  			case ss.AnyOf("pprof", fullName, shortName):
   126  				pprof = pp
   127  			case ss.AnyOf(options.flagName, fullName, shortName):
   128  				options.cnf = pp
   129  			}
   130  
   131  		case reflect.Int:
   132  			if count := t("count"); count == "true" {
   133  				f.CountVar(p.(*int), name, cast.ToInt(val), usage)
   134  			} else {
   135  				f.IntVar(p.(*int), name, cast.ToInt(val), usage)
   136  			}
   137  		case reflect.Int32:
   138  			f.Int32Var(p.(*int32), name, cast.ToInt32(val), usage)
   139  		case reflect.Int64:
   140  			f.Int64Var(p.(*int64), name, cast.ToInt64(val), usage)
   141  		case reflect.Uint:
   142  			f.UintVar(p.(*uint), name, cast.ToUint(val), usage)
   143  		case reflect.Uint32:
   144  			f.Uint32Var(p.(*uint32), name, cast.ToUint32(val), usage)
   145  		case reflect.Uint64:
   146  			if size == "true" {
   147  				f.Var(flag.NewSizeFlag(p.(*uint64), val), name, usage)
   148  			} else {
   149  				f.Uint64Var(p.(*uint64), name, cast.ToUint64(val), usage)
   150  			}
   151  		case reflect.Bool:
   152  			if fi.Name == "Init" {
   153  				f.BoolVar(&initing, name, false, usage)
   154  			} else {
   155  				pp := p.(*bool)
   156  				checkVersionShow = checkVersion(checkVersionShow, a, fi.Name, pp)
   157  				f.BoolVar(pp, name, cast.ToBool(val), usage)
   158  			}
   159  		case reflect.Float32:
   160  			f.Float32Var(p.(*float32), name, cast.ToFloat32(val), usage)
   161  		case reflect.Float64:
   162  			f.Float64Var(p.(*float64), name, cast.ToFloat64(val), usage)
   163  		}
   164  	}
   165  
   166  	if u, ok := a.(UsageShower); ok {
   167  		f.Usage = func() {
   168  			fmt.Println(strings.TrimSpace(u.Usage()))
   169  		}
   170  	}
   171  
   172  	if options.cnf != nil {
   173  		fn, sn := ss.Split2(options.flagName, ss.WithSeps(","))
   174  		if value, _ := FindFlag(args, fn, sn); value != "" || options.defaultCnf != "" {
   175  			if err := LoadConfFile(value, options.defaultCnf, a); err != nil {
   176  				fmt.Fprintln(os.Stderr, err.Error())
   177  				os.Exit(-1)
   178  			}
   179  		}
   180  	}
   181  
   182  	// 提前到这里,实际上是为了先解析出 --conf 参数,便于下面从配置文件载入数据
   183  	// 但是,命令行应该优先级,应该比配置文件优先级高,为了解决这个矛盾
   184  	// 需要把 --conf 参数置为第一个参数,并且使用自定义参数的形式,在解析到改参数时,
   185  	// 立即从对应的配置文件加载所有配置,然后再依次处理其它命令行参数
   186  	_ = f.Parse(args[1:])
   187  
   188  	if checkVersionShow != nil {
   189  		checkVersionShow()
   190  	}
   191  	if initing {
   192  		ctl.Config{Initing: true, InitFiles: options.initFiles}.ProcessInit()
   193  	}
   194  
   195  	checkRequired(requiredVars, f)
   196  
   197  	if pp, ok := a.(PostProcessor); ok {
   198  		pp.PostProcess()
   199  	}
   200  
   201  	if pprof != nil && *pprof != "" {
   202  		go startPprof(*pprof)
   203  	}
   204  }
   205  
   206  func FindFlag(args []string, targetNames ...string) (value string, found bool) {
   207  	for i := 1; i < len(args); i++ {
   208  		s := args[i]
   209  		if len(s) < 2 || s[0] != '-' {
   210  			continue
   211  		}
   212  		numMinuses := 1
   213  		if s[1] == '-' {
   214  			numMinuses++
   215  			if len(s) == 2 { // "--" terminates the flags
   216  				break
   217  			}
   218  		}
   219  
   220  		name := s[numMinuses:]
   221  		if len(name) == 0 || name[0] == '-' || name[0] == '=' { // bad flag syntax: %s"
   222  			continue
   223  		}
   224  		if strings.HasPrefix(name, "test.") { // ignore go test flags
   225  			continue
   226  		}
   227  
   228  		// it's a flag. does it have an argument?
   229  		hasValue := false
   230  		for j := 1; j < len(name); j++ { // equals cannot be first
   231  			if name[j] == '=' {
   232  				value = name[j+1:]
   233  				hasValue = true
   234  				name = name[0:j]
   235  				break
   236  			}
   237  		}
   238  
   239  		if !ss.AnyOf(name, targetNames...) {
   240  			continue
   241  		}
   242  
   243  		// It must have a value, which might be the next argument.
   244  		if !hasValue && i+1 < len(args) {
   245  			// value is the next arg
   246  			hasValue = true
   247  			value = args[i+1]
   248  		}
   249  
   250  		return value, true
   251  	}
   252  
   253  	return "", false
   254  }
   255  
   256  func createOptions(fns []OptionsFn) *Options {
   257  	options := &Options{}
   258  	for _, f := range fns {
   259  		f(options)
   260  	}
   261  
   262  	return options
   263  }
   264  
   265  var (
   266  	timeDurationType = reflect.TypeOf(time.Duration(0))
   267  	flagValueType    = reflect.TypeOf((*flag.Value)(nil)).Elem()
   268  )
   269  
   270  func checkRequired(requiredVars []requiredVar, f *flag.FlagSet) {
   271  	requiredMissed := 0
   272  	for _, rv := range requiredVars {
   273  		if rv.p != nil && *rv.p == "" || rv.pp != nil && len(*rv.pp) == 0 {
   274  			requiredMissed++
   275  			fmt.Printf("-%s is required\n", rv.name)
   276  		}
   277  	}
   278  
   279  	if requiredMissed > 0 {
   280  		f.Usage()
   281  		os.Exit(1)
   282  	}
   283  }
   284  
   285  func checkVersion(checker func(), arg interface{}, fiName string, bp *bool) func() {
   286  	if checker == nil && fiName == "Version" {
   287  		if vs, ok := arg.(VersionShower); ok {
   288  			return func() {
   289  				if *bp {
   290  					fmt.Println(vs.VersionInfo())
   291  					os.Exit(0)
   292  				}
   293  			}
   294  		} else {
   295  			return func() {
   296  				if *bp {
   297  					fmt.Println(v.Version())
   298  					os.Exit(0)
   299  				}
   300  			}
   301  		}
   302  	}
   303  
   304  	return checker
   305  }
   306  
   307  type ArrayFlags struct {
   308  	Value string
   309  	pp    *[]string
   310  }
   311  
   312  func (i *ArrayFlags) String() string { return i.Value }
   313  
   314  func (i *ArrayFlags) Set(value string) error {
   315  	*i.pp = append(*i.pp, value)
   316  	return nil
   317  }
   318  
   319  func startPprof(pprofAddr string) {
   320  	pprofHostPort := pprofAddr
   321  	parts := strings.Split(pprofHostPort, ":")
   322  	if len(parts) == 2 && parts[0] == "" {
   323  		pprofHostPort = fmt.Sprintf("localhost:%s", parts[1])
   324  	}
   325  
   326  	log.Printf("I! Starting pprof HTTP server at: http://%s/debug/pprof", pprofHostPort)
   327  	if err := http.ListenAndServe(pprofAddr, nil); err != nil {
   328  		log.Fatal("E! " + err.Error())
   329  	}
   330  }