github.com/gopherd/gonum@v0.0.4/lapack/gonum/dgtsv.go (about) 1 // Copyright ©2020 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import "math" 8 9 // Dgtsv solves the equation 10 // A * X = B 11 // where A is an n×n tridiagonal matrix. It uses Gaussian elimination with 12 // partial pivoting. The equation Aᵀ * X = B may be solved by swapping the 13 // arguments for du and dl. 14 // 15 // On entry, dl, d and du contain the sub-diagonal, the diagonal and the 16 // super-diagonal, respectively, of A. On return, the first n-2 elements of dl, 17 // the first n-1 elements of du and the first n elements of d may be 18 // overwritten. 19 // 20 // On entry, b contains the n×nrhs right-hand side matrix B. On return, b will 21 // be overwritten. If ok is true, it will be overwritten by the solution matrix X. 22 // 23 // Dgtsv returns whether the solution X has been successfuly computed. 24 func (impl Implementation) Dgtsv(n, nrhs int, dl, d, du []float64, b []float64, ldb int) (ok bool) { 25 switch { 26 case n < 0: 27 panic(nLT0) 28 case nrhs < 0: 29 panic(nrhsLT0) 30 case ldb < max(1, nrhs): 31 panic(badLdB) 32 } 33 34 if n == 0 || nrhs == 0 { 35 return true 36 } 37 38 switch { 39 case len(dl) < n-1: 40 panic(shortDL) 41 case len(d) < n: 42 panic(shortD) 43 case len(du) < n-1: 44 panic(shortDU) 45 case len(b) < (n-1)*ldb+nrhs: 46 panic(shortB) 47 } 48 49 dl = dl[:n-1] 50 d = d[:n] 51 du = du[:n-1] 52 53 for i := 0; i < n-1; i++ { 54 if math.Abs(d[i]) >= math.Abs(dl[i]) { 55 // No row interchange required. 56 if d[i] == 0 { 57 return false 58 } 59 fact := dl[i] / d[i] 60 d[i+1] -= fact * du[i] 61 for j := 0; j < nrhs; j++ { 62 b[(i+1)*ldb+j] -= fact * b[i*ldb+j] 63 } 64 dl[i] = 0 65 } else { 66 // Interchange rows i and i+1. 67 fact := d[i] / dl[i] 68 d[i] = dl[i] 69 tmp := d[i+1] 70 d[i+1] = du[i] - fact*tmp 71 du[i] = tmp 72 if i+1 < n-1 { 73 dl[i] = du[i+1] 74 du[i+1] = -fact * dl[i] 75 } 76 for j := 0; j < nrhs; j++ { 77 tmp = b[i*ldb+j] 78 b[i*ldb+j] = b[(i+1)*ldb+j] 79 b[(i+1)*ldb+j] = tmp - fact*b[(i+1)*ldb+j] 80 } 81 } 82 } 83 if d[n-1] == 0 { 84 return false 85 } 86 87 // Back solve with the matrix U from the factorization. 88 for j := 0; j < nrhs; j++ { 89 b[(n-1)*ldb+j] /= d[n-1] 90 if n > 1 { 91 b[(n-2)*ldb+j] = (b[(n-2)*ldb+j] - du[n-2]*b[(n-1)*ldb+j]) / d[n-2] 92 } 93 for i := n - 3; i >= 0; i-- { 94 b[i*ldb+j] = (b[i*ldb+j] - du[i]*b[(i+1)*ldb+j] - dl[i]*b[(i+2)*ldb+j]) / d[i] 95 } 96 } 97 98 return true 99 }