github.com/hduhelp/go-zero@v1.4.3/core/mapping/utils.go (about)

     1  package mapping
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"math"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/hduhelp/go-zero/core/stringx"
    14  )
    15  
    16  const (
    17  	defaultOption      = "default"
    18  	stringOption       = "string"
    19  	optionalOption     = "optional"
    20  	optionsOption      = "options"
    21  	rangeOption        = "range"
    22  	optionSeparator    = "|"
    23  	equalToken         = "="
    24  	escapeChar         = '\\'
    25  	leftBracket        = '('
    26  	rightBracket       = ')'
    27  	leftSquareBracket  = '['
    28  	rightSquareBracket = ']'
    29  	segmentSeparator   = ','
    30  )
    31  
    32  var (
    33  	errUnsupportedType  = errors.New("unsupported type on setting field value")
    34  	errNumberRange      = errors.New("wrong number range setting")
    35  	optionsCache        = make(map[string]optionsCacheValue)
    36  	cacheLock           sync.RWMutex
    37  	structRequiredCache = make(map[reflect.Type]requiredCacheValue)
    38  	structCacheLock     sync.RWMutex
    39  )
    40  
    41  type (
    42  	optionsCacheValue struct {
    43  		key     string
    44  		options *fieldOptions
    45  		err     error
    46  	}
    47  
    48  	requiredCacheValue struct {
    49  		required bool
    50  		err      error
    51  	}
    52  )
    53  
    54  // Deref dereferences a type, if pointer type, returns its element type.
    55  func Deref(t reflect.Type) reflect.Type {
    56  	if t.Kind() == reflect.Ptr {
    57  		t = t.Elem()
    58  	}
    59  
    60  	return t
    61  }
    62  
    63  // Repr returns the string representation of v.
    64  func Repr(v interface{}) string {
    65  	if v == nil {
    66  		return ""
    67  	}
    68  
    69  	// if func (v *Type) String() string, we can't use Elem()
    70  	switch vt := v.(type) {
    71  	case fmt.Stringer:
    72  		return vt.String()
    73  	}
    74  
    75  	val := reflect.ValueOf(v)
    76  	if val.Kind() == reflect.Ptr && !val.IsNil() {
    77  		val = val.Elem()
    78  	}
    79  
    80  	return reprOfValue(val)
    81  }
    82  
    83  // ValidatePtr validates v if it's a valid pointer.
    84  func ValidatePtr(v *reflect.Value) error {
    85  	// sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr,
    86  	// panic otherwise
    87  	if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() {
    88  		return fmt.Errorf("not a valid pointer: %v", v)
    89  	}
    90  
    91  	return nil
    92  }
    93  
    94  func convertType(kind reflect.Kind, str string) (interface{}, error) {
    95  	switch kind {
    96  	case reflect.Bool:
    97  		return str == "1" || strings.ToLower(str) == "true", nil
    98  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    99  		intValue, err := strconv.ParseInt(str, 10, 64)
   100  		if err != nil {
   101  			return 0, fmt.Errorf("the value %q cannot parsed as int", str)
   102  		}
   103  
   104  		return intValue, nil
   105  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   106  		uintValue, err := strconv.ParseUint(str, 10, 64)
   107  		if err != nil {
   108  			return 0, fmt.Errorf("the value %q cannot parsed as uint", str)
   109  		}
   110  
   111  		return uintValue, nil
   112  	case reflect.Float32, reflect.Float64:
   113  		floatValue, err := strconv.ParseFloat(str, 64)
   114  		if err != nil {
   115  			return 0, fmt.Errorf("the value %q cannot parsed as float", str)
   116  		}
   117  
   118  		return floatValue, nil
   119  	case reflect.String:
   120  		return str, nil
   121  	default:
   122  		return nil, errUnsupportedType
   123  	}
   124  }
   125  
   126  func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fieldOptions, error) {
   127  	segments := parseSegments(value)
   128  	key := strings.TrimSpace(segments[0])
   129  	options := segments[1:]
   130  
   131  	if len(options) == 0 {
   132  		return key, nil, nil
   133  	}
   134  
   135  	var fieldOpts fieldOptions
   136  	for _, segment := range options {
   137  		option := strings.TrimSpace(segment)
   138  		if err := parseOption(&fieldOpts, field.Name, option); err != nil {
   139  			return "", nil, err
   140  		}
   141  	}
   142  
   143  	return key, &fieldOpts, nil
   144  }
   145  
   146  // ensureValue ensures nested members not to be nil.
   147  // If pointer value is nil, set to a new value.
   148  func ensureValue(v reflect.Value) reflect.Value {
   149  	for {
   150  		if v.Kind() != reflect.Ptr {
   151  			break
   152  		}
   153  
   154  		if v.IsNil() {
   155  			v.Set(reflect.New(v.Type().Elem()))
   156  		}
   157  		v = v.Elem()
   158  	}
   159  
   160  	return v
   161  }
   162  
   163  func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
   164  	numFields := tp.NumField()
   165  	for i := 0; i < numFields; i++ {
   166  		childField := tp.Field(i)
   167  		if usingDifferentKeys(tag, childField) {
   168  			return true, nil
   169  		}
   170  
   171  		_, opts, err := parseKeyAndOptions(tag, childField)
   172  		if err != nil {
   173  			return false, err
   174  		}
   175  
   176  		if opts == nil {
   177  			if childField.Type.Kind() != reflect.Struct {
   178  				return true, nil
   179  			}
   180  
   181  			if required, err := implicitValueRequiredStruct(tag, childField.Type); err != nil {
   182  				return false, err
   183  			} else if required {
   184  				return true, nil
   185  			}
   186  		} else if !opts.Optional && len(opts.Default) == 0 {
   187  			return true, nil
   188  		} else if len(opts.OptionalDep) > 0 && opts.OptionalDep[0] == notSymbol {
   189  			return true, nil
   190  		}
   191  	}
   192  
   193  	return false, nil
   194  }
   195  
   196  func isLeftInclude(b byte) (bool, error) {
   197  	switch b {
   198  	case '[':
   199  		return true, nil
   200  	case '(':
   201  		return false, nil
   202  	default:
   203  		return false, errNumberRange
   204  	}
   205  }
   206  
   207  func isRightInclude(b byte) (bool, error) {
   208  	switch b {
   209  	case ']':
   210  		return true, nil
   211  	case ')':
   212  		return false, nil
   213  	default:
   214  		return false, errNumberRange
   215  	}
   216  }
   217  
   218  func maybeNewValue(field reflect.StructField, value reflect.Value) {
   219  	if field.Type.Kind() == reflect.Ptr && value.IsNil() {
   220  		value.Set(reflect.New(value.Type().Elem()))
   221  	}
   222  }
   223  
   224  func parseGroupedSegments(val string) []string {
   225  	val = strings.TrimLeftFunc(val, func(r rune) bool {
   226  		return r == leftBracket || r == leftSquareBracket
   227  	})
   228  	val = strings.TrimRightFunc(val, func(r rune) bool {
   229  		return r == rightBracket || r == rightSquareBracket
   230  	})
   231  	return parseSegments(val)
   232  }
   233  
   234  // don't modify returned fieldOptions, it's cached and shared among different calls.
   235  func parseKeyAndOptions(tagName string, field reflect.StructField) (string, *fieldOptions, error) {
   236  	value := field.Tag.Get(tagName)
   237  	if len(value) == 0 {
   238  		return field.Name, nil, nil
   239  	}
   240  
   241  	cacheLock.RLock()
   242  	cache, ok := optionsCache[value]
   243  	cacheLock.RUnlock()
   244  	if ok {
   245  		return stringx.TakeOne(cache.key, field.Name), cache.options, cache.err
   246  	}
   247  
   248  	key, options, err := doParseKeyAndOptions(field, value)
   249  	cacheLock.Lock()
   250  	optionsCache[value] = optionsCacheValue{
   251  		key:     key,
   252  		options: options,
   253  		err:     err,
   254  	}
   255  	cacheLock.Unlock()
   256  
   257  	return stringx.TakeOne(key, field.Name), options, err
   258  }
   259  
   260  // support below notations:
   261  // [:5] (:5] [:5) (:5)
   262  // [1:] [1:) (1:] (1:)
   263  // [1:5] [1:5) (1:5] (1:5)
   264  func parseNumberRange(str string) (*numberRange, error) {
   265  	if len(str) == 0 {
   266  		return nil, errNumberRange
   267  	}
   268  
   269  	leftInclude, err := isLeftInclude(str[0])
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  
   274  	str = str[1:]
   275  	if len(str) == 0 {
   276  		return nil, errNumberRange
   277  	}
   278  
   279  	rightInclude, err := isRightInclude(str[len(str)-1])
   280  	if err != nil {
   281  		return nil, err
   282  	}
   283  
   284  	str = str[:len(str)-1]
   285  	fields := strings.Split(str, ":")
   286  	if len(fields) != 2 {
   287  		return nil, errNumberRange
   288  	}
   289  
   290  	if len(fields[0]) == 0 && len(fields[1]) == 0 {
   291  		return nil, errNumberRange
   292  	}
   293  
   294  	var left float64
   295  	if len(fields[0]) > 0 {
   296  		var err error
   297  		if left, err = strconv.ParseFloat(fields[0], 64); err != nil {
   298  			return nil, err
   299  		}
   300  	} else {
   301  		left = -math.MaxFloat64
   302  	}
   303  
   304  	var right float64
   305  	if len(fields[1]) > 0 {
   306  		var err error
   307  		if right, err = strconv.ParseFloat(fields[1], 64); err != nil {
   308  			return nil, err
   309  		}
   310  	} else {
   311  		right = math.MaxFloat64
   312  	}
   313  
   314  	if left > right {
   315  		return nil, errNumberRange
   316  	}
   317  
   318  	// [2:2] valid
   319  	// [2:2) invalid
   320  	// (2:2] invalid
   321  	// (2:2) invalid
   322  	if left == right {
   323  		if !leftInclude || !rightInclude {
   324  			return nil, errNumberRange
   325  		}
   326  	}
   327  
   328  	return &numberRange{
   329  		left:         left,
   330  		leftInclude:  leftInclude,
   331  		right:        right,
   332  		rightInclude: rightInclude,
   333  	}, nil
   334  }
   335  
   336  func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
   337  	switch {
   338  	case option == stringOption:
   339  		fieldOpts.FromString = true
   340  	case strings.HasPrefix(option, optionalOption):
   341  		segs := strings.Split(option, equalToken)
   342  		switch len(segs) {
   343  		case 1:
   344  			fieldOpts.Optional = true
   345  		case 2:
   346  			fieldOpts.Optional = true
   347  			fieldOpts.OptionalDep = segs[1]
   348  		default:
   349  			return fmt.Errorf("field %s has wrong optional", fieldName)
   350  		}
   351  	case option == optionalOption:
   352  		fieldOpts.Optional = true
   353  	case strings.HasPrefix(option, optionsOption):
   354  		segs := strings.Split(option, equalToken)
   355  		if len(segs) != 2 {
   356  			return fmt.Errorf("field %s has wrong options", fieldName)
   357  		}
   358  
   359  		fieldOpts.Options = parseOptions(segs[1])
   360  	case strings.HasPrefix(option, defaultOption):
   361  		segs := strings.Split(option, equalToken)
   362  		if len(segs) != 2 {
   363  			return fmt.Errorf("field %s has wrong default option", fieldName)
   364  		}
   365  
   366  		fieldOpts.Default = strings.TrimSpace(segs[1])
   367  	case strings.HasPrefix(option, rangeOption):
   368  		segs := strings.Split(option, equalToken)
   369  		if len(segs) != 2 {
   370  			return fmt.Errorf("field %s has wrong range", fieldName)
   371  		}
   372  
   373  		nr, err := parseNumberRange(segs[1])
   374  		if err != nil {
   375  			return err
   376  		}
   377  
   378  		fieldOpts.Range = nr
   379  	}
   380  
   381  	return nil
   382  }
   383  
   384  // parseOptions parses the given options in tag.
   385  // for example: `json:"name,options=foo|bar"` or `json:"name,options=[foo,bar]"`
   386  func parseOptions(val string) []string {
   387  	if len(val) == 0 {
   388  		return nil
   389  	}
   390  
   391  	if val[0] == leftSquareBracket {
   392  		return parseGroupedSegments(val)
   393  	}
   394  
   395  	return strings.Split(val, optionSeparator)
   396  }
   397  
   398  func parseSegments(val string) []string {
   399  	var segments []string
   400  	var escaped, grouped bool
   401  	var buf strings.Builder
   402  
   403  	for _, ch := range val {
   404  		if escaped {
   405  			buf.WriteRune(ch)
   406  			escaped = false
   407  			continue
   408  		}
   409  
   410  		switch ch {
   411  		case segmentSeparator:
   412  			if grouped {
   413  				buf.WriteRune(ch)
   414  			} else {
   415  				// need to trim spaces, but we cannot ignore empty string,
   416  				// because the first segment stands for the key might be empty.
   417  				// if ignored, the later tag will be used as the key.
   418  				segments = append(segments, strings.TrimSpace(buf.String()))
   419  				buf.Reset()
   420  			}
   421  		case escapeChar:
   422  			if grouped {
   423  				buf.WriteRune(ch)
   424  			} else {
   425  				escaped = true
   426  			}
   427  		case leftBracket, leftSquareBracket:
   428  			buf.WriteRune(ch)
   429  			grouped = true
   430  		case rightBracket, rightSquareBracket:
   431  			buf.WriteRune(ch)
   432  			grouped = false
   433  		default:
   434  			buf.WriteRune(ch)
   435  		}
   436  	}
   437  
   438  	last := strings.TrimSpace(buf.String())
   439  	// ignore last empty string
   440  	if len(last) > 0 {
   441  		segments = append(segments, last)
   442  	}
   443  
   444  	return segments
   445  }
   446  
   447  func reprOfValue(val reflect.Value) string {
   448  	switch vt := val.Interface().(type) {
   449  	case bool:
   450  		return strconv.FormatBool(vt)
   451  	case error:
   452  		return vt.Error()
   453  	case float32:
   454  		return strconv.FormatFloat(float64(vt), 'f', -1, 32)
   455  	case float64:
   456  		return strconv.FormatFloat(vt, 'f', -1, 64)
   457  	case fmt.Stringer:
   458  		return vt.String()
   459  	case int:
   460  		return strconv.Itoa(vt)
   461  	case int8:
   462  		return strconv.Itoa(int(vt))
   463  	case int16:
   464  		return strconv.Itoa(int(vt))
   465  	case int32:
   466  		return strconv.Itoa(int(vt))
   467  	case int64:
   468  		return strconv.FormatInt(vt, 10)
   469  	case string:
   470  		return vt
   471  	case uint:
   472  		return strconv.FormatUint(uint64(vt), 10)
   473  	case uint8:
   474  		return strconv.FormatUint(uint64(vt), 10)
   475  	case uint16:
   476  		return strconv.FormatUint(uint64(vt), 10)
   477  	case uint32:
   478  		return strconv.FormatUint(uint64(vt), 10)
   479  	case uint64:
   480  		return strconv.FormatUint(vt, 10)
   481  	case []byte:
   482  		return string(vt)
   483  	default:
   484  		return fmt.Sprint(val.Interface())
   485  	}
   486  }
   487  
   488  func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interface{}) error {
   489  	switch kind {
   490  	case reflect.Bool:
   491  		value.SetBool(v.(bool))
   492  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   493  		value.SetInt(v.(int64))
   494  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   495  		value.SetUint(v.(uint64))
   496  	case reflect.Float32, reflect.Float64:
   497  		value.SetFloat(v.(float64))
   498  	case reflect.String:
   499  		value.SetString(v.(string))
   500  	default:
   501  		return errUnsupportedType
   502  	}
   503  
   504  	return nil
   505  }
   506  
   507  func setValue(kind reflect.Kind, value reflect.Value, str string) error {
   508  	if !value.CanSet() {
   509  		return errValueNotSettable
   510  	}
   511  
   512  	value = ensureValue(value)
   513  	v, err := convertType(kind, str)
   514  	if err != nil {
   515  		return err
   516  	}
   517  
   518  	return setMatchedPrimitiveValue(kind, value, v)
   519  }
   520  
   521  func structValueRequired(tag string, tp reflect.Type) (bool, error) {
   522  	structCacheLock.RLock()
   523  	val, ok := structRequiredCache[tp]
   524  	structCacheLock.RUnlock()
   525  	if ok {
   526  		return val.required, val.err
   527  	}
   528  
   529  	required, err := implicitValueRequiredStruct(tag, tp)
   530  	structCacheLock.Lock()
   531  	structRequiredCache[tp] = requiredCacheValue{
   532  		required: required,
   533  		err:      err,
   534  	}
   535  	structCacheLock.Unlock()
   536  
   537  	return required, err
   538  }
   539  
   540  func toFloat64(v interface{}) (float64, bool) {
   541  	switch val := v.(type) {
   542  	case int:
   543  		return float64(val), true
   544  	case int8:
   545  		return float64(val), true
   546  	case int16:
   547  		return float64(val), true
   548  	case int32:
   549  		return float64(val), true
   550  	case int64:
   551  		return float64(val), true
   552  	case uint:
   553  		return float64(val), true
   554  	case uint8:
   555  		return float64(val), true
   556  	case uint16:
   557  		return float64(val), true
   558  	case uint32:
   559  		return float64(val), true
   560  	case uint64:
   561  		return float64(val), true
   562  	case float32:
   563  		return float64(val), true
   564  	case float64:
   565  		return val, true
   566  	default:
   567  		return 0, false
   568  	}
   569  }
   570  
   571  func usingDifferentKeys(key string, field reflect.StructField) bool {
   572  	if len(field.Tag) > 0 {
   573  		if _, ok := field.Tag.Lookup(key); !ok {
   574  			return true
   575  		}
   576  	}
   577  
   578  	return false
   579  }
   580  
   581  func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string, opts *fieldOptionsWithContext) error {
   582  	if !value.CanSet() {
   583  		return errValueNotSettable
   584  	}
   585  
   586  	v, err := convertType(kind, str)
   587  	if err != nil {
   588  		return err
   589  	}
   590  
   591  	if err := validateValueRange(v, opts); err != nil {
   592  		return err
   593  	}
   594  
   595  	return setMatchedPrimitiveValue(kind, value, v)
   596  }
   597  
   598  func validateJsonNumberRange(v json.Number, opts *fieldOptionsWithContext) error {
   599  	if opts == nil || opts.Range == nil {
   600  		return nil
   601  	}
   602  
   603  	fv, err := v.Float64()
   604  	if err != nil {
   605  		return err
   606  	}
   607  
   608  	return validateNumberRange(fv, opts.Range)
   609  }
   610  
   611  func validateNumberRange(fv float64, nr *numberRange) error {
   612  	if nr == nil {
   613  		return nil
   614  	}
   615  
   616  	if (nr.leftInclude && fv < nr.left) || (!nr.leftInclude && fv <= nr.left) {
   617  		return errNumberRange
   618  	}
   619  
   620  	if (nr.rightInclude && fv > nr.right) || (!nr.rightInclude && fv >= nr.right) {
   621  		return errNumberRange
   622  	}
   623  
   624  	return nil
   625  }
   626  
   627  func validateValueInOptions(val interface{}, options []string) error {
   628  	if len(options) > 0 {
   629  		switch v := val.(type) {
   630  		case string:
   631  			if !stringx.Contains(options, v) {
   632  				return fmt.Errorf(`error: value "%s" is not defined in options "%v"`, v, options)
   633  			}
   634  		default:
   635  			if !stringx.Contains(options, Repr(v)) {
   636  				return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
   637  			}
   638  		}
   639  	}
   640  
   641  	return nil
   642  }
   643  
   644  func validateValueRange(mapValue interface{}, opts *fieldOptionsWithContext) error {
   645  	if opts == nil || opts.Range == nil {
   646  		return nil
   647  	}
   648  
   649  	fv, ok := toFloat64(mapValue)
   650  	if !ok {
   651  		return errNumberRange
   652  	}
   653  
   654  	return validateNumberRange(fv, opts.Range)
   655  }