gonum.org/v1/gonum@v0.14.0/blas/testblas/zherk.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testblas 6 7 import ( 8 "fmt" 9 "math/cmplx" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "gonum.org/v1/gonum/blas" 15 ) 16 17 type Zherker interface { 18 Zherk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha float64, a []complex128, lda int, beta float64, c []complex128, ldc int) 19 } 20 21 func ZherkTest(t *testing.T, impl Zherker) { 22 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 23 for _, trans := range []blas.Transpose{blas.NoTrans, blas.ConjTrans} { 24 name := uploString(uplo) + "-" + transString(trans) 25 t.Run(name, func(t *testing.T) { 26 for _, n := range []int{0, 1, 2, 3, 4, 5} { 27 for _, k := range []int{0, 1, 2, 3, 4, 5, 7} { 28 zherkTest(t, impl, uplo, trans, n, k) 29 } 30 } 31 }) 32 } 33 } 34 } 35 36 func zherkTest(t *testing.T, impl Zherker, uplo blas.Uplo, trans blas.Transpose, n, k int) { 37 const tol = 1e-13 38 39 rnd := rand.New(rand.NewSource(1)) 40 41 rowA, colA := n, k 42 if trans == blas.ConjTrans { 43 rowA, colA = k, n 44 } 45 for _, lda := range []int{max(1, colA), colA + 2} { 46 for _, ldc := range []int{max(1, n), n + 4} { 47 for _, alpha := range []float64{0, 1, 0.7} { 48 for _, beta := range []float64{0, 1, 1.3} { 49 // Allocate the matrix A and fill it with random numbers. 50 a := make([]complex128, rowA*lda) 51 for i := range a { 52 a[i] = rndComplex128(rnd) 53 } 54 // Create a copy of A for checking that 55 // Zherk does not modify A. 56 aCopy := make([]complex128, len(a)) 57 copy(aCopy, a) 58 59 // Allocate the matrix C and fill it with random numbers. 60 c := make([]complex128, n*ldc) 61 for i := range c { 62 c[i] = rndComplex128(rnd) 63 } 64 if (alpha == 0 || k == 0) && beta == 1 { 65 // In case of a quick return 66 // zero out the diagonal. 67 for i := 0; i < n; i++ { 68 c[i*ldc+i] = complex(real(c[i*ldc+i]), 0) 69 } 70 } 71 // Create a copy of C for checking that 72 // Zherk does not modify its triangle 73 // opposite to uplo. 74 cCopy := make([]complex128, len(c)) 75 copy(cCopy, c) 76 // Create a copy of C expanded into a 77 // full hermitian matrix for computing 78 // the expected result using zmm. 79 cHer := make([]complex128, len(c)) 80 copy(cHer, c) 81 if uplo == blas.Upper { 82 for i := 0; i < n; i++ { 83 cHer[i*ldc+i] = complex(real(cHer[i*ldc+i]), 0) 84 for j := i + 1; j < n; j++ { 85 cHer[j*ldc+i] = cmplx.Conj(cHer[i*ldc+j]) 86 } 87 } 88 } else { 89 for i := 0; i < n; i++ { 90 for j := 0; j < i; j++ { 91 cHer[j*ldc+i] = cmplx.Conj(cHer[i*ldc+j]) 92 } 93 cHer[i*ldc+i] = complex(real(cHer[i*ldc+i]), 0) 94 } 95 } 96 97 // Compute the expected result using an internal Zgemm implementation. 98 var want []complex128 99 if trans == blas.NoTrans { 100 want = zmm(blas.NoTrans, blas.ConjTrans, n, n, k, complex(alpha, 0), a, lda, a, lda, complex(beta, 0), cHer, ldc) 101 } else { 102 want = zmm(blas.ConjTrans, blas.NoTrans, n, n, k, complex(alpha, 0), a, lda, a, lda, complex(beta, 0), cHer, ldc) 103 } 104 105 // Compute the result using Zherk. 106 impl.Zherk(uplo, trans, n, k, alpha, a, lda, beta, c, ldc) 107 108 prefix := fmt.Sprintf("n=%v,k=%v,lda=%v,ldc=%v,alpha=%v,beta=%v", n, k, lda, ldc, alpha, beta) 109 110 if !zsame(a, aCopy) { 111 t.Errorf("%v: unexpected modification of A", prefix) 112 continue 113 } 114 if uplo == blas.Upper && !zSameLowerTri(n, c, ldc, cCopy, ldc) { 115 t.Errorf("%v: unexpected modification in lower triangle of C", prefix) 116 continue 117 } 118 if uplo == blas.Lower && !zSameUpperTri(n, c, ldc, cCopy, ldc) { 119 t.Errorf("%v: unexpected modification in upper triangle of C", prefix) 120 continue 121 } 122 123 // Check that the diagonal of C has only real elements. 124 hasRealDiag := true 125 for i := 0; i < n; i++ { 126 if imag(c[i*ldc+i]) != 0 { 127 hasRealDiag = false 128 break 129 } 130 } 131 if !hasRealDiag { 132 t.Errorf("%v: diagonal of C has imaginary elements\ngot=%v", prefix, c) 133 continue 134 } 135 136 // Expand C into a full hermitian matrix 137 // for comparison with the result from zmm. 138 if uplo == blas.Upper { 139 for i := 0; i < n-1; i++ { 140 for j := i + 1; j < n; j++ { 141 c[j*ldc+i] = cmplx.Conj(c[i*ldc+j]) 142 } 143 } 144 } else { 145 for i := 1; i < n; i++ { 146 for j := 0; j < i; j++ { 147 c[j*ldc+i] = cmplx.Conj(c[i*ldc+j]) 148 } 149 } 150 } 151 if !zEqualApprox(c, want, tol) { 152 t.Errorf("%v: unexpected result\nwant=%v\ngot= %v", prefix, want, c) 153 } 154 } 155 } 156 } 157 } 158 }