github.com/Psiphon-Labs/goarista@v0.0.0-20160825065156-d002785f4c67/test/deepequal.go (about)

     1  // Copyright (C) 2014  Arista Networks, Inc.
     2  // Use of this source code is governed by the Apache License 2.0
     3  // that can be found in the COPYING file.
     4  
     5  package test
     6  
     7  import (
     8  	"bytes"
     9  	"math"
    10  	"reflect"
    11  
    12  	"github.com/aristanetworks/goarista/areflect"
    13  	"github.com/aristanetworks/goarista/key"
    14  )
    15  
    16  var comparableType = reflect.TypeOf((*key.Comparable)(nil)).Elem()
    17  
    18  // DeepEqual is a faster implementation of reflect.DeepEqual that:
    19  //   - Has a reflection-free fast-path for all the common types we use.
    20  //   - Gives data types the ability to exclude some of their fields from the
    21  //     consideration of DeepEqual by tagging them with `deepequal:"ignore"`.
    22  //   - Gives data types the ability to define their own comparison method by
    23  //     implementing the comparable interface.
    24  //   - Supports "composite" (or "complex") keys in maps that are pointers.
    25  func DeepEqual(a, b interface{}) bool {
    26  	return deepEqual(a, b, nil)
    27  }
    28  
    29  func deepEqual(a, b interface{}, seen map[edge]struct{}) bool {
    30  	if a == nil || b == nil {
    31  		return a == b
    32  	}
    33  	switch a := a.(type) {
    34  	// Short circuit fast-path for common built-in types.
    35  	// Note: the cases are listed by frequency.
    36  	case bool:
    37  		return a == b
    38  
    39  	case map[string]interface{}:
    40  		v, ok := b.(map[string]interface{})
    41  		if !ok || len(a) != len(v) {
    42  			return false
    43  		}
    44  		for key, value := range a {
    45  			if other, ok := v[key]; !ok || !deepEqual(value, other, seen) {
    46  				return false
    47  			}
    48  		}
    49  		return true
    50  
    51  	case string, uint32, uint64, int32,
    52  		uint16, int16, uint8, int8, int64:
    53  		return a == b
    54  
    55  	case *map[string]interface{}:
    56  		v, ok := b.(*map[string]interface{})
    57  		if !ok || a == nil || v == nil {
    58  			return ok && a == v
    59  		}
    60  		return deepEqual(*a, *v, seen)
    61  
    62  	case map[interface{}]interface{}:
    63  		v, ok := b.(map[interface{}]interface{})
    64  		if !ok {
    65  			return false
    66  		}
    67  		// We compare in both directions to catch keys that are in b but not
    68  		// in a.  It sucks to have to do another O(N^2) for this, but oh well.
    69  		return mapEqual(a, v) && mapEqual(v, a)
    70  
    71  	case float32:
    72  		v, ok := b.(float32)
    73  		return ok && (a == b || (math.IsNaN(float64(a)) && math.IsNaN(float64(v))))
    74  	case float64:
    75  		v, ok := b.(float64)
    76  		return ok && (a == b || (math.IsNaN(a) && math.IsNaN(v)))
    77  
    78  	case []string:
    79  		v, ok := b.([]string)
    80  		if !ok || len(a) != len(v) {
    81  			return false
    82  		}
    83  		for i, s := range a {
    84  			if s != v[i] {
    85  				return false
    86  			}
    87  		}
    88  		return true
    89  	case []byte:
    90  		v, ok := b.([]byte)
    91  		return ok && bytes.Equal(a, v)
    92  
    93  	case map[uint64]interface{}:
    94  		v, ok := b.(map[uint64]interface{})
    95  		if !ok || len(a) != len(v) {
    96  			return false
    97  		}
    98  		for key, value := range a {
    99  			if other, ok := v[key]; !ok || !deepEqual(value, other, seen) {
   100  				return false
   101  			}
   102  		}
   103  		return true
   104  
   105  	case *map[interface{}]interface{}:
   106  		v, ok := b.(*map[interface{}]interface{})
   107  		if !ok || a == nil || v == nil {
   108  			return ok && a == v
   109  		}
   110  		return deepEqual(*a, *v, seen)
   111  	case key.Comparable:
   112  		return a.Equal(b)
   113  
   114  	case []uint32:
   115  		v, ok := b.([]uint32)
   116  		if !ok || len(a) != len(v) {
   117  			return false
   118  		}
   119  		for i, s := range a {
   120  			if s != v[i] {
   121  				return false
   122  			}
   123  		}
   124  		return true
   125  	case []uint64:
   126  		v, ok := b.([]uint64)
   127  		if !ok || len(a) != len(v) {
   128  			return false
   129  		}
   130  		for i, s := range a {
   131  			if s != v[i] {
   132  				return false
   133  			}
   134  		}
   135  		return true
   136  	case []interface{}:
   137  		v, ok := b.([]interface{})
   138  		if !ok || len(a) != len(v) {
   139  			return false
   140  		}
   141  		for i, s := range a {
   142  			if !deepEqual(s, v[i], seen) {
   143  				return false
   144  			}
   145  		}
   146  		return true
   147  	case *[]string:
   148  		v, ok := b.(*[]string)
   149  		if !ok || a == nil || v == nil {
   150  			return ok && a == v
   151  		}
   152  		return deepEqual(*a, *v, seen)
   153  	case *[]interface{}:
   154  		v, ok := b.(*[]interface{})
   155  		if !ok || a == nil || v == nil {
   156  			return ok && a == v
   157  		}
   158  		return deepEqual(*a, *v, seen)
   159  
   160  	default:
   161  		// Handle other kinds of non-comparable objects.
   162  		return genericDeepEqual(a, b, seen)
   163  	}
   164  }
   165  
   166  type edge struct {
   167  	from uintptr
   168  	to   uintptr
   169  }
   170  
   171  func genericDeepEqual(a, b interface{}, seen map[edge]struct{}) bool {
   172  	av := reflect.ValueOf(a)
   173  	bv := reflect.ValueOf(b)
   174  	if avalid, bvalid := av.IsValid(), bv.IsValid(); !avalid || !bvalid {
   175  		return avalid == bvalid
   176  	}
   177  	if bv.Type() != av.Type() {
   178  		return false
   179  	}
   180  
   181  	switch av.Kind() {
   182  	case reflect.Ptr:
   183  		if av.IsNil() || bv.IsNil() {
   184  			return a == b
   185  		}
   186  
   187  		av = av.Elem()
   188  		bv = bv.Elem()
   189  		if av.CanAddr() && bv.CanAddr() {
   190  			e := edge{from: av.UnsafeAddr(), to: bv.UnsafeAddr()}
   191  			// Detect and prevent cycles.
   192  			if seen == nil {
   193  				seen = make(map[edge]struct{})
   194  			} else if _, ok := seen[e]; ok {
   195  				return true
   196  			}
   197  			seen[e] = struct{}{}
   198  		}
   199  
   200  		return deepEqual(av.Interface(), bv.Interface(), seen)
   201  	case reflect.Slice, reflect.Array:
   202  		l := av.Len()
   203  		if l != bv.Len() {
   204  			return false
   205  		}
   206  		for i := 0; i < l; i++ {
   207  			if !deepEqual(av.Index(i).Interface(), bv.Index(i).Interface(), seen) {
   208  				return false
   209  			}
   210  		}
   211  		return true
   212  	case reflect.Map:
   213  		if av.IsNil() != bv.IsNil() {
   214  			return false
   215  		}
   216  		if av.Len() != bv.Len() {
   217  			return false
   218  		}
   219  		if av.Pointer() == bv.Pointer() {
   220  			return true
   221  		}
   222  		for _, k := range av.MapKeys() {
   223  			// Upon finding the first key that's a pointer, we bail out and do
   224  			// a O(N^2) comparison.
   225  			if kk := k.Kind(); kk == reflect.Ptr || kk == reflect.Interface {
   226  				ok, _, _ := complexKeyMapEqual(av, bv, seen)
   227  				return ok
   228  			}
   229  			ea := av.MapIndex(k)
   230  			eb := bv.MapIndex(k)
   231  			if !eb.IsValid() {
   232  				return false
   233  			}
   234  			if !deepEqual(ea.Interface(), eb.Interface(), seen) {
   235  				return false
   236  			}
   237  		}
   238  		return true
   239  	case reflect.Struct:
   240  		typ := av.Type()
   241  		if typ.Implements(comparableType) {
   242  			return av.Interface().(key.Comparable).Equal(bv.Interface())
   243  		}
   244  		for i, n := 0, av.NumField(); i < n; i++ {
   245  			if typ.Field(i).Tag.Get("deepequal") == "ignore" {
   246  				continue
   247  			}
   248  			af := areflect.ForceExport(av.Field(i))
   249  			bf := areflect.ForceExport(bv.Field(i))
   250  			if !deepEqual(af.Interface(), bf.Interface(), seen) {
   251  				return false
   252  			}
   253  		}
   254  		return true
   255  	default:
   256  		// Other the basic types.
   257  		return a == b
   258  	}
   259  }
   260  
   261  // Compares two maps with complex keys (that are pointers).  This assumes the
   262  // maps have already been checked to have the same sizes.  The cost of this
   263  // function is O(N^2) in the size of the input maps.
   264  //
   265  // The return is to be interpreted this way:
   266  //    true, _, _            =>   av == bv
   267  //    false, key, invalid   =>   the given key wasn't found in bv
   268  //    false, key, value     =>   the given key had the given value in bv,
   269  //                               which is different in av
   270  func complexKeyMapEqual(av, bv reflect.Value,
   271  	seen map[edge]struct{}) (bool, reflect.Value, reflect.Value) {
   272  	for _, ka := range av.MapKeys() {
   273  		var eb reflect.Value // The entry in bv with a key equal to ka
   274  		for _, kb := range bv.MapKeys() {
   275  			if deepEqual(ka.Elem().Interface(), kb.Elem().Interface(), seen) {
   276  				// Found the corresponding entry in bv.
   277  				eb = bv.MapIndex(kb)
   278  				break
   279  			}
   280  		}
   281  		if !eb.IsValid() { // We didn't find a key equal to `ka' in 'bv'.
   282  			return false, ka, reflect.Value{}
   283  		}
   284  		ea := av.MapIndex(ka)
   285  		if !deepEqual(ea.Interface(), eb.Interface(), seen) {
   286  			return false, ka, eb
   287  		}
   288  	}
   289  	return true, reflect.Value{}, reflect.Value{}
   290  }
   291  
   292  // mapEqual does O(N^2) comparisons to check that all the keys present in the
   293  // first map are also present in the second map and have identical values.
   294  func mapEqual(a, b map[interface{}]interface{}) bool {
   295  	if len(a) != len(b) {
   296  		return false
   297  	}
   298  	for akey, avalue := range a {
   299  		found := false
   300  		for bkey, bvalue := range b {
   301  			if DeepEqual(akey, bkey) {
   302  				if !DeepEqual(avalue, bvalue) {
   303  					return false
   304  				}
   305  				found = true
   306  				break
   307  			}
   308  		}
   309  		if !found {
   310  			return false
   311  		}
   312  	}
   313  	return true
   314  }