github.com/gopherd/gonum@v0.0.4/lapack/gonum/dgesc2.go (about)

     1  // Copyright ©2021 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gopherd/gonum/blas/blas64"
    11  )
    12  
    13  // Dgesc2 solves a system of linear equations
    14  //  A * x = scale * b
    15  // with a general n×n matrix A represented by the LU factorization with complete
    16  // pivoting
    17  //  A = P * L * U * Q
    18  // as computed by Dgetc2.
    19  //
    20  // On entry, rhs contains the right hand side vector b. On return, it is
    21  // overwritten with the solution vector x.
    22  //
    23  // Dgesc2 returns a scale factor
    24  //  0 <= scale <= 1
    25  // chosen to prevent overflow in the solution.
    26  //
    27  // Dgesc2 is an internal routine. It is exported for testing purposes.
    28  func (impl Implementation) Dgesc2(n int, a []float64, lda int, rhs []float64, ipiv, jpiv []int) (scale float64) {
    29  	switch {
    30  	case n < 0:
    31  		panic(nLT0)
    32  	case lda < max(1, n):
    33  		panic(badLdA)
    34  	}
    35  
    36  	// Quick return if possible.
    37  	if n == 0 {
    38  		return 0
    39  	}
    40  
    41  	switch {
    42  	case len(a) < (n-1)*lda+n:
    43  		panic(shortA)
    44  	case len(rhs) < n:
    45  		panic(shortRHS)
    46  	case len(ipiv) != n:
    47  		panic(badLenIpiv)
    48  	case len(jpiv) != n:
    49  		panic(badLenJpiv)
    50  	}
    51  
    52  	const smlnum = dlamchS / dlamchP
    53  
    54  	// Apply permutations ipiv to rhs.
    55  	impl.Dlaswp(1, rhs, 1, 0, n-1, ipiv[:n], 1)
    56  
    57  	// Solve for L part.
    58  	for i := 0; i < n-1; i++ {
    59  		for j := i + 1; j < n; j++ {
    60  			rhs[j] -= float64(a[j*lda+i] * rhs[i])
    61  		}
    62  	}
    63  
    64  	// Check for scaling.
    65  	scale = 1.0
    66  	bi := blas64.Implementation()
    67  	i := bi.Idamax(n, rhs, 1)
    68  	if 2*smlnum*math.Abs(rhs[i]) > math.Abs(a[(n-1)*lda+(n-1)]) {
    69  		temp := 0.5 / math.Abs(rhs[i])
    70  		bi.Dscal(n, temp, rhs, 1)
    71  		scale *= temp
    72  	}
    73  
    74  	// Solve for U part.
    75  	for i := n - 1; i >= 0; i-- {
    76  		temp := 1.0 / a[i*lda+i]
    77  		rhs[i] *= temp
    78  		for j := i + 1; j < n; j++ {
    79  			rhs[i] -= float64(rhs[j] * (a[i*lda+j] * temp))
    80  		}
    81  	}
    82  
    83  	// Apply permutations jpiv to the solution (rhs).
    84  	impl.Dlaswp(1, rhs, 1, 0, n-1, jpiv[:n], -1)
    85  
    86  	return scale
    87  }