github.com/bytedance/go-tagexpr@v2.7.5-0.20210114074101-de5b8743ad85+incompatible/binding/bind.go (about) 1 package binding 2 3 import ( 4 jsonpkg "encoding/json" 5 "net/http" 6 "reflect" 7 "strings" 8 "sync" 9 10 "github.com/henrylee2cn/ameda" 11 "github.com/henrylee2cn/goutil" 12 "github.com/henrylee2cn/goutil/tpack" 13 14 "github.com/bytedance/go-tagexpr" 15 "github.com/bytedance/go-tagexpr/validator" 16 ) 17 18 // Binding binding and verification tool for http request 19 type Binding struct { 20 vd *validator.Validator 21 recvs map[uintptr]*receiver 22 lock sync.RWMutex 23 bindErrFactory func(failField, msg string) error 24 config Config 25 } 26 27 // New creates a binding tool. 28 // NOTE: 29 // Use default tag name for config fields that are empty 30 func New(config *Config) *Binding { 31 if config == nil { 32 config = new(Config) 33 } 34 b := &Binding{ 35 recvs: make(map[uintptr]*receiver, 1024), 36 config: *config, 37 } 38 b.config.init() 39 b.vd = validator.New(b.config.Validator) 40 return b.SetErrorFactory(nil, nil) 41 } 42 43 // SetLooseZeroMode if set to true, 44 // the empty string request parameter is bound to the zero value of parameter. 45 // NOTE: 46 // The default is false; 47 // Suitable for these parameter types: query/header/cookie/form . 48 func (b *Binding) SetLooseZeroMode(enable bool) *Binding { 49 b.config.LooseZeroMode = enable 50 for k := range b.recvs { 51 delete(b.recvs, k) 52 } 53 return b 54 } 55 56 var defaultValidatingErrFactory = newDefaultErrorFactory("validating") 57 var defaultBindErrFactory = newDefaultErrorFactory("binding") 58 59 // SetErrorFactory customizes the factory of validation error. 60 // NOTE: 61 // If errFactory==nil, the default is used 62 func (b *Binding) SetErrorFactory(bindErrFactory, validatingErrFactory func(failField, msg string) error) *Binding { 63 if bindErrFactory == nil { 64 bindErrFactory = defaultBindErrFactory 65 } 66 if validatingErrFactory == nil { 67 validatingErrFactory = defaultValidatingErrFactory 68 } 69 b.bindErrFactory = bindErrFactory 70 b.vd.SetErrorFactory(validatingErrFactory) 71 return b 72 } 73 74 // BindAndValidate binds the request parameters and validates them if needed. 75 func (b *Binding) BindAndValidate(recvPointer interface{}, req *http.Request, pathParams PathParams) error { 76 return b.IBindAndValidate(recvPointer, wrapRequest(req), pathParams) 77 } 78 79 // Bind binds the request parameters. 80 func (b *Binding) Bind(recvPointer interface{}, req *http.Request, pathParams PathParams) error { 81 return b.IBind(recvPointer, wrapRequest(req), pathParams) 82 } 83 84 // IBindAndValidate binds the request parameters and validates them if needed. 85 func (b *Binding) IBindAndValidate(recvPointer interface{}, req Request, pathParams PathParams) error { 86 v, hasVd, err := b.bind(recvPointer, req, pathParams) 87 if err != nil { 88 return err 89 } 90 if hasVd { 91 return b.vd.Validate(v) 92 } 93 return nil 94 } 95 96 // IBind binds the request parameters. 97 func (b *Binding) IBind(recvPointer interface{}, req Request, pathParams PathParams) error { 98 _, _, err := b.bind(recvPointer, req, pathParams) 99 return err 100 } 101 102 // Validate validates whether the fields of value is valid. 103 func (b *Binding) Validate(value interface{}) error { 104 return b.vd.Validate(value) 105 } 106 107 func (b *Binding) bind(pointer interface{}, req Request, pathParams PathParams) (elemValue reflect.Value, hasVd bool, err error) { 108 elemValue, err = b.receiverValueOf(pointer) 109 if err != nil { 110 return 111 } 112 if elemValue.Kind() == reflect.Struct { 113 hasVd, err = b.bindStruct(pointer, elemValue, req, pathParams) 114 } else { 115 hasVd, err = b.bindNonstruct(pointer, elemValue, req, pathParams) 116 } 117 return 118 } 119 120 func (b *Binding) bindNonstruct(pointer interface{}, _ reflect.Value, req Request, _ PathParams) (hasVd bool, err error) { 121 bodyCodec := getBodyCodec(req) 122 switch bodyCodec { 123 case bodyJSON: 124 hasVd = true 125 bodyBytes, err := req.GetBody() 126 if err != nil { 127 return hasVd, err 128 } 129 err = bindJSON(pointer, bodyBytes) 130 case bodyProtobuf: 131 hasVd = true 132 bodyBytes, err := req.GetBody() 133 if err != nil { 134 return hasVd, err 135 } 136 err = bindProtobuf(pointer, bodyBytes) 137 case bodyForm: 138 postForm, err := req.GetPostForm() 139 if err != nil { 140 return false, err 141 } 142 b, _ := jsonpkg.Marshal(postForm) 143 err = jsonpkg.Unmarshal(b, pointer) 144 default: 145 // query and form 146 form, err := req.GetForm() 147 if err != nil { 148 return false, err 149 } 150 b, _ := jsonpkg.Marshal(form) 151 err = jsonpkg.Unmarshal(b, pointer) 152 } 153 return 154 } 155 156 func (b *Binding) bindStruct(structPointer interface{}, structValue reflect.Value, req Request, pathParams PathParams) (hasVd bool, err error) { 157 recv, err := b.getOrPrepareReceiver(structValue) 158 if err != nil { 159 return 160 } 161 162 expr, err := b.vd.VM().Run(structValue) 163 if err != nil { 164 return 165 } 166 167 bodyCodec, bodyBytes, err := recv.getBodyInfo(req) 168 if len(bodyBytes) > 0 { 169 err = recv.prebindBody(structPointer, structValue, bodyCodec, bodyBytes) 170 } 171 if err != nil { 172 return 173 } 174 bodyString := ameda.UnsafeBytesToString(bodyBytes) 175 postForm, err := req.GetPostForm() 176 if err != nil { 177 return 178 } 179 queryValues := recv.getQuery(req) 180 cookies := recv.getCookies(req) 181 182 for _, param := range recv.params { 183 for i, info := range param.tagInfos { 184 var found bool 185 switch info.paramIn { 186 case path: 187 found, err = param.bindPath(info, expr, pathParams) 188 case query: 189 found, err = param.bindQuery(info, expr, queryValues) 190 case cookie: 191 err = param.bindCookie(info, expr, cookies) 192 found = err == nil 193 case header: 194 found, err = param.bindHeader(info, expr, req.GetHeader()) 195 case form, json, protobuf: 196 if info.paramIn == in(bodyCodec) { 197 found, err = param.bindOrRequireBody(info, expr, bodyCodec, bodyString, postForm) 198 } else if info.required { 199 found = false 200 err = info.requiredError 201 } 202 case raw_body: 203 err = param.bindRawBody(info, expr, bodyBytes) 204 found = err == nil 205 case default_val: 206 found, err = param.bindDefaultVal(expr, param.defaultVal) 207 } 208 if found && err == nil { 209 break 210 } 211 if (found || i == len(param.tagInfos)-1) && err != nil { 212 return recv.hasVd, err 213 } 214 } 215 } 216 return recv.hasVd, nil 217 } 218 219 func (b *Binding) receiverValueOf(receiver interface{}) (reflect.Value, error) { 220 v := reflect.ValueOf(receiver) 221 if v.Kind() == reflect.Ptr { 222 v = goutil.DereferencePtrValue(v) 223 if v.IsValid() && v.CanAddr() { 224 return v, nil 225 } 226 } 227 return v, b.bindErrFactory("", "receiver must be a non-nil pointer") 228 } 229 230 func (b *Binding) getOrPrepareReceiver(value reflect.Value) (*receiver, error) { 231 runtimeTypeID := tpack.From(value).RuntimeTypeID() 232 b.lock.RLock() 233 recv, ok := b.recvs[runtimeTypeID] 234 b.lock.RUnlock() 235 if ok { 236 return recv, nil 237 } 238 239 expr, err := b.vd.VM().Run(reflect.New(value.Type()).Elem()) 240 if err != nil { 241 return nil, err 242 } 243 recv = &receiver{ 244 params: make([]*paramInfo, 0, 16), 245 looseZeroMode: b.config.LooseZeroMode, 246 } 247 var errExprSelector tagexpr.ExprSelector 248 var errMsg string 249 var fieldsWithValidTag = make(map[string]bool) 250 expr.RangeFields(func(fh *tagexpr.FieldHandler) bool { 251 if !fh.Value(true).CanSet() { 252 selector := fh.StringSelector() 253 errMsg = "field cannot be set: " + selector 254 errExprSelector = tagexpr.ExprSelector(selector) 255 return true 256 } 257 258 tagKVs := b.config.parse(fh.StructField()) 259 p := recv.getOrAddParam(fh, b.bindErrFactory) 260 tagInfos := [maxIn]*tagInfo{} 261 L: 262 for _, tagKV := range tagKVs { 263 paramIn := undefined 264 switch tagKV.name { 265 case b.config.Validator: 266 recv.hasVd = true 267 continue L 268 case b.config.PathParam: 269 paramIn = path 270 case b.config.FormBody: 271 paramIn = form 272 case b.config.Query: 273 paramIn = query 274 case b.config.Cookie: 275 paramIn = cookie 276 case b.config.Header: 277 paramIn = header 278 case b.config.protobufBody: 279 paramIn = protobuf 280 case b.config.jsonBody: 281 paramIn = json 282 case b.config.RawBody: 283 paramIn = raw_body 284 case b.config.defaultVal: 285 paramIn = default_val 286 default: 287 continue L 288 } 289 if paramIn == default_val { 290 tagInfos[paramIn] = &tagInfo{paramIn: default_val, paramName: tagKV.value} 291 } else { 292 tagInfos[paramIn] = tagKV.toInfo(paramIn == header) 293 } 294 } 295 296 for i, info := range tagInfos { 297 if info != nil { 298 if info.paramIn != default_val && info.paramName == "-" { 299 p.omitIns[in(i)] = true 300 recv.assginIn(in(i), false) 301 } else { 302 info.paramIn = in(i) 303 p.tagInfos = append(p.tagInfos, info) 304 recv.assginIn(in(i), true) 305 } 306 } 307 } 308 fs := string(fh.FieldSelector()) 309 if len(p.tagInfos) == 0 { 310 var canDefault = true 311 for s := range fieldsWithValidTag { 312 if strings.HasPrefix(fs, s) { 313 canDefault = false 314 break 315 } 316 } 317 if canDefault { 318 if !goutil.IsExportedName(p.structField.Name) { 319 canDefault = false 320 } 321 } 322 // Supports the default binding order when there is no valid tag in the superior field of the exportable field 323 if canDefault { 324 for _, i := range sortedDefaultIn { 325 if p.omitIns[i] { 326 recv.assginIn(i, false) 327 continue 328 } 329 p.tagInfos = append(p.tagInfos, &tagInfo{ 330 paramIn: i, 331 paramName: p.structField.Name, 332 }) 333 recv.assginIn(i, true) 334 } 335 } 336 } else { 337 fieldsWithValidTag[fs+tagexpr.FieldSeparator] = true 338 } 339 if !recv.hasVd { 340 _, recv.hasVd = tagKVs.lookup(b.config.Validator) 341 } 342 return true 343 }) 344 345 if errMsg != "" { 346 return nil, b.bindErrFactory(errExprSelector.String(), errMsg) 347 } 348 349 recv.initParams() 350 351 b.lock.Lock() 352 b.recvs[runtimeTypeID] = recv 353 b.lock.Unlock() 354 355 return recv, nil 356 }