gitee.com/quant1x/gox@v1.21.2/api/copier.go (about)

     1  package api
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"errors"
     7  	"fmt"
     8  	"reflect"
     9  	"strings"
    10  	"unicode"
    11  )
    12  
    13  // These flags define options for tag handling
    14  const (
    15  	// Denotes that a destination field must be copied to. If copying fails then a panic will ensue.
    16  	tagMust uint8 = 1 << iota
    17  
    18  	// Denotes that the program should not panic when the must flag is on and
    19  	// value is not copied. The program will return an error instead.
    20  	tagNoPanic
    21  
    22  	// Ignore a destination field from being copied to.
    23  	tagIgnore
    24  
    25  	// Denotes that the value as been copied
    26  	hasCopied
    27  )
    28  
    29  // Option sets copy options
    30  type Option struct {
    31  	// setting this value to true will ignore copying zero values of all the fields, including bools, as well as a
    32  	// struct having all it's fields set to their zero values respectively (see IsZero() in reflect/value.go)
    33  	IgnoreEmpty bool
    34  	DeepCopy    bool
    35  }
    36  
    37  // Tag Flags
    38  type flags struct {
    39  	BitFlags  map[string]uint8
    40  	SrcNames  tagNameMapping
    41  	DestNames tagNameMapping
    42  }
    43  
    44  // Field Tag name mapping
    45  type tagNameMapping struct {
    46  	FieldNameToTag map[string]string
    47  	TagToFieldName map[string]string
    48  }
    49  
    50  // Copy copy things
    51  func Copy[T any, S any](to *T, from *S) (err error) {
    52  	return copier(to, from, Option{})
    53  }
    54  
    55  // CopyWithOption copy with option
    56  func CopyWithOption(toValue any, fromValue any, opt Option) (err error) {
    57  	return copier(toValue, fromValue, opt)
    58  }
    59  
    60  func copier(toValue any, fromValue any, opt Option) (err error) {
    61  	var (
    62  		isSlice bool
    63  		amount  = 1
    64  		from    = indirect(reflect.ValueOf(fromValue))
    65  		to      = indirect(reflect.ValueOf(toValue))
    66  	)
    67  
    68  	if !to.CanAddr() {
    69  		return ErrInvalidCopyDestination
    70  	}
    71  
    72  	// Return is from value is invalid
    73  	if !from.IsValid() {
    74  		return ErrInvalidCopyFrom
    75  	}
    76  
    77  	fromType, isPtrFrom := indirectType(from.Type())
    78  	toType, _ := indirectType(to.Type())
    79  
    80  	if fromType.Kind() == reflect.Interface {
    81  		fromType = reflect.TypeOf(from.Interface())
    82  	}
    83  
    84  	if toType.Kind() == reflect.Interface {
    85  		toType, _ = indirectType(reflect.TypeOf(to.Interface()))
    86  		oldTo := to
    87  		to = reflect.New(reflect.TypeOf(to.Interface())).Elem()
    88  		defer func() {
    89  			oldTo.Set(to)
    90  		}()
    91  	}
    92  
    93  	// Just set it if possible to assign for normal types
    94  	if from.Kind() != reflect.Slice && from.Kind() != reflect.Struct && from.Kind() != reflect.Map && (from.Type().AssignableTo(to.Type()) || from.Type().ConvertibleTo(to.Type())) {
    95  		if !isPtrFrom || !opt.DeepCopy {
    96  			to.Set(from.Convert(to.Type()))
    97  		} else {
    98  			fromCopy := reflect.New(from.Type())
    99  			fromCopy.Set(from.Elem())
   100  			to.Set(fromCopy.Convert(to.Type()))
   101  		}
   102  		return
   103  	}
   104  
   105  	if from.Kind() != reflect.Slice && fromType.Kind() == reflect.Map && toType.Kind() == reflect.Map {
   106  		if !fromType.Key().ConvertibleTo(toType.Key()) {
   107  			return ErrMapKeyNotMatch
   108  		}
   109  
   110  		if to.IsNil() {
   111  			to.Set(reflect.MakeMapWithSize(toType, from.Len()))
   112  		}
   113  
   114  		for _, k := range from.MapKeys() {
   115  			toKey := indirect(reflect.New(toType.Key()))
   116  			if !set(toKey, k, opt.DeepCopy) {
   117  				return fmt.Errorf("%w map, old key: %v, new key: %v", ErrNotSupported, k.Type(), toType.Key())
   118  			}
   119  
   120  			elemType, _ := indirectType(toType.Elem())
   121  			toValue := indirect(reflect.New(elemType))
   122  			if !set(toValue, from.MapIndex(k), opt.DeepCopy) {
   123  				if err = copier(toValue.Addr().Interface(), from.MapIndex(k).Interface(), opt); err != nil {
   124  					return err
   125  				}
   126  			}
   127  
   128  			for {
   129  				if elemType == toType.Elem() {
   130  					to.SetMapIndex(toKey, toValue)
   131  					break
   132  				}
   133  				elemType = reflect.PtrTo(elemType)
   134  				toValue = toValue.Addr()
   135  			}
   136  		}
   137  		return
   138  	}
   139  
   140  	if from.Kind() == reflect.Slice && to.Kind() == reflect.Slice && fromType.ConvertibleTo(toType) {
   141  		if to.IsNil() {
   142  			slice := reflect.MakeSlice(reflect.SliceOf(to.Type().Elem()), from.Len(), from.Cap())
   143  			to.Set(slice)
   144  		}
   145  
   146  		for i := 0; i < from.Len(); i++ {
   147  			if to.Len() < i+1 {
   148  				to.Set(reflect.Append(to, reflect.New(to.Type().Elem()).Elem()))
   149  			}
   150  
   151  			if !set(to.Index(i), from.Index(i), opt.DeepCopy) {
   152  				// ignore error while copy slice element
   153  				err = copier(to.Index(i).Addr().Interface(), from.Index(i).Interface(), opt)
   154  				if err != nil {
   155  					continue
   156  				}
   157  			}
   158  		}
   159  		return
   160  	}
   161  
   162  	if fromType.Kind() != reflect.Struct || toType.Kind() != reflect.Struct {
   163  		// skip not supported type
   164  		return
   165  	}
   166  
   167  	if from.Kind() == reflect.Slice || to.Kind() == reflect.Slice {
   168  		isSlice = true
   169  		if from.Kind() == reflect.Slice {
   170  			amount = from.Len()
   171  		}
   172  	}
   173  
   174  	for i := 0; i < amount; i++ {
   175  		var dest, source reflect.Value
   176  
   177  		if isSlice {
   178  			// source
   179  			if from.Kind() == reflect.Slice {
   180  				source = indirect(from.Index(i))
   181  			} else {
   182  				source = indirect(from)
   183  			}
   184  			// dest
   185  			dest = indirect(reflect.New(toType).Elem())
   186  		} else {
   187  			source = indirect(from)
   188  			dest = indirect(to)
   189  		}
   190  
   191  		destKind := dest.Kind()
   192  		initDest := false
   193  		if destKind == reflect.Interface {
   194  			initDest = true
   195  			dest = indirect(reflect.New(toType))
   196  		}
   197  
   198  		// Get tag options
   199  		flgs, err := getFlags(dest, source, toType, fromType)
   200  		if err != nil {
   201  			return err
   202  		}
   203  
   204  		// check source
   205  		if source.IsValid() {
   206  			// Copy from source field to dest field or method
   207  			fromTypeFields := deepFields(fromType)
   208  			for _, field := range fromTypeFields {
   209  				name := field.Name
   210  
   211  				// Get bit flags for field
   212  				fieldFlags, _ := flgs.BitFlags[name]
   213  
   214  				// Check if we should ignore copying
   215  				if (fieldFlags & tagIgnore) != 0 {
   216  					continue
   217  				}
   218  
   219  				srcFieldName, destFieldName := getFieldName(name, flgs)
   220  				if fromField := source.FieldByName(srcFieldName); fromField.IsValid() && !shouldIgnore(fromField, opt.IgnoreEmpty) {
   221  					// process for nested anonymous field
   222  					destFieldNotSet := false
   223  					if f, ok := dest.Type().FieldByName(destFieldName); ok {
   224  						for idx := range f.Index {
   225  							destField := dest.FieldByIndex(f.Index[:idx+1])
   226  
   227  							if destField.Kind() != reflect.Ptr {
   228  								continue
   229  							}
   230  
   231  							if !destField.IsNil() {
   232  								continue
   233  							}
   234  							if !destField.CanSet() {
   235  								destFieldNotSet = true
   236  								break
   237  							}
   238  
   239  							// destField is a nil pointer that can be set
   240  							newValue := reflect.New(destField.Type().Elem())
   241  							destField.Set(newValue)
   242  						}
   243  					}
   244  
   245  					if destFieldNotSet {
   246  						break
   247  					}
   248  
   249  					toField := dest.FieldByName(destFieldName)
   250  					if toField.IsValid() {
   251  						if toField.CanSet() {
   252  							if !set(toField, fromField, opt.DeepCopy) {
   253  								if err := copier(toField.Addr().Interface(), fromField.Interface(), opt); err != nil {
   254  									return err
   255  								}
   256  							}
   257  							if fieldFlags != 0 {
   258  								// Note that a copy was made
   259  								flgs.BitFlags[name] = fieldFlags | hasCopied
   260  							}
   261  						}
   262  					} else {
   263  						// try to set to method
   264  						var toMethod reflect.Value
   265  						if dest.CanAddr() {
   266  							toMethod = dest.Addr().MethodByName(destFieldName)
   267  						} else {
   268  							toMethod = dest.MethodByName(destFieldName)
   269  						}
   270  
   271  						if toMethod.IsValid() && toMethod.Type().NumIn() == 1 && fromField.Type().AssignableTo(toMethod.Type().In(0)) {
   272  							toMethod.Call([]reflect.Value{fromField})
   273  						}
   274  					}
   275  				}
   276  			}
   277  
   278  			// Copy from from method to dest field
   279  			for _, field := range deepFields(toType) {
   280  				name := field.Name
   281  				srcFieldName, destFieldName := getFieldName(name, flgs)
   282  
   283  				var fromMethod reflect.Value
   284  				if source.CanAddr() {
   285  					fromMethod = source.Addr().MethodByName(srcFieldName)
   286  				} else {
   287  					fromMethod = source.MethodByName(srcFieldName)
   288  				}
   289  
   290  				if fromMethod.IsValid() && fromMethod.Type().NumIn() == 0 && fromMethod.Type().NumOut() == 1 && !shouldIgnore(fromMethod, opt.IgnoreEmpty) {
   291  					if toField := dest.FieldByName(destFieldName); toField.IsValid() && toField.CanSet() {
   292  						values := fromMethod.Call([]reflect.Value{})
   293  						if len(values) >= 1 {
   294  							set(toField, values[0], opt.DeepCopy)
   295  						}
   296  					}
   297  				}
   298  			}
   299  		}
   300  
   301  		if isSlice && to.Kind() == reflect.Slice {
   302  			if dest.Addr().Type().AssignableTo(to.Type().Elem()) {
   303  				if to.Len() < i+1 {
   304  					to.Set(reflect.Append(to, dest.Addr()))
   305  				} else {
   306  					if !set(to.Index(i), dest.Addr(), opt.DeepCopy) {
   307  						// ignore error while copy slice element
   308  						err = copier(to.Index(i).Addr().Interface(), dest.Addr().Interface(), opt)
   309  						if err != nil {
   310  							continue
   311  						}
   312  					}
   313  				}
   314  			} else if dest.Type().AssignableTo(to.Type().Elem()) {
   315  				if to.Len() < i+1 {
   316  					to.Set(reflect.Append(to, dest))
   317  				} else {
   318  					if !set(to.Index(i), dest, opt.DeepCopy) {
   319  						// ignore error while copy slice element
   320  						err = copier(to.Index(i).Addr().Interface(), dest.Interface(), opt)
   321  						if err != nil {
   322  							continue
   323  						}
   324  					}
   325  				}
   326  			}
   327  		} else if initDest {
   328  			to.Set(dest)
   329  		}
   330  
   331  		err = checkBitFlags(flgs.BitFlags)
   332  	}
   333  
   334  	return
   335  }
   336  
   337  func shouldIgnore(v reflect.Value, ignoreEmpty bool) bool {
   338  	if !ignoreEmpty {
   339  		return false
   340  	}
   341  
   342  	return v.IsZero()
   343  }
   344  
   345  func deepFields(reflectType reflect.Type) []reflect.StructField {
   346  	if reflectType, _ = indirectType(reflectType); reflectType.Kind() == reflect.Struct {
   347  		fields := make([]reflect.StructField, 0, reflectType.NumField())
   348  
   349  		for i := 0; i < reflectType.NumField(); i++ {
   350  			v := reflectType.Field(i)
   351  			if v.Anonymous {
   352  				fields = append(fields, deepFields(v.Type)...)
   353  			} else {
   354  				fields = append(fields, v)
   355  			}
   356  		}
   357  
   358  		return fields
   359  	}
   360  
   361  	return nil
   362  }
   363  
   364  func indirect(reflectValue reflect.Value) reflect.Value {
   365  	for reflectValue.Kind() == reflect.Ptr {
   366  		reflectValue = reflectValue.Elem()
   367  	}
   368  	return reflectValue
   369  }
   370  
   371  func indirectType(reflectType reflect.Type) (_ reflect.Type, isPtr bool) {
   372  	for reflectType.Kind() == reflect.Ptr || reflectType.Kind() == reflect.Slice {
   373  		reflectType = reflectType.Elem()
   374  		isPtr = true
   375  	}
   376  	return reflectType, isPtr
   377  }
   378  
   379  func set(to, from reflect.Value, deepCopy bool) bool {
   380  	if from.IsValid() {
   381  		if to.Kind() == reflect.Ptr {
   382  			// set `to` to nil if from is nil
   383  			if from.Kind() == reflect.Ptr && from.IsNil() {
   384  				to.Set(reflect.Zero(to.Type()))
   385  				return true
   386  			} else if to.IsNil() {
   387  				// `from`         -> `to`
   388  				// sql.NullString -> *string
   389  				if fromValuer, ok := driverValuer(from); ok {
   390  					v, err := fromValuer.Value()
   391  					if err != nil {
   392  						return false
   393  					}
   394  					// if `from` is not valid do nothing with `to`
   395  					if v == nil {
   396  						return true
   397  					}
   398  				}
   399  				// allocate new `to` variable with default value (eg. *string -> new(string))
   400  				to.Set(reflect.New(to.Type().Elem()))
   401  			}
   402  			// depointer `to`
   403  			to = to.Elem()
   404  		}
   405  
   406  		if deepCopy {
   407  			toKind := to.Kind()
   408  			if toKind == reflect.Interface && to.IsNil() {
   409  				if reflect.TypeOf(from.Interface()) != nil {
   410  					to.Set(reflect.New(reflect.TypeOf(from.Interface())).Elem())
   411  					toKind = reflect.TypeOf(to.Interface()).Kind()
   412  				}
   413  			}
   414  			if toKind == reflect.Struct || toKind == reflect.Map || toKind == reflect.Slice {
   415  				return false
   416  			}
   417  		}
   418  
   419  		if from.Type().ConvertibleTo(to.Type()) {
   420  			to.Set(from.Convert(to.Type()))
   421  		} else if toScanner, ok := to.Addr().Interface().(sql.Scanner); ok {
   422  			// `from`  -> `to`
   423  			// *string -> sql.NullString
   424  			if from.Kind() == reflect.Ptr {
   425  				// if `from` is nil do nothing with `to`
   426  				if from.IsNil() {
   427  					return true
   428  				}
   429  				// depointer `from`
   430  				from = indirect(from)
   431  			}
   432  			// `from` -> `to`
   433  			// string -> sql.NullString
   434  			// set `to` by invoking method Scan(`from`)
   435  			err := toScanner.Scan(from.Interface())
   436  			if err != nil {
   437  				return false
   438  			}
   439  		} else if fromValuer, ok := driverValuer(from); ok {
   440  			// `from`         -> `to`
   441  			// sql.NullString -> string
   442  			v, err := fromValuer.Value()
   443  			if err != nil {
   444  				return false
   445  			}
   446  			// if `from` is not valid do nothing with `to`
   447  			if v == nil {
   448  				return true
   449  			}
   450  			rv := reflect.ValueOf(v)
   451  			if rv.Type().AssignableTo(to.Type()) {
   452  				to.Set(rv)
   453  			}
   454  		} else if from.Kind() == reflect.Ptr {
   455  			return set(to, from.Elem(), deepCopy)
   456  		} else {
   457  			return false
   458  		}
   459  	}
   460  
   461  	return true
   462  }
   463  
   464  // parseTags Parses struct tags and returns uint8 bit flags.
   465  func parseTags(tag string) (flg uint8, name string, err error) {
   466  	for _, t := range strings.Split(tag, ",") {
   467  		switch t {
   468  		case "-":
   469  			flg = tagIgnore
   470  			return
   471  		case "must":
   472  			flg = flg | tagMust
   473  		case "nopanic":
   474  			flg = flg | tagNoPanic
   475  		default:
   476  			if unicode.IsUpper([]rune(t)[0]) {
   477  				name = strings.TrimSpace(t)
   478  			} else {
   479  				err = errors.New("copier field name tag must be start upper case")
   480  			}
   481  		}
   482  	}
   483  	return
   484  }
   485  
   486  // getTagFlags Parses struct tags for bit flags, field name.
   487  func getFlags(dest, src reflect.Value, toType, fromType reflect.Type) (flags, error) {
   488  	flgs := flags{
   489  		BitFlags: map[string]uint8{},
   490  		SrcNames: tagNameMapping{
   491  			FieldNameToTag: map[string]string{},
   492  			TagToFieldName: map[string]string{},
   493  		},
   494  		DestNames: tagNameMapping{
   495  			FieldNameToTag: map[string]string{},
   496  			TagToFieldName: map[string]string{},
   497  		},
   498  	}
   499  	var toTypeFields, fromTypeFields []reflect.StructField
   500  	if dest.IsValid() {
   501  		toTypeFields = deepFields(toType)
   502  	}
   503  	if src.IsValid() {
   504  		fromTypeFields = deepFields(fromType)
   505  	}
   506  
   507  	// Get a list dest of tags
   508  	for _, field := range toTypeFields {
   509  		tags := field.Tag.Get("copier")
   510  		if tags != "" {
   511  			var name string
   512  			var err error
   513  			if flgs.BitFlags[field.Name], name, err = parseTags(tags); err != nil {
   514  				return flags{}, err
   515  			} else if name != "" {
   516  				flgs.DestNames.FieldNameToTag[field.Name] = name
   517  				flgs.DestNames.TagToFieldName[name] = field.Name
   518  			}
   519  		}
   520  	}
   521  
   522  	// Get a list source of tags
   523  	for _, field := range fromTypeFields {
   524  		tags := field.Tag.Get("copier")
   525  		if tags != "" {
   526  			var name string
   527  			var err error
   528  			if _, name, err = parseTags(tags); err != nil {
   529  				return flags{}, err
   530  			} else if name != "" {
   531  				flgs.SrcNames.FieldNameToTag[field.Name] = name
   532  				flgs.SrcNames.TagToFieldName[name] = field.Name
   533  			}
   534  		}
   535  	}
   536  	return flgs, nil
   537  }
   538  
   539  // checkBitFlags Checks flags for error or panic conditions.
   540  func checkBitFlags(flagsList map[string]uint8) (err error) {
   541  	// Check flag conditions were met
   542  	for name, flgs := range flagsList {
   543  		if flgs&hasCopied == 0 {
   544  			switch {
   545  			case flgs&tagMust != 0 && flgs&tagNoPanic != 0:
   546  				err = fmt.Errorf("field %s has must tag but was not copied", name)
   547  				return
   548  			case flgs&(tagMust) != 0:
   549  				panic(fmt.Sprintf("Field %s has must tag but was not copied", name))
   550  			}
   551  		}
   552  	}
   553  	return
   554  }
   555  
   556  func getFieldName(fieldName string, flgs flags) (srcFieldName string, destFieldName string) {
   557  	// get dest field name
   558  	if srcTagName, ok := flgs.SrcNames.FieldNameToTag[fieldName]; ok {
   559  		destFieldName = srcTagName
   560  		if destTagName, ok := flgs.DestNames.TagToFieldName[srcTagName]; ok {
   561  			destFieldName = destTagName
   562  		}
   563  	} else {
   564  		if destTagName, ok := flgs.DestNames.TagToFieldName[fieldName]; ok {
   565  			destFieldName = destTagName
   566  		}
   567  	}
   568  	if destFieldName == "" {
   569  		destFieldName = fieldName
   570  	}
   571  
   572  	// get source field name
   573  	if destTagName, ok := flgs.DestNames.FieldNameToTag[fieldName]; ok {
   574  		srcFieldName = destTagName
   575  		if srcField, ok := flgs.SrcNames.TagToFieldName[destTagName]; ok {
   576  			srcFieldName = srcField
   577  		}
   578  	} else {
   579  		if srcField, ok := flgs.SrcNames.TagToFieldName[fieldName]; ok {
   580  			srcFieldName = srcField
   581  		}
   582  	}
   583  
   584  	if srcFieldName == "" {
   585  		srcFieldName = fieldName
   586  	}
   587  	return
   588  }
   589  
   590  func driverValuer(v reflect.Value) (i driver.Valuer, ok bool) {
   591  
   592  	if !v.CanAddr() {
   593  		i, ok = v.Interface().(driver.Valuer)
   594  		return
   595  	}
   596  
   597  	i, ok = v.Addr().Interface().(driver.Valuer)
   598  	return
   599  }