github.com/gopherd/gonum@v0.0.4/lapack/gonum/dpstf2.go (about)

     1  // Copyright ©2021 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gopherd/gonum/blas"
    11  	"github.com/gopherd/gonum/blas/blas64"
    12  )
    13  
    14  // Dpstf2 computes the Cholesky factorization with complete pivoting of an n×n
    15  // symmetric positive semidefinite matrix A.
    16  //
    17  // The factorization has the form
    18  //  Pᵀ * A * P = Uᵀ * U ,  if uplo = blas.Upper,
    19  //  Pᵀ * A * P = L  * Lᵀ,  if uplo = blas.Lower,
    20  // where U is an upper triangular matrix, L is lower triangular, and P is a
    21  // permutation matrix.
    22  //
    23  // tol is a user-defined tolerance. The algorithm terminates if the pivot is
    24  // less than or equal to tol. If tol is negative, then n*eps*max(A[k,k]) will be
    25  // used instead.
    26  //
    27  // On return, A contains the factor U or L from the Cholesky factorization and
    28  // piv contains P stored such that P[piv[k],k] = 1.
    29  //
    30  // Dpstf2 returns the computed rank of A and whether the factorization can be
    31  // used to solve a system. Dpstf2 does not attempt to check that A is positive
    32  // semi-definite, so if ok is false, the matrix A is either rank deficient or is
    33  // not positive semidefinite.
    34  //
    35  // The length of piv must be n and the length of work must be at least 2*n,
    36  // otherwise Dpstf2 will panic.
    37  //
    38  // Dpstf2 is an internal routine. It is exported for testing purposes.
    39  func (Implementation) Dpstf2(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) {
    40  	switch {
    41  	case uplo != blas.Upper && uplo != blas.Lower:
    42  		panic(badUplo)
    43  	case n < 0:
    44  		panic(nLT0)
    45  	case lda < max(1, n):
    46  		panic(badLdA)
    47  	}
    48  
    49  	// Quick return if possible.
    50  	if n == 0 {
    51  		return 0, true
    52  	}
    53  
    54  	switch {
    55  	case len(a) < (n-1)*lda+n:
    56  		panic(shortA)
    57  	case len(piv) != n:
    58  		panic(badLenPiv)
    59  	case len(work) < 2*n:
    60  		panic(shortWork)
    61  	}
    62  
    63  	// Initialize piv.
    64  	for i := range piv[:n] {
    65  		piv[i] = i
    66  	}
    67  
    68  	// Compute the first pivot.
    69  	pvt := 0
    70  	ajj := a[0]
    71  	for i := 1; i < n; i++ {
    72  		aii := a[i*lda+i]
    73  		if aii > ajj {
    74  			pvt = i
    75  			ajj = aii
    76  		}
    77  	}
    78  	if ajj <= 0 || math.IsNaN(ajj) {
    79  		return 0, false
    80  	}
    81  
    82  	// Compute stopping value if not supplied.
    83  	dstop := tol
    84  	if dstop < 0 {
    85  		dstop = float64(n) * dlamchE * ajj
    86  	}
    87  
    88  	// Set first half of work to zero, holds dot products.
    89  	dots := work[:n]
    90  	for i := range dots {
    91  		dots[i] = 0
    92  	}
    93  	work2 := work[n : 2*n]
    94  
    95  	bi := blas64.Implementation()
    96  	if uplo == blas.Upper {
    97  		// Compute the Cholesky factorization  Pᵀ * A * P = Uᵀ * U.
    98  		for j := 0; j < n; j++ {
    99  			// Update dot products and compute possible pivots which are stored
   100  			// in the second half of work.
   101  			for i := j; i < n; i++ {
   102  				if j > 0 {
   103  					tmp := a[(j-1)*lda+i]
   104  					dots[i] += tmp * tmp
   105  				}
   106  				work2[i] = a[i*lda+i] - dots[i]
   107  			}
   108  			if j > 0 {
   109  				// Find the pivot.
   110  				pvt = j
   111  				ajj = work2[pvt]
   112  				for k := j + 1; k < n; k++ {
   113  					wk := work2[k]
   114  					if wk > ajj {
   115  						pvt = k
   116  						ajj = wk
   117  					}
   118  				}
   119  				// Test for exit.
   120  				if ajj <= dstop || math.IsNaN(ajj) {
   121  					a[j*lda+j] = ajj
   122  					return j, false
   123  				}
   124  			}
   125  			if j != pvt {
   126  				// Swap pivot rows and columns.
   127  				a[pvt*lda+pvt] = a[j*lda+j]
   128  				bi.Dswap(j, a[j:], lda, a[pvt:], lda)
   129  				if pvt < n-1 {
   130  					bi.Dswap(n-pvt-1, a[j*lda+(pvt+1):], 1, a[pvt*lda+(pvt+1):], 1)
   131  				}
   132  				bi.Dswap(pvt-j-1, a[j*lda+(j+1):], 1, a[(j+1)*lda+pvt:], lda)
   133  				// Swap dot products and piv.
   134  				dots[j], dots[pvt] = dots[pvt], dots[j]
   135  				piv[j], piv[pvt] = piv[pvt], piv[j]
   136  			}
   137  			ajj = math.Sqrt(ajj)
   138  			a[j*lda+j] = ajj
   139  			// Compute elements j+1:n of row j.
   140  			if j < n-1 {
   141  				bi.Dgemv(blas.Trans, j, n-j-1,
   142  					-1, a[j+1:], lda, a[j:], lda,
   143  					1, a[j*lda+j+1:], 1)
   144  				bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1)
   145  			}
   146  		}
   147  	} else {
   148  		// Compute the Cholesky factorization  Pᵀ * A * P = L * Lᵀ.
   149  		for j := 0; j < n; j++ {
   150  			// Update dot products and compute possible pivots which are stored
   151  			// in the second half of work.
   152  			for i := j; i < n; i++ {
   153  				if j > 0 {
   154  					tmp := a[i*lda+(j-1)]
   155  					dots[i] += tmp * tmp
   156  				}
   157  				work2[i] = a[i*lda+i] - dots[i]
   158  			}
   159  			if j > 0 {
   160  				// Find the pivot.
   161  				pvt = j
   162  				ajj = work2[pvt]
   163  				for k := j + 1; k < n; k++ {
   164  					wk := work2[k]
   165  					if wk > ajj {
   166  						pvt = k
   167  						ajj = wk
   168  					}
   169  				}
   170  				// Test for exit.
   171  				if ajj <= dstop || math.IsNaN(ajj) {
   172  					a[j*lda+j] = ajj
   173  					return j, false
   174  				}
   175  			}
   176  			if j != pvt {
   177  				// Swap pivot rows and columns.
   178  				a[pvt*lda+pvt] = a[j*lda+j]
   179  				bi.Dswap(j, a[j*lda:], 1, a[pvt*lda:], 1)
   180  				if pvt < n-1 {
   181  					bi.Dswap(n-pvt-1, a[(pvt+1)*lda+j:], lda, a[(pvt+1)*lda+pvt:], lda)
   182  				}
   183  				bi.Dswap(pvt-j-1, a[(j+1)*lda+j:], lda, a[pvt*lda+(j+1):], 1)
   184  				// Swap dot products and piv.
   185  				dots[j], dots[pvt] = dots[pvt], dots[j]
   186  				piv[j], piv[pvt] = piv[pvt], piv[j]
   187  			}
   188  			ajj = math.Sqrt(ajj)
   189  			a[j*lda+j] = ajj
   190  			// Compute elements j+1:n of column j.
   191  			if j < n-1 {
   192  				bi.Dgemv(blas.NoTrans, n-j-1, j,
   193  					-1, a[(j+1)*lda:], lda, a[j*lda:], 1,
   194  					1, a[(j+1)*lda+j:], lda)
   195  				bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda)
   196  			}
   197  		}
   198  	}
   199  	return n, true
   200  }