gonum.org/v1/gonum@v0.14.0/lapack/gonum/dgtsv.go (about) 1 // Copyright ©2020 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import "math" 8 9 // Dgtsv solves the equation 10 // 11 // A * X = B 12 // 13 // where A is an n×n tridiagonal matrix. It uses Gaussian elimination with 14 // partial pivoting. The equation Aᵀ * X = B may be solved by swapping the 15 // arguments for du and dl. 16 // 17 // On entry, dl, d and du contain the sub-diagonal, the diagonal and the 18 // super-diagonal, respectively, of A. On return, the first n-2 elements of dl, 19 // the first n-1 elements of du and the first n elements of d may be 20 // overwritten. 21 // 22 // On entry, b contains the n×nrhs right-hand side matrix B. On return, b will 23 // be overwritten. If ok is true, it will be overwritten by the solution matrix X. 24 // 25 // Dgtsv returns whether the solution X has been successfully computed. 26 func (impl Implementation) Dgtsv(n, nrhs int, dl, d, du []float64, b []float64, ldb int) (ok bool) { 27 switch { 28 case n < 0: 29 panic(nLT0) 30 case nrhs < 0: 31 panic(nrhsLT0) 32 case ldb < max(1, nrhs): 33 panic(badLdB) 34 } 35 36 if n == 0 || nrhs == 0 { 37 return true 38 } 39 40 switch { 41 case len(dl) < n-1: 42 panic(shortDL) 43 case len(d) < n: 44 panic(shortD) 45 case len(du) < n-1: 46 panic(shortDU) 47 case len(b) < (n-1)*ldb+nrhs: 48 panic(shortB) 49 } 50 51 dl = dl[:n-1] 52 d = d[:n] 53 du = du[:n-1] 54 55 for i := 0; i < n-1; i++ { 56 if math.Abs(d[i]) >= math.Abs(dl[i]) { 57 // No row interchange required. 58 if d[i] == 0 { 59 return false 60 } 61 fact := dl[i] / d[i] 62 d[i+1] -= fact * du[i] 63 for j := 0; j < nrhs; j++ { 64 b[(i+1)*ldb+j] -= fact * b[i*ldb+j] 65 } 66 dl[i] = 0 67 } else { 68 // Interchange rows i and i+1. 69 fact := d[i] / dl[i] 70 d[i] = dl[i] 71 tmp := d[i+1] 72 d[i+1] = du[i] - fact*tmp 73 du[i] = tmp 74 if i+1 < n-1 { 75 dl[i] = du[i+1] 76 du[i+1] = -fact * dl[i] 77 } 78 for j := 0; j < nrhs; j++ { 79 tmp = b[i*ldb+j] 80 b[i*ldb+j] = b[(i+1)*ldb+j] 81 b[(i+1)*ldb+j] = tmp - fact*b[(i+1)*ldb+j] 82 } 83 } 84 } 85 if d[n-1] == 0 { 86 return false 87 } 88 89 // Back solve with the matrix U from the factorization. 90 for j := 0; j < nrhs; j++ { 91 b[(n-1)*ldb+j] /= d[n-1] 92 if n > 1 { 93 b[(n-2)*ldb+j] = (b[(n-2)*ldb+j] - du[n-2]*b[(n-1)*ldb+j]) / d[n-2] 94 } 95 for i := n - 3; i >= 0; i-- { 96 b[i*ldb+j] = (b[i*ldb+j] - du[i]*b[(i+1)*ldb+j] - dl[i]*b[(i+2)*ldb+j]) / d[i] 97 } 98 } 99 100 return true 101 }