gonum.org/v1/gonum@v0.14.0/lapack/gonum/dgetrs.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"gonum.org/v1/gonum/blas"
     9  	"gonum.org/v1/gonum/blas/blas64"
    10  )
    11  
    12  // Dgetrs solves a system of equations using an LU factorization.
    13  // The system of equations solved is
    14  //
    15  //	A * X = B  if trans == blas.Trans
    16  //	Aᵀ * X = B if trans == blas.NoTrans
    17  //
    18  // A is a general n×n matrix with stride lda. B is a general matrix of size n×nrhs.
    19  //
    20  // On entry b contains the elements of the matrix B. On exit, b contains the
    21  // elements of X, the solution to the system of equations.
    22  //
    23  // a and ipiv contain the LU factorization of A and the permutation indices as
    24  // computed by Dgetrf. ipiv is zero-indexed.
    25  func (impl Implementation) Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) {
    26  	switch {
    27  	case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans:
    28  		panic(badTrans)
    29  	case n < 0:
    30  		panic(nLT0)
    31  	case nrhs < 0:
    32  		panic(nrhsLT0)
    33  	case lda < max(1, n):
    34  		panic(badLdA)
    35  	case ldb < max(1, nrhs):
    36  		panic(badLdB)
    37  	}
    38  
    39  	// Quick return if possible.
    40  	if n == 0 || nrhs == 0 {
    41  		return
    42  	}
    43  
    44  	switch {
    45  	case len(a) < (n-1)*lda+n:
    46  		panic(shortA)
    47  	case len(b) < (n-1)*ldb+nrhs:
    48  		panic(shortB)
    49  	case len(ipiv) != n:
    50  		panic(badLenIpiv)
    51  	}
    52  
    53  	bi := blas64.Implementation()
    54  
    55  	if trans == blas.NoTrans {
    56  		// Solve A * X = B.
    57  		impl.Dlaswp(nrhs, b, ldb, 0, n-1, ipiv, 1)
    58  		// Solve L * X = B, updating b.
    59  		bi.Dtrsm(blas.Left, blas.Lower, blas.NoTrans, blas.Unit,
    60  			n, nrhs, 1, a, lda, b, ldb)
    61  		// Solve U * X = B, updating b.
    62  		bi.Dtrsm(blas.Left, blas.Upper, blas.NoTrans, blas.NonUnit,
    63  			n, nrhs, 1, a, lda, b, ldb)
    64  		return
    65  	}
    66  	// Solve Aᵀ * X = B.
    67  	// Solve Uᵀ * X = B, updating b.
    68  	bi.Dtrsm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit,
    69  		n, nrhs, 1, a, lda, b, ldb)
    70  	// Solve Lᵀ * X = B, updating b.
    71  	bi.Dtrsm(blas.Left, blas.Lower, blas.Trans, blas.Unit,
    72  		n, nrhs, 1, a, lda, b, ldb)
    73  	impl.Dlaswp(nrhs, b, ldb, 0, n-1, ipiv, -1)
    74  }