gonum.org/v1/gonum@v0.14.0/lapack/gonum/dpstrf.go (about) 1 // Copyright ©2021 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 ) 13 14 // Dpstrf computes the Cholesky factorization with complete pivoting of an n×n 15 // symmetric positive semidefinite matrix A. 16 // 17 // The factorization has the form 18 // 19 // Pᵀ * A * P = Uᵀ * U , if uplo = blas.Upper, 20 // Pᵀ * A * P = L * Lᵀ, if uplo = blas.Lower, 21 // 22 // where U is an upper triangular matrix, L is lower triangular, and P is a 23 // permutation matrix. 24 // 25 // tol is a user-defined tolerance. The algorithm terminates if the pivot is 26 // less than or equal to tol. If tol is negative, then n*eps*max(A[k,k]) will be 27 // used instead. 28 // 29 // On return, A contains the factor U or L from the Cholesky factorization and 30 // piv contains P stored such that P[piv[k],k] = 1. 31 // 32 // Dpstrf returns the computed rank of A and whether the factorization can be 33 // used to solve a system. Dpstrf does not attempt to check that A is positive 34 // semi-definite, so if ok is false, the matrix A is either rank deficient or is 35 // not positive semidefinite. 36 // 37 // The length of piv must be n and the length of work must be at least 2*n, 38 // otherwise Dpstrf will panic. 39 // 40 // Dpstrf is an internal routine. It is exported for testing purposes. 41 func (impl Implementation) Dpstrf(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) { 42 switch { 43 case uplo != blas.Upper && uplo != blas.Lower: 44 panic(badUplo) 45 case n < 0: 46 panic(nLT0) 47 case lda < max(1, n): 48 panic(badLdA) 49 } 50 51 // Quick return if possible. 52 if n == 0 { 53 return 0, true 54 } 55 56 switch { 57 case len(a) < (n-1)*lda+n: 58 panic(shortA) 59 case len(piv) != n: 60 panic(badLenPiv) 61 case len(work) < 2*n: 62 panic(shortWork) 63 } 64 65 // Get block size. 66 nb := impl.Ilaenv(1, "DPOTRF", string(uplo), n, -1, -1, -1) 67 if nb <= 1 || n <= nb { 68 // Use unblocked code. 69 return impl.Dpstf2(uplo, n, a, lda, piv, tol, work) 70 } 71 72 // Initialize piv. 73 for i := range piv[:n] { 74 piv[i] = i 75 } 76 77 // Compute the first pivot. 78 pvt := 0 79 ajj := a[0] 80 for i := 1; i < n; i++ { 81 aii := a[i*lda+i] 82 if aii > ajj { 83 pvt = i 84 ajj = aii 85 } 86 } 87 if ajj <= 0 || math.IsNaN(ajj) { 88 return 0, false 89 } 90 91 // Compute stopping value if not supplied. 92 dstop := tol 93 if dstop < 0 { 94 dstop = float64(n) * dlamchE * ajj 95 } 96 97 bi := blas64.Implementation() 98 // Split work in half, the first half holds dot products. 99 dots := work[:n] 100 work2 := work[n : 2*n] 101 if uplo == blas.Upper { 102 // Compute the Cholesky factorization Pᵀ * A * P = Uᵀ * U. 103 for k := 0; k < n; k += nb { 104 // Account for last block not being nb wide. 105 jb := min(nb, n-k) 106 // Set relevant part of dot products to zero. 107 for i := k; i < n; i++ { 108 dots[i] = 0 109 } 110 for j := k; j < k+jb; j++ { 111 // Update dot products and compute possible pivots which are stored 112 // in the second half of work. 113 for i := j; i < n; i++ { 114 if j > k { 115 tmp := a[(j-1)*lda+i] 116 dots[i] += tmp * tmp 117 } 118 work2[i] = a[i*lda+i] - dots[i] 119 } 120 if j > 0 { 121 // Find the pivot. 122 pvt = j 123 ajj = work2[pvt] 124 for l := j + 1; l < n; l++ { 125 wl := work2[l] 126 if wl > ajj { 127 pvt = l 128 ajj = wl 129 } 130 } 131 // Test for exit. 132 if ajj <= dstop || math.IsNaN(ajj) { 133 a[j*lda+j] = ajj 134 return j, false 135 } 136 } 137 if j != pvt { 138 // Swap pivot rows and columns. 139 a[pvt*lda+pvt] = a[j*lda+j] 140 bi.Dswap(j, a[j:], lda, a[pvt:], lda) 141 if pvt < n-1 { 142 bi.Dswap(n-pvt-1, a[j*lda+(pvt+1):], 1, a[pvt*lda+(pvt+1):], 1) 143 } 144 bi.Dswap(pvt-j-1, a[j*lda+(j+1):], 1, a[(j+1)*lda+pvt:], lda) 145 // Swap dot products and piv. 146 dots[j], dots[pvt] = dots[pvt], dots[j] 147 piv[j], piv[pvt] = piv[pvt], piv[j] 148 } 149 ajj = math.Sqrt(ajj) 150 a[j*lda+j] = ajj 151 // Compute elements j+1:n of row j. 152 if j < n-1 { 153 bi.Dgemv(blas.Trans, j-k, n-j-1, 154 -1, a[k*lda+j+1:], lda, a[k*lda+j:], lda, 155 1, a[j*lda+j+1:], 1) 156 bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1) 157 } 158 } 159 // Update trailing matrix. 160 if k+jb < n { 161 j := k + jb 162 bi.Dsyrk(blas.Upper, blas.Trans, n-j, jb, 163 -1, a[k*lda+j:], lda, 1, a[j*lda+j:], lda) 164 } 165 } 166 } else { 167 // Compute the Cholesky factorization Pᵀ * A * P = L * Lᵀ. 168 for k := 0; k < n; k += nb { 169 // Account for last block not being nb wide. 170 jb := min(nb, n-k) 171 // Set relevant part of dot products to zero. 172 for i := k; i < n; i++ { 173 dots[i] = 0 174 } 175 for j := k; j < k+jb; j++ { 176 // Update dot products and compute possible pivots which are stored 177 // in the second half of work. 178 for i := j; i < n; i++ { 179 if j > k { 180 tmp := a[i*lda+(j-1)] 181 dots[i] += tmp * tmp 182 } 183 work2[i] = a[i*lda+i] - dots[i] 184 } 185 if j > 0 { 186 // Find the pivot. 187 pvt = j 188 ajj = work2[pvt] 189 for l := j + 1; l < n; l++ { 190 wl := work2[l] 191 if wl > ajj { 192 pvt = l 193 ajj = wl 194 } 195 } 196 // Test for exit. 197 if ajj <= dstop || math.IsNaN(ajj) { 198 a[j*lda+j] = ajj 199 return j, false 200 } 201 } 202 if j != pvt { 203 // Swap pivot rows and columns. 204 a[pvt*lda+pvt] = a[j*lda+j] 205 bi.Dswap(j, a[j*lda:], 1, a[pvt*lda:], 1) 206 if pvt < n-1 { 207 bi.Dswap(n-pvt-1, a[(pvt+1)*lda+j:], lda, a[(pvt+1)*lda+pvt:], lda) 208 } 209 bi.Dswap(pvt-j-1, a[(j+1)*lda+j:], lda, a[pvt*lda+(j+1):], 1) 210 // Swap dot products and piv. 211 dots[j], dots[pvt] = dots[pvt], dots[j] 212 piv[j], piv[pvt] = piv[pvt], piv[j] 213 } 214 ajj = math.Sqrt(ajj) 215 a[j*lda+j] = ajj 216 // Compute elements j+1:n of column j. 217 if j < n-1 { 218 bi.Dgemv(blas.NoTrans, n-j-1, j-k, 219 -1, a[(j+1)*lda+k:], lda, a[j*lda+k:], 1, 220 1, a[(j+1)*lda+j:], lda) 221 bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda) 222 } 223 } 224 // Update trailing matrix. 225 if k+jb < n { 226 j := k + jb 227 bi.Dsyrk(blas.Lower, blas.NoTrans, n-j, jb, 228 -1, a[j*lda+k:], lda, 1, a[j*lda+j:], lda) 229 } 230 } 231 } 232 return n, true 233 }