github.com/shuguocloud/go-zero@v1.3.0/core/mapping/utils.go (about)

     1  package mapping
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"math"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/shuguocloud/go-zero/core/stringx"
    14  )
    15  
    16  const (
    17  	defaultOption      = "default"
    18  	stringOption       = "string"
    19  	optionalOption     = "optional"
    20  	optionsOption      = "options"
    21  	rangeOption        = "range"
    22  	optionSeparator    = "|"
    23  	equalToken         = "="
    24  	escapeChar         = '\\'
    25  	leftBracket        = '('
    26  	rightBracket       = ')'
    27  	leftSquareBracket  = '['
    28  	rightSquareBracket = ']'
    29  	segmentSeparator   = ','
    30  )
    31  
    32  var (
    33  	errUnsupportedType  = errors.New("unsupported type on setting field value")
    34  	errNumberRange      = errors.New("wrong number range setting")
    35  	optionsCache        = make(map[string]optionsCacheValue)
    36  	cacheLock           sync.RWMutex
    37  	structRequiredCache = make(map[reflect.Type]requiredCacheValue)
    38  	structCacheLock     sync.RWMutex
    39  )
    40  
    41  type (
    42  	optionsCacheValue struct {
    43  		key     string
    44  		options *fieldOptions
    45  		err     error
    46  	}
    47  
    48  	requiredCacheValue struct {
    49  		required bool
    50  		err      error
    51  	}
    52  )
    53  
    54  // Deref dereferences a type, if pointer type, returns its element type.
    55  func Deref(t reflect.Type) reflect.Type {
    56  	if t.Kind() == reflect.Ptr {
    57  		t = t.Elem()
    58  	}
    59  
    60  	return t
    61  }
    62  
    63  // Repr returns the string representation of v.
    64  func Repr(v interface{}) string {
    65  	if v == nil {
    66  		return ""
    67  	}
    68  
    69  	// if func (v *Type) String() string, we can't use Elem()
    70  	switch vt := v.(type) {
    71  	case fmt.Stringer:
    72  		return vt.String()
    73  	}
    74  
    75  	val := reflect.ValueOf(v)
    76  	if val.Kind() == reflect.Ptr && !val.IsNil() {
    77  		val = val.Elem()
    78  	}
    79  
    80  	return reprOfValue(val)
    81  }
    82  
    83  // ValidatePtr validates v if it's a valid pointer.
    84  func ValidatePtr(v *reflect.Value) error {
    85  	// sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr,
    86  	// panic otherwise
    87  	if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() {
    88  		return fmt.Errorf("not a valid pointer: %v", v)
    89  	}
    90  
    91  	return nil
    92  }
    93  
    94  func convertType(kind reflect.Kind, str string) (interface{}, error) {
    95  	switch kind {
    96  	case reflect.Bool:
    97  		return str == "1" || strings.ToLower(str) == "true", nil
    98  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    99  		intValue, err := strconv.ParseInt(str, 10, 64)
   100  		if err != nil {
   101  			return 0, fmt.Errorf("the value %q cannot parsed as int", str)
   102  		}
   103  
   104  		return intValue, nil
   105  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   106  		uintValue, err := strconv.ParseUint(str, 10, 64)
   107  		if err != nil {
   108  			return 0, fmt.Errorf("the value %q cannot parsed as uint", str)
   109  		}
   110  
   111  		return uintValue, nil
   112  	case reflect.Float32, reflect.Float64:
   113  		floatValue, err := strconv.ParseFloat(str, 64)
   114  		if err != nil {
   115  			return 0, fmt.Errorf("the value %q cannot parsed as float", str)
   116  		}
   117  
   118  		return floatValue, nil
   119  	case reflect.String:
   120  		return str, nil
   121  	default:
   122  		return nil, errUnsupportedType
   123  	}
   124  }
   125  
   126  func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fieldOptions, error) {
   127  	segments := parseSegments(value)
   128  	key := strings.TrimSpace(segments[0])
   129  	options := segments[1:]
   130  
   131  	if len(options) == 0 {
   132  		return key, nil, nil
   133  	}
   134  
   135  	var fieldOpts fieldOptions
   136  	for _, segment := range options {
   137  		option := strings.TrimSpace(segment)
   138  		if err := parseOption(&fieldOpts, field.Name, option); err != nil {
   139  			return "", nil, err
   140  		}
   141  	}
   142  
   143  	return key, &fieldOpts, nil
   144  }
   145  
   146  func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
   147  	numFields := tp.NumField()
   148  	for i := 0; i < numFields; i++ {
   149  		childField := tp.Field(i)
   150  		if usingDifferentKeys(tag, childField) {
   151  			return true, nil
   152  		}
   153  
   154  		_, opts, err := parseKeyAndOptions(tag, childField)
   155  		if err != nil {
   156  			return false, err
   157  		}
   158  
   159  		if opts == nil {
   160  			if childField.Type.Kind() != reflect.Struct {
   161  				return true, nil
   162  			}
   163  
   164  			if required, err := implicitValueRequiredStruct(tag, childField.Type); err != nil {
   165  				return false, err
   166  			} else if required {
   167  				return true, nil
   168  			}
   169  		} else if !opts.Optional && len(opts.Default) == 0 {
   170  			return true, nil
   171  		} else if len(opts.OptionalDep) > 0 && opts.OptionalDep[0] == notSymbol {
   172  			return true, nil
   173  		}
   174  	}
   175  
   176  	return false, nil
   177  }
   178  
   179  func isLeftInclude(b byte) (bool, error) {
   180  	switch b {
   181  	case '[':
   182  		return true, nil
   183  	case '(':
   184  		return false, nil
   185  	default:
   186  		return false, errNumberRange
   187  	}
   188  }
   189  
   190  func isRightInclude(b byte) (bool, error) {
   191  	switch b {
   192  	case ']':
   193  		return true, nil
   194  	case ')':
   195  		return false, nil
   196  	default:
   197  		return false, errNumberRange
   198  	}
   199  }
   200  
   201  func maybeNewValue(field reflect.StructField, value reflect.Value) {
   202  	if field.Type.Kind() == reflect.Ptr && value.IsNil() {
   203  		value.Set(reflect.New(value.Type().Elem()))
   204  	}
   205  }
   206  
   207  func parseGroupedSegments(val string) []string {
   208  	val = strings.TrimLeftFunc(val, func(r rune) bool {
   209  		return r == leftBracket || r == leftSquareBracket
   210  	})
   211  	val = strings.TrimRightFunc(val, func(r rune) bool {
   212  		return r == rightBracket || r == rightSquareBracket
   213  	})
   214  	return parseSegments(val)
   215  }
   216  
   217  // don't modify returned fieldOptions, it's cached and shared among different calls.
   218  func parseKeyAndOptions(tagName string, field reflect.StructField) (string, *fieldOptions, error) {
   219  	value := field.Tag.Get(tagName)
   220  	if len(value) == 0 {
   221  		return field.Name, nil, nil
   222  	}
   223  
   224  	cacheLock.RLock()
   225  	cache, ok := optionsCache[value]
   226  	cacheLock.RUnlock()
   227  	if ok {
   228  		return stringx.TakeOne(cache.key, field.Name), cache.options, cache.err
   229  	}
   230  
   231  	key, options, err := doParseKeyAndOptions(field, value)
   232  	cacheLock.Lock()
   233  	optionsCache[value] = optionsCacheValue{
   234  		key:     key,
   235  		options: options,
   236  		err:     err,
   237  	}
   238  	cacheLock.Unlock()
   239  
   240  	return stringx.TakeOne(key, field.Name), options, err
   241  }
   242  
   243  // support below notations:
   244  // [:5] (:5] [:5) (:5)
   245  // [1:] [1:) (1:] (1:)
   246  // [1:5] [1:5) (1:5] (1:5)
   247  func parseNumberRange(str string) (*numberRange, error) {
   248  	if len(str) == 0 {
   249  		return nil, errNumberRange
   250  	}
   251  
   252  	leftInclude, err := isLeftInclude(str[0])
   253  	if err != nil {
   254  		return nil, err
   255  	}
   256  
   257  	str = str[1:]
   258  	if len(str) == 0 {
   259  		return nil, errNumberRange
   260  	}
   261  
   262  	rightInclude, err := isRightInclude(str[len(str)-1])
   263  	if err != nil {
   264  		return nil, err
   265  	}
   266  
   267  	str = str[:len(str)-1]
   268  	fields := strings.Split(str, ":")
   269  	if len(fields) != 2 {
   270  		return nil, errNumberRange
   271  	}
   272  
   273  	if len(fields[0]) == 0 && len(fields[1]) == 0 {
   274  		return nil, errNumberRange
   275  	}
   276  
   277  	var left float64
   278  	if len(fields[0]) > 0 {
   279  		var err error
   280  		if left, err = strconv.ParseFloat(fields[0], 64); err != nil {
   281  			return nil, err
   282  		}
   283  	} else {
   284  		left = -math.MaxFloat64
   285  	}
   286  
   287  	var right float64
   288  	if len(fields[1]) > 0 {
   289  		var err error
   290  		if right, err = strconv.ParseFloat(fields[1], 64); err != nil {
   291  			return nil, err
   292  		}
   293  	} else {
   294  		right = math.MaxFloat64
   295  	}
   296  
   297  	return &numberRange{
   298  		left:         left,
   299  		leftInclude:  leftInclude,
   300  		right:        right,
   301  		rightInclude: rightInclude,
   302  	}, nil
   303  }
   304  
   305  func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
   306  	switch {
   307  	case option == stringOption:
   308  		fieldOpts.FromString = true
   309  	case strings.HasPrefix(option, optionalOption):
   310  		segs := strings.Split(option, equalToken)
   311  		switch len(segs) {
   312  		case 1:
   313  			fieldOpts.Optional = true
   314  		case 2:
   315  			fieldOpts.Optional = true
   316  			fieldOpts.OptionalDep = segs[1]
   317  		default:
   318  			return fmt.Errorf("field %s has wrong optional", fieldName)
   319  		}
   320  	case option == optionalOption:
   321  		fieldOpts.Optional = true
   322  	case strings.HasPrefix(option, optionsOption):
   323  		segs := strings.Split(option, equalToken)
   324  		if len(segs) != 2 {
   325  			return fmt.Errorf("field %s has wrong options", fieldName)
   326  		}
   327  
   328  		fieldOpts.Options = parseOptions(segs[1])
   329  	case strings.HasPrefix(option, defaultOption):
   330  		segs := strings.Split(option, equalToken)
   331  		if len(segs) != 2 {
   332  			return fmt.Errorf("field %s has wrong default option", fieldName)
   333  		}
   334  
   335  		fieldOpts.Default = strings.TrimSpace(segs[1])
   336  	case strings.HasPrefix(option, rangeOption):
   337  		segs := strings.Split(option, equalToken)
   338  		if len(segs) != 2 {
   339  			return fmt.Errorf("field %s has wrong range", fieldName)
   340  		}
   341  
   342  		nr, err := parseNumberRange(segs[1])
   343  		if err != nil {
   344  			return err
   345  		}
   346  
   347  		fieldOpts.Range = nr
   348  	}
   349  
   350  	return nil
   351  }
   352  
   353  // parseOptions parses the given options in tag.
   354  // for example: `json:"name,options=foo|bar"` or `json:"name,options=[foo,bar]"`
   355  func parseOptions(val string) []string {
   356  	if len(val) == 0 {
   357  		return nil
   358  	}
   359  
   360  	if val[0] == leftSquareBracket {
   361  		return parseGroupedSegments(val)
   362  	}
   363  
   364  	return strings.Split(val, optionSeparator)
   365  }
   366  
   367  func parseSegments(val string) []string {
   368  	var segments []string
   369  	var escaped, grouped bool
   370  	var buf strings.Builder
   371  
   372  	for _, ch := range val {
   373  		if escaped {
   374  			buf.WriteRune(ch)
   375  			escaped = false
   376  			continue
   377  		}
   378  
   379  		switch ch {
   380  		case segmentSeparator:
   381  			if grouped {
   382  				buf.WriteRune(ch)
   383  			} else {
   384  				// need to trim spaces, but we cannot ignore empty string,
   385  				// because the first segment stands for the key might be empty.
   386  				// if ignored, the later tag will be used as the key.
   387  				segments = append(segments, strings.TrimSpace(buf.String()))
   388  				buf.Reset()
   389  			}
   390  		case escapeChar:
   391  			if grouped {
   392  				buf.WriteRune(ch)
   393  			} else {
   394  				escaped = true
   395  			}
   396  		case leftBracket, leftSquareBracket:
   397  			buf.WriteRune(ch)
   398  			grouped = true
   399  		case rightBracket, rightSquareBracket:
   400  			buf.WriteRune(ch)
   401  			grouped = false
   402  		default:
   403  			buf.WriteRune(ch)
   404  		}
   405  	}
   406  
   407  	last := strings.TrimSpace(buf.String())
   408  	// ignore last empty string
   409  	if len(last) > 0 {
   410  		segments = append(segments, last)
   411  	}
   412  
   413  	return segments
   414  }
   415  
   416  func reprOfValue(val reflect.Value) string {
   417  	switch vt := val.Interface().(type) {
   418  	case bool:
   419  		return strconv.FormatBool(vt)
   420  	case error:
   421  		return vt.Error()
   422  	case float32:
   423  		return strconv.FormatFloat(float64(vt), 'f', -1, 32)
   424  	case float64:
   425  		return strconv.FormatFloat(vt, 'f', -1, 64)
   426  	case fmt.Stringer:
   427  		return vt.String()
   428  	case int:
   429  		return strconv.Itoa(vt)
   430  	case int8:
   431  		return strconv.Itoa(int(vt))
   432  	case int16:
   433  		return strconv.Itoa(int(vt))
   434  	case int32:
   435  		return strconv.Itoa(int(vt))
   436  	case int64:
   437  		return strconv.FormatInt(vt, 10)
   438  	case string:
   439  		return vt
   440  	case uint:
   441  		return strconv.FormatUint(uint64(vt), 10)
   442  	case uint8:
   443  		return strconv.FormatUint(uint64(vt), 10)
   444  	case uint16:
   445  		return strconv.FormatUint(uint64(vt), 10)
   446  	case uint32:
   447  		return strconv.FormatUint(uint64(vt), 10)
   448  	case uint64:
   449  		return strconv.FormatUint(vt, 10)
   450  	case []byte:
   451  		return string(vt)
   452  	default:
   453  		return fmt.Sprint(val.Interface())
   454  	}
   455  }
   456  
   457  func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interface{}) error {
   458  	switch kind {
   459  	case reflect.Bool:
   460  		value.SetBool(v.(bool))
   461  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   462  		value.SetInt(v.(int64))
   463  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   464  		value.SetUint(v.(uint64))
   465  	case reflect.Float32, reflect.Float64:
   466  		value.SetFloat(v.(float64))
   467  	case reflect.String:
   468  		value.SetString(v.(string))
   469  	default:
   470  		return errUnsupportedType
   471  	}
   472  
   473  	return nil
   474  }
   475  
   476  func setValue(kind reflect.Kind, value reflect.Value, str string) error {
   477  	if !value.CanSet() {
   478  		return errValueNotSettable
   479  	}
   480  
   481  	v, err := convertType(kind, str)
   482  	if err != nil {
   483  		return err
   484  	}
   485  
   486  	return setMatchedPrimitiveValue(kind, value, v)
   487  }
   488  
   489  func structValueRequired(tag string, tp reflect.Type) (bool, error) {
   490  	structCacheLock.RLock()
   491  	val, ok := structRequiredCache[tp]
   492  	structCacheLock.RUnlock()
   493  	if ok {
   494  		return val.required, val.err
   495  	}
   496  
   497  	required, err := implicitValueRequiredStruct(tag, tp)
   498  	structCacheLock.Lock()
   499  	structRequiredCache[tp] = requiredCacheValue{
   500  		required: required,
   501  		err:      err,
   502  	}
   503  	structCacheLock.Unlock()
   504  
   505  	return required, err
   506  }
   507  
   508  func toFloat64(v interface{}) (float64, bool) {
   509  	switch val := v.(type) {
   510  	case int:
   511  		return float64(val), true
   512  	case int8:
   513  		return float64(val), true
   514  	case int16:
   515  		return float64(val), true
   516  	case int32:
   517  		return float64(val), true
   518  	case int64:
   519  		return float64(val), true
   520  	case uint:
   521  		return float64(val), true
   522  	case uint8:
   523  		return float64(val), true
   524  	case uint16:
   525  		return float64(val), true
   526  	case uint32:
   527  		return float64(val), true
   528  	case uint64:
   529  		return float64(val), true
   530  	case float32:
   531  		return float64(val), true
   532  	case float64:
   533  		return val, true
   534  	default:
   535  		return 0, false
   536  	}
   537  }
   538  
   539  func usingDifferentKeys(key string, field reflect.StructField) bool {
   540  	if len(field.Tag) > 0 {
   541  		if _, ok := field.Tag.Lookup(key); !ok {
   542  			return true
   543  		}
   544  	}
   545  
   546  	return false
   547  }
   548  
   549  func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string, opts *fieldOptionsWithContext) error {
   550  	if !value.CanSet() {
   551  		return errValueNotSettable
   552  	}
   553  
   554  	v, err := convertType(kind, str)
   555  	if err != nil {
   556  		return err
   557  	}
   558  
   559  	if err := validateValueRange(v, opts); err != nil {
   560  		return err
   561  	}
   562  
   563  	return setMatchedPrimitiveValue(kind, value, v)
   564  }
   565  
   566  func validateJsonNumberRange(v json.Number, opts *fieldOptionsWithContext) error {
   567  	if opts == nil || opts.Range == nil {
   568  		return nil
   569  	}
   570  
   571  	fv, err := v.Float64()
   572  	if err != nil {
   573  		return err
   574  	}
   575  
   576  	return validateNumberRange(fv, opts.Range)
   577  }
   578  
   579  func validateNumberRange(fv float64, nr *numberRange) error {
   580  	if nr == nil {
   581  		return nil
   582  	}
   583  
   584  	if (nr.leftInclude && fv < nr.left) || (!nr.leftInclude && fv <= nr.left) {
   585  		return errNumberRange
   586  	}
   587  
   588  	if (nr.rightInclude && fv > nr.right) || (!nr.rightInclude && fv >= nr.right) {
   589  		return errNumberRange
   590  	}
   591  
   592  	return nil
   593  }
   594  
   595  func validateValueInOptions(options []string, value interface{}) error {
   596  	if len(options) > 0 {
   597  		switch v := value.(type) {
   598  		case string:
   599  			if !stringx.Contains(options, v) {
   600  				return fmt.Errorf(`error: value "%s" is not defined in options "%v"`, v, options)
   601  			}
   602  		default:
   603  			if !stringx.Contains(options, Repr(v)) {
   604  				return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, value, options)
   605  			}
   606  		}
   607  	}
   608  
   609  	return nil
   610  }
   611  
   612  func validateValueRange(mapValue interface{}, opts *fieldOptionsWithContext) error {
   613  	if opts == nil || opts.Range == nil {
   614  		return nil
   615  	}
   616  
   617  	fv, ok := toFloat64(mapValue)
   618  	if !ok {
   619  		return errNumberRange
   620  	}
   621  
   622  	return validateNumberRange(fv, opts.Range)
   623  }