github.com/gopherd/gonum@v0.0.4/mat/tridiag.go (about)

     1  // Copyright ©2020 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gopherd/gonum/blas"
    11  	"github.com/gopherd/gonum/blas/blas64"
    12  	"github.com/gopherd/gonum/internal/asm/f64"
    13  	"github.com/gopherd/gonum/lapack/lapack64"
    14  )
    15  
    16  var (
    17  	tridiagDense *Tridiag
    18  	_            Matrix           = tridiagDense
    19  	_            allMatrix        = tridiagDense
    20  	_            denseMatrix      = tridiagDense
    21  	_            Banded           = tridiagDense
    22  	_            MutableBanded    = tridiagDense
    23  	_            RawTridiagonaler = tridiagDense
    24  )
    25  
    26  // A RawTridiagonaler can return a lapack64.Tridiagonal representation of the
    27  // receiver. Changes to the elements of DL, D, DU in lapack64.Tridiagonal will
    28  // be reflected in the original matrix, changes to the N field will not.
    29  type RawTridiagonaler interface {
    30  	RawTridiagonal() lapack64.Tridiagonal
    31  }
    32  
    33  // Tridiag represents a tridiagonal matrix by its three diagonals.
    34  type Tridiag struct {
    35  	mat lapack64.Tridiagonal
    36  }
    37  
    38  // NewTridiag creates a new n×n tridiagonal matrix with the first sub-diagonal
    39  // in dl, the main diagonal in d and the first super-diagonal in du. If all of
    40  // dl, d, and du are nil, new backing slices will be allocated for them. If dl
    41  // and du have length n-1 and d has length n, they will be used as backing
    42  // slices, and changes to the elements of the returned Tridiag will be reflected
    43  // in dl, d, du. If neither of these is true, NewTridiag will panic.
    44  func NewTridiag(n int, dl, d, du []float64) *Tridiag {
    45  	if n <= 0 {
    46  		if n == 0 {
    47  			panic(ErrZeroLength)
    48  		}
    49  		panic(ErrNegativeDimension)
    50  	}
    51  	if dl != nil || d != nil || du != nil {
    52  		if len(dl) != n-1 || len(d) != n || len(du) != n-1 {
    53  			panic(ErrShape)
    54  		}
    55  	} else {
    56  		d = make([]float64, n)
    57  		if n > 1 {
    58  			dl = make([]float64, n-1)
    59  			du = make([]float64, n-1)
    60  		}
    61  	}
    62  	return &Tridiag{
    63  		mat: lapack64.Tridiagonal{
    64  			N:  n,
    65  			DL: dl,
    66  			D:  d,
    67  			DU: du,
    68  		},
    69  	}
    70  }
    71  
    72  // Dims returns the number of rows and columns in the matrix.
    73  func (a *Tridiag) Dims() (r, c int) {
    74  	return a.mat.N, a.mat.N
    75  }
    76  
    77  // Bandwidth returns 1, 1 - the upper and lower bandwidths of the matrix.
    78  func (a *Tridiag) Bandwidth() (kl, ku int) {
    79  	return 1, 1
    80  }
    81  
    82  // T performs an implicit transpose by returning the receiver inside a Transpose.
    83  func (a *Tridiag) T() Matrix {
    84  	// An alternative would be to return the receiver with DL,DU swapped; the
    85  	// untranspose function would then always return false. With Transpose the
    86  	// diagonal swapping will be done in tridiagonal routines in lapack like
    87  	// lapack64.Gtsv or gonum.Dlagtm based on the trans parameter.
    88  	return Transpose{a}
    89  }
    90  
    91  // TBand performs an implicit transpose by returning the receiver inside a
    92  // TransposeBand.
    93  func (a *Tridiag) TBand() Banded {
    94  	// An alternative would be to return the receiver with DL,DU swapped; see
    95  	// explanation in T above.
    96  	return TransposeBand{a}
    97  }
    98  
    99  // RawTridiagonal returns the underlying lapack64.Tridiagonal used by the
   100  // receiver. Changes to elements in the receiver following the call will be
   101  // reflected in the returned matrix.
   102  func (a *Tridiag) RawTridiagonal() lapack64.Tridiagonal {
   103  	return a.mat
   104  }
   105  
   106  // SetRawTridiagonal sets the underlying lapack64.Tridiagonal used by the
   107  // receiver. Changes to elements in the receiver following the call will be
   108  // reflected in the input.
   109  func (a *Tridiag) SetRawTridiagonal(mat lapack64.Tridiagonal) {
   110  	a.mat = mat
   111  }
   112  
   113  // IsEmpty returns whether the receiver is empty. Empty matrices can be the
   114  // receiver for size-restricted operations. The receiver can be zeroed using
   115  // Reset.
   116  func (a *Tridiag) IsEmpty() bool {
   117  	return a.mat.N == 0
   118  }
   119  
   120  // Reset empties the matrix so that it can be reused as the receiver of a
   121  // dimensionally restricted operation.
   122  //
   123  // Reset should not be used when the matrix shares backing data. See the Reseter
   124  // interface for more information.
   125  func (a *Tridiag) Reset() {
   126  	a.mat.N = 0
   127  	a.mat.DL = a.mat.DL[:0]
   128  	a.mat.D = a.mat.D[:0]
   129  	a.mat.DU = a.mat.DU[:0]
   130  }
   131  
   132  // CloneFromTridiag makes a copy of the input Tridiag into the receiver,
   133  // overwriting the previous value of the receiver. CloneFromTridiag does not
   134  // place any restrictions on receiver shape.
   135  func (a *Tridiag) CloneFromTridiag(from *Tridiag) {
   136  	n := from.mat.N
   137  	switch n {
   138  	case 0:
   139  		panic(ErrZeroLength)
   140  	case 1:
   141  		a.mat = lapack64.Tridiagonal{
   142  			N:  1,
   143  			DL: use(a.mat.DL, 0),
   144  			D:  use(a.mat.D, 1),
   145  			DU: use(a.mat.DU, 0),
   146  		}
   147  		a.mat.D[0] = from.mat.D[0]
   148  	default:
   149  		a.mat = lapack64.Tridiagonal{
   150  			N:  n,
   151  			DL: use(a.mat.DL, n-1),
   152  			D:  use(a.mat.D, n),
   153  			DU: use(a.mat.DU, n-1),
   154  		}
   155  		copy(a.mat.DL, from.mat.DL)
   156  		copy(a.mat.D, from.mat.D)
   157  		copy(a.mat.DU, from.mat.DU)
   158  	}
   159  }
   160  
   161  // DiagView returns the diagonal as a matrix backed by the original data.
   162  func (a *Tridiag) DiagView() Diagonal {
   163  	return &DiagDense{
   164  		mat: blas64.Vector{
   165  			N:    a.mat.N,
   166  			Data: a.mat.D[:a.mat.N],
   167  			Inc:  1,
   168  		},
   169  	}
   170  }
   171  
   172  // Zero sets all of the matrix elements to zero.
   173  func (a *Tridiag) Zero() {
   174  	zero(a.mat.DL)
   175  	zero(a.mat.D)
   176  	zero(a.mat.DU)
   177  }
   178  
   179  // Trace returns the trace of the matrix.
   180  //
   181  // Trace will panic with ErrZeroLength if the matrix has zero size.
   182  func (a *Tridiag) Trace() float64 {
   183  	if a.IsEmpty() {
   184  		panic(ErrZeroLength)
   185  	}
   186  	return f64.Sum(a.mat.D)
   187  }
   188  
   189  // Norm returns the specified norm of the receiver. Valid norms are:
   190  //  1 - The maximum absolute column sum
   191  //  2 - The Frobenius norm, the square root of the sum of the squares of the elements
   192  //  Inf - The maximum absolute row sum
   193  //
   194  // Norm will panic with ErrNormOrder if an illegal norm is specified and with
   195  // ErrZeroLength if the matrix has zero size.
   196  func (a *Tridiag) Norm(norm float64) float64 {
   197  	if a.IsEmpty() {
   198  		panic(ErrZeroLength)
   199  	}
   200  	return lapack64.Langt(normLapack(norm, false), a.mat)
   201  }
   202  
   203  // MulVecTo computes A⋅x or Aᵀ⋅x storing the result into dst.
   204  func (a *Tridiag) MulVecTo(dst *VecDense, trans bool, x Vector) {
   205  	n := a.mat.N
   206  	if x.Len() != n {
   207  		panic(ErrShape)
   208  	}
   209  	dst.reuseAsNonZeroed(n)
   210  	t := blas.NoTrans
   211  	if trans {
   212  		t = blas.Trans
   213  	}
   214  	xMat, _ := untransposeExtract(x)
   215  	if xVec, ok := xMat.(*VecDense); ok && dst != xVec {
   216  		dst.checkOverlap(xVec.mat)
   217  		lapack64.Lagtm(t, 1, a.mat, xVec.asGeneral(), 0, dst.asGeneral())
   218  	} else {
   219  		xCopy := getVecDenseWorkspace(n, false)
   220  		xCopy.CloneFromVec(x)
   221  		lapack64.Lagtm(t, 1, a.mat, xCopy.asGeneral(), 0, dst.asGeneral())
   222  		putVecDenseWorkspace(xCopy)
   223  	}
   224  }
   225  
   226  // SolveTo solves a tridiagonal system A⋅X = B  or  Aᵀ⋅X = B where A is an
   227  // n×n tridiagonal matrix represented by the receiver and B is a given n×nrhs
   228  // matrix. If A is non-singular, the result will be stored into dst and nil will
   229  // be returned. If A is singular, the contents of dst will be undefined and a
   230  // Condition error will be returned.
   231  func (a *Tridiag) SolveTo(dst *Dense, trans bool, b Matrix) error {
   232  	n, nrhs := b.Dims()
   233  	if n != a.mat.N {
   234  		panic(ErrShape)
   235  	}
   236  	if b, ok := b.(RawMatrixer); ok && dst != b {
   237  		dst.checkOverlap(b.RawMatrix())
   238  	}
   239  	dst.reuseAsNonZeroed(n, nrhs)
   240  	if dst != b {
   241  		dst.Copy(b)
   242  	}
   243  	var aCopy Tridiag
   244  	aCopy.CloneFromTridiag(a)
   245  	var ok bool
   246  	if trans {
   247  		ok = lapack64.Gtsv(blas.Trans, aCopy.mat, dst.mat)
   248  	} else {
   249  		ok = lapack64.Gtsv(blas.NoTrans, aCopy.mat, dst.mat)
   250  	}
   251  	if !ok {
   252  		return Condition(math.Inf(1))
   253  	}
   254  	return nil
   255  }
   256  
   257  // SolveVecTo solves a tridiagonal system A⋅X = B  or  Aᵀ⋅X = B where A is an
   258  // n×n tridiagonal matrix represented by the receiver and b is a given n-vector.
   259  // If A is non-singular, the result will be stored into dst and nil will be
   260  // returned. If A is singular, the contents of dst will be undefined and a
   261  // Condition error will be returned.
   262  func (a *Tridiag) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
   263  	n, nrhs := b.Dims()
   264  	if n != a.mat.N || nrhs != 1 {
   265  		panic(ErrShape)
   266  	}
   267  	if b, ok := b.(RawVectorer); ok && dst != b {
   268  		dst.checkOverlap(b.RawVector())
   269  	}
   270  	dst.reuseAsNonZeroed(n)
   271  	if dst != b {
   272  		dst.CopyVec(b)
   273  	}
   274  	var aCopy Tridiag
   275  	aCopy.CloneFromTridiag(a)
   276  	var ok bool
   277  	if trans {
   278  		ok = lapack64.Gtsv(blas.Trans, aCopy.mat, dst.asGeneral())
   279  	} else {
   280  		ok = lapack64.Gtsv(blas.NoTrans, aCopy.mat, dst.asGeneral())
   281  	}
   282  	if !ok {
   283  		return Condition(math.Inf(1))
   284  	}
   285  	return nil
   286  }
   287  
   288  // DoNonZero calls the function fn for each of the non-zero elements of A. The
   289  // function fn takes a row/column index and the element value of A at (i,j).
   290  func (a *Tridiag) DoNonZero(fn func(i, j int, v float64)) {
   291  	for i, aij := range a.mat.DU {
   292  		if aij != 0 {
   293  			fn(i, i+1, aij)
   294  		}
   295  	}
   296  	for i, aii := range a.mat.D {
   297  		if aii != 0 {
   298  			fn(i, i, aii)
   299  		}
   300  	}
   301  	for i, aij := range a.mat.DL {
   302  		if aij != 0 {
   303  			fn(i+1, i, aij)
   304  		}
   305  	}
   306  }
   307  
   308  // DoRowNonZero calls the function fn for each of the non-zero elements of row i
   309  // of A. The function fn takes a row/column index and the element value of A at
   310  // (i,j).
   311  func (a *Tridiag) DoRowNonZero(i int, fn func(i, j int, v float64)) {
   312  	n := a.mat.N
   313  	if uint(i) >= uint(n) {
   314  		panic(ErrRowAccess)
   315  	}
   316  	if n == 1 {
   317  		v := a.mat.D[0]
   318  		if v != 0 {
   319  			fn(0, 0, v)
   320  		}
   321  		return
   322  	}
   323  	switch i {
   324  	case 0:
   325  		v := a.mat.D[0]
   326  		if v != 0 {
   327  			fn(i, 0, v)
   328  		}
   329  		v = a.mat.DU[0]
   330  		if v != 0 {
   331  			fn(i, 1, v)
   332  		}
   333  	case n - 1:
   334  		v := a.mat.DL[n-2]
   335  		if v != 0 {
   336  			fn(n-1, n-2, v)
   337  		}
   338  		v = a.mat.D[n-1]
   339  		if v != 0 {
   340  			fn(n-1, n-1, v)
   341  		}
   342  	default:
   343  		v := a.mat.DL[i-1]
   344  		if v != 0 {
   345  			fn(i, i-1, v)
   346  		}
   347  		v = a.mat.D[i]
   348  		if v != 0 {
   349  			fn(i, i, v)
   350  		}
   351  		v = a.mat.DU[i]
   352  		if v != 0 {
   353  			fn(i, i+1, v)
   354  		}
   355  	}
   356  }
   357  
   358  // DoColNonZero calls the function fn for each of the non-zero elements of
   359  // column j of A. The function fn takes a row/column index and the element value
   360  // of A at (i, j).
   361  func (a *Tridiag) DoColNonZero(j int, fn func(i, j int, v float64)) {
   362  	n := a.mat.N
   363  	if uint(j) >= uint(n) {
   364  		panic(ErrColAccess)
   365  	}
   366  	if n == 1 {
   367  		v := a.mat.D[0]
   368  		if v != 0 {
   369  			fn(0, 0, v)
   370  		}
   371  		return
   372  	}
   373  	switch j {
   374  	case 0:
   375  		v := a.mat.D[0]
   376  		if v != 0 {
   377  			fn(0, 0, v)
   378  		}
   379  		v = a.mat.DL[0]
   380  		if v != 0 {
   381  			fn(1, 0, v)
   382  		}
   383  	case n - 1:
   384  		v := a.mat.DU[n-2]
   385  		if v != 0 {
   386  			fn(n-2, n-1, v)
   387  		}
   388  		v = a.mat.D[n-1]
   389  		if v != 0 {
   390  			fn(n-1, n-1, v)
   391  		}
   392  	default:
   393  		v := a.mat.DU[j-1]
   394  		if v != 0 {
   395  			fn(j-1, j, v)
   396  		}
   397  		v = a.mat.D[j]
   398  		if v != 0 {
   399  			fn(j, j, v)
   400  		}
   401  		v = a.mat.DL[j]
   402  		if v != 0 {
   403  			fn(j+1, j, v)
   404  		}
   405  	}
   406  }