gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/blas/testblas/common.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testblas 6 7 import ( 8 "math" 9 "math/cmplx" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "gonum.org/v1/gonum/blas" 15 "gonum.org/v1/gonum/floats/scalar" 16 ) 17 18 // throwPanic will throw unexpected panics if true, or will just report them as errors if false 19 const throwPanic = true 20 21 var znan = cmplx.NaN() 22 23 func dTolEqual(a, b float64) bool { 24 if math.IsNaN(a) && math.IsNaN(b) { 25 return true 26 } 27 if a == b { 28 return true 29 } 30 m := math.Max(math.Abs(a), math.Abs(b)) 31 if m > 1 { 32 a /= m 33 b /= m 34 } 35 if math.Abs(a-b) < 1e-14 { 36 return true 37 } 38 return false 39 } 40 41 func dSliceTolEqual(a, b []float64) bool { 42 if len(a) != len(b) { 43 return false 44 } 45 for i := range a { 46 if !dTolEqual(a[i], b[i]) { 47 return false 48 } 49 } 50 return true 51 } 52 53 func dStridedSliceTolEqual(n int, a []float64, inca int, b []float64, incb int) bool { 54 ia := 0 55 ib := 0 56 if inca <= 0 { 57 ia = -(n - 1) * inca 58 } 59 if incb <= 0 { 60 ib = -(n - 1) * incb 61 } 62 for i := 0; i < n; i++ { 63 if !dTolEqual(a[ia], b[ib]) { 64 return false 65 } 66 ia += inca 67 ib += incb 68 } 69 return true 70 } 71 72 func dSliceEqual(a, b []float64) bool { 73 if len(a) != len(b) { 74 return false 75 } 76 for i := range a { 77 if !dTolEqual(a[i], b[i]) { 78 return false 79 } 80 } 81 return true 82 } 83 84 func dCopyTwoTmp(x, xTmp, y, yTmp []float64) { 85 if len(x) != len(xTmp) { 86 panic("x size mismatch") 87 } 88 if len(y) != len(yTmp) { 89 panic("y size mismatch") 90 } 91 copy(xTmp, x) 92 copy(yTmp, y) 93 } 94 95 // returns true if the function panics 96 func panics(f func()) (b bool) { 97 defer func() { 98 err := recover() 99 if err != nil { 100 b = true 101 } 102 }() 103 f() 104 return 105 } 106 107 func testpanics(f func(), name string, t *testing.T) { 108 b := panics(f) 109 if !b { 110 t.Errorf("%v should panic and does not", name) 111 } 112 } 113 114 func sliceOfSliceCopy(a [][]float64) [][]float64 { 115 n := make([][]float64, len(a)) 116 for i := range a { 117 n[i] = make([]float64, len(a[i])) 118 copy(n[i], a[i]) 119 } 120 return n 121 } 122 123 func sliceCopy(a []float64) []float64 { 124 n := make([]float64, len(a)) 125 copy(n, a) 126 return n 127 } 128 129 func flatten(a [][]float64) []float64 { 130 if len(a) == 0 { 131 return nil 132 } 133 m := len(a) 134 n := len(a[0]) 135 s := make([]float64, m*n) 136 for i := 0; i < m; i++ { 137 for j := 0; j < n; j++ { 138 s[i*n+j] = a[i][j] 139 } 140 } 141 return s 142 } 143 144 func unflatten(a []float64, m, n int) [][]float64 { 145 s := make([][]float64, m) 146 for i := 0; i < m; i++ { 147 s[i] = make([]float64, n) 148 for j := 0; j < n; j++ { 149 s[i][j] = a[i*n+j] 150 } 151 } 152 return s 153 } 154 155 // flattenTriangular turns the upper or lower triangle of a dense slice of slice 156 // into a single slice with packed storage. a must be a square matrix. 157 func flattenTriangular(a [][]float64, ul blas.Uplo) []float64 { 158 m := len(a) 159 aFlat := make([]float64, m*(m+1)/2) 160 var k int 161 if ul == blas.Upper { 162 for i := 0; i < m; i++ { 163 k += copy(aFlat[k:], a[i][i:]) 164 } 165 return aFlat 166 } 167 for i := 0; i < m; i++ { 168 k += copy(aFlat[k:], a[i][:i+1]) 169 } 170 return aFlat 171 } 172 173 // flattenBanded turns a dense banded slice of slice into the compact banded matrix format 174 func flattenBanded(a [][]float64, ku, kl int) []float64 { 175 m := len(a) 176 n := len(a[0]) 177 if ku < 0 || kl < 0 { 178 panic("testblas: negative band length") 179 } 180 nRows := m 181 nCols := (ku + kl + 1) 182 aflat := make([]float64, nRows*nCols) 183 for i := range aflat { 184 aflat[i] = math.NaN() 185 } 186 // loop over the rows, and then the bands 187 // elements in the ith row stay in the ith row 188 // order in bands is kept 189 for i := 0; i < nRows; i++ { 190 min := -kl 191 if i-kl < 0 { 192 min = -i 193 } 194 max := ku 195 if i+ku >= n { 196 max = n - i - 1 197 } 198 for j := min; j <= max; j++ { 199 col := kl + j 200 aflat[i*nCols+col] = a[i][i+j] 201 } 202 } 203 return aflat 204 } 205 206 // makeIncremented takes a float64 slice with inc == 1 and makes an incremented version 207 // and adds extra values on the end 208 func makeIncremented(x []float64, inc int, extra int) []float64 { 209 if inc == 0 { 210 panic("zero inc") 211 } 212 absinc := inc 213 if absinc < 0 { 214 absinc = -inc 215 } 216 xcopy := make([]float64, len(x)) 217 if inc > 0 { 218 copy(xcopy, x) 219 } else { 220 for i := 0; i < len(x); i++ { 221 xcopy[i] = x[len(x)-i-1] 222 } 223 } 224 225 // don't use NaN because it makes comparison hard 226 // Do use a weird unique value for easier debugging 227 counter := 100.0 228 var xnew []float64 229 for i, v := range xcopy { 230 xnew = append(xnew, v) 231 if i != len(x)-1 { 232 for j := 0; j < absinc-1; j++ { 233 xnew = append(xnew, counter) 234 counter++ 235 } 236 } 237 } 238 for i := 0; i < extra; i++ { 239 xnew = append(xnew, counter) 240 counter++ 241 } 242 return xnew 243 } 244 245 // makeIncremented32 takes a float32 slice with inc == 1 and makes an incremented version 246 // and adds extra values on the end 247 func makeIncremented32(x []float32, inc int, extra int) []float32 { 248 if inc == 0 { 249 panic("zero inc") 250 } 251 absinc := inc 252 if absinc < 0 { 253 absinc = -inc 254 } 255 xcopy := make([]float32, len(x)) 256 if inc > 0 { 257 copy(xcopy, x) 258 } else { 259 for i := 0; i < len(x); i++ { 260 xcopy[i] = x[len(x)-i-1] 261 } 262 } 263 264 // don't use NaN because it makes comparison hard 265 // Do use a weird unique value for easier debugging 266 var counter float32 = 100.0 267 var xnew []float32 268 for i, v := range xcopy { 269 xnew = append(xnew, v) 270 if i != len(x)-1 { 271 for j := 0; j < absinc-1; j++ { 272 xnew = append(xnew, counter) 273 counter++ 274 } 275 } 276 } 277 for i := 0; i < extra; i++ { 278 xnew = append(xnew, counter) 279 counter++ 280 } 281 return xnew 282 } 283 284 func abs(x int) int { 285 if x < 0 { 286 return -x 287 } 288 return x 289 } 290 291 func allPairs(x, y []int) [][2]int { 292 var p [][2]int 293 for _, v0 := range x { 294 for _, v1 := range y { 295 p = append(p, [2]int{v0, v1}) 296 } 297 } 298 return p 299 } 300 301 func sameFloat64(a, b float64) bool { 302 return a == b || math.IsNaN(a) && math.IsNaN(b) 303 } 304 305 func sameComplex128(x, y complex128) bool { 306 return sameFloat64(real(x), real(y)) && sameFloat64(imag(x), imag(y)) 307 } 308 309 func zsame(x, y []complex128) bool { 310 if len(x) != len(y) { 311 return false 312 } 313 for i, v := range x { 314 w := y[i] 315 if !sameComplex128(v, w) { 316 return false 317 } 318 } 319 return true 320 } 321 322 // zSameAtNonstrided returns whether elements at non-stride positions of vectors 323 // x and y are same. 324 func zSameAtNonstrided(x, y []complex128, inc int) bool { 325 if len(x) != len(y) { 326 return false 327 } 328 if inc < 0 { 329 inc = -inc 330 } 331 for i, v := range x { 332 if i%inc == 0 { 333 continue 334 } 335 w := y[i] 336 if !sameComplex128(v, w) { 337 return false 338 } 339 } 340 return true 341 } 342 343 // zEqualApproxAtStrided returns whether elements at stride positions of vectors 344 // x and y are approximately equal within tol. 345 func zEqualApproxAtStrided(x, y []complex128, inc int, tol float64) bool { 346 if len(x) != len(y) { 347 return false 348 } 349 if inc < 0 { 350 inc = -inc 351 } 352 for i := 0; i < len(x); i += inc { 353 v := x[i] 354 w := y[i] 355 if !(cmplx.Abs(v-w) <= tol) { 356 return false 357 } 358 } 359 return true 360 } 361 362 func makeZVector(data []complex128, inc int) []complex128 { 363 if inc == 0 { 364 panic("bad test") 365 } 366 if len(data) == 0 { 367 return nil 368 } 369 inc = abs(inc) 370 x := make([]complex128, (len(data)-1)*inc+1) 371 for i := range x { 372 x[i] = znan 373 } 374 for i, v := range data { 375 x[i*inc] = v 376 } 377 return x 378 } 379 380 func makeZGeneral(data []complex128, m, n int, ld int) []complex128 { 381 if m < 0 || n < 0 { 382 panic("bad test") 383 } 384 if data != nil && len(data) != m*n { 385 panic("bad test") 386 } 387 if ld < max(1, n) { 388 panic("bad test") 389 } 390 if m == 0 || n == 0 { 391 return nil 392 } 393 a := make([]complex128, (m-1)*ld+n) 394 for i := range a { 395 a[i] = znan 396 } 397 if data != nil { 398 for i := 0; i < m; i++ { 399 copy(a[i*ld:i*ld+n], data[i*n:i*n+n]) 400 } 401 } 402 return a 403 } 404 405 // zPack returns the uplo triangle of an n×n matrix A in packed format. 406 func zPack(uplo blas.Uplo, n int, a []complex128, lda int) []complex128 { 407 if n == 0 { 408 return nil 409 } 410 ap := make([]complex128, n*(n+1)/2) 411 var ii int 412 if uplo == blas.Upper { 413 for i := 0; i < n; i++ { 414 for j := i; j < n; j++ { 415 ap[ii] = a[i*lda+j] 416 ii++ 417 } 418 } 419 } else { 420 for i := 0; i < n; i++ { 421 for j := 0; j <= i; j++ { 422 ap[ii] = a[i*lda+j] 423 ii++ 424 } 425 } 426 } 427 return ap 428 } 429 430 // zUnpackAsHermitian returns an n×n general Hermitian matrix (with stride n) 431 // whose packed uplo triangle is stored on entry in ap. 432 func zUnpackAsHermitian(uplo blas.Uplo, n int, ap []complex128) []complex128 { 433 if n == 0 { 434 return nil 435 } 436 a := make([]complex128, n*n) 437 lda := n 438 var ii int 439 if uplo == blas.Upper { 440 for i := 0; i < n; i++ { 441 for j := i; j < n; j++ { 442 a[i*lda+j] = ap[ii] 443 if i != j { 444 a[j*lda+i] = cmplx.Conj(ap[ii]) 445 } 446 ii++ 447 } 448 } 449 } else { 450 for i := 0; i < n; i++ { 451 for j := 0; j <= i; j++ { 452 a[i*lda+j] = ap[ii] 453 if i != j { 454 a[j*lda+i] = cmplx.Conj(ap[ii]) 455 } 456 ii++ 457 } 458 } 459 } 460 return a 461 } 462 463 // zPackBand returns the (kL+1+kU) band of an m×n general matrix A in band 464 // matrix format with ldab stride. Out-of-range elements are filled with NaN. 465 func zPackBand(kL, kU, ldab int, m, n int, a []complex128, lda int) []complex128 { 466 if m == 0 || n == 0 { 467 return nil 468 } 469 nRow := min(m, n+kL) 470 ab := make([]complex128, (nRow-1)*ldab+kL+1+kU) 471 for i := range ab { 472 ab[i] = znan 473 } 474 for i := 0; i < m; i++ { 475 off := max(0, kL-i) 476 var k int 477 for j := max(0, i-kL); j < min(n, i+kU+1); j++ { 478 ab[i*ldab+off+k] = a[i*lda+j] 479 k++ 480 } 481 } 482 return ab 483 } 484 485 // zPackTriBand returns in band matrix format the (k+1) band in the uplo 486 // triangle of an n×n matrix A. Out-of-range elements are filled with NaN. 487 func zPackTriBand(k, ldab int, uplo blas.Uplo, n int, a []complex128, lda int) []complex128 { 488 if n == 0 { 489 return nil 490 } 491 ab := make([]complex128, (n-1)*ldab+k+1) 492 for i := range ab { 493 ab[i] = znan 494 } 495 if uplo == blas.Upper { 496 for i := 0; i < n; i++ { 497 var k int 498 for j := i; j < min(n, i+k+1); j++ { 499 ab[i*ldab+k] = a[i*lda+j] 500 k++ 501 } 502 } 503 } else { 504 for i := 0; i < n; i++ { 505 off := max(0, k-i) 506 var kk int 507 for j := max(0, i-k); j <= i; j++ { 508 ab[i*ldab+off+kk] = a[i*lda+j] 509 kk++ 510 } 511 } 512 } 513 return ab 514 } 515 516 // zEqualApprox returns whether the slices a and b are approximately equal. 517 func zEqualApprox(a, b []complex128, tol float64) bool { 518 if len(a) != len(b) { 519 panic("mismatched slice length") 520 } 521 for i, ai := range a { 522 if !scalar.EqualWithinAbs(cmplx.Abs(ai), cmplx.Abs(b[i]), tol) { 523 return false 524 } 525 } 526 return true 527 } 528 529 // rndComplex128 returns a complex128 with random components. 530 func rndComplex128(rnd *rand.Rand) complex128 { 531 return complex(rnd.NormFloat64(), rnd.NormFloat64()) 532 } 533 534 // zmm returns the result of one of the matrix-matrix operations 535 // 536 // alpha * op(A) * op(B) + beta * C 537 // 538 // where op(X) is one of 539 // 540 // op(X) = X or op(X) = Xᵀ or op(X) = Xᴴ, 541 // 542 // alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix, 543 // op(B) a k×n matrix and C an m×n matrix. 544 // 545 // The returned slice is newly allocated, has the same length as c and the 546 // matrix it represents has the stride ldc. Out-of-range elements are equal to 547 // those of C to ease comparison of results from BLAS Level 3 functions. 548 func zmm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) []complex128 { 549 r := make([]complex128, len(c)) 550 copy(r, c) 551 for i := 0; i < m; i++ { 552 for j := 0; j < n; j++ { 553 r[i*ldc+j] = 0 554 } 555 } 556 switch tA { 557 case blas.NoTrans: 558 switch tB { 559 case blas.NoTrans: 560 for i := 0; i < m; i++ { 561 for j := 0; j < n; j++ { 562 for l := 0; l < k; l++ { 563 r[i*ldc+j] += a[i*lda+l] * b[l*ldb+j] 564 } 565 } 566 } 567 case blas.Trans: 568 for i := 0; i < m; i++ { 569 for j := 0; j < n; j++ { 570 for l := 0; l < k; l++ { 571 r[i*ldc+j] += a[i*lda+l] * b[j*ldb+l] 572 } 573 } 574 } 575 case blas.ConjTrans: 576 for i := 0; i < m; i++ { 577 for j := 0; j < n; j++ { 578 for l := 0; l < k; l++ { 579 r[i*ldc+j] += a[i*lda+l] * cmplx.Conj(b[j*ldb+l]) 580 } 581 } 582 } 583 } 584 case blas.Trans: 585 switch tB { 586 case blas.NoTrans: 587 for i := 0; i < m; i++ { 588 for j := 0; j < n; j++ { 589 for l := 0; l < k; l++ { 590 r[i*ldc+j] += a[l*lda+i] * b[l*ldb+j] 591 } 592 } 593 } 594 case blas.Trans: 595 for i := 0; i < m; i++ { 596 for j := 0; j < n; j++ { 597 for l := 0; l < k; l++ { 598 r[i*ldc+j] += a[l*lda+i] * b[j*ldb+l] 599 } 600 } 601 } 602 case blas.ConjTrans: 603 for i := 0; i < m; i++ { 604 for j := 0; j < n; j++ { 605 for l := 0; l < k; l++ { 606 r[i*ldc+j] += a[l*lda+i] * cmplx.Conj(b[j*ldb+l]) 607 } 608 } 609 } 610 } 611 case blas.ConjTrans: 612 switch tB { 613 case blas.NoTrans: 614 for i := 0; i < m; i++ { 615 for j := 0; j < n; j++ { 616 for l := 0; l < k; l++ { 617 r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j] 618 } 619 } 620 } 621 case blas.Trans: 622 for i := 0; i < m; i++ { 623 for j := 0; j < n; j++ { 624 for l := 0; l < k; l++ { 625 r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l] 626 } 627 } 628 } 629 case blas.ConjTrans: 630 for i := 0; i < m; i++ { 631 for j := 0; j < n; j++ { 632 for l := 0; l < k; l++ { 633 r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l]) 634 } 635 } 636 } 637 } 638 } 639 for i := 0; i < m; i++ { 640 for j := 0; j < n; j++ { 641 r[i*ldc+j] = alpha * r[i*ldc+j] 642 if beta != 0 { 643 r[i*ldc+j] += beta * c[i*ldc+j] 644 } 645 } 646 } 647 return r 648 } 649 650 // transString returns a string representation of blas.Transpose. 651 func transString(t blas.Transpose) string { 652 switch t { 653 case blas.NoTrans: 654 return "NoTrans" 655 case blas.Trans: 656 return "Trans" 657 case blas.ConjTrans: 658 return "ConjTrans" 659 } 660 return "unknown trans" 661 } 662 663 // uploString returns a string representation of blas.Uplo. 664 func uploString(uplo blas.Uplo) string { 665 switch uplo { 666 case blas.Lower: 667 return "Lower" 668 case blas.Upper: 669 return "Upper" 670 } 671 return "unknown uplo" 672 } 673 674 // sideString returns a string representation of blas.Side. 675 func sideString(side blas.Side) string { 676 switch side { 677 case blas.Left: 678 return "Left" 679 case blas.Right: 680 return "Right" 681 } 682 return "unknown side" 683 } 684 685 // diagString returns a string representation of blas.Diag. 686 func diagString(diag blas.Diag) string { 687 switch diag { 688 case blas.Unit: 689 return "Unit" 690 case blas.NonUnit: 691 return "NonUnit" 692 } 693 return "unknown diag" 694 } 695 696 // zSameLowerTri returns whether n×n matrices A and B are same under the diagonal. 697 func zSameLowerTri(n int, a []complex128, lda int, b []complex128, ldb int) bool { 698 for i := 1; i < n; i++ { 699 for j := 0; j < i; j++ { 700 aij := a[i*lda+j] 701 bij := b[i*ldb+j] 702 if !sameComplex128(aij, bij) { 703 return false 704 } 705 } 706 } 707 return true 708 } 709 710 // zSameUpperTri returns whether n×n matrices A and B are same above the diagonal. 711 func zSameUpperTri(n int, a []complex128, lda int, b []complex128, ldb int) bool { 712 for i := 0; i < n-1; i++ { 713 for j := i + 1; j < n; j++ { 714 aij := a[i*lda+j] 715 bij := b[i*ldb+j] 716 if !sameComplex128(aij, bij) { 717 return false 718 } 719 } 720 } 721 return true 722 }