github.com/motomux/pretty@v0.0.0-20161209205251-b2aad2c9a95d/diff.go (about)

     1  package pretty
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"reflect"
     7  )
     8  
     9  type sbuf []string
    10  
    11  func (p *sbuf) Printf(format string, a ...interface{}) {
    12  	s := fmt.Sprintf(format, a...)
    13  	*p = append(*p, s)
    14  }
    15  
    16  // Diff returns a slice where each element describes
    17  // a difference between a and b.
    18  func Diff(a, b interface{}) (desc []string) {
    19  	Pdiff((*sbuf)(&desc), a, b)
    20  	return desc
    21  }
    22  
    23  // wprintfer calls Fprintf on w for each Printf call
    24  // with a trailing newline.
    25  type wprintfer struct{ w io.Writer }
    26  
    27  func (p *wprintfer) Printf(format string, a ...interface{}) {
    28  	fmt.Fprintf(p.w, format+"\n", a...)
    29  }
    30  
    31  // Fdiff writes to w a description of the differences between a and b.
    32  func Fdiff(w io.Writer, a, b interface{}) {
    33  	Pdiff(&wprintfer{w}, a, b)
    34  }
    35  
    36  type Printfer interface {
    37  	Printf(format string, a ...interface{})
    38  }
    39  
    40  // Pdiff prints to p a description of the differences between a and b.
    41  // It calls Printf once for each difference, with no trailing newline.
    42  // The standard library log.Logger is a Printfer.
    43  func Pdiff(p Printfer, a, b interface{}) {
    44  	diffPrinter{w: p}.diff(reflect.ValueOf(a), reflect.ValueOf(b))
    45  }
    46  
    47  type Logfer interface {
    48  	Logf(format string, a ...interface{})
    49  }
    50  
    51  // logprintfer calls Fprintf on w for each Printf call
    52  // with a trailing newline.
    53  type logprintfer struct{ l Logfer }
    54  
    55  func (p *logprintfer) Printf(format string, a ...interface{}) {
    56  	p.l.Logf(format, a...)
    57  }
    58  
    59  // Ldiff prints to l a description of the differences between a and b.
    60  // It calls Logf once for each difference, with no trailing newline.
    61  // The standard library testing.T and testing.B are Logfers.
    62  func Ldiff(l Logfer, a, b interface{}) {
    63  	Pdiff(&logprintfer{l}, a, b)
    64  }
    65  
    66  type diffPrinter struct {
    67  	w Printfer
    68  	l string // label
    69  }
    70  
    71  func (w diffPrinter) printf(f string, a ...interface{}) {
    72  	var l string
    73  	if w.l != "" {
    74  		l = w.l + ": "
    75  	}
    76  	w.w.Printf(l+f, a...)
    77  }
    78  
    79  func (w diffPrinter) diff(av, bv reflect.Value) {
    80  	if !av.IsValid() && bv.IsValid() {
    81  		w.printf("nil != %# v", formatter{v: bv, quote: true})
    82  		return
    83  	}
    84  	if av.IsValid() && !bv.IsValid() {
    85  		w.printf("%# v != nil", formatter{v: av, quote: true})
    86  		return
    87  	}
    88  	if !av.IsValid() && !bv.IsValid() {
    89  		return
    90  	}
    91  
    92  	at := av.Type()
    93  	bt := bv.Type()
    94  	if at != bt {
    95  		w.printf("%v != %v", at, bt)
    96  		return
    97  	}
    98  
    99  	switch kind := at.Kind(); kind {
   100  	case reflect.Bool:
   101  		if a, b := av.Bool(), bv.Bool(); a != b {
   102  			w.printf("%v != %v", a, b)
   103  		}
   104  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   105  		if a, b := av.Int(), bv.Int(); a != b {
   106  			w.printf("%d != %d", a, b)
   107  		}
   108  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   109  		if a, b := av.Uint(), bv.Uint(); a != b {
   110  			w.printf("%d != %d", a, b)
   111  		}
   112  	case reflect.Float32, reflect.Float64:
   113  		if a, b := av.Float(), bv.Float(); a != b {
   114  			w.printf("%v != %v", a, b)
   115  		}
   116  	case reflect.Complex64, reflect.Complex128:
   117  		if a, b := av.Complex(), bv.Complex(); a != b {
   118  			w.printf("%v != %v", a, b)
   119  		}
   120  	case reflect.Array:
   121  		n := av.Len()
   122  		for i := 0; i < n; i++ {
   123  			w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
   124  		}
   125  	case reflect.Chan, reflect.Func, reflect.UnsafePointer:
   126  		if a, b := av.Pointer(), bv.Pointer(); a != b {
   127  			w.printf("%#x != %#x", a, b)
   128  		}
   129  	case reflect.Interface:
   130  		w.diff(av.Elem(), bv.Elem())
   131  	case reflect.Map:
   132  		ak, both, bk := keyDiff(av.MapKeys(), bv.MapKeys())
   133  		for _, k := range ak {
   134  			w := w.relabel(fmt.Sprintf("[%#v]", k))
   135  			w.printf("%q != (missing)", av.MapIndex(k))
   136  		}
   137  		for _, k := range both {
   138  			w := w.relabel(fmt.Sprintf("[%#v]", k))
   139  			w.diff(av.MapIndex(k), bv.MapIndex(k))
   140  		}
   141  		for _, k := range bk {
   142  			w := w.relabel(fmt.Sprintf("[%#v]", k))
   143  			w.printf("(missing) != %q", bv.MapIndex(k))
   144  		}
   145  		if av.IsNil() != bv.IsNil() {
   146  			w.printf("%#v != %#v", av, bv)
   147  			break
   148  		}
   149  	case reflect.Ptr:
   150  		switch {
   151  		case av.IsNil() && !bv.IsNil():
   152  			w.printf("nil != %# v", formatter{v: bv, quote: true})
   153  		case !av.IsNil() && bv.IsNil():
   154  			w.printf("%# v != nil", formatter{v: av, quote: true})
   155  		case !av.IsNil() && !bv.IsNil():
   156  			w.diff(av.Elem(), bv.Elem())
   157  		}
   158  	case reflect.Slice:
   159  		lenA := av.Len()
   160  		lenB := bv.Len()
   161  		if lenA != lenB {
   162  			w.printf("%s[%d] != %s[%d]", av.Type(), lenA, bv.Type(), lenB)
   163  			break
   164  		}
   165  		for i := 0; i < lenA; i++ {
   166  			w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
   167  		}
   168  		if av.IsNil() != bv.IsNil() {
   169  			w.printf("%#v != %#v", av, bv)
   170  			break
   171  		}
   172  	case reflect.String:
   173  		if a, b := av.String(), bv.String(); a != b {
   174  			w.printf("%q != %q", a, b)
   175  		}
   176  	case reflect.Struct:
   177  		for i := 0; i < av.NumField(); i++ {
   178  			w.relabel(at.Field(i).Name).diff(av.Field(i), bv.Field(i))
   179  		}
   180  	default:
   181  		panic("unknown reflect Kind: " + kind.String())
   182  	}
   183  }
   184  
   185  func (d diffPrinter) relabel(name string) (d1 diffPrinter) {
   186  	d1 = d
   187  	if d.l != "" && name[0] != '[' {
   188  		d1.l += "."
   189  	}
   190  	d1.l += name
   191  	return d1
   192  }
   193  
   194  // keyEqual compares a and b for equality.
   195  // Both a and b must be valid map keys.
   196  func keyEqual(av, bv reflect.Value) bool {
   197  	if !av.IsValid() && !bv.IsValid() {
   198  		return true
   199  	}
   200  	if !av.IsValid() || !bv.IsValid() || av.Type() != bv.Type() {
   201  		return false
   202  	}
   203  	switch kind := av.Kind(); kind {
   204  	case reflect.Bool:
   205  		a, b := av.Bool(), bv.Bool()
   206  		return a == b
   207  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   208  		a, b := av.Int(), bv.Int()
   209  		return a == b
   210  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   211  		a, b := av.Uint(), bv.Uint()
   212  		return a == b
   213  	case reflect.Float32, reflect.Float64:
   214  		a, b := av.Float(), bv.Float()
   215  		return a == b
   216  	case reflect.Complex64, reflect.Complex128:
   217  		a, b := av.Complex(), bv.Complex()
   218  		return a == b
   219  	case reflect.Array:
   220  		for i := 0; i < av.Len(); i++ {
   221  			if !keyEqual(av.Index(i), bv.Index(i)) {
   222  				return false
   223  			}
   224  		}
   225  		return true
   226  	case reflect.Chan, reflect.UnsafePointer, reflect.Ptr:
   227  		a, b := av.Pointer(), bv.Pointer()
   228  		return a == b
   229  	case reflect.Interface:
   230  		return keyEqual(av.Elem(), bv.Elem())
   231  	case reflect.String:
   232  		a, b := av.String(), bv.String()
   233  		return a == b
   234  	case reflect.Struct:
   235  		for i := 0; i < av.NumField(); i++ {
   236  			if !keyEqual(av.Field(i), bv.Field(i)) {
   237  				return false
   238  			}
   239  		}
   240  		return true
   241  	default:
   242  		panic("invalid map key type " + av.Type().String())
   243  	}
   244  }
   245  
   246  func keyDiff(a, b []reflect.Value) (ak, both, bk []reflect.Value) {
   247  	for _, av := range a {
   248  		inBoth := false
   249  		for _, bv := range b {
   250  			if keyEqual(av, bv) {
   251  				inBoth = true
   252  				both = append(both, av)
   253  				break
   254  			}
   255  		}
   256  		if !inBoth {
   257  			ak = append(ak, av)
   258  		}
   259  	}
   260  	for _, bv := range b {
   261  		inBoth := false
   262  		for _, av := range a {
   263  			if keyEqual(av, bv) {
   264  				inBoth = true
   265  				break
   266  			}
   267  		}
   268  		if !inBoth {
   269  			bk = append(bk, bv)
   270  		}
   271  	}
   272  	return
   273  }