github.com/gopherd/gonum@v0.0.4/lapack/gonum/dpstrf.go (about)

     1  // Copyright ©2021 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gopherd/gonum/blas"
    11  	"github.com/gopherd/gonum/blas/blas64"
    12  )
    13  
    14  // Dpstrf computes the Cholesky factorization with complete pivoting of an n×n
    15  // symmetric positive semidefinite matrix A.
    16  //
    17  // The factorization has the form
    18  //  Pᵀ * A * P = Uᵀ * U ,  if uplo = blas.Upper,
    19  //  Pᵀ * A * P = L  * Lᵀ,  if uplo = blas.Lower,
    20  // where U is an upper triangular matrix, L is lower triangular, and P is a
    21  // permutation matrix.
    22  //
    23  // tol is a user-defined tolerance. The algorithm terminates if the pivot is
    24  // less than or equal to tol. If tol is negative, then n*eps*max(A[k,k]) will be
    25  // used instead.
    26  //
    27  // On return, A contains the factor U or L from the Cholesky factorization and
    28  // piv contains P stored such that P[piv[k],k] = 1.
    29  //
    30  // Dpstrf returns the computed rank of A and whether the factorization can be
    31  // used to solve a system. Dpstrf does not attempt to check that A is positive
    32  // semi-definite, so if ok is false, the matrix A is either rank deficient or is
    33  // not positive semidefinite.
    34  //
    35  // The length of piv must be n and the length of work must be at least 2*n,
    36  // otherwise Dpstrf will panic.
    37  //
    38  // Dpstrf is an internal routine. It is exported for testing purposes.
    39  func (impl Implementation) Dpstrf(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) {
    40  	switch {
    41  	case uplo != blas.Upper && uplo != blas.Lower:
    42  		panic(badUplo)
    43  	case n < 0:
    44  		panic(nLT0)
    45  	case lda < max(1, n):
    46  		panic(badLdA)
    47  	}
    48  
    49  	// Quick return if possible.
    50  	if n == 0 {
    51  		return 0, true
    52  	}
    53  
    54  	switch {
    55  	case len(a) < (n-1)*lda+n:
    56  		panic(shortA)
    57  	case len(piv) != n:
    58  		panic(badLenPiv)
    59  	case len(work) < 2*n:
    60  		panic(shortWork)
    61  	}
    62  
    63  	// Get block size.
    64  	nb := impl.Ilaenv(1, "DPOTRF", string(uplo), n, -1, -1, -1)
    65  	if nb <= 1 || n <= nb {
    66  		// Use unblocked code.
    67  		return impl.Dpstf2(uplo, n, a, lda, piv, tol, work)
    68  	}
    69  
    70  	// Initialize piv.
    71  	for i := range piv[:n] {
    72  		piv[i] = i
    73  	}
    74  
    75  	// Compute the first pivot.
    76  	pvt := 0
    77  	ajj := a[0]
    78  	for i := 1; i < n; i++ {
    79  		aii := a[i*lda+i]
    80  		if aii > ajj {
    81  			pvt = i
    82  			ajj = aii
    83  		}
    84  	}
    85  	if ajj <= 0 || math.IsNaN(ajj) {
    86  		return 0, false
    87  	}
    88  
    89  	// Compute stopping value if not supplied.
    90  	dstop := tol
    91  	if dstop < 0 {
    92  		dstop = float64(n) * dlamchE * ajj
    93  	}
    94  
    95  	bi := blas64.Implementation()
    96  	// Split work in half, the first half holds dot products.
    97  	dots := work[:n]
    98  	work2 := work[n : 2*n]
    99  	if uplo == blas.Upper {
   100  		// Compute the Cholesky factorization  Pᵀ * A * P = Uᵀ * U.
   101  		for k := 0; k < n; k += nb {
   102  			// Account for last block not being nb wide.
   103  			jb := min(nb, n-k)
   104  			// Set relevant part of dot products to zero.
   105  			for i := k; i < n; i++ {
   106  				dots[i] = 0
   107  			}
   108  			for j := k; j < k+jb; j++ {
   109  				// Update dot products and compute possible pivots which are stored
   110  				// in the second half of work.
   111  				for i := j; i < n; i++ {
   112  					if j > k {
   113  						tmp := a[(j-1)*lda+i]
   114  						dots[i] += tmp * tmp
   115  					}
   116  					work2[i] = a[i*lda+i] - dots[i]
   117  				}
   118  				if j > 0 {
   119  					// Find the pivot.
   120  					pvt = j
   121  					ajj = work2[pvt]
   122  					for l := j + 1; l < n; l++ {
   123  						wl := work2[l]
   124  						if wl > ajj {
   125  							pvt = l
   126  							ajj = wl
   127  						}
   128  					}
   129  					// Test for exit.
   130  					if ajj <= dstop || math.IsNaN(ajj) {
   131  						a[j*lda+j] = ajj
   132  						return j, false
   133  					}
   134  				}
   135  				if j != pvt {
   136  					// Swap pivot rows and columns.
   137  					a[pvt*lda+pvt] = a[j*lda+j]
   138  					bi.Dswap(j, a[j:], lda, a[pvt:], lda)
   139  					if pvt < n-1 {
   140  						bi.Dswap(n-pvt-1, a[j*lda+(pvt+1):], 1, a[pvt*lda+(pvt+1):], 1)
   141  					}
   142  					bi.Dswap(pvt-j-1, a[j*lda+(j+1):], 1, a[(j+1)*lda+pvt:], lda)
   143  					// Swap dot products and piv.
   144  					dots[j], dots[pvt] = dots[pvt], dots[j]
   145  					piv[j], piv[pvt] = piv[pvt], piv[j]
   146  				}
   147  				ajj = math.Sqrt(ajj)
   148  				a[j*lda+j] = ajj
   149  				// Compute elements j+1:n of row j.
   150  				if j < n-1 {
   151  					bi.Dgemv(blas.Trans, j-k, n-j-1,
   152  						-1, a[k*lda+j+1:], lda, a[k*lda+j:], lda,
   153  						1, a[j*lda+j+1:], 1)
   154  					bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1)
   155  				}
   156  			}
   157  			// Update trailing matrix.
   158  			if k+jb < n {
   159  				j := k + jb
   160  				bi.Dsyrk(blas.Upper, blas.Trans, n-j, jb,
   161  					-1, a[k*lda+j:], lda, 1, a[j*lda+j:], lda)
   162  			}
   163  		}
   164  	} else {
   165  		// Compute the Cholesky factorization  Pᵀ * A * P = L * Lᵀ.
   166  		for k := 0; k < n; k += nb {
   167  			// Account for last block not being nb wide.
   168  			jb := min(nb, n-k)
   169  			// Set relevant part of dot products to zero.
   170  			for i := k; i < n; i++ {
   171  				dots[i] = 0
   172  			}
   173  			for j := k; j < k+jb; j++ {
   174  				// Update dot products and compute possible pivots which are stored
   175  				// in the second half of work.
   176  				for i := j; i < n; i++ {
   177  					if j > k {
   178  						tmp := a[i*lda+(j-1)]
   179  						dots[i] += tmp * tmp
   180  					}
   181  					work2[i] = a[i*lda+i] - dots[i]
   182  				}
   183  				if j > 0 {
   184  					// Find the pivot.
   185  					pvt = j
   186  					ajj = work2[pvt]
   187  					for l := j + 1; l < n; l++ {
   188  						wl := work2[l]
   189  						if wl > ajj {
   190  							pvt = l
   191  							ajj = wl
   192  						}
   193  					}
   194  					// Test for exit.
   195  					if ajj <= dstop || math.IsNaN(ajj) {
   196  						a[j*lda+j] = ajj
   197  						return j, false
   198  					}
   199  				}
   200  				if j != pvt {
   201  					// Swap pivot rows and columns.
   202  					a[pvt*lda+pvt] = a[j*lda+j]
   203  					bi.Dswap(j, a[j*lda:], 1, a[pvt*lda:], 1)
   204  					if pvt < n-1 {
   205  						bi.Dswap(n-pvt-1, a[(pvt+1)*lda+j:], lda, a[(pvt+1)*lda+pvt:], lda)
   206  					}
   207  					bi.Dswap(pvt-j-1, a[(j+1)*lda+j:], lda, a[pvt*lda+(j+1):], 1)
   208  					// Swap dot products and piv.
   209  					dots[j], dots[pvt] = dots[pvt], dots[j]
   210  					piv[j], piv[pvt] = piv[pvt], piv[j]
   211  				}
   212  				ajj = math.Sqrt(ajj)
   213  				a[j*lda+j] = ajj
   214  				// Compute elements j+1:n of column j.
   215  				if j < n-1 {
   216  					bi.Dgemv(blas.NoTrans, n-j-1, j-k,
   217  						-1, a[(j+1)*lda+k:], lda, a[j*lda+k:], 1,
   218  						1, a[(j+1)*lda+j:], lda)
   219  					bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda)
   220  				}
   221  			}
   222  			// Update trailing matrix.
   223  			if k+jb < n {
   224  				j := k + jb
   225  				bi.Dsyrk(blas.Lower, blas.NoTrans, n-j, jb,
   226  					-1, a[j*lda+k:], lda, 1, a[j*lda+j:], lda)
   227  			}
   228  		}
   229  	}
   230  	return n, true
   231  }