github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/test/diff.go (about)

     1  // Copyright (c) 2015 Arista Networks, Inc.
     2  // Use of this source code is governed by the Apache License 2.0
     3  // that can be found in the COPYING file.
     4  
     5  package test
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"reflect"
    11  	"runtime"
    12  	"sort"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/aristanetworks/goarista/areflect"
    17  	"github.com/aristanetworks/goarista/key"
    18  )
    19  
    20  // diffable types have a method that returns the diff
    21  // of two objects
    22  type diffable interface {
    23  	// Diff returns a human readable string of the diff of the two objects
    24  	// an empty string means that the two objects are equal
    25  	Diff(other interface{}) string
    26  }
    27  
    28  // Diff returns the difference of two objects in a human readable format.
    29  // An empty string is returned when there is no difference.
    30  // To avoid confusing diffs, make sure you pass the expected value first.
    31  func Diff(expected, actual interface{}) string {
    32  	if DeepEqual(expected, actual) {
    33  		return ""
    34  	}
    35  
    36  	return diffImpl(expected, actual, nil)
    37  }
    38  
    39  func diffImpl(a, b interface{}, seen map[edge]struct{}) string {
    40  	av := reflect.ValueOf(a)
    41  	bv := reflect.ValueOf(b)
    42  	// Check if nil
    43  	if !av.IsValid() {
    44  		if !bv.IsValid() {
    45  			return "" // Both are "nil" with no type
    46  		}
    47  		return fmt.Sprintf("expected nil but got a %T: %#v", b, b)
    48  	} else if !bv.IsValid() {
    49  		return fmt.Sprintf("expected a %T (%#v) but got nil", a, a)
    50  	}
    51  	if av.Type() != bv.Type() {
    52  		return fmt.Sprintf("expected a %T but got a %T", a, b)
    53  	}
    54  
    55  	switch a := a.(type) {
    56  	case string, bool,
    57  		int8, int16, int32, int64,
    58  		uint8, uint16, uint32, uint64,
    59  		float32, float64,
    60  		complex64, complex128,
    61  		int, uint, uintptr:
    62  		if a != b {
    63  			typ := reflect.TypeOf(a).Name()
    64  			return fmt.Sprintf("%s(%v) != %s(%v)", typ, a, typ, b)
    65  		}
    66  		return ""
    67  	case []byte:
    68  		if !bytes.Equal(a, b.([]byte)) {
    69  			return fmt.Sprintf("[]byte(%q) != []byte(%q)", a, b)
    70  		}
    71  	}
    72  
    73  	if ac, ok := a.(diffable); ok {
    74  		return ac.Diff(b.(diffable))
    75  	}
    76  
    77  	if ad, ok := a.(DeepEqualer); ok {
    78  		if ad.DeepEqual(b.(DeepEqualer), func(x, y interface{}) bool {
    79  			return deepEqual(x, y, seen)
    80  		}) {
    81  			return ""
    82  		}
    83  		return fmt.Sprintf("DeepEqualer types are different: %v vs %v", a, b)
    84  	}
    85  
    86  	if ac, ok := a.(key.Comparable); ok {
    87  		if ac.Equal(b.(key.Comparable)) {
    88  			return ""
    89  		}
    90  		return fmt.Sprintf("Comparable types are different: %s vs %s",
    91  			PrettyPrint(a), PrettyPrint(b))
    92  	}
    93  
    94  	if at, ok := a.(time.Time); ok {
    95  		if at.Equal(b.(time.Time)) {
    96  			return ""
    97  		}
    98  		return fmt.Sprintf("time.Time values are different: %s vs %s", at, b.(time.Time))
    99  	}
   100  
   101  	switch av.Kind() {
   102  	case reflect.Array, reflect.Slice:
   103  		l := av.Len()
   104  		if l != bv.Len() {
   105  			return fmt.Sprintf("Expected an array of size %d but got %d",
   106  				l, bv.Len())
   107  		}
   108  		for i := 0; i < l; i++ {
   109  			diff := diffImpl(av.Index(i).Interface(), bv.Index(i).Interface(),
   110  				seen)
   111  			if len(diff) > 0 {
   112  				return fmt.Sprintf("In arrays, values are different at index %d: %s", i, diff)
   113  			}
   114  		}
   115  
   116  	case reflect.Map:
   117  		if c, d := isNilCheck(av, bv); c {
   118  			return d
   119  		}
   120  		if av.Len() != bv.Len() {
   121  			return fmt.Sprintf("Maps have different size: %d != %d (%s)",
   122  				av.Len(), bv.Len(), diffMapKeys(av, bv))
   123  		}
   124  		for _, ka := range av.MapKeys() {
   125  			ae := av.MapIndex(ka)
   126  			if k := ka.Kind(); k == reflect.Ptr || k == reflect.Interface {
   127  				return diffComplexKeyMap(av, bv, seen)
   128  			}
   129  			be := bv.MapIndex(ka)
   130  			if !be.IsValid() {
   131  				return fmt.Sprintf(
   132  					"key %s in map is missing in the actual map",
   133  					prettyPrint(ka, ptrSet{}, prettyPrintDepth))
   134  			}
   135  			if !ae.CanInterface() {
   136  				return fmt.Sprintf(
   137  					"for key %s in map, value can't become an interface: %s",
   138  					prettyPrint(ka, ptrSet{}, prettyPrintDepth),
   139  					prettyPrint(ae, ptrSet{}, prettyPrintDepth))
   140  			}
   141  			if !be.CanInterface() {
   142  				return fmt.Sprintf(
   143  					"for key %s in map, value can't become an interface: %s",
   144  					prettyPrint(ka, ptrSet{}, prettyPrintDepth),
   145  					prettyPrint(be, ptrSet{}, prettyPrintDepth))
   146  			}
   147  			if diff := diffImpl(ae.Interface(), be.Interface(), seen); len(diff) > 0 {
   148  				return fmt.Sprintf(
   149  					"for key %s in map, values are different: %s",
   150  					prettyPrint(ka, ptrSet{}, prettyPrintDepth), diff)
   151  			}
   152  		}
   153  
   154  	case reflect.Ptr, reflect.Interface:
   155  		if c, d := isNilCheck(av, bv); c {
   156  			return d
   157  		}
   158  		av = av.Elem()
   159  		bv = bv.Elem()
   160  
   161  		if av.CanAddr() && bv.CanAddr() {
   162  			e := edge{from: av.UnsafeAddr(), to: bv.UnsafeAddr()}
   163  			// Detect and prevent cycles.
   164  			if seen == nil {
   165  				seen = make(map[edge]struct{})
   166  			} else if _, ok := seen[e]; ok {
   167  				return ""
   168  			}
   169  			seen[e] = struct{}{}
   170  		}
   171  		return diffImpl(av.Interface(), bv.Interface(), seen)
   172  
   173  	case reflect.Struct:
   174  		typ := av.Type()
   175  		for i, n := 0, av.NumField(); i < n; i++ {
   176  			if typ.Field(i).Tag.Get("deepequal") == "ignore" {
   177  				continue
   178  			}
   179  			af := areflect.ForceExport(av.Field(i))
   180  			bf := areflect.ForceExport(bv.Field(i))
   181  			if diff := diffImpl(af.Interface(), bf.Interface(), seen); len(diff) > 0 {
   182  				return fmt.Sprintf("attributes %q are different: %s",
   183  					av.Type().Field(i).Name, diff)
   184  			}
   185  		}
   186  
   187  		// The following cases are here to handle named types (aka type aliases).
   188  	case reflect.String:
   189  		if as, bs := av.String(), bv.String(); as != bs {
   190  			return fmt.Sprintf("%s(%q) != %s(%q)", av.Type().Name(), as, bv.Type().Name(), bs)
   191  		}
   192  	case reflect.Bool:
   193  		if ab, bb := av.Bool(), bv.Bool(); ab != bb {
   194  			return fmt.Sprintf("%s(%t) != %s(%t)", av.Type().Name(), ab, bv.Type().Name(), bb)
   195  		}
   196  	case reflect.Uint, reflect.Uintptr,
   197  		reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   198  		if ai, bi := av.Uint(), bv.Uint(); ai != bi {
   199  			return fmt.Sprintf("%s(%d) != %s(%d)", av.Type().Name(), ai, bv.Type().Name(), bi)
   200  		}
   201  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   202  		if ai, bi := av.Int(), bv.Int(); ai != bi {
   203  			return fmt.Sprintf("%s(%d) != %s(%d)", av.Type().Name(), ai, bv.Type().Name(), bi)
   204  		}
   205  	case reflect.Float32, reflect.Float64:
   206  		if af, bf := av.Float(), bv.Float(); af != bf {
   207  			return fmt.Sprintf("%s(%f) != %s(%f)", av.Type().Name(), af, bv.Type().Name(), bf)
   208  		}
   209  	case reflect.Complex64, reflect.Complex128:
   210  		if ac, bc := av.Complex(), bv.Complex(); ac != bc {
   211  			return fmt.Sprintf("%s(%f) != %s(%f)", av.Type().Name(), ac, bv.Type().Name(), bc)
   212  		}
   213  	case reflect.Func:
   214  		return fmt.Sprintf("type %T: %#[1]v with name %q cannot"+
   215  			" be compared to %#[3]v with name %q, functions must be exactly equal or nil",
   216  			a, runtime.FuncForPC(av.Pointer()).Name(), b, runtime.FuncForPC(bv.Pointer()).Name())
   217  	default:
   218  		return fmt.Sprintf("Unknown or unsupported type: %T: %#[1]v", a)
   219  
   220  	}
   221  
   222  	return ""
   223  }
   224  
   225  func diffComplexKeyMap(av, bv reflect.Value, seen map[edge]struct{}) string {
   226  	ok, ka, be := complexKeyMapEqual(av, bv, seen)
   227  	if ok {
   228  		return ""
   229  	} else if be.IsValid() {
   230  		return fmt.Sprintf("for complex key %s in map, values are different: %s",
   231  			prettyPrint(ka, ptrSet{}, prettyPrintDepth),
   232  			diffImpl(av.MapIndex(ka).Interface(), be.Interface(), seen))
   233  	}
   234  	return fmt.Sprintf("complex key %s in map is missing in the actual map",
   235  		prettyPrint(ka, ptrSet{}, prettyPrintDepth))
   236  }
   237  
   238  func diffMapKeys(av, bv reflect.Value) string {
   239  	var diffs []string
   240  	// TODO: We produce extraneous diffs for composite keys.
   241  	for _, ka := range av.MapKeys() {
   242  		be := bv.MapIndex(ka)
   243  		if !be.IsValid() {
   244  			diffs = append(diffs, fmt.Sprintf("missing key: %s",
   245  				PrettyPrint(ka.Interface())))
   246  		}
   247  	}
   248  	for _, kb := range bv.MapKeys() {
   249  		ae := av.MapIndex(kb)
   250  		if !ae.IsValid() {
   251  			diffs = append(diffs, fmt.Sprintf("extra key: %s",
   252  				PrettyPrint(kb.Interface())))
   253  		}
   254  	}
   255  	sort.Strings(diffs)
   256  	return strings.Join(diffs, ", ")
   257  }
   258  
   259  func isNilCheck(a, b reflect.Value) (bool /*checked*/, string) {
   260  	if a.IsNil() {
   261  		if b.IsNil() {
   262  			return true, ""
   263  		}
   264  		return true, fmt.Sprintf("expected nil but got %s",
   265  			prettyPrint(b, ptrSet{}, prettyPrintDepth))
   266  	} else if b.IsNil() {
   267  		return true, fmt.Sprintf("got nil instead of %s",
   268  			prettyPrint(a, ptrSet{}, prettyPrintDepth))
   269  	}
   270  	return false, ""
   271  }
   272  
   273  type mapEntry struct {
   274  	k, v string
   275  }
   276  
   277  type mapEntries struct {
   278  	entries []*mapEntry
   279  }
   280  
   281  func (t *mapEntries) Len() int {
   282  	return len(t.entries)
   283  }
   284  func (t *mapEntries) Less(i, j int) bool {
   285  	if t.entries[i].k > t.entries[j].k {
   286  		return false
   287  	} else if t.entries[i].k < t.entries[j].k {
   288  		return true
   289  	}
   290  	return t.entries[i].v <= t.entries[j].v
   291  }
   292  func (t *mapEntries) Swap(i, j int) {
   293  	t.entries[i], t.entries[j] = t.entries[j], t.entries[i]
   294  }