github.com/cloudwego/hertz@v0.9.3/pkg/app/server/binding/default.go (about) 1 /* 2 * Copyright 2023 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * The MIT License 16 * 17 * Copyright (c) 2019-present Fenny and Contributors 18 * 19 * Permission is hereby granted, free of charge, to any person obtaining a copy 20 * of this software and associated documentation files (the "Software"), to deal 21 * in the Software without restriction, including without limitation the rights 22 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 * copies of the Software, and to permit persons to whom the Software is 24 * furnished to do so, subject to the following conditions: 25 * 26 * The above copyright notice and this permission notice shall be included in all 27 * copies or substantial portions of the Software. 28 * 29 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 * SOFTWARE. 36 * 37 * Copyright (c) 2014 Manuel MartÃnez-Almeida 38 * 39 * Permission is hereby granted, free of charge, to any person obtaining a copy 40 * of this software and associated documentation files (the "Software"), to deal 41 * in the Software without restriction, including without limitation the rights 42 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 43 * copies of the Software, and to permit persons to whom the Software is 44 * furnished to do so, subject to the following conditions: 45 * 46 * The above copyright notice and this permission notice shall be included in 47 * all copies or substantial portions of the Software. 48 * 49 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 50 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 51 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 52 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 53 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 54 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 55 * THE SOFTWARE. 56 * 57 * This file may have been modified by CloudWeGo authors. All CloudWeGo 58 * Modifications are Copyright 2023 CloudWeGo Authors 59 */ 60 61 package binding 62 63 import ( 64 "bytes" 65 stdJson "encoding/json" 66 "fmt" 67 "io" 68 "net/url" 69 "reflect" 70 "strings" 71 "sync" 72 73 exprValidator "github.com/bytedance/go-tagexpr/v2/validator" 74 "github.com/cloudwego/hertz/internal/bytesconv" 75 inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" 76 hJson "github.com/cloudwego/hertz/pkg/common/json" 77 "github.com/cloudwego/hertz/pkg/common/utils" 78 "github.com/cloudwego/hertz/pkg/protocol" 79 "github.com/cloudwego/hertz/pkg/protocol/consts" 80 "github.com/cloudwego/hertz/pkg/route/param" 81 "google.golang.org/protobuf/proto" 82 ) 83 84 const ( 85 queryTag = "query" 86 headerTag = "header" 87 formTag = "form" 88 pathTag = "path" 89 defaultValidateTag = "vd" 90 ) 91 92 type decoderInfo struct { 93 decoder inDecoder.Decoder 94 needValidate bool 95 } 96 97 var defaultBind = NewDefaultBinder(nil) 98 99 func DefaultBinder() Binder { 100 return defaultBind 101 } 102 103 type defaultBinder struct { 104 config *BindConfig 105 decoderCache sync.Map 106 queryDecoderCache sync.Map 107 formDecoderCache sync.Map 108 headerDecoderCache sync.Map 109 pathDecoderCache sync.Map 110 } 111 112 func NewDefaultBinder(config *BindConfig) Binder { 113 if config == nil { 114 bindConfig := NewBindConfig() 115 bindConfig.initTypeUnmarshal() 116 return &defaultBinder{ 117 config: bindConfig, 118 } 119 } 120 config.initTypeUnmarshal() 121 if config.Validator == nil { 122 config.Validator = DefaultValidator() 123 } 124 return &defaultBinder{ 125 config: config, 126 } 127 } 128 129 // BindAndValidate binds data from *protocol.Request to obj and validates them if needed. 130 // NOTE: 131 // 132 // obj should be a pointer. 133 func BindAndValidate(req *protocol.Request, obj interface{}, pathParams param.Params) error { 134 return DefaultBinder().BindAndValidate(req, obj, pathParams) 135 } 136 137 // Bind binds data from *protocol.Request to obj. 138 // NOTE: 139 // 140 // obj should be a pointer. 141 func Bind(req *protocol.Request, obj interface{}, pathParams param.Params) error { 142 return DefaultBinder().Bind(req, obj, pathParams) 143 } 144 145 // Validate validates obj with "vd" tag 146 // NOTE: 147 // 148 // obj should be a pointer. 149 // Validate should be called after Bind. 150 func Validate(obj interface{}) error { 151 return DefaultValidator().ValidateStruct(obj) 152 } 153 154 func (b *defaultBinder) tagCache(tag string) *sync.Map { 155 switch tag { 156 case queryTag: 157 return &b.queryDecoderCache 158 case headerTag: 159 return &b.headerDecoderCache 160 case formTag: 161 return &b.formDecoderCache 162 case pathTag: 163 return &b.pathDecoderCache 164 default: 165 return &b.decoderCache 166 } 167 } 168 169 func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params param.Params, tag string) error { 170 rv, typeID := valueAndTypeID(v) 171 if err := checkPointer(rv); err != nil { 172 return err 173 } 174 rt := dereferPointer(rv) 175 if rt.Kind() != reflect.Struct { 176 return b.bindNonStruct(req, v) 177 } 178 179 if len(tag) == 0 { 180 err := b.preBindBody(req, v) 181 if err != nil { 182 return fmt.Errorf("bind body failed, err=%v", err) 183 } 184 } 185 cache := b.tagCache(tag) 186 cached, ok := cache.Load(typeID) 187 if ok { 188 // cached fieldDecoder, fast path 189 decoder := cached.(decoderInfo) 190 return decoder.decoder(req, params, rv.Elem()) 191 } 192 validateTag := defaultValidateTag 193 if len(b.config.Validator.ValidateTag()) != 0 { 194 validateTag = b.config.Validator.ValidateTag() 195 } 196 decodeConfig := &inDecoder.DecodeConfig{ 197 LooseZeroMode: b.config.LooseZeroMode, 198 DisableDefaultTag: b.config.DisableDefaultTag, 199 DisableStructFieldResolve: b.config.DisableStructFieldResolve, 200 EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, 201 EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, 202 ValidateTag: validateTag, 203 TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs, 204 } 205 decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig) 206 if err != nil { 207 return err 208 } 209 210 cache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) 211 return decoder(req, params, rv.Elem()) 212 } 213 214 func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{}, params param.Params, tag string) error { 215 rv, typeID := valueAndTypeID(v) 216 if err := checkPointer(rv); err != nil { 217 return err 218 } 219 rt := dereferPointer(rv) 220 if rt.Kind() != reflect.Struct { 221 return b.bindNonStruct(req, v) 222 } 223 224 err := b.preBindBody(req, v) 225 if err != nil { 226 return fmt.Errorf("bind body failed, err=%v", err) 227 } 228 cache := b.tagCache(tag) 229 cached, ok := cache.Load(typeID) 230 if ok { 231 // cached fieldDecoder, fast path 232 decoder := cached.(decoderInfo) 233 err = decoder.decoder(req, params, rv.Elem()) 234 if err != nil { 235 return err 236 } 237 if decoder.needValidate { 238 err = b.config.Validator.ValidateStruct(rv.Elem()) 239 } 240 return err 241 } 242 validateTag := defaultValidateTag 243 if len(b.config.Validator.ValidateTag()) != 0 { 244 validateTag = b.config.Validator.ValidateTag() 245 } 246 decodeConfig := &inDecoder.DecodeConfig{ 247 LooseZeroMode: b.config.LooseZeroMode, 248 DisableDefaultTag: b.config.DisableDefaultTag, 249 DisableStructFieldResolve: b.config.DisableStructFieldResolve, 250 EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, 251 EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, 252 ValidateTag: validateTag, 253 TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs, 254 } 255 decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig) 256 if err != nil { 257 return err 258 } 259 260 cache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) 261 err = decoder(req, params, rv.Elem()) 262 if err != nil { 263 return err 264 } 265 if needValidate { 266 err = b.config.Validator.ValidateStruct(rv.Elem()) 267 } 268 return err 269 } 270 271 func (b *defaultBinder) BindQuery(req *protocol.Request, v interface{}) error { 272 return b.bindTag(req, v, nil, queryTag) 273 } 274 275 func (b *defaultBinder) BindHeader(req *protocol.Request, v interface{}) error { 276 return b.bindTag(req, v, nil, headerTag) 277 } 278 279 func (b *defaultBinder) BindPath(req *protocol.Request, v interface{}, params param.Params) error { 280 return b.bindTag(req, v, params, pathTag) 281 } 282 283 func (b *defaultBinder) BindForm(req *protocol.Request, v interface{}) error { 284 return b.bindTag(req, v, nil, formTag) 285 } 286 287 func (b *defaultBinder) BindJSON(req *protocol.Request, v interface{}) error { 288 return b.decodeJSON(bytes.NewReader(req.Body()), v) 289 } 290 291 func (b *defaultBinder) decodeJSON(r io.Reader, obj interface{}) error { 292 decoder := hJson.NewDecoder(r) 293 if b.config.EnableDecoderUseNumber { 294 decoder.UseNumber() 295 } 296 if b.config.EnableDecoderDisallowUnknownFields { 297 decoder.DisallowUnknownFields() 298 } 299 300 return decoder.Decode(obj) 301 } 302 303 func (b *defaultBinder) BindProtobuf(req *protocol.Request, v interface{}) error { 304 msg, ok := v.(proto.Message) 305 if !ok { 306 return fmt.Errorf("%s does not implement 'proto.Message'", v) 307 } 308 return proto.Unmarshal(req.Body(), msg) 309 } 310 311 func (b *defaultBinder) Name() string { 312 return "hertz" 313 } 314 315 func (b *defaultBinder) BindAndValidate(req *protocol.Request, v interface{}, params param.Params) error { 316 return b.bindTagWithValidate(req, v, params, "") 317 } 318 319 func (b *defaultBinder) Bind(req *protocol.Request, v interface{}, params param.Params) error { 320 return b.bindTag(req, v, params, "") 321 } 322 323 // best effort binding 324 func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error { 325 if req.Header.ContentLength() <= 0 { 326 return nil 327 } 328 ct := bytesconv.B2s(req.Header.ContentType()) 329 switch strings.ToLower(utils.FilterContentType(ct)) { 330 case consts.MIMEApplicationJSON: 331 return hJson.Unmarshal(req.Body(), v) 332 case consts.MIMEPROTOBUF: 333 msg, ok := v.(proto.Message) 334 if !ok { 335 return fmt.Errorf("%s can not implement 'proto.Message'", v) 336 } 337 return proto.Unmarshal(req.Body(), msg) 338 default: 339 return nil 340 } 341 } 342 343 func (b *defaultBinder) bindNonStruct(req *protocol.Request, v interface{}) (err error) { 344 ct := bytesconv.B2s(req.Header.ContentType()) 345 switch strings.ToLower(utils.FilterContentType(ct)) { 346 case consts.MIMEApplicationJSON: 347 err = hJson.Unmarshal(req.Body(), v) 348 case consts.MIMEPROTOBUF: 349 msg, ok := v.(proto.Message) 350 if !ok { 351 return fmt.Errorf("%s can not implement 'proto.Message'", v) 352 } 353 err = proto.Unmarshal(req.Body(), msg) 354 case consts.MIMEMultipartPOSTForm: 355 form := make(url.Values) 356 mf, err1 := req.MultipartForm() 357 if err1 == nil && mf.Value != nil { 358 for k, v := range mf.Value { 359 for _, vv := range v { 360 form.Add(k, vv) 361 } 362 } 363 } 364 b, _ := stdJson.Marshal(form) 365 err = hJson.Unmarshal(b, v) 366 case consts.MIMEApplicationHTMLForm: 367 form := make(url.Values) 368 req.PostArgs().VisitAll(func(formKey, value []byte) { 369 form.Add(string(formKey), string(value)) 370 }) 371 b, _ := stdJson.Marshal(form) 372 err = hJson.Unmarshal(b, v) 373 default: 374 // using query to decode 375 query := make(url.Values) 376 req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { 377 query.Add(string(queryKey), string(value)) 378 }) 379 b, _ := stdJson.Marshal(query) 380 err = hJson.Unmarshal(b, v) 381 } 382 return 383 } 384 385 var _ StructValidator = (*validator)(nil) 386 387 type validator struct { 388 validateTag string 389 validate *exprValidator.Validator 390 } 391 392 func NewValidator(config *ValidateConfig) StructValidator { 393 validateTag := defaultValidateTag 394 if config != nil && len(config.ValidateTag) != 0 { 395 validateTag = config.ValidateTag 396 } 397 vd := exprValidator.New(validateTag).SetErrorFactory(defaultValidateErrorFactory) 398 if config != nil && config.ErrFactory != nil { 399 vd.SetErrorFactory(config.ErrFactory) 400 } 401 return &validator{ 402 validateTag: validateTag, 403 validate: vd, 404 } 405 } 406 407 // Error validate error 408 type validateError struct { 409 FailPath, Msg string 410 } 411 412 // Error implements error interface. 413 func (e *validateError) Error() string { 414 if e.Msg != "" { 415 return e.Msg 416 } 417 return "invalid parameter: " + e.FailPath 418 } 419 420 func defaultValidateErrorFactory(failPath, msg string) error { 421 return &validateError{ 422 FailPath: failPath, 423 Msg: msg, 424 } 425 } 426 427 // ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. 428 func (v *validator) ValidateStruct(obj interface{}) error { 429 if obj == nil { 430 return nil 431 } 432 return v.validate.Validate(obj) 433 } 434 435 // Engine returns the underlying validator 436 func (v *validator) Engine() interface{} { 437 return v.validate 438 } 439 440 func (v *validator) ValidateTag() string { 441 return v.validateTag 442 } 443 444 var defaultValidate = NewValidator(NewValidateConfig()) 445 446 func DefaultValidator() StructValidator { 447 return defaultValidate 448 }