gonum.org/v1/gonum@v0.14.0/lapack/gonum/dgtsv.go (about)

     1  // Copyright ©2020 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import "math"
     8  
     9  // Dgtsv solves the equation
    10  //
    11  //	A * X = B
    12  //
    13  // where A is an n×n tridiagonal matrix. It uses Gaussian elimination with
    14  // partial pivoting. The equation Aᵀ * X = B may be solved by swapping the
    15  // arguments for du and dl.
    16  //
    17  // On entry, dl, d and du contain the sub-diagonal, the diagonal and the
    18  // super-diagonal, respectively, of A. On return, the first n-2 elements of dl,
    19  // the first n-1 elements of du and the first n elements of d may be
    20  // overwritten.
    21  //
    22  // On entry, b contains the n×nrhs right-hand side matrix B. On return, b will
    23  // be overwritten. If ok is true, it will be overwritten by the solution matrix X.
    24  //
    25  // Dgtsv returns whether the solution X has been successfully computed.
    26  func (impl Implementation) Dgtsv(n, nrhs int, dl, d, du []float64, b []float64, ldb int) (ok bool) {
    27  	switch {
    28  	case n < 0:
    29  		panic(nLT0)
    30  	case nrhs < 0:
    31  		panic(nrhsLT0)
    32  	case ldb < max(1, nrhs):
    33  		panic(badLdB)
    34  	}
    35  
    36  	if n == 0 || nrhs == 0 {
    37  		return true
    38  	}
    39  
    40  	switch {
    41  	case len(dl) < n-1:
    42  		panic(shortDL)
    43  	case len(d) < n:
    44  		panic(shortD)
    45  	case len(du) < n-1:
    46  		panic(shortDU)
    47  	case len(b) < (n-1)*ldb+nrhs:
    48  		panic(shortB)
    49  	}
    50  
    51  	dl = dl[:n-1]
    52  	d = d[:n]
    53  	du = du[:n-1]
    54  
    55  	for i := 0; i < n-1; i++ {
    56  		if math.Abs(d[i]) >= math.Abs(dl[i]) {
    57  			// No row interchange required.
    58  			if d[i] == 0 {
    59  				return false
    60  			}
    61  			fact := dl[i] / d[i]
    62  			d[i+1] -= fact * du[i]
    63  			for j := 0; j < nrhs; j++ {
    64  				b[(i+1)*ldb+j] -= fact * b[i*ldb+j]
    65  			}
    66  			dl[i] = 0
    67  		} else {
    68  			// Interchange rows i and i+1.
    69  			fact := d[i] / dl[i]
    70  			d[i] = dl[i]
    71  			tmp := d[i+1]
    72  			d[i+1] = du[i] - fact*tmp
    73  			du[i] = tmp
    74  			if i+1 < n-1 {
    75  				dl[i] = du[i+1]
    76  				du[i+1] = -fact * dl[i]
    77  			}
    78  			for j := 0; j < nrhs; j++ {
    79  				tmp = b[i*ldb+j]
    80  				b[i*ldb+j] = b[(i+1)*ldb+j]
    81  				b[(i+1)*ldb+j] = tmp - fact*b[(i+1)*ldb+j]
    82  			}
    83  		}
    84  	}
    85  	if d[n-1] == 0 {
    86  		return false
    87  	}
    88  
    89  	// Back solve with the matrix U from the factorization.
    90  	for j := 0; j < nrhs; j++ {
    91  		b[(n-1)*ldb+j] /= d[n-1]
    92  		if n > 1 {
    93  			b[(n-2)*ldb+j] = (b[(n-2)*ldb+j] - du[n-2]*b[(n-1)*ldb+j]) / d[n-2]
    94  		}
    95  		for i := n - 3; i >= 0; i-- {
    96  			b[i*ldb+j] = (b[i*ldb+j] - du[i]*b[(i+1)*ldb+j] - dl[i]*b[(i+2)*ldb+j]) / d[i]
    97  		}
    98  	}
    99  
   100  	return true
   101  }