github.com/bytedance/go-tagexpr/v2@v2.9.8/binding/bind.go (about) 1 package binding 2 3 import ( 4 jsonpkg "encoding/json" 5 "mime/multipart" 6 "net/http" 7 "reflect" 8 "strings" 9 "sync" 10 11 "github.com/andeya/ameda" 12 "github.com/andeya/goutil" 13 14 "github.com/bytedance/go-tagexpr/v2" 15 "github.com/bytedance/go-tagexpr/v2/validator" 16 ) 17 18 // Binding binding and verification tool for http request 19 type Binding struct { 20 vd *validator.Validator 21 recvs map[uintptr]*receiver 22 lock sync.RWMutex 23 bindErrFactory func(failField, msg string) error 24 config Config 25 jsonUnmarshalFunc func(data []byte, v interface{}) error 26 } 27 28 // New creates a binding tool. 29 // NOTE: 30 // 31 // Use default tag name for config fields that are empty 32 func New(config *Config) *Binding { 33 if config == nil { 34 config = new(Config) 35 } 36 b := &Binding{ 37 recvs: make(map[uintptr]*receiver, 1024), 38 config: *config, 39 } 40 b.config.init() 41 b.vd = validator.New(b.config.Validator) 42 return b.SetErrorFactory(nil, nil) 43 } 44 45 // SetLooseZeroMode if set to true, 46 // the empty string request parameter is bound to the zero value of parameter. 47 // NOTE: 48 // 49 // The default is false; 50 // Suitable for these parameter types: query/header/cookie/form . 51 func (b *Binding) SetLooseZeroMode(enable bool) *Binding { 52 b.config.LooseZeroMode = enable 53 for k := range b.recvs { 54 delete(b.recvs, k) 55 } 56 return b 57 } 58 59 var defaultValidatingErrFactory = newDefaultErrorFactory("validating") 60 var defaultBindErrFactory = newDefaultErrorFactory("binding") 61 62 // SetErrorFactory customizes the factory of validation error. 63 // NOTE: 64 // 65 // If errFactory==nil, the default is used 66 func (b *Binding) SetErrorFactory(bindErrFactory, validatingErrFactory func(failField, msg string) error) *Binding { 67 if bindErrFactory == nil { 68 bindErrFactory = defaultBindErrFactory 69 } 70 if validatingErrFactory == nil { 71 validatingErrFactory = defaultValidatingErrFactory 72 } 73 b.bindErrFactory = bindErrFactory 74 b.vd.SetErrorFactory(validatingErrFactory) 75 return b 76 } 77 78 // BindAndValidate binds the request parameters and validates them if needed. 79 func (b *Binding) BindAndValidate(recvPointer interface{}, req *http.Request, pathParams PathParams) error { 80 return b.IBindAndValidate(recvPointer, wrapRequest(req), pathParams) 81 } 82 83 // Bind binds the request parameters. 84 func (b *Binding) Bind(recvPointer interface{}, req *http.Request, pathParams PathParams) error { 85 return b.IBind(recvPointer, wrapRequest(req), pathParams) 86 } 87 88 // IBindAndValidate binds the request parameters and validates them if needed. 89 func (b *Binding) IBindAndValidate(recvPointer interface{}, req Request, pathParams PathParams) error { 90 v, hasVd, err := b.bind(recvPointer, req, pathParams) 91 if err != nil { 92 return err 93 } 94 if hasVd { 95 return b.vd.Validate(v) 96 } 97 return nil 98 } 99 100 // IBind binds the request parameters. 101 func (b *Binding) IBind(recvPointer interface{}, req Request, pathParams PathParams) error { 102 _, _, err := b.bind(recvPointer, req, pathParams) 103 return err 104 } 105 106 // Validate validates whether the fields of value is valid. 107 func (b *Binding) Validate(value interface{}) error { 108 return b.vd.Validate(value) 109 } 110 111 func (b *Binding) bind(pointer interface{}, req Request, pathParams PathParams) (elemValue reflect.Value, hasVd bool, err error) { 112 elemValue, err = b.receiverValueOf(pointer) 113 if err != nil { 114 return 115 } 116 if elemValue.Kind() == reflect.Struct { 117 hasVd, err = b.bindStruct(pointer, elemValue, req, pathParams) 118 } else { 119 hasVd, err = b.bindNonstruct(pointer, elemValue, req, pathParams) 120 } 121 return 122 } 123 124 func (b *Binding) bindNonstruct(pointer interface{}, _ reflect.Value, req Request, _ PathParams) (hasVd bool, err error) { 125 bodyCodec := getBodyCodec(req) 126 switch bodyCodec { 127 case bodyJSON: 128 hasVd = true 129 bodyBytes, err := req.GetBody() 130 if err != nil { 131 return hasVd, err 132 } 133 err = b.bindJSON(pointer, bodyBytes) 134 case bodyProtobuf: 135 hasVd = true 136 bodyBytes, err := req.GetBody() 137 if err != nil { 138 return hasVd, err 139 } 140 err = bindProtobuf(pointer, bodyBytes) 141 case bodyForm: 142 postForm, err := req.GetPostForm() 143 if err != nil { 144 return false, err 145 } 146 b, _ := jsonpkg.Marshal(postForm) 147 err = jsonpkg.Unmarshal(b, pointer) 148 default: 149 // query and form 150 form, err := req.GetForm() 151 if err != nil { 152 return false, err 153 } 154 b, _ := jsonpkg.Marshal(form) 155 err = jsonpkg.Unmarshal(b, pointer) 156 } 157 return 158 } 159 160 func (b *Binding) bindStruct(structPointer interface{}, structValue reflect.Value, req Request, pathParams PathParams) (hasVd bool, err error) { 161 recv, err := b.getOrPrepareReceiver(structValue) 162 if err != nil { 163 return 164 } 165 166 expr, err := b.vd.VM().Run(structValue) 167 if err != nil { 168 return 169 } 170 171 bodyCodec, bodyBytes, err := recv.getBodyInfo(req) 172 if len(bodyBytes) > 0 { 173 err = b.prebindBody(structPointer, structValue, bodyCodec, bodyBytes) 174 } 175 if err != nil { 176 return 177 } 178 bodyString := ameda.UnsafeBytesToString(bodyBytes) 179 postForm, err := req.GetPostForm() 180 if err != nil { 181 return 182 } 183 var fileHeaders map[string][]*multipart.FileHeader 184 if _req, ok := req.(requestWithFileHeader); ok { 185 fileHeaders, err = _req.GetFileHeaders() 186 if err != nil { 187 return 188 } 189 } 190 queryValues := recv.getQuery(req) 191 cookies := recv.getCookies(req) 192 193 for _, param := range recv.params { 194 for i, info := range param.tagInfos { 195 var found bool 196 switch info.paramIn { 197 case raw_body: 198 err = param.bindRawBody(info, expr, bodyBytes) 199 found = err == nil 200 case path: 201 found, err = param.bindPath(info, expr, pathParams) 202 case query: 203 found, err = param.bindQuery(info, expr, queryValues) 204 case cookie: 205 found, err = param.bindCookie(info, expr, cookies) 206 case header: 207 found, err = param.bindHeader(info, expr, req.GetHeader()) 208 case form, json, protobuf: 209 if info.paramIn == in(bodyCodec) { 210 found, err = param.bindOrRequireBody(info, expr, bodyCodec, bodyString, postForm, fileHeaders, 211 recv.hasDefaultVal) 212 } else if info.required { 213 found = false 214 err = info.requiredError 215 } 216 case default_val: 217 found, err = param.bindDefaultVal(expr, param.defaultVal) 218 } 219 if found && err == nil { 220 break 221 } 222 if (found || i == len(param.tagInfos)-1) && err != nil { 223 return recv.hasVd, err 224 } 225 } 226 } 227 return recv.hasVd, nil 228 } 229 230 func (b *Binding) receiverValueOf(receiver interface{}) (reflect.Value, error) { 231 v := reflect.ValueOf(receiver) 232 if v.Kind() == reflect.Ptr { 233 v = ameda.DereferencePtrValue(v) 234 if v.IsValid() && v.CanAddr() { 235 return v, nil 236 } 237 } 238 return v, b.bindErrFactory("", "receiver must be a non-nil pointer") 239 } 240 241 func (b *Binding) getOrPrepareReceiver(value reflect.Value) (*receiver, error) { 242 runtimeTypeID := ameda.ValueFrom(value).RuntimeTypeID() 243 b.lock.RLock() 244 recv, ok := b.recvs[runtimeTypeID] 245 b.lock.RUnlock() 246 if ok { 247 return recv, nil 248 } 249 t := value.Type() 250 expr, err := b.vd.VM().Run(reflect.New(t).Elem()) 251 if err != nil { 252 return nil, err 253 } 254 recv = &receiver{ 255 params: make([]*paramInfo, 0, 16), 256 looseZeroMode: b.config.LooseZeroMode, 257 } 258 var errExprSelector tagexpr.ExprSelector 259 var errMsg string 260 var fieldsWithValidTag = make(map[string]bool) 261 expr.RangeFields(func(fh *tagexpr.FieldHandler) bool { 262 if !fh.Value(true).CanSet() { 263 selector := fh.StringSelector() 264 errMsg = "field cannot be set: " + selector 265 errExprSelector = tagexpr.ExprSelector(selector) 266 return true 267 } 268 269 tagKVs := b.config.parse(fh.StructField()) 270 p := recv.getOrAddParam(fh, b.bindErrFactory) 271 tagInfos := [maxIn]*tagInfo{} 272 L: 273 for _, tagKV := range tagKVs { 274 paramIn := undefined 275 switch tagKV.name { 276 case b.config.Validator: 277 recv.hasVd = true 278 continue L 279 case b.config.PathParam: 280 paramIn = path 281 case b.config.FormBody: 282 paramIn = form 283 case b.config.Query: 284 paramIn = query 285 case b.config.Cookie: 286 paramIn = cookie 287 case b.config.Header: 288 paramIn = header 289 case b.config.protobufBody: 290 paramIn = protobuf 291 case b.config.jsonBody: 292 paramIn = json 293 case b.config.RawBody: 294 paramIn = raw_body 295 case b.config.defaultVal: 296 paramIn = default_val 297 default: 298 continue L 299 } 300 if paramIn == default_val { 301 tagInfos[paramIn] = &tagInfo{paramIn: default_val, paramName: tagKV.value} 302 } else { 303 tagInfos[paramIn] = tagKV.toInfo(paramIn == header) 304 } 305 } 306 307 for i, info := range tagInfos { 308 if info != nil { 309 if info.paramIn != default_val && info.paramName == "-" { 310 p.omitIns[in(i)] = true 311 recv.assginIn(in(i), false) 312 } else { 313 info.paramIn = in(i) 314 p.tagInfos = append(p.tagInfos, info) 315 recv.assginIn(in(i), true) 316 } 317 } 318 } 319 fs := string(fh.FieldSelector()) 320 switch len(p.tagInfos) { 321 case 0: 322 var canDefault = true 323 for s := range fieldsWithValidTag { 324 if strings.HasPrefix(fs, s) { 325 canDefault = false 326 break 327 } 328 } 329 if canDefault { 330 if !goutil.IsExportedName(p.structField.Name) { 331 canDefault = false 332 } 333 } 334 // Supports the default binding order when there is no valid tag in the superior field of the exportable field 335 if canDefault { 336 for _, i := range sortedDefaultIn { 337 if p.omitIns[i] { 338 recv.assginIn(i, false) 339 continue 340 } 341 p.tagInfos = append(p.tagInfos, &tagInfo{ 342 paramIn: i, 343 paramName: p.structField.Name, 344 }) 345 recv.assginIn(i, true) 346 } 347 } 348 case 1: 349 if p.tagInfos[0].paramIn == default_val { 350 last := p.tagInfos[0] 351 p.tagInfos = make([]*tagInfo, 0, len(sortedDefaultIn)+1) 352 for _, i := range sortedDefaultIn { 353 if p.omitIns[i] { 354 recv.assginIn(i, false) 355 continue 356 } 357 p.tagInfos = append(p.tagInfos, &tagInfo{ 358 paramIn: i, 359 paramName: p.structField.Name, 360 }) 361 recv.assginIn(i, true) 362 } 363 p.tagInfos = append(p.tagInfos, last) 364 } 365 fallthrough 366 default: 367 fieldsWithValidTag[fs+tagexpr.FieldSeparator] = true 368 } 369 if !recv.hasVd { 370 _, recv.hasVd = tagKVs.lookup(b.config.Validator) 371 } 372 return true 373 }) 374 375 if errMsg != "" { 376 return nil, b.bindErrFactory(errExprSelector.String(), errMsg) 377 } 378 if !recv.hasVd { 379 recv.hasVd, _ = b.findVdTag(ameda.DereferenceType(t), false, 20, map[reflect.Type]bool{}) 380 } 381 recv.initParams() 382 383 b.lock.Lock() 384 b.recvs[runtimeTypeID] = recv 385 b.lock.Unlock() 386 387 return recv, nil 388 } 389 390 func (b *Binding) findVdTag(t reflect.Type, inMapOrSlice bool, depth int, exist map[reflect.Type]bool) (hasVd bool, err error) { 391 if depth <= 0 || exist[t] { 392 return 393 } 394 depth-- 395 switch t.Kind() { 396 case reflect.Struct: 397 exist[t] = true 398 for i := t.NumField() - 1; i >= 0; i-- { 399 field := t.Field(i) 400 if inMapOrSlice { 401 tagKVs := b.config.parse(field) 402 for _, tagKV := range tagKVs { 403 if tagKV.name == b.config.Validator { 404 return true, nil 405 } 406 } 407 } 408 hasVd, _ = b.findVdTag(ameda.DereferenceType(field.Type), inMapOrSlice, depth, exist) 409 if hasVd { 410 return true, nil 411 } 412 } 413 return false, nil 414 case reflect.Slice, reflect.Array, reflect.Map: 415 return b.findVdTag(ameda.DereferenceType(t.Elem()), true, depth, exist) 416 default: 417 return false, nil 418 } 419 } 420 421 func (b *Binding) bindJSON(pointer interface{}, bodyBytes []byte) error { 422 if b.jsonUnmarshalFunc != nil { 423 return b.jsonUnmarshalFunc(bodyBytes, pointer) 424 } else { 425 return jsonpkg.Unmarshal(bodyBytes, pointer) 426 } 427 } 428 429 func (b *Binding) ResetJSONUnmarshaler(fn JSONUnmarshaler) { 430 b.jsonUnmarshalFunc = fn 431 }