github.com/nikandfor/assert@v0.0.0-20231112165957-bf2ce0a3555a/deep/deep.go (about)

     1  package deep
     2  
     3  import (
     4  	"fmt"
     5  	"hash/crc32"
     6  	"io"
     7  	"math/big"
     8  	"os"
     9  	"reflect"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  	"unicode"
    14  	"unsafe"
    15  
    16  	"tlog.app/go/errors"
    17  )
    18  
    19  type (
    20  	prefixWriter struct {
    21  		io.Writer
    22  		pref []byte
    23  		add  bool
    24  	}
    25  
    26  	visit struct {
    27  		a, b unsafe.Pointer
    28  		typ  reflect.Type
    29  	}
    30  
    31  	// rtype is the common implementation of most values.
    32  	// It is embedded in other struct types.
    33  	//
    34  	// rtype must be kept in sync with ../runtime/type.go:/^type._type.
    35  	rtype struct {
    36  		size       uintptr
    37  		ptrdata    uintptr // number of bytes in the type that can contain pointers
    38  		hash       uint32  // hash of type; avoids computation in hash tables
    39  		tflag      tflag   // extra type information flags
    40  		align      uint8   // alignment of variable with this type
    41  		fieldAlign uint8   // alignment of struct field with this type
    42  		kind       uint8   // enumeration for C
    43  		// function for comparing objects of this type
    44  		// (ptr to object A, ptr to object B) -> ==?
    45  		equal     func(unsafe.Pointer, unsafe.Pointer) bool
    46  		gcdata    *byte   // garbage collection data
    47  		str       nameOff // string form
    48  		ptrToThis typeOff // type for pointer to this type, may be zero
    49  	}
    50  
    51  	tflag uint8
    52  
    53  	nameOff int32
    54  	typeOff int32
    55  
    56  	value struct {
    57  		typ  *rtype
    58  		ptr  unsafe.Pointer
    59  		flag uintptr
    60  	}
    61  
    62  	formatter struct {
    63  		io.Writer
    64  		notnl bool
    65  	}
    66  )
    67  
    68  var spaces = "                                                                          "
    69  
    70  var stop = map[reflect.Type]struct{}{
    71  	reflect.TypeOf(time.Time{}):      struct{}{},
    72  	reflect.TypeOf(&time.Location{}): struct{}{},
    73  	reflect.TypeOf(&big.Int{}):       struct{}{},
    74  	reflect.TypeOf(&os.File{}):       struct{}{},
    75  }
    76  
    77  func Equal(a, b interface{}) bool {
    78  	av := reflect.ValueOf(a)
    79  	bv := reflect.ValueOf(b)
    80  
    81  	return equal(av, bv, nil)
    82  }
    83  
    84  func Diff(w io.Writer, a, b interface{}) bool {
    85  	av := reflect.ValueOf(a)
    86  	bv := reflect.ValueOf(b)
    87  
    88  	return equal(av, bv, nil)
    89  }
    90  
    91  func equal(a, b reflect.Value, visited map[visit]struct{}) bool {
    92  	if !a.IsValid() || !b.IsValid() {
    93  		return a.IsValid() == b.IsValid()
    94  	}
    95  	if a.Type() != b.Type() {
    96  		return false
    97  	}
    98  
    99  	// The hard part is taken from reflect.DeepEqual
   100  
   101  	// We want to avoid putting more in the visited map than we need to.
   102  	// For any possible reference cycle that might be encountered,
   103  	// hard(v1, v2) needs to return true for at least one of the types in the cycle,
   104  	// and it's safe and valid to get Value's internal pointer.
   105  	hard := func(v1, v2 reflect.Value) bool {
   106  		switch v1.Kind() {
   107  		case reflect.Ptr:
   108  			if ptrdata(v1) == 0 {
   109  				// go:notinheap pointers can't be cyclic.
   110  				// At least, all of our current uses of go:notinheap have
   111  				// that property. The runtime ones aren't cyclic (and we don't use
   112  				// DeepEqual on them anyway), and the cgo-generated ones are
   113  				// all empty structs.
   114  				return false
   115  			}
   116  
   117  			fallthrough
   118  		case reflect.Map, reflect.Slice, reflect.Interface:
   119  			// Nil pointers cannot be cyclic. Avoid putting them in the visited map.
   120  			return !v1.IsNil() && !v2.IsNil()
   121  		}
   122  
   123  		return false
   124  	}
   125  
   126  	if hard(a, b) {
   127  		// For a Ptr or Map value, we need to check flagIndir,
   128  		// which we do by calling the pointer method.
   129  		// For Slice or Interface, flagIndir is always set,
   130  		// and using v.ptr suffices.
   131  		ptrval := func(v reflect.Value) unsafe.Pointer {
   132  			switch v.Kind() {
   133  			case reflect.Ptr, reflect.Map:
   134  				return valuePointer(v)
   135  			default:
   136  				return (*value)(unsafe.Pointer(&v)).ptr
   137  			}
   138  		}
   139  
   140  		addr1 := ptrval(a)
   141  		addr2 := ptrval(b)
   142  		if uintptr(addr1) > uintptr(addr2) {
   143  			// Canonicalize order to reduce number of entries in visited.
   144  			// Assumes non-moving garbage collector.
   145  			addr1, addr2 = addr2, addr1
   146  		}
   147  
   148  		// Short circuit if references are already seen.
   149  		typ := a.Type()
   150  		v := visit{a: addr1, b: addr2, typ: typ}
   151  		if _, ok := visited[v]; ok {
   152  			return true
   153  		}
   154  
   155  		if visited == nil {
   156  			visited = make(map[visit]struct{})
   157  		}
   158  
   159  		// Remember for later.
   160  		visited[v] = struct{}{}
   161  	}
   162  
   163  	for a.Kind() == reflect.Ptr {
   164  		if a.IsNil() != b.IsNil() {
   165  			return false
   166  		}
   167  
   168  		if a.IsNil() {
   169  			return true
   170  		}
   171  
   172  		a = a.Elem()
   173  		b = b.Elem()
   174  	}
   175  
   176  	switch a.Kind() {
   177  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   178  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   179  		reflect.Uintptr, reflect.UnsafePointer,
   180  		reflect.Float64, reflect.Float32,
   181  		reflect.Complex128, reflect.Complex64,
   182  		reflect.String,
   183  		reflect.Chan,
   184  		reflect.Bool:
   185  
   186  		return eface(a) == eface(b)
   187  
   188  	case reflect.Interface:
   189  		ai := a.InterfaceData()
   190  		bi := b.InterfaceData()
   191  
   192  		if ai[0] != bi[0] {
   193  			return false
   194  		}
   195  
   196  		return equal(a.Elem(), b.Elem(), visited)
   197  	case reflect.Slice, reflect.Array:
   198  		return equalSlice(a, b, visited)
   199  
   200  	case reflect.Struct:
   201  		return equalStructFields(a, b, visited)
   202  
   203  	case reflect.Map:
   204  		return equalMap(a, b, visited)
   205  
   206  	case reflect.Func:
   207  		return equalFunc(a, b, visited)
   208  
   209  	default:
   210  		panic(fmt.Sprintf("cannot compare %v", a.Kind()))
   211  	}
   212  }
   213  
   214  func equalStructFields(a, b reflect.Value, visited map[visit]struct{}) bool {
   215  	t := a.Type()
   216  
   217  	for i := 0; i < t.NumField(); i++ {
   218  		ft := t.Field(i)
   219  		if ft.Tag.Get("deep") == "-" {
   220  			continue
   221  		}
   222  
   223  		f, ok := getTag(ft, "deep", "compare")
   224  		switch {
   225  		case ok && f == "false":
   226  			continue
   227  		case ok && (f == "nil" || f == "isnil"):
   228  			if a.Field(i).IsNil() != b.Field(i).IsNil() {
   229  				return false
   230  			}
   231  
   232  			continue
   233  		case ok && (f == "pointer" || f == "ptr"):
   234  			if a.Field(i).Pointer() != b.Field(i).Pointer() {
   235  				return false
   236  			}
   237  
   238  			continue
   239  		}
   240  
   241  		if !equal(a.Field(i), b.Field(i), visited) {
   242  			return false
   243  		}
   244  	}
   245  
   246  	return true
   247  }
   248  
   249  func equalSlice(a, b reflect.Value, visited map[visit]struct{}) bool {
   250  	if a.Len() != b.Len() {
   251  		return false
   252  	}
   253  
   254  	for i := 0; i < a.Len(); i++ {
   255  		if !equal(a.Index(i), b.Index(i), visited) {
   256  			return false
   257  		}
   258  	}
   259  
   260  	return true
   261  }
   262  
   263  func equalMap(a, b reflect.Value, visited map[visit]struct{}) bool {
   264  	if a.Len() != b.Len() {
   265  		return false
   266  	}
   267  
   268  	it := a.MapRange()
   269  
   270  	for it.Next() {
   271  		v := b.MapIndex(it.Key())
   272  
   273  		if !equal(it.Value(), v, visited) {
   274  			return false
   275  		}
   276  	}
   277  
   278  	return true
   279  }
   280  
   281  func equalFunc(a, b reflect.Value, visited map[visit]struct{}) bool {
   282  	if a.IsNil() && b.IsNil() {
   283  		return true
   284  	}
   285  
   286  	panic("can't compare funcs")
   287  }
   288  
   289  func Fprint(w io.Writer, x ...interface{}) (n int, err error) {
   290  	f := formatter{
   291  		Writer: w,
   292  	}
   293  
   294  	for i, x := range x {
   295  		n, err = f.print(n, reflect.ValueOf(x), 0, 10)
   296  		if err != nil {
   297  			return n, errors.Wrap(err, "%d", i)
   298  		}
   299  	}
   300  
   301  	return
   302  }
   303  
   304  func (f *formatter) print(n int, x reflect.Value, d, maxdepth int) (m int, err error) {
   305  	//	defer func() {
   306  	//		fmt.Fprintf(os.Stderr, "print: n:%v  x:%v  from %v\n", m, x, loc.Caller(1))
   307  	//	}()
   308  
   309  	if x == (reflect.Value{}) {
   310  		return f.writef(n, "nil")
   311  	}
   312  
   313  	tp := x.Type()
   314  
   315  	if _, ok := stop[tp]; ok {
   316  		return f.writef(n, "%#v", x)
   317  	}
   318  
   319  	if d == maxdepth {
   320  		return f.writef(n, "(%v)(omitted)", x.Type())
   321  	}
   322  
   323  	for x.Kind() == reflect.Ptr {
   324  		if x.IsNil() {
   325  			return f.writef(n, "(%v)(nil)", x.Type())
   326  		}
   327  
   328  		n, err = f.writef(n, "&")
   329  		if err != nil {
   330  			return
   331  		}
   332  
   333  		x = x.Elem()
   334  	}
   335  
   336  	if _, ok := stop[tp]; ok {
   337  		return f.writef(n, "%#v", x)
   338  	}
   339  
   340  	named := x.Type().Name() != x.Kind().String()
   341  
   342  	switch x.Kind() {
   343  	case reflect.Bool:
   344  		if named {
   345  			n, err = f.writef(n, "%v(%v)", x.Type(), x.Bool())
   346  			break
   347  		}
   348  
   349  		n, err = f.writef(n, "%v", x.Bool())
   350  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   351  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   352  		reflect.Uintptr, reflect.UnsafePointer:
   353  
   354  		n, err = f.writef(n, "%v(0x%x)", x.Type(), x)
   355  	case reflect.String:
   356  		vf := "%q"
   357  		if x.Len() > 40 {
   358  			vf = "%-.40q"
   359  		}
   360  
   361  		if named {
   362  			n, err = f.writef(n, "%v("+vf+")", x.Type(), x.String())
   363  			break
   364  		}
   365  
   366  		n, err = f.writef(n, vf, x.String())
   367  	case reflect.Slice, reflect.Array:
   368  		if x.Kind() == reflect.Slice && x.IsNil() {
   369  			return f.writef(n, `%v(nil)`, tp)
   370  		}
   371  
   372  		if tp := x.Type(); tp.Elem().Kind() == reflect.Uint8 {
   373  			if x.Len() > 20 {
   374  				format := `unhex("%x", "total_len=%d,hash=%x")`
   375  				if isPrintable(x.Slice(0, 20).Bytes()) {
   376  					format = `%q, "total_len=%d,hash=%x"`
   377  				}
   378  
   379  				return f.writef(n, `%v(`+format+`)`, tp, x.Slice(0, 20).Bytes(), x.Len(), hashBytes(x.Slice(0, x.Len()).Bytes()))
   380  			}
   381  
   382  			format := `unhex("%x")`
   383  			if isPrintable(x.Slice(0, x.Len()).Bytes()) {
   384  				format = "%q"
   385  			}
   386  
   387  			return f.writef(n, `%v(`+format+`)`, tp, x.Slice(0, x.Len()).Bytes())
   388  		}
   389  
   390  		n, err = f.writef(n, "%v", x.Type())
   391  		if err != nil {
   392  			return
   393  		}
   394  
   395  		n, err = f.printSlice(n, x, d+1, maxdepth)
   396  		if err != nil {
   397  			return
   398  		}
   399  	case reflect.Struct:
   400  		n, err = f.writef(n, "%v{\n", x.Type())
   401  		if err != nil {
   402  			return
   403  		}
   404  
   405  		n, err = f.printStructFields(n, x, d+1, maxdepth)
   406  		if err != nil {
   407  			return
   408  		}
   409  
   410  		n, err = f.ident(n, d, "}")
   411  	case reflect.Interface:
   412  		n, err = f.writef(n, "(%v)(", x.Type())
   413  		if err != nil {
   414  			return
   415  		}
   416  
   417  		n, err = f.print(n, x.Elem(), d+1, maxdepth)
   418  		if err != nil {
   419  			return
   420  		}
   421  
   422  		n, err = f.ident(n, d, ")")
   423  	default:
   424  		n, err = f.writef(n, "%v", x.Type())
   425  		if err != nil {
   426  			return
   427  		}
   428  
   429  		n, err = f.writef(n, " (kind: %v)", x.Kind())
   430  	}
   431  
   432  	if err != nil {
   433  		return
   434  	}
   435  
   436  	return n, nil
   437  }
   438  
   439  func (f *formatter) printStructFields(n int, x reflect.Value, d, maxdepth int) (_ int, err error) {
   440  	t := x.Type()
   441  
   442  	for i := 0; i < t.NumField(); i++ {
   443  		ft := t.Field(i)
   444  		if ft.Tag.Get("deep") == "-" {
   445  			continue
   446  		}
   447  
   448  		fmaxdepth := maxdepth
   449  
   450  		v, ok := getTag(ft, "deep", "print")
   451  		switch {
   452  		case ok && v == "omit":
   453  			continue
   454  		case ok && strings.HasPrefix(v, "maxdepth="):
   455  			v, err := strconv.Atoi(v[len("maxdepth="):])
   456  			if err == nil && fmaxdepth > d+v {
   457  				fmaxdepth = d + v
   458  			}
   459  		}
   460  
   461  		n, err = f.ident(n, d, "")
   462  		if err != nil {
   463  			return
   464  		}
   465  
   466  		n, err = f.writef(n, "%v: ", ft.Name)
   467  		if err != nil {
   468  			return
   469  		}
   470  
   471  		if l := len(ft.Name); l < 14 {
   472  			n, err = f.writef(n, "%v", spaces[:14-l])
   473  			if err != nil {
   474  				return
   475  			}
   476  		}
   477  
   478  		n, err = f.print(n, x.Field(i), d, fmaxdepth)
   479  		if err != nil {
   480  			return
   481  		}
   482  
   483  		n, err = f.writef(n, "\n")
   484  		if err != nil {
   485  			return
   486  		}
   487  	}
   488  
   489  	return n, nil
   490  }
   491  
   492  func (f *formatter) printSlice(n int, x reflect.Value, d, maxdepth int) (m int, err error) {
   493  	t := x.Type().Elem()
   494  	k := t.Kind()
   495  
   496  	if x.IsNil() {
   497  		return f.writef(n, "(nil)")
   498  	}
   499  
   500  	if k == reflect.Uint8 {
   501  		ok := 0
   502  		for _, c := range x.Bytes() {
   503  			if c >= 0x20 && c < 0x80 {
   504  				ok++
   505  			}
   506  		}
   507  
   508  		if ok*5/4 >= x.Len() {
   509  			return f.writef(n, "(%q)", x.Bytes())
   510  		}
   511  	}
   512  
   513  	switch k {
   514  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   515  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   516  		reflect.Uintptr, reflect.UnsafePointer:
   517  
   518  		n, err = f.writef(n, "{")
   519  		if err != nil {
   520  			return
   521  		}
   522  
   523  		for i := 0; i < x.Len(); i++ {
   524  			if i != 0 {
   525  				n, err = f.writef(n, ", ")
   526  				if err != nil {
   527  					return
   528  				}
   529  			}
   530  
   531  			if i == 10 {
   532  				n, err = f.writef(n, "... %d elements", x.Len()-i)
   533  				if err != nil {
   534  					return
   535  				}
   536  
   537  				break
   538  			}
   539  
   540  			xx := x.Index(i)
   541  
   542  			if k == reflect.UnsafePointer {
   543  				n, err = f.writef(n, "0x%x", xx.Pointer())
   544  				if err != nil {
   545  					return
   546  				}
   547  
   548  				continue
   549  			}
   550  
   551  			var val interface{}
   552  
   553  			switch k {
   554  			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   555  				val = xx.Int()
   556  			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   557  				reflect.Uintptr:
   558  				val = xx.Uint()
   559  			}
   560  
   561  			n, err = f.writef(n, "%v", val)
   562  			if err != nil {
   563  				return
   564  			}
   565  		}
   566  
   567  		n, err = f.writef(n, "}")
   568  	default:
   569  		n, err = f.writef(n, "{")
   570  		if err != nil {
   571  			return
   572  		}
   573  
   574  		for i := 0; i < x.Len(); i++ {
   575  			if i != 0 {
   576  				n, err = f.writef(n, ", ")
   577  				if err != nil {
   578  					return
   579  				}
   580  			}
   581  
   582  			xx := x.Index(i)
   583  
   584  			n, err = f.print(n, xx, d+1, maxdepth)
   585  			if err != nil {
   586  				return
   587  			}
   588  		}
   589  
   590  		n, err = f.writef(n, "}")
   591  	}
   592  	if err != nil {
   593  		return
   594  	}
   595  
   596  	return n, nil
   597  }
   598  
   599  func (f *formatter) ident(n, d int, fmt string, args ...interface{}) (_ int, err error) {
   600  	if !f.notnl {
   601  		n, err = f.writef(n, "%s", spaces[:4*d])
   602  		if err != nil {
   603  			return
   604  		}
   605  	}
   606  
   607  	if fmt == "" && len(args) == 0 {
   608  		return n, err
   609  	}
   610  
   611  	return f.writef(n, fmt, args...)
   612  }
   613  
   614  func (f *formatter) writef(i int, format string, args ...interface{}) (n int, err error) {
   615  	n, err = fmt.Fprintf(f, format, args...)
   616  	return i + n, err
   617  }
   618  
   619  func (f *formatter) Write(p []byte) (n int, err error) {
   620  	if len(p) != 0 {
   621  		f.notnl = p[len(p)-1] != '\n'
   622  	}
   623  
   624  	return f.Writer.Write(p)
   625  }
   626  
   627  func getTag(x reflect.StructField, t, k string) (string, bool) {
   628  	tags := strings.Split(x.Tag.Get(t), ",")
   629  
   630  	for _, tag := range tags {
   631  		kv := strings.SplitN(tag, "=", 2)
   632  		if kv[0] == k {
   633  			if len(kv) == 1 {
   634  				return "", true
   635  			}
   636  
   637  			return kv[1], true
   638  		}
   639  	}
   640  
   641  	return "", false
   642  }
   643  
   644  func (w *prefixWriter) Write(p []byte) (n int, err error) {
   645  	i := 0
   646  
   647  	for i < len(p) {
   648  		if w.add {
   649  			_, err = w.Writer.Write(w.pref)
   650  			if err != nil {
   651  				return
   652  			}
   653  		}
   654  
   655  		st := i
   656  
   657  		for i < len(p) && p[i] != '\n' {
   658  			i++
   659  		}
   660  
   661  		if i < len(p) && p[i] == '\n' {
   662  			i++
   663  
   664  			w.add = true
   665  		}
   666  
   667  		var m int
   668  		m, err = w.Writer.Write(p[st:i])
   669  		n += m
   670  		if err != nil {
   671  			return
   672  		}
   673  	}
   674  
   675  	return
   676  }
   677  
   678  func eface(x reflect.Value) interface{} {
   679  	return *(*interface{})(unsafe.Pointer(&x))
   680  }
   681  
   682  func ptrdata(v reflect.Value) uintptr {
   683  	return (*value)(unsafe.Pointer(&v)).typ.ptrdata
   684  }
   685  
   686  //go:linkname valuePointer reflect.Value.pointer
   687  func valuePointer(v reflect.Value) unsafe.Pointer
   688  
   689  func hashBytes(d []byte) uint32 {
   690  	return crc32.ChecksumIEEE(d)
   691  }
   692  
   693  func isPrintable(b []byte) bool {
   694  	for _, r := range string(b) {
   695  		if !unicode.IsPrint(r) {
   696  			return false
   697  		}
   698  	}
   699  
   700  	return true
   701  }