github.com/metacubex/mihomo@v1.18.5/common/structure/structure.go (about)

     1  package structure
     2  
     3  // references: https://github.com/mitchellh/mapstructure
     4  
     5  import (
     6  	"encoding/base64"
     7  	"fmt"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  )
    12  
    13  // Option is the configuration that is used to create a new decoder
    14  type Option struct {
    15  	TagName          string
    16  	WeaklyTypedInput bool
    17  	KeyReplacer      *strings.Replacer
    18  }
    19  
    20  var DefaultKeyReplacer = strings.NewReplacer("_", "-")
    21  
    22  // Decoder is the core of structure
    23  type Decoder struct {
    24  	option *Option
    25  }
    26  
    27  // NewDecoder return a Decoder by Option
    28  func NewDecoder(option Option) *Decoder {
    29  	if option.TagName == "" {
    30  		option.TagName = "structure"
    31  	}
    32  	return &Decoder{option: &option}
    33  }
    34  
    35  // Decode transform a map[string]any to a struct
    36  func (d *Decoder) Decode(src map[string]any, dst any) error {
    37  	if reflect.TypeOf(dst).Kind() != reflect.Ptr {
    38  		return fmt.Errorf("decode must recive a ptr struct")
    39  	}
    40  	t := reflect.TypeOf(dst).Elem()
    41  	v := reflect.ValueOf(dst).Elem()
    42  	for idx := 0; idx < v.NumField(); idx++ {
    43  		field := t.Field(idx)
    44  		if field.Anonymous {
    45  			if err := d.decodeStruct(field.Name, src, v.Field(idx)); err != nil {
    46  				return err
    47  			}
    48  			continue
    49  		}
    50  
    51  		tag := field.Tag.Get(d.option.TagName)
    52  		key, omitKey, found := strings.Cut(tag, ",")
    53  		omitempty := found && omitKey == "omitempty"
    54  
    55  		value, ok := src[key]
    56  		if !ok {
    57  			if d.option.KeyReplacer != nil {
    58  				key = d.option.KeyReplacer.Replace(key)
    59  			}
    60  
    61  			for _strKey := range src {
    62  				strKey := _strKey
    63  				if d.option.KeyReplacer != nil {
    64  					strKey = d.option.KeyReplacer.Replace(strKey)
    65  				}
    66  				if strings.EqualFold(key, strKey) {
    67  					value = src[_strKey]
    68  					ok = true
    69  					break
    70  				}
    71  			}
    72  		}
    73  		if !ok || value == nil {
    74  			if omitempty {
    75  				continue
    76  			}
    77  			return fmt.Errorf("key '%s' missing", key)
    78  		}
    79  
    80  		err := d.decode(key, value, v.Field(idx))
    81  		if err != nil {
    82  			return err
    83  		}
    84  	}
    85  	return nil
    86  }
    87  
    88  func (d *Decoder) decode(name string, data any, val reflect.Value) error {
    89  	kind := val.Kind()
    90  	switch {
    91  	case isInt(kind):
    92  		return d.decodeInt(name, data, val)
    93  	case isUint(kind):
    94  		return d.decodeUint(name, data, val)
    95  	case isFloat(kind):
    96  		return d.decodeFloat(name, data, val)
    97  	}
    98  	switch kind {
    99  	case reflect.Pointer:
   100  		if val.IsNil() {
   101  			val.Set(reflect.New(val.Type().Elem()))
   102  		}
   103  		return d.decode(name, data, val.Elem())
   104  	case reflect.String:
   105  		return d.decodeString(name, data, val)
   106  	case reflect.Bool:
   107  		return d.decodeBool(name, data, val)
   108  	case reflect.Slice:
   109  		return d.decodeSlice(name, data, val)
   110  	case reflect.Map:
   111  		return d.decodeMap(name, data, val)
   112  	case reflect.Interface:
   113  		return d.setInterface(name, data, val)
   114  	case reflect.Struct:
   115  		return d.decodeStruct(name, data, val)
   116  	default:
   117  		return fmt.Errorf("type %s not support", val.Kind().String())
   118  	}
   119  }
   120  
   121  func isInt(kind reflect.Kind) bool {
   122  	switch kind {
   123  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   124  		return true
   125  	default:
   126  		return false
   127  	}
   128  }
   129  
   130  func isUint(kind reflect.Kind) bool {
   131  	switch kind {
   132  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   133  		return true
   134  	default:
   135  		return false
   136  	}
   137  }
   138  
   139  func isFloat(kind reflect.Kind) bool {
   140  	switch kind {
   141  	case reflect.Float32, reflect.Float64:
   142  		return true
   143  	default:
   144  		return false
   145  	}
   146  }
   147  
   148  func (d *Decoder) decodeInt(name string, data any, val reflect.Value) (err error) {
   149  	dataVal := reflect.ValueOf(data)
   150  	kind := dataVal.Kind()
   151  	switch {
   152  	case isInt(kind):
   153  		val.SetInt(dataVal.Int())
   154  	case isUint(kind) && d.option.WeaklyTypedInput:
   155  		val.SetInt(int64(dataVal.Uint()))
   156  	case isFloat(kind) && d.option.WeaklyTypedInput:
   157  		val.SetInt(int64(dataVal.Float()))
   158  	case kind == reflect.String && d.option.WeaklyTypedInput:
   159  		var i int64
   160  		i, err = strconv.ParseInt(dataVal.String(), 0, val.Type().Bits())
   161  		if err == nil {
   162  			val.SetInt(i)
   163  		} else {
   164  			err = fmt.Errorf("cannot parse '%s' as int: %s", name, err)
   165  		}
   166  	default:
   167  		err = fmt.Errorf(
   168  			"'%s' expected type '%s', got unconvertible type '%s'",
   169  			name, val.Type(), dataVal.Type(),
   170  		)
   171  	}
   172  	return err
   173  }
   174  
   175  func (d *Decoder) decodeUint(name string, data any, val reflect.Value) (err error) {
   176  	dataVal := reflect.ValueOf(data)
   177  	kind := dataVal.Kind()
   178  	switch {
   179  	case isUint(kind):
   180  		val.SetUint(dataVal.Uint())
   181  	case isInt(kind) && d.option.WeaklyTypedInput:
   182  		val.SetUint(uint64(dataVal.Int()))
   183  	case isFloat(kind) && d.option.WeaklyTypedInput:
   184  		val.SetUint(uint64(dataVal.Float()))
   185  	case kind == reflect.String && d.option.WeaklyTypedInput:
   186  		var i uint64
   187  		i, err = strconv.ParseUint(dataVal.String(), 0, val.Type().Bits())
   188  		if err == nil {
   189  			val.SetUint(i)
   190  		} else {
   191  			err = fmt.Errorf("cannot parse '%s' as int: %s", name, err)
   192  		}
   193  	default:
   194  		err = fmt.Errorf(
   195  			"'%s' expected type '%s', got unconvertible type '%s'",
   196  			name, val.Type(), dataVal.Type(),
   197  		)
   198  	}
   199  	return err
   200  }
   201  
   202  func (d *Decoder) decodeFloat(name string, data any, val reflect.Value) (err error) {
   203  	dataVal := reflect.ValueOf(data)
   204  	kind := dataVal.Kind()
   205  	switch {
   206  	case isFloat(kind):
   207  		val.SetFloat(dataVal.Float())
   208  	case isUint(kind):
   209  		val.SetFloat(float64(dataVal.Uint()))
   210  	case isInt(kind) && d.option.WeaklyTypedInput:
   211  		val.SetFloat(float64(dataVal.Int()))
   212  	case kind == reflect.String && d.option.WeaklyTypedInput:
   213  		var i float64
   214  		i, err = strconv.ParseFloat(dataVal.String(), val.Type().Bits())
   215  		if err == nil {
   216  			val.SetFloat(i)
   217  		} else {
   218  			err = fmt.Errorf("cannot parse '%s' as int: %s", name, err)
   219  		}
   220  	default:
   221  		err = fmt.Errorf(
   222  			"'%s' expected type '%s', got unconvertible type '%s'",
   223  			name, val.Type(), dataVal.Type(),
   224  		)
   225  	}
   226  	return err
   227  }
   228  
   229  func (d *Decoder) decodeString(name string, data any, val reflect.Value) (err error) {
   230  	dataVal := reflect.ValueOf(data)
   231  	kind := dataVal.Kind()
   232  	switch {
   233  	case kind == reflect.String:
   234  		val.SetString(dataVal.String())
   235  	case isInt(kind) && d.option.WeaklyTypedInput:
   236  		val.SetString(strconv.FormatInt(dataVal.Int(), 10))
   237  	case isUint(kind) && d.option.WeaklyTypedInput:
   238  		val.SetString(strconv.FormatUint(dataVal.Uint(), 10))
   239  	case isFloat(kind) && d.option.WeaklyTypedInput:
   240  		val.SetString(strconv.FormatFloat(dataVal.Float(), 'E', -1, dataVal.Type().Bits()))
   241  	default:
   242  		err = fmt.Errorf(
   243  			"'%s' expected type '%s', got unconvertible type '%s'",
   244  			name, val.Type(), dataVal.Type(),
   245  		)
   246  	}
   247  	return err
   248  }
   249  
   250  func (d *Decoder) decodeBool(name string, data any, val reflect.Value) (err error) {
   251  	dataVal := reflect.ValueOf(data)
   252  	kind := dataVal.Kind()
   253  	switch {
   254  	case kind == reflect.Bool:
   255  		val.SetBool(dataVal.Bool())
   256  	case isInt(kind) && d.option.WeaklyTypedInput:
   257  		val.SetBool(dataVal.Int() != 0)
   258  	case isUint(kind) && d.option.WeaklyTypedInput:
   259  		val.SetString(strconv.FormatUint(dataVal.Uint(), 10))
   260  	default:
   261  		err = fmt.Errorf(
   262  			"'%s' expected type '%s', got unconvertible type '%s'",
   263  			name, val.Type(), dataVal.Type(),
   264  		)
   265  	}
   266  	return err
   267  }
   268  
   269  func (d *Decoder) decodeSlice(name string, data any, val reflect.Value) error {
   270  	dataVal := reflect.Indirect(reflect.ValueOf(data))
   271  	valType := val.Type()
   272  	valElemType := valType.Elem()
   273  
   274  	if dataVal.Kind() == reflect.String && valElemType.Kind() == reflect.Uint8 { // from encoding/json
   275  		s := []byte(dataVal.String())
   276  		b := make([]byte, base64.StdEncoding.DecodedLen(len(s)))
   277  		n, err := base64.StdEncoding.Decode(b, s)
   278  		if err != nil {
   279  			return fmt.Errorf("try decode '%s' by base64 error: %w", name, err)
   280  		}
   281  		val.SetBytes(b[:n])
   282  		return nil
   283  	}
   284  
   285  	if dataVal.Kind() != reflect.Slice {
   286  		return fmt.Errorf("'%s' is not a slice", name)
   287  	}
   288  
   289  	valSlice := val
   290  	// make a new slice with cap(val)==cap(dataVal)
   291  	// the caller can determine whether the original configuration contains this item by judging whether the value is nil.
   292  	valSlice = reflect.MakeSlice(valType, 0, dataVal.Len())
   293  	for i := 0; i < dataVal.Len(); i++ {
   294  		currentData := dataVal.Index(i).Interface()
   295  		for valSlice.Len() <= i {
   296  			valSlice = reflect.Append(valSlice, reflect.Zero(valElemType))
   297  		}
   298  		fieldName := fmt.Sprintf("%s[%d]", name, i)
   299  		if currentData == nil {
   300  			// in weakly type mode, null will convert to zero value
   301  			if d.option.WeaklyTypedInput {
   302  				continue
   303  			}
   304  			// in non-weakly type mode, null will convert to nil if element's zero value is nil, otherwise return an error
   305  			if elemKind := valElemType.Kind(); elemKind == reflect.Map || elemKind == reflect.Slice {
   306  				continue
   307  			}
   308  			return fmt.Errorf("'%s' can not be null", fieldName)
   309  		}
   310  		currentField := valSlice.Index(i)
   311  		if err := d.decode(fieldName, currentData, currentField); err != nil {
   312  			return err
   313  		}
   314  	}
   315  
   316  	val.Set(valSlice)
   317  	return nil
   318  }
   319  
   320  func (d *Decoder) decodeMap(name string, data any, val reflect.Value) error {
   321  	valType := val.Type()
   322  	valKeyType := valType.Key()
   323  	valElemType := valType.Elem()
   324  
   325  	valMap := val
   326  
   327  	if valMap.IsNil() {
   328  		mapType := reflect.MapOf(valKeyType, valElemType)
   329  		valMap = reflect.MakeMap(mapType)
   330  	}
   331  
   332  	dataVal := reflect.Indirect(reflect.ValueOf(data))
   333  	if dataVal.Kind() != reflect.Map {
   334  		return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind())
   335  	}
   336  
   337  	return d.decodeMapFromMap(name, dataVal, val, valMap)
   338  }
   339  
   340  func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error {
   341  	valType := val.Type()
   342  	valKeyType := valType.Key()
   343  	valElemType := valType.Elem()
   344  
   345  	errors := make([]string, 0)
   346  
   347  	if dataVal.Len() == 0 {
   348  		if dataVal.IsNil() {
   349  			if !val.IsNil() {
   350  				val.Set(dataVal)
   351  			}
   352  		} else {
   353  			val.Set(valMap)
   354  		}
   355  
   356  		return nil
   357  	}
   358  
   359  	for _, k := range dataVal.MapKeys() {
   360  		fieldName := fmt.Sprintf("%s[%s]", name, k)
   361  
   362  		currentKey := reflect.Indirect(reflect.New(valKeyType))
   363  		if err := d.decode(fieldName, k.Interface(), currentKey); err != nil {
   364  			errors = append(errors, err.Error())
   365  			continue
   366  		}
   367  
   368  		v := dataVal.MapIndex(k).Interface()
   369  		if v == nil {
   370  			errors = append(errors, fmt.Sprintf("filed %s invalid", fieldName))
   371  			continue
   372  		}
   373  
   374  		currentVal := reflect.Indirect(reflect.New(valElemType))
   375  		if err := d.decode(fieldName, v, currentVal); err != nil {
   376  			errors = append(errors, err.Error())
   377  			continue
   378  		}
   379  
   380  		valMap.SetMapIndex(currentKey, currentVal)
   381  	}
   382  
   383  	val.Set(valMap)
   384  
   385  	if len(errors) > 0 {
   386  		return fmt.Errorf(strings.Join(errors, ","))
   387  	}
   388  
   389  	return nil
   390  }
   391  
   392  func (d *Decoder) decodeStruct(name string, data any, val reflect.Value) error {
   393  	dataVal := reflect.Indirect(reflect.ValueOf(data))
   394  
   395  	// If the type of the value to write to and the data match directly,
   396  	// then we just set it directly instead of recursing into the structure.
   397  	if dataVal.Type() == val.Type() {
   398  		val.Set(dataVal)
   399  		return nil
   400  	}
   401  
   402  	dataValKind := dataVal.Kind()
   403  	switch dataValKind {
   404  	case reflect.Map:
   405  		return d.decodeStructFromMap(name, dataVal, val)
   406  	default:
   407  		return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind())
   408  	}
   409  }
   410  
   411  func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) error {
   412  	dataValType := dataVal.Type()
   413  	if kind := dataValType.Key().Kind(); kind != reflect.String && kind != reflect.Interface {
   414  		return fmt.Errorf(
   415  			"'%s' needs a map with string keys, has '%s' keys",
   416  			name, dataValType.Key().Kind())
   417  	}
   418  
   419  	dataValKeys := make(map[reflect.Value]struct{})
   420  	dataValKeysUnused := make(map[any]struct{})
   421  	for _, dataValKey := range dataVal.MapKeys() {
   422  		dataValKeys[dataValKey] = struct{}{}
   423  		dataValKeysUnused[dataValKey.Interface()] = struct{}{}
   424  	}
   425  
   426  	errors := make([]string, 0)
   427  
   428  	// This slice will keep track of all the structs we'll be decoding.
   429  	// There can be more than one struct if there are embedded structs
   430  	// that are squashed.
   431  	structs := make([]reflect.Value, 1, 5)
   432  	structs[0] = val
   433  
   434  	// Compile the list of all the fields that we're going to be decoding
   435  	// from all the structs.
   436  	type field struct {
   437  		field reflect.StructField
   438  		val   reflect.Value
   439  	}
   440  	var fields []field
   441  	for len(structs) > 0 {
   442  		structVal := structs[0]
   443  		structs = structs[1:]
   444  
   445  		structType := structVal.Type()
   446  
   447  		for i := 0; i < structType.NumField(); i++ {
   448  			fieldType := structType.Field(i)
   449  			fieldKind := fieldType.Type.Kind()
   450  
   451  			// If "squash" is specified in the tag, we squash the field down.
   452  			squash := false
   453  			tagParts := strings.Split(fieldType.Tag.Get(d.option.TagName), ",")
   454  			for _, tag := range tagParts[1:] {
   455  				if tag == "squash" {
   456  					squash = true
   457  					break
   458  				}
   459  			}
   460  
   461  			if squash {
   462  				if fieldKind != reflect.Struct {
   463  					errors = append(errors,
   464  						fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldKind).Error())
   465  				} else {
   466  					structs = append(structs, structVal.FieldByName(fieldType.Name))
   467  				}
   468  				continue
   469  			}
   470  
   471  			// Normal struct field, store it away
   472  			fields = append(fields, field{fieldType, structVal.Field(i)})
   473  		}
   474  	}
   475  
   476  	// for fieldType, field := range fields {
   477  	for _, f := range fields {
   478  		field, fieldValue := f.field, f.val
   479  		fieldName := field.Name
   480  
   481  		tagValue := field.Tag.Get(d.option.TagName)
   482  		tagValue = strings.SplitN(tagValue, ",", 2)[0]
   483  		if tagValue != "" {
   484  			fieldName = tagValue
   485  		}
   486  
   487  		rawMapKey := reflect.ValueOf(fieldName)
   488  		rawMapVal := dataVal.MapIndex(rawMapKey)
   489  		if !rawMapVal.IsValid() {
   490  			// Do a slower search by iterating over each key and
   491  			// doing case-insensitive search.
   492  			if d.option.KeyReplacer != nil {
   493  				fieldName = d.option.KeyReplacer.Replace(fieldName)
   494  			}
   495  			for dataValKey := range dataValKeys {
   496  				mK, ok := dataValKey.Interface().(string)
   497  				if !ok {
   498  					// Not a string key
   499  					continue
   500  				}
   501  				if d.option.KeyReplacer != nil {
   502  					mK = d.option.KeyReplacer.Replace(mK)
   503  				}
   504  
   505  				if strings.EqualFold(mK, fieldName) {
   506  					rawMapKey = dataValKey
   507  					rawMapVal = dataVal.MapIndex(dataValKey)
   508  					break
   509  				}
   510  			}
   511  
   512  			if !rawMapVal.IsValid() {
   513  				// There was no matching key in the map for the value in
   514  				// the struct. Just ignore.
   515  				continue
   516  			}
   517  		}
   518  
   519  		// Delete the key we're using from the unused map so we stop tracking
   520  		delete(dataValKeysUnused, rawMapKey.Interface())
   521  
   522  		if !fieldValue.IsValid() {
   523  			// This should never happen
   524  			panic("field is not valid")
   525  		}
   526  
   527  		// If we can't set the field, then it is unexported or something,
   528  		// and we just continue onwards.
   529  		if !fieldValue.CanSet() {
   530  			continue
   531  		}
   532  
   533  		// If the name is empty string, then we're at the root, and we
   534  		// don't dot-join the fields.
   535  		if name != "" {
   536  			fieldName = fmt.Sprintf("%s.%s", name, fieldName)
   537  		}
   538  
   539  		if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil {
   540  			errors = append(errors, err.Error())
   541  		}
   542  	}
   543  
   544  	if len(errors) > 0 {
   545  		return fmt.Errorf(strings.Join(errors, ","))
   546  	}
   547  
   548  	return nil
   549  }
   550  
   551  func (d *Decoder) setInterface(name string, data any, val reflect.Value) (err error) {
   552  	dataVal := reflect.ValueOf(data)
   553  	val.Set(dataVal)
   554  	return nil
   555  }