gonum.org/v1/gonum@v0.14.0/lapack/gonum/dpstrf.go (about)

     1  // Copyright ©2021 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"gonum.org/v1/gonum/blas"
    11  	"gonum.org/v1/gonum/blas/blas64"
    12  )
    13  
    14  // Dpstrf computes the Cholesky factorization with complete pivoting of an n×n
    15  // symmetric positive semidefinite matrix A.
    16  //
    17  // The factorization has the form
    18  //
    19  //	Pᵀ * A * P = Uᵀ * U ,  if uplo = blas.Upper,
    20  //	Pᵀ * A * P = L  * Lᵀ,  if uplo = blas.Lower,
    21  //
    22  // where U is an upper triangular matrix, L is lower triangular, and P is a
    23  // permutation matrix.
    24  //
    25  // tol is a user-defined tolerance. The algorithm terminates if the pivot is
    26  // less than or equal to tol. If tol is negative, then n*eps*max(A[k,k]) will be
    27  // used instead.
    28  //
    29  // On return, A contains the factor U or L from the Cholesky factorization and
    30  // piv contains P stored such that P[piv[k],k] = 1.
    31  //
    32  // Dpstrf returns the computed rank of A and whether the factorization can be
    33  // used to solve a system. Dpstrf does not attempt to check that A is positive
    34  // semi-definite, so if ok is false, the matrix A is either rank deficient or is
    35  // not positive semidefinite.
    36  //
    37  // The length of piv must be n and the length of work must be at least 2*n,
    38  // otherwise Dpstrf will panic.
    39  //
    40  // Dpstrf is an internal routine. It is exported for testing purposes.
    41  func (impl Implementation) Dpstrf(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) {
    42  	switch {
    43  	case uplo != blas.Upper && uplo != blas.Lower:
    44  		panic(badUplo)
    45  	case n < 0:
    46  		panic(nLT0)
    47  	case lda < max(1, n):
    48  		panic(badLdA)
    49  	}
    50  
    51  	// Quick return if possible.
    52  	if n == 0 {
    53  		return 0, true
    54  	}
    55  
    56  	switch {
    57  	case len(a) < (n-1)*lda+n:
    58  		panic(shortA)
    59  	case len(piv) != n:
    60  		panic(badLenPiv)
    61  	case len(work) < 2*n:
    62  		panic(shortWork)
    63  	}
    64  
    65  	// Get block size.
    66  	nb := impl.Ilaenv(1, "DPOTRF", string(uplo), n, -1, -1, -1)
    67  	if nb <= 1 || n <= nb {
    68  		// Use unblocked code.
    69  		return impl.Dpstf2(uplo, n, a, lda, piv, tol, work)
    70  	}
    71  
    72  	// Initialize piv.
    73  	for i := range piv[:n] {
    74  		piv[i] = i
    75  	}
    76  
    77  	// Compute the first pivot.
    78  	pvt := 0
    79  	ajj := a[0]
    80  	for i := 1; i < n; i++ {
    81  		aii := a[i*lda+i]
    82  		if aii > ajj {
    83  			pvt = i
    84  			ajj = aii
    85  		}
    86  	}
    87  	if ajj <= 0 || math.IsNaN(ajj) {
    88  		return 0, false
    89  	}
    90  
    91  	// Compute stopping value if not supplied.
    92  	dstop := tol
    93  	if dstop < 0 {
    94  		dstop = float64(n) * dlamchE * ajj
    95  	}
    96  
    97  	bi := blas64.Implementation()
    98  	// Split work in half, the first half holds dot products.
    99  	dots := work[:n]
   100  	work2 := work[n : 2*n]
   101  	if uplo == blas.Upper {
   102  		// Compute the Cholesky factorization  Pᵀ * A * P = Uᵀ * U.
   103  		for k := 0; k < n; k += nb {
   104  			// Account for last block not being nb wide.
   105  			jb := min(nb, n-k)
   106  			// Set relevant part of dot products to zero.
   107  			for i := k; i < n; i++ {
   108  				dots[i] = 0
   109  			}
   110  			for j := k; j < k+jb; j++ {
   111  				// Update dot products and compute possible pivots which are stored
   112  				// in the second half of work.
   113  				for i := j; i < n; i++ {
   114  					if j > k {
   115  						tmp := a[(j-1)*lda+i]
   116  						dots[i] += tmp * tmp
   117  					}
   118  					work2[i] = a[i*lda+i] - dots[i]
   119  				}
   120  				if j > 0 {
   121  					// Find the pivot.
   122  					pvt = j
   123  					ajj = work2[pvt]
   124  					for l := j + 1; l < n; l++ {
   125  						wl := work2[l]
   126  						if wl > ajj {
   127  							pvt = l
   128  							ajj = wl
   129  						}
   130  					}
   131  					// Test for exit.
   132  					if ajj <= dstop || math.IsNaN(ajj) {
   133  						a[j*lda+j] = ajj
   134  						return j, false
   135  					}
   136  				}
   137  				if j != pvt {
   138  					// Swap pivot rows and columns.
   139  					a[pvt*lda+pvt] = a[j*lda+j]
   140  					bi.Dswap(j, a[j:], lda, a[pvt:], lda)
   141  					if pvt < n-1 {
   142  						bi.Dswap(n-pvt-1, a[j*lda+(pvt+1):], 1, a[pvt*lda+(pvt+1):], 1)
   143  					}
   144  					bi.Dswap(pvt-j-1, a[j*lda+(j+1):], 1, a[(j+1)*lda+pvt:], lda)
   145  					// Swap dot products and piv.
   146  					dots[j], dots[pvt] = dots[pvt], dots[j]
   147  					piv[j], piv[pvt] = piv[pvt], piv[j]
   148  				}
   149  				ajj = math.Sqrt(ajj)
   150  				a[j*lda+j] = ajj
   151  				// Compute elements j+1:n of row j.
   152  				if j < n-1 {
   153  					bi.Dgemv(blas.Trans, j-k, n-j-1,
   154  						-1, a[k*lda+j+1:], lda, a[k*lda+j:], lda,
   155  						1, a[j*lda+j+1:], 1)
   156  					bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1)
   157  				}
   158  			}
   159  			// Update trailing matrix.
   160  			if k+jb < n {
   161  				j := k + jb
   162  				bi.Dsyrk(blas.Upper, blas.Trans, n-j, jb,
   163  					-1, a[k*lda+j:], lda, 1, a[j*lda+j:], lda)
   164  			}
   165  		}
   166  	} else {
   167  		// Compute the Cholesky factorization  Pᵀ * A * P = L * Lᵀ.
   168  		for k := 0; k < n; k += nb {
   169  			// Account for last block not being nb wide.
   170  			jb := min(nb, n-k)
   171  			// Set relevant part of dot products to zero.
   172  			for i := k; i < n; i++ {
   173  				dots[i] = 0
   174  			}
   175  			for j := k; j < k+jb; j++ {
   176  				// Update dot products and compute possible pivots which are stored
   177  				// in the second half of work.
   178  				for i := j; i < n; i++ {
   179  					if j > k {
   180  						tmp := a[i*lda+(j-1)]
   181  						dots[i] += tmp * tmp
   182  					}
   183  					work2[i] = a[i*lda+i] - dots[i]
   184  				}
   185  				if j > 0 {
   186  					// Find the pivot.
   187  					pvt = j
   188  					ajj = work2[pvt]
   189  					for l := j + 1; l < n; l++ {
   190  						wl := work2[l]
   191  						if wl > ajj {
   192  							pvt = l
   193  							ajj = wl
   194  						}
   195  					}
   196  					// Test for exit.
   197  					if ajj <= dstop || math.IsNaN(ajj) {
   198  						a[j*lda+j] = ajj
   199  						return j, false
   200  					}
   201  				}
   202  				if j != pvt {
   203  					// Swap pivot rows and columns.
   204  					a[pvt*lda+pvt] = a[j*lda+j]
   205  					bi.Dswap(j, a[j*lda:], 1, a[pvt*lda:], 1)
   206  					if pvt < n-1 {
   207  						bi.Dswap(n-pvt-1, a[(pvt+1)*lda+j:], lda, a[(pvt+1)*lda+pvt:], lda)
   208  					}
   209  					bi.Dswap(pvt-j-1, a[(j+1)*lda+j:], lda, a[pvt*lda+(j+1):], 1)
   210  					// Swap dot products and piv.
   211  					dots[j], dots[pvt] = dots[pvt], dots[j]
   212  					piv[j], piv[pvt] = piv[pvt], piv[j]
   213  				}
   214  				ajj = math.Sqrt(ajj)
   215  				a[j*lda+j] = ajj
   216  				// Compute elements j+1:n of column j.
   217  				if j < n-1 {
   218  					bi.Dgemv(blas.NoTrans, n-j-1, j-k,
   219  						-1, a[(j+1)*lda+k:], lda, a[j*lda+k:], 1,
   220  						1, a[(j+1)*lda+j:], lda)
   221  					bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda)
   222  				}
   223  			}
   224  			// Update trailing matrix.
   225  			if k+jb < n {
   226  				j := k + jb
   227  				bi.Dsyrk(blas.Lower, blas.NoTrans, n-j, jb,
   228  					-1, a[j*lda+k:], lda, 1, a[j*lda+j:], lda)
   229  			}
   230  		}
   231  	}
   232  	return n, true
   233  }