github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/testblas/zsyrk.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testblas 6 7 import ( 8 "fmt" 9 "math/cmplx" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "github.com/jingcheng-WU/gonum/blas" 15 ) 16 17 type Zsyrker interface { 18 Zsyrk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int) 19 } 20 21 func ZsyrkTest(t *testing.T, impl Zsyrker) { 22 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 23 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} { 24 name := uploString(uplo) + "-" + transString(trans) 25 t.Run(name, func(t *testing.T) { 26 for _, n := range []int{0, 1, 2, 3, 4, 5} { 27 for _, k := range []int{0, 1, 2, 3, 4, 5, 7} { 28 zsyrkTest(t, impl, uplo, trans, n, k) 29 } 30 } 31 }) 32 } 33 } 34 } 35 36 func zsyrkTest(t *testing.T, impl Zsyrker, uplo blas.Uplo, trans blas.Transpose, n, k int) { 37 const tol = 1e-13 38 39 rnd := rand.New(rand.NewSource(1)) 40 41 rowA, colA := n, k 42 if trans == blas.Trans { 43 rowA, colA = k, n 44 } 45 for _, lda := range []int{max(1, colA), colA + 2} { 46 for _, ldc := range []int{max(1, n), n + 4} { 47 for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} { 48 for _, beta := range []complex128{0, 1, complex(1.3, -1.1)} { 49 for _, nanC := range []bool{false, true} { 50 if nanC && beta != 0 { 51 // Skip tests with C containing NaN values 52 // unless beta would zero out the NaNs. 53 continue 54 } 55 56 // Allocate the matrix A and fill it with random numbers. 57 a := make([]complex128, rowA*lda) 58 for i := range a { 59 a[i] = rndComplex128(rnd) 60 } 61 // Create a copy of A for checking that 62 // Zsyrk does not modify A. 63 aCopy := make([]complex128, len(a)) 64 copy(aCopy, a) 65 66 // Allocate the matrix C and fill it with random numbers. 67 c := make([]complex128, n*ldc) 68 for i := range c { 69 c[i] = rndComplex128(rnd) 70 } 71 if nanC { 72 for i := 0; i < n; i++ { 73 for j := 0; j < n; j++ { 74 c[i+j*ldc] = cmplx.NaN() 75 } 76 } 77 } 78 79 // Create a copy of C for checking that 80 // Zsyrk does not modify its triangle 81 // opposite to uplo. 82 cCopy := make([]complex128, len(c)) 83 copy(cCopy, c) 84 // Create a copy of C expanded into a 85 // full symmetric matrix for computing 86 // the expected result using zmm. 87 cSym := make([]complex128, len(c)) 88 copy(cSym, c) 89 if uplo == blas.Upper { 90 for i := 0; i < n-1; i++ { 91 for j := i + 1; j < n; j++ { 92 cSym[j*ldc+i] = cSym[i*ldc+j] 93 } 94 } 95 } else { 96 for i := 1; i < n; i++ { 97 for j := 0; j < i; j++ { 98 cSym[j*ldc+i] = cSym[i*ldc+j] 99 } 100 } 101 } 102 103 // Compute the expected result using an internal Zgemm implementation. 104 var want []complex128 105 if trans == blas.NoTrans { 106 want = zmm(blas.NoTrans, blas.Trans, n, n, k, alpha, a, lda, a, lda, beta, cSym, ldc) 107 } else { 108 want = zmm(blas.Trans, blas.NoTrans, n, n, k, alpha, a, lda, a, lda, beta, cSym, ldc) 109 } 110 111 // Compute the result using Zsyrk. 112 impl.Zsyrk(uplo, trans, n, k, alpha, a, lda, beta, c, ldc) 113 114 prefix := fmt.Sprintf("n=%v,k=%v,lda=%v,ldc=%v,alpha=%v,beta=%v", n, k, lda, ldc, alpha, beta) 115 116 if !zsame(a, aCopy) { 117 t.Errorf("%v: unexpected modification of A", prefix) 118 continue 119 } 120 if uplo == blas.Upper && !zSameLowerTri(n, c, ldc, cCopy, ldc) { 121 t.Errorf("%v: unexpected modification in lower triangle of C", prefix) 122 continue 123 } 124 if uplo == blas.Lower && !zSameUpperTri(n, c, ldc, cCopy, ldc) { 125 t.Errorf("%v: unexpected modification in upper triangle of C", prefix) 126 continue 127 } 128 129 // Expand C into a full symmetric matrix 130 // for comparison with the result from zmm. 131 if uplo == blas.Upper { 132 for i := 0; i < n-1; i++ { 133 for j := i + 1; j < n; j++ { 134 c[j*ldc+i] = c[i*ldc+j] 135 } 136 } 137 } else { 138 for i := 1; i < n; i++ { 139 for j := 0; j < i; j++ { 140 c[j*ldc+i] = c[i*ldc+j] 141 } 142 } 143 } 144 if !zEqualApprox(c, want, tol) { 145 t.Errorf("%v: unexpected result", prefix) 146 } 147 } 148 } 149 } 150 } 151 } 152 }