github.com/wzzhu/tensor@v0.9.24/testutils_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"math"
     7  	"math/cmplx"
     8  	"math/rand"
     9  	"reflect"
    10  	"testing"
    11  	"testing/quick"
    12  	"time"
    13  	"unsafe"
    14  
    15  	"github.com/chewxy/math32"
    16  	"github.com/wzzhu/tensor/internal/storage"
    17  )
    18  
    19  func randomBool() bool {
    20  	i := rand.Intn(11)
    21  	return i > 5
    22  }
    23  
    24  // from : https://stackoverflow.com/a/31832326/3426066
    25  const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
    26  const (
    27  	letterIdxBits = 6                    // 6 bits to represent a letter index
    28  	letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
    29  	letterIdxMax  = 63 / letterIdxBits   // # of letter indices fitting in 63 bits
    30  )
    31  
    32  const quickchecks = 1000
    33  
    34  func newRand() *rand.Rand {
    35  	return rand.New(rand.NewSource(time.Now().UnixNano()))
    36  }
    37  
    38  func randomString() string {
    39  	n := rand.Intn(10)
    40  	b := make([]byte, n)
    41  	src := newRand()
    42  	// A src.Int63() generates 63 random bits, enough for letterIdxMax characters!
    43  	for i, cache, remain := n-1, src.Int63(), letterIdxMax; i >= 0; {
    44  		if remain == 0 {
    45  			cache, remain = src.Int63(), letterIdxMax
    46  		}
    47  		if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
    48  			b[i] = letterBytes[idx]
    49  			i--
    50  		}
    51  		cache >>= letterIdxBits
    52  		remain--
    53  	}
    54  
    55  	return string(b)
    56  }
    57  
    58  // taken from the Go Stdlib package math
    59  func tolerancef64(a, b, e float64) bool {
    60  	d := a - b
    61  	if d < 0 {
    62  		d = -d
    63  	}
    64  
    65  	// note: b is correct (expected) value, a is actual value.
    66  	// make error tolerance a fraction of b, not a.
    67  	if b != 0 {
    68  		e = e * b
    69  		if e < 0 {
    70  			e = -e
    71  		}
    72  	}
    73  	return d < e
    74  }
    75  func closeenoughf64(a, b float64) bool { return tolerancef64(a, b, 1e-8) }
    76  func closef64(a, b float64) bool       { return tolerancef64(a, b, 1e-14) }
    77  func veryclosef64(a, b float64) bool   { return tolerancef64(a, b, 4e-16) }
    78  func soclosef64(a, b, e float64) bool  { return tolerancef64(a, b, e) }
    79  func alikef64(a, b float64) bool {
    80  	switch {
    81  	case math.IsNaN(a) && math.IsNaN(b):
    82  		return true
    83  	case a == b:
    84  		return math.Signbit(a) == math.Signbit(b)
    85  	}
    86  	return false
    87  }
    88  
    89  // taken from math32, which was taken from the Go std lib
    90  func tolerancef32(a, b, e float32) bool {
    91  	d := a - b
    92  	if d < 0 {
    93  		d = -d
    94  	}
    95  
    96  	// note: b is correct (expected) value, a is actual value.
    97  	// make error tolerance a fraction of b, not a.
    98  	if b != 0 {
    99  		e = e * b
   100  		if e < 0 {
   101  			e = -e
   102  		}
   103  	}
   104  	return d < e
   105  }
   106  func closef32(a, b float32) bool      { return tolerancef32(a, b, 1e-5) } // the number gotten from the cfloat standard. Haskell's Linear package uses 1e-6 for floats
   107  func veryclosef32(a, b float32) bool  { return tolerancef32(a, b, 1e-6) } // from wiki
   108  func soclosef32(a, b, e float32) bool { return tolerancef32(a, b, e) }
   109  func alikef32(a, b float32) bool {
   110  	switch {
   111  	case math32.IsNaN(a) && math32.IsNaN(b):
   112  		return true
   113  	case a == b:
   114  		return math32.Signbit(a) == math32.Signbit(b)
   115  	}
   116  	return false
   117  }
   118  
   119  // taken from math/cmplx test
   120  func cTolerance(a, b complex128, e float64) bool {
   121  	d := cmplx.Abs(a - b)
   122  	if b != 0 {
   123  		e = e * cmplx.Abs(b)
   124  		if e < 0 {
   125  			e = -e
   126  		}
   127  	}
   128  	return d < e
   129  }
   130  
   131  func cClose(a, b complex128) bool              { return cTolerance(a, b, 1e-14) }
   132  func cSoclose(a, b complex128, e float64) bool { return cTolerance(a, b, e) }
   133  func cVeryclose(a, b complex128) bool          { return cTolerance(a, b, 4e-16) }
   134  func cAlike(a, b complex128) bool {
   135  	switch {
   136  	case cmplx.IsNaN(a) && cmplx.IsNaN(b):
   137  		return true
   138  	case a == b:
   139  		return math.Signbit(real(a)) == math.Signbit(real(b)) && math.Signbit(imag(a)) == math.Signbit(imag(b))
   140  	}
   141  	return false
   142  }
   143  
   144  func allClose(a, b interface{}, approxFn ...interface{}) bool {
   145  	switch at := a.(type) {
   146  	case []float64:
   147  		closeness := closef64
   148  		var ok bool
   149  		if len(approxFn) > 0 {
   150  			if closeness, ok = approxFn[0].(func(a, b float64) bool); !ok {
   151  				closeness = closef64
   152  			}
   153  		}
   154  		bt := b.([]float64)
   155  		for i, v := range at {
   156  			if math.IsNaN(v) {
   157  				if !math.IsNaN(bt[i]) {
   158  					return false
   159  				}
   160  				continue
   161  			}
   162  			if math.IsInf(v, 0) {
   163  				if !math.IsInf(bt[i], 0) {
   164  					return false
   165  				}
   166  				continue
   167  			}
   168  			if !closeness(v, bt[i]) {
   169  				return false
   170  			}
   171  		}
   172  		return true
   173  	case []float32:
   174  		closeness := closef32
   175  		var ok bool
   176  		if len(approxFn) > 0 {
   177  			if closeness, ok = approxFn[0].(func(a, b float32) bool); !ok {
   178  				closeness = closef32
   179  			}
   180  		}
   181  		bt := b.([]float32)
   182  		for i, v := range at {
   183  			if math32.IsNaN(v) {
   184  				if !math32.IsNaN(bt[i]) {
   185  					return false
   186  				}
   187  				continue
   188  			}
   189  			if math32.IsInf(v, 0) {
   190  				if !math32.IsInf(bt[i], 0) {
   191  					return false
   192  				}
   193  				continue
   194  			}
   195  			if !closeness(v, bt[i]) {
   196  				return false
   197  			}
   198  		}
   199  		return true
   200  	case []complex64:
   201  		bt := b.([]complex64)
   202  		for i, v := range at {
   203  			if cmplx.IsNaN(complex128(v)) {
   204  				if !cmplx.IsNaN(complex128(bt[i])) {
   205  					return false
   206  				}
   207  				continue
   208  			}
   209  			if cmplx.IsInf(complex128(v)) {
   210  				if !cmplx.IsInf(complex128(bt[i])) {
   211  					return false
   212  				}
   213  				continue
   214  			}
   215  			if !cSoclose(complex128(v), complex128(bt[i]), 1e-5) {
   216  				return false
   217  			}
   218  		}
   219  		return true
   220  	case []complex128:
   221  		bt := b.([]complex128)
   222  		for i, v := range at {
   223  			if cmplx.IsNaN(v) {
   224  				if !cmplx.IsNaN(bt[i]) {
   225  					return false
   226  				}
   227  				continue
   228  			}
   229  			if cmplx.IsInf(v) {
   230  				if !cmplx.IsInf(bt[i]) {
   231  					return false
   232  				}
   233  				continue
   234  			}
   235  			if !cClose(v, bt[i]) {
   236  				return false
   237  			}
   238  		}
   239  		return true
   240  	default:
   241  		return reflect.DeepEqual(a, b)
   242  	}
   243  }
   244  
   245  func checkErr(t *testing.T, expected bool, err error, name string, id interface{}) (cont bool) {
   246  	switch {
   247  	case expected:
   248  		if err == nil {
   249  			t.Errorf("Expected error in test %v (%v)", name, id)
   250  		}
   251  		return true
   252  	case !expected && err != nil:
   253  		t.Errorf("Test %v (%v) errored: %+v", name, id, err)
   254  		return true
   255  	}
   256  	return false
   257  }
   258  
   259  func sliceApproxf64(a, b []float64, fn func(a, b float64) bool) bool {
   260  	if len(a) != len(b) {
   261  		return false
   262  	}
   263  
   264  	for i, v := range a {
   265  		if math.IsNaN(v) {
   266  			if !alikef64(v, b[i]) {
   267  				return false
   268  			}
   269  		}
   270  		if !fn(v, b[i]) {
   271  			return false
   272  		}
   273  	}
   274  	return true
   275  }
   276  
   277  func RandomFloat64(size int) []float64 {
   278  	r := make([]float64, size)
   279  	for i := range r {
   280  		r[i] = rand.NormFloat64()
   281  	}
   282  	return r
   283  }
   284  
   285  func factorize(a int) []int {
   286  	if a <= 0 {
   287  		return nil
   288  	}
   289  	// all numbers are divisible by at least 1
   290  	retVal := make([]int, 1)
   291  	retVal[0] = 1
   292  
   293  	fill := func(a int, e int) {
   294  		n := len(retVal)
   295  		for i, p := 0, a; i < e; i, p = i+1, p*a {
   296  			for j := 0; j < n; j++ {
   297  				retVal = append(retVal, retVal[j]*p)
   298  			}
   299  		}
   300  	}
   301  	// find factors of 2
   302  	// rightshift by 1 = division by 2
   303  	var e int
   304  	for ; a&1 == 0; e++ {
   305  		a >>= 1
   306  	}
   307  	fill(2, e)
   308  
   309  	// find factors of 3 and up
   310  	for next := 3; a > 1; next += 2 {
   311  		if next*next > a {
   312  			next = a
   313  		}
   314  		for e = 0; a%next == 0; e++ {
   315  			a /= next
   316  		}
   317  		if e > 0 {
   318  			fill(next, e)
   319  		}
   320  	}
   321  	return retVal
   322  }
   323  
   324  func shuffleInts(a []int, r *rand.Rand) {
   325  	for i := range a {
   326  		j := r.Intn(i + 1)
   327  		a[i], a[j] = a[j], a[i]
   328  	}
   329  }
   330  
   331  type TensorGenerator struct {
   332  	ShapeConstraint Shape
   333  	DtypeConstraint Dtype
   334  }
   335  
   336  func (g TensorGenerator) Generate(r *rand.Rand, size int) reflect.Value {
   337  	var retVal Tensor
   338  	// generate type of tensor
   339  
   340  	return reflect.ValueOf(retVal)
   341  }
   342  
   343  func (t *Dense) Generate(r *rand.Rand, size int) reflect.Value {
   344  	// generate type
   345  	ri := r.Intn(len(specializedTypes.set))
   346  	of := specializedTypes.set[ri]
   347  	datatyp := reflect.SliceOf(of.Type)
   348  	gendat, _ := quick.Value(datatyp, r)
   349  	// generate dims
   350  	var scalar bool
   351  	var s Shape
   352  	dims := r.Intn(5) // dims4 is the max we'll generate even though we can handle much more
   353  	l := gendat.Len()
   354  
   355  	// generate shape based on inputs
   356  	switch {
   357  	case dims == 0 || l == 0:
   358  		scalar = true
   359  		gendat, _ = quick.Value(of.Type, r)
   360  	case dims == 1:
   361  		s = Shape{gendat.Len()}
   362  	default:
   363  		factors := factorize(l)
   364  		s = Shape(BorrowInts(dims))
   365  		// fill with 1s so that we can get a non-zero TotalSize
   366  		for i := 0; i < len(s); i++ {
   367  			s[i] = 1
   368  		}
   369  
   370  		for i := 0; i < dims; i++ {
   371  			j := rand.Intn(len(factors))
   372  			s[i] = factors[j]
   373  			size := s.TotalSize()
   374  			if q, r := divmod(l, size); r != 0 {
   375  				factors = factorize(r)
   376  			} else if size != l {
   377  				if i < dims-2 {
   378  					factors = factorize(q)
   379  				} else if i == dims-2 {
   380  					s[i+1] = q
   381  					break
   382  				}
   383  			} else {
   384  				break
   385  			}
   386  		}
   387  		shuffleInts(s, r)
   388  	}
   389  
   390  	// generate flags
   391  	flag := MemoryFlag(r.Intn(4))
   392  
   393  	// generate order
   394  	order := DataOrder(r.Intn(4))
   395  
   396  	var v *Dense
   397  	if scalar {
   398  		v = New(FromScalar(gendat.Interface()))
   399  	} else {
   400  		v = New(Of(of), WithShape(s...), WithBacking(gendat.Interface()))
   401  	}
   402  
   403  	v.flag = flag
   404  	v.AP.o = order
   405  
   406  	// generate engine
   407  	oeint := r.Intn(2)
   408  	eint := r.Intn(4)
   409  	switch eint {
   410  	case 0:
   411  		v.e = StdEng{}
   412  		if oeint == 0 {
   413  			v.oe = StdEng{}
   414  		} else {
   415  			v.oe = nil
   416  		}
   417  	case 1:
   418  		// check is to prevent panics which Float64Engine will do if asked to allocate memory for non float64s
   419  		if of == Float64 {
   420  			v.e = Float64Engine{}
   421  			if oeint == 0 {
   422  				v.oe = Float64Engine{}
   423  			} else {
   424  				v.oe = nil
   425  			}
   426  		} else {
   427  			v.e = StdEng{}
   428  			if oeint == 0 {
   429  				v.oe = StdEng{}
   430  			} else {
   431  				v.oe = nil
   432  			}
   433  		}
   434  	case 2:
   435  		// check is to prevent panics which Float64Engine will do if asked to allocate memory for non float64s
   436  		if of == Float32 {
   437  			v.e = Float32Engine{}
   438  			if oeint == 0 {
   439  				v.oe = Float32Engine{}
   440  			} else {
   441  				v.oe = nil
   442  			}
   443  		} else {
   444  			v.e = StdEng{}
   445  			if oeint == 0 {
   446  				v.oe = StdEng{}
   447  			} else {
   448  				v.oe = nil
   449  			}
   450  		}
   451  	case 3:
   452  		v.e = dummyEngine(true)
   453  		v.oe = nil
   454  	}
   455  
   456  	return reflect.ValueOf(v)
   457  }
   458  
   459  // fakemem is a byteslice, while making it a Memory
   460  type fakemem []byte
   461  
   462  func (m fakemem) Uintptr() uintptr        { return uintptr(unsafe.Pointer(&m[0])) }
   463  func (m fakemem) MemSize() uintptr        { return uintptr(len(m)) }
   464  func (m fakemem) Pointer() unsafe.Pointer { return unsafe.Pointer(&m[0]) }
   465  
   466  // dummyEngine implements Engine. The bool indicates whether the data is native-accessible
   467  type dummyEngine bool
   468  
   469  func (e dummyEngine) AllocAccessible() bool { return bool(e) }
   470  func (e dummyEngine) Alloc(size int64) (Memory, error) {
   471  	ps := make(fakemem, int(size))
   472  	return ps, nil
   473  }
   474  func (e dummyEngine) Free(mem Memory, size int64) error        { return nil }
   475  func (e dummyEngine) Memset(mem Memory, val interface{}) error { return nil }
   476  func (e dummyEngine) Memclr(mem Memory)                        {}
   477  func (e dummyEngine) Memcpy(dst, src Memory) error {
   478  	if e {
   479  		var a, b storage.Header
   480  		a.Raw = storage.FromMemory(src.Uintptr(), src.MemSize())
   481  		b.Raw = storage.FromMemory(dst.Uintptr(), dst.MemSize())
   482  
   483  		copy(b.Raw, a.Raw)
   484  		return nil
   485  	}
   486  	return errors.New("Unable to copy ")
   487  }
   488  func (e dummyEngine) Accessible(mem Memory) (Memory, error) { return mem, nil }
   489  func (e dummyEngine) WorksWith(order DataOrder) bool        { return true }
   490  
   491  // dummyEngine2 is used for testing additional methods that may not be provided in the stdeng
   492  type dummyEngine2 struct {
   493  	e StdEng
   494  }
   495  
   496  func (e dummyEngine2) AllocAccessible() bool                    { return e.e.AllocAccessible() }
   497  func (e dummyEngine2) Alloc(size int64) (Memory, error)         { return e.e.Alloc(size) }
   498  func (e dummyEngine2) Free(mem Memory, size int64) error        { return e.e.Free(mem, size) }
   499  func (e dummyEngine2) Memset(mem Memory, val interface{}) error { return e.e.Memset(mem, val) }
   500  func (e dummyEngine2) Memclr(mem Memory)                        { e.e.Memclr(mem) }
   501  func (e dummyEngine2) Memcpy(dst, src Memory) error             { return e.e.Memcpy(dst, src) }
   502  func (e dummyEngine2) Accessible(mem Memory) (Memory, error)    { return e.e.Accessible(mem) }
   503  func (e dummyEngine2) WorksWith(order DataOrder) bool           { return e.e.WorksWith(order) }
   504  
   505  func (e dummyEngine2) Argmax(t Tensor, axis int) (Tensor, error) { return e.e.Argmax(t, axis) }
   506  func (e dummyEngine2) Argmin(t Tensor, axis int) (Tensor, error) { return e.e.Argmin(t, axis) }
   507  
   508  func willerr(a *Dense, tc, eqtc *typeclass) (retVal, willFailEq bool) {
   509  	if err := typeclassCheck(a.Dtype(), eqtc); err == nil {
   510  		willFailEq = true
   511  	}
   512  	if err := typeclassCheck(a.Dtype(), tc); err != nil {
   513  		return true, willFailEq
   514  	}
   515  
   516  	retVal = retVal || !a.IsNativelyAccessible()
   517  	return
   518  }
   519  
   520  func qcErrCheck(t *testing.T, name string, a Dtyper, b interface{}, we bool, err error) (e error, retEarly bool) {
   521  	switch {
   522  	case !we && err != nil:
   523  		t.Errorf("Tests for %v (%v) was unable to proceed: %v", name, a.Dtype(), err)
   524  		return err, true
   525  	case we && err == nil:
   526  		if b == nil {
   527  			t.Errorf("Expected error when performing %v on %T of %v ", name, a, a.Dtype())
   528  			return errors.New("Error"), true
   529  		}
   530  		if bd, ok := b.(Dtyper); ok {
   531  			t.Errorf("Expected error when performing %v on %T of %v and %T of %v", name, a, a.Dtype(), b, bd.Dtype())
   532  		} else {
   533  			t.Errorf("Expected error when performing %v on %T of %v and %v of %T", name, a, a.Dtype(), b, b)
   534  		}
   535  		return errors.New("Error"), true
   536  	case we && err != nil:
   537  		return nil, true
   538  	}
   539  	return nil, false
   540  }
   541  
   542  func qcIsFloat(dt Dtype) bool {
   543  	if err := typeclassCheck(dt, floatcmplxTypes); err == nil {
   544  		return true
   545  	}
   546  	return false
   547  }
   548  
   549  func qcEqCheck(t *testing.T, dt Dtype, willFailEq bool, correct, got interface{}) bool {
   550  	isFloatTypes := qcIsFloat(dt)
   551  	if !willFailEq && (isFloatTypes && !allClose(correct, got) || (!isFloatTypes && !reflect.DeepEqual(correct, got))) {
   552  		t.Errorf("q.Dtype: %v", dt)
   553  		t.Errorf("correct\n%v", correct)
   554  		t.Errorf("got\n%v", got)
   555  		return false
   556  	}
   557  	return true
   558  }
   559  
   560  // DummyState is a dummy fmt.State, used to debug things
   561  type DummyState struct {
   562  	*bytes.Buffer
   563  }
   564  
   565  func (d *DummyState) Width() (int, bool)     { return 0, false }
   566  func (d *DummyState) Precision() (int, bool) { return 0, false }
   567  func (d *DummyState) Flag(c int) bool        { return false }