gonum.org/v1/gonum@v0.14.0/blas/testblas/zhbmv.go (about) 1 // Copyright ©2018 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testblas 6 7 import ( 8 "fmt" 9 "math" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 "gonum.org/v1/gonum/blas" 14 ) 15 16 type Zhbmver interface { 17 Zhbmv(uplo blas.Uplo, n, k int, alpha complex128, ab []complex128, ldab int, x []complex128, incX int, beta complex128, y []complex128, incY int) 18 19 Zhemver 20 } 21 22 func ZhbmvTest(t *testing.T, impl Zhbmver) { 23 rnd := rand.New(rand.NewSource(1)) 24 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 25 for _, n := range []int{0, 1, 2, 3, 5} { 26 for k := 0; k < n; k++ { 27 for _, ldab := range []int{k + 1, k + 1 + 10} { 28 // Generate all possible combinations of given increments. 29 // Use slices to reduce indentation. 30 for _, inc := range allPairs([]int{-11, 1, 7}, []int{-3, 1, 5}) { 31 incX := inc[0] 32 incY := inc[1] 33 for _, ab := range []struct { 34 alpha complex128 35 beta complex128 36 }{ 37 // All potentially relevant values of 38 // alpha and beta. 39 {0, 0}, 40 {0, 1}, 41 {0, complex(rnd.NormFloat64(), rnd.NormFloat64())}, 42 {complex(rnd.NormFloat64(), rnd.NormFloat64()), 0}, 43 {complex(rnd.NormFloat64(), rnd.NormFloat64()), 1}, 44 {complex(rnd.NormFloat64(), rnd.NormFloat64()), complex(rnd.NormFloat64(), rnd.NormFloat64())}, 45 } { 46 testZhbmv(t, impl, rnd, uplo, n, k, ab.alpha, ab.beta, ldab, incX, incY) 47 } 48 } 49 } 50 } 51 } 52 } 53 } 54 55 // testZhbmv tests Zhbmv by comparing its output to that of Zhemv. 56 func testZhbmv(t *testing.T, impl Zhbmver, rnd *rand.Rand, uplo blas.Uplo, n, k int, alpha, beta complex128, ldab, incX, incY int) { 57 const tol = 1e-13 58 59 // Allocate a dense-storage Hermitian band matrix filled with NaNs that will be 60 // used as the reference matrix for Zhemv. 61 lda := max(1, n) 62 a := makeZGeneral(nil, n, n, lda) 63 // Fill the matrix with zeros. 64 for i := 0; i < n; i++ { 65 for j := 0; j < n; j++ { 66 a[i*lda+j] = 0 67 } 68 } 69 // Fill the triangle band with random data, invalidating the imaginary 70 // part of diagonal elements because it should not be referenced by 71 // Zhbmv and Zhemv. 72 if uplo == blas.Upper { 73 for i := 0; i < n; i++ { 74 a[i*lda+i] = complex(rnd.NormFloat64(), math.NaN()) 75 for j := i + 1; j < min(n, i+k+1); j++ { 76 re := rnd.NormFloat64() 77 im := rnd.NormFloat64() 78 a[i*lda+j] = complex(re, im) 79 } 80 } 81 } else { 82 for i := 0; i < n; i++ { 83 for j := max(0, i-k); j < i; j++ { 84 re := rnd.NormFloat64() 85 im := rnd.NormFloat64() 86 a[i*lda+j] = complex(re, im) 87 } 88 a[i*lda+i] = complex(rnd.NormFloat64(), math.NaN()) 89 } 90 } 91 // Create the actual Hermitian band matrix. 92 ab := zPackTriBand(k, ldab, uplo, n, a, lda) 93 abCopy := make([]complex128, len(ab)) 94 copy(abCopy, ab) 95 96 // Generate a random complex vector x. 97 xtest := make([]complex128, n) 98 for i := range xtest { 99 re := rnd.NormFloat64() 100 im := rnd.NormFloat64() 101 xtest[i] = complex(re, im) 102 } 103 x := makeZVector(xtest, incX) 104 xCopy := make([]complex128, len(x)) 105 copy(xCopy, x) 106 107 // Generate a random complex vector y. 108 ytest := make([]complex128, n) 109 for i := range ytest { 110 re := rnd.NormFloat64() 111 im := rnd.NormFloat64() 112 ytest[i] = complex(re, im) 113 } 114 y := makeZVector(ytest, incY) 115 116 want := make([]complex128, len(y)) 117 copy(want, y) 118 119 // Compute the reference result of alpha*op(A)*x + beta*y, storing it 120 // into want. 121 impl.Zhemv(uplo, n, alpha, a, lda, x, incX, beta, want, incY) 122 // Compute alpha*op(A)*x + beta*y, storing the result in-place into y. 123 impl.Zhbmv(uplo, n, k, alpha, ab, ldab, x, incX, beta, y, incY) 124 125 prefix := fmt.Sprintf("uplo=%v,n=%v,k=%v,incX=%v,incY=%v,ldab=%v", uplo, n, k, incX, incY, ldab) 126 if !zsame(x, xCopy) { 127 t.Errorf("%v: unexpected modification of x", prefix) 128 } 129 if !zsame(ab, abCopy) { 130 t.Errorf("%v: unexpected modification of ab", prefix) 131 } 132 if !zSameAtNonstrided(y, want, incY) { 133 t.Errorf("%v: unexpected modification of y", prefix) 134 } 135 if !zEqualApproxAtStrided(y, want, incY, tol) { 136 t.Errorf("%v: unexpected result\nwant %v\ngot %v", prefix, want, y) 137 } 138 }