gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dgesv.go (about) 1 // Copyright ©2021 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "math" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "gonum.org/v1/gonum/blas" 15 "gonum.org/v1/gonum/blas/blas64" 16 "gonum.org/v1/gonum/lapack" 17 ) 18 19 type Dgesver interface { 20 Dgesv(n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) bool 21 22 Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) bool 23 } 24 25 func DgesvTest(t *testing.T, impl Dgesver) { 26 rnd := rand.New(rand.NewSource(1)) 27 for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 50, 100} { 28 for _, nrhs := range []int{0, 1, 2, 5} { 29 for _, lda := range []int{max(1, n), n + 5} { 30 for _, ldb := range []int{max(1, nrhs), nrhs + 5} { 31 dgesvTest(t, impl, rnd, n, nrhs, lda, ldb) 32 } 33 } 34 } 35 } 36 } 37 38 func dgesvTest(t *testing.T, impl Dgesver, rnd *rand.Rand, n, nrhs, lda, ldb int) { 39 const tol = 1e-15 40 41 name := fmt.Sprintf("n=%v,nrhs=%v,lda=%v,ldb=%v", n, nrhs, lda, ldb) 42 43 // Create a random system matrix A and the solution X. 44 a := randomGeneral(n, n, lda, rnd) 45 xWant := randomGeneral(n, nrhs, ldb, rnd) 46 47 // Compute the right hand side matrix B = A*X. 48 b := zeros(n, nrhs, ldb) 49 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, xWant, 0, b) 50 51 // Allocate a slice for row pivots and fill it with invalid indices. 52 ipiv := make([]int, n) 53 for i := range ipiv { 54 ipiv[i] = -1 55 } 56 57 // Call Dgesv to solve A*X = B. 58 lu := cloneGeneral(a) 59 xGot := cloneGeneral(b) 60 ok := impl.Dgesv(n, nrhs, lu.Data, lu.Stride, ipiv, xGot.Data, xGot.Stride) 61 62 if !ok { 63 t.Errorf("%v: unexpected failure in Dgesv", name) 64 return 65 } 66 67 if n == 0 || nrhs == 0 { 68 return 69 } 70 71 // Check that all elements of ipiv have been set. 72 ipivSet := true 73 for _, ipv := range ipiv { 74 if ipv == -1 { 75 ipivSet = false 76 break 77 } 78 } 79 if !ipivSet { 80 t.Fatalf("%v: not all elements of ipiv set", name) 81 return 82 } 83 84 // Compute the reciprocal of the condition number of A from its LU 85 // decomposition before it's overwritten further below. 86 aInv := cloneGeneral(lu) 87 impl.Dgetri(n, aInv.Data, aInv.Stride, ipiv, make([]float64, n), n) 88 ainvnorm := dlange(lapack.MaxColumnSum, n, n, aInv.Data, aInv.Stride) 89 anorm := dlange(lapack.MaxColumnSum, n, n, a.Data, a.Stride) 90 rcond := 1 / anorm / ainvnorm 91 92 // Reconstruct matrix A from factors and compute residual. 93 // 94 // Extract L and U from lu. 95 l := zeros(n, n, n) 96 u := zeros(n, n, n) 97 for i := 0; i < n; i++ { 98 for j := 0; j < i; j++ { 99 l.Data[i*l.Stride+j] = lu.Data[i*lu.Stride+j] 100 } 101 l.Data[i*l.Stride+i] = 1 102 for j := i; j < n; j++ { 103 u.Data[i*u.Stride+j] = lu.Data[i*lu.Stride+j] 104 } 105 } 106 // Compute L*U. 107 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, l, u, 0, lu) 108 // Apply P to L*U. 109 for i := n - 1; i >= 0; i-- { 110 ip := ipiv[i] 111 if ip == i { 112 continue 113 } 114 row1 := blas64.Vector{N: n, Data: lu.Data[i*lu.Stride:], Inc: 1} 115 row2 := blas64.Vector{N: n, Data: lu.Data[ip*lu.Stride:], Inc: 1} 116 blas64.Swap(row1, row2) 117 } 118 // Compute P*L*U - A. 119 for i := 0; i < n; i++ { 120 for j := 0; j < n; j++ { 121 lu.Data[i*lu.Stride+j] -= a.Data[i*a.Stride+j] 122 } 123 } 124 // Compute the residual |P*L*U - A|. 125 resid := dlange(lapack.MaxColumnSum, n, n, lu.Data, lu.Stride) 126 resid /= float64(n) * anorm 127 if resid > tol || math.IsNaN(resid) { 128 t.Errorf("%v: residual |P*L*U - A| is too large, got %v, want <= %v", name, resid, tol) 129 } 130 131 // Compute residual of the computed solution. 132 // 133 // Compute B - A*X. 134 blas64.Gemm(blas.NoTrans, blas.NoTrans, -1, a, xGot, 1, b) 135 // Compute the maximum over the number of right hand sides of |B - A*X| / (|A| * |X|). 136 resid = 0 137 for j := 0; j < nrhs; j++ { 138 bnorm := blas64.Asum(blas64.Vector{N: n, Data: b.Data[j:], Inc: b.Stride}) 139 xnorm := blas64.Asum(blas64.Vector{N: n, Data: xGot.Data[j:], Inc: xGot.Stride}) 140 resid = math.Max(resid, bnorm/anorm/xnorm) 141 } 142 if resid > tol || math.IsNaN(resid) { 143 t.Errorf("%v: residual |B - A*X| is too large, got %v, want <= %v", name, resid, tol) 144 } 145 146 // Compare the computed solution with the generated exact solution. 147 // 148 // Compute X - XWANT. 149 for i := 0; i < n; i++ { 150 for j := 0; j < nrhs; j++ { 151 xGot.Data[i*xGot.Stride+j] -= xWant.Data[i*xWant.Stride+j] 152 } 153 } 154 // Compute the maximum of |X - XWANT|/|XWANT| over all the vectors X and XWANT. 155 resid = 0 156 for j := 0; j < nrhs; j++ { 157 xnorm := dlange(lapack.MaxAbs, n, 1, xWant.Data[j:], xWant.Stride) 158 diff := dlange(lapack.MaxAbs, n, 1, xGot.Data[j:], xGot.Stride) 159 resid = math.Max(resid, diff/xnorm*rcond) 160 } 161 if resid > tol || math.IsNaN(resid) { 162 t.Errorf("%v: residual |X-XWANT| is too large, got %v, want <= %v", name, resid, tol) 163 } 164 }