gonum.org/v1/gonum@v0.14.0/lapack/gonum/dpstf2.go (about)

     1  // Copyright ©2021 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"gonum.org/v1/gonum/blas"
    11  	"gonum.org/v1/gonum/blas/blas64"
    12  )
    13  
    14  // Dpstf2 computes the Cholesky factorization with complete pivoting of an n×n
    15  // symmetric positive semidefinite matrix A.
    16  //
    17  // The factorization has the form
    18  //
    19  //	Pᵀ * A * P = Uᵀ * U ,  if uplo = blas.Upper,
    20  //	Pᵀ * A * P = L  * Lᵀ,  if uplo = blas.Lower,
    21  //
    22  // where U is an upper triangular matrix, L is lower triangular, and P is a
    23  // permutation matrix.
    24  //
    25  // tol is a user-defined tolerance. The algorithm terminates if the pivot is
    26  // less than or equal to tol. If tol is negative, then n*eps*max(A[k,k]) will be
    27  // used instead.
    28  //
    29  // On return, A contains the factor U or L from the Cholesky factorization and
    30  // piv contains P stored such that P[piv[k],k] = 1.
    31  //
    32  // Dpstf2 returns the computed rank of A and whether the factorization can be
    33  // used to solve a system. Dpstf2 does not attempt to check that A is positive
    34  // semi-definite, so if ok is false, the matrix A is either rank deficient or is
    35  // not positive semidefinite.
    36  //
    37  // The length of piv must be n and the length of work must be at least 2*n,
    38  // otherwise Dpstf2 will panic.
    39  //
    40  // Dpstf2 is an internal routine. It is exported for testing purposes.
    41  func (Implementation) Dpstf2(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) {
    42  	switch {
    43  	case uplo != blas.Upper && uplo != blas.Lower:
    44  		panic(badUplo)
    45  	case n < 0:
    46  		panic(nLT0)
    47  	case lda < max(1, n):
    48  		panic(badLdA)
    49  	}
    50  
    51  	// Quick return if possible.
    52  	if n == 0 {
    53  		return 0, true
    54  	}
    55  
    56  	switch {
    57  	case len(a) < (n-1)*lda+n:
    58  		panic(shortA)
    59  	case len(piv) != n:
    60  		panic(badLenPiv)
    61  	case len(work) < 2*n:
    62  		panic(shortWork)
    63  	}
    64  
    65  	// Initialize piv.
    66  	for i := range piv[:n] {
    67  		piv[i] = i
    68  	}
    69  
    70  	// Compute the first pivot.
    71  	pvt := 0
    72  	ajj := a[0]
    73  	for i := 1; i < n; i++ {
    74  		aii := a[i*lda+i]
    75  		if aii > ajj {
    76  			pvt = i
    77  			ajj = aii
    78  		}
    79  	}
    80  	if ajj <= 0 || math.IsNaN(ajj) {
    81  		return 0, false
    82  	}
    83  
    84  	// Compute stopping value if not supplied.
    85  	dstop := tol
    86  	if dstop < 0 {
    87  		dstop = float64(n) * dlamchE * ajj
    88  	}
    89  
    90  	// Set first half of work to zero, holds dot products.
    91  	dots := work[:n]
    92  	for i := range dots {
    93  		dots[i] = 0
    94  	}
    95  	work2 := work[n : 2*n]
    96  
    97  	bi := blas64.Implementation()
    98  	if uplo == blas.Upper {
    99  		// Compute the Cholesky factorization  Pᵀ * A * P = Uᵀ * U.
   100  		for j := 0; j < n; j++ {
   101  			// Update dot products and compute possible pivots which are stored
   102  			// in the second half of work.
   103  			for i := j; i < n; i++ {
   104  				if j > 0 {
   105  					tmp := a[(j-1)*lda+i]
   106  					dots[i] += tmp * tmp
   107  				}
   108  				work2[i] = a[i*lda+i] - dots[i]
   109  			}
   110  			if j > 0 {
   111  				// Find the pivot.
   112  				pvt = j
   113  				ajj = work2[pvt]
   114  				for k := j + 1; k < n; k++ {
   115  					wk := work2[k]
   116  					if wk > ajj {
   117  						pvt = k
   118  						ajj = wk
   119  					}
   120  				}
   121  				// Test for exit.
   122  				if ajj <= dstop || math.IsNaN(ajj) {
   123  					a[j*lda+j] = ajj
   124  					return j, false
   125  				}
   126  			}
   127  			if j != pvt {
   128  				// Swap pivot rows and columns.
   129  				a[pvt*lda+pvt] = a[j*lda+j]
   130  				bi.Dswap(j, a[j:], lda, a[pvt:], lda)
   131  				if pvt < n-1 {
   132  					bi.Dswap(n-pvt-1, a[j*lda+(pvt+1):], 1, a[pvt*lda+(pvt+1):], 1)
   133  				}
   134  				bi.Dswap(pvt-j-1, a[j*lda+(j+1):], 1, a[(j+1)*lda+pvt:], lda)
   135  				// Swap dot products and piv.
   136  				dots[j], dots[pvt] = dots[pvt], dots[j]
   137  				piv[j], piv[pvt] = piv[pvt], piv[j]
   138  			}
   139  			ajj = math.Sqrt(ajj)
   140  			a[j*lda+j] = ajj
   141  			// Compute elements j+1:n of row j.
   142  			if j < n-1 {
   143  				bi.Dgemv(blas.Trans, j, n-j-1,
   144  					-1, a[j+1:], lda, a[j:], lda,
   145  					1, a[j*lda+j+1:], 1)
   146  				bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1)
   147  			}
   148  		}
   149  	} else {
   150  		// Compute the Cholesky factorization  Pᵀ * A * P = L * Lᵀ.
   151  		for j := 0; j < n; j++ {
   152  			// Update dot products and compute possible pivots which are stored
   153  			// in the second half of work.
   154  			for i := j; i < n; i++ {
   155  				if j > 0 {
   156  					tmp := a[i*lda+(j-1)]
   157  					dots[i] += tmp * tmp
   158  				}
   159  				work2[i] = a[i*lda+i] - dots[i]
   160  			}
   161  			if j > 0 {
   162  				// Find the pivot.
   163  				pvt = j
   164  				ajj = work2[pvt]
   165  				for k := j + 1; k < n; k++ {
   166  					wk := work2[k]
   167  					if wk > ajj {
   168  						pvt = k
   169  						ajj = wk
   170  					}
   171  				}
   172  				// Test for exit.
   173  				if ajj <= dstop || math.IsNaN(ajj) {
   174  					a[j*lda+j] = ajj
   175  					return j, false
   176  				}
   177  			}
   178  			if j != pvt {
   179  				// Swap pivot rows and columns.
   180  				a[pvt*lda+pvt] = a[j*lda+j]
   181  				bi.Dswap(j, a[j*lda:], 1, a[pvt*lda:], 1)
   182  				if pvt < n-1 {
   183  					bi.Dswap(n-pvt-1, a[(pvt+1)*lda+j:], lda, a[(pvt+1)*lda+pvt:], lda)
   184  				}
   185  				bi.Dswap(pvt-j-1, a[(j+1)*lda+j:], lda, a[pvt*lda+(j+1):], 1)
   186  				// Swap dot products and piv.
   187  				dots[j], dots[pvt] = dots[pvt], dots[j]
   188  				piv[j], piv[pvt] = piv[pvt], piv[j]
   189  			}
   190  			ajj = math.Sqrt(ajj)
   191  			a[j*lda+j] = ajj
   192  			// Compute elements j+1:n of column j.
   193  			if j < n-1 {
   194  				bi.Dgemv(blas.NoTrans, n-j-1, j,
   195  					-1, a[(j+1)*lda:], lda, a[j*lda:], 1,
   196  					1, a[(j+1)*lda+j:], lda)
   197  				bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda)
   198  			}
   199  		}
   200  	}
   201  	return n, true
   202  }