github.com/gopherd/gonum@v0.0.4/lapack/gonum/dpstf2.go (about) 1 // Copyright ©2021 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "math" 9 10 "github.com/gopherd/gonum/blas" 11 "github.com/gopherd/gonum/blas/blas64" 12 ) 13 14 // Dpstf2 computes the Cholesky factorization with complete pivoting of an n×n 15 // symmetric positive semidefinite matrix A. 16 // 17 // The factorization has the form 18 // Pᵀ * A * P = Uᵀ * U , if uplo = blas.Upper, 19 // Pᵀ * A * P = L * Lᵀ, if uplo = blas.Lower, 20 // where U is an upper triangular matrix, L is lower triangular, and P is a 21 // permutation matrix. 22 // 23 // tol is a user-defined tolerance. The algorithm terminates if the pivot is 24 // less than or equal to tol. If tol is negative, then n*eps*max(A[k,k]) will be 25 // used instead. 26 // 27 // On return, A contains the factor U or L from the Cholesky factorization and 28 // piv contains P stored such that P[piv[k],k] = 1. 29 // 30 // Dpstf2 returns the computed rank of A and whether the factorization can be 31 // used to solve a system. Dpstf2 does not attempt to check that A is positive 32 // semi-definite, so if ok is false, the matrix A is either rank deficient or is 33 // not positive semidefinite. 34 // 35 // The length of piv must be n and the length of work must be at least 2*n, 36 // otherwise Dpstf2 will panic. 37 // 38 // Dpstf2 is an internal routine. It is exported for testing purposes. 39 func (Implementation) Dpstf2(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) { 40 switch { 41 case uplo != blas.Upper && uplo != blas.Lower: 42 panic(badUplo) 43 case n < 0: 44 panic(nLT0) 45 case lda < max(1, n): 46 panic(badLdA) 47 } 48 49 // Quick return if possible. 50 if n == 0 { 51 return 0, true 52 } 53 54 switch { 55 case len(a) < (n-1)*lda+n: 56 panic(shortA) 57 case len(piv) != n: 58 panic(badLenPiv) 59 case len(work) < 2*n: 60 panic(shortWork) 61 } 62 63 // Initialize piv. 64 for i := range piv[:n] { 65 piv[i] = i 66 } 67 68 // Compute the first pivot. 69 pvt := 0 70 ajj := a[0] 71 for i := 1; i < n; i++ { 72 aii := a[i*lda+i] 73 if aii > ajj { 74 pvt = i 75 ajj = aii 76 } 77 } 78 if ajj <= 0 || math.IsNaN(ajj) { 79 return 0, false 80 } 81 82 // Compute stopping value if not supplied. 83 dstop := tol 84 if dstop < 0 { 85 dstop = float64(n) * dlamchE * ajj 86 } 87 88 // Set first half of work to zero, holds dot products. 89 dots := work[:n] 90 for i := range dots { 91 dots[i] = 0 92 } 93 work2 := work[n : 2*n] 94 95 bi := blas64.Implementation() 96 if uplo == blas.Upper { 97 // Compute the Cholesky factorization Pᵀ * A * P = Uᵀ * U. 98 for j := 0; j < n; j++ { 99 // Update dot products and compute possible pivots which are stored 100 // in the second half of work. 101 for i := j; i < n; i++ { 102 if j > 0 { 103 tmp := a[(j-1)*lda+i] 104 dots[i] += tmp * tmp 105 } 106 work2[i] = a[i*lda+i] - dots[i] 107 } 108 if j > 0 { 109 // Find the pivot. 110 pvt = j 111 ajj = work2[pvt] 112 for k := j + 1; k < n; k++ { 113 wk := work2[k] 114 if wk > ajj { 115 pvt = k 116 ajj = wk 117 } 118 } 119 // Test for exit. 120 if ajj <= dstop || math.IsNaN(ajj) { 121 a[j*lda+j] = ajj 122 return j, false 123 } 124 } 125 if j != pvt { 126 // Swap pivot rows and columns. 127 a[pvt*lda+pvt] = a[j*lda+j] 128 bi.Dswap(j, a[j:], lda, a[pvt:], lda) 129 if pvt < n-1 { 130 bi.Dswap(n-pvt-1, a[j*lda+(pvt+1):], 1, a[pvt*lda+(pvt+1):], 1) 131 } 132 bi.Dswap(pvt-j-1, a[j*lda+(j+1):], 1, a[(j+1)*lda+pvt:], lda) 133 // Swap dot products and piv. 134 dots[j], dots[pvt] = dots[pvt], dots[j] 135 piv[j], piv[pvt] = piv[pvt], piv[j] 136 } 137 ajj = math.Sqrt(ajj) 138 a[j*lda+j] = ajj 139 // Compute elements j+1:n of row j. 140 if j < n-1 { 141 bi.Dgemv(blas.Trans, j, n-j-1, 142 -1, a[j+1:], lda, a[j:], lda, 143 1, a[j*lda+j+1:], 1) 144 bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1) 145 } 146 } 147 } else { 148 // Compute the Cholesky factorization Pᵀ * A * P = L * Lᵀ. 149 for j := 0; j < n; j++ { 150 // Update dot products and compute possible pivots which are stored 151 // in the second half of work. 152 for i := j; i < n; i++ { 153 if j > 0 { 154 tmp := a[i*lda+(j-1)] 155 dots[i] += tmp * tmp 156 } 157 work2[i] = a[i*lda+i] - dots[i] 158 } 159 if j > 0 { 160 // Find the pivot. 161 pvt = j 162 ajj = work2[pvt] 163 for k := j + 1; k < n; k++ { 164 wk := work2[k] 165 if wk > ajj { 166 pvt = k 167 ajj = wk 168 } 169 } 170 // Test for exit. 171 if ajj <= dstop || math.IsNaN(ajj) { 172 a[j*lda+j] = ajj 173 return j, false 174 } 175 } 176 if j != pvt { 177 // Swap pivot rows and columns. 178 a[pvt*lda+pvt] = a[j*lda+j] 179 bi.Dswap(j, a[j*lda:], 1, a[pvt*lda:], 1) 180 if pvt < n-1 { 181 bi.Dswap(n-pvt-1, a[(pvt+1)*lda+j:], lda, a[(pvt+1)*lda+pvt:], lda) 182 } 183 bi.Dswap(pvt-j-1, a[(j+1)*lda+j:], lda, a[pvt*lda+(j+1):], 1) 184 // Swap dot products and piv. 185 dots[j], dots[pvt] = dots[pvt], dots[j] 186 piv[j], piv[pvt] = piv[pvt], piv[j] 187 } 188 ajj = math.Sqrt(ajj) 189 a[j*lda+j] = ajj 190 // Compute elements j+1:n of column j. 191 if j < n-1 { 192 bi.Dgemv(blas.NoTrans, n-j-1, j, 193 -1, a[(j+1)*lda:], lda, a[j*lda:], 1, 194 1, a[(j+1)*lda+j:], lda) 195 bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda) 196 } 197 } 198 } 199 return n, true 200 }