github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/test/deepequal.go (about)

     1  // Copyright (c) 2014 Arista Networks, Inc.
     2  // Use of this source code is governed by the Apache License 2.0
     3  // that can be found in the COPYING file.
     4  
     5  package test
     6  
     7  import (
     8  	"bytes"
     9  	"math"
    10  	"reflect"
    11  	"runtime"
    12  	"time"
    13  
    14  	"github.com/aristanetworks/goarista/areflect"
    15  	"github.com/aristanetworks/goarista/key"
    16  
    17  	"google.golang.org/protobuf/proto"
    18  )
    19  
    20  var comparableType = reflect.TypeOf((*key.Comparable)(nil)).Elem()
    21  
    22  // DeepEqual is a faster implementation of reflect.DeepEqual that:
    23  //   - Has a reflection-free fast-path for all the common types we use.
    24  //   - Gives data types the ability to exclude some of their fields from the
    25  //     consideration of DeepEqual by tagging them with `deepequal:"ignore"`.
    26  //   - Gives data types the ability to define their own comparison method by
    27  //     implementing the comparable interface.
    28  //   - Supports "composite" (or "complex") keys in maps that are pointers.
    29  func DeepEqual(a, b interface{}) bool {
    30  	return deepEqual(a, b, nil)
    31  }
    32  
    33  // DeepEqualer allows collection types to compare their values using a
    34  // given comparator.
    35  type DeepEqualer interface {
    36  	DeepEqual(other interface{}, comparer func(a, b interface{}) bool) bool
    37  }
    38  
    39  func deepEqual(a, b interface{}, seen map[edge]struct{}) bool {
    40  	if a == nil || b == nil {
    41  		return a == b
    42  	}
    43  	switch a := a.(type) {
    44  	// Short circuit fast-path for common built-in types.
    45  	// Note: the cases are listed by frequency.
    46  	case bool:
    47  		return a == b
    48  
    49  	case map[string]interface{}:
    50  		v, ok := b.(map[string]interface{})
    51  		if !ok || len(a) != len(v) {
    52  			return false
    53  		}
    54  		for key, value := range a {
    55  			if other, ok := v[key]; !ok || !deepEqual(value, other, seen) {
    56  				return false
    57  			}
    58  		}
    59  		return true
    60  
    61  	case string, uint32, uint64, int32,
    62  		uint16, int16, uint8, int8, int64:
    63  		return a == b
    64  
    65  	case *map[string]interface{}:
    66  		v, ok := b.(*map[string]interface{})
    67  		if !ok || a == nil || v == nil {
    68  			return ok && a == v
    69  		}
    70  		return deepEqual(*a, *v, seen)
    71  
    72  	case map[interface{}]interface{}:
    73  		v, ok := b.(map[interface{}]interface{})
    74  		if !ok {
    75  			return false
    76  		}
    77  		// We compare in both directions to catch keys that are in b but not
    78  		// in a.  It sucks to have to do another O(N^2) for this, but oh well.
    79  		return mapEqual(a, v) && mapEqual(v, a)
    80  
    81  	case float32:
    82  		v, ok := b.(float32)
    83  		return ok && (a == b || (math.IsNaN(float64(a)) && math.IsNaN(float64(v))))
    84  	case float64:
    85  		v, ok := b.(float64)
    86  		return ok && (a == b || (math.IsNaN(a) && math.IsNaN(v)))
    87  
    88  	case []string:
    89  		v, ok := b.([]string)
    90  		if !ok || len(a) != len(v) {
    91  			return false
    92  		}
    93  		for i, s := range a {
    94  			if s != v[i] {
    95  				return false
    96  			}
    97  		}
    98  		return true
    99  	case []byte:
   100  		v, ok := b.([]byte)
   101  		return ok && bytes.Equal(a, v)
   102  
   103  	case map[uint64]interface{}:
   104  		v, ok := b.(map[uint64]interface{})
   105  		if !ok || len(a) != len(v) {
   106  			return false
   107  		}
   108  		for key, value := range a {
   109  			if other, ok := v[key]; !ok || !deepEqual(value, other, seen) {
   110  				return false
   111  			}
   112  		}
   113  		return true
   114  
   115  	case *map[interface{}]interface{}:
   116  		v, ok := b.(*map[interface{}]interface{})
   117  		if !ok || a == nil || v == nil {
   118  			return ok && a == v
   119  		}
   120  		return deepEqual(*a, *v, seen)
   121  
   122  	case error:
   123  		if sb, ok := b.(error); ok {
   124  			return errorStringFromError(a) == errorStringFromError(sb)
   125  		}
   126  		return false
   127  
   128  	case DeepEqualer:
   129  		return a.DeepEqual(b, func(x, y interface{}) bool {
   130  			return deepEqual(x, y, seen)
   131  		})
   132  
   133  	case key.Comparable:
   134  		return a.Equal(b)
   135  
   136  	case time.Time:
   137  		bt, ok := b.(time.Time)
   138  		return ok && a.Equal(bt)
   139  
   140  	case proto.Message:
   141  		bMsg, ok := b.(proto.Message)
   142  		if !ok || a == nil || bMsg == nil {
   143  			return ok && a == bMsg
   144  		}
   145  		return proto.Equal(a, bMsg)
   146  	case []uint32:
   147  		v, ok := b.([]uint32)
   148  		if !ok || len(a) != len(v) {
   149  			return false
   150  		}
   151  		for i, s := range a {
   152  			if s != v[i] {
   153  				return false
   154  			}
   155  		}
   156  		return true
   157  	case []uint64:
   158  		v, ok := b.([]uint64)
   159  		if !ok || len(a) != len(v) {
   160  			return false
   161  		}
   162  		for i, s := range a {
   163  			if s != v[i] {
   164  				return false
   165  			}
   166  		}
   167  		return true
   168  	case []interface{}:
   169  		v, ok := b.([]interface{})
   170  		if !ok || len(a) != len(v) {
   171  			return false
   172  		}
   173  		for i, s := range a {
   174  			if !deepEqual(s, v[i], seen) {
   175  				return false
   176  			}
   177  		}
   178  		return true
   179  	case *[]string:
   180  		v, ok := b.(*[]string)
   181  		if !ok || a == nil || v == nil {
   182  			return ok && a == v
   183  		}
   184  		return deepEqual(*a, *v, seen)
   185  	case *[]interface{}:
   186  		v, ok := b.(*[]interface{})
   187  		if !ok || a == nil || v == nil {
   188  			return ok && a == v
   189  		}
   190  		return deepEqual(*a, *v, seen)
   191  
   192  	default:
   193  		// Handle other kinds of non-comparable objects.
   194  		return genericDeepEqual(a, b, seen)
   195  	}
   196  }
   197  
   198  func errorStringFromError(err error) string {
   199  	if v := reflect.ValueOf(err); v.Kind() == reflect.Ptr && v.IsNil() {
   200  		return ""
   201  	}
   202  	return err.Error()
   203  }
   204  
   205  type edge struct {
   206  	from uintptr
   207  	to   uintptr
   208  }
   209  
   210  func genericDeepEqual(a, b interface{}, seen map[edge]struct{}) bool {
   211  	av := reflect.ValueOf(a)
   212  	bv := reflect.ValueOf(b)
   213  	if avalid, bvalid := av.IsValid(), bv.IsValid(); !avalid || !bvalid {
   214  		return avalid == bvalid
   215  	}
   216  	if bv.Type() != av.Type() {
   217  		return false
   218  	}
   219  
   220  	switch av.Kind() {
   221  	case reflect.Ptr:
   222  		if av.IsNil() || bv.IsNil() {
   223  			return a == b
   224  		}
   225  
   226  		av = av.Elem()
   227  		bv = bv.Elem()
   228  		if av.CanAddr() && bv.CanAddr() {
   229  			e := edge{from: av.UnsafeAddr(), to: bv.UnsafeAddr()}
   230  			// Detect and prevent cycles.
   231  			if seen == nil {
   232  				seen = make(map[edge]struct{})
   233  			} else if _, ok := seen[e]; ok {
   234  				return true
   235  			}
   236  			seen[e] = struct{}{}
   237  		}
   238  
   239  		return deepEqual(av.Interface(), bv.Interface(), seen)
   240  	case reflect.Slice, reflect.Array:
   241  		l := av.Len()
   242  		if l != bv.Len() {
   243  			return false
   244  		}
   245  		for i := 0; i < l; i++ {
   246  			if !deepEqual(av.Index(i).Interface(), bv.Index(i).Interface(), seen) {
   247  				return false
   248  			}
   249  		}
   250  		return true
   251  	case reflect.Map:
   252  		if av.IsNil() != bv.IsNil() {
   253  			return false
   254  		}
   255  		if av.Len() != bv.Len() {
   256  			return false
   257  		}
   258  		if av.Pointer() == bv.Pointer() {
   259  			return true
   260  		}
   261  		for _, k := range av.MapKeys() {
   262  			// Upon finding the first key that's a pointer, we bail out and do
   263  			// a O(N^2) comparison.
   264  			if kk := k.Kind(); kk == reflect.Ptr || kk == reflect.Interface {
   265  				ok, _, _ := complexKeyMapEqual(av, bv, seen)
   266  				return ok
   267  			}
   268  			ea := av.MapIndex(k)
   269  			eb := bv.MapIndex(k)
   270  			if !eb.IsValid() {
   271  				return false
   272  			}
   273  			if !deepEqual(ea.Interface(), eb.Interface(), seen) {
   274  				return false
   275  			}
   276  		}
   277  		return true
   278  	case reflect.Struct:
   279  		typ := av.Type()
   280  		if typ.Implements(comparableType) {
   281  			return av.Interface().(key.Comparable).Equal(bv.Interface())
   282  		}
   283  		for i, n := 0, av.NumField(); i < n; i++ {
   284  			if typ.Field(i).Tag.Get("deepequal") == "ignore" {
   285  				continue
   286  			}
   287  			af := areflect.ForceExport(av.Field(i))
   288  			bf := areflect.ForceExport(bv.Field(i))
   289  			if !deepEqual(af.Interface(), bf.Interface(), seen) {
   290  				return false
   291  			}
   292  		}
   293  		return true
   294  	case reflect.Func:
   295  		if av.IsNil() != bv.IsNil() {
   296  			return false
   297  		}
   298  		if av.IsNil() {
   299  			return true
   300  		}
   301  		// if both of these are functions, compare their full names
   302  		return runtime.FuncForPC(av.Pointer()).Name() ==
   303  			runtime.FuncForPC(bv.Pointer()).Name()
   304  	default:
   305  		// Incomparable types cannot be compared with DeepEqual.
   306  		// To skip comparisons, tag the type with `deepequal:"ignore"`
   307  		if !av.Type().Comparable() || !bv.Type().Comparable() {
   308  			return false
   309  		}
   310  		// Other the basic types.
   311  		return a == b
   312  	}
   313  }
   314  
   315  // Compares two maps with complex keys (that are pointers).  This assumes the
   316  // maps have already been checked to have the same sizes.  The cost of this
   317  // function is O(N^2) in the size of the input maps.
   318  //
   319  // The return is to be interpreted this way:
   320  //
   321  //	true, _, _            =>   av == bv
   322  //	false, key, invalid   =>   the given key wasn't found in bv
   323  //	false, key, value     =>   the given key had the given value in bv,
   324  //	                           which is different in av
   325  func complexKeyMapEqual(av, bv reflect.Value,
   326  	seen map[edge]struct{}) (bool, reflect.Value, reflect.Value) {
   327  	for _, ka := range av.MapKeys() {
   328  		var eb reflect.Value // The entry in bv with a key equal to ka
   329  		for _, kb := range bv.MapKeys() {
   330  			if deepEqual(ka.Elem().Interface(), kb.Elem().Interface(), seen) {
   331  				// Found the corresponding entry in bv.
   332  				eb = bv.MapIndex(kb)
   333  				break
   334  			}
   335  		}
   336  		if !eb.IsValid() { // We didn't find a key equal to `ka' in 'bv'.
   337  			return false, ka, reflect.Value{}
   338  		}
   339  		ea := av.MapIndex(ka)
   340  		if !deepEqual(ea.Interface(), eb.Interface(), seen) {
   341  			return false, ka, eb
   342  		}
   343  	}
   344  	return true, reflect.Value{}, reflect.Value{}
   345  }
   346  
   347  // mapEqual does O(N^2) comparisons to check that all the keys present in the
   348  // first map are also present in the second map and have identical values.
   349  func mapEqual(a, b map[interface{}]interface{}) bool {
   350  	if len(a) != len(b) {
   351  		return false
   352  	}
   353  	for akey, avalue := range a {
   354  		found := false
   355  		for bkey, bvalue := range b {
   356  			if DeepEqual(akey, bkey) {
   357  				if !DeepEqual(avalue, bvalue) {
   358  					return false
   359  				}
   360  				found = true
   361  				break
   362  			}
   363  		}
   364  		if !found {
   365  			return false
   366  		}
   367  	}
   368  	return true
   369  }