github.com/wzzhu/tensor@v0.9.24/testutils_test.go (about) 1 package tensor 2 3 import ( 4 "bytes" 5 "errors" 6 "math" 7 "math/cmplx" 8 "math/rand" 9 "reflect" 10 "testing" 11 "testing/quick" 12 "time" 13 "unsafe" 14 15 "github.com/chewxy/math32" 16 "github.com/wzzhu/tensor/internal/storage" 17 ) 18 19 func randomBool() bool { 20 i := rand.Intn(11) 21 return i > 5 22 } 23 24 // from : https://stackoverflow.com/a/31832326/3426066 25 const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 26 const ( 27 letterIdxBits = 6 // 6 bits to represent a letter index 28 letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits 29 letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits 30 ) 31 32 const quickchecks = 1000 33 34 func newRand() *rand.Rand { 35 return rand.New(rand.NewSource(time.Now().UnixNano())) 36 } 37 38 func randomString() string { 39 n := rand.Intn(10) 40 b := make([]byte, n) 41 src := newRand() 42 // A src.Int63() generates 63 random bits, enough for letterIdxMax characters! 43 for i, cache, remain := n-1, src.Int63(), letterIdxMax; i >= 0; { 44 if remain == 0 { 45 cache, remain = src.Int63(), letterIdxMax 46 } 47 if idx := int(cache & letterIdxMask); idx < len(letterBytes) { 48 b[i] = letterBytes[idx] 49 i-- 50 } 51 cache >>= letterIdxBits 52 remain-- 53 } 54 55 return string(b) 56 } 57 58 // taken from the Go Stdlib package math 59 func tolerancef64(a, b, e float64) bool { 60 d := a - b 61 if d < 0 { 62 d = -d 63 } 64 65 // note: b is correct (expected) value, a is actual value. 66 // make error tolerance a fraction of b, not a. 67 if b != 0 { 68 e = e * b 69 if e < 0 { 70 e = -e 71 } 72 } 73 return d < e 74 } 75 func closeenoughf64(a, b float64) bool { return tolerancef64(a, b, 1e-8) } 76 func closef64(a, b float64) bool { return tolerancef64(a, b, 1e-14) } 77 func veryclosef64(a, b float64) bool { return tolerancef64(a, b, 4e-16) } 78 func soclosef64(a, b, e float64) bool { return tolerancef64(a, b, e) } 79 func alikef64(a, b float64) bool { 80 switch { 81 case math.IsNaN(a) && math.IsNaN(b): 82 return true 83 case a == b: 84 return math.Signbit(a) == math.Signbit(b) 85 } 86 return false 87 } 88 89 // taken from math32, which was taken from the Go std lib 90 func tolerancef32(a, b, e float32) bool { 91 d := a - b 92 if d < 0 { 93 d = -d 94 } 95 96 // note: b is correct (expected) value, a is actual value. 97 // make error tolerance a fraction of b, not a. 98 if b != 0 { 99 e = e * b 100 if e < 0 { 101 e = -e 102 } 103 } 104 return d < e 105 } 106 func closef32(a, b float32) bool { return tolerancef32(a, b, 1e-5) } // the number gotten from the cfloat standard. Haskell's Linear package uses 1e-6 for floats 107 func veryclosef32(a, b float32) bool { return tolerancef32(a, b, 1e-6) } // from wiki 108 func soclosef32(a, b, e float32) bool { return tolerancef32(a, b, e) } 109 func alikef32(a, b float32) bool { 110 switch { 111 case math32.IsNaN(a) && math32.IsNaN(b): 112 return true 113 case a == b: 114 return math32.Signbit(a) == math32.Signbit(b) 115 } 116 return false 117 } 118 119 // taken from math/cmplx test 120 func cTolerance(a, b complex128, e float64) bool { 121 d := cmplx.Abs(a - b) 122 if b != 0 { 123 e = e * cmplx.Abs(b) 124 if e < 0 { 125 e = -e 126 } 127 } 128 return d < e 129 } 130 131 func cClose(a, b complex128) bool { return cTolerance(a, b, 1e-14) } 132 func cSoclose(a, b complex128, e float64) bool { return cTolerance(a, b, e) } 133 func cVeryclose(a, b complex128) bool { return cTolerance(a, b, 4e-16) } 134 func cAlike(a, b complex128) bool { 135 switch { 136 case cmplx.IsNaN(a) && cmplx.IsNaN(b): 137 return true 138 case a == b: 139 return math.Signbit(real(a)) == math.Signbit(real(b)) && math.Signbit(imag(a)) == math.Signbit(imag(b)) 140 } 141 return false 142 } 143 144 func allClose(a, b interface{}, approxFn ...interface{}) bool { 145 switch at := a.(type) { 146 case []float64: 147 closeness := closef64 148 var ok bool 149 if len(approxFn) > 0 { 150 if closeness, ok = approxFn[0].(func(a, b float64) bool); !ok { 151 closeness = closef64 152 } 153 } 154 bt := b.([]float64) 155 for i, v := range at { 156 if math.IsNaN(v) { 157 if !math.IsNaN(bt[i]) { 158 return false 159 } 160 continue 161 } 162 if math.IsInf(v, 0) { 163 if !math.IsInf(bt[i], 0) { 164 return false 165 } 166 continue 167 } 168 if !closeness(v, bt[i]) { 169 return false 170 } 171 } 172 return true 173 case []float32: 174 closeness := closef32 175 var ok bool 176 if len(approxFn) > 0 { 177 if closeness, ok = approxFn[0].(func(a, b float32) bool); !ok { 178 closeness = closef32 179 } 180 } 181 bt := b.([]float32) 182 for i, v := range at { 183 if math32.IsNaN(v) { 184 if !math32.IsNaN(bt[i]) { 185 return false 186 } 187 continue 188 } 189 if math32.IsInf(v, 0) { 190 if !math32.IsInf(bt[i], 0) { 191 return false 192 } 193 continue 194 } 195 if !closeness(v, bt[i]) { 196 return false 197 } 198 } 199 return true 200 case []complex64: 201 bt := b.([]complex64) 202 for i, v := range at { 203 if cmplx.IsNaN(complex128(v)) { 204 if !cmplx.IsNaN(complex128(bt[i])) { 205 return false 206 } 207 continue 208 } 209 if cmplx.IsInf(complex128(v)) { 210 if !cmplx.IsInf(complex128(bt[i])) { 211 return false 212 } 213 continue 214 } 215 if !cSoclose(complex128(v), complex128(bt[i]), 1e-5) { 216 return false 217 } 218 } 219 return true 220 case []complex128: 221 bt := b.([]complex128) 222 for i, v := range at { 223 if cmplx.IsNaN(v) { 224 if !cmplx.IsNaN(bt[i]) { 225 return false 226 } 227 continue 228 } 229 if cmplx.IsInf(v) { 230 if !cmplx.IsInf(bt[i]) { 231 return false 232 } 233 continue 234 } 235 if !cClose(v, bt[i]) { 236 return false 237 } 238 } 239 return true 240 default: 241 return reflect.DeepEqual(a, b) 242 } 243 } 244 245 func checkErr(t *testing.T, expected bool, err error, name string, id interface{}) (cont bool) { 246 switch { 247 case expected: 248 if err == nil { 249 t.Errorf("Expected error in test %v (%v)", name, id) 250 } 251 return true 252 case !expected && err != nil: 253 t.Errorf("Test %v (%v) errored: %+v", name, id, err) 254 return true 255 } 256 return false 257 } 258 259 func sliceApproxf64(a, b []float64, fn func(a, b float64) bool) bool { 260 if len(a) != len(b) { 261 return false 262 } 263 264 for i, v := range a { 265 if math.IsNaN(v) { 266 if !alikef64(v, b[i]) { 267 return false 268 } 269 } 270 if !fn(v, b[i]) { 271 return false 272 } 273 } 274 return true 275 } 276 277 func RandomFloat64(size int) []float64 { 278 r := make([]float64, size) 279 for i := range r { 280 r[i] = rand.NormFloat64() 281 } 282 return r 283 } 284 285 func factorize(a int) []int { 286 if a <= 0 { 287 return nil 288 } 289 // all numbers are divisible by at least 1 290 retVal := make([]int, 1) 291 retVal[0] = 1 292 293 fill := func(a int, e int) { 294 n := len(retVal) 295 for i, p := 0, a; i < e; i, p = i+1, p*a { 296 for j := 0; j < n; j++ { 297 retVal = append(retVal, retVal[j]*p) 298 } 299 } 300 } 301 // find factors of 2 302 // rightshift by 1 = division by 2 303 var e int 304 for ; a&1 == 0; e++ { 305 a >>= 1 306 } 307 fill(2, e) 308 309 // find factors of 3 and up 310 for next := 3; a > 1; next += 2 { 311 if next*next > a { 312 next = a 313 } 314 for e = 0; a%next == 0; e++ { 315 a /= next 316 } 317 if e > 0 { 318 fill(next, e) 319 } 320 } 321 return retVal 322 } 323 324 func shuffleInts(a []int, r *rand.Rand) { 325 for i := range a { 326 j := r.Intn(i + 1) 327 a[i], a[j] = a[j], a[i] 328 } 329 } 330 331 type TensorGenerator struct { 332 ShapeConstraint Shape 333 DtypeConstraint Dtype 334 } 335 336 func (g TensorGenerator) Generate(r *rand.Rand, size int) reflect.Value { 337 var retVal Tensor 338 // generate type of tensor 339 340 return reflect.ValueOf(retVal) 341 } 342 343 func (t *Dense) Generate(r *rand.Rand, size int) reflect.Value { 344 // generate type 345 ri := r.Intn(len(specializedTypes.set)) 346 of := specializedTypes.set[ri] 347 datatyp := reflect.SliceOf(of.Type) 348 gendat, _ := quick.Value(datatyp, r) 349 // generate dims 350 var scalar bool 351 var s Shape 352 dims := r.Intn(5) // dims4 is the max we'll generate even though we can handle much more 353 l := gendat.Len() 354 355 // generate shape based on inputs 356 switch { 357 case dims == 0 || l == 0: 358 scalar = true 359 gendat, _ = quick.Value(of.Type, r) 360 case dims == 1: 361 s = Shape{gendat.Len()} 362 default: 363 factors := factorize(l) 364 s = Shape(BorrowInts(dims)) 365 // fill with 1s so that we can get a non-zero TotalSize 366 for i := 0; i < len(s); i++ { 367 s[i] = 1 368 } 369 370 for i := 0; i < dims; i++ { 371 j := rand.Intn(len(factors)) 372 s[i] = factors[j] 373 size := s.TotalSize() 374 if q, r := divmod(l, size); r != 0 { 375 factors = factorize(r) 376 } else if size != l { 377 if i < dims-2 { 378 factors = factorize(q) 379 } else if i == dims-2 { 380 s[i+1] = q 381 break 382 } 383 } else { 384 break 385 } 386 } 387 shuffleInts(s, r) 388 } 389 390 // generate flags 391 flag := MemoryFlag(r.Intn(4)) 392 393 // generate order 394 order := DataOrder(r.Intn(4)) 395 396 var v *Dense 397 if scalar { 398 v = New(FromScalar(gendat.Interface())) 399 } else { 400 v = New(Of(of), WithShape(s...), WithBacking(gendat.Interface())) 401 } 402 403 v.flag = flag 404 v.AP.o = order 405 406 // generate engine 407 oeint := r.Intn(2) 408 eint := r.Intn(4) 409 switch eint { 410 case 0: 411 v.e = StdEng{} 412 if oeint == 0 { 413 v.oe = StdEng{} 414 } else { 415 v.oe = nil 416 } 417 case 1: 418 // check is to prevent panics which Float64Engine will do if asked to allocate memory for non float64s 419 if of == Float64 { 420 v.e = Float64Engine{} 421 if oeint == 0 { 422 v.oe = Float64Engine{} 423 } else { 424 v.oe = nil 425 } 426 } else { 427 v.e = StdEng{} 428 if oeint == 0 { 429 v.oe = StdEng{} 430 } else { 431 v.oe = nil 432 } 433 } 434 case 2: 435 // check is to prevent panics which Float64Engine will do if asked to allocate memory for non float64s 436 if of == Float32 { 437 v.e = Float32Engine{} 438 if oeint == 0 { 439 v.oe = Float32Engine{} 440 } else { 441 v.oe = nil 442 } 443 } else { 444 v.e = StdEng{} 445 if oeint == 0 { 446 v.oe = StdEng{} 447 } else { 448 v.oe = nil 449 } 450 } 451 case 3: 452 v.e = dummyEngine(true) 453 v.oe = nil 454 } 455 456 return reflect.ValueOf(v) 457 } 458 459 // fakemem is a byteslice, while making it a Memory 460 type fakemem []byte 461 462 func (m fakemem) Uintptr() uintptr { return uintptr(unsafe.Pointer(&m[0])) } 463 func (m fakemem) MemSize() uintptr { return uintptr(len(m)) } 464 func (m fakemem) Pointer() unsafe.Pointer { return unsafe.Pointer(&m[0]) } 465 466 // dummyEngine implements Engine. The bool indicates whether the data is native-accessible 467 type dummyEngine bool 468 469 func (e dummyEngine) AllocAccessible() bool { return bool(e) } 470 func (e dummyEngine) Alloc(size int64) (Memory, error) { 471 ps := make(fakemem, int(size)) 472 return ps, nil 473 } 474 func (e dummyEngine) Free(mem Memory, size int64) error { return nil } 475 func (e dummyEngine) Memset(mem Memory, val interface{}) error { return nil } 476 func (e dummyEngine) Memclr(mem Memory) {} 477 func (e dummyEngine) Memcpy(dst, src Memory) error { 478 if e { 479 var a, b storage.Header 480 a.Raw = storage.FromMemory(src.Uintptr(), src.MemSize()) 481 b.Raw = storage.FromMemory(dst.Uintptr(), dst.MemSize()) 482 483 copy(b.Raw, a.Raw) 484 return nil 485 } 486 return errors.New("Unable to copy ") 487 } 488 func (e dummyEngine) Accessible(mem Memory) (Memory, error) { return mem, nil } 489 func (e dummyEngine) WorksWith(order DataOrder) bool { return true } 490 491 // dummyEngine2 is used for testing additional methods that may not be provided in the stdeng 492 type dummyEngine2 struct { 493 e StdEng 494 } 495 496 func (e dummyEngine2) AllocAccessible() bool { return e.e.AllocAccessible() } 497 func (e dummyEngine2) Alloc(size int64) (Memory, error) { return e.e.Alloc(size) } 498 func (e dummyEngine2) Free(mem Memory, size int64) error { return e.e.Free(mem, size) } 499 func (e dummyEngine2) Memset(mem Memory, val interface{}) error { return e.e.Memset(mem, val) } 500 func (e dummyEngine2) Memclr(mem Memory) { e.e.Memclr(mem) } 501 func (e dummyEngine2) Memcpy(dst, src Memory) error { return e.e.Memcpy(dst, src) } 502 func (e dummyEngine2) Accessible(mem Memory) (Memory, error) { return e.e.Accessible(mem) } 503 func (e dummyEngine2) WorksWith(order DataOrder) bool { return e.e.WorksWith(order) } 504 505 func (e dummyEngine2) Argmax(t Tensor, axis int) (Tensor, error) { return e.e.Argmax(t, axis) } 506 func (e dummyEngine2) Argmin(t Tensor, axis int) (Tensor, error) { return e.e.Argmin(t, axis) } 507 508 func willerr(a *Dense, tc, eqtc *typeclass) (retVal, willFailEq bool) { 509 if err := typeclassCheck(a.Dtype(), eqtc); err == nil { 510 willFailEq = true 511 } 512 if err := typeclassCheck(a.Dtype(), tc); err != nil { 513 return true, willFailEq 514 } 515 516 retVal = retVal || !a.IsNativelyAccessible() 517 return 518 } 519 520 func qcErrCheck(t *testing.T, name string, a Dtyper, b interface{}, we bool, err error) (e error, retEarly bool) { 521 switch { 522 case !we && err != nil: 523 t.Errorf("Tests for %v (%v) was unable to proceed: %v", name, a.Dtype(), err) 524 return err, true 525 case we && err == nil: 526 if b == nil { 527 t.Errorf("Expected error when performing %v on %T of %v ", name, a, a.Dtype()) 528 return errors.New("Error"), true 529 } 530 if bd, ok := b.(Dtyper); ok { 531 t.Errorf("Expected error when performing %v on %T of %v and %T of %v", name, a, a.Dtype(), b, bd.Dtype()) 532 } else { 533 t.Errorf("Expected error when performing %v on %T of %v and %v of %T", name, a, a.Dtype(), b, b) 534 } 535 return errors.New("Error"), true 536 case we && err != nil: 537 return nil, true 538 } 539 return nil, false 540 } 541 542 func qcIsFloat(dt Dtype) bool { 543 if err := typeclassCheck(dt, floatcmplxTypes); err == nil { 544 return true 545 } 546 return false 547 } 548 549 func qcEqCheck(t *testing.T, dt Dtype, willFailEq bool, correct, got interface{}) bool { 550 isFloatTypes := qcIsFloat(dt) 551 if !willFailEq && (isFloatTypes && !allClose(correct, got) || (!isFloatTypes && !reflect.DeepEqual(correct, got))) { 552 t.Errorf("q.Dtype: %v", dt) 553 t.Errorf("correct\n%v", correct) 554 t.Errorf("got\n%v", got) 555 return false 556 } 557 return true 558 } 559 560 // DummyState is a dummy fmt.State, used to debug things 561 type DummyState struct { 562 *bytes.Buffer 563 } 564 565 func (d *DummyState) Width() (int, bool) { return 0, false } 566 func (d *DummyState) Precision() (int, bool) { return 0, false } 567 func (d *DummyState) Flag(c int) bool { return false }