github.com/gopherd/gonum@v0.0.4/lapack/gonum/dgtsv.go (about)

     1  // Copyright ©2020 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import "math"
     8  
     9  // Dgtsv solves the equation
    10  //  A * X = B
    11  // where A is an n×n tridiagonal matrix. It uses Gaussian elimination with
    12  // partial pivoting. The equation Aᵀ * X = B may be solved by swapping the
    13  // arguments for du and dl.
    14  //
    15  // On entry, dl, d and du contain the sub-diagonal, the diagonal and the
    16  // super-diagonal, respectively, of A. On return, the first n-2 elements of dl,
    17  // the first n-1 elements of du and the first n elements of d may be
    18  // overwritten.
    19  //
    20  // On entry, b contains the n×nrhs right-hand side matrix B. On return, b will
    21  // be overwritten. If ok is true, it will be overwritten by the solution matrix X.
    22  //
    23  // Dgtsv returns whether the solution X has been successfuly computed.
    24  func (impl Implementation) Dgtsv(n, nrhs int, dl, d, du []float64, b []float64, ldb int) (ok bool) {
    25  	switch {
    26  	case n < 0:
    27  		panic(nLT0)
    28  	case nrhs < 0:
    29  		panic(nrhsLT0)
    30  	case ldb < max(1, nrhs):
    31  		panic(badLdB)
    32  	}
    33  
    34  	if n == 0 || nrhs == 0 {
    35  		return true
    36  	}
    37  
    38  	switch {
    39  	case len(dl) < n-1:
    40  		panic(shortDL)
    41  	case len(d) < n:
    42  		panic(shortD)
    43  	case len(du) < n-1:
    44  		panic(shortDU)
    45  	case len(b) < (n-1)*ldb+nrhs:
    46  		panic(shortB)
    47  	}
    48  
    49  	dl = dl[:n-1]
    50  	d = d[:n]
    51  	du = du[:n-1]
    52  
    53  	for i := 0; i < n-1; i++ {
    54  		if math.Abs(d[i]) >= math.Abs(dl[i]) {
    55  			// No row interchange required.
    56  			if d[i] == 0 {
    57  				return false
    58  			}
    59  			fact := dl[i] / d[i]
    60  			d[i+1] -= fact * du[i]
    61  			for j := 0; j < nrhs; j++ {
    62  				b[(i+1)*ldb+j] -= fact * b[i*ldb+j]
    63  			}
    64  			dl[i] = 0
    65  		} else {
    66  			// Interchange rows i and i+1.
    67  			fact := d[i] / dl[i]
    68  			d[i] = dl[i]
    69  			tmp := d[i+1]
    70  			d[i+1] = du[i] - fact*tmp
    71  			du[i] = tmp
    72  			if i+1 < n-1 {
    73  				dl[i] = du[i+1]
    74  				du[i+1] = -fact * dl[i]
    75  			}
    76  			for j := 0; j < nrhs; j++ {
    77  				tmp = b[i*ldb+j]
    78  				b[i*ldb+j] = b[(i+1)*ldb+j]
    79  				b[(i+1)*ldb+j] = tmp - fact*b[(i+1)*ldb+j]
    80  			}
    81  		}
    82  	}
    83  	if d[n-1] == 0 {
    84  		return false
    85  	}
    86  
    87  	// Back solve with the matrix U from the factorization.
    88  	for j := 0; j < nrhs; j++ {
    89  		b[(n-1)*ldb+j] /= d[n-1]
    90  		if n > 1 {
    91  			b[(n-2)*ldb+j] = (b[(n-2)*ldb+j] - du[n-2]*b[(n-1)*ldb+j]) / d[n-2]
    92  		}
    93  		for i := n - 3; i >= 0; i-- {
    94  			b[i*ldb+j] = (b[i*ldb+j] - du[i]*b[(i+1)*ldb+j] - dl[i]*b[(i+2)*ldb+j]) / d[i]
    95  		}
    96  	}
    97  
    98  	return true
    99  }