github.com/bytedance/go-tagexpr@v2.7.5-0.20210114074101-de5b8743ad85+incompatible/binding/bind.go (about)

     1  package binding
     2  
     3  import (
     4  	jsonpkg "encoding/json"
     5  	"net/http"
     6  	"reflect"
     7  	"strings"
     8  	"sync"
     9  
    10  	"github.com/henrylee2cn/ameda"
    11  	"github.com/henrylee2cn/goutil"
    12  	"github.com/henrylee2cn/goutil/tpack"
    13  
    14  	"github.com/bytedance/go-tagexpr"
    15  	"github.com/bytedance/go-tagexpr/validator"
    16  )
    17  
    18  // Binding binding and verification tool for http request
    19  type Binding struct {
    20  	vd             *validator.Validator
    21  	recvs          map[uintptr]*receiver
    22  	lock           sync.RWMutex
    23  	bindErrFactory func(failField, msg string) error
    24  	config         Config
    25  }
    26  
    27  // New creates a binding tool.
    28  // NOTE:
    29  //  Use default tag name for config fields that are empty
    30  func New(config *Config) *Binding {
    31  	if config == nil {
    32  		config = new(Config)
    33  	}
    34  	b := &Binding{
    35  		recvs:  make(map[uintptr]*receiver, 1024),
    36  		config: *config,
    37  	}
    38  	b.config.init()
    39  	b.vd = validator.New(b.config.Validator)
    40  	return b.SetErrorFactory(nil, nil)
    41  }
    42  
    43  // SetLooseZeroMode if set to true,
    44  // the empty string request parameter is bound to the zero value of parameter.
    45  // NOTE:
    46  //  The default is false;
    47  //  Suitable for these parameter types: query/header/cookie/form .
    48  func (b *Binding) SetLooseZeroMode(enable bool) *Binding {
    49  	b.config.LooseZeroMode = enable
    50  	for k := range b.recvs {
    51  		delete(b.recvs, k)
    52  	}
    53  	return b
    54  }
    55  
    56  var defaultValidatingErrFactory = newDefaultErrorFactory("validating")
    57  var defaultBindErrFactory = newDefaultErrorFactory("binding")
    58  
    59  // SetErrorFactory customizes the factory of validation error.
    60  // NOTE:
    61  //  If errFactory==nil, the default is used
    62  func (b *Binding) SetErrorFactory(bindErrFactory, validatingErrFactory func(failField, msg string) error) *Binding {
    63  	if bindErrFactory == nil {
    64  		bindErrFactory = defaultBindErrFactory
    65  	}
    66  	if validatingErrFactory == nil {
    67  		validatingErrFactory = defaultValidatingErrFactory
    68  	}
    69  	b.bindErrFactory = bindErrFactory
    70  	b.vd.SetErrorFactory(validatingErrFactory)
    71  	return b
    72  }
    73  
    74  // BindAndValidate binds the request parameters and validates them if needed.
    75  func (b *Binding) BindAndValidate(recvPointer interface{}, req *http.Request, pathParams PathParams) error {
    76  	return b.IBindAndValidate(recvPointer, wrapRequest(req), pathParams)
    77  }
    78  
    79  // Bind binds the request parameters.
    80  func (b *Binding) Bind(recvPointer interface{}, req *http.Request, pathParams PathParams) error {
    81  	return b.IBind(recvPointer, wrapRequest(req), pathParams)
    82  }
    83  
    84  // IBindAndValidate binds the request parameters and validates them if needed.
    85  func (b *Binding) IBindAndValidate(recvPointer interface{}, req Request, pathParams PathParams) error {
    86  	v, hasVd, err := b.bind(recvPointer, req, pathParams)
    87  	if err != nil {
    88  		return err
    89  	}
    90  	if hasVd {
    91  		return b.vd.Validate(v)
    92  	}
    93  	return nil
    94  }
    95  
    96  // IBind binds the request parameters.
    97  func (b *Binding) IBind(recvPointer interface{}, req Request, pathParams PathParams) error {
    98  	_, _, err := b.bind(recvPointer, req, pathParams)
    99  	return err
   100  }
   101  
   102  // Validate validates whether the fields of value is valid.
   103  func (b *Binding) Validate(value interface{}) error {
   104  	return b.vd.Validate(value)
   105  }
   106  
   107  func (b *Binding) bind(pointer interface{}, req Request, pathParams PathParams) (elemValue reflect.Value, hasVd bool, err error) {
   108  	elemValue, err = b.receiverValueOf(pointer)
   109  	if err != nil {
   110  		return
   111  	}
   112  	if elemValue.Kind() == reflect.Struct {
   113  		hasVd, err = b.bindStruct(pointer, elemValue, req, pathParams)
   114  	} else {
   115  		hasVd, err = b.bindNonstruct(pointer, elemValue, req, pathParams)
   116  	}
   117  	return
   118  }
   119  
   120  func (b *Binding) bindNonstruct(pointer interface{}, _ reflect.Value, req Request, _ PathParams) (hasVd bool, err error) {
   121  	bodyCodec := getBodyCodec(req)
   122  	switch bodyCodec {
   123  	case bodyJSON:
   124  		hasVd = true
   125  		bodyBytes, err := req.GetBody()
   126  		if err != nil {
   127  			return hasVd, err
   128  		}
   129  		err = bindJSON(pointer, bodyBytes)
   130  	case bodyProtobuf:
   131  		hasVd = true
   132  		bodyBytes, err := req.GetBody()
   133  		if err != nil {
   134  			return hasVd, err
   135  		}
   136  		err = bindProtobuf(pointer, bodyBytes)
   137  	case bodyForm:
   138  		postForm, err := req.GetPostForm()
   139  		if err != nil {
   140  			return false, err
   141  		}
   142  		b, _ := jsonpkg.Marshal(postForm)
   143  		err = jsonpkg.Unmarshal(b, pointer)
   144  	default:
   145  		// query and form
   146  		form, err := req.GetForm()
   147  		if err != nil {
   148  			return false, err
   149  		}
   150  		b, _ := jsonpkg.Marshal(form)
   151  		err = jsonpkg.Unmarshal(b, pointer)
   152  	}
   153  	return
   154  }
   155  
   156  func (b *Binding) bindStruct(structPointer interface{}, structValue reflect.Value, req Request, pathParams PathParams) (hasVd bool, err error) {
   157  	recv, err := b.getOrPrepareReceiver(structValue)
   158  	if err != nil {
   159  		return
   160  	}
   161  
   162  	expr, err := b.vd.VM().Run(structValue)
   163  	if err != nil {
   164  		return
   165  	}
   166  
   167  	bodyCodec, bodyBytes, err := recv.getBodyInfo(req)
   168  	if len(bodyBytes) > 0 {
   169  		err = recv.prebindBody(structPointer, structValue, bodyCodec, bodyBytes)
   170  	}
   171  	if err != nil {
   172  		return
   173  	}
   174  	bodyString := ameda.UnsafeBytesToString(bodyBytes)
   175  	postForm, err := req.GetPostForm()
   176  	if err != nil {
   177  		return
   178  	}
   179  	queryValues := recv.getQuery(req)
   180  	cookies := recv.getCookies(req)
   181  
   182  	for _, param := range recv.params {
   183  		for i, info := range param.tagInfos {
   184  			var found bool
   185  			switch info.paramIn {
   186  			case path:
   187  				found, err = param.bindPath(info, expr, pathParams)
   188  			case query:
   189  				found, err = param.bindQuery(info, expr, queryValues)
   190  			case cookie:
   191  				err = param.bindCookie(info, expr, cookies)
   192  				found = err == nil
   193  			case header:
   194  				found, err = param.bindHeader(info, expr, req.GetHeader())
   195  			case form, json, protobuf:
   196  				if info.paramIn == in(bodyCodec) {
   197  					found, err = param.bindOrRequireBody(info, expr, bodyCodec, bodyString, postForm)
   198  				} else if info.required {
   199  					found = false
   200  					err = info.requiredError
   201  				}
   202  			case raw_body:
   203  				err = param.bindRawBody(info, expr, bodyBytes)
   204  				found = err == nil
   205  			case default_val:
   206  				found, err = param.bindDefaultVal(expr, param.defaultVal)
   207  			}
   208  			if found && err == nil {
   209  				break
   210  			}
   211  			if (found || i == len(param.tagInfos)-1) && err != nil {
   212  				return recv.hasVd, err
   213  			}
   214  		}
   215  	}
   216  	return recv.hasVd, nil
   217  }
   218  
   219  func (b *Binding) receiverValueOf(receiver interface{}) (reflect.Value, error) {
   220  	v := reflect.ValueOf(receiver)
   221  	if v.Kind() == reflect.Ptr {
   222  		v = goutil.DereferencePtrValue(v)
   223  		if v.IsValid() && v.CanAddr() {
   224  			return v, nil
   225  		}
   226  	}
   227  	return v, b.bindErrFactory("", "receiver must be a non-nil pointer")
   228  }
   229  
   230  func (b *Binding) getOrPrepareReceiver(value reflect.Value) (*receiver, error) {
   231  	runtimeTypeID := tpack.From(value).RuntimeTypeID()
   232  	b.lock.RLock()
   233  	recv, ok := b.recvs[runtimeTypeID]
   234  	b.lock.RUnlock()
   235  	if ok {
   236  		return recv, nil
   237  	}
   238  
   239  	expr, err := b.vd.VM().Run(reflect.New(value.Type()).Elem())
   240  	if err != nil {
   241  		return nil, err
   242  	}
   243  	recv = &receiver{
   244  		params:        make([]*paramInfo, 0, 16),
   245  		looseZeroMode: b.config.LooseZeroMode,
   246  	}
   247  	var errExprSelector tagexpr.ExprSelector
   248  	var errMsg string
   249  	var fieldsWithValidTag = make(map[string]bool)
   250  	expr.RangeFields(func(fh *tagexpr.FieldHandler) bool {
   251  		if !fh.Value(true).CanSet() {
   252  			selector := fh.StringSelector()
   253  			errMsg = "field cannot be set: " + selector
   254  			errExprSelector = tagexpr.ExprSelector(selector)
   255  			return true
   256  		}
   257  
   258  		tagKVs := b.config.parse(fh.StructField())
   259  		p := recv.getOrAddParam(fh, b.bindErrFactory)
   260  		tagInfos := [maxIn]*tagInfo{}
   261  	L:
   262  		for _, tagKV := range tagKVs {
   263  			paramIn := undefined
   264  			switch tagKV.name {
   265  			case b.config.Validator:
   266  				recv.hasVd = true
   267  				continue L
   268  			case b.config.PathParam:
   269  				paramIn = path
   270  			case b.config.FormBody:
   271  				paramIn = form
   272  			case b.config.Query:
   273  				paramIn = query
   274  			case b.config.Cookie:
   275  				paramIn = cookie
   276  			case b.config.Header:
   277  				paramIn = header
   278  			case b.config.protobufBody:
   279  				paramIn = protobuf
   280  			case b.config.jsonBody:
   281  				paramIn = json
   282  			case b.config.RawBody:
   283  				paramIn = raw_body
   284  			case b.config.defaultVal:
   285  				paramIn = default_val
   286  			default:
   287  				continue L
   288  			}
   289  			if paramIn == default_val {
   290  				tagInfos[paramIn] = &tagInfo{paramIn: default_val, paramName: tagKV.value}
   291  			} else {
   292  				tagInfos[paramIn] = tagKV.toInfo(paramIn == header)
   293  			}
   294  		}
   295  
   296  		for i, info := range tagInfos {
   297  			if info != nil {
   298  				if info.paramIn != default_val && info.paramName == "-" {
   299  					p.omitIns[in(i)] = true
   300  					recv.assginIn(in(i), false)
   301  				} else {
   302  					info.paramIn = in(i)
   303  					p.tagInfos = append(p.tagInfos, info)
   304  					recv.assginIn(in(i), true)
   305  				}
   306  			}
   307  		}
   308  		fs := string(fh.FieldSelector())
   309  		if len(p.tagInfos) == 0 {
   310  			var canDefault = true
   311  			for s := range fieldsWithValidTag {
   312  				if strings.HasPrefix(fs, s) {
   313  					canDefault = false
   314  					break
   315  				}
   316  			}
   317  			if canDefault {
   318  				if !goutil.IsExportedName(p.structField.Name) {
   319  					canDefault = false
   320  				}
   321  			}
   322  			// Supports the default binding order when there is no valid tag in the superior field of the exportable field
   323  			if canDefault {
   324  				for _, i := range sortedDefaultIn {
   325  					if p.omitIns[i] {
   326  						recv.assginIn(i, false)
   327  						continue
   328  					}
   329  					p.tagInfos = append(p.tagInfos, &tagInfo{
   330  						paramIn:   i,
   331  						paramName: p.structField.Name,
   332  					})
   333  					recv.assginIn(i, true)
   334  				}
   335  			}
   336  		} else {
   337  			fieldsWithValidTag[fs+tagexpr.FieldSeparator] = true
   338  		}
   339  		if !recv.hasVd {
   340  			_, recv.hasVd = tagKVs.lookup(b.config.Validator)
   341  		}
   342  		return true
   343  	})
   344  
   345  	if errMsg != "" {
   346  		return nil, b.bindErrFactory(errExprSelector.String(), errMsg)
   347  	}
   348  
   349  	recv.initParams()
   350  
   351  	b.lock.Lock()
   352  	b.recvs[runtimeTypeID] = recv
   353  	b.lock.Unlock()
   354  
   355  	return recv, nil
   356  }