github.com/yaling888/clash@v1.53.0/common/structure/structure.go (about)

     1  package structure
     2  
     3  // references: https://github.com/mitchellh/mapstructure
     4  
     5  import (
     6  	"errors"
     7  	"fmt"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/yaling888/clash/common/errors2"
    14  )
    15  
    16  var durationType = reflect.TypeOf(time.Duration(0))
    17  
    18  // Option is the configuration that is used to create a new decoder
    19  type Option struct {
    20  	TagName          string
    21  	WeaklyTypedInput bool
    22  }
    23  
    24  // Decoder is the core of structure
    25  type Decoder struct {
    26  	option *Option
    27  }
    28  
    29  // NewDecoder return a Decoder by Option
    30  func NewDecoder(option Option) *Decoder {
    31  	if option.TagName == "" {
    32  		option.TagName = "structure"
    33  	}
    34  	return &Decoder{option: &option}
    35  }
    36  
    37  // Decode transform a map[string]any to a struct
    38  func (d *Decoder) Decode(src map[string]any, dst any) error {
    39  	if reflect.TypeOf(dst).Kind() != reflect.Ptr {
    40  		return fmt.Errorf("decode must recive a ptr struct")
    41  	}
    42  	t := reflect.TypeOf(dst).Elem()
    43  	v := reflect.ValueOf(dst).Elem()
    44  	for idx := 0; idx < v.NumField(); idx++ {
    45  		field := t.Field(idx)
    46  		if field.Anonymous {
    47  			if err := d.decodeStruct(field.Name, src, v.Field(idx)); err != nil {
    48  				return err
    49  			}
    50  			continue
    51  		}
    52  
    53  		tag := field.Tag.Get(d.option.TagName)
    54  		key, omitKey, found := strings.Cut(tag, ",")
    55  		omitempty := found && omitKey == "omitempty"
    56  
    57  		value, ok := src[key]
    58  		if !ok || value == nil {
    59  			if omitempty {
    60  				continue
    61  			}
    62  			return fmt.Errorf("key '%s' missing", key)
    63  		}
    64  
    65  		err := d.decode(key, value, v.Field(idx))
    66  		if err != nil {
    67  			return err
    68  		}
    69  	}
    70  	return nil
    71  }
    72  
    73  func (d *Decoder) decode(name string, data any, val reflect.Value) error {
    74  	switch val.Kind() {
    75  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    76  		return d.decodeInt(name, data, val)
    77  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
    78  		return d.decodeUint(name, data, val)
    79  	case reflect.String:
    80  		return d.decodeString(name, data, val)
    81  	case reflect.Bool:
    82  		return d.decodeBool(name, data, val)
    83  	case reflect.Slice:
    84  		return d.decodeSlice(name, data, val)
    85  	case reflect.Map:
    86  		return d.decodeMap(name, data, val)
    87  	case reflect.Interface:
    88  		return d.setInterface(name, data, val)
    89  	case reflect.Struct:
    90  		return d.decodeStruct(name, data, val)
    91  	default:
    92  		return fmt.Errorf("type %s not support", val.Kind().String())
    93  	}
    94  }
    95  
    96  func (d *Decoder) decodeInt(name string, data any, val reflect.Value) (err error) {
    97  	dataVal := reflect.ValueOf(data)
    98  	switch {
    99  	case dataVal.CanInt():
   100  		resolved := dataVal.Int()
   101  		if val.Type() == durationType {
   102  			resolved *= 1e9
   103  		}
   104  		val.SetInt(resolved)
   105  	case dataVal.CanUint():
   106  		resolved := dataVal.Uint()
   107  		if val.Type() == durationType {
   108  			resolved *= 1e9
   109  		}
   110  		val.SetInt(int64(resolved))
   111  	case dataVal.CanFloat() && d.option.WeaklyTypedInput:
   112  		resolved := dataVal.Float()
   113  		if val.Type() == durationType {
   114  			resolved *= 1e9
   115  		}
   116  		val.SetInt(int64(resolved))
   117  	case dataVal.Kind() == reflect.String && d.option.WeaklyTypedInput:
   118  		var (
   119  			rs       int64
   120  			valType  = val.Type()
   121  			resolved = dataVal.String()
   122  		)
   123  
   124  		rs, err = strconv.ParseInt(resolved, 0, valType.Bits())
   125  		if err == nil {
   126  			if valType == durationType {
   127  				rs *= 1e9
   128  			}
   129  			val.SetInt(rs)
   130  		} else if valType == durationType {
   131  			dur, err1 := time.ParseDuration(resolved)
   132  			if err1 == nil {
   133  				val.SetInt(int64(dur))
   134  			}
   135  			err = err1
   136  		}
   137  
   138  		if err != nil {
   139  			err = fmt.Errorf("cannot parse '%s' as int: %w", name, err)
   140  		}
   141  	default:
   142  		err = fmt.Errorf(
   143  			"'%s' expected type '%s', got unconvertible type '%s'",
   144  			name, val.Type(), dataVal.Type(),
   145  		)
   146  	}
   147  	return err
   148  }
   149  
   150  func (d *Decoder) decodeUint(name string, data any, val reflect.Value) (err error) {
   151  	dataVal := reflect.ValueOf(data)
   152  	switch {
   153  	case dataVal.CanInt():
   154  		val.SetUint(uint64(dataVal.Int()))
   155  	case dataVal.CanUint():
   156  		val.SetUint(dataVal.Uint())
   157  	case dataVal.CanFloat() && d.option.WeaklyTypedInput:
   158  		val.SetUint(uint64(dataVal.Float()))
   159  	case dataVal.Kind() == reflect.String && d.option.WeaklyTypedInput:
   160  		var i uint64
   161  		i, err = strconv.ParseUint(dataVal.String(), 0, val.Type().Bits())
   162  		if err == nil {
   163  			val.SetUint(i)
   164  		} else {
   165  			err = fmt.Errorf("cannot parse '%s' as uint: %w", name, err)
   166  		}
   167  	default:
   168  		err = fmt.Errorf(
   169  			"'%s' expected type '%s', got unconvertible type '%s'",
   170  			name, val.Type(), dataVal.Type(),
   171  		)
   172  	}
   173  	return err
   174  }
   175  
   176  func (d *Decoder) decodeString(name string, data any, val reflect.Value) (err error) {
   177  	dataVal := reflect.ValueOf(data)
   178  	kind := dataVal.Kind()
   179  	switch {
   180  	case kind == reflect.String:
   181  		val.SetString(dataVal.String())
   182  	case kind == reflect.Bool && d.option.WeaklyTypedInput:
   183  		val.SetString(strconv.FormatBool(dataVal.Bool()))
   184  	case dataVal.CanInt() && d.option.WeaklyTypedInput:
   185  		val.SetString(strconv.FormatInt(dataVal.Int(), 10))
   186  	default:
   187  		err = fmt.Errorf(
   188  			"'%s' expected type '%s', got unconvertible type '%s'",
   189  			name, val.Type(), dataVal.Type(),
   190  		)
   191  	}
   192  	return err
   193  }
   194  
   195  func (d *Decoder) decodeBool(name string, data any, val reflect.Value) (err error) {
   196  	dataVal := reflect.ValueOf(data)
   197  	kind := dataVal.Kind()
   198  	switch {
   199  	case kind == reflect.Bool:
   200  		val.SetBool(dataVal.Bool())
   201  	case dataVal.CanInt() && d.option.WeaklyTypedInput:
   202  		val.SetBool(dataVal.Int() != 0)
   203  	case kind == reflect.String && d.option.WeaklyTypedInput:
   204  		v, _ := strconv.ParseBool(dataVal.String())
   205  		val.SetBool(v)
   206  	default:
   207  		err = fmt.Errorf(
   208  			"'%s' expected type '%s', got unconvertible type '%s'",
   209  			name, val.Type(), dataVal.Type(),
   210  		)
   211  	}
   212  	return err
   213  }
   214  
   215  func (d *Decoder) decodeSlice(name string, data any, val reflect.Value) error {
   216  	dataVal := reflect.Indirect(reflect.ValueOf(data))
   217  	valType := val.Type()
   218  	valElemType := valType.Elem()
   219  
   220  	if dataVal.Kind() != reflect.Slice {
   221  		return fmt.Errorf("'%s' is not a slice", name)
   222  	}
   223  
   224  	valSlice := val
   225  	for i := 0; i < dataVal.Len(); i++ {
   226  		currentData := dataVal.Index(i).Interface()
   227  		for valSlice.Len() <= i {
   228  			valSlice = reflect.Append(valSlice, reflect.Zero(valElemType))
   229  		}
   230  		fieldName := fmt.Sprintf("%s[%d]", name, i)
   231  		if currentData == nil {
   232  			// in weakly type mode, null will convert to zero value
   233  			if d.option.WeaklyTypedInput {
   234  				continue
   235  			}
   236  			// in non-weakly type mode, null will convert to nil if element's zero value is nil
   237  			// otherwise return an error
   238  			if elemKind := valElemType.Kind(); elemKind == reflect.Map || elemKind == reflect.Slice {
   239  				continue
   240  			}
   241  			return fmt.Errorf("'%s' can not be null", fieldName)
   242  		}
   243  		currentField := valSlice.Index(i)
   244  		if err := d.decode(fieldName, currentData, currentField); err != nil {
   245  			return err
   246  		}
   247  	}
   248  
   249  	val.Set(valSlice)
   250  	return nil
   251  }
   252  
   253  func (d *Decoder) decodeMap(name string, data any, val reflect.Value) error {
   254  	valType := val.Type()
   255  	valKeyType := valType.Key()
   256  	valElemType := valType.Elem()
   257  
   258  	valMap := val
   259  
   260  	if valMap.IsNil() {
   261  		mapType := reflect.MapOf(valKeyType, valElemType)
   262  		valMap = reflect.MakeMap(mapType)
   263  	}
   264  
   265  	dataVal := reflect.Indirect(reflect.ValueOf(data))
   266  	if dataVal.Kind() != reflect.Map {
   267  		return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind())
   268  	}
   269  
   270  	return d.decodeMapFromMap(name, dataVal, val, valMap)
   271  }
   272  
   273  func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error {
   274  	valType := val.Type()
   275  	valKeyType := valType.Key()
   276  	valElemType := valType.Elem()
   277  
   278  	if dataVal.Len() == 0 {
   279  		if dataVal.IsNil() {
   280  			if !val.IsNil() {
   281  				val.Set(dataVal)
   282  			}
   283  		} else {
   284  			val.Set(valMap)
   285  		}
   286  
   287  		return nil
   288  	}
   289  
   290  	var errs error
   291  	for _, k := range dataVal.MapKeys() {
   292  		fieldName := fmt.Sprintf("%s[%s]", name, k)
   293  
   294  		currentKey := reflect.Indirect(reflect.New(valKeyType))
   295  		if err := d.decode(fieldName, k.Interface(), currentKey); err != nil {
   296  			errs = errors.Join(errs, err)
   297  			continue
   298  		}
   299  
   300  		v := dataVal.MapIndex(k).Interface()
   301  		if v == nil {
   302  			errs = errors.Join(errs, fmt.Errorf("filed %s invalid", fieldName))
   303  			continue
   304  		}
   305  
   306  		currentVal := reflect.Indirect(reflect.New(valElemType))
   307  		if err := d.decode(fieldName, v, currentVal); err != nil {
   308  			errs = errors.Join(errs, err)
   309  			continue
   310  		}
   311  
   312  		valMap.SetMapIndex(currentKey, currentVal)
   313  	}
   314  
   315  	val.Set(valMap)
   316  
   317  	if errs != nil {
   318  		return errors2.New(errs)
   319  	}
   320  
   321  	return nil
   322  }
   323  
   324  func (d *Decoder) decodeStruct(name string, data any, val reflect.Value) error {
   325  	dataVal := reflect.Indirect(reflect.ValueOf(data))
   326  
   327  	// If the type of the value to write to and the data match directly,
   328  	// then we just set it directly instead of recursing into the structure.
   329  	if dataVal.Type() == val.Type() {
   330  		val.Set(dataVal)
   331  		return nil
   332  	}
   333  
   334  	dataValKind := dataVal.Kind()
   335  	switch dataValKind {
   336  	case reflect.Map:
   337  		return d.decodeStructFromMap(name, dataVal, val)
   338  	default:
   339  		return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind())
   340  	}
   341  }
   342  
   343  func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) error {
   344  	dataValType := dataVal.Type()
   345  	if kind := dataValType.Key().Kind(); kind != reflect.String && kind != reflect.Interface {
   346  		return fmt.Errorf(
   347  			"'%s' needs a map with string keys, has '%s' keys",
   348  			name, dataValType.Key().Kind())
   349  	}
   350  
   351  	dataValKeys := make(map[reflect.Value]struct{})
   352  	dataValKeysUnused := make(map[any]struct{})
   353  	for _, dataValKey := range dataVal.MapKeys() {
   354  		dataValKeys[dataValKey] = struct{}{}
   355  		dataValKeysUnused[dataValKey.Interface()] = struct{}{}
   356  	}
   357  
   358  	// This slice will keep track of all the structs we'll be decoding.
   359  	// There can be more than one struct if there are embedded structs
   360  	// that are squashed.
   361  	structs := make([]reflect.Value, 1, 5)
   362  	structs[0] = val
   363  
   364  	// Compile the list of all the fields that we're going to be decoding
   365  	// from all the structs.
   366  	type field struct {
   367  		field reflect.StructField
   368  		val   reflect.Value
   369  	}
   370  
   371  	var (
   372  		fields []field
   373  		errs   error
   374  	)
   375  	for len(structs) > 0 {
   376  		structVal := structs[0]
   377  		structs = structs[1:]
   378  
   379  		structType := structVal.Type()
   380  
   381  		for i := 0; i < structType.NumField(); i++ {
   382  			fieldType := structType.Field(i)
   383  			fieldKind := fieldType.Type.Kind()
   384  
   385  			// If "squash" is specified in the tag, we squash the field down.
   386  			squash := false
   387  			tagParts := strings.Split(fieldType.Tag.Get(d.option.TagName), ",")
   388  			for _, tag := range tagParts[1:] {
   389  				if tag == "squash" {
   390  					squash = true
   391  					break
   392  				}
   393  			}
   394  
   395  			if squash {
   396  				if fieldKind != reflect.Struct {
   397  					errs = errors.Join(
   398  						errs,
   399  						fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldKind),
   400  					)
   401  				} else {
   402  					structs = append(structs, structVal.FieldByName(fieldType.Name))
   403  				}
   404  				continue
   405  			}
   406  
   407  			// Normal struct field, store it away
   408  			fields = append(fields, field{fieldType, structVal.Field(i)})
   409  		}
   410  	}
   411  
   412  	// for fieldType, field := range fields {
   413  	for _, f := range fields {
   414  		fieldM, fieldValue := f.field, f.val
   415  		fieldName := fieldM.Name
   416  
   417  		tagValue := fieldM.Tag.Get(d.option.TagName)
   418  		tagValue = strings.SplitN(tagValue, ",", 2)[0]
   419  		if tagValue != "" {
   420  			fieldName = tagValue
   421  		}
   422  
   423  		rawMapKey := reflect.ValueOf(fieldName)
   424  		rawMapVal := dataVal.MapIndex(rawMapKey)
   425  		if !rawMapVal.IsValid() {
   426  			// Do a slower search by iterating over each key and
   427  			// doing case-insensitive search.
   428  			for dataValKey := range dataValKeys {
   429  				mK, ok := dataValKey.Interface().(string)
   430  				if !ok {
   431  					// Not a string key
   432  					continue
   433  				}
   434  
   435  				if strings.EqualFold(mK, fieldName) {
   436  					rawMapKey = dataValKey
   437  					rawMapVal = dataVal.MapIndex(dataValKey)
   438  					break
   439  				}
   440  			}
   441  
   442  			if !rawMapVal.IsValid() {
   443  				// There was no matching key in the map for the value in
   444  				// the struct. Just ignore.
   445  				continue
   446  			}
   447  		}
   448  
   449  		// Delete the key we're using from the unused map so stop tracking
   450  		delete(dataValKeysUnused, rawMapKey.Interface())
   451  
   452  		if !fieldValue.IsValid() {
   453  			// This should never happen
   454  			panic("field is not valid")
   455  		}
   456  
   457  		// If we can't set the field, then it is unexported or something,
   458  		// and we just continue onwards.
   459  		if !fieldValue.CanSet() {
   460  			continue
   461  		}
   462  
   463  		// If the name is empty string, then we're at the root, and we
   464  		// don't dot-join the fields.
   465  		if name != "" {
   466  			fieldName = fmt.Sprintf("%s.%s", name, fieldName)
   467  		}
   468  
   469  		if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil {
   470  			errs = errors.Join(errs, err)
   471  		}
   472  	}
   473  
   474  	if errs != nil {
   475  		return errors2.New(errs)
   476  	}
   477  
   478  	return nil
   479  }
   480  
   481  func (d *Decoder) setInterface(_ string, data any, val reflect.Value) (err error) {
   482  	dataVal := reflect.ValueOf(data)
   483  	val.Set(dataVal)
   484  	return nil
   485  }