gonum.org/v1/gonum@v0.14.0/blas/testblas/ztrsm.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testblas 6 7 import ( 8 "fmt" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas" 14 ) 15 16 type Ztrsmer interface { 17 Ztrsm(side blas.Side, uplo blas.Uplo, transA blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) 18 } 19 20 func ZtrsmTest(t *testing.T, impl Ztrsmer) { 21 for _, side := range []blas.Side{blas.Left, blas.Right} { 22 for _, uplo := range []blas.Uplo{blas.Lower, blas.Upper} { 23 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} { 24 for _, diag := range []blas.Diag{blas.Unit, blas.NonUnit} { 25 name := sideString(side) + "-" + uploString(uplo) + "-" + transString(trans) + "-" + diagString(diag) 26 t.Run(name, func(t *testing.T) { 27 for _, m := range []int{0, 1, 2, 3, 4, 5} { 28 for _, n := range []int{0, 1, 2, 3, 4, 5} { 29 ztrsmTest(t, impl, side, uplo, trans, diag, m, n) 30 } 31 } 32 }) 33 } 34 } 35 } 36 } 37 } 38 39 func ztrsmTest(t *testing.T, impl Ztrsmer, side blas.Side, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, m, n int) { 40 const tol = 1e-13 41 42 rnd := rand.New(rand.NewSource(1)) 43 44 nA := m 45 if side == blas.Right { 46 nA = n 47 } 48 for _, lda := range []int{max(1, nA), nA + 2} { 49 for _, ldb := range []int{max(1, n), n + 3} { 50 for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} { 51 // Allocate the matrix A and fill it with random numbers. 52 a := make([]complex128, nA*lda) 53 for i := range a { 54 a[i] = rndComplex128(rnd) 55 } 56 // Set some elements of A to 0 and 1 to cover special cases in Ztrsm. 57 if nA > 2 { 58 if uplo == blas.Upper { 59 a[nA-2] = 1 60 a[nA-1] = 0 61 } else { 62 a[(nA-2)*lda] = 1 63 a[(nA-1)*lda] = 0 64 } 65 } 66 // Create a copy of A for checking that Ztrsm 67 // does not modify its triangle opposite to uplo. 68 aCopy := make([]complex128, len(a)) 69 copy(aCopy, a) 70 // Create a dense representation of A for 71 // computing the right-hand side matrix using zmm. 72 aTri := make([]complex128, len(a)) 73 copy(aTri, a) 74 if uplo == blas.Upper { 75 for i := 0; i < nA; i++ { 76 // Zero out the lower triangle. 77 for j := 0; j < i; j++ { 78 aTri[i*lda+j] = 0 79 } 80 if diag == blas.Unit { 81 aTri[i*lda+i] = 1 82 } 83 } 84 } else { 85 for i := 0; i < nA; i++ { 86 if diag == blas.Unit { 87 aTri[i*lda+i] = 1 88 } 89 // Zero out the upper triangle. 90 for j := i + 1; j < nA; j++ { 91 aTri[i*lda+j] = 0 92 } 93 } 94 } 95 96 // Allocate the right-hand side matrix B and fill it with random numbers. 97 b := make([]complex128, m*ldb) 98 for i := range b { 99 b[i] = rndComplex128(rnd) 100 } 101 // Set some elements of B to 0 to cover special cases in Ztrsm. 102 if m > 1 && n > 1 { 103 b[0] = 0 104 b[(m-1)*ldb+n-1] = 0 105 } 106 bCopy := make([]complex128, len(b)) 107 copy(bCopy, b) 108 109 // Compute the solution matrix X using Ztrsm. 110 // X is overwritten on B. 111 impl.Ztrsm(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) 112 x := b 113 114 prefix := fmt.Sprintf("m=%v,n=%v,lda=%v,ldb=%v,alpha=%v", m, n, lda, ldb, alpha) 115 116 if !zsame(a, aCopy) { 117 t.Errorf("%v: unexpected modification of A", prefix) 118 continue 119 } 120 121 // Compute the left-hand side matrix of op(A)*X=alpha*B or X*op(A)=alpha*B 122 // using an internal Zgemm implementation. 123 var lhs []complex128 124 if side == blas.Left { 125 lhs = zmm(trans, blas.NoTrans, m, n, m, 1, aTri, lda, x, ldb, 0, b, ldb) 126 } else { 127 lhs = zmm(blas.NoTrans, trans, m, n, n, 1, x, ldb, aTri, lda, 0, b, ldb) 128 } 129 130 // Compute the right-hand side matrix alpha*B. 131 rhs := bCopy 132 for i := 0; i < m; i++ { 133 for j := 0; j < n; j++ { 134 rhs[i*ldb+j] *= alpha 135 } 136 } 137 138 if !zEqualApprox(lhs, rhs, tol) { 139 t.Errorf("%v: unexpected result", prefix) 140 } 141 } 142 } 143 } 144 }