github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/defaults/defaults.go (about)

     1  package defaults
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"reflect"
     7  	"strconv"
     8  	"time"
     9  
    10  	"github.com/bingoohuang/gg/pkg/reflector"
    11  )
    12  
    13  // ErrInvalidType is the error for non-struct pointer
    14  var ErrInvalidType = errors.New("not a struct pointer")
    15  
    16  // Option is the options for Validate.
    17  type Option struct {
    18  	TagName string
    19  }
    20  
    21  // OptionFn is the function prototype to apply option
    22  type OptionFn func(*Option)
    23  
    24  // TagName defines the tag name for validate.
    25  func TagName(tagName string) OptionFn { return func(o *Option) { o.TagName = tagName } }
    26  
    27  // MustSet function is a wrapper of Set function
    28  // It will call Set and panic if err not equals nil.
    29  func MustSet(ptr interface{}) {
    30  	if err := Set(ptr); err != nil {
    31  		panic(err)
    32  	}
    33  }
    34  
    35  // Set initializes members in a struct referenced by a pointer.
    36  // Maps and slices are initialized by `make` and other primitive types are set with default values.
    37  // `ptr` should be a struct pointer
    38  func Set(ptr interface{}, optionFns ...OptionFn) error {
    39  	if reflect.TypeOf(ptr).Kind() != reflect.Ptr {
    40  		return ErrInvalidType
    41  	}
    42  
    43  	option := createOption(optionFns)
    44  
    45  	v := reflect.ValueOf(ptr).Elem()
    46  	t := v.Type()
    47  
    48  	if t.Kind() != reflect.Struct {
    49  		return ErrInvalidType
    50  	}
    51  
    52  	for i := 0; i < t.NumField(); i++ {
    53  		defaultVal := t.Field(i).Tag.Get(option.TagName)
    54  		if defaultVal == "-" {
    55  			continue
    56  		}
    57  
    58  		if err := setField(v.Field(i), defaultVal); err != nil {
    59  			return err
    60  		}
    61  	}
    62  
    63  	return nil
    64  }
    65  
    66  func createOption(optionFns []OptionFn) *Option {
    67  	option := &Option{}
    68  
    69  	for _, fn := range optionFns {
    70  		fn(option)
    71  	}
    72  
    73  	if option.TagName == "" {
    74  		option.TagName = "default"
    75  	}
    76  
    77  	return option
    78  }
    79  
    80  func setField(field reflect.Value, v string) error {
    81  	if !field.CanSet() {
    82  		return nil
    83  	}
    84  
    85  	if !shouldInitializeField(field, v) {
    86  		return nil
    87  	}
    88  
    89  	if reflector.IsEmpty(field) {
    90  		if err := setZeroField(field, v); err != nil {
    91  			return err
    92  		}
    93  	}
    94  
    95  	switch field.Kind() {
    96  	case reflect.Ptr:
    97  		if err := setField(field.Elem(), v); err != nil {
    98  			return err
    99  		}
   100  
   101  		callSetter(field.Interface())
   102  	case reflect.Struct:
   103  		ref := reflect.New(field.Type())
   104  		ref.Elem().Set(field)
   105  
   106  		if err := Set(ref.Interface()); err != nil {
   107  			return err
   108  		}
   109  
   110  		callSetter(ref.Interface())
   111  		field.Set(ref.Elem())
   112  	case reflect.Slice:
   113  		for j := 0; j < field.Len(); j++ {
   114  			if err := setField(field.Index(j), v); err != nil {
   115  				return err
   116  			}
   117  		}
   118  	}
   119  
   120  	return nil
   121  }
   122  
   123  func setZeroField(field reflect.Value, v string) error {
   124  	m := map[reflect.Kind]converterFn{
   125  		reflect.Bool:    convertBool,
   126  		reflect.Int:     convertInt,
   127  		reflect.Int8:    convertInt8,
   128  		reflect.Int16:   convertInt16,
   129  		reflect.Int32:   convertInt32,
   130  		reflect.Int64:   convertInt64,
   131  		reflect.Uint:    convertUInt,
   132  		reflect.Uint8:   convertUInt8,
   133  		reflect.Uint16:  convertUInt16,
   134  		reflect.Uint32:  convertUInt32,
   135  		reflect.Uint64:  convertUInt64,
   136  		reflect.Uintptr: convertUintptr,
   137  		reflect.Float32: convertFloat32,
   138  		reflect.Float64: convertFloat64,
   139  		reflect.String:  convertString,
   140  		reflect.Slice:   convertSlice,
   141  		reflect.Map:     convertMap,
   142  		reflect.Struct:  convertStruct,
   143  		reflect.Ptr:     convertPtr,
   144  	}
   145  
   146  	f, ok := m[field.Kind()]
   147  	if !ok {
   148  		return nil
   149  	}
   150  
   151  	val, err := f(field.Type(), v)
   152  
   153  	if err == nil {
   154  		field.Set(val)
   155  		return nil
   156  	}
   157  
   158  	if werr, ok := err.(*wrapError); ok {
   159  		return werr.error
   160  	}
   161  
   162  	return nil
   163  }
   164  
   165  // Setter is an interface for setting default values
   166  type Setter interface {
   167  	SetDefaults()
   168  }
   169  
   170  func callSetter(v interface{}) {
   171  	if ds, ok := v.(Setter); ok {
   172  		ds.SetDefaults()
   173  	}
   174  }
   175  
   176  type converterFn func(t reflect.Type, v string) (reflect.Value, error)
   177  
   178  func convertBool(t reflect.Type, v string) (reflect.Value, error) {
   179  	val, err := strconv.ParseBool(v)
   180  	if err != nil {
   181  		return reflect.Value{}, err
   182  	}
   183  
   184  	return reflect.ValueOf(val).Convert(t), nil
   185  }
   186  
   187  func convertInt(t reflect.Type, v string) (reflect.Value, error) {
   188  	val, err := strconv.ParseInt(v, 10, 64)
   189  	if err != nil {
   190  		return reflect.Value{}, err
   191  	}
   192  
   193  	return reflect.ValueOf(int(val)).Convert(t), nil
   194  }
   195  
   196  func convertInt8(t reflect.Type, v string) (reflect.Value, error) {
   197  	val, err := strconv.ParseInt(v, 10, 8)
   198  	if err != nil {
   199  		return reflect.Value{}, err
   200  	}
   201  
   202  	return reflect.ValueOf(int8(val)).Convert(t), nil
   203  }
   204  
   205  func convertInt16(t reflect.Type, v string) (reflect.Value, error) {
   206  	val, err := strconv.ParseInt(v, 10, 16)
   207  	if err != nil {
   208  		return reflect.Value{}, err
   209  	}
   210  
   211  	return reflect.ValueOf(int16(val)).Convert(t), nil
   212  }
   213  
   214  func convertInt32(t reflect.Type, v string) (reflect.Value, error) {
   215  	val, err := strconv.ParseInt(v, 10, 32)
   216  	if err != nil {
   217  		return reflect.Value{}, err
   218  	}
   219  
   220  	return reflect.ValueOf(int32(val)).Convert(t), nil
   221  }
   222  
   223  func convertInt64(t reflect.Type, v string) (reflect.Value, error) {
   224  	d, err := time.ParseDuration(v)
   225  	if err == nil {
   226  		return reflect.ValueOf(d).Convert(t), nil
   227  	}
   228  
   229  	val, err := strconv.ParseInt(v, 10, 64)
   230  	if err != nil {
   231  		return reflect.Value{}, err
   232  	}
   233  
   234  	return reflect.ValueOf(val).Convert(t), nil
   235  }
   236  
   237  func convertUInt(t reflect.Type, v string) (reflect.Value, error) {
   238  	val, err := strconv.ParseUint(v, 10, 64)
   239  	if err != nil {
   240  		return reflect.Value{}, err
   241  	}
   242  
   243  	return reflect.ValueOf(uint(val)).Convert(t), nil
   244  }
   245  
   246  func convertUInt8(t reflect.Type, v string) (reflect.Value, error) {
   247  	val, err := strconv.ParseUint(v, 10, 8)
   248  	if err != nil {
   249  		return reflect.Value{}, err
   250  	}
   251  
   252  	return reflect.ValueOf(uint8(val)).Convert(t), nil
   253  }
   254  
   255  func convertUInt16(t reflect.Type, v string) (reflect.Value, error) {
   256  	val, err := strconv.ParseUint(v, 10, 16)
   257  	if err != nil {
   258  		return reflect.Value{}, err
   259  	}
   260  
   261  	return reflect.ValueOf(uint16(val)).Convert(t), nil
   262  }
   263  
   264  func convertUInt32(t reflect.Type, v string) (reflect.Value, error) {
   265  	val, err := strconv.ParseUint(v, 10, 32)
   266  	if err != nil {
   267  		return reflect.Value{}, err
   268  	}
   269  
   270  	return reflect.ValueOf(uint32(val)).Convert(t), nil
   271  }
   272  
   273  func convertUInt64(t reflect.Type, v string) (reflect.Value, error) {
   274  	val, err := strconv.ParseUint(v, 10, 64)
   275  	if err != nil {
   276  		return reflect.Value{}, err
   277  	}
   278  
   279  	return reflect.ValueOf(val).Convert(t), nil
   280  }
   281  
   282  func convertUintptr(t reflect.Type, v string) (reflect.Value, error) {
   283  	val, err := strconv.ParseUint(v, 10, 64)
   284  	if err != nil {
   285  		return reflect.Value{}, err
   286  	}
   287  
   288  	return reflect.ValueOf(uintptr(val)).Convert(t), nil
   289  }
   290  
   291  func convertFloat32(t reflect.Type, v string) (reflect.Value, error) {
   292  	val, err := strconv.ParseFloat(v, 32)
   293  	if err != nil {
   294  		return reflect.Value{}, err
   295  	}
   296  
   297  	return reflect.ValueOf(float32(val)).Convert(t), nil
   298  }
   299  
   300  func convertFloat64(t reflect.Type, v string) (reflect.Value, error) {
   301  	val, err := strconv.ParseFloat(v, 64)
   302  	if err != nil {
   303  		return reflect.Value{}, err
   304  	}
   305  
   306  	return reflect.ValueOf(val).Convert(t), nil
   307  }
   308  
   309  type wrapError struct {
   310  	error
   311  }
   312  
   313  func convertString(t reflect.Type, v string) (reflect.Value, error) {
   314  	return reflect.ValueOf(v).Convert(t), nil
   315  }
   316  
   317  func convertSlice(t reflect.Type, v string) (reflect.Value, error) {
   318  	ref := reflect.New(t)
   319  	ref.Elem().Set(reflect.MakeSlice(t, 0, 0))
   320  
   321  	if v != "" && v != "[]" {
   322  		if err := json.Unmarshal([]byte(v), ref.Interface()); err != nil {
   323  			return reflect.Value{}, &wrapError{err}
   324  		}
   325  	}
   326  
   327  	return ref.Elem().Convert(t), nil
   328  }
   329  
   330  func convertMap(t reflect.Type, v string) (reflect.Value, error) {
   331  	ref := reflect.New(t)
   332  	ref.Elem().Set(reflect.MakeMap(t))
   333  
   334  	if v != "" && v != "{}" {
   335  		if err := json.Unmarshal([]byte(v), ref.Interface()); err != nil {
   336  			return reflect.Value{}, &wrapError{err}
   337  		}
   338  	}
   339  
   340  	return ref.Elem().Convert(t), nil
   341  }
   342  
   343  func convertPtr(t reflect.Type, v string) (reflect.Value, error) {
   344  	return reflect.New(t.Elem()), nil
   345  }
   346  
   347  func convertStruct(t reflect.Type, v string) (reflect.Value, error) {
   348  	ref := reflect.New(t)
   349  
   350  	if v != "" && v != "{}" {
   351  		if err := json.Unmarshal([]byte(v), ref.Interface()); err != nil {
   352  			return reflect.Value{}, &wrapError{err}
   353  		}
   354  	}
   355  
   356  	return ref.Elem(), nil
   357  }
   358  
   359  func shouldInitializeField(field reflect.Value, defaultVal string) bool {
   360  	switch field.Kind() {
   361  	case reflect.Struct:
   362  		return true
   363  	case reflect.Slice:
   364  		return field.Len() > 0 || defaultVal != ""
   365  	}
   366  
   367  	return defaultVal != ""
   368  }
   369  
   370  // CanUpdate returns true when the given value is an initial value of its type
   371  func CanUpdate(v interface{}) bool {
   372  	return reflector.IsEmpty(v)
   373  }