gonum.org/v1/gonum@v0.14.0/blas/testblas/ztrmm.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testblas 6 7 import ( 8 "fmt" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas" 14 ) 15 16 type Ztrmmer interface { 17 Ztrmm(side blas.Side, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) 18 } 19 20 func ZtrmmTest(t *testing.T, impl Ztrmmer) { 21 for _, side := range []blas.Side{blas.Left, blas.Right} { 22 for _, uplo := range []blas.Uplo{blas.Lower, blas.Upper} { 23 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} { 24 for _, diag := range []blas.Diag{blas.Unit, blas.NonUnit} { 25 name := sideString(side) + "-" + uploString(uplo) + "-" + transString(trans) + "-" + diagString(diag) 26 t.Run(name, func(t *testing.T) { 27 for _, m := range []int{0, 1, 2, 3, 4, 5} { 28 for _, n := range []int{0, 1, 2, 3, 4, 5} { 29 ztrmmTest(t, impl, side, uplo, trans, diag, m, n) 30 } 31 } 32 }) 33 } 34 } 35 } 36 } 37 } 38 39 func ztrmmTest(t *testing.T, impl Ztrmmer, side blas.Side, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, m, n int) { 40 const tol = 1e-13 41 42 rnd := rand.New(rand.NewSource(1)) 43 44 nA := m 45 if side == blas.Right { 46 nA = n 47 } 48 for _, lda := range []int{max(1, nA), nA + 2} { 49 for _, ldb := range []int{max(1, n), n + 3} { 50 for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} { 51 // Allocate the matrix A and fill it with random numbers. 52 a := make([]complex128, nA*lda) 53 for i := range a { 54 a[i] = rndComplex128(rnd) 55 } 56 // Put a zero into A to cover special cases in Ztrmm. 57 if nA > 1 { 58 if uplo == blas.Upper { 59 a[nA-1] = 0 60 } else { 61 a[(nA-1)*lda] = 0 62 } 63 } 64 // Create a copy of A for checking that Ztrmm 65 // does not modify its triangle opposite to 66 // uplo. 67 aCopy := make([]complex128, len(a)) 68 copy(aCopy, a) 69 // Create a dense representation of A for 70 // computing the expected result using zmm. 71 aTri := make([]complex128, len(a)) 72 copy(aTri, a) 73 if uplo == blas.Upper { 74 for i := 0; i < nA; i++ { 75 // Zero out the lower triangle. 76 for j := 0; j < i; j++ { 77 aTri[i*lda+j] = 0 78 } 79 if diag == blas.Unit { 80 aTri[i*lda+i] = 1 81 } 82 } 83 } else { 84 for i := 0; i < nA; i++ { 85 if diag == blas.Unit { 86 aTri[i*lda+i] = 1 87 } 88 // Zero out the upper triangle. 89 for j := i + 1; j < nA; j++ { 90 aTri[i*lda+j] = 0 91 } 92 } 93 } 94 95 // Allocate the matrix B and fill it with random numbers. 96 b := make([]complex128, m*ldb) 97 for i := range b { 98 b[i] = rndComplex128(rnd) 99 } 100 // Put a zero into B to cover special cases in Ztrmm. 101 if m > 0 && n > 0 { 102 b[0] = 0 103 } 104 105 // Compute the expected result using an internal Zgemm implementation. 106 var want []complex128 107 if side == blas.Left { 108 want = zmm(trans, blas.NoTrans, m, n, m, alpha, aTri, lda, b, ldb, 0, b, ldb) 109 } else { 110 want = zmm(blas.NoTrans, trans, m, n, n, alpha, b, ldb, aTri, lda, 0, b, ldb) 111 } 112 113 // Compute the result using Ztrmm. 114 impl.Ztrmm(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) 115 116 prefix := fmt.Sprintf("m=%v,n=%v,lda=%v,ldb=%v,alpha=%v", m, n, lda, ldb, alpha) 117 if !zsame(a, aCopy) { 118 t.Errorf("%v: unexpected modification of A", prefix) 119 continue 120 } 121 122 if !zEqualApprox(b, want, tol) { 123 t.Errorf("%v: unexpected result", prefix) 124 } 125 } 126 } 127 } 128 }