github.com/wzzhu/tensor@v0.9.24/api_unary_test.go (about) 1 package tensor 2 3 import ( 4 "math/rand" 5 "testing" 6 "testing/quick" 7 "time" 8 "math" 9 10 "github.com/stretchr/testify/assert" 11 "github.com/chewxy/math32" 12 ) 13 14 /* 15 GENERATED FILE BY Genlib V1. DO NOT EDIT 16 */ 17 18 var clampTests = []struct { 19 a, reuse interface{} 20 min, max interface{} 21 correct interface{} 22 correctSliced interface{} 23 }{ 24 {[]int{1, 2, 3, 4}, []int{10, 20, 30, 40}, int(2), int(3), []int{2, 2, 3, 3}, []int{2, 2, 3}}, 25 {[]int8{1, 2, 3, 4}, []int8{10, 20, 30, 40}, int8(2), int8(3), []int8{2, 2, 3, 3}, []int8{2, 2, 3}}, 26 {[]int16{1, 2, 3, 4}, []int16{10, 20, 30, 40}, int16(2), int16(3), []int16{2, 2, 3, 3}, []int16{2, 2, 3}}, 27 {[]int32{1, 2, 3, 4}, []int32{10, 20, 30, 40}, int32(2), int32(3), []int32{2, 2, 3, 3}, []int32{2, 2, 3}}, 28 {[]int64{1, 2, 3, 4}, []int64{10, 20, 30, 40}, int64(2), int64(3), []int64{2, 2, 3, 3}, []int64{2, 2, 3}}, 29 {[]uint{1, 2, 3, 4}, []uint{10, 20, 30, 40}, uint(2), uint(3), []uint{2, 2, 3, 3}, []uint{2, 2, 3}}, 30 {[]uint8{1, 2, 3, 4}, []uint8{10, 20, 30, 40}, uint8(2), uint8(3), []uint8{2, 2, 3, 3}, []uint8{2, 2, 3}}, 31 {[]uint16{1, 2, 3, 4}, []uint16{10, 20, 30, 40}, uint16(2), uint16(3), []uint16{2, 2, 3, 3}, []uint16{2, 2, 3}}, 32 {[]uint32{1, 2, 3, 4}, []uint32{10, 20, 30, 40}, uint32(2), uint32(3), []uint32{2, 2, 3, 3}, []uint32{2, 2, 3}}, 33 {[]uint64{1, 2, 3, 4}, []uint64{10, 20, 30, 40}, uint64(2), uint64(3), []uint64{2, 2, 3, 3}, []uint64{2, 2, 3}}, 34 {[]float32{1, 2, 3, 4}, []float32{10, 20, 30, 40}, float32(2), float32(3), []float32{2, 2, 3, 3}, []float32{2, 2, 3}}, 35 {[]float64{1, 2, 3, 4}, []float64{10, 20, 30, 40}, float64(2), float64(3), []float64{2, 2, 3, 3}, []float64{2, 2, 3}}, 36 } 37 38 var clampTestsMasked = []struct { 39 a, reuse interface{} 40 min, max interface{} 41 correct interface{} 42 correctSliced interface{} 43 }{ 44 {[]int{1, 2, 3, 4}, []int{1, 20, 30, 40}, int(2), int(3), []int{1, 2, 3, 3}, []int{1, 2, 3}}, 45 {[]int8{1, 2, 3, 4}, []int8{1, 20, 30, 40}, int8(2), int8(3), []int8{1, 2, 3, 3}, []int8{1, 2, 3}}, 46 {[]int16{1, 2, 3, 4}, []int16{1, 20, 30, 40}, int16(2), int16(3), []int16{1, 2, 3, 3}, []int16{1, 2, 3}}, 47 {[]int32{1, 2, 3, 4}, []int32{1, 20, 30, 40}, int32(2), int32(3), []int32{1, 2, 3, 3}, []int32{1, 2, 3}}, 48 {[]int64{1, 2, 3, 4}, []int64{1, 20, 30, 40}, int64(2), int64(3), []int64{1, 2, 3, 3}, []int64{1, 2, 3}}, 49 {[]uint{1, 2, 3, 4}, []uint{1, 20, 30, 40}, uint(2), uint(3), []uint{1, 2, 3, 3}, []uint{1, 2, 3}}, 50 {[]uint8{1, 2, 3, 4}, []uint8{1, 20, 30, 40}, uint8(2), uint8(3), []uint8{1, 2, 3, 3}, []uint8{1, 2, 3}}, 51 {[]uint16{1, 2, 3, 4}, []uint16{1, 20, 30, 40}, uint16(2), uint16(3), []uint16{1, 2, 3, 3}, []uint16{1, 2, 3}}, 52 {[]uint32{1, 2, 3, 4}, []uint32{1, 20, 30, 40}, uint32(2), uint32(3), []uint32{1, 2, 3, 3}, []uint32{1, 2, 3}}, 53 {[]uint64{1, 2, 3, 4}, []uint64{1, 20, 30, 40}, uint64(2), uint64(3), []uint64{1, 2, 3, 3}, []uint64{1, 2, 3}}, 54 {[]float32{1, 2, 3, 4}, []float32{1, 20, 30, 40}, float32(2), float32(3), []float32{1, 2, 3, 3}, []float32{1, 2, 3}}, 55 {[]float64{1, 2, 3, 4}, []float64{1, 20, 30, 40}, float64(2), float64(3), []float64{1, 2, 3, 3}, []float64{1, 2, 3}}, 56 } 57 58 func TestClamp(t *testing.T) { 59 assert := assert.New(t) 60 var got, sliced Tensor 61 var T, reuse *Dense 62 var err error 63 for _, ct := range clampTests { 64 T = New(WithBacking(ct.a)) 65 // safe 66 if got, err = Clamp(T, ct.min, ct.max); err != nil { 67 t.Error(err) 68 continue 69 } 70 if got == T { 71 t.Error("expected got != T") 72 continue 73 } 74 assert.Equal(ct.correct, got.Data()) 75 76 // sliced safe 77 if sliced, err = T.Slice(makeRS(0, 3)); err != nil { 78 t.Error("Unable to slice T") 79 continue 80 } 81 if got, err = Clamp(sliced, ct.min, ct.max); err != nil { 82 t.Error(err) 83 continue 84 } 85 86 // reuse 87 reuse = New(WithBacking(ct.reuse)) 88 if got, err = Clamp(T, ct.min, ct.max, WithReuse(reuse)); err != nil { 89 t.Error(err) 90 continue 91 } 92 if got != reuse { 93 t.Error("expected got == reuse") 94 continue 95 } 96 assert.Equal(ct.correct, got.Data()) 97 98 // unsafe 99 if got, err = Clamp(T, ct.min, ct.max, UseUnsafe()); err != nil { 100 t.Error(err) 101 continue 102 } 103 if got != T { 104 t.Error("expected got == T") 105 continue 106 } 107 assert.Equal(ct.correct, got.Data()) 108 } 109 } 110 111 func TestClampMasked(t *testing.T) { 112 assert := assert.New(t) 113 var got, sliced Tensor 114 var T, reuse *Dense 115 var err error 116 for _, ct := range clampTestsMasked { 117 T = New(WithBacking(ct.a, []bool{true, false, false, false})) 118 // safe 119 if got, err = Clamp(T, ct.min, ct.max); err != nil { 120 t.Error(err) 121 continue 122 } 123 if got == T { 124 t.Error("expected got != T") 125 continue 126 } 127 assert.Equal(ct.correct, got.Data()) 128 129 // sliced safe 130 if sliced, err = T.Slice(makeRS(0, 3)); err != nil { 131 t.Error("Unable to slice T") 132 continue 133 } 134 if got, err = Clamp(sliced, ct.min, ct.max); err != nil { 135 t.Error(err) 136 continue 137 } 138 139 // reuse 140 reuse = New(WithBacking(ct.reuse, []bool{true, false, false, false})) 141 if got, err = Clamp(T, ct.min, ct.max, WithReuse(reuse)); err != nil { 142 t.Error(err) 143 continue 144 } 145 if got != reuse { 146 t.Error("expected got == reuse") 147 continue 148 } 149 assert.Equal(ct.correct, got.Data()) 150 151 // unsafe 152 if got, err = Clamp(T, ct.min, ct.max, UseUnsafe()); err != nil { 153 t.Error(err) 154 continue 155 } 156 if got != T { 157 t.Error("expected got == T") 158 continue 159 } 160 assert.Equal(ct.correct, got.Data()) 161 } 162 } 163 164 var signTests = []struct { 165 a, reuse interface{} 166 correct interface{} 167 correctSliced interface{} 168 }{ 169 {[]int{0, 1, 2, -2, -1}, []int{100, 10, 20, 30, 40}, []int{0, 1, 1, -1, -1}, []int{0, 1, 1, -1}}, 170 {[]int8{0, 1, 2, -2, -1}, []int8{100, 10, 20, 30, 40}, []int8{0, 1, 1, -1, -1}, []int8{0, 1, 1, -1}}, 171 {[]int16{0, 1, 2, -2, -1}, []int16{100, 10, 20, 30, 40}, []int16{0, 1, 1, -1, -1}, []int16{0, 1, 1, -1}}, 172 {[]int32{0, 1, 2, -2, -1}, []int32{100, 10, 20, 30, 40}, []int32{0, 1, 1, -1, -1}, []int32{0, 1, 1, -1}}, 173 {[]int64{0, 1, 2, -2, -1}, []int64{100, 10, 20, 30, 40}, []int64{0, 1, 1, -1, -1}, []int64{0, 1, 1, -1}}, 174 {[]float32{0, 1, 2, -2, -1}, []float32{100, 10, 20, 30, 40}, []float32{0, 1, 1, -1, -1}, []float32{0, 1, 1, -1}}, 175 {[]float64{0, 1, 2, -2, -1}, []float64{100, 10, 20, 30, 40}, []float64{0, 1, 1, -1, -1}, []float64{0, 1, 1, -1}}, 176 } 177 178 var signTestsMasked = []struct { 179 a, reuse interface{} 180 correct interface{} 181 // correctSliced interface{} 182 }{ 183 {[]int{1, 2, -2, -1}, []int{10, 20, 30, 40}, []int{1, 1, -2, -1}}, 184 {[]int8{1, 2, -2, -1}, []int8{10, 20, 30, 40}, []int8{1, 1, -2, -1}}, 185 {[]int16{1, 2, -2, -1}, []int16{10, 20, 30, 40}, []int16{1, 1, -2, -1}}, 186 {[]int32{1, 2, -2, -1}, []int32{10, 20, 30, 40}, []int32{1, 1, -2, -1}}, 187 {[]int64{1, 2, -2, -1}, []int64{10, 20, 30, 40}, []int64{1, 1, -2, -1}}, 188 {[]float32{1, 2, -2, -1}, []float32{10, 20, 30, 40}, []float32{1, 1, -2, -1}}, 189 {[]float64{1, 2, -2, -1}, []float64{10, 20, 30, 40}, []float64{1, 1, -2, -1}}, 190 } 191 192 func TestSign(t *testing.T) { 193 assert := assert.New(t) 194 var got, sliced Tensor 195 var T, reuse *Dense 196 var err error 197 for _, st := range signTests { 198 T = New(WithBacking(st.a)) 199 // safe 200 if got, err = Sign(T); err != nil { 201 t.Error(err) 202 continue 203 } 204 205 if got == T { 206 t.Error("expected got != T") 207 continue 208 } 209 assert.Equal(st.correct, got.Data()) 210 211 // sliced safe 212 if sliced, err = T.Slice(makeRS(0, 4)); err != nil { 213 t.Error("Unable to slice T") 214 continue 215 } 216 if got, err = Sign(sliced); err != nil { 217 t.Error(err) 218 continue 219 } 220 assert.Equal(st.correctSliced, got.Data()) 221 222 // reuse 223 reuse = New(WithBacking(st.reuse)) 224 if got, err = Sign(T, WithReuse(reuse)); err != nil { 225 t.Error(err) 226 continue 227 } 228 229 if got != reuse { 230 t.Error("expected got == reuse") 231 continue 232 } 233 assert.Equal(st.correct, got.Data()) 234 235 // unsafe 236 if got, err = Sign(T, UseUnsafe()); err != nil { 237 t.Error(err) 238 continue 239 } 240 if got != T { 241 t.Error("expected got == T") 242 continue 243 } 244 assert.Equal(st.correct, got.Data()) 245 } 246 } 247 248 func TestSignMasked(t *testing.T) { 249 assert := assert.New(t) 250 var got Tensor 251 var T, reuse *Dense 252 var err error 253 for _, st := range signTestsMasked { 254 T = New(WithBacking(st.a, []bool{false, false, true, false})) 255 // safe 256 if got, err = Sign(T); err != nil { 257 t.Error(err) 258 continue 259 } 260 261 if got == T { 262 t.Error("expected got != T") 263 continue 264 } 265 assert.Equal(st.correct, got.Data()) 266 267 // reuse 268 reuse = New(WithBacking(st.reuse, []bool{false, false, true, false})) 269 if got, err = Sign(T, WithReuse(reuse)); err != nil { 270 t.Error(err) 271 continue 272 } 273 274 if got != reuse { 275 t.Error("expected got == reuse") 276 continue 277 } 278 assert.Equal(st.correct, got.Data()) 279 280 // unsafe 281 if got, err = Sign(T, UseUnsafe()); err != nil { 282 t.Error(err) 283 continue 284 } 285 if got != T { 286 t.Error("expected got == T") 287 continue 288 } 289 assert.Equal(st.correct, got.Data()) 290 } 291 } 292 293 var negTestsMasked = []struct { 294 a, reuse interface{} 295 correct interface{} 296 }{ 297 {[]int{1, 2, -2, -1}, []int{10, 20, 30, 40}, []int{-1, -2, -2, 1}}, 298 {[]int8{1, 2, -2, -1}, []int8{10, 20, 30, 40}, []int8{-1, -2, -2, 1}}, 299 {[]int16{1, 2, -2, -1}, []int16{10, 20, 30, 40}, []int16{-1, -2, -2, 1}}, 300 {[]int32{1, 2, -2, -1}, []int32{10, 20, 30, 40}, []int32{-1, -2, -2, 1}}, 301 {[]int64{1, 2, -2, -1}, []int64{10, 20, 30, 40}, []int64{-1, -2, -2, 1}}, 302 {[]float32{1, 2, -2, -1}, []float32{10, 20, 30, 40}, []float32{-1, -2, -2, 1}}, 303 {[]float64{1, 2, -2, -1}, []float64{10, 20, 30, 40}, []float64{-1, -2, -2, 1}}, 304 } 305 306 func TestNegMasked(t *testing.T) { 307 assert := assert.New(t) 308 var got Tensor 309 var T, reuse *Dense 310 var err error 311 for _, st := range negTestsMasked { 312 T = New(WithBacking(st.a, []bool{false, false, true, false})) 313 // safe 314 if got, err = Neg(T); err != nil { 315 t.Error(err) 316 continue 317 } 318 319 if got == T { 320 t.Error("expected got != T") 321 continue 322 } 323 assert.Equal(st.correct, got.Data()) 324 325 // reuse 326 reuse = New(WithBacking(st.reuse, []bool{false, false, true, false})) 327 if got, err = Neg(T, WithReuse(reuse)); err != nil { 328 t.Error(err) 329 continue 330 } 331 332 if got != reuse { 333 t.Error("expected got == reuse") 334 continue 335 } 336 assert.Equal(st.correct, got.Data()) 337 338 // unsafe 339 if got, err = Neg(T, UseUnsafe()); err != nil { 340 t.Error(err) 341 continue 342 } 343 if got != T { 344 t.Error("expected got == T") 345 continue 346 } 347 assert.Equal(st.correct, got.Data()) 348 } 349 } 350 351 func TestInvSqrt(t *testing.T) { 352 var r *rand.Rand 353 invFn := func(q *Dense) bool { 354 a := q.Clone().(*Dense) 355 b := q.Clone().(*Dense) 356 correct := a.Clone().(*Dense) 357 we, willFailEq := willerr(a, floatTypes, nil) 358 _, ok := q.Engine().(InvSqrter) 359 we = we || !ok 360 361 // we'll exclude everything other than floats 362 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 363 return true 364 } 365 ret, err := InvSqrt(a) 366 if err, retEarly := qcErrCheck(t, "Inv", a, nil, we, err); retEarly { 367 if err != nil { 368 return false 369 } 370 return true 371 } 372 Sqrt(b, UseUnsafe()) 373 Mul(ret, b, UseUnsafe()) 374 if !qcEqCheck(t, b.Dtype(), willFailEq, correct.Data(), ret.Data()) { 375 return false 376 } 377 return true 378 } 379 380 r = rand.New(rand.NewSource(time.Now().UnixNano())) 381 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 382 t.Errorf("Inv tests for InvSqrt failed: %v", err) 383 } 384 385 // unsafe 386 invFn = func(q *Dense) bool { 387 a := q.Clone().(*Dense) 388 b := q.Clone().(*Dense) 389 correct := a.Clone().(*Dense) 390 we, willFailEq := willerr(a, floatTypes, nil) 391 _, ok := q.Engine().(InvSqrter) 392 we = we || !ok 393 394 // we'll exclude everything other than floats 395 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 396 return true 397 } 398 ret, err := InvSqrt(a, UseUnsafe()) 399 if err, retEarly := qcErrCheck(t, "Inv", a, nil, we, err); retEarly { 400 if err != nil { 401 return false 402 } 403 return true 404 } 405 Sqrt(b, UseUnsafe()) 406 Mul(ret, b, UseUnsafe()) 407 if !qcEqCheck(t, b.Dtype(), willFailEq, correct.Data(), ret.Data()) { 408 return false 409 } 410 if ret != a { 411 t.Errorf("Expected ret to be the same as a") 412 return false 413 } 414 return true 415 } 416 417 r = rand.New(rand.NewSource(time.Now().UnixNano())) 418 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 419 t.Errorf("Inv tests using unsafe for InvSqrt failed: %v", err) 420 } 421 422 // reuse 423 invFn = func(q *Dense) bool { 424 a := q.Clone().(*Dense) 425 b := q.Clone().(*Dense) 426 reuse := q.Clone().(*Dense) 427 reuse.Zero() 428 correct := a.Clone().(*Dense) 429 we, willFailEq := willerr(a, floatTypes, nil) 430 _, ok := q.Engine().(InvSqrter) 431 we = we || !ok 432 433 // we'll exclude everything other than floats 434 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 435 return true 436 } 437 ret, err := InvSqrt(a, WithReuse(reuse)) 438 if err, retEarly := qcErrCheck(t, "Inv", a, nil, we, err); retEarly { 439 if err != nil { 440 return false 441 } 442 return true 443 } 444 Sqrt(b, UseUnsafe()) 445 Mul(ret, b, UseUnsafe()) 446 if !qcEqCheck(t, b.Dtype(), willFailEq, correct.Data(), ret.Data()) { 447 return false 448 } 449 if ret != reuse { 450 t.Errorf("Expected ret to be the same as reuse") 451 return false 452 } 453 return true 454 } 455 r = rand.New(rand.NewSource(time.Now().UnixNano())) 456 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 457 t.Errorf("Inv tests with reuse for InvSqrt failed: %v", err) 458 } 459 460 // incr 461 invFn = func(q *Dense) bool { 462 a := q.Clone().(*Dense) 463 b := q.Clone().(*Dense) 464 incr := New(Of(a.t), WithShape(a.Shape().Clone()...)) 465 correct := a.Clone().(*Dense) 466 incr.Memset(identityVal(100, a.t)) 467 correct.Add(incr, UseUnsafe()) 468 469 we, willFailEq := willerr(a, floatTypes, nil) 470 _, ok := q.Engine().(InvSqrter) 471 we = we || !ok 472 473 // we'll exclude everything other than floats 474 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 475 return true 476 } 477 ret, err := InvSqrt(a, WithIncr(incr)) 478 if err, retEarly := qcErrCheck(t, "Inv", a, nil, we, err); retEarly { 479 if err != nil { 480 return false 481 } 482 return true 483 } 484 if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()); err != nil { 485 t.Errorf("err while subtracting incr: %v", err) 486 return false 487 } 488 Sqrt(b, UseUnsafe()) 489 Mul(ret, b, UseUnsafe()) 490 if !qcEqCheck(t, b.Dtype(), willFailEq, correct.Data(), ret.Data()) { 491 return false 492 } 493 if ret != incr { 494 t.Errorf("Expected ret to be the same as incr") 495 return false 496 } 497 return true 498 } 499 500 r = rand.New(rand.NewSource(time.Now().UnixNano())) 501 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 502 t.Errorf("Inv tests with incr for InvSqrt failed: %v", err) 503 } 504 505 } 506 507 func TestInv(t *testing.T) { 508 var r *rand.Rand 509 invFn := func(q *Dense) bool { 510 a := q.Clone().(*Dense) 511 correct := a.Clone().(*Dense) 512 we, willFailEq := willerr(a, floatTypes, nil) 513 _, ok := q.Engine().(Inver) 514 we = we || !ok 515 516 // we'll exclude everything other than floats 517 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 518 return true 519 } 520 ret, err := Inv(a) 521 if err, retEarly := qcErrCheck(t, "Inv", a, nil, we, err); retEarly { 522 if err != nil { 523 return false 524 } 525 return true 526 } 527 Mul(ret, a, UseUnsafe()) 528 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 529 return false 530 } 531 return true 532 } 533 534 r = rand.New(rand.NewSource(time.Now().UnixNano())) 535 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 536 t.Errorf("Inv tests for Inv failed: %v", err) 537 } 538 539 // unsafe 540 invFn = func(q *Dense) bool { 541 a := q.Clone().(*Dense) 542 b := q.Clone().(*Dense) 543 correct := a.Clone().(*Dense) 544 we, willFailEq := willerr(a, floatTypes, nil) 545 _, ok := q.Engine().(Inver) 546 we = we || !ok 547 548 // we'll exclude everything other than floats 549 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 550 return true 551 } 552 ret, err := Inv(a, UseUnsafe()) 553 if err, retEarly := qcErrCheck(t, "Inv", a, nil, we, err); retEarly { 554 if err != nil { 555 return false 556 } 557 return true 558 } 559 Mul(ret, b, UseUnsafe()) 560 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 561 return false 562 } 563 if ret != a { 564 t.Errorf("Expected ret to be the same as a") 565 return false 566 } 567 return true 568 } 569 r = rand.New(rand.NewSource(time.Now().UnixNano())) 570 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 571 t.Errorf("Inv tests using unsafe for Inv failed: %v", err) 572 } 573 574 // reuse 575 invFn = func(q *Dense) bool { 576 a := q.Clone().(*Dense) 577 correct := a.Clone().(*Dense) 578 reuse := a.Clone().(*Dense) 579 reuse.Zero() 580 we, willFailEq := willerr(a, floatTypes, nil) 581 _, ok := q.Engine().(Inver) 582 we = we || !ok 583 584 // we'll exclude everything other than floats 585 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 586 return true 587 } 588 ret, err := Inv(a, WithReuse(reuse)) 589 if err, retEarly := qcErrCheck(t, "Inv", a, nil, we, err); retEarly { 590 if err != nil { 591 return false 592 } 593 return true 594 } 595 Mul(ret, a, UseUnsafe()) 596 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 597 return false 598 } 599 if ret != reuse { 600 t.Errorf("Expected ret to be the same as reuse") 601 } 602 return true 603 } 604 r = rand.New(rand.NewSource(time.Now().UnixNano())) 605 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 606 t.Errorf("Inv tests using unsafe for Inv failed: %v", err) 607 } 608 609 // incr 610 invFn = func(q *Dense) bool { 611 a := q.Clone().(*Dense) 612 incr := New(Of(a.t), WithShape(a.Shape().Clone()...)) 613 correct := a.Clone().(*Dense) 614 incr.Memset(identityVal(100, a.t)) 615 correct.Add(incr, UseUnsafe()) 616 we, willFailEq := willerr(a, floatTypes, nil) 617 _, ok := q.Engine().(Inver) 618 we = we || !ok 619 620 // we'll exclude everything other than floats 621 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 622 return true 623 } 624 ret, err := Inv(a, WithIncr(incr)) 625 if err, retEarly := qcErrCheck(t, "Inv", a, nil, we, err); retEarly { 626 if err != nil { 627 return false 628 } 629 return true 630 } 631 if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()); err != nil { 632 t.Errorf("err while subtracting incr: %v", err) 633 return false 634 } 635 Mul(ret, a, UseUnsafe()) 636 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 637 return false 638 } 639 if ret != incr { 640 t.Errorf("Expected ret to be the same as incr") 641 } 642 return true 643 } 644 r = rand.New(rand.NewSource(time.Now().UnixNano())) 645 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 646 t.Errorf("Inv tests using unsafe for Inv failed: %v", err) 647 } 648 } 649 650 func TestLog10(t *testing.T) { 651 var r *rand.Rand 652 653 // default 654 invFn := func(q *Dense) bool { 655 a := q.Clone().(*Dense) 656 correct := a.Clone().(*Dense) 657 we, willFailEq := willerr(a, floatTypes, nil) 658 _, ok := q.Engine().(Log10er) 659 we = we || !ok 660 661 // we'll exclude everything other than floats 662 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 663 return true 664 } 665 ret, err := Log10(a) 666 if err, retEarly := qcErrCheck(t, "Log10", a, nil, we, err); retEarly { 667 if err != nil { 668 return false 669 } 670 return true 671 } 672 673 ten := identityVal(10, a.Dtype()) 674 Pow(ten, ret, UseUnsafe()) 675 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 676 return false 677 } 678 return true 679 } 680 681 r = rand.New(rand.NewSource(time.Now().UnixNano())) 682 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 683 t.Errorf("Inv tests for Log10 failed: %v", err) 684 } 685 686 687 // unsafe 688 invFn = func(q *Dense) bool { 689 a := q.Clone().(*Dense) 690 b := q.Clone().(*Dense) 691 correct := a.Clone().(*Dense) 692 we, willFailEq := willerr(a, floatTypes, nil) 693 _, ok := q.Engine().(Log10er) 694 we = we || !ok 695 696 // we'll exclude everything other than floats 697 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 698 return true 699 } 700 ret, err := Log10(a, UseUnsafe()) 701 if err, retEarly := qcErrCheck(t, "Log10", a, nil, we, err); retEarly { 702 if err != nil { 703 return false 704 } 705 return true 706 } 707 ten := identityVal(10, a.Dtype()) 708 Pow(ten, b, UseUnsafe()) 709 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 710 return false 711 } 712 if ret != a { 713 t.Errorf("Expected ret to be the same as a") 714 return false 715 } 716 return true 717 } 718 r = rand.New(rand.NewSource(time.Now().UnixNano())) 719 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 720 t.Errorf("Inv tests using unsafe for Log10 failed: %v", err) 721 } 722 723 724 // reuse 725 invFn = func(q *Dense) bool { 726 a := q.Clone().(*Dense) 727 correct := a.Clone().(*Dense) 728 reuse := a.Clone().(*Dense) 729 reuse.Zero() 730 we, willFailEq := willerr(a, floatTypes, nil) 731 _, ok := q.Engine().(Log10er) 732 we = we || !ok 733 734 // we'll exclude everything other than floats 735 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 736 return true 737 } 738 ret, err := Log10(a, WithReuse(reuse)) 739 if err, retEarly := qcErrCheck(t, "Log10", a, nil, we, err); retEarly { 740 if err != nil { 741 return false 742 } 743 return true 744 } 745 ten := identityVal(10, a.Dtype()) 746 Pow(ten, ret, UseUnsafe()) 747 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 748 return false 749 } 750 if ret != reuse { 751 t.Errorf("Expected ret to be the same as reuse") 752 } 753 return true 754 } 755 r = rand.New(rand.NewSource(time.Now().UnixNano())) 756 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 757 t.Errorf("Inv tests using unsafe for Log10 failed: %v", err) 758 } 759 760 // incr 761 invFn = func(q *Dense) bool { 762 a := q.Clone().(*Dense) 763 incr := New(Of(a.t), WithShape(a.Shape().Clone()...)) 764 correct := a.Clone().(*Dense) 765 incr.Memset(identityVal(100, a.t)) 766 correct.Add(incr, UseUnsafe()) 767 we, willFailEq := willerr(a, floatTypes, nil) 768 _, ok := q.Engine().(Log10er) 769 we = we || !ok 770 771 // we'll exclude everything other than floats 772 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 773 return true 774 } 775 ret, err := Log10(a, WithIncr(incr)) 776 if err, retEarly := qcErrCheck(t, "Log10", a, nil, we, err); retEarly { 777 if err != nil { 778 return false 779 } 780 return true 781 } 782 if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()); err != nil { 783 t.Errorf("err while subtracting incr: %v", err) 784 return false 785 } 786 ten := identityVal(10, a.Dtype()) 787 Pow(ten, ret, UseUnsafe()) 788 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 789 return false 790 } 791 if ret != incr { 792 t.Errorf("Expected ret to be the same as incr") 793 } 794 return true 795 } 796 r = rand.New(rand.NewSource(time.Now().UnixNano())) 797 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 798 t.Errorf("Inv tests using unsafe for Log10 failed: %v", err) 799 } 800 801 } 802 803 func TestAbs(t *testing.T) { 804 var r *rand.Rand 805 absFn := func(q *Dense) bool { 806 a := q.Clone().(*Dense) 807 zeros := New(Of(q.Dtype()), WithShape(q.Shape().Clone()...)) 808 correct := New(Of(Bool), WithShape(q.Shape().Clone()...)) 809 correct.Memset(true) 810 // we'll exclude everything other than ordtypes because complex numbers cannot be abs'd 811 if err := typeclassCheck(a.Dtype(), ordTypes); err != nil { 812 return true 813 } 814 we, willFailEq := willerr(a, signedTypes, nil) 815 _, ok := q.Engine().(Abser) 816 we = we || !ok 817 818 ret, err := Abs(a) 819 if err, retEarly := qcErrCheck(t, "Abs", a, nil, we, err); retEarly { 820 if err != nil { 821 return false 822 } 823 return true 824 } 825 826 check, _ := Gte(ret, zeros) 827 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), check.Data()) { 828 return false 829 } 830 return true 831 } 832 833 r = rand.New(rand.NewSource(time.Now().UnixNano())) 834 if err := quick.Check(absFn, &quick.Config{Rand: r}); err != nil { 835 t.Errorf("Inv tests for Abs failed: %v", err) 836 } 837 } 838 839 840 func TestTanh(t *testing.T) { 841 var r *rand.Rand 842 // default 843 invFn := func(q *Dense) bool { 844 a := q.Clone().(*Dense) 845 correct := a.Clone().(*Dense) 846 we, willFailEq := willerr(a, floatTypes, nil) 847 _, ok := q.Engine().(Tanher) 848 we = we || !ok 849 850 // we'll exclude everything other than floats 851 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 852 return true 853 } 854 ret, err := Tanh(a) 855 if err, retEarly := qcErrCheck(t, "Tanh", a, nil, we, err); retEarly { 856 if err != nil { 857 return false 858 } 859 return true 860 } 861 switch a.Dtype() { 862 case Float64: 863 if ret, err = ret.Apply(math.Atan, UseUnsafe()); err != nil { 864 t.Error(err) 865 return false 866 } 867 case Float32: 868 if ret, err = ret.Apply(math32.Atan, UseUnsafe()); err != nil { 869 t.Error(err) 870 return false 871 } 872 } 873 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 874 return false 875 } 876 return true 877 } 878 879 r = rand.New(rand.NewSource(time.Now().UnixNano())) 880 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 881 t.Errorf("Inv tests for Tanh failed: %v", err) 882 } 883 884 // unsafe 885 invFn = func(q *Dense) bool { 886 a := q.Clone().(*Dense) 887 correct := a.Clone().(*Dense) 888 we, willFailEq := willerr(a, floatTypes, nil) 889 _, ok := q.Engine().(Tanher) 890 we = we || !ok 891 892 // we'll exclude everything other than floats 893 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 894 return true 895 } 896 ret, err := Tanh(a, UseUnsafe()) 897 if err, retEarly := qcErrCheck(t, "Tanh", a, nil, we, err); retEarly { 898 if err != nil { 899 return false 900 } 901 return true 902 } 903 switch a.Dtype() { 904 case Float64: 905 if ret, err = ret.Apply(math.Atan, UseUnsafe()); err != nil { 906 t.Error(err) 907 return false 908 } 909 case Float32: 910 if ret, err = ret.Apply(math32.Atan, UseUnsafe()); err != nil { 911 t.Error(err) 912 return false 913 } 914 } 915 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 916 return false 917 } 918 if ret != a { 919 t.Errorf("Expected ret to be the same as a") 920 return false 921 } 922 return true 923 } 924 r = rand.New(rand.NewSource(time.Now().UnixNano())) 925 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 926 t.Errorf("Inv tests using unsafe for Tanh failed: %v", err) 927 } 928 929 930 // reuse 931 invFn = func(q *Dense) bool { 932 a := q.Clone().(*Dense) 933 correct := a.Clone().(*Dense) 934 reuse := a.Clone().(*Dense) 935 reuse.Zero() 936 we, willFailEq := willerr(a, floatTypes, nil) 937 _, ok := q.Engine().(Tanher) 938 we = we || !ok 939 940 // we'll exclude everything other than floats 941 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 942 return true 943 } 944 ret, err := Tanh(a, WithReuse(reuse)) 945 if err, retEarly := qcErrCheck(t, "Tanh", a, nil, we, err); retEarly { 946 if err != nil { 947 return false 948 } 949 return true 950 } 951 switch a.Dtype() { 952 case Float64: 953 if ret, err = ret.Apply(math.Atan, UseUnsafe()); err != nil { 954 t.Error(err) 955 return false 956 } 957 case Float32: 958 if ret, err = ret.Apply(math32.Atan, UseUnsafe()); err != nil { 959 t.Error(err) 960 return false 961 } 962 } 963 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 964 return false 965 } 966 if ret != reuse { 967 t.Errorf("Expected ret to be the same as reuse") 968 } 969 return true 970 } 971 r = rand.New(rand.NewSource(time.Now().UnixNano())) 972 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 973 t.Errorf("Inv tests using unsafe for Tanh failed: %v", err) 974 } 975 976 977 // incr 978 invFn = func(q *Dense) bool { 979 a := q.Clone().(*Dense) 980 incr := New(Of(a.t), WithShape(a.Shape().Clone()...)) 981 correct := a.Clone().(*Dense) 982 incr.Memset(identityVal(100, a.t)) 983 correct.Add(incr, UseUnsafe()) 984 we, willFailEq := willerr(a, floatTypes, nil) 985 _, ok := q.Engine().(Tanher) 986 we = we || !ok 987 988 // we'll exclude everything other than floats 989 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 990 return true 991 } 992 ret, err := Tanh(a, WithIncr(incr)) 993 if err, retEarly := qcErrCheck(t, "Tanh", a, nil, we, err); retEarly { 994 if err != nil { 995 return false 996 } 997 return true 998 } 999 if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()); err != nil { 1000 t.Errorf("err while subtracting incr: %v", err) 1001 return false 1002 } 1003 switch a.Dtype() { 1004 case Float64: 1005 if ret, err = ret.Apply(math.Atan, UseUnsafe()); err != nil { 1006 t.Error(err) 1007 return false 1008 } 1009 case Float32: 1010 if ret, err = ret.Apply(math32.Atan, UseUnsafe()); err != nil { 1011 t.Error(err) 1012 return false 1013 } 1014 } 1015 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 1016 return false 1017 } 1018 if ret != incr { 1019 t.Errorf("Expected ret to be the same as incr") 1020 } 1021 return true 1022 } 1023 r = rand.New(rand.NewSource(time.Now().UnixNano())) 1024 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 1025 t.Errorf("Inv tests using unsafe for Tanh failed: %v", err) 1026 } 1027 } 1028 1029 func TestLog2(t *testing.T) { 1030 var r *rand.Rand 1031 1032 // default 1033 invFn := func(q *Dense) bool { 1034 a := q.Clone().(*Dense) 1035 correct := a.Clone().(*Dense) 1036 we, willFailEq := willerr(a, floatTypes, nil) 1037 _, ok := q.Engine().(Log2er) 1038 we = we || !ok 1039 1040 // we'll exclude everything other than floats 1041 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 1042 return true 1043 } 1044 ret, err := Log2(a) 1045 if err, retEarly := qcErrCheck(t, "Log2", a, nil, we, err); retEarly { 1046 if err != nil { 1047 return false 1048 } 1049 return true 1050 } 1051 1052 two := identityVal(2, a.Dtype()) 1053 Pow(two, ret, UseUnsafe()) 1054 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 1055 return false 1056 } 1057 return true 1058 } 1059 1060 r = rand.New(rand.NewSource(time.Now().UnixNano())) 1061 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 1062 t.Errorf("Inv tests for Log2 failed: %v", err) 1063 } 1064 1065 1066 // unsafe 1067 invFn = func(q *Dense) bool { 1068 a := q.Clone().(*Dense) 1069 b := q.Clone().(*Dense) 1070 correct := a.Clone().(*Dense) 1071 we, willFailEq := willerr(a, floatTypes, nil) 1072 _, ok := q.Engine().(Log2er) 1073 we = we || !ok 1074 1075 // we'll exclude everything other than floats 1076 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 1077 return true 1078 } 1079 ret, err := Log2(a, UseUnsafe()) 1080 if err, retEarly := qcErrCheck(t, "Log2", a, nil, we, err); retEarly { 1081 if err != nil { 1082 return false 1083 } 1084 return true 1085 } 1086 two := identityVal(2, a.Dtype()) 1087 Pow(two, b, UseUnsafe()) 1088 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 1089 return false 1090 } 1091 if ret != a { 1092 t.Errorf("Expected ret to be the same as a") 1093 return false 1094 } 1095 return true 1096 } 1097 r = rand.New(rand.NewSource(time.Now().UnixNano())) 1098 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 1099 t.Errorf("Inv tests using unsafe for Log2 failed: %v", err) 1100 } 1101 1102 1103 // reuse 1104 invFn = func(q *Dense) bool { 1105 a := q.Clone().(*Dense) 1106 correct := a.Clone().(*Dense) 1107 reuse := a.Clone().(*Dense) 1108 reuse.Zero() 1109 we, willFailEq := willerr(a, floatTypes, nil) 1110 _, ok := q.Engine().(Log2er) 1111 we = we || !ok 1112 1113 // we'll exclude everything other than floats 1114 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 1115 return true 1116 } 1117 ret, err := Log2(a, WithReuse(reuse)) 1118 if err, retEarly := qcErrCheck(t, "Log2", a, nil, we, err); retEarly { 1119 if err != nil { 1120 return false 1121 } 1122 return true 1123 } 1124 two := identityVal(2, a.Dtype()) 1125 Pow(two, ret, UseUnsafe()) 1126 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 1127 return false 1128 } 1129 if ret != reuse { 1130 t.Errorf("Expected ret to be the same as reuse") 1131 } 1132 return true 1133 } 1134 r = rand.New(rand.NewSource(time.Now().UnixNano())) 1135 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 1136 t.Errorf("Inv tests using unsafe for Log2 failed: %v", err) 1137 } 1138 1139 // incr 1140 invFn = func(q *Dense) bool { 1141 a := q.Clone().(*Dense) 1142 incr := New(Of(a.t), WithShape(a.Shape().Clone()...)) 1143 correct := a.Clone().(*Dense) 1144 incr.Memset(identityVal(100, a.t)) 1145 correct.Add(incr, UseUnsafe()) 1146 we, willFailEq := willerr(a, floatTypes, nil) 1147 _, ok := q.Engine().(Log2er) 1148 we = we || !ok 1149 1150 // we'll exclude everything other than floats 1151 if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { 1152 return true 1153 } 1154 ret, err := Log2(a, WithIncr(incr)) 1155 if err, retEarly := qcErrCheck(t, "Log2", a, nil, we, err); retEarly { 1156 if err != nil { 1157 return false 1158 } 1159 return true 1160 } 1161 if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()); err != nil { 1162 t.Errorf("err while subtracting incr: %v", err) 1163 return false 1164 } 1165 two := identityVal(2, a.Dtype()) 1166 Pow(two, ret, UseUnsafe()) 1167 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 1168 return false 1169 } 1170 if ret != incr { 1171 t.Errorf("Expected ret to be the same as incr") 1172 } 1173 return true 1174 } 1175 r = rand.New(rand.NewSource(time.Now().UnixNano())) 1176 if err := quick.Check(invFn, &quick.Config{Rand: r}); err != nil { 1177 t.Errorf("Inv tests using unsafe for Log2 failed: %v", err) 1178 } 1179 1180 }