github.com/bytedance/go-tagexpr/v2@v2.9.8/binding/bind.go (about)

     1  package binding
     2  
     3  import (
     4  	jsonpkg "encoding/json"
     5  	"mime/multipart"
     6  	"net/http"
     7  	"reflect"
     8  	"strings"
     9  	"sync"
    10  
    11  	"github.com/andeya/ameda"
    12  	"github.com/andeya/goutil"
    13  
    14  	"github.com/bytedance/go-tagexpr/v2"
    15  	"github.com/bytedance/go-tagexpr/v2/validator"
    16  )
    17  
    18  // Binding binding and verification tool for http request
    19  type Binding struct {
    20  	vd                *validator.Validator
    21  	recvs             map[uintptr]*receiver
    22  	lock              sync.RWMutex
    23  	bindErrFactory    func(failField, msg string) error
    24  	config            Config
    25  	jsonUnmarshalFunc func(data []byte, v interface{}) error
    26  }
    27  
    28  // New creates a binding tool.
    29  // NOTE:
    30  //
    31  //	Use default tag name for config fields that are empty
    32  func New(config *Config) *Binding {
    33  	if config == nil {
    34  		config = new(Config)
    35  	}
    36  	b := &Binding{
    37  		recvs:  make(map[uintptr]*receiver, 1024),
    38  		config: *config,
    39  	}
    40  	b.config.init()
    41  	b.vd = validator.New(b.config.Validator)
    42  	return b.SetErrorFactory(nil, nil)
    43  }
    44  
    45  // SetLooseZeroMode if set to true,
    46  // the empty string request parameter is bound to the zero value of parameter.
    47  // NOTE:
    48  //
    49  //	The default is false;
    50  //	Suitable for these parameter types: query/header/cookie/form .
    51  func (b *Binding) SetLooseZeroMode(enable bool) *Binding {
    52  	b.config.LooseZeroMode = enable
    53  	for k := range b.recvs {
    54  		delete(b.recvs, k)
    55  	}
    56  	return b
    57  }
    58  
    59  var defaultValidatingErrFactory = newDefaultErrorFactory("validating")
    60  var defaultBindErrFactory = newDefaultErrorFactory("binding")
    61  
    62  // SetErrorFactory customizes the factory of validation error.
    63  // NOTE:
    64  //
    65  //	If errFactory==nil, the default is used
    66  func (b *Binding) SetErrorFactory(bindErrFactory, validatingErrFactory func(failField, msg string) error) *Binding {
    67  	if bindErrFactory == nil {
    68  		bindErrFactory = defaultBindErrFactory
    69  	}
    70  	if validatingErrFactory == nil {
    71  		validatingErrFactory = defaultValidatingErrFactory
    72  	}
    73  	b.bindErrFactory = bindErrFactory
    74  	b.vd.SetErrorFactory(validatingErrFactory)
    75  	return b
    76  }
    77  
    78  // BindAndValidate binds the request parameters and validates them if needed.
    79  func (b *Binding) BindAndValidate(recvPointer interface{}, req *http.Request, pathParams PathParams) error {
    80  	return b.IBindAndValidate(recvPointer, wrapRequest(req), pathParams)
    81  }
    82  
    83  // Bind binds the request parameters.
    84  func (b *Binding) Bind(recvPointer interface{}, req *http.Request, pathParams PathParams) error {
    85  	return b.IBind(recvPointer, wrapRequest(req), pathParams)
    86  }
    87  
    88  // IBindAndValidate binds the request parameters and validates them if needed.
    89  func (b *Binding) IBindAndValidate(recvPointer interface{}, req Request, pathParams PathParams) error {
    90  	v, hasVd, err := b.bind(recvPointer, req, pathParams)
    91  	if err != nil {
    92  		return err
    93  	}
    94  	if hasVd {
    95  		return b.vd.Validate(v)
    96  	}
    97  	return nil
    98  }
    99  
   100  // IBind binds the request parameters.
   101  func (b *Binding) IBind(recvPointer interface{}, req Request, pathParams PathParams) error {
   102  	_, _, err := b.bind(recvPointer, req, pathParams)
   103  	return err
   104  }
   105  
   106  // Validate validates whether the fields of value is valid.
   107  func (b *Binding) Validate(value interface{}) error {
   108  	return b.vd.Validate(value)
   109  }
   110  
   111  func (b *Binding) bind(pointer interface{}, req Request, pathParams PathParams) (elemValue reflect.Value, hasVd bool, err error) {
   112  	elemValue, err = b.receiverValueOf(pointer)
   113  	if err != nil {
   114  		return
   115  	}
   116  	if elemValue.Kind() == reflect.Struct {
   117  		hasVd, err = b.bindStruct(pointer, elemValue, req, pathParams)
   118  	} else {
   119  		hasVd, err = b.bindNonstruct(pointer, elemValue, req, pathParams)
   120  	}
   121  	return
   122  }
   123  
   124  func (b *Binding) bindNonstruct(pointer interface{}, _ reflect.Value, req Request, _ PathParams) (hasVd bool, err error) {
   125  	bodyCodec := getBodyCodec(req)
   126  	switch bodyCodec {
   127  	case bodyJSON:
   128  		hasVd = true
   129  		bodyBytes, err := req.GetBody()
   130  		if err != nil {
   131  			return hasVd, err
   132  		}
   133  		err = b.bindJSON(pointer, bodyBytes)
   134  	case bodyProtobuf:
   135  		hasVd = true
   136  		bodyBytes, err := req.GetBody()
   137  		if err != nil {
   138  			return hasVd, err
   139  		}
   140  		err = bindProtobuf(pointer, bodyBytes)
   141  	case bodyForm:
   142  		postForm, err := req.GetPostForm()
   143  		if err != nil {
   144  			return false, err
   145  		}
   146  		b, _ := jsonpkg.Marshal(postForm)
   147  		err = jsonpkg.Unmarshal(b, pointer)
   148  	default:
   149  		// query and form
   150  		form, err := req.GetForm()
   151  		if err != nil {
   152  			return false, err
   153  		}
   154  		b, _ := jsonpkg.Marshal(form)
   155  		err = jsonpkg.Unmarshal(b, pointer)
   156  	}
   157  	return
   158  }
   159  
   160  func (b *Binding) bindStruct(structPointer interface{}, structValue reflect.Value, req Request, pathParams PathParams) (hasVd bool, err error) {
   161  	recv, err := b.getOrPrepareReceiver(structValue)
   162  	if err != nil {
   163  		return
   164  	}
   165  
   166  	expr, err := b.vd.VM().Run(structValue)
   167  	if err != nil {
   168  		return
   169  	}
   170  
   171  	bodyCodec, bodyBytes, err := recv.getBodyInfo(req)
   172  	if len(bodyBytes) > 0 {
   173  		err = b.prebindBody(structPointer, structValue, bodyCodec, bodyBytes)
   174  	}
   175  	if err != nil {
   176  		return
   177  	}
   178  	bodyString := ameda.UnsafeBytesToString(bodyBytes)
   179  	postForm, err := req.GetPostForm()
   180  	if err != nil {
   181  		return
   182  	}
   183  	var fileHeaders map[string][]*multipart.FileHeader
   184  	if _req, ok := req.(requestWithFileHeader); ok {
   185  		fileHeaders, err = _req.GetFileHeaders()
   186  		if err != nil {
   187  			return
   188  		}
   189  	}
   190  	queryValues := recv.getQuery(req)
   191  	cookies := recv.getCookies(req)
   192  
   193  	for _, param := range recv.params {
   194  		for i, info := range param.tagInfos {
   195  			var found bool
   196  			switch info.paramIn {
   197  			case raw_body:
   198  				err = param.bindRawBody(info, expr, bodyBytes)
   199  				found = err == nil
   200  			case path:
   201  				found, err = param.bindPath(info, expr, pathParams)
   202  			case query:
   203  				found, err = param.bindQuery(info, expr, queryValues)
   204  			case cookie:
   205  				found, err = param.bindCookie(info, expr, cookies)
   206  			case header:
   207  				found, err = param.bindHeader(info, expr, req.GetHeader())
   208  			case form, json, protobuf:
   209  				if info.paramIn == in(bodyCodec) {
   210  					found, err = param.bindOrRequireBody(info, expr, bodyCodec, bodyString, postForm, fileHeaders,
   211  						recv.hasDefaultVal)
   212  				} else if info.required {
   213  					found = false
   214  					err = info.requiredError
   215  				}
   216  			case default_val:
   217  				found, err = param.bindDefaultVal(expr, param.defaultVal)
   218  			}
   219  			if found && err == nil {
   220  				break
   221  			}
   222  			if (found || i == len(param.tagInfos)-1) && err != nil {
   223  				return recv.hasVd, err
   224  			}
   225  		}
   226  	}
   227  	return recv.hasVd, nil
   228  }
   229  
   230  func (b *Binding) receiverValueOf(receiver interface{}) (reflect.Value, error) {
   231  	v := reflect.ValueOf(receiver)
   232  	if v.Kind() == reflect.Ptr {
   233  		v = ameda.DereferencePtrValue(v)
   234  		if v.IsValid() && v.CanAddr() {
   235  			return v, nil
   236  		}
   237  	}
   238  	return v, b.bindErrFactory("", "receiver must be a non-nil pointer")
   239  }
   240  
   241  func (b *Binding) getOrPrepareReceiver(value reflect.Value) (*receiver, error) {
   242  	runtimeTypeID := ameda.ValueFrom(value).RuntimeTypeID()
   243  	b.lock.RLock()
   244  	recv, ok := b.recvs[runtimeTypeID]
   245  	b.lock.RUnlock()
   246  	if ok {
   247  		return recv, nil
   248  	}
   249  	t := value.Type()
   250  	expr, err := b.vd.VM().Run(reflect.New(t).Elem())
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  	recv = &receiver{
   255  		params:        make([]*paramInfo, 0, 16),
   256  		looseZeroMode: b.config.LooseZeroMode,
   257  	}
   258  	var errExprSelector tagexpr.ExprSelector
   259  	var errMsg string
   260  	var fieldsWithValidTag = make(map[string]bool)
   261  	expr.RangeFields(func(fh *tagexpr.FieldHandler) bool {
   262  		if !fh.Value(true).CanSet() {
   263  			selector := fh.StringSelector()
   264  			errMsg = "field cannot be set: " + selector
   265  			errExprSelector = tagexpr.ExprSelector(selector)
   266  			return true
   267  		}
   268  
   269  		tagKVs := b.config.parse(fh.StructField())
   270  		p := recv.getOrAddParam(fh, b.bindErrFactory)
   271  		tagInfos := [maxIn]*tagInfo{}
   272  	L:
   273  		for _, tagKV := range tagKVs {
   274  			paramIn := undefined
   275  			switch tagKV.name {
   276  			case b.config.Validator:
   277  				recv.hasVd = true
   278  				continue L
   279  			case b.config.PathParam:
   280  				paramIn = path
   281  			case b.config.FormBody:
   282  				paramIn = form
   283  			case b.config.Query:
   284  				paramIn = query
   285  			case b.config.Cookie:
   286  				paramIn = cookie
   287  			case b.config.Header:
   288  				paramIn = header
   289  			case b.config.protobufBody:
   290  				paramIn = protobuf
   291  			case b.config.jsonBody:
   292  				paramIn = json
   293  			case b.config.RawBody:
   294  				paramIn = raw_body
   295  			case b.config.defaultVal:
   296  				paramIn = default_val
   297  			default:
   298  				continue L
   299  			}
   300  			if paramIn == default_val {
   301  				tagInfos[paramIn] = &tagInfo{paramIn: default_val, paramName: tagKV.value}
   302  			} else {
   303  				tagInfos[paramIn] = tagKV.toInfo(paramIn == header)
   304  			}
   305  		}
   306  
   307  		for i, info := range tagInfos {
   308  			if info != nil {
   309  				if info.paramIn != default_val && info.paramName == "-" {
   310  					p.omitIns[in(i)] = true
   311  					recv.assginIn(in(i), false)
   312  				} else {
   313  					info.paramIn = in(i)
   314  					p.tagInfos = append(p.tagInfos, info)
   315  					recv.assginIn(in(i), true)
   316  				}
   317  			}
   318  		}
   319  		fs := string(fh.FieldSelector())
   320  		switch len(p.tagInfos) {
   321  		case 0:
   322  			var canDefault = true
   323  			for s := range fieldsWithValidTag {
   324  				if strings.HasPrefix(fs, s) {
   325  					canDefault = false
   326  					break
   327  				}
   328  			}
   329  			if canDefault {
   330  				if !goutil.IsExportedName(p.structField.Name) {
   331  					canDefault = false
   332  				}
   333  			}
   334  			// Supports the default binding order when there is no valid tag in the superior field of the exportable field
   335  			if canDefault {
   336  				for _, i := range sortedDefaultIn {
   337  					if p.omitIns[i] {
   338  						recv.assginIn(i, false)
   339  						continue
   340  					}
   341  					p.tagInfos = append(p.tagInfos, &tagInfo{
   342  						paramIn:   i,
   343  						paramName: p.structField.Name,
   344  					})
   345  					recv.assginIn(i, true)
   346  				}
   347  			}
   348  		case 1:
   349  			if p.tagInfos[0].paramIn == default_val {
   350  				last := p.tagInfos[0]
   351  				p.tagInfos = make([]*tagInfo, 0, len(sortedDefaultIn)+1)
   352  				for _, i := range sortedDefaultIn {
   353  					if p.omitIns[i] {
   354  						recv.assginIn(i, false)
   355  						continue
   356  					}
   357  					p.tagInfos = append(p.tagInfos, &tagInfo{
   358  						paramIn:   i,
   359  						paramName: p.structField.Name,
   360  					})
   361  					recv.assginIn(i, true)
   362  				}
   363  				p.tagInfos = append(p.tagInfos, last)
   364  			}
   365  			fallthrough
   366  		default:
   367  			fieldsWithValidTag[fs+tagexpr.FieldSeparator] = true
   368  		}
   369  		if !recv.hasVd {
   370  			_, recv.hasVd = tagKVs.lookup(b.config.Validator)
   371  		}
   372  		return true
   373  	})
   374  
   375  	if errMsg != "" {
   376  		return nil, b.bindErrFactory(errExprSelector.String(), errMsg)
   377  	}
   378  	if !recv.hasVd {
   379  		recv.hasVd, _ = b.findVdTag(ameda.DereferenceType(t), false, 20, map[reflect.Type]bool{})
   380  	}
   381  	recv.initParams()
   382  
   383  	b.lock.Lock()
   384  	b.recvs[runtimeTypeID] = recv
   385  	b.lock.Unlock()
   386  
   387  	return recv, nil
   388  }
   389  
   390  func (b *Binding) findVdTag(t reflect.Type, inMapOrSlice bool, depth int, exist map[reflect.Type]bool) (hasVd bool, err error) {
   391  	if depth <= 0 || exist[t] {
   392  		return
   393  	}
   394  	depth--
   395  	switch t.Kind() {
   396  	case reflect.Struct:
   397  		exist[t] = true
   398  		for i := t.NumField() - 1; i >= 0; i-- {
   399  			field := t.Field(i)
   400  			if inMapOrSlice {
   401  				tagKVs := b.config.parse(field)
   402  				for _, tagKV := range tagKVs {
   403  					if tagKV.name == b.config.Validator {
   404  						return true, nil
   405  					}
   406  				}
   407  			}
   408  			hasVd, _ = b.findVdTag(ameda.DereferenceType(field.Type), inMapOrSlice, depth, exist)
   409  			if hasVd {
   410  				return true, nil
   411  			}
   412  		}
   413  		return false, nil
   414  	case reflect.Slice, reflect.Array, reflect.Map:
   415  		return b.findVdTag(ameda.DereferenceType(t.Elem()), true, depth, exist)
   416  	default:
   417  		return false, nil
   418  	}
   419  }
   420  
   421  func (b *Binding) bindJSON(pointer interface{}, bodyBytes []byte) error {
   422  	if b.jsonUnmarshalFunc != nil {
   423  		return b.jsonUnmarshalFunc(bodyBytes, pointer)
   424  	} else {
   425  		return jsonpkg.Unmarshal(bodyBytes, pointer)
   426  	}
   427  }
   428  
   429  func (b *Binding) ResetJSONUnmarshaler(fn JSONUnmarshaler) {
   430  	b.jsonUnmarshalFunc = fn
   431  }