github.com/cloudwego/hertz@v0.9.3/pkg/app/server/binding/default.go (about)

     1  /*
     2   * Copyright 2023 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   * The MIT License
    16   *
    17   * Copyright (c) 2019-present Fenny and Contributors
    18   *
    19   * Permission is hereby granted, free of charge, to any person obtaining a copy
    20   * of this software and associated documentation files (the "Software"), to deal
    21   * in the Software without restriction, including without limitation the rights
    22   * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    23   * copies of the Software, and to permit persons to whom the Software is
    24   * furnished to do so, subject to the following conditions:
    25   *
    26   * The above copyright notice and this permission notice shall be included in all
    27   * copies or substantial portions of the Software.
    28   *
    29   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    30   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    31   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    32   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    33   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    34   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    35   * SOFTWARE.
    36   *
    37   * Copyright (c) 2014 Manuel Martínez-Almeida
    38   *
    39   * Permission is hereby granted, free of charge, to any person obtaining a copy
    40   * of this software and associated documentation files (the "Software"), to deal
    41   * in the Software without restriction, including without limitation the rights
    42   * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    43   * copies of the Software, and to permit persons to whom the Software is
    44   * furnished to do so, subject to the following conditions:
    45   *
    46   * The above copyright notice and this permission notice shall be included in
    47   * all copies or substantial portions of the Software.
    48   *
    49   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    50   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    51   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    52   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    53   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    54   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    55   * THE SOFTWARE.
    56   *
    57   * This file may have been modified by CloudWeGo authors. All CloudWeGo
    58   * Modifications are Copyright 2023 CloudWeGo Authors
    59   */
    60  
    61  package binding
    62  
    63  import (
    64  	"bytes"
    65  	stdJson "encoding/json"
    66  	"fmt"
    67  	"io"
    68  	"net/url"
    69  	"reflect"
    70  	"strings"
    71  	"sync"
    72  
    73  	exprValidator "github.com/bytedance/go-tagexpr/v2/validator"
    74  	"github.com/cloudwego/hertz/internal/bytesconv"
    75  	inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder"
    76  	hJson "github.com/cloudwego/hertz/pkg/common/json"
    77  	"github.com/cloudwego/hertz/pkg/common/utils"
    78  	"github.com/cloudwego/hertz/pkg/protocol"
    79  	"github.com/cloudwego/hertz/pkg/protocol/consts"
    80  	"github.com/cloudwego/hertz/pkg/route/param"
    81  	"google.golang.org/protobuf/proto"
    82  )
    83  
    84  const (
    85  	queryTag           = "query"
    86  	headerTag          = "header"
    87  	formTag            = "form"
    88  	pathTag            = "path"
    89  	defaultValidateTag = "vd"
    90  )
    91  
    92  type decoderInfo struct {
    93  	decoder      inDecoder.Decoder
    94  	needValidate bool
    95  }
    96  
    97  var defaultBind = NewDefaultBinder(nil)
    98  
    99  func DefaultBinder() Binder {
   100  	return defaultBind
   101  }
   102  
   103  type defaultBinder struct {
   104  	config             *BindConfig
   105  	decoderCache       sync.Map
   106  	queryDecoderCache  sync.Map
   107  	formDecoderCache   sync.Map
   108  	headerDecoderCache sync.Map
   109  	pathDecoderCache   sync.Map
   110  }
   111  
   112  func NewDefaultBinder(config *BindConfig) Binder {
   113  	if config == nil {
   114  		bindConfig := NewBindConfig()
   115  		bindConfig.initTypeUnmarshal()
   116  		return &defaultBinder{
   117  			config: bindConfig,
   118  		}
   119  	}
   120  	config.initTypeUnmarshal()
   121  	if config.Validator == nil {
   122  		config.Validator = DefaultValidator()
   123  	}
   124  	return &defaultBinder{
   125  		config: config,
   126  	}
   127  }
   128  
   129  // BindAndValidate binds data from *protocol.Request to obj and validates them if needed.
   130  // NOTE:
   131  //
   132  //	obj should be a pointer.
   133  func BindAndValidate(req *protocol.Request, obj interface{}, pathParams param.Params) error {
   134  	return DefaultBinder().BindAndValidate(req, obj, pathParams)
   135  }
   136  
   137  // Bind binds data from *protocol.Request to obj.
   138  // NOTE:
   139  //
   140  //	obj should be a pointer.
   141  func Bind(req *protocol.Request, obj interface{}, pathParams param.Params) error {
   142  	return DefaultBinder().Bind(req, obj, pathParams)
   143  }
   144  
   145  // Validate validates obj with "vd" tag
   146  // NOTE:
   147  //
   148  //	obj should be a pointer.
   149  //	Validate should be called after Bind.
   150  func Validate(obj interface{}) error {
   151  	return DefaultValidator().ValidateStruct(obj)
   152  }
   153  
   154  func (b *defaultBinder) tagCache(tag string) *sync.Map {
   155  	switch tag {
   156  	case queryTag:
   157  		return &b.queryDecoderCache
   158  	case headerTag:
   159  		return &b.headerDecoderCache
   160  	case formTag:
   161  		return &b.formDecoderCache
   162  	case pathTag:
   163  		return &b.pathDecoderCache
   164  	default:
   165  		return &b.decoderCache
   166  	}
   167  }
   168  
   169  func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params param.Params, tag string) error {
   170  	rv, typeID := valueAndTypeID(v)
   171  	if err := checkPointer(rv); err != nil {
   172  		return err
   173  	}
   174  	rt := dereferPointer(rv)
   175  	if rt.Kind() != reflect.Struct {
   176  		return b.bindNonStruct(req, v)
   177  	}
   178  
   179  	if len(tag) == 0 {
   180  		err := b.preBindBody(req, v)
   181  		if err != nil {
   182  			return fmt.Errorf("bind body failed, err=%v", err)
   183  		}
   184  	}
   185  	cache := b.tagCache(tag)
   186  	cached, ok := cache.Load(typeID)
   187  	if ok {
   188  		// cached fieldDecoder, fast path
   189  		decoder := cached.(decoderInfo)
   190  		return decoder.decoder(req, params, rv.Elem())
   191  	}
   192  	validateTag := defaultValidateTag
   193  	if len(b.config.Validator.ValidateTag()) != 0 {
   194  		validateTag = b.config.Validator.ValidateTag()
   195  	}
   196  	decodeConfig := &inDecoder.DecodeConfig{
   197  		LooseZeroMode:                      b.config.LooseZeroMode,
   198  		DisableDefaultTag:                  b.config.DisableDefaultTag,
   199  		DisableStructFieldResolve:          b.config.DisableStructFieldResolve,
   200  		EnableDecoderUseNumber:             b.config.EnableDecoderUseNumber,
   201  		EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields,
   202  		ValidateTag:                        validateTag,
   203  		TypeUnmarshalFuncs:                 b.config.TypeUnmarshalFuncs,
   204  	}
   205  	decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig)
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	cache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate})
   211  	return decoder(req, params, rv.Elem())
   212  }
   213  
   214  func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{}, params param.Params, tag string) error {
   215  	rv, typeID := valueAndTypeID(v)
   216  	if err := checkPointer(rv); err != nil {
   217  		return err
   218  	}
   219  	rt := dereferPointer(rv)
   220  	if rt.Kind() != reflect.Struct {
   221  		return b.bindNonStruct(req, v)
   222  	}
   223  
   224  	err := b.preBindBody(req, v)
   225  	if err != nil {
   226  		return fmt.Errorf("bind body failed, err=%v", err)
   227  	}
   228  	cache := b.tagCache(tag)
   229  	cached, ok := cache.Load(typeID)
   230  	if ok {
   231  		// cached fieldDecoder, fast path
   232  		decoder := cached.(decoderInfo)
   233  		err = decoder.decoder(req, params, rv.Elem())
   234  		if err != nil {
   235  			return err
   236  		}
   237  		if decoder.needValidate {
   238  			err = b.config.Validator.ValidateStruct(rv.Elem())
   239  		}
   240  		return err
   241  	}
   242  	validateTag := defaultValidateTag
   243  	if len(b.config.Validator.ValidateTag()) != 0 {
   244  		validateTag = b.config.Validator.ValidateTag()
   245  	}
   246  	decodeConfig := &inDecoder.DecodeConfig{
   247  		LooseZeroMode:                      b.config.LooseZeroMode,
   248  		DisableDefaultTag:                  b.config.DisableDefaultTag,
   249  		DisableStructFieldResolve:          b.config.DisableStructFieldResolve,
   250  		EnableDecoderUseNumber:             b.config.EnableDecoderUseNumber,
   251  		EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields,
   252  		ValidateTag:                        validateTag,
   253  		TypeUnmarshalFuncs:                 b.config.TypeUnmarshalFuncs,
   254  	}
   255  	decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig)
   256  	if err != nil {
   257  		return err
   258  	}
   259  
   260  	cache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate})
   261  	err = decoder(req, params, rv.Elem())
   262  	if err != nil {
   263  		return err
   264  	}
   265  	if needValidate {
   266  		err = b.config.Validator.ValidateStruct(rv.Elem())
   267  	}
   268  	return err
   269  }
   270  
   271  func (b *defaultBinder) BindQuery(req *protocol.Request, v interface{}) error {
   272  	return b.bindTag(req, v, nil, queryTag)
   273  }
   274  
   275  func (b *defaultBinder) BindHeader(req *protocol.Request, v interface{}) error {
   276  	return b.bindTag(req, v, nil, headerTag)
   277  }
   278  
   279  func (b *defaultBinder) BindPath(req *protocol.Request, v interface{}, params param.Params) error {
   280  	return b.bindTag(req, v, params, pathTag)
   281  }
   282  
   283  func (b *defaultBinder) BindForm(req *protocol.Request, v interface{}) error {
   284  	return b.bindTag(req, v, nil, formTag)
   285  }
   286  
   287  func (b *defaultBinder) BindJSON(req *protocol.Request, v interface{}) error {
   288  	return b.decodeJSON(bytes.NewReader(req.Body()), v)
   289  }
   290  
   291  func (b *defaultBinder) decodeJSON(r io.Reader, obj interface{}) error {
   292  	decoder := hJson.NewDecoder(r)
   293  	if b.config.EnableDecoderUseNumber {
   294  		decoder.UseNumber()
   295  	}
   296  	if b.config.EnableDecoderDisallowUnknownFields {
   297  		decoder.DisallowUnknownFields()
   298  	}
   299  
   300  	return decoder.Decode(obj)
   301  }
   302  
   303  func (b *defaultBinder) BindProtobuf(req *protocol.Request, v interface{}) error {
   304  	msg, ok := v.(proto.Message)
   305  	if !ok {
   306  		return fmt.Errorf("%s does not implement 'proto.Message'", v)
   307  	}
   308  	return proto.Unmarshal(req.Body(), msg)
   309  }
   310  
   311  func (b *defaultBinder) Name() string {
   312  	return "hertz"
   313  }
   314  
   315  func (b *defaultBinder) BindAndValidate(req *protocol.Request, v interface{}, params param.Params) error {
   316  	return b.bindTagWithValidate(req, v, params, "")
   317  }
   318  
   319  func (b *defaultBinder) Bind(req *protocol.Request, v interface{}, params param.Params) error {
   320  	return b.bindTag(req, v, params, "")
   321  }
   322  
   323  // best effort binding
   324  func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error {
   325  	if req.Header.ContentLength() <= 0 {
   326  		return nil
   327  	}
   328  	ct := bytesconv.B2s(req.Header.ContentType())
   329  	switch strings.ToLower(utils.FilterContentType(ct)) {
   330  	case consts.MIMEApplicationJSON:
   331  		return hJson.Unmarshal(req.Body(), v)
   332  	case consts.MIMEPROTOBUF:
   333  		msg, ok := v.(proto.Message)
   334  		if !ok {
   335  			return fmt.Errorf("%s can not implement 'proto.Message'", v)
   336  		}
   337  		return proto.Unmarshal(req.Body(), msg)
   338  	default:
   339  		return nil
   340  	}
   341  }
   342  
   343  func (b *defaultBinder) bindNonStruct(req *protocol.Request, v interface{}) (err error) {
   344  	ct := bytesconv.B2s(req.Header.ContentType())
   345  	switch strings.ToLower(utils.FilterContentType(ct)) {
   346  	case consts.MIMEApplicationJSON:
   347  		err = hJson.Unmarshal(req.Body(), v)
   348  	case consts.MIMEPROTOBUF:
   349  		msg, ok := v.(proto.Message)
   350  		if !ok {
   351  			return fmt.Errorf("%s can not implement 'proto.Message'", v)
   352  		}
   353  		err = proto.Unmarshal(req.Body(), msg)
   354  	case consts.MIMEMultipartPOSTForm:
   355  		form := make(url.Values)
   356  		mf, err1 := req.MultipartForm()
   357  		if err1 == nil && mf.Value != nil {
   358  			for k, v := range mf.Value {
   359  				for _, vv := range v {
   360  					form.Add(k, vv)
   361  				}
   362  			}
   363  		}
   364  		b, _ := stdJson.Marshal(form)
   365  		err = hJson.Unmarshal(b, v)
   366  	case consts.MIMEApplicationHTMLForm:
   367  		form := make(url.Values)
   368  		req.PostArgs().VisitAll(func(formKey, value []byte) {
   369  			form.Add(string(formKey), string(value))
   370  		})
   371  		b, _ := stdJson.Marshal(form)
   372  		err = hJson.Unmarshal(b, v)
   373  	default:
   374  		// using query to decode
   375  		query := make(url.Values)
   376  		req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) {
   377  			query.Add(string(queryKey), string(value))
   378  		})
   379  		b, _ := stdJson.Marshal(query)
   380  		err = hJson.Unmarshal(b, v)
   381  	}
   382  	return
   383  }
   384  
   385  var _ StructValidator = (*validator)(nil)
   386  
   387  type validator struct {
   388  	validateTag string
   389  	validate    *exprValidator.Validator
   390  }
   391  
   392  func NewValidator(config *ValidateConfig) StructValidator {
   393  	validateTag := defaultValidateTag
   394  	if config != nil && len(config.ValidateTag) != 0 {
   395  		validateTag = config.ValidateTag
   396  	}
   397  	vd := exprValidator.New(validateTag).SetErrorFactory(defaultValidateErrorFactory)
   398  	if config != nil && config.ErrFactory != nil {
   399  		vd.SetErrorFactory(config.ErrFactory)
   400  	}
   401  	return &validator{
   402  		validateTag: validateTag,
   403  		validate:    vd,
   404  	}
   405  }
   406  
   407  // Error validate error
   408  type validateError struct {
   409  	FailPath, Msg string
   410  }
   411  
   412  // Error implements error interface.
   413  func (e *validateError) Error() string {
   414  	if e.Msg != "" {
   415  		return e.Msg
   416  	}
   417  	return "invalid parameter: " + e.FailPath
   418  }
   419  
   420  func defaultValidateErrorFactory(failPath, msg string) error {
   421  	return &validateError{
   422  		FailPath: failPath,
   423  		Msg:      msg,
   424  	}
   425  }
   426  
   427  // ValidateStruct receives any kind of type, but only performed struct or pointer to struct type.
   428  func (v *validator) ValidateStruct(obj interface{}) error {
   429  	if obj == nil {
   430  		return nil
   431  	}
   432  	return v.validate.Validate(obj)
   433  }
   434  
   435  // Engine returns the underlying validator
   436  func (v *validator) Engine() interface{} {
   437  	return v.validate
   438  }
   439  
   440  func (v *validator) ValidateTag() string {
   441  	return v.validateTag
   442  }
   443  
   444  var defaultValidate = NewValidator(NewValidateConfig())
   445  
   446  func DefaultValidator() StructValidator {
   447  	return defaultValidate
   448  }