github.com/bytedance/go-tagexpr/v2@v2.9.8/binding/param_info.go (about)

     1  package binding
     2  
     3  import (
     4  	jsonpkg "encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"mime/multipart"
     8  	"net/http"
     9  	"net/url"
    10  	"reflect"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/andeya/ameda"
    15  	"github.com/bytedance/go-tagexpr/v2"
    16  	gjson "github.com/bytedance/go-tagexpr/v2/binding/tidwall_gjson"
    17  )
    18  
    19  const (
    20  	specialChar = "\x07"
    21  )
    22  
    23  type paramInfo struct {
    24  	fieldSelector  string
    25  	structField    reflect.StructField
    26  	tagInfos       []*tagInfo
    27  	omitIns        map[in]bool
    28  	bindErrFactory func(failField, msg string) error
    29  	looseZeroMode  bool
    30  	defaultVal     []byte
    31  }
    32  
    33  func (p *paramInfo) name(_ in) string {
    34  	var name string
    35  	for _, info := range p.tagInfos {
    36  		if info.paramIn == json {
    37  			name = info.paramName
    38  			break
    39  		}
    40  	}
    41  	if name == "" {
    42  		return p.structField.Name
    43  	}
    44  	return name
    45  }
    46  
    47  func (p *paramInfo) getField(expr *tagexpr.TagExpr, initZero bool) (reflect.Value, error) {
    48  	fh, found := expr.Field(p.fieldSelector)
    49  	if found {
    50  		v := fh.Value(initZero)
    51  		if v.IsValid() {
    52  			return v, nil
    53  		}
    54  	}
    55  	return reflect.Value{}, nil
    56  }
    57  
    58  func (p *paramInfo) bindRawBody(info *tagInfo, expr *tagexpr.TagExpr, bodyBytes []byte) error {
    59  	if len(bodyBytes) == 0 {
    60  		if info.required {
    61  			return info.requiredError
    62  		}
    63  		return nil
    64  	}
    65  	v, err := p.getField(expr, true)
    66  	if err != nil || !v.IsValid() {
    67  		return err
    68  	}
    69  	v = ameda.DereferenceValue(v)
    70  	switch v.Kind() {
    71  	case reflect.Slice:
    72  		if v.Type().Elem().Kind() != reflect.Uint8 {
    73  			return info.typeError
    74  		}
    75  		v.Set(reflect.ValueOf(bodyBytes))
    76  		return nil
    77  	case reflect.String:
    78  		v.Set(reflect.ValueOf(ameda.UnsafeBytesToString(bodyBytes)))
    79  		return nil
    80  	default:
    81  		return info.typeError
    82  	}
    83  }
    84  
    85  func (p *paramInfo) bindPath(info *tagInfo, expr *tagexpr.TagExpr, pathParams PathParams) (bool, error) {
    86  	if pathParams == nil {
    87  		return false, nil
    88  	}
    89  	r, found := pathParams.Get(info.paramName)
    90  	if !found {
    91  		if info.required {
    92  			return false, info.requiredError
    93  		}
    94  		return false, nil
    95  	}
    96  	return true, p.bindStringSlice(info, expr, []string{r})
    97  }
    98  
    99  func (p *paramInfo) bindQuery(info *tagInfo, expr *tagexpr.TagExpr, queryValues url.Values) (bool, error) {
   100  	return p.bindMapStrings(info, expr, queryValues)
   101  }
   102  
   103  func (p *paramInfo) bindHeader(info *tagInfo, expr *tagexpr.TagExpr, header http.Header) (bool, error) {
   104  	return p.bindMapStrings(info, expr, header)
   105  }
   106  
   107  func (p *paramInfo) bindCookie(info *tagInfo, expr *tagexpr.TagExpr, cookies []*http.Cookie) (bool, error) {
   108  	var r []string
   109  	for _, c := range cookies {
   110  		if c.Name == info.paramName {
   111  			r = append(r, c.Value)
   112  		}
   113  	}
   114  	if len(r) == 0 {
   115  		if info.required {
   116  			return false, info.requiredError
   117  		}
   118  		return false, nil
   119  	}
   120  	return true, p.bindStringSlice(info, expr, r)
   121  }
   122  
   123  func (p *paramInfo) bindOrRequireBody(
   124  	info *tagInfo, expr *tagexpr.TagExpr, bodyCodec codec, bodyString string,
   125  	postForm map[string][]string, fileHeaders map[string][]*multipart.FileHeader, hasDefaultVal bool) (bool, error) {
   126  	switch bodyCodec {
   127  	case bodyForm:
   128  		found, err := p.bindMapStrings(info, expr, postForm)
   129  		if !found {
   130  			return p.bindFileHeaders(info, expr, fileHeaders)
   131  		}
   132  		return found, err
   133  	case bodyJSON:
   134  		return p.checkRequireJSON(info, expr, bodyString, hasDefaultVal)
   135  	case bodyProtobuf:
   136  		// It has been checked when binding, no need to check now
   137  		return true, nil
   138  		// err := p.checkRequireProtobuf(info, expr, false)
   139  		// return err == nil, err
   140  	default:
   141  		return false, info.contentTypeError
   142  	}
   143  }
   144  
   145  func (p *paramInfo) checkRequireProtobuf(info *tagInfo, expr *tagexpr.TagExpr, checkOpt bool) error {
   146  	if checkOpt && !info.required {
   147  		v, err := p.getField(expr, false)
   148  		if err != nil || !v.IsValid() {
   149  			return info.requiredError
   150  		}
   151  	}
   152  	return nil
   153  }
   154  
   155  func (p *paramInfo) checkRequireJSON(info *tagInfo, expr *tagexpr.TagExpr, bodyString string, hasDefaultVal bool) (bool, error) {
   156  	var requiredError error
   157  	if info.required { // only return error if it's a required field
   158  		requiredError = info.requiredError
   159  	} else if !hasDefaultVal {
   160  		return true, nil
   161  	}
   162  	if !gjson.Get(bodyString, info.namePath).Exists() {
   163  		idx := strings.LastIndex(info.namePath, ".")
   164  		// There should be a superior but it is empty, no error is reported
   165  		if idx > 0 && !gjson.Get(bodyString, info.namePath[:idx]).Exists() {
   166  			return true, nil
   167  		}
   168  		return false, requiredError
   169  	}
   170  	v, err := p.getField(expr, false)
   171  	if err != nil || !v.IsValid() {
   172  		return false, requiredError
   173  	}
   174  	return true, nil
   175  }
   176  
   177  var fileHeaderType = reflect.TypeOf(multipart.FileHeader{})
   178  
   179  func (p *paramInfo) bindFileHeaders(info *tagInfo, expr *tagexpr.TagExpr, fileHeaders map[string][]*multipart.FileHeader) (bool, error) {
   180  	r, ok := fileHeaders[info.paramName]
   181  	if !ok || len(r) == 0 {
   182  		if info.required {
   183  			return false, info.requiredError
   184  		}
   185  		return false, nil
   186  	}
   187  	v, err := p.getField(expr, true)
   188  	if err != nil || !v.IsValid() {
   189  		return true, err
   190  	}
   191  	v = ameda.DereferenceValue(v)
   192  	var elemType reflect.Type
   193  	isSlice := v.Kind() == reflect.Slice
   194  	if isSlice {
   195  		elemType = v.Type().Elem()
   196  	} else {
   197  		elemType = v.Type()
   198  	}
   199  	var ptrDepth int
   200  	for elemType.Kind() == reflect.Ptr {
   201  		elemType = elemType.Elem()
   202  		ptrDepth++
   203  	}
   204  	if elemType != fileHeaderType {
   205  		return true, errors.New("parameter type is not (*)multipart.FileHeader struct or slice")
   206  	}
   207  	if len(r) == 0 || r[0] == nil {
   208  		return true, nil
   209  	}
   210  	if !isSlice {
   211  		v.Set(reflect.ValueOf(*r[0]))
   212  		return true, nil
   213  	}
   214  	for _, fileHeader := range r {
   215  		v.Set(reflect.Append(v, ameda.ReferenceValue(reflect.ValueOf(fileHeader), ptrDepth-1)))
   216  	}
   217  	return true, nil
   218  }
   219  
   220  func (p *paramInfo) bindMapStrings(info *tagInfo, expr *tagexpr.TagExpr, values map[string][]string) (bool, error) {
   221  	r, ok := values[info.paramName]
   222  	if !ok || len(r) == 0 {
   223  		if info.required {
   224  			return false, info.requiredError
   225  		}
   226  		return false, nil
   227  	}
   228  	return true, p.bindStringSlice(info, expr, r)
   229  }
   230  
   231  // NOTE: len(a)>0
   232  func (p *paramInfo) bindStringSlice(info *tagInfo, expr *tagexpr.TagExpr, a []string) error {
   233  	v, err := p.getField(expr, true)
   234  	if err != nil || !v.IsValid() {
   235  		return err
   236  	}
   237  
   238  	v = ameda.DereferenceValue(v)
   239  
   240  	// we have customized unmarshal defined, we should use it firstly
   241  	if fn, exist := typeUnmarshalFuncs[v.Type()]; exist {
   242  		vv, err := fn(a[0], p.looseZeroMode)
   243  		if err == nil {
   244  			v.Set(vv)
   245  			return nil
   246  		}
   247  		return info.typeError
   248  	}
   249  
   250  	switch v.Kind() {
   251  	case reflect.String:
   252  		v.SetString(a[0])
   253  		return nil
   254  
   255  	case reflect.Bool:
   256  		var bol bool
   257  		bol, err = strconv.ParseBool(a[0])
   258  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   259  			v.SetBool(bol)
   260  			return nil
   261  		}
   262  	case reflect.Float32:
   263  		var f float64
   264  		f, err = strconv.ParseFloat(a[0], 32)
   265  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   266  			v.SetFloat(f)
   267  			return nil
   268  		}
   269  	case reflect.Float64:
   270  		var f float64
   271  		f, err = strconv.ParseFloat(a[0], 64)
   272  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   273  			v.SetFloat(f)
   274  			return nil
   275  		}
   276  	case reflect.Int64, reflect.Int:
   277  		var i int64
   278  		i, err = strconv.ParseInt(a[0], 10, 64)
   279  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   280  			v.SetInt(i)
   281  			return nil
   282  		}
   283  	case reflect.Int32:
   284  		var i int64
   285  		i, err = strconv.ParseInt(a[0], 10, 32)
   286  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   287  			v.SetInt(i)
   288  			return nil
   289  		}
   290  	case reflect.Int16:
   291  		var i int64
   292  		i, err = strconv.ParseInt(a[0], 10, 16)
   293  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   294  			v.SetInt(i)
   295  			return nil
   296  		}
   297  	case reflect.Int8:
   298  		var i int64
   299  		i, err = strconv.ParseInt(a[0], 10, 8)
   300  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   301  			v.SetInt(i)
   302  			return nil
   303  		}
   304  	case reflect.Uint64, reflect.Uint:
   305  		var u uint64
   306  		u, err = strconv.ParseUint(a[0], 10, 64)
   307  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   308  			v.SetUint(u)
   309  			return nil
   310  		}
   311  	case reflect.Uint32:
   312  		var u uint64
   313  		u, err = strconv.ParseUint(a[0], 10, 32)
   314  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   315  			v.SetUint(u)
   316  			return nil
   317  		}
   318  	case reflect.Uint16:
   319  		var u uint64
   320  		u, err = strconv.ParseUint(a[0], 10, 16)
   321  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   322  			v.SetUint(u)
   323  			return nil
   324  		}
   325  	case reflect.Uint8:
   326  		var u uint64
   327  		u, err = strconv.ParseUint(a[0], 10, 8)
   328  		if err == nil || (a[0] == "" && p.looseZeroMode) {
   329  			v.SetUint(u)
   330  			return nil
   331  		}
   332  	case reflect.Slice:
   333  		var ptrDepth int
   334  		t := v.Type().Elem()
   335  		elemKind := t.Kind()
   336  		for elemKind == reflect.Ptr {
   337  			t = t.Elem()
   338  			elemKind = t.Kind()
   339  			ptrDepth++
   340  		}
   341  		val := reflect.New(v.Type()).Elem()
   342  		for _, s := range a {
   343  			var vv reflect.Value
   344  			vv, err = stringToValue(t, s, p.looseZeroMode)
   345  			if err != nil {
   346  				break
   347  			}
   348  			val = reflect.Append(val, ameda.ReferenceValue(vv, ptrDepth))
   349  		}
   350  		if err == nil {
   351  			v.Set(val)
   352  			return nil
   353  		}
   354  		fallthrough
   355  	default:
   356  		// no customized unmarshal defined
   357  		err = unmarshal(ameda.UnsafeStringToBytes(a[0]), v.Addr().Interface())
   358  		if err == nil {
   359  			return nil
   360  		}
   361  	}
   362  	return info.typeError
   363  }
   364  
   365  func (p *paramInfo) bindDefaultVal(expr *tagexpr.TagExpr, defaultValue []byte) (bool, error) {
   366  	if defaultValue == nil {
   367  		return false, nil
   368  	}
   369  	v, err := p.getField(expr, true)
   370  	if err != nil || !v.IsValid() {
   371  		return false, err
   372  	}
   373  	return true, jsonpkg.Unmarshal(defaultValue, v.Addr().Interface())
   374  }
   375  
   376  // setDefaultVal preprocess the default tags and store the parsed value
   377  func (p *paramInfo) setDefaultVal() error {
   378  	for _, info := range p.tagInfos {
   379  		if info.paramIn != default_val {
   380  			continue
   381  		}
   382  
   383  		defaultVal := info.paramName
   384  		st := ameda.DereferenceType(p.structField.Type)
   385  		switch st.Kind() {
   386  		case reflect.String:
   387  			p.defaultVal, _ = jsonpkg.Marshal(defaultVal)
   388  			continue
   389  		case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct:
   390  			// escape single quote and double quote, replace single quote with double quote
   391  			defaultVal = strings.Replace(defaultVal, `"`, `\"`, -1)
   392  			defaultVal = strings.Replace(defaultVal, `\'`, specialChar, -1)
   393  			defaultVal = strings.Replace(defaultVal, `'`, `"`, -1)
   394  			defaultVal = strings.Replace(defaultVal, specialChar, `'`, -1)
   395  		}
   396  		p.defaultVal = ameda.UnsafeStringToBytes(defaultVal)
   397  	}
   398  	return nil
   399  }
   400  
   401  func stringToValue(elemType reflect.Type, s string, emptyAsZero bool) (v reflect.Value, err error) {
   402  	v = reflect.New(elemType).Elem()
   403  
   404  	// we have customized unmarshal defined, we should use it firstly
   405  	if fn, exist := typeUnmarshalFuncs[elemType]; exist {
   406  		vv, err := fn(s, emptyAsZero)
   407  		if err == nil {
   408  			v.Set(vv)
   409  		}
   410  		return v, err
   411  	}
   412  
   413  	switch elemType.Kind() {
   414  	case reflect.String:
   415  		v.SetString(s)
   416  	case reflect.Bool:
   417  		var i bool
   418  		i, err = ameda.StringToBool(s, emptyAsZero)
   419  		if err == nil {
   420  			v.SetBool(i)
   421  		}
   422  	case reflect.Float32, reflect.Float64:
   423  		var i float64
   424  		i, err = ameda.StringToFloat64(s, emptyAsZero)
   425  		if err == nil {
   426  			v.SetFloat(i)
   427  		}
   428  	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
   429  		var i int64
   430  		i, err = ameda.StringToInt64(s, emptyAsZero)
   431  		if err == nil {
   432  			v.SetInt(i)
   433  		}
   434  	case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
   435  		var i uint64
   436  		i, err = ameda.StringToUint64(s, emptyAsZero)
   437  		if err == nil {
   438  			v.SetUint(i)
   439  		}
   440  	default:
   441  		// no customized unmarshal defined
   442  		err = unmarshal(ameda.UnsafeStringToBytes(s), v.Addr().Interface())
   443  	}
   444  	if err != nil {
   445  		return reflect.Value{}, fmt.Errorf("type mismatch, error=%v", err)
   446  	}
   447  	return v, nil
   448  }