github.com/psiphon-labs/goarista@v0.0.0-20160825065156-d002785f4c67/test/deepequal.go (about) 1 // Copyright (C) 2014 Arista Networks, Inc. 2 // Use of this source code is governed by the Apache License 2.0 3 // that can be found in the COPYING file. 4 5 package test 6 7 import ( 8 "bytes" 9 "math" 10 "reflect" 11 12 "github.com/aristanetworks/goarista/areflect" 13 "github.com/aristanetworks/goarista/key" 14 ) 15 16 var comparableType = reflect.TypeOf((*key.Comparable)(nil)).Elem() 17 18 // DeepEqual is a faster implementation of reflect.DeepEqual that: 19 // - Has a reflection-free fast-path for all the common types we use. 20 // - Gives data types the ability to exclude some of their fields from the 21 // consideration of DeepEqual by tagging them with `deepequal:"ignore"`. 22 // - Gives data types the ability to define their own comparison method by 23 // implementing the comparable interface. 24 // - Supports "composite" (or "complex") keys in maps that are pointers. 25 func DeepEqual(a, b interface{}) bool { 26 return deepEqual(a, b, nil) 27 } 28 29 func deepEqual(a, b interface{}, seen map[edge]struct{}) bool { 30 if a == nil || b == nil { 31 return a == b 32 } 33 switch a := a.(type) { 34 // Short circuit fast-path for common built-in types. 35 // Note: the cases are listed by frequency. 36 case bool: 37 return a == b 38 39 case map[string]interface{}: 40 v, ok := b.(map[string]interface{}) 41 if !ok || len(a) != len(v) { 42 return false 43 } 44 for key, value := range a { 45 if other, ok := v[key]; !ok || !deepEqual(value, other, seen) { 46 return false 47 } 48 } 49 return true 50 51 case string, uint32, uint64, int32, 52 uint16, int16, uint8, int8, int64: 53 return a == b 54 55 case *map[string]interface{}: 56 v, ok := b.(*map[string]interface{}) 57 if !ok || a == nil || v == nil { 58 return ok && a == v 59 } 60 return deepEqual(*a, *v, seen) 61 62 case map[interface{}]interface{}: 63 v, ok := b.(map[interface{}]interface{}) 64 if !ok { 65 return false 66 } 67 // We compare in both directions to catch keys that are in b but not 68 // in a. It sucks to have to do another O(N^2) for this, but oh well. 69 return mapEqual(a, v) && mapEqual(v, a) 70 71 case float32: 72 v, ok := b.(float32) 73 return ok && (a == b || (math.IsNaN(float64(a)) && math.IsNaN(float64(v)))) 74 case float64: 75 v, ok := b.(float64) 76 return ok && (a == b || (math.IsNaN(a) && math.IsNaN(v))) 77 78 case []string: 79 v, ok := b.([]string) 80 if !ok || len(a) != len(v) { 81 return false 82 } 83 for i, s := range a { 84 if s != v[i] { 85 return false 86 } 87 } 88 return true 89 case []byte: 90 v, ok := b.([]byte) 91 return ok && bytes.Equal(a, v) 92 93 case map[uint64]interface{}: 94 v, ok := b.(map[uint64]interface{}) 95 if !ok || len(a) != len(v) { 96 return false 97 } 98 for key, value := range a { 99 if other, ok := v[key]; !ok || !deepEqual(value, other, seen) { 100 return false 101 } 102 } 103 return true 104 105 case *map[interface{}]interface{}: 106 v, ok := b.(*map[interface{}]interface{}) 107 if !ok || a == nil || v == nil { 108 return ok && a == v 109 } 110 return deepEqual(*a, *v, seen) 111 case key.Comparable: 112 return a.Equal(b) 113 114 case []uint32: 115 v, ok := b.([]uint32) 116 if !ok || len(a) != len(v) { 117 return false 118 } 119 for i, s := range a { 120 if s != v[i] { 121 return false 122 } 123 } 124 return true 125 case []uint64: 126 v, ok := b.([]uint64) 127 if !ok || len(a) != len(v) { 128 return false 129 } 130 for i, s := range a { 131 if s != v[i] { 132 return false 133 } 134 } 135 return true 136 case []interface{}: 137 v, ok := b.([]interface{}) 138 if !ok || len(a) != len(v) { 139 return false 140 } 141 for i, s := range a { 142 if !deepEqual(s, v[i], seen) { 143 return false 144 } 145 } 146 return true 147 case *[]string: 148 v, ok := b.(*[]string) 149 if !ok || a == nil || v == nil { 150 return ok && a == v 151 } 152 return deepEqual(*a, *v, seen) 153 case *[]interface{}: 154 v, ok := b.(*[]interface{}) 155 if !ok || a == nil || v == nil { 156 return ok && a == v 157 } 158 return deepEqual(*a, *v, seen) 159 160 default: 161 // Handle other kinds of non-comparable objects. 162 return genericDeepEqual(a, b, seen) 163 } 164 } 165 166 type edge struct { 167 from uintptr 168 to uintptr 169 } 170 171 func genericDeepEqual(a, b interface{}, seen map[edge]struct{}) bool { 172 av := reflect.ValueOf(a) 173 bv := reflect.ValueOf(b) 174 if avalid, bvalid := av.IsValid(), bv.IsValid(); !avalid || !bvalid { 175 return avalid == bvalid 176 } 177 if bv.Type() != av.Type() { 178 return false 179 } 180 181 switch av.Kind() { 182 case reflect.Ptr: 183 if av.IsNil() || bv.IsNil() { 184 return a == b 185 } 186 187 av = av.Elem() 188 bv = bv.Elem() 189 if av.CanAddr() && bv.CanAddr() { 190 e := edge{from: av.UnsafeAddr(), to: bv.UnsafeAddr()} 191 // Detect and prevent cycles. 192 if seen == nil { 193 seen = make(map[edge]struct{}) 194 } else if _, ok := seen[e]; ok { 195 return true 196 } 197 seen[e] = struct{}{} 198 } 199 200 return deepEqual(av.Interface(), bv.Interface(), seen) 201 case reflect.Slice, reflect.Array: 202 l := av.Len() 203 if l != bv.Len() { 204 return false 205 } 206 for i := 0; i < l; i++ { 207 if !deepEqual(av.Index(i).Interface(), bv.Index(i).Interface(), seen) { 208 return false 209 } 210 } 211 return true 212 case reflect.Map: 213 if av.IsNil() != bv.IsNil() { 214 return false 215 } 216 if av.Len() != bv.Len() { 217 return false 218 } 219 if av.Pointer() == bv.Pointer() { 220 return true 221 } 222 for _, k := range av.MapKeys() { 223 // Upon finding the first key that's a pointer, we bail out and do 224 // a O(N^2) comparison. 225 if kk := k.Kind(); kk == reflect.Ptr || kk == reflect.Interface { 226 ok, _, _ := complexKeyMapEqual(av, bv, seen) 227 return ok 228 } 229 ea := av.MapIndex(k) 230 eb := bv.MapIndex(k) 231 if !eb.IsValid() { 232 return false 233 } 234 if !deepEqual(ea.Interface(), eb.Interface(), seen) { 235 return false 236 } 237 } 238 return true 239 case reflect.Struct: 240 typ := av.Type() 241 if typ.Implements(comparableType) { 242 return av.Interface().(key.Comparable).Equal(bv.Interface()) 243 } 244 for i, n := 0, av.NumField(); i < n; i++ { 245 if typ.Field(i).Tag.Get("deepequal") == "ignore" { 246 continue 247 } 248 af := areflect.ForceExport(av.Field(i)) 249 bf := areflect.ForceExport(bv.Field(i)) 250 if !deepEqual(af.Interface(), bf.Interface(), seen) { 251 return false 252 } 253 } 254 return true 255 default: 256 // Other the basic types. 257 return a == b 258 } 259 } 260 261 // Compares two maps with complex keys (that are pointers). This assumes the 262 // maps have already been checked to have the same sizes. The cost of this 263 // function is O(N^2) in the size of the input maps. 264 // 265 // The return is to be interpreted this way: 266 // true, _, _ => av == bv 267 // false, key, invalid => the given key wasn't found in bv 268 // false, key, value => the given key had the given value in bv, 269 // which is different in av 270 func complexKeyMapEqual(av, bv reflect.Value, 271 seen map[edge]struct{}) (bool, reflect.Value, reflect.Value) { 272 for _, ka := range av.MapKeys() { 273 var eb reflect.Value // The entry in bv with a key equal to ka 274 for _, kb := range bv.MapKeys() { 275 if deepEqual(ka.Elem().Interface(), kb.Elem().Interface(), seen) { 276 // Found the corresponding entry in bv. 277 eb = bv.MapIndex(kb) 278 break 279 } 280 } 281 if !eb.IsValid() { // We didn't find a key equal to `ka' in 'bv'. 282 return false, ka, reflect.Value{} 283 } 284 ea := av.MapIndex(ka) 285 if !deepEqual(ea.Interface(), eb.Interface(), seen) { 286 return false, ka, eb 287 } 288 } 289 return true, reflect.Value{}, reflect.Value{} 290 } 291 292 // mapEqual does O(N^2) comparisons to check that all the keys present in the 293 // first map are also present in the second map and have identical values. 294 func mapEqual(a, b map[interface{}]interface{}) bool { 295 if len(a) != len(b) { 296 return false 297 } 298 for akey, avalue := range a { 299 found := false 300 for bkey, bvalue := range b { 301 if DeepEqual(akey, bkey) { 302 if !DeepEqual(avalue, bvalue) { 303 return false 304 } 305 found = true 306 break 307 } 308 } 309 if !found { 310 return false 311 } 312 } 313 return true 314 }