github.com/slackhq/nebula@v1.9.0/test/assert.go (about)

     1  package test
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"testing"
     7  	"time"
     8  	"unsafe"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  )
    12  
    13  // AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
    14  // There is currently a special case for `time.loc` (as this code traverses into unexported fields)
    15  func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
    16  	v1 := reflect.ValueOf(a)
    17  	v2 := reflect.ValueOf(b)
    18  
    19  	if !assert.Equal(t, v1.Type(), v2.Type()) {
    20  		return
    21  	}
    22  
    23  	traverseDeepCopy(t, v1, v2, v1.Type().String())
    24  }
    25  
    26  func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool {
    27  	switch v1.Kind() {
    28  	case reflect.Array:
    29  		for i := 0; i < v1.Len(); i++ {
    30  			if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
    31  				return false
    32  			}
    33  		}
    34  		return true
    35  
    36  	case reflect.Slice:
    37  		if v1.IsNil() || v2.IsNil() {
    38  			return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2)
    39  		}
    40  
    41  		if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) {
    42  			return false
    43  		}
    44  
    45  		// A slice with cap 0
    46  		if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) {
    47  			return false
    48  		}
    49  
    50  		v1c := v1.Cap()
    51  		v2c := v2.Cap()
    52  		if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() {
    53  			return assert.Fail(t, "", "%s share some underlying memory", name)
    54  		}
    55  
    56  		for i := 0; i < v1.Len(); i++ {
    57  			if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
    58  				return false
    59  			}
    60  		}
    61  		return true
    62  
    63  	case reflect.Interface:
    64  		if v1.IsNil() || v2.IsNil() {
    65  			return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
    66  		}
    67  		return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
    68  
    69  	case reflect.Ptr:
    70  		local := reflect.ValueOf(time.Local).Pointer()
    71  		if local == v1.Pointer() && local == v2.Pointer() {
    72  			return true
    73  		}
    74  
    75  		if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) {
    76  			return false
    77  		}
    78  
    79  		return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
    80  
    81  	case reflect.Struct:
    82  		for i, n := 0, v1.NumField(); i < n; i++ {
    83  			if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) {
    84  				return false
    85  			}
    86  		}
    87  		return true
    88  
    89  	case reflect.Map:
    90  		if v1.IsNil() || v2.IsNil() {
    91  			return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
    92  		}
    93  
    94  		if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) {
    95  			return false
    96  		}
    97  
    98  		if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) {
    99  			return false
   100  		}
   101  
   102  		for _, k := range v1.MapKeys() {
   103  			val1 := v1.MapIndex(k)
   104  			val2 := v2.MapIndex(k)
   105  			if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) {
   106  				return false
   107  			}
   108  
   109  			if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) {
   110  				return false
   111  			}
   112  
   113  			if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) {
   114  				return false
   115  			}
   116  		}
   117  
   118  		return true
   119  
   120  	default:
   121  		if v1.CanInterface() && v2.CanInterface() {
   122  			return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name)
   123  		}
   124  
   125  		e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface()
   126  		e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface()
   127  
   128  		return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name)
   129  	}
   130  }