gonum.org/v1/gonum@v0.14.0/blas/testblas/common.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testblas 6 7 import ( 8 "math" 9 "math/cmplx" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "gonum.org/v1/gonum/blas" 15 "gonum.org/v1/gonum/floats/scalar" 16 ) 17 18 // throwPanic will throw unexpected panics if true, or will just report them as errors if false 19 const throwPanic = true 20 21 var znan = cmplx.NaN() 22 23 func dTolEqual(a, b float64) bool { 24 if math.IsNaN(a) && math.IsNaN(b) { 25 return true 26 } 27 if a == b { 28 return true 29 } 30 m := math.Max(math.Abs(a), math.Abs(b)) 31 if m > 1 { 32 a /= m 33 b /= m 34 } 35 if math.Abs(a-b) < 1e-14 { 36 return true 37 } 38 return false 39 } 40 41 func dSliceTolEqual(a, b []float64) bool { 42 if len(a) != len(b) { 43 return false 44 } 45 for i := range a { 46 if !dTolEqual(a[i], b[i]) { 47 return false 48 } 49 } 50 return true 51 } 52 53 func dStridedSliceTolEqual(n int, a []float64, inca int, b []float64, incb int) bool { 54 ia := 0 55 ib := 0 56 if inca <= 0 { 57 ia = -(n - 1) * inca 58 } 59 if incb <= 0 { 60 ib = -(n - 1) * incb 61 } 62 for i := 0; i < n; i++ { 63 if !dTolEqual(a[ia], b[ib]) { 64 return false 65 } 66 ia += inca 67 ib += incb 68 } 69 return true 70 } 71 72 func dSliceEqual(a, b []float64) bool { 73 if len(a) != len(b) { 74 return false 75 } 76 for i := range a { 77 if !dTolEqual(a[i], b[i]) { 78 return false 79 } 80 } 81 return true 82 } 83 84 func dCopyTwoTmp(x, xTmp, y, yTmp []float64) { 85 if len(x) != len(xTmp) { 86 panic("x size mismatch") 87 } 88 if len(y) != len(yTmp) { 89 panic("y size mismatch") 90 } 91 copy(xTmp, x) 92 copy(yTmp, y) 93 } 94 95 // returns true if the function panics 96 func panics(f func()) (b bool) { 97 defer func() { 98 err := recover() 99 if err != nil { 100 b = true 101 } 102 }() 103 f() 104 return 105 } 106 107 func testpanics(f func(), name string, t *testing.T) { 108 b := panics(f) 109 if !b { 110 t.Errorf("%v should panic and does not", name) 111 } 112 } 113 114 func sliceOfSliceCopy(a [][]float64) [][]float64 { 115 n := make([][]float64, len(a)) 116 for i := range a { 117 n[i] = make([]float64, len(a[i])) 118 copy(n[i], a[i]) 119 } 120 return n 121 } 122 123 func sliceCopy(a []float64) []float64 { 124 n := make([]float64, len(a)) 125 copy(n, a) 126 return n 127 } 128 129 func flatten(a [][]float64) []float64 { 130 if len(a) == 0 { 131 return nil 132 } 133 m := len(a) 134 n := len(a[0]) 135 s := make([]float64, m*n) 136 for i := 0; i < m; i++ { 137 for j := 0; j < n; j++ { 138 s[i*n+j] = a[i][j] 139 } 140 } 141 return s 142 } 143 144 func unflatten(a []float64, m, n int) [][]float64 { 145 s := make([][]float64, m) 146 for i := 0; i < m; i++ { 147 s[i] = make([]float64, n) 148 for j := 0; j < n; j++ { 149 s[i][j] = a[i*n+j] 150 } 151 } 152 return s 153 } 154 155 // flattenTriangular turns the upper or lower triangle of a dense slice of slice 156 // into a single slice with packed storage. a must be a square matrix. 157 func flattenTriangular(a [][]float64, ul blas.Uplo) []float64 { 158 m := len(a) 159 aFlat := make([]float64, m*(m+1)/2) 160 var k int 161 if ul == blas.Upper { 162 for i := 0; i < m; i++ { 163 k += copy(aFlat[k:], a[i][i:]) 164 } 165 return aFlat 166 } 167 for i := 0; i < m; i++ { 168 k += copy(aFlat[k:], a[i][:i+1]) 169 } 170 return aFlat 171 } 172 173 // flattenBanded turns a dense banded slice of slice into the compact banded matrix format 174 func flattenBanded(a [][]float64, ku, kl int) []float64 { 175 m := len(a) 176 n := len(a[0]) 177 if ku < 0 || kl < 0 { 178 panic("testblas: negative band length") 179 } 180 nRows := m 181 nCols := (ku + kl + 1) 182 aflat := make([]float64, nRows*nCols) 183 for i := range aflat { 184 aflat[i] = math.NaN() 185 } 186 // loop over the rows, and then the bands 187 // elements in the ith row stay in the ith row 188 // order in bands is kept 189 for i := 0; i < nRows; i++ { 190 min := -kl 191 if i-kl < 0 { 192 min = -i 193 } 194 max := ku 195 if i+ku >= n { 196 max = n - i - 1 197 } 198 for j := min; j <= max; j++ { 199 col := kl + j 200 aflat[i*nCols+col] = a[i][i+j] 201 } 202 } 203 return aflat 204 } 205 206 // makeIncremented takes a float64 slice with inc == 1 and makes an incremented version 207 // and adds extra values on the end 208 func makeIncremented(x []float64, inc int, extra int) []float64 { 209 if inc == 0 { 210 panic("zero inc") 211 } 212 absinc := inc 213 if absinc < 0 { 214 absinc = -inc 215 } 216 xcopy := make([]float64, len(x)) 217 if inc > 0 { 218 copy(xcopy, x) 219 } else { 220 for i := 0; i < len(x); i++ { 221 xcopy[i] = x[len(x)-i-1] 222 } 223 } 224 225 // don't use NaN because it makes comparison hard 226 // Do use a weird unique value for easier debugging 227 counter := 100.0 228 var xnew []float64 229 for i, v := range xcopy { 230 xnew = append(xnew, v) 231 if i != len(x)-1 { 232 for j := 0; j < absinc-1; j++ { 233 xnew = append(xnew, counter) 234 counter++ 235 } 236 } 237 } 238 for i := 0; i < extra; i++ { 239 xnew = append(xnew, counter) 240 counter++ 241 } 242 return xnew 243 } 244 245 // makeIncremented32 takes a float32 slice with inc == 1 and makes an incremented version 246 // and adds extra values on the end 247 func makeIncremented32(x []float32, inc int, extra int) []float32 { 248 if inc == 0 { 249 panic("zero inc") 250 } 251 absinc := inc 252 if absinc < 0 { 253 absinc = -inc 254 } 255 xcopy := make([]float32, len(x)) 256 if inc > 0 { 257 copy(xcopy, x) 258 } else { 259 for i := 0; i < len(x); i++ { 260 xcopy[i] = x[len(x)-i-1] 261 } 262 } 263 264 // don't use NaN because it makes comparison hard 265 // Do use a weird unique value for easier debugging 266 var counter float32 = 100.0 267 var xnew []float32 268 for i, v := range xcopy { 269 xnew = append(xnew, v) 270 if i != len(x)-1 { 271 for j := 0; j < absinc-1; j++ { 272 xnew = append(xnew, counter) 273 counter++ 274 } 275 } 276 } 277 for i := 0; i < extra; i++ { 278 xnew = append(xnew, counter) 279 counter++ 280 } 281 return xnew 282 } 283 284 func abs(x int) int { 285 if x < 0 { 286 return -x 287 } 288 return x 289 } 290 291 func allPairs(x, y []int) [][2]int { 292 var p [][2]int 293 for _, v0 := range x { 294 for _, v1 := range y { 295 p = append(p, [2]int{v0, v1}) 296 } 297 } 298 return p 299 } 300 301 func sameFloat64(a, b float64) bool { 302 return a == b || math.IsNaN(a) && math.IsNaN(b) 303 } 304 305 func sameComplex128(x, y complex128) bool { 306 return sameFloat64(real(x), real(y)) && sameFloat64(imag(x), imag(y)) 307 } 308 309 func zsame(x, y []complex128) bool { 310 if len(x) != len(y) { 311 return false 312 } 313 for i, v := range x { 314 w := y[i] 315 if !sameComplex128(v, w) { 316 return false 317 } 318 } 319 return true 320 } 321 322 // zSameAtNonstrided returns whether elements at non-stride positions of vectors 323 // x and y are same. 324 func zSameAtNonstrided(x, y []complex128, inc int) bool { 325 if len(x) != len(y) { 326 return false 327 } 328 if inc < 0 { 329 inc = -inc 330 } 331 for i, v := range x { 332 if i%inc == 0 { 333 continue 334 } 335 w := y[i] 336 if !sameComplex128(v, w) { 337 return false 338 } 339 } 340 return true 341 } 342 343 // zEqualApproxAtStrided returns whether elements at stride positions of vectors 344 // x and y are approximately equal within tol. 345 func zEqualApproxAtStrided(x, y []complex128, inc int, tol float64) bool { 346 if len(x) != len(y) { 347 return false 348 } 349 if inc < 0 { 350 inc = -inc 351 } 352 for i := 0; i < len(x); i += inc { 353 v := x[i] 354 w := y[i] 355 if !(cmplx.Abs(v-w) <= tol) { 356 return false 357 } 358 } 359 return true 360 } 361 362 func makeZVector(data []complex128, inc int) []complex128 { 363 if inc == 0 { 364 panic("bad test") 365 } 366 if len(data) == 0 { 367 return nil 368 } 369 inc = abs(inc) 370 x := make([]complex128, (len(data)-1)*inc+1) 371 for i := range x { 372 x[i] = znan 373 } 374 for i, v := range data { 375 x[i*inc] = v 376 } 377 return x 378 } 379 380 func makeZGeneral(data []complex128, m, n int, ld int) []complex128 { 381 if m < 0 || n < 0 { 382 panic("bad test") 383 } 384 if data != nil && len(data) != m*n { 385 panic("bad test") 386 } 387 if ld < max(1, n) { 388 panic("bad test") 389 } 390 if m == 0 || n == 0 { 391 return nil 392 } 393 a := make([]complex128, (m-1)*ld+n) 394 for i := range a { 395 a[i] = znan 396 } 397 if data != nil { 398 for i := 0; i < m; i++ { 399 copy(a[i*ld:i*ld+n], data[i*n:i*n+n]) 400 } 401 } 402 return a 403 } 404 405 func max(a, b int) int { 406 if a < b { 407 return b 408 } 409 return a 410 } 411 412 func min(a, b int) int { 413 if a < b { 414 return a 415 } 416 return b 417 } 418 419 // zPack returns the uplo triangle of an n×n matrix A in packed format. 420 func zPack(uplo blas.Uplo, n int, a []complex128, lda int) []complex128 { 421 if n == 0 { 422 return nil 423 } 424 ap := make([]complex128, n*(n+1)/2) 425 var ii int 426 if uplo == blas.Upper { 427 for i := 0; i < n; i++ { 428 for j := i; j < n; j++ { 429 ap[ii] = a[i*lda+j] 430 ii++ 431 } 432 } 433 } else { 434 for i := 0; i < n; i++ { 435 for j := 0; j <= i; j++ { 436 ap[ii] = a[i*lda+j] 437 ii++ 438 } 439 } 440 } 441 return ap 442 } 443 444 // zUnpackAsHermitian returns an n×n general Hermitian matrix (with stride n) 445 // whose packed uplo triangle is stored on entry in ap. 446 func zUnpackAsHermitian(uplo blas.Uplo, n int, ap []complex128) []complex128 { 447 if n == 0 { 448 return nil 449 } 450 a := make([]complex128, n*n) 451 lda := n 452 var ii int 453 if uplo == blas.Upper { 454 for i := 0; i < n; i++ { 455 for j := i; j < n; j++ { 456 a[i*lda+j] = ap[ii] 457 if i != j { 458 a[j*lda+i] = cmplx.Conj(ap[ii]) 459 } 460 ii++ 461 } 462 } 463 } else { 464 for i := 0; i < n; i++ { 465 for j := 0; j <= i; j++ { 466 a[i*lda+j] = ap[ii] 467 if i != j { 468 a[j*lda+i] = cmplx.Conj(ap[ii]) 469 } 470 ii++ 471 } 472 } 473 } 474 return a 475 } 476 477 // zPackBand returns the (kL+1+kU) band of an m×n general matrix A in band 478 // matrix format with ldab stride. Out-of-range elements are filled with NaN. 479 func zPackBand(kL, kU, ldab int, m, n int, a []complex128, lda int) []complex128 { 480 if m == 0 || n == 0 { 481 return nil 482 } 483 nRow := min(m, n+kL) 484 ab := make([]complex128, (nRow-1)*ldab+kL+1+kU) 485 for i := range ab { 486 ab[i] = znan 487 } 488 for i := 0; i < m; i++ { 489 off := max(0, kL-i) 490 var k int 491 for j := max(0, i-kL); j < min(n, i+kU+1); j++ { 492 ab[i*ldab+off+k] = a[i*lda+j] 493 k++ 494 } 495 } 496 return ab 497 } 498 499 // zPackTriBand returns in band matrix format the (k+1) band in the uplo 500 // triangle of an n×n matrix A. Out-of-range elements are filled with NaN. 501 func zPackTriBand(k, ldab int, uplo blas.Uplo, n int, a []complex128, lda int) []complex128 { 502 if n == 0 { 503 return nil 504 } 505 ab := make([]complex128, (n-1)*ldab+k+1) 506 for i := range ab { 507 ab[i] = znan 508 } 509 if uplo == blas.Upper { 510 for i := 0; i < n; i++ { 511 var k int 512 for j := i; j < min(n, i+k+1); j++ { 513 ab[i*ldab+k] = a[i*lda+j] 514 k++ 515 } 516 } 517 } else { 518 for i := 0; i < n; i++ { 519 off := max(0, k-i) 520 var kk int 521 for j := max(0, i-k); j <= i; j++ { 522 ab[i*ldab+off+kk] = a[i*lda+j] 523 kk++ 524 } 525 } 526 } 527 return ab 528 } 529 530 // zEqualApprox returns whether the slices a and b are approximately equal. 531 func zEqualApprox(a, b []complex128, tol float64) bool { 532 if len(a) != len(b) { 533 panic("mismatched slice length") 534 } 535 for i, ai := range a { 536 if !scalar.EqualWithinAbs(cmplx.Abs(ai), cmplx.Abs(b[i]), tol) { 537 return false 538 } 539 } 540 return true 541 } 542 543 // rndComplex128 returns a complex128 with random components. 544 func rndComplex128(rnd *rand.Rand) complex128 { 545 return complex(rnd.NormFloat64(), rnd.NormFloat64()) 546 } 547 548 // zmm returns the result of one of the matrix-matrix operations 549 // 550 // alpha * op(A) * op(B) + beta * C 551 // 552 // where op(X) is one of 553 // 554 // op(X) = X or op(X) = Xᵀ or op(X) = Xᴴ, 555 // 556 // alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix, 557 // op(B) a k×n matrix and C an m×n matrix. 558 // 559 // The returned slice is newly allocated, has the same length as c and the 560 // matrix it represents has the stride ldc. Out-of-range elements are equal to 561 // those of C to ease comparison of results from BLAS Level 3 functions. 562 func zmm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) []complex128 { 563 r := make([]complex128, len(c)) 564 copy(r, c) 565 for i := 0; i < m; i++ { 566 for j := 0; j < n; j++ { 567 r[i*ldc+j] = 0 568 } 569 } 570 switch tA { 571 case blas.NoTrans: 572 switch tB { 573 case blas.NoTrans: 574 for i := 0; i < m; i++ { 575 for j := 0; j < n; j++ { 576 for l := 0; l < k; l++ { 577 r[i*ldc+j] += a[i*lda+l] * b[l*ldb+j] 578 } 579 } 580 } 581 case blas.Trans: 582 for i := 0; i < m; i++ { 583 for j := 0; j < n; j++ { 584 for l := 0; l < k; l++ { 585 r[i*ldc+j] += a[i*lda+l] * b[j*ldb+l] 586 } 587 } 588 } 589 case blas.ConjTrans: 590 for i := 0; i < m; i++ { 591 for j := 0; j < n; j++ { 592 for l := 0; l < k; l++ { 593 r[i*ldc+j] += a[i*lda+l] * cmplx.Conj(b[j*ldb+l]) 594 } 595 } 596 } 597 } 598 case blas.Trans: 599 switch tB { 600 case blas.NoTrans: 601 for i := 0; i < m; i++ { 602 for j := 0; j < n; j++ { 603 for l := 0; l < k; l++ { 604 r[i*ldc+j] += a[l*lda+i] * b[l*ldb+j] 605 } 606 } 607 } 608 case blas.Trans: 609 for i := 0; i < m; i++ { 610 for j := 0; j < n; j++ { 611 for l := 0; l < k; l++ { 612 r[i*ldc+j] += a[l*lda+i] * b[j*ldb+l] 613 } 614 } 615 } 616 case blas.ConjTrans: 617 for i := 0; i < m; i++ { 618 for j := 0; j < n; j++ { 619 for l := 0; l < k; l++ { 620 r[i*ldc+j] += a[l*lda+i] * cmplx.Conj(b[j*ldb+l]) 621 } 622 } 623 } 624 } 625 case blas.ConjTrans: 626 switch tB { 627 case blas.NoTrans: 628 for i := 0; i < m; i++ { 629 for j := 0; j < n; j++ { 630 for l := 0; l < k; l++ { 631 r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j] 632 } 633 } 634 } 635 case blas.Trans: 636 for i := 0; i < m; i++ { 637 for j := 0; j < n; j++ { 638 for l := 0; l < k; l++ { 639 r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l] 640 } 641 } 642 } 643 case blas.ConjTrans: 644 for i := 0; i < m; i++ { 645 for j := 0; j < n; j++ { 646 for l := 0; l < k; l++ { 647 r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l]) 648 } 649 } 650 } 651 } 652 } 653 for i := 0; i < m; i++ { 654 for j := 0; j < n; j++ { 655 r[i*ldc+j] = alpha * r[i*ldc+j] 656 if beta != 0 { 657 r[i*ldc+j] += beta * c[i*ldc+j] 658 } 659 } 660 } 661 return r 662 } 663 664 // transString returns a string representation of blas.Transpose. 665 func transString(t blas.Transpose) string { 666 switch t { 667 case blas.NoTrans: 668 return "NoTrans" 669 case blas.Trans: 670 return "Trans" 671 case blas.ConjTrans: 672 return "ConjTrans" 673 } 674 return "unknown trans" 675 } 676 677 // uploString returns a string representation of blas.Uplo. 678 func uploString(uplo blas.Uplo) string { 679 switch uplo { 680 case blas.Lower: 681 return "Lower" 682 case blas.Upper: 683 return "Upper" 684 } 685 return "unknown uplo" 686 } 687 688 // sideString returns a string representation of blas.Side. 689 func sideString(side blas.Side) string { 690 switch side { 691 case blas.Left: 692 return "Left" 693 case blas.Right: 694 return "Right" 695 } 696 return "unknown side" 697 } 698 699 // diagString returns a string representation of blas.Diag. 700 func diagString(diag blas.Diag) string { 701 switch diag { 702 case blas.Unit: 703 return "Unit" 704 case blas.NonUnit: 705 return "NonUnit" 706 } 707 return "unknown diag" 708 } 709 710 // zSameLowerTri returns whether n×n matrices A and B are same under the diagonal. 711 func zSameLowerTri(n int, a []complex128, lda int, b []complex128, ldb int) bool { 712 for i := 1; i < n; i++ { 713 for j := 0; j < i; j++ { 714 aij := a[i*lda+j] 715 bij := b[i*ldb+j] 716 if !sameComplex128(aij, bij) { 717 return false 718 } 719 } 720 } 721 return true 722 } 723 724 // zSameUpperTri returns whether n×n matrices A and B are same above the diagonal. 725 func zSameUpperTri(n int, a []complex128, lda int, b []complex128, ldb int) bool { 726 for i := 0; i < n-1; i++ { 727 for j := i + 1; j < n; j++ { 728 aij := a[i*lda+j] 729 bij := b[i*ldb+j] 730 if !sameComplex128(aij, bij) { 731 return false 732 } 733 } 734 } 735 return true 736 }