github.com/wfusion/gofusion@v1.1.14/common/utils/clone/wrapper.go (about)

     1  // Copyright 2019 Huan Du. All rights reserved.
     2  // Licensed under the MIT license that can be found in the LICENSE file.
     3  
     4  package clone
     5  
     6  import (
     7  	"encoding/binary"
     8  	"hash/crc64"
     9  	"reflect"
    10  	"sync"
    11  	"unsafe"
    12  )
    13  
    14  var (
    15  	sizeOfChecksum = unsafe.Sizeof(uint64(0))
    16  
    17  	crc64Table = crc64.MakeTable(crc64.ECMA)
    18  
    19  	cachedWrapperTypes sync.Map
    20  )
    21  
    22  // Wrap creates a wrapper of v, which must be a pointer.
    23  // If v is not a pointer, Wrap simply returns v and do nothing.
    24  //
    25  // The wrapper is a deep clone of v's value. It holds a shadow copy to v internally.
    26  //
    27  //	t := &T{Foo: 123}
    28  //	v := Wrap(t).(*T)               // v is a clone of t.
    29  //	reflect.DeepEqual(t, v) == true // v equals t.
    30  //	v.Foo = 456                     // v.Foo is changed, but t.Foo doesn't change.
    31  //	orig := Unwrap(v)               // Use `Unwrap` to discard wrapper and return original value, which is t.
    32  //	orig.(*T) == t                  // orig and t is exactly the same.
    33  //	Undo(v)                         // Use `Undo` to discard any change on v.
    34  //	v.Foo == t.Foo                  // Now, the value of v and t are the same again.
    35  func wrap(v interface{}) interface{} {
    36  	if v == nil {
    37  		return v
    38  	}
    39  
    40  	val := reflect.ValueOf(v)
    41  	pt := val.Type()
    42  
    43  	if val.Kind() != reflect.Ptr {
    44  		return v
    45  	}
    46  
    47  	t := pt.Elem()
    48  	elem := val.Elem()
    49  	ptr := unsafe.Pointer(val.Pointer())
    50  	cache, ok := cachedWrapperTypes.Load(t)
    51  
    52  	if !ok {
    53  		cache = reflect.StructOf([]reflect.StructField{
    54  			{
    55  				Name:      "T",
    56  				Type:      t,
    57  				Anonymous: true,
    58  			},
    59  			{
    60  				Name: "Checksum",
    61  				Type: reflect.TypeOf(uint64(0)),
    62  			},
    63  			{
    64  				Name: "Origin",
    65  				Type: pt,
    66  			},
    67  		})
    68  		cachedWrapperTypes.Store(t, cache)
    69  	}
    70  
    71  	wrapperType := cache.(reflect.Type)
    72  	pw := defaultAllocator.New(wrapperType)
    73  
    74  	wrapperPtr := unsafe.Pointer(pw.Pointer())
    75  	wrapper := pw.Elem()
    76  
    77  	// Equivalent code: wrapper.T = Clone(v)
    78  	field := wrapper.Field(0)
    79  	field.Set(heapCloneState.clone(elem))
    80  
    81  	// Equivalent code: wrapper.Checksum = makeChecksum(v)
    82  	checksumPtr := unsafe.Pointer(uintptr(wrapperPtr) + t.Size())
    83  	*(*uint64)(checksumPtr) = makeChecksum(t, uintptr(wrapperPtr), uintptr(ptr))
    84  
    85  	// Equivalent code: wrapper.Origin = v
    86  	originPtr := unsafe.Pointer(uintptr(wrapperPtr) + t.Size() + sizeOfChecksum)
    87  	*(*uintptr)(originPtr) = uintptr(ptr)
    88  
    89  	return field.Addr().Interface()
    90  }
    91  
    92  func validateChecksum(t reflect.Type, ptr unsafe.Pointer) bool {
    93  	pw := uintptr(ptr)
    94  	orig := uintptr(getOrigin(t, ptr))
    95  	checksum := *(*uint64)(unsafe.Pointer(uintptr(ptr) + t.Size()))
    96  	expected := makeChecksum(t, pw, orig)
    97  
    98  	return checksum == expected
    99  }
   100  
   101  func makeChecksum(t reflect.Type, pw uintptr, orig uintptr) uint64 {
   102  	var data [binary.MaxVarintLen64 * 2]byte
   103  	binary.PutUvarint(data[:binary.MaxVarintLen64], uint64(pw))
   104  	binary.PutUvarint(data[binary.MaxVarintLen64:], uint64(orig))
   105  	return crc64.Checksum(data[:], crc64Table)
   106  }
   107  
   108  func getOrigin(t reflect.Type, ptr unsafe.Pointer) unsafe.Pointer {
   109  	return *(*unsafe.Pointer)(unsafe.Pointer(uintptr(ptr) + t.Size() + sizeOfChecksum))
   110  }
   111  
   112  // Unwrap returns v's original value if v is a wrapped value.
   113  // Otherwise, simply returns v itself.
   114  func unwrap(v interface{}) interface{} {
   115  	if v == nil {
   116  		return v
   117  	}
   118  
   119  	val := reflect.ValueOf(v)
   120  
   121  	if !isWrapped(val) {
   122  		return v
   123  	}
   124  
   125  	origVal := origin(val)
   126  	return origVal.Interface()
   127  }
   128  
   129  func origin(val reflect.Value) reflect.Value {
   130  	pt := val.Type()
   131  	t := pt.Elem()
   132  	ptr := unsafe.Pointer(val.Pointer())
   133  	orig := getOrigin(t, ptr)
   134  	origVal := reflect.NewAt(t, orig)
   135  	return origVal
   136  }
   137  
   138  // Undo discards any change made in wrapped value.
   139  // If v is not a wrapped value, nothing happens.
   140  func undo(v interface{}) {
   141  	if v == nil {
   142  		return
   143  	}
   144  
   145  	val := reflect.ValueOf(v)
   146  
   147  	if !isWrapped(val) {
   148  		return
   149  	}
   150  
   151  	origVal := origin(val)
   152  	elem := val.Elem()
   153  	elem.Set(heapCloneState.clone(origVal.Elem()))
   154  }
   155  
   156  func isWrapped(val reflect.Value) bool {
   157  	pt := val.Type()
   158  
   159  	if pt.Kind() != reflect.Ptr {
   160  		return false
   161  	}
   162  
   163  	t := pt.Elem()
   164  	ptr := unsafe.Pointer(val.Pointer())
   165  	return validateChecksum(t, ptr)
   166  }