github.com/gopherd/gonum@v0.0.4/lapack/gonum/dpstrf.go (about) 1 // Copyright ©2021 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "math" 9 10 "github.com/gopherd/gonum/blas" 11 "github.com/gopherd/gonum/blas/blas64" 12 ) 13 14 // Dpstrf computes the Cholesky factorization with complete pivoting of an n×n 15 // symmetric positive semidefinite matrix A. 16 // 17 // The factorization has the form 18 // Pᵀ * A * P = Uᵀ * U , if uplo = blas.Upper, 19 // Pᵀ * A * P = L * Lᵀ, if uplo = blas.Lower, 20 // where U is an upper triangular matrix, L is lower triangular, and P is a 21 // permutation matrix. 22 // 23 // tol is a user-defined tolerance. The algorithm terminates if the pivot is 24 // less than or equal to tol. If tol is negative, then n*eps*max(A[k,k]) will be 25 // used instead. 26 // 27 // On return, A contains the factor U or L from the Cholesky factorization and 28 // piv contains P stored such that P[piv[k],k] = 1. 29 // 30 // Dpstrf returns the computed rank of A and whether the factorization can be 31 // used to solve a system. Dpstrf does not attempt to check that A is positive 32 // semi-definite, so if ok is false, the matrix A is either rank deficient or is 33 // not positive semidefinite. 34 // 35 // The length of piv must be n and the length of work must be at least 2*n, 36 // otherwise Dpstrf will panic. 37 // 38 // Dpstrf is an internal routine. It is exported for testing purposes. 39 func (impl Implementation) Dpstrf(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) { 40 switch { 41 case uplo != blas.Upper && uplo != blas.Lower: 42 panic(badUplo) 43 case n < 0: 44 panic(nLT0) 45 case lda < max(1, n): 46 panic(badLdA) 47 } 48 49 // Quick return if possible. 50 if n == 0 { 51 return 0, true 52 } 53 54 switch { 55 case len(a) < (n-1)*lda+n: 56 panic(shortA) 57 case len(piv) != n: 58 panic(badLenPiv) 59 case len(work) < 2*n: 60 panic(shortWork) 61 } 62 63 // Get block size. 64 nb := impl.Ilaenv(1, "DPOTRF", string(uplo), n, -1, -1, -1) 65 if nb <= 1 || n <= nb { 66 // Use unblocked code. 67 return impl.Dpstf2(uplo, n, a, lda, piv, tol, work) 68 } 69 70 // Initialize piv. 71 for i := range piv[:n] { 72 piv[i] = i 73 } 74 75 // Compute the first pivot. 76 pvt := 0 77 ajj := a[0] 78 for i := 1; i < n; i++ { 79 aii := a[i*lda+i] 80 if aii > ajj { 81 pvt = i 82 ajj = aii 83 } 84 } 85 if ajj <= 0 || math.IsNaN(ajj) { 86 return 0, false 87 } 88 89 // Compute stopping value if not supplied. 90 dstop := tol 91 if dstop < 0 { 92 dstop = float64(n) * dlamchE * ajj 93 } 94 95 bi := blas64.Implementation() 96 // Split work in half, the first half holds dot products. 97 dots := work[:n] 98 work2 := work[n : 2*n] 99 if uplo == blas.Upper { 100 // Compute the Cholesky factorization Pᵀ * A * P = Uᵀ * U. 101 for k := 0; k < n; k += nb { 102 // Account for last block not being nb wide. 103 jb := min(nb, n-k) 104 // Set relevant part of dot products to zero. 105 for i := k; i < n; i++ { 106 dots[i] = 0 107 } 108 for j := k; j < k+jb; j++ { 109 // Update dot products and compute possible pivots which are stored 110 // in the second half of work. 111 for i := j; i < n; i++ { 112 if j > k { 113 tmp := a[(j-1)*lda+i] 114 dots[i] += tmp * tmp 115 } 116 work2[i] = a[i*lda+i] - dots[i] 117 } 118 if j > 0 { 119 // Find the pivot. 120 pvt = j 121 ajj = work2[pvt] 122 for l := j + 1; l < n; l++ { 123 wl := work2[l] 124 if wl > ajj { 125 pvt = l 126 ajj = wl 127 } 128 } 129 // Test for exit. 130 if ajj <= dstop || math.IsNaN(ajj) { 131 a[j*lda+j] = ajj 132 return j, false 133 } 134 } 135 if j != pvt { 136 // Swap pivot rows and columns. 137 a[pvt*lda+pvt] = a[j*lda+j] 138 bi.Dswap(j, a[j:], lda, a[pvt:], lda) 139 if pvt < n-1 { 140 bi.Dswap(n-pvt-1, a[j*lda+(pvt+1):], 1, a[pvt*lda+(pvt+1):], 1) 141 } 142 bi.Dswap(pvt-j-1, a[j*lda+(j+1):], 1, a[(j+1)*lda+pvt:], lda) 143 // Swap dot products and piv. 144 dots[j], dots[pvt] = dots[pvt], dots[j] 145 piv[j], piv[pvt] = piv[pvt], piv[j] 146 } 147 ajj = math.Sqrt(ajj) 148 a[j*lda+j] = ajj 149 // Compute elements j+1:n of row j. 150 if j < n-1 { 151 bi.Dgemv(blas.Trans, j-k, n-j-1, 152 -1, a[k*lda+j+1:], lda, a[k*lda+j:], lda, 153 1, a[j*lda+j+1:], 1) 154 bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1) 155 } 156 } 157 // Update trailing matrix. 158 if k+jb < n { 159 j := k + jb 160 bi.Dsyrk(blas.Upper, blas.Trans, n-j, jb, 161 -1, a[k*lda+j:], lda, 1, a[j*lda+j:], lda) 162 } 163 } 164 } else { 165 // Compute the Cholesky factorization Pᵀ * A * P = L * Lᵀ. 166 for k := 0; k < n; k += nb { 167 // Account for last block not being nb wide. 168 jb := min(nb, n-k) 169 // Set relevant part of dot products to zero. 170 for i := k; i < n; i++ { 171 dots[i] = 0 172 } 173 for j := k; j < k+jb; j++ { 174 // Update dot products and compute possible pivots which are stored 175 // in the second half of work. 176 for i := j; i < n; i++ { 177 if j > k { 178 tmp := a[i*lda+(j-1)] 179 dots[i] += tmp * tmp 180 } 181 work2[i] = a[i*lda+i] - dots[i] 182 } 183 if j > 0 { 184 // Find the pivot. 185 pvt = j 186 ajj = work2[pvt] 187 for l := j + 1; l < n; l++ { 188 wl := work2[l] 189 if wl > ajj { 190 pvt = l 191 ajj = wl 192 } 193 } 194 // Test for exit. 195 if ajj <= dstop || math.IsNaN(ajj) { 196 a[j*lda+j] = ajj 197 return j, false 198 } 199 } 200 if j != pvt { 201 // Swap pivot rows and columns. 202 a[pvt*lda+pvt] = a[j*lda+j] 203 bi.Dswap(j, a[j*lda:], 1, a[pvt*lda:], 1) 204 if pvt < n-1 { 205 bi.Dswap(n-pvt-1, a[(pvt+1)*lda+j:], lda, a[(pvt+1)*lda+pvt:], lda) 206 } 207 bi.Dswap(pvt-j-1, a[(j+1)*lda+j:], lda, a[pvt*lda+(j+1):], 1) 208 // Swap dot products and piv. 209 dots[j], dots[pvt] = dots[pvt], dots[j] 210 piv[j], piv[pvt] = piv[pvt], piv[j] 211 } 212 ajj = math.Sqrt(ajj) 213 a[j*lda+j] = ajj 214 // Compute elements j+1:n of column j. 215 if j < n-1 { 216 bi.Dgemv(blas.NoTrans, n-j-1, j-k, 217 -1, a[(j+1)*lda+k:], lda, a[j*lda+k:], 1, 218 1, a[(j+1)*lda+j:], lda) 219 bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda) 220 } 221 } 222 // Update trailing matrix. 223 if k+jb < n { 224 j := k + jb 225 bi.Dsyrk(blas.Lower, blas.NoTrans, n-j, jb, 226 -1, a[j*lda+k:], lda, 1, a[j*lda+j:], lda) 227 } 228 } 229 } 230 return n, true 231 }