gonum.org/v1/gonum@v0.14.0/mat/tridiag.go (about)

     1  // Copyright ©2020 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math"
     9  
    10  	"gonum.org/v1/gonum/blas"
    11  	"gonum.org/v1/gonum/blas/blas64"
    12  	"gonum.org/v1/gonum/internal/asm/f64"
    13  	"gonum.org/v1/gonum/lapack/lapack64"
    14  )
    15  
    16  var (
    17  	tridiagDense *Tridiag
    18  	_            Matrix           = tridiagDense
    19  	_            allMatrix        = tridiagDense
    20  	_            denseMatrix      = tridiagDense
    21  	_            Banded           = tridiagDense
    22  	_            MutableBanded    = tridiagDense
    23  	_            RawTridiagonaler = tridiagDense
    24  )
    25  
    26  // A RawTridiagonaler can return a lapack64.Tridiagonal representation of the
    27  // receiver. Changes to the elements of DL, D, DU in lapack64.Tridiagonal will
    28  // be reflected in the original matrix, changes to the N field will not.
    29  type RawTridiagonaler interface {
    30  	RawTridiagonal() lapack64.Tridiagonal
    31  }
    32  
    33  // Tridiag represents a tridiagonal matrix by its three diagonals.
    34  type Tridiag struct {
    35  	mat lapack64.Tridiagonal
    36  }
    37  
    38  // NewTridiag creates a new n×n tridiagonal matrix with the first sub-diagonal
    39  // in dl, the main diagonal in d and the first super-diagonal in du. If all of
    40  // dl, d, and du are nil, new backing slices will be allocated for them. If dl
    41  // and du have length n-1 and d has length n, they will be used as backing
    42  // slices, and changes to the elements of the returned Tridiag will be reflected
    43  // in dl, d, du. If neither of these is true, NewTridiag will panic.
    44  func NewTridiag(n int, dl, d, du []float64) *Tridiag {
    45  	if n <= 0 {
    46  		if n == 0 {
    47  			panic(ErrZeroLength)
    48  		}
    49  		panic(ErrNegativeDimension)
    50  	}
    51  	if dl != nil || d != nil || du != nil {
    52  		if len(dl) != n-1 || len(d) != n || len(du) != n-1 {
    53  			panic(ErrShape)
    54  		}
    55  	} else {
    56  		d = make([]float64, n)
    57  		if n > 1 {
    58  			dl = make([]float64, n-1)
    59  			du = make([]float64, n-1)
    60  		}
    61  	}
    62  	return &Tridiag{
    63  		mat: lapack64.Tridiagonal{
    64  			N:  n,
    65  			DL: dl,
    66  			D:  d,
    67  			DU: du,
    68  		},
    69  	}
    70  }
    71  
    72  // Dims returns the number of rows and columns in the matrix.
    73  func (a *Tridiag) Dims() (r, c int) {
    74  	return a.mat.N, a.mat.N
    75  }
    76  
    77  // Bandwidth returns 1, 1 - the upper and lower bandwidths of the matrix.
    78  func (a *Tridiag) Bandwidth() (kl, ku int) {
    79  	return 1, 1
    80  }
    81  
    82  // T performs an implicit transpose by returning the receiver inside a Transpose.
    83  func (a *Tridiag) T() Matrix {
    84  	// An alternative would be to return the receiver with DL,DU swapped; the
    85  	// untranspose function would then always return false. With Transpose the
    86  	// diagonal swapping will be done in tridiagonal routines in lapack like
    87  	// lapack64.Gtsv or gonum.Dlagtm based on the trans parameter.
    88  	return Transpose{a}
    89  }
    90  
    91  // TBand performs an implicit transpose by returning the receiver inside a
    92  // TransposeBand.
    93  func (a *Tridiag) TBand() Banded {
    94  	// An alternative would be to return the receiver with DL,DU swapped; see
    95  	// explanation in T above.
    96  	return TransposeBand{a}
    97  }
    98  
    99  // RawTridiagonal returns the underlying lapack64.Tridiagonal used by the
   100  // receiver. Changes to elements in the receiver following the call will be
   101  // reflected in the returned matrix.
   102  func (a *Tridiag) RawTridiagonal() lapack64.Tridiagonal {
   103  	return a.mat
   104  }
   105  
   106  // SetRawTridiagonal sets the underlying lapack64.Tridiagonal used by the
   107  // receiver. Changes to elements in the receiver following the call will be
   108  // reflected in the input.
   109  func (a *Tridiag) SetRawTridiagonal(mat lapack64.Tridiagonal) {
   110  	a.mat = mat
   111  }
   112  
   113  // IsEmpty returns whether the receiver is empty. Empty matrices can be the
   114  // receiver for size-restricted operations. The receiver can be zeroed using
   115  // Reset.
   116  func (a *Tridiag) IsEmpty() bool {
   117  	return a.mat.N == 0
   118  }
   119  
   120  // Reset empties the matrix so that it can be reused as the receiver of a
   121  // dimensionally restricted operation.
   122  //
   123  // Reset should not be used when the matrix shares backing data. See the Reseter
   124  // interface for more information.
   125  func (a *Tridiag) Reset() {
   126  	a.mat.N = 0
   127  	a.mat.DL = a.mat.DL[:0]
   128  	a.mat.D = a.mat.D[:0]
   129  	a.mat.DU = a.mat.DU[:0]
   130  }
   131  
   132  // CloneFromTridiag makes a copy of the input Tridiag into the receiver,
   133  // overwriting the previous value of the receiver. CloneFromTridiag does not
   134  // place any restrictions on receiver shape.
   135  func (a *Tridiag) CloneFromTridiag(from *Tridiag) {
   136  	n := from.mat.N
   137  	switch n {
   138  	case 0:
   139  		panic(ErrZeroLength)
   140  	case 1:
   141  		a.mat = lapack64.Tridiagonal{
   142  			N:  1,
   143  			DL: use(a.mat.DL, 0),
   144  			D:  use(a.mat.D, 1),
   145  			DU: use(a.mat.DU, 0),
   146  		}
   147  		a.mat.D[0] = from.mat.D[0]
   148  	default:
   149  		a.mat = lapack64.Tridiagonal{
   150  			N:  n,
   151  			DL: use(a.mat.DL, n-1),
   152  			D:  use(a.mat.D, n),
   153  			DU: use(a.mat.DU, n-1),
   154  		}
   155  		copy(a.mat.DL, from.mat.DL)
   156  		copy(a.mat.D, from.mat.D)
   157  		copy(a.mat.DU, from.mat.DU)
   158  	}
   159  }
   160  
   161  // DiagView returns the diagonal as a matrix backed by the original data.
   162  func (a *Tridiag) DiagView() Diagonal {
   163  	return &DiagDense{
   164  		mat: blas64.Vector{
   165  			N:    a.mat.N,
   166  			Data: a.mat.D[:a.mat.N],
   167  			Inc:  1,
   168  		},
   169  	}
   170  }
   171  
   172  // Zero sets all of the matrix elements to zero.
   173  func (a *Tridiag) Zero() {
   174  	zero(a.mat.DL)
   175  	zero(a.mat.D)
   176  	zero(a.mat.DU)
   177  }
   178  
   179  // Trace returns the trace of the matrix.
   180  //
   181  // Trace will panic with ErrZeroLength if the matrix has zero size.
   182  func (a *Tridiag) Trace() float64 {
   183  	if a.IsEmpty() {
   184  		panic(ErrZeroLength)
   185  	}
   186  	return f64.Sum(a.mat.D)
   187  }
   188  
   189  // Norm returns the specified norm of the receiver. Valid norms are:
   190  //
   191  //	1 - The maximum absolute column sum
   192  //	2 - The Frobenius norm, the square root of the sum of the squares of the elements
   193  //	Inf - The maximum absolute row sum
   194  //
   195  // Norm will panic with ErrNormOrder if an illegal norm is specified and with
   196  // ErrZeroLength if the matrix has zero size.
   197  func (a *Tridiag) Norm(norm float64) float64 {
   198  	if a.IsEmpty() {
   199  		panic(ErrZeroLength)
   200  	}
   201  	return lapack64.Langt(normLapack(norm, false), a.mat)
   202  }
   203  
   204  // MulVecTo computes A⋅x or Aᵀ⋅x storing the result into dst.
   205  func (a *Tridiag) MulVecTo(dst *VecDense, trans bool, x Vector) {
   206  	n := a.mat.N
   207  	if x.Len() != n {
   208  		panic(ErrShape)
   209  	}
   210  	dst.reuseAsNonZeroed(n)
   211  	t := blas.NoTrans
   212  	if trans {
   213  		t = blas.Trans
   214  	}
   215  	xMat, _ := untransposeExtract(x)
   216  	if xVec, ok := xMat.(*VecDense); ok && dst != xVec {
   217  		dst.checkOverlap(xVec.mat)
   218  		lapack64.Lagtm(t, 1, a.mat, xVec.asGeneral(), 0, dst.asGeneral())
   219  	} else {
   220  		xCopy := getVecDenseWorkspace(n, false)
   221  		xCopy.CloneFromVec(x)
   222  		lapack64.Lagtm(t, 1, a.mat, xCopy.asGeneral(), 0, dst.asGeneral())
   223  		putVecDenseWorkspace(xCopy)
   224  	}
   225  }
   226  
   227  // SolveTo solves a tridiagonal system A⋅X = B  or  Aᵀ⋅X = B where A is an
   228  // n×n tridiagonal matrix represented by the receiver and B is a given n×nrhs
   229  // matrix. If A is non-singular, the result will be stored into dst and nil will
   230  // be returned. If A is singular, the contents of dst will be undefined and a
   231  // Condition error will be returned.
   232  func (a *Tridiag) SolveTo(dst *Dense, trans bool, b Matrix) error {
   233  	n, nrhs := b.Dims()
   234  	if n != a.mat.N {
   235  		panic(ErrShape)
   236  	}
   237  	if b, ok := b.(RawMatrixer); ok && dst != b {
   238  		dst.checkOverlap(b.RawMatrix())
   239  	}
   240  	dst.reuseAsNonZeroed(n, nrhs)
   241  	if dst != b {
   242  		dst.Copy(b)
   243  	}
   244  	var aCopy Tridiag
   245  	aCopy.CloneFromTridiag(a)
   246  	var ok bool
   247  	if trans {
   248  		ok = lapack64.Gtsv(blas.Trans, aCopy.mat, dst.mat)
   249  	} else {
   250  		ok = lapack64.Gtsv(blas.NoTrans, aCopy.mat, dst.mat)
   251  	}
   252  	if !ok {
   253  		return Condition(math.Inf(1))
   254  	}
   255  	return nil
   256  }
   257  
   258  // SolveVecTo solves a tridiagonal system A⋅X = B  or  Aᵀ⋅X = B where A is an
   259  // n×n tridiagonal matrix represented by the receiver and b is a given n-vector.
   260  // If A is non-singular, the result will be stored into dst and nil will be
   261  // returned. If A is singular, the contents of dst will be undefined and a
   262  // Condition error will be returned.
   263  func (a *Tridiag) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
   264  	n, nrhs := b.Dims()
   265  	if n != a.mat.N || nrhs != 1 {
   266  		panic(ErrShape)
   267  	}
   268  	if b, ok := b.(RawVectorer); ok && dst != b {
   269  		dst.checkOverlap(b.RawVector())
   270  	}
   271  	dst.reuseAsNonZeroed(n)
   272  	if dst != b {
   273  		dst.CopyVec(b)
   274  	}
   275  	var aCopy Tridiag
   276  	aCopy.CloneFromTridiag(a)
   277  	var ok bool
   278  	if trans {
   279  		ok = lapack64.Gtsv(blas.Trans, aCopy.mat, dst.asGeneral())
   280  	} else {
   281  		ok = lapack64.Gtsv(blas.NoTrans, aCopy.mat, dst.asGeneral())
   282  	}
   283  	if !ok {
   284  		return Condition(math.Inf(1))
   285  	}
   286  	return nil
   287  }
   288  
   289  // DoNonZero calls the function fn for each of the non-zero elements of A. The
   290  // function fn takes a row/column index and the element value of A at (i,j).
   291  func (a *Tridiag) DoNonZero(fn func(i, j int, v float64)) {
   292  	for i, aij := range a.mat.DU {
   293  		if aij != 0 {
   294  			fn(i, i+1, aij)
   295  		}
   296  	}
   297  	for i, aii := range a.mat.D {
   298  		if aii != 0 {
   299  			fn(i, i, aii)
   300  		}
   301  	}
   302  	for i, aij := range a.mat.DL {
   303  		if aij != 0 {
   304  			fn(i+1, i, aij)
   305  		}
   306  	}
   307  }
   308  
   309  // DoRowNonZero calls the function fn for each of the non-zero elements of row i
   310  // of A. The function fn takes a row/column index and the element value of A at
   311  // (i,j).
   312  func (a *Tridiag) DoRowNonZero(i int, fn func(i, j int, v float64)) {
   313  	n := a.mat.N
   314  	if uint(i) >= uint(n) {
   315  		panic(ErrRowAccess)
   316  	}
   317  	if n == 1 {
   318  		v := a.mat.D[0]
   319  		if v != 0 {
   320  			fn(0, 0, v)
   321  		}
   322  		return
   323  	}
   324  	switch i {
   325  	case 0:
   326  		v := a.mat.D[0]
   327  		if v != 0 {
   328  			fn(i, 0, v)
   329  		}
   330  		v = a.mat.DU[0]
   331  		if v != 0 {
   332  			fn(i, 1, v)
   333  		}
   334  	case n - 1:
   335  		v := a.mat.DL[n-2]
   336  		if v != 0 {
   337  			fn(n-1, n-2, v)
   338  		}
   339  		v = a.mat.D[n-1]
   340  		if v != 0 {
   341  			fn(n-1, n-1, v)
   342  		}
   343  	default:
   344  		v := a.mat.DL[i-1]
   345  		if v != 0 {
   346  			fn(i, i-1, v)
   347  		}
   348  		v = a.mat.D[i]
   349  		if v != 0 {
   350  			fn(i, i, v)
   351  		}
   352  		v = a.mat.DU[i]
   353  		if v != 0 {
   354  			fn(i, i+1, v)
   355  		}
   356  	}
   357  }
   358  
   359  // DoColNonZero calls the function fn for each of the non-zero elements of
   360  // column j of A. The function fn takes a row/column index and the element value
   361  // of A at (i, j).
   362  func (a *Tridiag) DoColNonZero(j int, fn func(i, j int, v float64)) {
   363  	n := a.mat.N
   364  	if uint(j) >= uint(n) {
   365  		panic(ErrColAccess)
   366  	}
   367  	if n == 1 {
   368  		v := a.mat.D[0]
   369  		if v != 0 {
   370  			fn(0, 0, v)
   371  		}
   372  		return
   373  	}
   374  	switch j {
   375  	case 0:
   376  		v := a.mat.D[0]
   377  		if v != 0 {
   378  			fn(0, 0, v)
   379  		}
   380  		v = a.mat.DL[0]
   381  		if v != 0 {
   382  			fn(1, 0, v)
   383  		}
   384  	case n - 1:
   385  		v := a.mat.DU[n-2]
   386  		if v != 0 {
   387  			fn(n-2, n-1, v)
   388  		}
   389  		v = a.mat.D[n-1]
   390  		if v != 0 {
   391  			fn(n-1, n-1, v)
   392  		}
   393  	default:
   394  		v := a.mat.DU[j-1]
   395  		if v != 0 {
   396  			fn(j-1, j, v)
   397  		}
   398  		v = a.mat.D[j]
   399  		if v != 0 {
   400  			fn(j, j, v)
   401  		}
   402  		v = a.mat.DL[j]
   403  		if v != 0 {
   404  			fn(j+1, j, v)
   405  		}
   406  	}
   407  }