github.com/wzzhu/tensor@v0.9.24/api_arith_test.go (about) 1 package tensor 2 3 import ( 4 "log" 5 "math/rand" 6 "testing" 7 "testing/quick" 8 "time" 9 10 "github.com/stretchr/testify/assert" 11 ) 12 13 // This file contains the tests for API functions that aren't generated by genlib 14 15 func TestMod(t *testing.T) { 16 a := New(WithBacking([]float64{1, 2, 3, 4})) 17 b := New(WithBacking([]float64{1, 1, 1, 1})) 18 var correct interface{} = []float64{0, 0, 0, 0} 19 20 // vec-vec 21 res, err := Mod(a, b) 22 if err != nil { 23 t.Fatalf("Error: %v", err) 24 } 25 assert.Equal(t, correct, res.Data()) 26 27 // scalar 28 if res, err = Mod(a, 1.0); err != nil { 29 t.Fatalf("Error: %v", err) 30 } 31 assert.Equal(t, correct, res.Data()) 32 } 33 34 func TestFMA(t *testing.T) { 35 same := func(q *Dense) bool { 36 a := q.Clone().(*Dense) 37 x := q.Clone().(*Dense) 38 y := New(Of(q.Dtype()), WithShape(q.Shape().Clone()...)) 39 y.Memset(identityVal(100, q.Dtype())) 40 WithEngine(q.Engine())(y) 41 y2 := y.Clone().(*Dense) 42 43 we, willFailEq := willerr(a, numberTypes, nil) 44 _, ok1 := q.Engine().(FMAer) 45 _, ok2 := q.Engine().(Muler) 46 _, ok3 := q.Engine().(Adder) 47 we = we || (!ok1 && (!ok2 || !ok3)) 48 49 f, err := FMA(a, x, y) 50 if err, retEarly := qcErrCheck(t, "FMA#1", a, x, we, err); retEarly { 51 if err != nil { 52 log.Printf("q.Engine() %T", q.Engine()) 53 return false 54 } 55 return true 56 } 57 58 we, _ = willerr(a, numberTypes, nil) 59 _, ok := a.Engine().(Muler) 60 we = we || !ok 61 wi, err := Mul(a, x, WithIncr(y2)) 62 if err, retEarly := qcErrCheck(t, "FMA#2", a, x, we, err); retEarly { 63 if err != nil { 64 return false 65 } 66 return true 67 } 68 return qcEqCheck(t, q.Dtype(), willFailEq, wi, f) 69 } 70 r := rand.New(rand.NewSource(time.Now().UnixNano())) 71 if err := quick.Check(same, &quick.Config{Rand: r}); err != nil { 72 t.Error(err) 73 } 74 75 // specific engines 76 var eng Engine 77 78 // FLOAT64 ENGINE 79 80 // vec-vec 81 eng = Float64Engine{} 82 a := New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng)) 83 x := New(WithBacking(Range(Float64, 1, 101)), WithEngine(eng)) 84 y := New(Of(Float64), WithShape(100), WithEngine(eng)) 85 86 f, err := FMA(a, x, y) 87 if err != nil { 88 t.Fatal(err) 89 } 90 91 a2 := New(WithBacking(Range(Float64, 0, 100))) 92 x2 := New(WithBacking(Range(Float64, 1, 101))) 93 y2 := New(Of(Float64), WithShape(100)) 94 f2, err := Mul(a2, x2, WithIncr(y2)) 95 if err != nil { 96 t.Fatal(err) 97 } 98 99 assert.Equal(t, f.Data(), f2.Data()) 100 101 // vec-scalar 102 a = New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng)) 103 y = New(Of(Float64), WithShape(100)) 104 105 if f, err = FMA(a, 2.0, y); err != nil { 106 t.Fatal(err) 107 } 108 109 a2 = New(WithBacking(Range(Float64, 0, 100))) 110 y2 = New(Of(Float64), WithShape(100)) 111 if f2, err = Mul(a2, 2.0, WithIncr(y2)); err != nil { 112 t.Fatal(err) 113 } 114 115 assert.Equal(t, f.Data(), f2.Data()) 116 117 // FLOAT32 engine 118 eng = Float32Engine{} 119 a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng)) 120 x = New(WithBacking(Range(Float32, 1, 101)), WithEngine(eng)) 121 y = New(Of(Float32), WithShape(100), WithEngine(eng)) 122 123 f, err = FMA(a, x, y) 124 if err != nil { 125 t.Fatal(err) 126 } 127 128 a2 = New(WithBacking(Range(Float32, 0, 100))) 129 x2 = New(WithBacking(Range(Float32, 1, 101))) 130 y2 = New(Of(Float32), WithShape(100)) 131 f2, err = Mul(a2, x2, WithIncr(y2)) 132 if err != nil { 133 t.Fatal(err) 134 } 135 136 assert.Equal(t, f.Data(), f2.Data()) 137 138 // vec-scalar 139 a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng)) 140 y = New(Of(Float32), WithShape(100)) 141 142 if f, err = FMA(a, float32(2), y); err != nil { 143 t.Fatal(err) 144 } 145 146 a2 = New(WithBacking(Range(Float32, 0, 100))) 147 y2 = New(Of(Float32), WithShape(100)) 148 if f2, err = Mul(a2, float32(2), WithIncr(y2)); err != nil { 149 t.Fatal(err) 150 } 151 152 assert.Equal(t, f.Data(), f2.Data()) 153 154 } 155 156 func TestMulScalarScalar(t *testing.T) { 157 // scalar-scalar 158 a := New(WithBacking([]float64{2})) 159 b := New(WithBacking([]float64{3})) 160 var correct interface{} = 6.0 161 162 res, err := Mul(a, b) 163 if err != nil { 164 t.Fatalf("Error: %v", err) 165 } 166 assert.Equal(t, correct, res.Data()) 167 168 // Test commutativity 169 res, err = Mul(b, a) 170 if err != nil { 171 t.Fatalf("Error: %v", err) 172 } 173 assert.Equal(t, correct, res.Data()) 174 175 // scalar-tensor 176 a = New(WithBacking([]float64{3, 2})) 177 b = New(WithBacking([]float64{2})) 178 correct = []float64{6, 4} 179 180 res, err = Mul(a, b) 181 if err != nil { 182 t.Fatalf("Error: %v", err) 183 } 184 assert.Equal(t, correct, res.Data()) 185 186 // Test commutativity 187 res, err = Mul(b, a) 188 if err != nil { 189 t.Fatalf("Error: %v", err) 190 } 191 assert.Equal(t, correct, res.Data()) 192 193 // tensor - tensor 194 a = New(WithBacking([]float64{3, 5})) 195 b = New(WithBacking([]float64{7, 2})) 196 correct = []float64{21, 10} 197 198 res, err = Mul(a, b) 199 if err != nil { 200 t.Fatalf("Error: %v", err) 201 } 202 assert.Equal(t, correct, res.Data()) 203 204 // Test commutativity 205 res, err = Mul(b, a) 206 if err != nil { 207 t.Fatalf("Error: %v", err) 208 } 209 assert.Equal(t, correct, res.Data()) 210 211 // Interface - tensor 212 ai := 2.0 213 b = NewDense(Float64, Shape{1, 1}, WithBacking([]float64{3})) 214 correct = []float64{6.0} 215 216 res, err = Mul(ai, b) 217 if err != nil { 218 t.Fatalf("Error: %v", err) 219 } 220 assert.Equal(t, correct, res.Data()) 221 222 // Commutativity 223 res, err = Mul(b, ai) 224 if err != nil { 225 t.Fatalf("Error: %v", err) 226 } 227 assert.Equal(t, correct, res.Data()) 228 } 229 230 func TestDivScalarScalar(t *testing.T) { 231 // scalar-scalar 232 a := New(WithBacking([]float64{6})) 233 b := New(WithBacking([]float64{2})) 234 var correct interface{} = 3.0 235 236 res, err := Div(a, b) 237 if err != nil { 238 t.Fatalf("Error: %v", err) 239 } 240 assert.Equal(t, correct, res.Data()) 241 242 // scalar-tensor 243 a = New(WithBacking([]float64{6, 4})) 244 b = New(WithBacking([]float64{2})) 245 correct = []float64{3, 2} 246 247 res, err = Div(a, b) 248 if err != nil { 249 t.Fatalf("Error: %v", err) 250 } 251 assert.Equal(t, correct, res.Data()) 252 253 // tensor-scalar 254 a = New(WithBacking([]float64{6})) 255 b = New(WithBacking([]float64{3, 2})) 256 correct = []float64{2, 3} 257 258 res, err = Div(a, b) 259 if err != nil { 260 t.Fatalf("Error: %v", err) 261 } 262 assert.Equal(t, correct, res.Data()) 263 264 // tensor - tensor 265 a = New(WithBacking([]float64{21, 10})) 266 b = New(WithBacking([]float64{7, 2})) 267 correct = []float64{3, 5} 268 269 res, err = Div(a, b) 270 if err != nil { 271 t.Fatalf("Error: %v", err) 272 } 273 assert.Equal(t, correct, res.Data()) 274 275 // interface-scalar 276 ai := 6.0 277 b = New(WithBacking([]float64{2})) 278 correct = 3.0 279 280 res, err = Div(ai, b) 281 if err != nil { 282 t.Fatalf("Error: %v", err) 283 } 284 assert.Equal(t, correct, res.Data()) 285 286 // scalar-interface 287 a = New(WithBacking([]float64{6})) 288 bi := 2.0 289 correct = 3.0 290 291 res, err = Div(a, bi) 292 if err != nil { 293 t.Fatalf("Error: %v", err) 294 } 295 assert.Equal(t, correct, res.Data()) 296 } 297 298 func TestAddScalarScalar(t *testing.T) { 299 // scalar-scalar 300 a := New(WithBacking([]float64{2})) 301 b := New(WithBacking([]float64{3})) 302 var correct interface{} = 5.0 303 304 res, err := Add(a, b) 305 if err != nil { 306 t.Fatalf("Error: %v", err) 307 } 308 assert.Equal(t, correct, res.Data()) 309 310 // Test commutativity 311 res, err = Add(b, a) 312 if err != nil { 313 t.Fatalf("Error: %v", err) 314 } 315 assert.Equal(t, correct, res.Data()) 316 317 // scalar-tensor 318 a = New(WithBacking([]float64{3, 2})) 319 b = New(WithBacking([]float64{2})) 320 correct = []float64{5, 4} 321 322 res, err = Add(a, b) 323 if err != nil { 324 t.Fatalf("Error: %v", err) 325 } 326 assert.Equal(t, correct, res.Data()) 327 328 // Test commutativity 329 res, err = Add(b, a) 330 if err != nil { 331 t.Fatalf("Error: %v", err) 332 } 333 assert.Equal(t, correct, res.Data()) 334 335 // tensor - tensor 336 a = New(WithBacking([]float64{3, 5})) 337 b = New(WithBacking([]float64{7, 2})) 338 correct = []float64{10, 7} 339 340 res, err = Add(a, b) 341 if err != nil { 342 t.Fatalf("Error: %v", err) 343 } 344 assert.Equal(t, correct, res.Data()) 345 346 // Test commutativity 347 res, err = Add(b, a) 348 if err != nil { 349 t.Fatalf("Error: %v", err) 350 } 351 assert.Equal(t, correct, res.Data()) 352 353 // interface-scalar 354 ai := 2.0 355 b = New(WithBacking([]float64{3})) 356 correct = 5.0 357 358 res, err = Add(ai, b) 359 if err != nil { 360 t.Fatalf("Error: %v", err) 361 } 362 assert.Equal(t, correct, res.Data()) 363 364 // Test commutativity 365 res, err = Add(b, ai) 366 if err != nil { 367 t.Fatalf("Error: %v", err) 368 } 369 assert.Equal(t, correct, res.Data()) 370 } 371 372 func TestSubScalarScalar(t *testing.T) { 373 // scalar-scalar 374 a := New(WithBacking([]float64{6})) 375 b := New(WithBacking([]float64{2})) 376 var correct interface{} = 4.0 377 378 res, err := Sub(a, b) 379 if err != nil { 380 t.Fatalf("Error: %v", err) 381 } 382 assert.Equal(t, correct, res.Data()) 383 384 // scalar-tensor 385 a = New(WithBacking([]float64{6, 4})) 386 b = New(WithBacking([]float64{2})) 387 correct = []float64{4, 2} 388 389 res, err = Sub(a, b) 390 if err != nil { 391 t.Fatalf("Error: %v", err) 392 } 393 assert.Equal(t, correct, res.Data()) 394 395 // tensor-scalar 396 a = New(WithBacking([]float64{6})) 397 b = New(WithBacking([]float64{3, 2})) 398 correct = []float64{3, 4} 399 400 res, err = Sub(a, b) 401 if err != nil { 402 t.Fatalf("Error: %v", err) 403 } 404 assert.Equal(t, correct, res.Data()) 405 406 // tensor - tensor 407 a = New(WithBacking([]float64{21, 10})) 408 b = New(WithBacking([]float64{7, 2})) 409 correct = []float64{14, 8} 410 411 res, err = Sub(a, b) 412 if err != nil { 413 t.Fatalf("Error: %v", err) 414 } 415 assert.Equal(t, correct, res.Data()) 416 417 // interface-scalar 418 ai := 6.0 419 b = New(WithBacking([]float64{2})) 420 correct = 4.0 421 422 res, err = Sub(ai, b) 423 if err != nil { 424 t.Fatalf("Error: %v", err) 425 } 426 assert.Equal(t, correct, res.Data()) 427 428 // scalar-interface 429 a = New(WithBacking([]float64{6})) 430 bi := 2.0 431 correct = 4.0 432 433 res, err = Sub(a, bi) 434 if err != nil { 435 t.Fatalf("Error: %v", err) 436 } 437 assert.Equal(t, correct, res.Data()) 438 } 439 440 func TestModScalarScalar(t *testing.T) { 441 // scalar-scalar 442 a := New(WithBacking([]float64{5})) 443 b := New(WithBacking([]float64{2})) 444 var correct interface{} = 1.0 445 446 res, err := Mod(a, b) 447 if err != nil { 448 t.Fatalf("Error: %v", err) 449 } 450 assert.Equal(t, correct, res.Data()) 451 452 // scalar-tensor 453 a = New(WithBacking([]float64{5, 4})) 454 b = New(WithBacking([]float64{2})) 455 correct = []float64{1, 0} 456 457 res, err = Mod(a, b) 458 if err != nil { 459 t.Fatalf("Error: %v", err) 460 } 461 assert.Equal(t, correct, res.Data()) 462 463 // tensor-scalar 464 a = New(WithBacking([]float64{5})) 465 b = New(WithBacking([]float64{3, 2})) 466 correct = []float64{2, 1} 467 468 res, err = Mod(a, b) 469 if err != nil { 470 t.Fatalf("Error: %v", err) 471 } 472 assert.Equal(t, correct, res.Data()) 473 474 // tensor - tensor 475 a = New(WithBacking([]float64{22, 10})) 476 b = New(WithBacking([]float64{7, 2})) 477 correct = []float64{1, 0} 478 479 res, err = Mod(a, b) 480 if err != nil { 481 t.Fatalf("Error: %v", err) 482 } 483 assert.Equal(t, correct, res.Data()) 484 485 // interface-scalar 486 ai := 5.0 487 b = New(WithBacking([]float64{2})) 488 correct = 1.0 489 490 res, err = Mod(ai, b) 491 if err != nil { 492 t.Fatalf("Error: %v", err) 493 } 494 assert.Equal(t, correct, res.Data()) 495 496 // scalar-interface 497 a = New(WithBacking([]float64{5})) 498 bi := 2.0 499 correct = 1.0 500 501 res, err = Mod(a, bi) 502 if err != nil { 503 t.Fatalf("Error: %v", err) 504 } 505 assert.Equal(t, correct, res.Data()) 506 } 507 508 func TestPowScalarScalar(t *testing.T) { 509 // scalar-scalar 510 a := New(WithBacking([]float64{6})) 511 b := New(WithBacking([]float64{2})) 512 var correct interface{} = 36.0 513 514 res, err := Pow(a, b) 515 if err != nil { 516 t.Fatalf("Error: %v", err) 517 } 518 assert.Equal(t, correct, res.Data()) 519 520 // scalar-tensor 521 a = New(WithBacking([]float64{6, 4})) 522 b = New(WithBacking([]float64{2})) 523 correct = []float64{36, 16} 524 525 res, err = Pow(a, b) 526 if err != nil { 527 t.Fatalf("Error: %v", err) 528 } 529 assert.Equal(t, correct, res.Data()) 530 531 // tensor-scalar 532 a = New(WithBacking([]float64{6})) 533 b = New(WithBacking([]float64{3, 2})) 534 correct = []float64{216, 36} 535 536 res, err = Pow(a, b) 537 if err != nil { 538 t.Fatalf("Error: %v", err) 539 } 540 assert.Equal(t, correct, res.Data()) 541 542 // tensor - tensor 543 a = New(WithBacking([]float64{3, 10})) 544 b = New(WithBacking([]float64{7, 2})) 545 correct = []float64{2187, 100} 546 547 res, err = Pow(a, b) 548 if err != nil { 549 t.Fatalf("Error: %v", err) 550 } 551 assert.Equal(t, correct, res.Data()) 552 553 // interface-scalar 554 ai := 6.0 555 b = New(WithBacking([]float64{2})) 556 correct = 36.0 557 558 res, err = Pow(ai, b) 559 if err != nil { 560 t.Fatalf("Error: %v", err) 561 } 562 assert.Equal(t, correct, res.Data()) 563 564 // scalar-interface 565 a = New(WithBacking([]float64{6})) 566 bi := 2.0 567 correct = 36.0 568 569 res, err = Pow(a, bi) 570 if err != nil { 571 t.Fatalf("Error: %v", err) 572 } 573 assert.Equal(t, correct, res.Data()) 574 }