github.com/qiniu/dyn@v1.3.0/flag/parser.go (about)

     1  package flag
     2  
     3  import (
     4  	"errors"
     5  	"flag"
     6  	"reflect"
     7  	"strconv"
     8  	"strings"
     9  	"syscall"
    10  
    11  	"github.com/qiniu/dyn/cmdarg"
    12  	"github.com/qiniu/x/jsonutil"
    13  	"github.com/qiniu/x/log"
    14  
    15  	. "github.com/qiniu/dyn/proto"
    16  )
    17  
    18  var (
    19  	ErrParamsNotEnough     = errors.New("params not enough")
    20  	ErrTooMuchParams       = errors.New("too much params")
    21  	ErrUnsupportedFlagType = errors.New("unsupported flag type")
    22  	ErrUnsupportedArgType  = errors.New("unsupported argument type")
    23  )
    24  
    25  // ---------------------------------------------------------------------------
    26  
    27  type Context interface {
    28  	Subst(vexp interface{}, ft int) (vres interface{}, err error)
    29  	SubstText(exprvar string, ft int) (v string, err error)
    30  }
    31  
    32  // ---------------------------------------------------------------------------
    33  
    34  type nilContextImpl struct{}
    35  
    36  var nilContext Context = nilContextImpl{}
    37  
    38  func (p nilContextImpl) Subst(vexp interface{}, ft int) (vres interface{}, err error) {
    39  	return vexp, nil
    40  }
    41  
    42  func (p nilContextImpl) SubstText(exprvar string, ft int) (v string, err error) {
    43  	return exprvar, nil
    44  }
    45  
    46  // ---------------------------------------------------------------------------
    47  
    48  func hasOption(tag string, opt string) bool {
    49  
    50  	for i := 0; i < len(tag); i++ {
    51  		switch tag[i] {
    52  		case ',':
    53  			if strings.HasPrefix(tag[i+1:], opt) {
    54  				tagLeft := tag[i+1+len(opt):]
    55  				if tagLeft == "" || tagLeft[0] == ',' || tagLeft[0] == ' ' {
    56  					return true
    57  				}
    58  			}
    59  		case ' ':
    60  			return false
    61  		}
    62  	}
    63  	return false
    64  }
    65  
    66  func parseFlagArg(fv reflect.Value, f *flag.FlagSet, tag string) (err error) {
    67  
    68  	n := 0
    69  	for n < len(tag) {
    70  		if tag[n] == ',' || tag[n] == ' ' {
    71  			break
    72  		}
    73  		n++
    74  	}
    75  
    76  	name := tag[:n]
    77  	usage := ""
    78  	pos := strings.Index(tag[n:], " - ")
    79  	if pos >= 0 {
    80  		usage = tag[n+pos+3:]
    81  	}
    82  
    83  	switch fv.Kind() {
    84  	case reflect.Ptr:
    85  		switch fv.Elem().Kind() {
    86  		case reflect.Bool:
    87  			fv.Set(reflect.ValueOf(f.Bool(name, false, usage)))
    88  		case reflect.Int:
    89  			fv.Set(reflect.ValueOf(f.Int(name, 0, usage)))
    90  		case reflect.Uint:
    91  			fv.Set(reflect.ValueOf(f.Uint(name, 0, usage)))
    92  		case reflect.Uint64:
    93  			fv.Set(reflect.ValueOf(f.Uint64(name, 0, usage)))
    94  		case reflect.String:
    95  			fv.Set(reflect.ValueOf(f.String(name, "", usage)))
    96  		default:
    97  			return ErrUnsupportedFlagType
    98  		}
    99  	case reflect.Bool:
   100  		f.BoolVar(fv.Addr().Interface().(*bool), name, false, usage)
   101  	case reflect.Int:
   102  		f.IntVar(fv.Addr().Interface().(*int), name, 0, usage)
   103  	case reflect.Uint:
   104  		f.UintVar(fv.Addr().Interface().(*uint), name, 0, usage)
   105  	case reflect.Uint64:
   106  		f.Uint64Var(fv.Addr().Interface().(*uint64), name, 0, usage)
   107  	case reflect.String:
   108  		f.StringVar(fv.Addr().Interface().(*string), name, "", usage)
   109  	default:
   110  		return ErrUnsupportedFlagType
   111  	}
   112  	return nil
   113  }
   114  
   115  type parseArgOpts struct {
   116  	ft       string
   117  	keep     bool
   118  	optional bool
   119  }
   120  
   121  func getFmttype(ft string, ftDefault int) int {
   122  
   123  	switch ft {
   124  	case "form":
   125  		return Fmttype_Form
   126  	case "text":
   127  		return Fmttype_Text
   128  	case "json":
   129  		return Fmttype_Json
   130  	}
   131  	return ftDefault
   132  }
   133  
   134  func parseArgTag(tag string) (opts parseArgOpts, err error) {
   135  
   136  	pos := strings.Index(tag, " - ")
   137  	if pos >= 0 {
   138  		tag = tag[:pos]
   139  	}
   140  
   141  	parts := strings.Split(tag, ",")
   142  	for i := 1; i < len(parts); i++ {
   143  		switch parts[i] {
   144  		case "keep":
   145  			opts.keep = true
   146  		case "form", "text", "json":
   147  			opts.ft = parts[i]
   148  		case "opt":
   149  			opts.optional = true
   150  		default:
   151  			err = errors.New("Unknown tag option: " + parts[i])
   152  			return
   153  		}
   154  	}
   155  	return
   156  }
   157  
   158  func parseArg(
   159  	ctx Context, fv reflect.Value, arg string, opts parseArgOpts) (err error) {
   160  
   161  	kind := fv.Kind()
   162  	switch kind {
   163  	case reflect.String:
   164  		if opts.keep { // 保留 $(var) 不要自动展开
   165  			fv.SetString(arg)
   166  			return
   167  		}
   168  		arg, err = ctx.SubstText(arg, getFmttype(opts.ft, Fmttype_Text))
   169  		if err != nil {
   170  			return
   171  		}
   172  		fv.SetString(arg)
   173  
   174  	case reflect.Interface:
   175  		if ctx == nilContext {
   176  			var argObj interface{}
   177  			err1 := jsonutil.Unmarshal(arg, &argObj)
   178  			if err1 != nil {
   179  				return err1
   180  			}
   181  			fv.Set(reflect.ValueOf(argObj))
   182  		} else {
   183  			argObj, err1 := cmdarg.Unmarshal(arg)
   184  			if err1 != nil {
   185  				return err1
   186  			}
   187  			if opts.keep { // 保留 $(var) 不要做 Subst
   188  				fv.Set(reflect.ValueOf(argObj))
   189  				return
   190  			}
   191  			argObj, err2 := ctx.Subst(argObj, getFmttype(opts.ft, Fmttype_Text))
   192  			if err2 != nil {
   193  				return err2
   194  			}
   195  			fv.Set(reflect.ValueOf(argObj))
   196  		}
   197  
   198  	default:
   199  		if kind >= reflect.Int && kind <= reflect.Int64 {
   200  			arg, err = ctx.SubstText(arg, Fmttype_Text)
   201  			if err != nil {
   202  				return
   203  			}
   204  			intVal, err2 := strconv.ParseInt(arg, 10, 64)
   205  			if err2 != nil {
   206  				return err2
   207  			}
   208  			fv.SetInt(intVal)
   209  			return nil
   210  		}
   211  		if kind >= reflect.Uint && kind <= reflect.Uintptr {
   212  			arg, err = ctx.SubstText(arg, Fmttype_Text)
   213  			if err != nil {
   214  				return
   215  			}
   216  			uintVal, err2 := strconv.ParseUint(arg, 10, 64)
   217  			if err2 != nil {
   218  				return err2
   219  			}
   220  			fv.SetUint(uintVal)
   221  			return nil
   222  		}
   223  		arg, err = ctx.SubstText(arg, getFmttype(opts.ft, Fmttype_Json))
   224  		if err != nil {
   225  			return
   226  		}
   227  		err = jsonutil.Unmarshal(arg, fv.Addr().Interface())
   228  		if err != nil {
   229  			log.Debug("parseCmdArgs failed:", err, "arg:", arg)
   230  			return
   231  		}
   232  	}
   233  	return nil
   234  }
   235  
   236  func parseVargs(
   237  	ctx Context, fv reflect.Value, args []string, opts parseArgOpts) (err error) {
   238  
   239  	sliceType := fv.Type()
   240  	n := len(args)
   241  	sliceValue := reflect.MakeSlice(sliceType, n, n)
   242  	for i, arg := range args {
   243  		err = parseArg(ctx, sliceValue.Index(i), arg, opts)
   244  		if err != nil {
   245  			return
   246  		}
   247  	}
   248  	fv.Set(sliceValue)
   249  	return
   250  }
   251  
   252  func parseStructArgs(
   253  	ctx Context, strucType reflect.Type, cmd []string) (args reflect.Value, err error) {
   254  
   255  	nField := strucType.NumField()
   256  
   257  	hasFlag := false
   258  	for i := 0; i < nField; i++ {
   259  		sf := strucType.Field(i)
   260  		if strings.HasPrefix(string(sf.Tag), "flag:") {
   261  			hasFlag = true
   262  			break
   263  		}
   264  	}
   265  
   266  	args = reflect.New(strucType)
   267  	argsRef := args.Elem()
   268  
   269  	if hasFlag {
   270  		f := flag.NewFlagSet(cmd[0], 0)
   271  		for i := 0; i < nField; i++ {
   272  			sf := strucType.Field(i)
   273  			if strings.HasPrefix(string(sf.Tag), "f") {
   274  				err = parseFlagArg(argsRef.Field(i), f, sf.Tag.Get("flag"))
   275  				if err != nil {
   276  					return
   277  				}
   278  			}
   279  		}
   280  		err = f.Parse(cmd[1:])
   281  		if err != nil {
   282  			return
   283  		}
   284  		cmd = f.Args()
   285  	} else {
   286  		cmd = cmd[1:]
   287  	}
   288  
   289  	icmd := 0
   290  	for i := 0; i < nField; i++ {
   291  		sf := strucType.Field(i)
   292  		if strings.HasPrefix(string(sf.Tag), "arg:") {
   293  			tag := sf.Tag.Get("arg")
   294  			opts, err2 := parseArgTag(tag)
   295  			if err2 != nil {
   296  				err = err2
   297  				return
   298  			}
   299  			fv := argsRef.Field(i)
   300  			if fv.Kind() == reflect.Slice { // 不定参数
   301  				err = parseVargs(ctx, fv, cmd[icmd:], opts)
   302  				return
   303  			}
   304  			if icmd >= len(cmd) {
   305  				if opts.optional { // 可选参数
   306  					return
   307  				}
   308  				err = ErrParamsNotEnough
   309  				return
   310  			}
   311  			err = parseArg(ctx, fv, cmd[icmd], opts)
   312  			if err != nil {
   313  				return
   314  			}
   315  			icmd++
   316  		}
   317  	}
   318  	if icmd != len(cmd) {
   319  		err = ErrTooMuchParams
   320  	}
   321  	return
   322  }
   323  
   324  func Parse(
   325  	ctx Context, argsType reflect.Type, cmd []string) (args reflect.Value, err error) {
   326  
   327  	if ctx == nil {
   328  		ctx = nilContext
   329  	}
   330  
   331  	switch argsType.Kind() {
   332  	case reflect.Ptr: // may be args *xxxArgs
   333  		strucType := argsType.Elem()
   334  		if strucType.Kind() == reflect.Struct {
   335  			return parseStructArgs(ctx, strucType, cmd)
   336  		}
   337  	case reflect.Slice: // may be args []string
   338  		if argsType.Elem().Kind() == reflect.String {
   339  			return reflect.ValueOf(cmd), nil
   340  		}
   341  	}
   342  	err = syscall.EINVAL
   343  	return
   344  }
   345  
   346  // ---------------------------------------------------------------------------
   347  
   348  var (
   349  	ErrMethodNotFound         = errors.New("method not found")
   350  	ErrInvalidMethodPrototype = errors.New("invalid method prototype: method input argument count != 2")
   351  )
   352  
   353  func ExecMethod(
   354  	ctx Context, method, methodCtx reflect.Value, cmd []string) (out []reflect.Value, err error) {
   355  
   356  	mtype := method.Type()
   357  	narg := mtype.NumIn()
   358  	if narg > 2 {
   359  		err = ErrInvalidMethodPrototype
   360  		return
   361  	}
   362  	if narg == 0 {
   363  		return method.Call(nil), nil
   364  	}
   365  
   366  	argsType := mtype.In(narg - 1)
   367  	args, err := Parse(ctx, argsType, cmd)
   368  	if err != nil {
   369  		err = errors.New("ParseArgs failed: " + err.Error())
   370  		return
   371  	}
   372  
   373  	var in []reflect.Value
   374  	if narg == 1 {
   375  		in = []reflect.Value{args}
   376  	} else {
   377  		in = []reflect.Value{
   378  			methodCtx,
   379  			args,
   380  		}
   381  	}
   382  	return method.Call(in), nil
   383  }
   384  
   385  func Exec(v reflect.Value, cmd []string) (out []reflect.Value, err error) {
   386  
   387  	method := v.MethodByName("Cmd_" + cmd[0])
   388  	if !method.IsValid() {
   389  		return nil, ErrMethodNotFound
   390  	}
   391  
   392  	return ExecMethod(nilContext, method, v, cmd)
   393  }
   394  
   395  // ---------------------------------------------------------------------------
   396  
   397  type Cmd struct{}
   398  
   399  func HelpMethod(app string, method reflect.Value) {
   400  
   401  	// TODO
   402  }
   403  
   404  func HelpCmd(app string, v reflect.Value, cmd string) {
   405  
   406  	method := v.MethodByName("Cmd_" + cmd)
   407  	if !method.IsValid() {
   408  		return
   409  	}
   410  
   411  	HelpMethod(app, method)
   412  }
   413  
   414  func Help(app string, v reflect.Value) {
   415  
   416  	// TODO
   417  }
   418  
   419  // ---------------------------------------------------------------------------