launchpad.net/~rogpeppe/juju-core/500-errgo-fix@v0.0.0-20140213181702-000000002356/testing/checkers/deepequal.go (about)

     1  // Copied with small adaptations from the reflect package in the
     2  // Go source tree.
     3  
     4  // Copyright 2009 The Go Authors. All rights reserved.
     5  // Use of this source code is governed by a BSD-style
     6  // license that can be found in the LICENSE file.
     7  
     8  package checkers
     9  
    10  import (
    11  	"fmt"
    12  	"reflect"
    13  	"unsafe"
    14  )
    15  
    16  // During deepValueEqual, must keep track of checks that are
    17  // in progress.  The comparison algorithm assumes that all
    18  // checks in progress are true when it reencounters them.
    19  // Visited comparisons are stored in a map indexed by visit.
    20  type visit struct {
    21  	a1  uintptr
    22  	a2  uintptr
    23  	typ reflect.Type
    24  }
    25  
    26  type mismatchError struct {
    27  	v1, v2 reflect.Value
    28  	path   string
    29  	how    string
    30  }
    31  
    32  func (err *mismatchError) Error() string {
    33  	path := err.path
    34  	if path == "" {
    35  		path = "top level"
    36  	}
    37  	return fmt.Sprintf("mismatch at %s: %s; obtained %#v; expected %#v", path, err.how, interfaceOf(err.v1), interfaceOf(err.v2))
    38  }
    39  
    40  // Tests for deep equality using reflected types. The map argument tracks
    41  // comparisons that have already been seen, which allows short circuiting on
    42  // recursive types.
    43  func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, depth int) (ok bool, err error) {
    44  	errorf := func(f string, a ...interface{}) error {
    45  		return &mismatchError{
    46  			v1:   v1,
    47  			v2:   v2,
    48  			path: path,
    49  			how:  fmt.Sprintf(f, a...),
    50  		}
    51  	}
    52  	if !v1.IsValid() || !v2.IsValid() {
    53  		if v1.IsValid() == v2.IsValid() {
    54  			return true, nil
    55  		}
    56  		return false, errorf("validity mismatch")
    57  	}
    58  	if v1.Type() != v2.Type() {
    59  		return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Type())
    60  	}
    61  
    62  	// if depth > 10 { panic("deepValueEqual") }	// for debugging
    63  	hard := func(k reflect.Kind) bool {
    64  		switch k {
    65  		case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
    66  			return true
    67  		}
    68  		return false
    69  	}
    70  
    71  	if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
    72  		addr1 := v1.UnsafeAddr()
    73  		addr2 := v2.UnsafeAddr()
    74  		if addr1 > addr2 {
    75  			// Canonicalize order to reduce number of entries in visited.
    76  			addr1, addr2 = addr2, addr1
    77  		}
    78  
    79  		// Short circuit if references are identical ...
    80  		if addr1 == addr2 {
    81  			return true, nil
    82  		}
    83  
    84  		// ... or already seen
    85  		typ := v1.Type()
    86  		v := visit{addr1, addr2, typ}
    87  		if visited[v] {
    88  			return true, nil
    89  		}
    90  
    91  		// Remember for later.
    92  		visited[v] = true
    93  	}
    94  
    95  	switch v1.Kind() {
    96  	case reflect.Array:
    97  		if v1.Len() != v2.Len() {
    98  			// can't happen!
    99  			return false, errorf("length mismatch, %d vs %d", v1.Len(), v2.Len())
   100  		}
   101  		for i := 0; i < v1.Len(); i++ {
   102  			if ok, err := deepValueEqual(
   103  				fmt.Sprintf("%s[%d]", path, i),
   104  				v1.Index(i), v2.Index(i), visited, depth+1); !ok {
   105  				return false, err
   106  			}
   107  		}
   108  		return true, nil
   109  	case reflect.Slice:
   110  		// We treat a nil slice the same as an empty slice.
   111  		if v1.Len() != v2.Len() {
   112  			return false, errorf("length mismatch, %d vs %d", v1.Len(), v2.Len())
   113  		}
   114  		if v1.Pointer() == v2.Pointer() {
   115  			return true, nil
   116  		}
   117  		for i := 0; i < v1.Len(); i++ {
   118  			if ok, err := deepValueEqual(
   119  				fmt.Sprintf("%s[%d]", path, i),
   120  				v1.Index(i), v2.Index(i), visited, depth+1); !ok {
   121  				return false, err
   122  			}
   123  		}
   124  		return true, nil
   125  	case reflect.Interface:
   126  		if v1.IsNil() || v2.IsNil() {
   127  			if v1.IsNil() != v2.IsNil() {
   128  				return false, fmt.Errorf("nil vs non-nil interface mismatch")
   129  			}
   130  			return true, nil
   131  		}
   132  		return deepValueEqual(path, v1.Elem(), v2.Elem(), visited, depth+1)
   133  	case reflect.Ptr:
   134  		return deepValueEqual("(*"+path+")", v1.Elem(), v2.Elem(), visited, depth+1)
   135  	case reflect.Struct:
   136  		for i, n := 0, v1.NumField(); i < n; i++ {
   137  			path := path + "." + v1.Type().Field(i).Name
   138  			if ok, err := deepValueEqual(path, v1.Field(i), v2.Field(i), visited, depth+1); !ok {
   139  				return false, err
   140  			}
   141  		}
   142  		return true, nil
   143  	case reflect.Map:
   144  		if v1.IsNil() != v2.IsNil() {
   145  			return false, errorf("nil vs non-nil mismatch")
   146  		}
   147  		if v1.Len() != v2.Len() {
   148  			return false, errorf("length mismatch, %d vs %d", v1.Len(), v2.Len())
   149  		}
   150  		if v1.Pointer() == v2.Pointer() {
   151  			return true, nil
   152  		}
   153  		for _, k := range v1.MapKeys() {
   154  			var p string
   155  			if k.CanInterface() {
   156  				p = path + "[" + fmt.Sprintf("%#v", k.Interface()) + "]"
   157  			} else {
   158  				p = path + "[someKey]"
   159  			}
   160  			if ok, err := deepValueEqual(p, v1.MapIndex(k), v2.MapIndex(k), visited, depth+1); !ok {
   161  				return false, err
   162  			}
   163  		}
   164  		return true, nil
   165  	case reflect.Func:
   166  		if v1.IsNil() && v2.IsNil() {
   167  			return true, nil
   168  		}
   169  		// Can't do better than this:
   170  		return false, errorf("non-nil functions")
   171  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   172  		if v1.Int() != v2.Int() {
   173  			return false, errorf("unequal")
   174  		}
   175  		return true, nil
   176  	case reflect.Uint, reflect.Uintptr, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   177  		if v1.Uint() != v2.Uint() {
   178  			return false, errorf("unequal")
   179  		}
   180  		return true, nil
   181  	case reflect.Float32, reflect.Float64:
   182  		if v1.Float() != v2.Float() {
   183  			return false, errorf("unequal")
   184  		}
   185  		return true, nil
   186  	case reflect.Complex64, reflect.Complex128:
   187  		if v1.Complex() != v2.Complex() {
   188  			return false, errorf("unequal")
   189  		}
   190  		return true, nil
   191  	case reflect.Bool:
   192  		if v1.Bool() != v2.Bool() {
   193  			return false, errorf("unequal")
   194  		}
   195  		return true, nil
   196  	case reflect.String:
   197  		if v1.String() != v2.String() {
   198  			return false, errorf("unequal")
   199  		}
   200  		return true, nil
   201  	case reflect.Chan, reflect.UnsafePointer:
   202  		if v1.Pointer() != v2.Pointer() {
   203  			return false, errorf("unequal")
   204  		}
   205  		return true, nil
   206  	default:
   207  		panic("unexpected type " + v1.Type().String())
   208  	}
   209  }
   210  
   211  // DeepEqual tests for deep equality. It uses normal == equality where
   212  // possible but will scan elements of arrays, slices, maps, and fields
   213  // of structs. In maps, keys are compared with == but elements use deep
   214  // equality. DeepEqual correctly handles recursive types. Functions are
   215  // equal only if they are both nil.
   216  //
   217  // DeepEqual differs from reflect.DeepEqual in that an empty slice is
   218  // equal to a nil slice. If the two values compare unequal, the
   219  // resulting error holds the first difference encountered.
   220  func DeepEqual(a1, a2 interface{}) (bool, error) {
   221  	errorf := func(f string, a ...interface{}) error {
   222  		return &mismatchError{
   223  			v1:   reflect.ValueOf(a1),
   224  			v2:   reflect.ValueOf(a2),
   225  			path: "",
   226  			how:  fmt.Sprintf(f, a...),
   227  		}
   228  	}
   229  	if a1 == nil || a2 == nil {
   230  		if a1 == a2 {
   231  			return true, nil
   232  		}
   233  		return false, errorf("nil vs non-nil mismatch")
   234  	}
   235  	v1 := reflect.ValueOf(a1)
   236  	v2 := reflect.ValueOf(a2)
   237  	if v1.Type() != v2.Type() {
   238  		return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Type())
   239  	}
   240  	return deepValueEqual("", v1, v2, make(map[visit]bool), 0)
   241  }
   242  
   243  // interfaceOf returns v.Interface() even if v.CanInterface() == false.
   244  // This enables us to call fmt.Printf on a value even if it's derived
   245  // from inside an unexported field.
   246  func interfaceOf(v reflect.Value) interface{} {
   247  	if !v.IsValid() {
   248  		return nil
   249  	}
   250  	return bypassCanInterface(v).Interface()
   251  }
   252  
   253  type flag uintptr
   254  
   255  // copied from reflect/value.go
   256  const (
   257  	flagRO flag = 1 << iota
   258  )
   259  
   260  var flagValOffset = func() uintptr {
   261  	field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
   262  	if !ok {
   263  		panic("reflect.Value has no flag field")
   264  	}
   265  	return field.Offset
   266  }()
   267  
   268  func flagField(v *reflect.Value) *flag {
   269  	return (*flag)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + flagValOffset))
   270  }
   271  
   272  // bypassCanInterface returns a version of v that
   273  // bypasses the CanInterface check.
   274  func bypassCanInterface(v reflect.Value) reflect.Value {
   275  	if !v.IsValid() || v.CanInterface() {
   276  		return v
   277  	}
   278  	*flagField(&v) &^= flagRO
   279  	return v
   280  }
   281  
   282  // Sanity checks against future reflect package changes
   283  // to the type or semantics of the Value.flag field.
   284  func init() {
   285  	field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
   286  	if !ok {
   287  		panic("reflect.Value has no flag field")
   288  	}
   289  	if field.Type.Kind() != reflect.TypeOf(flag(0)).Kind() {
   290  		panic("reflect.Value flag field has changed kind")
   291  	}
   292  	var t struct {
   293  		a int
   294  		A int
   295  	}
   296  	vA := reflect.ValueOf(t).FieldByName("A")
   297  	va := reflect.ValueOf(t).FieldByName("a")
   298  	flagA := *flagField(&vA)
   299  	flaga := *flagField(&va)
   300  	if flagA&flagRO != 0 || flaga&flagRO == 0 {
   301  		panic("reflect.Value read-only flag has changed value")
   302  	}
   303  }