gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dtbtrs.go (about) 1 // Copyright ©2020 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "math" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "gonum.org/v1/gonum/blas" 15 "gonum.org/v1/gonum/blas/blas64" 16 "gonum.org/v1/gonum/floats" 17 "gonum.org/v1/gonum/lapack" 18 ) 19 20 type Dtbtrser interface { 21 Dtbtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, kd, nrhs int, a []float64, lda int, b []float64, ldb int) bool 22 } 23 24 func DtbtrsTest(t *testing.T, impl Dtbtrser) { 25 rnd := rand.New(rand.NewSource(1)) 26 27 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} { 28 name := transToString(trans) 29 t.Run(name, func(t *testing.T) { 30 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 31 for _, diag := range []blas.Diag{blas.Unit, blas.NonUnit} { 32 for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 23} { 33 for _, kd := range []int{0, 1, 2, n / 2, max(0, n-1), n, n + 5} { 34 for _, nrhs := range []int{0, 1, 2, 3, 4, 5} { 35 for _, lda := range []int{kd + 1, kd + 3} { 36 for _, ldb := range []int{max(1, nrhs), nrhs + 3} { 37 if diag == blas.Unit { 38 dtbtrsTest(t, impl, rnd, uplo, trans, diag, n, kd, nrhs, lda, ldb, false) 39 } else { 40 dtbtrsTest(t, impl, rnd, uplo, trans, diag, n, kd, nrhs, lda, ldb, true) 41 dtbtrsTest(t, impl, rnd, uplo, trans, diag, n, kd, nrhs, lda, ldb, false) 42 } 43 } 44 } 45 } 46 } 47 } 48 } 49 } 50 }) 51 } 52 } 53 54 func dtbtrsTest(t *testing.T, impl Dtbtrser, rnd *rand.Rand, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, kd, nrhs int, lda, ldb int, singular bool) { 55 if singular && diag == blas.Unit { 56 panic("blas.Unit triangular matrix cannot be singular") 57 } 58 59 const tol = 1e-14 60 61 if n == 0 { 62 singular = false 63 } 64 name := fmt.Sprintf("uplo=%v,diag=%v,n=%v,kd=%v,nrhs=%v,lda=%v,ldb=%v,sing=%v", string(uplo), string(diag), n, kd, nrhs, lda, ldb, singular) 65 66 // Generate a random triangular matrix A. One of its triangles won't be 67 // referenced. 68 a := make([]float64, n*lda) 69 for i := range a { 70 a[i] = rnd.NormFloat64() 71 } 72 if singular { 73 i := rnd.Intn(n) 74 if uplo == blas.Upper { 75 a[i*lda] = 0 76 } else { 77 a[i*lda+kd] = 0 78 } 79 } 80 aCopy := make([]float64, len(a)) 81 copy(aCopy, a) 82 83 // Generate a random solution matrix X. 84 x := make([]float64, n*ldb) 85 for i := range x { 86 x[i] = rnd.NormFloat64() 87 } 88 89 // Generate the right-hand side B as A * X or Aᵀ * X. 90 b := make([]float64, len(x)) 91 copy(b, x) 92 bi := blas64.Implementation() 93 if n > 0 { 94 for j := 0; j < nrhs; j++ { 95 bi.Dtbmv(uplo, trans, diag, n, kd, a, lda, b[j:], ldb) 96 } 97 } 98 99 got := make([]float64, len(b)) 100 copy(got, b) 101 ok := impl.Dtbtrs(uplo, trans, diag, n, kd, nrhs, a, lda, got, ldb) 102 103 if !floats.Equal(a, aCopy) { 104 t.Errorf("%v: unexpected modification of A", name) 105 } 106 107 if ok == singular { 108 t.Errorf("%v: misdetected singular matrix, ok=%v", name, ok) 109 } 110 111 if !ok { 112 if !floats.Equal(got, b) { 113 t.Errorf("%v: unexpected modification of B when singular", name) 114 } 115 return 116 } 117 118 if n == 0 || nrhs == 0 { 119 return 120 } 121 122 work := make([]float64, n) 123 124 // Compute the 1-norm of A or Aᵀ. 125 var aNorm float64 126 if trans == blas.NoTrans { 127 aNorm = dlantb(lapack.MaxColumnSum, uplo, diag, n, kd, a, lda, work) 128 } else { 129 aNorm = dlantb(lapack.MaxRowSum, uplo, diag, n, kd, a, lda, work) 130 } 131 132 // Compute the maximum over the number of right-hand sides of 133 // |op(A)*x-b| / (|op(A)| * |x|) 134 var resid float64 135 for j := 0; j < nrhs; j++ { 136 bi.Dcopy(n, got[j:], ldb, work, 1) 137 bi.Dtbmv(uplo, trans, diag, n, kd, a, lda, work, 1) 138 bi.Daxpy(n, -1, b[j:], ldb, work, 1) 139 rjNorm := bi.Dasum(n, work, 1) 140 xNorm := bi.Dasum(n, got[j:], ldb) 141 resid = math.Max(resid, rjNorm/aNorm/xNorm) 142 } 143 if resid > tol { 144 t.Errorf("%v: unexpected result; resid=%v,want<=%v", name, resid, tol) 145 } 146 }