github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/test/deepequal.go (about) 1 // Copyright (c) 2014 Arista Networks, Inc. 2 // Use of this source code is governed by the Apache License 2.0 3 // that can be found in the COPYING file. 4 5 package test 6 7 import ( 8 "bytes" 9 "math" 10 "reflect" 11 "runtime" 12 "time" 13 14 "github.com/aristanetworks/goarista/areflect" 15 "github.com/aristanetworks/goarista/key" 16 17 "google.golang.org/protobuf/proto" 18 ) 19 20 var comparableType = reflect.TypeOf((*key.Comparable)(nil)).Elem() 21 22 // DeepEqual is a faster implementation of reflect.DeepEqual that: 23 // - Has a reflection-free fast-path for all the common types we use. 24 // - Gives data types the ability to exclude some of their fields from the 25 // consideration of DeepEqual by tagging them with `deepequal:"ignore"`. 26 // - Gives data types the ability to define their own comparison method by 27 // implementing the comparable interface. 28 // - Supports "composite" (or "complex") keys in maps that are pointers. 29 func DeepEqual(a, b interface{}) bool { 30 return deepEqual(a, b, nil) 31 } 32 33 // DeepEqualer allows collection types to compare their values using a 34 // given comparator. 35 type DeepEqualer interface { 36 DeepEqual(other interface{}, comparer func(a, b interface{}) bool) bool 37 } 38 39 func deepEqual(a, b interface{}, seen map[edge]struct{}) bool { 40 if a == nil || b == nil { 41 return a == b 42 } 43 switch a := a.(type) { 44 // Short circuit fast-path for common built-in types. 45 // Note: the cases are listed by frequency. 46 case bool: 47 return a == b 48 49 case map[string]interface{}: 50 v, ok := b.(map[string]interface{}) 51 if !ok || len(a) != len(v) { 52 return false 53 } 54 for key, value := range a { 55 if other, ok := v[key]; !ok || !deepEqual(value, other, seen) { 56 return false 57 } 58 } 59 return true 60 61 case string, uint32, uint64, int32, 62 uint16, int16, uint8, int8, int64: 63 return a == b 64 65 case *map[string]interface{}: 66 v, ok := b.(*map[string]interface{}) 67 if !ok || a == nil || v == nil { 68 return ok && a == v 69 } 70 return deepEqual(*a, *v, seen) 71 72 case map[interface{}]interface{}: 73 v, ok := b.(map[interface{}]interface{}) 74 if !ok { 75 return false 76 } 77 // We compare in both directions to catch keys that are in b but not 78 // in a. It sucks to have to do another O(N^2) for this, but oh well. 79 return mapEqual(a, v) && mapEqual(v, a) 80 81 case float32: 82 v, ok := b.(float32) 83 return ok && (a == b || (math.IsNaN(float64(a)) && math.IsNaN(float64(v)))) 84 case float64: 85 v, ok := b.(float64) 86 return ok && (a == b || (math.IsNaN(a) && math.IsNaN(v))) 87 88 case []string: 89 v, ok := b.([]string) 90 if !ok || len(a) != len(v) { 91 return false 92 } 93 for i, s := range a { 94 if s != v[i] { 95 return false 96 } 97 } 98 return true 99 case []byte: 100 v, ok := b.([]byte) 101 return ok && bytes.Equal(a, v) 102 103 case map[uint64]interface{}: 104 v, ok := b.(map[uint64]interface{}) 105 if !ok || len(a) != len(v) { 106 return false 107 } 108 for key, value := range a { 109 if other, ok := v[key]; !ok || !deepEqual(value, other, seen) { 110 return false 111 } 112 } 113 return true 114 115 case *map[interface{}]interface{}: 116 v, ok := b.(*map[interface{}]interface{}) 117 if !ok || a == nil || v == nil { 118 return ok && a == v 119 } 120 return deepEqual(*a, *v, seen) 121 122 case error: 123 if sb, ok := b.(error); ok { 124 return errorStringFromError(a) == errorStringFromError(sb) 125 } 126 return false 127 128 case DeepEqualer: 129 return a.DeepEqual(b, func(x, y interface{}) bool { 130 return deepEqual(x, y, seen) 131 }) 132 133 case key.Comparable: 134 return a.Equal(b) 135 136 case time.Time: 137 bt, ok := b.(time.Time) 138 return ok && a.Equal(bt) 139 140 case proto.Message: 141 bMsg, ok := b.(proto.Message) 142 if !ok || a == nil || bMsg == nil { 143 return ok && a == bMsg 144 } 145 return proto.Equal(a, bMsg) 146 case []uint32: 147 v, ok := b.([]uint32) 148 if !ok || len(a) != len(v) { 149 return false 150 } 151 for i, s := range a { 152 if s != v[i] { 153 return false 154 } 155 } 156 return true 157 case []uint64: 158 v, ok := b.([]uint64) 159 if !ok || len(a) != len(v) { 160 return false 161 } 162 for i, s := range a { 163 if s != v[i] { 164 return false 165 } 166 } 167 return true 168 case []interface{}: 169 v, ok := b.([]interface{}) 170 if !ok || len(a) != len(v) { 171 return false 172 } 173 for i, s := range a { 174 if !deepEqual(s, v[i], seen) { 175 return false 176 } 177 } 178 return true 179 case *[]string: 180 v, ok := b.(*[]string) 181 if !ok || a == nil || v == nil { 182 return ok && a == v 183 } 184 return deepEqual(*a, *v, seen) 185 case *[]interface{}: 186 v, ok := b.(*[]interface{}) 187 if !ok || a == nil || v == nil { 188 return ok && a == v 189 } 190 return deepEqual(*a, *v, seen) 191 192 default: 193 // Handle other kinds of non-comparable objects. 194 return genericDeepEqual(a, b, seen) 195 } 196 } 197 198 func errorStringFromError(err error) string { 199 if v := reflect.ValueOf(err); v.Kind() == reflect.Ptr && v.IsNil() { 200 return "" 201 } 202 return err.Error() 203 } 204 205 type edge struct { 206 from uintptr 207 to uintptr 208 } 209 210 func genericDeepEqual(a, b interface{}, seen map[edge]struct{}) bool { 211 av := reflect.ValueOf(a) 212 bv := reflect.ValueOf(b) 213 if avalid, bvalid := av.IsValid(), bv.IsValid(); !avalid || !bvalid { 214 return avalid == bvalid 215 } 216 if bv.Type() != av.Type() { 217 return false 218 } 219 220 switch av.Kind() { 221 case reflect.Ptr: 222 if av.IsNil() || bv.IsNil() { 223 return a == b 224 } 225 226 av = av.Elem() 227 bv = bv.Elem() 228 if av.CanAddr() && bv.CanAddr() { 229 e := edge{from: av.UnsafeAddr(), to: bv.UnsafeAddr()} 230 // Detect and prevent cycles. 231 if seen == nil { 232 seen = make(map[edge]struct{}) 233 } else if _, ok := seen[e]; ok { 234 return true 235 } 236 seen[e] = struct{}{} 237 } 238 239 return deepEqual(av.Interface(), bv.Interface(), seen) 240 case reflect.Slice, reflect.Array: 241 l := av.Len() 242 if l != bv.Len() { 243 return false 244 } 245 for i := 0; i < l; i++ { 246 if !deepEqual(av.Index(i).Interface(), bv.Index(i).Interface(), seen) { 247 return false 248 } 249 } 250 return true 251 case reflect.Map: 252 if av.IsNil() != bv.IsNil() { 253 return false 254 } 255 if av.Len() != bv.Len() { 256 return false 257 } 258 if av.Pointer() == bv.Pointer() { 259 return true 260 } 261 for _, k := range av.MapKeys() { 262 // Upon finding the first key that's a pointer, we bail out and do 263 // a O(N^2) comparison. 264 if kk := k.Kind(); kk == reflect.Ptr || kk == reflect.Interface { 265 ok, _, _ := complexKeyMapEqual(av, bv, seen) 266 return ok 267 } 268 ea := av.MapIndex(k) 269 eb := bv.MapIndex(k) 270 if !eb.IsValid() { 271 return false 272 } 273 if !deepEqual(ea.Interface(), eb.Interface(), seen) { 274 return false 275 } 276 } 277 return true 278 case reflect.Struct: 279 typ := av.Type() 280 if typ.Implements(comparableType) { 281 return av.Interface().(key.Comparable).Equal(bv.Interface()) 282 } 283 for i, n := 0, av.NumField(); i < n; i++ { 284 if typ.Field(i).Tag.Get("deepequal") == "ignore" { 285 continue 286 } 287 af := areflect.ForceExport(av.Field(i)) 288 bf := areflect.ForceExport(bv.Field(i)) 289 if !deepEqual(af.Interface(), bf.Interface(), seen) { 290 return false 291 } 292 } 293 return true 294 case reflect.Func: 295 if av.IsNil() != bv.IsNil() { 296 return false 297 } 298 if av.IsNil() { 299 return true 300 } 301 // if both of these are functions, compare their full names 302 return runtime.FuncForPC(av.Pointer()).Name() == 303 runtime.FuncForPC(bv.Pointer()).Name() 304 default: 305 // Incomparable types cannot be compared with DeepEqual. 306 // To skip comparisons, tag the type with `deepequal:"ignore"` 307 if !av.Type().Comparable() || !bv.Type().Comparable() { 308 return false 309 } 310 // Other the basic types. 311 return a == b 312 } 313 } 314 315 // Compares two maps with complex keys (that are pointers). This assumes the 316 // maps have already been checked to have the same sizes. The cost of this 317 // function is O(N^2) in the size of the input maps. 318 // 319 // The return is to be interpreted this way: 320 // 321 // true, _, _ => av == bv 322 // false, key, invalid => the given key wasn't found in bv 323 // false, key, value => the given key had the given value in bv, 324 // which is different in av 325 func complexKeyMapEqual(av, bv reflect.Value, 326 seen map[edge]struct{}) (bool, reflect.Value, reflect.Value) { 327 for _, ka := range av.MapKeys() { 328 var eb reflect.Value // The entry in bv with a key equal to ka 329 for _, kb := range bv.MapKeys() { 330 if deepEqual(ka.Elem().Interface(), kb.Elem().Interface(), seen) { 331 // Found the corresponding entry in bv. 332 eb = bv.MapIndex(kb) 333 break 334 } 335 } 336 if !eb.IsValid() { // We didn't find a key equal to `ka' in 'bv'. 337 return false, ka, reflect.Value{} 338 } 339 ea := av.MapIndex(ka) 340 if !deepEqual(ea.Interface(), eb.Interface(), seen) { 341 return false, ka, eb 342 } 343 } 344 return true, reflect.Value{}, reflect.Value{} 345 } 346 347 // mapEqual does O(N^2) comparisons to check that all the keys present in the 348 // first map are also present in the second map and have identical values. 349 func mapEqual(a, b map[interface{}]interface{}) bool { 350 if len(a) != len(b) { 351 return false 352 } 353 for akey, avalue := range a { 354 found := false 355 for bkey, bvalue := range b { 356 if DeepEqual(akey, bkey) { 357 if !DeepEqual(avalue, bvalue) { 358 return false 359 } 360 found = true 361 break 362 } 363 } 364 if !found { 365 return false 366 } 367 } 368 return true 369 }