gonum.org/v1/gonum@v0.14.0/blas/testblas/zher2k.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testblas 6 7 import ( 8 "fmt" 9 "math/cmplx" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 "gonum.org/v1/gonum/blas" 14 ) 15 16 type Zher2ker interface { 17 Zher2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int) 18 } 19 20 func Zher2kTest(t *testing.T, impl Zher2ker) { 21 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 22 for _, trans := range []blas.Transpose{blas.NoTrans, blas.ConjTrans} { 23 name := uploString(uplo) + "-" + transString(trans) 24 t.Run(name, func(t *testing.T) { 25 for _, n := range []int{0, 1, 2, 3, 4, 5} { 26 for _, k := range []int{0, 1, 2, 3, 4, 5, 7} { 27 zher2kTest(t, impl, uplo, trans, n, k) 28 } 29 } 30 }) 31 } 32 } 33 } 34 35 func zher2kTest(t *testing.T, impl Zher2ker, uplo blas.Uplo, trans blas.Transpose, n, k int) { 36 const tol = 1e-13 37 38 rnd := rand.New(rand.NewSource(1)) 39 40 row, col := n, k 41 if trans == blas.ConjTrans { 42 row, col = k, n 43 } 44 for _, lda := range []int{max(1, col), col + 2} { 45 for _, ldb := range []int{max(1, col), col + 3} { 46 for _, ldc := range []int{max(1, n), n + 4} { 47 for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} { 48 for _, beta := range []float64{0, 1, 1.3} { 49 // Allocate the matrix A and fill it with random numbers. 50 a := make([]complex128, row*lda) 51 for i := range a { 52 a[i] = rndComplex128(rnd) 53 } 54 // Create a copy of A for checking that 55 // Zher2k does not modify A. 56 aCopy := make([]complex128, len(a)) 57 copy(aCopy, a) 58 59 // Allocate the matrix B and fill it with random numbers. 60 b := make([]complex128, row*ldb) 61 for i := range b { 62 b[i] = rndComplex128(rnd) 63 } 64 // Create a copy of B for checking that 65 // Zher2k does not modify B. 66 bCopy := make([]complex128, len(b)) 67 copy(bCopy, b) 68 69 // Allocate the matrix C and fill it with random numbers. 70 c := make([]complex128, n*ldc) 71 for i := range c { 72 c[i] = rndComplex128(rnd) 73 } 74 if (alpha == 0 || k == 0) && beta == 1 { 75 // In case of a quick return 76 // zero out the diagonal. 77 for i := 0; i < n; i++ { 78 c[i*ldc+i] = complex(real(c[i*ldc+i]), 0) 79 } 80 } 81 // Create a copy of C for checking that 82 // Zher2k does not modify its triangle 83 // opposite to uplo. 84 cCopy := make([]complex128, len(c)) 85 copy(cCopy, c) 86 // Create a copy of C expanded into a 87 // full hermitian matrix for computing 88 // the expected result using zmm. 89 cHer := make([]complex128, len(c)) 90 copy(cHer, c) 91 if uplo == blas.Upper { 92 for i := 0; i < n; i++ { 93 cHer[i*ldc+i] = complex(real(cHer[i*ldc+i]), 0) 94 for j := i + 1; j < n; j++ { 95 cHer[j*ldc+i] = cmplx.Conj(cHer[i*ldc+j]) 96 } 97 } 98 } else { 99 for i := 0; i < n; i++ { 100 for j := 0; j < i; j++ { 101 cHer[j*ldc+i] = cmplx.Conj(cHer[i*ldc+j]) 102 } 103 cHer[i*ldc+i] = complex(real(cHer[i*ldc+i]), 0) 104 } 105 } 106 107 // Compute the expected result using an internal Zgemm implementation. 108 var want []complex128 109 if trans == blas.NoTrans { 110 // C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C 111 tmp := zmm(blas.NoTrans, blas.ConjTrans, n, n, k, alpha, a, lda, b, ldb, complex(beta, 0), cHer, ldc) 112 want = zmm(blas.NoTrans, blas.ConjTrans, n, n, k, cmplx.Conj(alpha), b, ldb, a, lda, 1, tmp, ldc) 113 } else { 114 // C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C 115 tmp := zmm(blas.ConjTrans, blas.NoTrans, n, n, k, alpha, a, lda, b, ldb, complex(beta, 0), cHer, ldc) 116 want = zmm(blas.ConjTrans, blas.NoTrans, n, n, k, cmplx.Conj(alpha), b, ldb, a, lda, 1, tmp, ldc) 117 } 118 119 // Compute the result using Zher2k. 120 impl.Zher2k(uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc) 121 122 prefix := fmt.Sprintf("n=%v,k=%v,lda=%v,ldb=%v,ldc=%v,alpha=%v,beta=%v", n, k, lda, ldb, ldc, alpha, beta) 123 124 if !zsame(a, aCopy) { 125 t.Errorf("%v: unexpected modification of A", prefix) 126 continue 127 } 128 if !zsame(b, bCopy) { 129 t.Errorf("%v: unexpected modification of B", prefix) 130 continue 131 } 132 if uplo == blas.Upper && !zSameLowerTri(n, c, ldc, cCopy, ldc) { 133 t.Errorf("%v: unexpected modification in lower triangle of C", prefix) 134 continue 135 } 136 if uplo == blas.Lower && !zSameUpperTri(n, c, ldc, cCopy, ldc) { 137 t.Errorf("%v: unexpected modification in upper triangle of C", prefix) 138 continue 139 } 140 141 // Check that the diagonal of C has only real elements. 142 hasRealDiag := true 143 for i := 0; i < n; i++ { 144 if imag(c[i*ldc+i]) != 0 { 145 hasRealDiag = false 146 break 147 } 148 } 149 if !hasRealDiag { 150 t.Errorf("%v: diagonal of C has imaginary elements\ngot=%v", prefix, c) 151 continue 152 } 153 154 // Expand C into a full hermitian matrix 155 // for comparison with the result from zmm. 156 if uplo == blas.Upper { 157 for i := 0; i < n-1; i++ { 158 for j := i + 1; j < n; j++ { 159 c[j*ldc+i] = cmplx.Conj(c[i*ldc+j]) 160 } 161 } 162 } else { 163 for i := 1; i < n; i++ { 164 for j := 0; j < i; j++ { 165 c[j*ldc+i] = cmplx.Conj(c[i*ldc+j]) 166 } 167 } 168 } 169 if !zEqualApprox(c, want, tol) { 170 t.Errorf("%v: unexpected result\nwant=%v\ngot= %v", prefix, want, c) 171 } 172 } 173 } 174 } 175 } 176 } 177 }