gitee.com/quant1x/gox@v1.21.2/api/copier.go (about) 1 package api 2 3 import ( 4 "database/sql" 5 "database/sql/driver" 6 "errors" 7 "fmt" 8 "reflect" 9 "strings" 10 "unicode" 11 ) 12 13 // These flags define options for tag handling 14 const ( 15 // Denotes that a destination field must be copied to. If copying fails then a panic will ensue. 16 tagMust uint8 = 1 << iota 17 18 // Denotes that the program should not panic when the must flag is on and 19 // value is not copied. The program will return an error instead. 20 tagNoPanic 21 22 // Ignore a destination field from being copied to. 23 tagIgnore 24 25 // Denotes that the value as been copied 26 hasCopied 27 ) 28 29 // Option sets copy options 30 type Option struct { 31 // setting this value to true will ignore copying zero values of all the fields, including bools, as well as a 32 // struct having all it's fields set to their zero values respectively (see IsZero() in reflect/value.go) 33 IgnoreEmpty bool 34 DeepCopy bool 35 } 36 37 // Tag Flags 38 type flags struct { 39 BitFlags map[string]uint8 40 SrcNames tagNameMapping 41 DestNames tagNameMapping 42 } 43 44 // Field Tag name mapping 45 type tagNameMapping struct { 46 FieldNameToTag map[string]string 47 TagToFieldName map[string]string 48 } 49 50 // Copy copy things 51 func Copy[T any, S any](to *T, from *S) (err error) { 52 return copier(to, from, Option{}) 53 } 54 55 // CopyWithOption copy with option 56 func CopyWithOption(toValue any, fromValue any, opt Option) (err error) { 57 return copier(toValue, fromValue, opt) 58 } 59 60 func copier(toValue any, fromValue any, opt Option) (err error) { 61 var ( 62 isSlice bool 63 amount = 1 64 from = indirect(reflect.ValueOf(fromValue)) 65 to = indirect(reflect.ValueOf(toValue)) 66 ) 67 68 if !to.CanAddr() { 69 return ErrInvalidCopyDestination 70 } 71 72 // Return is from value is invalid 73 if !from.IsValid() { 74 return ErrInvalidCopyFrom 75 } 76 77 fromType, isPtrFrom := indirectType(from.Type()) 78 toType, _ := indirectType(to.Type()) 79 80 if fromType.Kind() == reflect.Interface { 81 fromType = reflect.TypeOf(from.Interface()) 82 } 83 84 if toType.Kind() == reflect.Interface { 85 toType, _ = indirectType(reflect.TypeOf(to.Interface())) 86 oldTo := to 87 to = reflect.New(reflect.TypeOf(to.Interface())).Elem() 88 defer func() { 89 oldTo.Set(to) 90 }() 91 } 92 93 // Just set it if possible to assign for normal types 94 if from.Kind() != reflect.Slice && from.Kind() != reflect.Struct && from.Kind() != reflect.Map && (from.Type().AssignableTo(to.Type()) || from.Type().ConvertibleTo(to.Type())) { 95 if !isPtrFrom || !opt.DeepCopy { 96 to.Set(from.Convert(to.Type())) 97 } else { 98 fromCopy := reflect.New(from.Type()) 99 fromCopy.Set(from.Elem()) 100 to.Set(fromCopy.Convert(to.Type())) 101 } 102 return 103 } 104 105 if from.Kind() != reflect.Slice && fromType.Kind() == reflect.Map && toType.Kind() == reflect.Map { 106 if !fromType.Key().ConvertibleTo(toType.Key()) { 107 return ErrMapKeyNotMatch 108 } 109 110 if to.IsNil() { 111 to.Set(reflect.MakeMapWithSize(toType, from.Len())) 112 } 113 114 for _, k := range from.MapKeys() { 115 toKey := indirect(reflect.New(toType.Key())) 116 if !set(toKey, k, opt.DeepCopy) { 117 return fmt.Errorf("%w map, old key: %v, new key: %v", ErrNotSupported, k.Type(), toType.Key()) 118 } 119 120 elemType, _ := indirectType(toType.Elem()) 121 toValue := indirect(reflect.New(elemType)) 122 if !set(toValue, from.MapIndex(k), opt.DeepCopy) { 123 if err = copier(toValue.Addr().Interface(), from.MapIndex(k).Interface(), opt); err != nil { 124 return err 125 } 126 } 127 128 for { 129 if elemType == toType.Elem() { 130 to.SetMapIndex(toKey, toValue) 131 break 132 } 133 elemType = reflect.PtrTo(elemType) 134 toValue = toValue.Addr() 135 } 136 } 137 return 138 } 139 140 if from.Kind() == reflect.Slice && to.Kind() == reflect.Slice && fromType.ConvertibleTo(toType) { 141 if to.IsNil() { 142 slice := reflect.MakeSlice(reflect.SliceOf(to.Type().Elem()), from.Len(), from.Cap()) 143 to.Set(slice) 144 } 145 146 for i := 0; i < from.Len(); i++ { 147 if to.Len() < i+1 { 148 to.Set(reflect.Append(to, reflect.New(to.Type().Elem()).Elem())) 149 } 150 151 if !set(to.Index(i), from.Index(i), opt.DeepCopy) { 152 // ignore error while copy slice element 153 err = copier(to.Index(i).Addr().Interface(), from.Index(i).Interface(), opt) 154 if err != nil { 155 continue 156 } 157 } 158 } 159 return 160 } 161 162 if fromType.Kind() != reflect.Struct || toType.Kind() != reflect.Struct { 163 // skip not supported type 164 return 165 } 166 167 if from.Kind() == reflect.Slice || to.Kind() == reflect.Slice { 168 isSlice = true 169 if from.Kind() == reflect.Slice { 170 amount = from.Len() 171 } 172 } 173 174 for i := 0; i < amount; i++ { 175 var dest, source reflect.Value 176 177 if isSlice { 178 // source 179 if from.Kind() == reflect.Slice { 180 source = indirect(from.Index(i)) 181 } else { 182 source = indirect(from) 183 } 184 // dest 185 dest = indirect(reflect.New(toType).Elem()) 186 } else { 187 source = indirect(from) 188 dest = indirect(to) 189 } 190 191 destKind := dest.Kind() 192 initDest := false 193 if destKind == reflect.Interface { 194 initDest = true 195 dest = indirect(reflect.New(toType)) 196 } 197 198 // Get tag options 199 flgs, err := getFlags(dest, source, toType, fromType) 200 if err != nil { 201 return err 202 } 203 204 // check source 205 if source.IsValid() { 206 // Copy from source field to dest field or method 207 fromTypeFields := deepFields(fromType) 208 for _, field := range fromTypeFields { 209 name := field.Name 210 211 // Get bit flags for field 212 fieldFlags, _ := flgs.BitFlags[name] 213 214 // Check if we should ignore copying 215 if (fieldFlags & tagIgnore) != 0 { 216 continue 217 } 218 219 srcFieldName, destFieldName := getFieldName(name, flgs) 220 if fromField := source.FieldByName(srcFieldName); fromField.IsValid() && !shouldIgnore(fromField, opt.IgnoreEmpty) { 221 // process for nested anonymous field 222 destFieldNotSet := false 223 if f, ok := dest.Type().FieldByName(destFieldName); ok { 224 for idx := range f.Index { 225 destField := dest.FieldByIndex(f.Index[:idx+1]) 226 227 if destField.Kind() != reflect.Ptr { 228 continue 229 } 230 231 if !destField.IsNil() { 232 continue 233 } 234 if !destField.CanSet() { 235 destFieldNotSet = true 236 break 237 } 238 239 // destField is a nil pointer that can be set 240 newValue := reflect.New(destField.Type().Elem()) 241 destField.Set(newValue) 242 } 243 } 244 245 if destFieldNotSet { 246 break 247 } 248 249 toField := dest.FieldByName(destFieldName) 250 if toField.IsValid() { 251 if toField.CanSet() { 252 if !set(toField, fromField, opt.DeepCopy) { 253 if err := copier(toField.Addr().Interface(), fromField.Interface(), opt); err != nil { 254 return err 255 } 256 } 257 if fieldFlags != 0 { 258 // Note that a copy was made 259 flgs.BitFlags[name] = fieldFlags | hasCopied 260 } 261 } 262 } else { 263 // try to set to method 264 var toMethod reflect.Value 265 if dest.CanAddr() { 266 toMethod = dest.Addr().MethodByName(destFieldName) 267 } else { 268 toMethod = dest.MethodByName(destFieldName) 269 } 270 271 if toMethod.IsValid() && toMethod.Type().NumIn() == 1 && fromField.Type().AssignableTo(toMethod.Type().In(0)) { 272 toMethod.Call([]reflect.Value{fromField}) 273 } 274 } 275 } 276 } 277 278 // Copy from from method to dest field 279 for _, field := range deepFields(toType) { 280 name := field.Name 281 srcFieldName, destFieldName := getFieldName(name, flgs) 282 283 var fromMethod reflect.Value 284 if source.CanAddr() { 285 fromMethod = source.Addr().MethodByName(srcFieldName) 286 } else { 287 fromMethod = source.MethodByName(srcFieldName) 288 } 289 290 if fromMethod.IsValid() && fromMethod.Type().NumIn() == 0 && fromMethod.Type().NumOut() == 1 && !shouldIgnore(fromMethod, opt.IgnoreEmpty) { 291 if toField := dest.FieldByName(destFieldName); toField.IsValid() && toField.CanSet() { 292 values := fromMethod.Call([]reflect.Value{}) 293 if len(values) >= 1 { 294 set(toField, values[0], opt.DeepCopy) 295 } 296 } 297 } 298 } 299 } 300 301 if isSlice && to.Kind() == reflect.Slice { 302 if dest.Addr().Type().AssignableTo(to.Type().Elem()) { 303 if to.Len() < i+1 { 304 to.Set(reflect.Append(to, dest.Addr())) 305 } else { 306 if !set(to.Index(i), dest.Addr(), opt.DeepCopy) { 307 // ignore error while copy slice element 308 err = copier(to.Index(i).Addr().Interface(), dest.Addr().Interface(), opt) 309 if err != nil { 310 continue 311 } 312 } 313 } 314 } else if dest.Type().AssignableTo(to.Type().Elem()) { 315 if to.Len() < i+1 { 316 to.Set(reflect.Append(to, dest)) 317 } else { 318 if !set(to.Index(i), dest, opt.DeepCopy) { 319 // ignore error while copy slice element 320 err = copier(to.Index(i).Addr().Interface(), dest.Interface(), opt) 321 if err != nil { 322 continue 323 } 324 } 325 } 326 } 327 } else if initDest { 328 to.Set(dest) 329 } 330 331 err = checkBitFlags(flgs.BitFlags) 332 } 333 334 return 335 } 336 337 func shouldIgnore(v reflect.Value, ignoreEmpty bool) bool { 338 if !ignoreEmpty { 339 return false 340 } 341 342 return v.IsZero() 343 } 344 345 func deepFields(reflectType reflect.Type) []reflect.StructField { 346 if reflectType, _ = indirectType(reflectType); reflectType.Kind() == reflect.Struct { 347 fields := make([]reflect.StructField, 0, reflectType.NumField()) 348 349 for i := 0; i < reflectType.NumField(); i++ { 350 v := reflectType.Field(i) 351 if v.Anonymous { 352 fields = append(fields, deepFields(v.Type)...) 353 } else { 354 fields = append(fields, v) 355 } 356 } 357 358 return fields 359 } 360 361 return nil 362 } 363 364 func indirect(reflectValue reflect.Value) reflect.Value { 365 for reflectValue.Kind() == reflect.Ptr { 366 reflectValue = reflectValue.Elem() 367 } 368 return reflectValue 369 } 370 371 func indirectType(reflectType reflect.Type) (_ reflect.Type, isPtr bool) { 372 for reflectType.Kind() == reflect.Ptr || reflectType.Kind() == reflect.Slice { 373 reflectType = reflectType.Elem() 374 isPtr = true 375 } 376 return reflectType, isPtr 377 } 378 379 func set(to, from reflect.Value, deepCopy bool) bool { 380 if from.IsValid() { 381 if to.Kind() == reflect.Ptr { 382 // set `to` to nil if from is nil 383 if from.Kind() == reflect.Ptr && from.IsNil() { 384 to.Set(reflect.Zero(to.Type())) 385 return true 386 } else if to.IsNil() { 387 // `from` -> `to` 388 // sql.NullString -> *string 389 if fromValuer, ok := driverValuer(from); ok { 390 v, err := fromValuer.Value() 391 if err != nil { 392 return false 393 } 394 // if `from` is not valid do nothing with `to` 395 if v == nil { 396 return true 397 } 398 } 399 // allocate new `to` variable with default value (eg. *string -> new(string)) 400 to.Set(reflect.New(to.Type().Elem())) 401 } 402 // depointer `to` 403 to = to.Elem() 404 } 405 406 if deepCopy { 407 toKind := to.Kind() 408 if toKind == reflect.Interface && to.IsNil() { 409 if reflect.TypeOf(from.Interface()) != nil { 410 to.Set(reflect.New(reflect.TypeOf(from.Interface())).Elem()) 411 toKind = reflect.TypeOf(to.Interface()).Kind() 412 } 413 } 414 if toKind == reflect.Struct || toKind == reflect.Map || toKind == reflect.Slice { 415 return false 416 } 417 } 418 419 if from.Type().ConvertibleTo(to.Type()) { 420 to.Set(from.Convert(to.Type())) 421 } else if toScanner, ok := to.Addr().Interface().(sql.Scanner); ok { 422 // `from` -> `to` 423 // *string -> sql.NullString 424 if from.Kind() == reflect.Ptr { 425 // if `from` is nil do nothing with `to` 426 if from.IsNil() { 427 return true 428 } 429 // depointer `from` 430 from = indirect(from) 431 } 432 // `from` -> `to` 433 // string -> sql.NullString 434 // set `to` by invoking method Scan(`from`) 435 err := toScanner.Scan(from.Interface()) 436 if err != nil { 437 return false 438 } 439 } else if fromValuer, ok := driverValuer(from); ok { 440 // `from` -> `to` 441 // sql.NullString -> string 442 v, err := fromValuer.Value() 443 if err != nil { 444 return false 445 } 446 // if `from` is not valid do nothing with `to` 447 if v == nil { 448 return true 449 } 450 rv := reflect.ValueOf(v) 451 if rv.Type().AssignableTo(to.Type()) { 452 to.Set(rv) 453 } 454 } else if from.Kind() == reflect.Ptr { 455 return set(to, from.Elem(), deepCopy) 456 } else { 457 return false 458 } 459 } 460 461 return true 462 } 463 464 // parseTags Parses struct tags and returns uint8 bit flags. 465 func parseTags(tag string) (flg uint8, name string, err error) { 466 for _, t := range strings.Split(tag, ",") { 467 switch t { 468 case "-": 469 flg = tagIgnore 470 return 471 case "must": 472 flg = flg | tagMust 473 case "nopanic": 474 flg = flg | tagNoPanic 475 default: 476 if unicode.IsUpper([]rune(t)[0]) { 477 name = strings.TrimSpace(t) 478 } else { 479 err = errors.New("copier field name tag must be start upper case") 480 } 481 } 482 } 483 return 484 } 485 486 // getTagFlags Parses struct tags for bit flags, field name. 487 func getFlags(dest, src reflect.Value, toType, fromType reflect.Type) (flags, error) { 488 flgs := flags{ 489 BitFlags: map[string]uint8{}, 490 SrcNames: tagNameMapping{ 491 FieldNameToTag: map[string]string{}, 492 TagToFieldName: map[string]string{}, 493 }, 494 DestNames: tagNameMapping{ 495 FieldNameToTag: map[string]string{}, 496 TagToFieldName: map[string]string{}, 497 }, 498 } 499 var toTypeFields, fromTypeFields []reflect.StructField 500 if dest.IsValid() { 501 toTypeFields = deepFields(toType) 502 } 503 if src.IsValid() { 504 fromTypeFields = deepFields(fromType) 505 } 506 507 // Get a list dest of tags 508 for _, field := range toTypeFields { 509 tags := field.Tag.Get("copier") 510 if tags != "" { 511 var name string 512 var err error 513 if flgs.BitFlags[field.Name], name, err = parseTags(tags); err != nil { 514 return flags{}, err 515 } else if name != "" { 516 flgs.DestNames.FieldNameToTag[field.Name] = name 517 flgs.DestNames.TagToFieldName[name] = field.Name 518 } 519 } 520 } 521 522 // Get a list source of tags 523 for _, field := range fromTypeFields { 524 tags := field.Tag.Get("copier") 525 if tags != "" { 526 var name string 527 var err error 528 if _, name, err = parseTags(tags); err != nil { 529 return flags{}, err 530 } else if name != "" { 531 flgs.SrcNames.FieldNameToTag[field.Name] = name 532 flgs.SrcNames.TagToFieldName[name] = field.Name 533 } 534 } 535 } 536 return flgs, nil 537 } 538 539 // checkBitFlags Checks flags for error or panic conditions. 540 func checkBitFlags(flagsList map[string]uint8) (err error) { 541 // Check flag conditions were met 542 for name, flgs := range flagsList { 543 if flgs&hasCopied == 0 { 544 switch { 545 case flgs&tagMust != 0 && flgs&tagNoPanic != 0: 546 err = fmt.Errorf("field %s has must tag but was not copied", name) 547 return 548 case flgs&(tagMust) != 0: 549 panic(fmt.Sprintf("Field %s has must tag but was not copied", name)) 550 } 551 } 552 } 553 return 554 } 555 556 func getFieldName(fieldName string, flgs flags) (srcFieldName string, destFieldName string) { 557 // get dest field name 558 if srcTagName, ok := flgs.SrcNames.FieldNameToTag[fieldName]; ok { 559 destFieldName = srcTagName 560 if destTagName, ok := flgs.DestNames.TagToFieldName[srcTagName]; ok { 561 destFieldName = destTagName 562 } 563 } else { 564 if destTagName, ok := flgs.DestNames.TagToFieldName[fieldName]; ok { 565 destFieldName = destTagName 566 } 567 } 568 if destFieldName == "" { 569 destFieldName = fieldName 570 } 571 572 // get source field name 573 if destTagName, ok := flgs.DestNames.FieldNameToTag[fieldName]; ok { 574 srcFieldName = destTagName 575 if srcField, ok := flgs.SrcNames.TagToFieldName[destTagName]; ok { 576 srcFieldName = srcField 577 } 578 } else { 579 if srcField, ok := flgs.SrcNames.TagToFieldName[fieldName]; ok { 580 srcFieldName = srcField 581 } 582 } 583 584 if srcFieldName == "" { 585 srcFieldName = fieldName 586 } 587 return 588 } 589 590 func driverValuer(v reflect.Value) (i driver.Valuer, ok bool) { 591 592 if !v.CanAddr() { 593 i, ok = v.Interface().(driver.Valuer) 594 return 595 } 596 597 i, ok = v.Addr().Interface().(driver.Valuer) 598 return 599 }