gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dpstrf.go (about) 1 // Copyright ©2021 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "math" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "gonum.org/v1/gonum/blas" 15 "gonum.org/v1/gonum/blas/blas64" 16 "gonum.org/v1/gonum/lapack" 17 ) 18 19 type Dpstrfer interface { 20 Dpstrf(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) 21 } 22 23 func DpstrfTest(t *testing.T, impl Dpstrfer) { 24 rnd := rand.New(rand.NewSource(1)) 25 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 26 t.Run(uploToString(uplo), func(t *testing.T) { 27 for _, n := range []int{0, 1, 2, 3, 4, 5, 31, 32, 33, 63, 64, 65, 127, 128, 129} { 28 for _, lda := range []int{max(1, n), n + 5} { 29 for _, rank := range []int{int(0.7 * float64(n)), n} { 30 dpstrfTest(t, impl, rnd, uplo, n, lda, rank) 31 } 32 } 33 } 34 }) 35 } 36 } 37 38 func dpstrfTest(t *testing.T, impl Dpstrfer, rnd *rand.Rand, uplo blas.Uplo, n, lda, rankWant int) { 39 const tol = 1e-13 40 41 name := fmt.Sprintf("n=%v,lda=%v", n, lda) 42 bi := blas64.Implementation() 43 44 // Generate a random, symmetric A with the given rank by applying rankWant 45 // rank-1 updates to the zero matrix. 46 a := make([]float64, n*lda) 47 for i := 0; i < rankWant; i++ { 48 x := randomSlice(n, rnd) 49 bi.Dsyr(uplo, n, 1, x, 1, a, lda) 50 } 51 52 // Make a copy of A for storing the factorization. 53 aFac := make([]float64, len(a)) 54 copy(aFac, a) 55 56 // Allocate a slice for pivots and fill it with invalid index values. 57 piv := make([]int, n) 58 for i := range piv { 59 piv[i] = -1 60 } 61 62 // Allocate the work slice. 63 work := make([]float64, 2*n) 64 65 // Call Dpstrf to Compute the Cholesky factorization with complete pivoting. 66 rank, ok := impl.Dpstrf(uplo, n, aFac, lda, piv, -1, work) 67 68 if ok != (rank == n) { 69 t.Errorf("%v: unexpected ok; got %v, want %v", name, ok, rank == n) 70 } 71 if rank != rankWant { 72 t.Errorf("%v: unexpected rank; got %v, want %v", name, rank, rankWant) 73 } 74 75 if n == 0 { 76 return 77 } 78 79 // Check that the residual |P*Uᵀ*U*Pᵀ - A| / n or |P*L*Lᵀ*Pᵀ - A| / n is 80 // sufficiently small. 81 resid := residualDpstrf(uplo, n, a, aFac, lda, rank, piv) 82 if resid > tol || math.IsNaN(resid) { 83 t.Errorf("%v: residual too large; got %v, want<=%v", name, resid, tol) 84 } 85 } 86 87 func residualDpstrf(uplo blas.Uplo, n int, a, aFac []float64, lda int, rank int, piv []int) float64 { 88 bi := blas64.Implementation() 89 // Reconstruct the symmetric positive semi-definite matrix A from its L or U 90 // factors and the permutation matrix P. 91 perm := zeros(n, n, n) 92 if uplo == blas.Upper { 93 // Change notation. 94 u, ldu := aFac, lda 95 // Zero out last n-rank rows of the factor U. 96 for i := rank; i < n; i++ { 97 for j := i; j < n; j++ { 98 u[i*ldu+j] = 0 99 } 100 } 101 // Extract U to aRec. 102 aRec := zeros(n, n, n) 103 for i := 0; i < n; i++ { 104 for j := i; j < n; j++ { 105 aRec.Data[i*aRec.Stride+j] = u[i*ldu+j] 106 } 107 } 108 // Multiply U by Uᵀ from the left. 109 bi.Dtrmm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit, n, n, 110 1, u, ldu, aRec.Data, aRec.Stride) 111 // Form P * Uᵀ * U * Pᵀ. 112 for i := 0; i < n; i++ { 113 for j := 0; j < n; j++ { 114 if piv[i] > piv[j] { 115 // Don't set the lower triangle. 116 continue 117 } 118 if i <= j { 119 perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[i*aRec.Stride+j] 120 } else { 121 perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[j*aRec.Stride+i] 122 } 123 } 124 } 125 // Compute the difference P*Uᵀ*U*Pᵀ - A. 126 for i := 0; i < n; i++ { 127 for j := i; j < n; j++ { 128 perm.Data[i*perm.Stride+j] -= a[i*lda+j] 129 } 130 } 131 } else { 132 // Change notation. 133 l, ldl := aFac, lda 134 // Zero out last n-rank columns of the factor L. 135 for i := rank; i < n; i++ { 136 for j := rank; j <= i; j++ { 137 l[i*ldl+j] = 0 138 } 139 } 140 // Extract L to aRec. 141 aRec := zeros(n, n, n) 142 for i := 0; i < n; i++ { 143 for j := 0; j <= i; j++ { 144 aRec.Data[i*aRec.Stride+j] = l[i*ldl+j] 145 } 146 } 147 // Multiply L by Lᵀ from the right. 148 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.NonUnit, n, n, 149 1, l, ldl, aRec.Data, aRec.Stride) 150 // Form P * L * Lᵀ * Pᵀ. 151 for i := 0; i < n; i++ { 152 for j := 0; j < n; j++ { 153 if piv[i] < piv[j] { 154 // Don't set the upper triangle. 155 continue 156 } 157 if i >= j { 158 perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[i*aRec.Stride+j] 159 } else { 160 perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[j*aRec.Stride+i] 161 } 162 } 163 } 164 // Compute the difference P*L*Lᵀ*Pᵀ - A. 165 for i := 0; i < n; i++ { 166 for j := 0; j <= i; j++ { 167 perm.Data[i*perm.Stride+j] -= a[i*lda+j] 168 } 169 } 170 } 171 // Compute |P*Uᵀ*U*Pᵀ - A| / n or |P*L*Lᵀ*Pᵀ - A| / n. 172 return dlansy(lapack.MaxColumnSum, uplo, n, perm.Data, perm.Stride) / float64(n) 173 }