github.com/bytedance/go-tagexpr@v2.7.5-0.20210114074101-de5b8743ad85+incompatible/binding/param_info.go (about)

     1  package binding
     2  
     3  import (
     4  	jsonpkg "encoding/json"
     5  	"errors"
     6  	"net/http"
     7  	"net/url"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/henrylee2cn/ameda"
    13  	"github.com/henrylee2cn/goutil"
    14  	"github.com/tidwall/gjson"
    15  
    16  	"github.com/bytedance/go-tagexpr"
    17  )
    18  
    19  const (
    20  	specialChar = "\x07"
    21  )
    22  
    23  type paramInfo struct {
    24  	fieldSelector  string
    25  	structField    reflect.StructField
    26  	tagInfos       []*tagInfo
    27  	omitIns        map[in]bool
    28  	bindErrFactory func(failField, msg string) error
    29  	looseZeroMode  bool
    30  	defaultVal     []byte
    31  }
    32  
    33  func (p *paramInfo) name(_ in) string {
    34  	var name string
    35  	for _, info := range p.tagInfos {
    36  		if info.paramIn == json {
    37  			name = info.paramName
    38  			break
    39  		}
    40  	}
    41  	if name == "" {
    42  		return p.structField.Name
    43  	}
    44  	return name
    45  }
    46  
    47  func (p *paramInfo) getField(expr *tagexpr.TagExpr, initZero bool) (reflect.Value, error) {
    48  	fh, found := expr.Field(p.fieldSelector)
    49  	if found {
    50  		v := fh.Value(initZero)
    51  		if v.IsValid() {
    52  			return v, nil
    53  		}
    54  	}
    55  	return reflect.Value{}, nil
    56  }
    57  
    58  func (p *paramInfo) bindRawBody(info *tagInfo, expr *tagexpr.TagExpr, bodyBytes []byte) error {
    59  	if len(bodyBytes) == 0 {
    60  		if info.required {
    61  			return info.requiredError
    62  		}
    63  		return nil
    64  	}
    65  	v, err := p.getField(expr, true)
    66  	if err != nil || !v.IsValid() {
    67  		return err
    68  	}
    69  	v = goutil.DereferenceValue(v)
    70  	switch v.Kind() {
    71  	case reflect.Slice:
    72  		if v.Type().Elem().Kind() != reflect.Uint8 {
    73  			return info.typeError
    74  		}
    75  		v.Set(reflect.ValueOf(bodyBytes))
    76  		return nil
    77  	case reflect.String:
    78  		v.Set(reflect.ValueOf(goutil.BytesToString(bodyBytes)))
    79  		return nil
    80  	default:
    81  		return info.typeError
    82  	}
    83  }
    84  
    85  func (p *paramInfo) bindPath(info *tagInfo, expr *tagexpr.TagExpr, pathParams PathParams) (bool, error) {
    86  	if pathParams == nil {
    87  		return false, nil
    88  	}
    89  	r, found := pathParams.Get(info.paramName)
    90  	if !found {
    91  		if info.required {
    92  			return false, info.requiredError
    93  		}
    94  		return false, nil
    95  	}
    96  	return true, p.bindStringSlice(info, expr, []string{r})
    97  }
    98  
    99  func (p *paramInfo) bindQuery(info *tagInfo, expr *tagexpr.TagExpr, queryValues url.Values) (bool, error) {
   100  	return p.bindMapStrings(info, expr, queryValues)
   101  }
   102  
   103  func (p *paramInfo) bindHeader(info *tagInfo, expr *tagexpr.TagExpr, header http.Header) (bool, error) {
   104  	return p.bindMapStrings(info, expr, header)
   105  }
   106  
   107  func (p *paramInfo) bindCookie(info *tagInfo, expr *tagexpr.TagExpr, cookies []*http.Cookie) error {
   108  	var r []string
   109  	for _, c := range cookies {
   110  		if c.Name == info.paramName {
   111  			r = append(r, c.Value)
   112  		}
   113  	}
   114  	if len(r) == 0 {
   115  		if info.required {
   116  			return info.requiredError
   117  		}
   118  		return nil
   119  	}
   120  	return p.bindStringSlice(info, expr, r)
   121  }
   122  
   123  func (p *paramInfo) bindOrRequireBody(info *tagInfo, expr *tagexpr.TagExpr, bodyCodec codec, bodyString string, postForm map[string][]string) (bool, error) {
   124  	switch bodyCodec {
   125  	case bodyForm:
   126  		return p.bindMapStrings(info, expr, postForm)
   127  	case bodyJSON:
   128  		return p.checkRequireJSON(info, expr, bodyString, false)
   129  	case bodyProtobuf:
   130  		err := p.checkRequireProtobuf(info, expr, false)
   131  		return err == nil, err
   132  	default:
   133  		return false, info.contentTypeError
   134  	}
   135  }
   136  
   137  func (p *paramInfo) checkRequireProtobuf(info *tagInfo, expr *tagexpr.TagExpr, checkOpt bool) error {
   138  	if checkOpt && !info.required {
   139  		v, err := p.getField(expr, false)
   140  		if err != nil || !v.IsValid() {
   141  			return info.requiredError
   142  		}
   143  	}
   144  	return nil
   145  }
   146  
   147  func (p *paramInfo) checkRequireJSON(info *tagInfo, expr *tagexpr.TagExpr, bodyString string, checkOpt bool) (bool, error) {
   148  	var requiredError error
   149  	if checkOpt || info.required { // only return error if it's a required field
   150  		requiredError = info.requiredError
   151  	}
   152  
   153  	if !gjson.Get(bodyString, info.namePath).Exists() {
   154  		idx := strings.LastIndex(info.namePath, ".")
   155  		// There should be a superior but it is empty, no error is reported
   156  		if idx > 0 && !gjson.Get(bodyString, info.namePath[:idx]).Exists() {
   157  			return true, nil
   158  		}
   159  		return false, requiredError
   160  	}
   161  	v, err := p.getField(expr, false)
   162  	if err != nil || !v.IsValid() {
   163  		return false, requiredError
   164  	}
   165  	return true, nil
   166  }
   167  
   168  func (p *paramInfo) bindMapStrings(info *tagInfo, expr *tagexpr.TagExpr, values map[string][]string) (bool, error) {
   169  	r, ok := values[info.paramName]
   170  	if !ok || len(r) == 0 {
   171  		if info.required {
   172  			return false, info.requiredError
   173  		}
   174  		return false, nil
   175  	}
   176  	return true, p.bindStringSlice(info, expr, r)
   177  }
   178  
   179  // NOTE: len(a)>0
   180  func (p *paramInfo) bindStringSlice(info *tagInfo, expr *tagexpr.TagExpr, a []string) error {
   181  	v, err := p.getField(expr, true)
   182  	if err != nil || !v.IsValid() {
   183  		return err
   184  	}
   185  
   186  	v = goutil.DereferenceValue(v)
   187  	switch v.Kind() {
   188  	case reflect.String:
   189  		v.SetString(a[0])
   190  		return nil
   191  
   192  	case reflect.Bool:
   193  		var bol bool
   194  		bol, err = strconv.ParseBool(a[0])
   195  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   196  			v.SetBool(bol)
   197  			return nil
   198  		}
   199  	case reflect.Float32:
   200  		var f float64
   201  		f, err = strconv.ParseFloat(a[0], 32)
   202  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   203  			v.SetFloat(f)
   204  			return nil
   205  		}
   206  	case reflect.Float64:
   207  		var f float64
   208  		f, err = strconv.ParseFloat(a[0], 64)
   209  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   210  			v.SetFloat(f)
   211  			return nil
   212  		}
   213  	case reflect.Int64, reflect.Int:
   214  		var i int64
   215  		i, err = strconv.ParseInt(a[0], 10, 64)
   216  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   217  			v.SetInt(i)
   218  			return nil
   219  		}
   220  	case reflect.Int32:
   221  		var i int64
   222  		i, err = strconv.ParseInt(a[0], 10, 32)
   223  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   224  			v.SetInt(i)
   225  			return nil
   226  		}
   227  	case reflect.Int16:
   228  		var i int64
   229  		i, err = strconv.ParseInt(a[0], 10, 16)
   230  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   231  			v.SetInt(i)
   232  			return nil
   233  		}
   234  	case reflect.Int8:
   235  		var i int64
   236  		i, err = strconv.ParseInt(a[0], 10, 8)
   237  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   238  			v.SetInt(i)
   239  			return nil
   240  		}
   241  	case reflect.Uint64, reflect.Uint:
   242  		var u uint64
   243  		u, err = strconv.ParseUint(a[0], 10, 64)
   244  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   245  			v.SetUint(u)
   246  			return nil
   247  		}
   248  	case reflect.Uint32:
   249  		var u uint64
   250  		u, err = strconv.ParseUint(a[0], 10, 32)
   251  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   252  			v.SetUint(u)
   253  			return nil
   254  		}
   255  	case reflect.Uint16:
   256  		var u uint64
   257  		u, err = strconv.ParseUint(a[0], 10, 16)
   258  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   259  			v.SetUint(u)
   260  			return nil
   261  		}
   262  	case reflect.Uint8:
   263  		var u uint64
   264  		u, err = strconv.ParseUint(a[0], 10, 8)
   265  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   266  			v.SetUint(u)
   267  			return nil
   268  		}
   269  	case reflect.Slice:
   270  		vv, err := stringsToValue(v.Type().Elem(), a, p.looseZeroMode)
   271  		if err == nil {
   272  			v.Set(vv)
   273  			return nil
   274  		}
   275  		fallthrough
   276  	default:
   277  		fn := typeUnmarshalFuncs[v.Type()]
   278  		if fn != nil {
   279  			vv, err := fn(a[0], p.looseZeroMode)
   280  			if err == nil {
   281  				v.Set(vv)
   282  				return nil
   283  			}
   284  		}
   285  	}
   286  	return info.typeError
   287  }
   288  
   289  func (p *paramInfo) bindDefaultVal(expr *tagexpr.TagExpr, defaultValue []byte) (bool, error) {
   290  	if defaultValue == nil {
   291  		return false, nil
   292  	}
   293  	v, err := p.getField(expr, true)
   294  	if err != nil || !v.IsValid() {
   295  		return false, err
   296  	}
   297  	return true, jsonpkg.Unmarshal(defaultValue, v.Addr().Interface())
   298  }
   299  
   300  // setDefaultVal preprocess the default tags and store the parsed value
   301  func (p *paramInfo) setDefaultVal() error {
   302  	for _, info := range p.tagInfos {
   303  		if info.paramIn != default_val {
   304  			continue
   305  		}
   306  
   307  		defaultVal := info.paramName
   308  		st := ameda.DereferenceType(p.structField.Type)
   309  		switch st.Kind() {
   310  		case reflect.String:
   311  			p.defaultVal, _ = jsonpkg.Marshal(defaultVal)
   312  			continue
   313  		case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct:
   314  			// escape single quote and double quote, replace single quote with double quote
   315  			defaultVal = strings.Replace(defaultVal, `"`, `\"`, -1)
   316  			defaultVal = strings.Replace(defaultVal, `\'`, specialChar, -1)
   317  			defaultVal = strings.Replace(defaultVal, `'`, `"`, -1)
   318  			defaultVal = strings.Replace(defaultVal, specialChar, `'`, -1)
   319  		}
   320  		p.defaultVal = ameda.UnsafeStringToBytes(defaultVal)
   321  	}
   322  	return nil
   323  }
   324  
   325  var errMismatch = errors.New("type mismatch")
   326  
   327  func stringsToValue(t reflect.Type, a []string, emptyAsZero bool) (reflect.Value, error) {
   328  	var i interface{}
   329  	var err error
   330  	var ptrDepth int
   331  	elemKind := t.Kind()
   332  	for elemKind == reflect.Ptr {
   333  		t = t.Elem()
   334  		elemKind = t.Kind()
   335  		ptrDepth++
   336  	}
   337  	switch elemKind {
   338  	case reflect.String:
   339  		i = a
   340  	case reflect.Bool:
   341  		i, err = goutil.StringsToBools(a, emptyAsZero)
   342  	case reflect.Float32:
   343  		i, err = goutil.StringsToFloat32s(a, emptyAsZero)
   344  	case reflect.Float64:
   345  		i, err = goutil.StringsToFloat64s(a, emptyAsZero)
   346  	case reflect.Int:
   347  		i, err = goutil.StringsToInts(a, emptyAsZero)
   348  	case reflect.Int64:
   349  		i, err = goutil.StringsToInt64s(a, emptyAsZero)
   350  	case reflect.Int32:
   351  		i, err = goutil.StringsToInt32s(a, emptyAsZero)
   352  	case reflect.Int16:
   353  		i, err = goutil.StringsToInt16s(a, emptyAsZero)
   354  	case reflect.Int8:
   355  		i, err = goutil.StringsToInt8s(a, emptyAsZero)
   356  	case reflect.Uint:
   357  		i, err = goutil.StringsToUints(a, emptyAsZero)
   358  	case reflect.Uint64:
   359  		i, err = goutil.StringsToUint64s(a, emptyAsZero)
   360  	case reflect.Uint32:
   361  		i, err = goutil.StringsToUint32s(a, emptyAsZero)
   362  	case reflect.Uint16:
   363  		i, err = goutil.StringsToUint16s(a, emptyAsZero)
   364  	case reflect.Uint8:
   365  		i, err = goutil.StringsToUint8s(a, emptyAsZero)
   366  	default:
   367  		fn := typeUnmarshalFuncs[t]
   368  		if fn == nil {
   369  			return reflect.Value{}, errMismatch
   370  		}
   371  		v := reflect.New(reflect.SliceOf(t)).Elem()
   372  		for _, s := range a {
   373  			vv, err := fn(s, emptyAsZero)
   374  			if err != nil {
   375  				return reflect.Value{}, errMismatch
   376  			}
   377  			v = reflect.Append(v, vv)
   378  		}
   379  		return goutil.ReferenceSlice(v, ptrDepth), nil
   380  	}
   381  	if err != nil {
   382  		return reflect.Value{}, errMismatch
   383  	}
   384  	return goutil.ReferenceSlice(reflect.ValueOf(i), ptrDepth), nil
   385  }