github.com/ecadlabs/pretty@v0.0.0-20230412123216-0f3d25fb750b/diff.go (about)

     1  package pretty
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"reflect"
     7  )
     8  
     9  type sbuf []string
    10  
    11  func (p *sbuf) Printf(format string, a ...interface{}) {
    12  	s := fmt.Sprintf(format, a...)
    13  	*p = append(*p, s)
    14  }
    15  
    16  // Diff returns a slice where each element describes
    17  // a difference between a and b.
    18  func Diff(a, b interface{}) (desc []string) {
    19  	Pdiff((*sbuf)(&desc), a, b)
    20  	return desc
    21  }
    22  
    23  // wprintfer calls Fprintf on w for each Printf call
    24  // with a trailing newline.
    25  type wprintfer struct{ w io.Writer }
    26  
    27  func (p *wprintfer) Printf(format string, a ...interface{}) {
    28  	fmt.Fprintf(p.w, format+"\n", a...)
    29  }
    30  
    31  // Fdiff writes to w a description of the differences between a and b.
    32  func Fdiff(w io.Writer, a, b interface{}) {
    33  	Pdiff(&wprintfer{w}, a, b)
    34  }
    35  
    36  type Printfer interface {
    37  	Printf(format string, a ...interface{})
    38  }
    39  
    40  // Pdiff prints to p a description of the differences between a and b.
    41  // It calls Printf once for each difference, with no trailing newline.
    42  // The standard library log.Logger is a Printfer.
    43  func Pdiff(p Printfer, a, b interface{}) {
    44  	d := diffPrinter{
    45  		w:        p,
    46  		aVisited: make(map[visit]visit),
    47  		bVisited: make(map[visit]visit),
    48  	}
    49  	d.diff(reflect.ValueOf(a), reflect.ValueOf(b))
    50  }
    51  
    52  type Logfer interface {
    53  	Logf(format string, a ...interface{})
    54  }
    55  
    56  // logprintfer calls Fprintf on w for each Printf call
    57  // with a trailing newline.
    58  type logprintfer struct{ l Logfer }
    59  
    60  func (p *logprintfer) Printf(format string, a ...interface{}) {
    61  	p.l.Logf(format, a...)
    62  }
    63  
    64  // Ldiff prints to l a description of the differences between a and b.
    65  // It calls Logf once for each difference, with no trailing newline.
    66  // The standard library testing.T and testing.B are Logfers.
    67  func Ldiff(l Logfer, a, b interface{}) {
    68  	Pdiff(&logprintfer{l}, a, b)
    69  }
    70  
    71  type diffPrinter struct {
    72  	w Printfer
    73  	l string // label
    74  
    75  	aVisited map[visit]visit
    76  	bVisited map[visit]visit
    77  }
    78  
    79  func (w diffPrinter) printf(f string, a ...interface{}) {
    80  	var l string
    81  	if w.l != "" {
    82  		l = w.l + ": "
    83  	}
    84  	w.w.Printf(l+f, a...)
    85  }
    86  
    87  func newFormatter(v reflect.Value, quote bool) formatter {
    88  	return formatter{v: v, quote: quote, opt: defaultOptions}
    89  }
    90  
    91  func (w diffPrinter) diff(av, bv reflect.Value) {
    92  	if !av.IsValid() && bv.IsValid() {
    93  		w.printf("nil != %# v", newFormatter(bv, true))
    94  		return
    95  	}
    96  	if av.IsValid() && !bv.IsValid() {
    97  		w.printf("%# v != nil", newFormatter(av, true))
    98  		return
    99  	}
   100  	if !av.IsValid() && !bv.IsValid() {
   101  		return
   102  	}
   103  
   104  	at := av.Type()
   105  	bt := bv.Type()
   106  	if at != bt {
   107  		w.printf("%v != %v", at, bt)
   108  		return
   109  	}
   110  
   111  	if av.CanAddr() && bv.CanAddr() {
   112  		avis := visit{av.UnsafeAddr(), at}
   113  		bvis := visit{bv.UnsafeAddr(), bt}
   114  		var cycle bool
   115  
   116  		// Have we seen this value before?
   117  		if vis, ok := w.aVisited[avis]; ok {
   118  			cycle = true
   119  			if vis != bvis {
   120  				w.printf("%# v (previously visited) != %# v", newFormatter(av, true), newFormatter(bv, true))
   121  			}
   122  		} else if _, ok := w.bVisited[bvis]; ok {
   123  			cycle = true
   124  			w.printf("%# v != %# v (previously visited)", newFormatter(av, true), newFormatter(bv, true))
   125  		}
   126  		w.aVisited[avis] = bvis
   127  		w.bVisited[bvis] = avis
   128  		if cycle {
   129  			return
   130  		}
   131  	}
   132  
   133  	switch kind := at.Kind(); kind {
   134  	case reflect.Bool:
   135  		if a, b := av.Bool(), bv.Bool(); a != b {
   136  			w.printf("%v != %v", a, b)
   137  		}
   138  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   139  		if a, b := av.Int(), bv.Int(); a != b {
   140  			w.printf("%d != %d", a, b)
   141  		}
   142  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   143  		if a, b := av.Uint(), bv.Uint(); a != b {
   144  			w.printf("%d != %d", a, b)
   145  		}
   146  	case reflect.Float32, reflect.Float64:
   147  		if a, b := av.Float(), bv.Float(); a != b {
   148  			w.printf("%v != %v", a, b)
   149  		}
   150  	case reflect.Complex64, reflect.Complex128:
   151  		if a, b := av.Complex(), bv.Complex(); a != b {
   152  			w.printf("%v != %v", a, b)
   153  		}
   154  	case reflect.Array:
   155  		n := av.Len()
   156  		for i := 0; i < n; i++ {
   157  			w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
   158  		}
   159  	case reflect.Chan, reflect.Func, reflect.UnsafePointer:
   160  		if a, b := av.Pointer(), bv.Pointer(); a != b {
   161  			w.printf("%#x != %#x", a, b)
   162  		}
   163  	case reflect.Interface:
   164  		w.diff(av.Elem(), bv.Elem())
   165  	case reflect.Map:
   166  		ak, both, bk := keyDiff(av.MapKeys(), bv.MapKeys())
   167  		for _, k := range ak {
   168  			w := w.relabel(fmt.Sprintf("[%#v]", k))
   169  			w.printf("%q != (missing)", av.MapIndex(k))
   170  		}
   171  		for _, k := range both {
   172  			w := w.relabel(fmt.Sprintf("[%#v]", k))
   173  			w.diff(av.MapIndex(k), bv.MapIndex(k))
   174  		}
   175  		for _, k := range bk {
   176  			w := w.relabel(fmt.Sprintf("[%#v]", k))
   177  			w.printf("(missing) != %q", bv.MapIndex(k))
   178  		}
   179  	case reflect.Ptr:
   180  		switch {
   181  		case av.IsNil() && !bv.IsNil():
   182  			w.printf("nil != %# v", newFormatter(bv, true))
   183  		case !av.IsNil() && bv.IsNil():
   184  			w.printf("%# v != nil", newFormatter(av, true))
   185  		case !av.IsNil() && !bv.IsNil():
   186  			w.diff(av.Elem(), bv.Elem())
   187  		}
   188  	case reflect.Slice:
   189  		lenA := av.Len()
   190  		lenB := bv.Len()
   191  		if lenA != lenB {
   192  			w.printf("%s[%d] != %s[%d]", av.Type(), lenA, bv.Type(), lenB)
   193  			break
   194  		}
   195  		for i := 0; i < lenA; i++ {
   196  			w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
   197  		}
   198  	case reflect.String:
   199  		if a, b := av.String(), bv.String(); a != b {
   200  			w.printf("%q != %q", a, b)
   201  		}
   202  	case reflect.Struct:
   203  		for i := 0; i < av.NumField(); i++ {
   204  			w.relabel(at.Field(i).Name).diff(av.Field(i), bv.Field(i))
   205  		}
   206  	default:
   207  		panic("unknown reflect Kind: " + kind.String())
   208  	}
   209  }
   210  
   211  func (d diffPrinter) relabel(name string) (d1 diffPrinter) {
   212  	d1 = d
   213  	if d.l != "" && name[0] != '[' {
   214  		d1.l += "."
   215  	}
   216  	d1.l += name
   217  	return d1
   218  }
   219  
   220  // keyEqual compares a and b for equality.
   221  // Both a and b must be valid map keys.
   222  func keyEqual(av, bv reflect.Value) bool {
   223  	if !av.IsValid() && !bv.IsValid() {
   224  		return true
   225  	}
   226  	if !av.IsValid() || !bv.IsValid() || av.Type() != bv.Type() {
   227  		return false
   228  	}
   229  	switch kind := av.Kind(); kind {
   230  	case reflect.Bool:
   231  		a, b := av.Bool(), bv.Bool()
   232  		return a == b
   233  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   234  		a, b := av.Int(), bv.Int()
   235  		return a == b
   236  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   237  		a, b := av.Uint(), bv.Uint()
   238  		return a == b
   239  	case reflect.Float32, reflect.Float64:
   240  		a, b := av.Float(), bv.Float()
   241  		return a == b
   242  	case reflect.Complex64, reflect.Complex128:
   243  		a, b := av.Complex(), bv.Complex()
   244  		return a == b
   245  	case reflect.Array:
   246  		for i := 0; i < av.Len(); i++ {
   247  			if !keyEqual(av.Index(i), bv.Index(i)) {
   248  				return false
   249  			}
   250  		}
   251  		return true
   252  	case reflect.Chan, reflect.UnsafePointer, reflect.Ptr:
   253  		a, b := av.Pointer(), bv.Pointer()
   254  		return a == b
   255  	case reflect.Interface:
   256  		return keyEqual(av.Elem(), bv.Elem())
   257  	case reflect.String:
   258  		a, b := av.String(), bv.String()
   259  		return a == b
   260  	case reflect.Struct:
   261  		for i := 0; i < av.NumField(); i++ {
   262  			if !keyEqual(av.Field(i), bv.Field(i)) {
   263  				return false
   264  			}
   265  		}
   266  		return true
   267  	default:
   268  		panic("invalid map key type " + av.Type().String())
   269  	}
   270  }
   271  
   272  func keyDiff(a, b []reflect.Value) (ak, both, bk []reflect.Value) {
   273  	for _, av := range a {
   274  		inBoth := false
   275  		for _, bv := range b {
   276  			if keyEqual(av, bv) {
   277  				inBoth = true
   278  				both = append(both, av)
   279  				break
   280  			}
   281  		}
   282  		if !inBoth {
   283  			ak = append(ak, av)
   284  		}
   285  	}
   286  	for _, bv := range b {
   287  		inBoth := false
   288  		for _, av := range a {
   289  			if keyEqual(av, bv) {
   290  				inBoth = true
   291  				break
   292  			}
   293  		}
   294  		if !inBoth {
   295  			bk = append(bk, bv)
   296  		}
   297  	}
   298  	return
   299  }