gonum.org/v1/gonum@v0.14.0/lapack/gonum/dpstf2.go (about) 1 // Copyright ©2021 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 ) 13 14 // Dpstf2 computes the Cholesky factorization with complete pivoting of an n×n 15 // symmetric positive semidefinite matrix A. 16 // 17 // The factorization has the form 18 // 19 // Pᵀ * A * P = Uᵀ * U , if uplo = blas.Upper, 20 // Pᵀ * A * P = L * Lᵀ, if uplo = blas.Lower, 21 // 22 // where U is an upper triangular matrix, L is lower triangular, and P is a 23 // permutation matrix. 24 // 25 // tol is a user-defined tolerance. The algorithm terminates if the pivot is 26 // less than or equal to tol. If tol is negative, then n*eps*max(A[k,k]) will be 27 // used instead. 28 // 29 // On return, A contains the factor U or L from the Cholesky factorization and 30 // piv contains P stored such that P[piv[k],k] = 1. 31 // 32 // Dpstf2 returns the computed rank of A and whether the factorization can be 33 // used to solve a system. Dpstf2 does not attempt to check that A is positive 34 // semi-definite, so if ok is false, the matrix A is either rank deficient or is 35 // not positive semidefinite. 36 // 37 // The length of piv must be n and the length of work must be at least 2*n, 38 // otherwise Dpstf2 will panic. 39 // 40 // Dpstf2 is an internal routine. It is exported for testing purposes. 41 func (Implementation) Dpstf2(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) { 42 switch { 43 case uplo != blas.Upper && uplo != blas.Lower: 44 panic(badUplo) 45 case n < 0: 46 panic(nLT0) 47 case lda < max(1, n): 48 panic(badLdA) 49 } 50 51 // Quick return if possible. 52 if n == 0 { 53 return 0, true 54 } 55 56 switch { 57 case len(a) < (n-1)*lda+n: 58 panic(shortA) 59 case len(piv) != n: 60 panic(badLenPiv) 61 case len(work) < 2*n: 62 panic(shortWork) 63 } 64 65 // Initialize piv. 66 for i := range piv[:n] { 67 piv[i] = i 68 } 69 70 // Compute the first pivot. 71 pvt := 0 72 ajj := a[0] 73 for i := 1; i < n; i++ { 74 aii := a[i*lda+i] 75 if aii > ajj { 76 pvt = i 77 ajj = aii 78 } 79 } 80 if ajj <= 0 || math.IsNaN(ajj) { 81 return 0, false 82 } 83 84 // Compute stopping value if not supplied. 85 dstop := tol 86 if dstop < 0 { 87 dstop = float64(n) * dlamchE * ajj 88 } 89 90 // Set first half of work to zero, holds dot products. 91 dots := work[:n] 92 for i := range dots { 93 dots[i] = 0 94 } 95 work2 := work[n : 2*n] 96 97 bi := blas64.Implementation() 98 if uplo == blas.Upper { 99 // Compute the Cholesky factorization Pᵀ * A * P = Uᵀ * U. 100 for j := 0; j < n; j++ { 101 // Update dot products and compute possible pivots which are stored 102 // in the second half of work. 103 for i := j; i < n; i++ { 104 if j > 0 { 105 tmp := a[(j-1)*lda+i] 106 dots[i] += tmp * tmp 107 } 108 work2[i] = a[i*lda+i] - dots[i] 109 } 110 if j > 0 { 111 // Find the pivot. 112 pvt = j 113 ajj = work2[pvt] 114 for k := j + 1; k < n; k++ { 115 wk := work2[k] 116 if wk > ajj { 117 pvt = k 118 ajj = wk 119 } 120 } 121 // Test for exit. 122 if ajj <= dstop || math.IsNaN(ajj) { 123 a[j*lda+j] = ajj 124 return j, false 125 } 126 } 127 if j != pvt { 128 // Swap pivot rows and columns. 129 a[pvt*lda+pvt] = a[j*lda+j] 130 bi.Dswap(j, a[j:], lda, a[pvt:], lda) 131 if pvt < n-1 { 132 bi.Dswap(n-pvt-1, a[j*lda+(pvt+1):], 1, a[pvt*lda+(pvt+1):], 1) 133 } 134 bi.Dswap(pvt-j-1, a[j*lda+(j+1):], 1, a[(j+1)*lda+pvt:], lda) 135 // Swap dot products and piv. 136 dots[j], dots[pvt] = dots[pvt], dots[j] 137 piv[j], piv[pvt] = piv[pvt], piv[j] 138 } 139 ajj = math.Sqrt(ajj) 140 a[j*lda+j] = ajj 141 // Compute elements j+1:n of row j. 142 if j < n-1 { 143 bi.Dgemv(blas.Trans, j, n-j-1, 144 -1, a[j+1:], lda, a[j:], lda, 145 1, a[j*lda+j+1:], 1) 146 bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1) 147 } 148 } 149 } else { 150 // Compute the Cholesky factorization Pᵀ * A * P = L * Lᵀ. 151 for j := 0; j < n; j++ { 152 // Update dot products and compute possible pivots which are stored 153 // in the second half of work. 154 for i := j; i < n; i++ { 155 if j > 0 { 156 tmp := a[i*lda+(j-1)] 157 dots[i] += tmp * tmp 158 } 159 work2[i] = a[i*lda+i] - dots[i] 160 } 161 if j > 0 { 162 // Find the pivot. 163 pvt = j 164 ajj = work2[pvt] 165 for k := j + 1; k < n; k++ { 166 wk := work2[k] 167 if wk > ajj { 168 pvt = k 169 ajj = wk 170 } 171 } 172 // Test for exit. 173 if ajj <= dstop || math.IsNaN(ajj) { 174 a[j*lda+j] = ajj 175 return j, false 176 } 177 } 178 if j != pvt { 179 // Swap pivot rows and columns. 180 a[pvt*lda+pvt] = a[j*lda+j] 181 bi.Dswap(j, a[j*lda:], 1, a[pvt*lda:], 1) 182 if pvt < n-1 { 183 bi.Dswap(n-pvt-1, a[(pvt+1)*lda+j:], lda, a[(pvt+1)*lda+pvt:], lda) 184 } 185 bi.Dswap(pvt-j-1, a[(j+1)*lda+j:], lda, a[pvt*lda+(j+1):], 1) 186 // Swap dot products and piv. 187 dots[j], dots[pvt] = dots[pvt], dots[j] 188 piv[j], piv[pvt] = piv[pvt], piv[j] 189 } 190 ajj = math.Sqrt(ajj) 191 a[j*lda+j] = ajj 192 // Compute elements j+1:n of column j. 193 if j < n-1 { 194 bi.Dgemv(blas.NoTrans, n-j-1, j, 195 -1, a[(j+1)*lda:], lda, a[j*lda:], 1, 196 1, a[(j+1)*lda+j:], lda) 197 bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda) 198 } 199 } 200 } 201 return n, true 202 }