gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/mat/tridiag.go (about)

     1  // Copyright ©2020 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math"
     9  
    10  	"gonum.org/v1/gonum/blas"
    11  	"gonum.org/v1/gonum/blas/blas64"
    12  	"gonum.org/v1/gonum/internal/asm/f64"
    13  	"gonum.org/v1/gonum/lapack/lapack64"
    14  )
    15  
    16  var (
    17  	tridiagDense *Tridiag
    18  	_            Matrix           = tridiagDense
    19  	_            allMatrix        = tridiagDense
    20  	_            denseMatrix      = tridiagDense
    21  	_            Banded           = tridiagDense
    22  	_            MutableBanded    = tridiagDense
    23  	_            RawTridiagonaler = tridiagDense
    24  )
    25  
    26  // A RawTridiagonaler can return a lapack64.Tridiagonal representation of the
    27  // receiver. Changes to the elements of DL, D, DU in lapack64.Tridiagonal will
    28  // be reflected in the original matrix, changes to the N field will not.
    29  type RawTridiagonaler interface {
    30  	RawTridiagonal() lapack64.Tridiagonal
    31  }
    32  
    33  // Tridiag represents a tridiagonal matrix by its three diagonals.
    34  type Tridiag struct {
    35  	mat lapack64.Tridiagonal
    36  }
    37  
    38  // NewTridiag creates a new n×n tridiagonal matrix with the first sub-diagonal
    39  // in dl, the main diagonal in d and the first super-diagonal in du. If all of
    40  // dl, d, and du are nil, new backing slices will be allocated for them. If dl
    41  // and du have length n-1 and d has length n, they will be used as backing
    42  // slices, and changes to the elements of the returned Tridiag will be reflected
    43  // in dl, d, du. If neither of these is true, NewTridiag will panic.
    44  func NewTridiag(n int, dl, d, du []float64) *Tridiag {
    45  	if n <= 0 {
    46  		if n == 0 {
    47  			panic(ErrZeroLength)
    48  		}
    49  		panic(ErrNegativeDimension)
    50  	}
    51  	if dl != nil || d != nil || du != nil {
    52  		if len(dl) != n-1 || len(d) != n || len(du) != n-1 {
    53  			panic(ErrShape)
    54  		}
    55  	} else {
    56  		d = make([]float64, n)
    57  		if n > 1 {
    58  			dl = make([]float64, n-1)
    59  			du = make([]float64, n-1)
    60  		}
    61  	}
    62  	return &Tridiag{
    63  		mat: lapack64.Tridiagonal{
    64  			N:  n,
    65  			DL: dl,
    66  			D:  d,
    67  			DU: du,
    68  		},
    69  	}
    70  }
    71  
    72  // Dims returns the number of rows and columns in the matrix.
    73  func (a *Tridiag) Dims() (r, c int) {
    74  	return a.mat.N, a.mat.N
    75  }
    76  
    77  // Bandwidth returns 1, 1 - the upper and lower bandwidths of the matrix.
    78  func (a *Tridiag) Bandwidth() (kl, ku int) {
    79  	return 1, 1
    80  }
    81  
    82  // T performs an implicit transpose by returning the receiver inside a Transpose.
    83  func (a *Tridiag) T() Matrix {
    84  	// An alternative would be to return the receiver with DL,DU swapped; the
    85  	// untranspose function would then always return false. With Transpose the
    86  	// diagonal swapping will be done in tridiagonal routines in lapack like
    87  	// lapack64.Gtsv or gonum.Dlagtm based on the trans parameter.
    88  	return Transpose{a}
    89  }
    90  
    91  // TBand performs an implicit transpose by returning the receiver inside a
    92  // TransposeBand.
    93  func (a *Tridiag) TBand() Banded {
    94  	// An alternative would be to return the receiver with DL,DU swapped; see
    95  	// explanation in T above.
    96  	return TransposeBand{a}
    97  }
    98  
    99  // RawTridiagonal returns the underlying lapack64.Tridiagonal used by the
   100  // receiver. Changes to elements in the receiver following the call will be
   101  // reflected in the returned matrix.
   102  func (a *Tridiag) RawTridiagonal() lapack64.Tridiagonal {
   103  	return a.mat
   104  }
   105  
   106  // SetRawTridiagonal sets the underlying lapack64.Tridiagonal used by the
   107  // receiver. Changes to elements in the receiver following the call will be
   108  // reflected in the input.
   109  func (a *Tridiag) SetRawTridiagonal(mat lapack64.Tridiagonal) {
   110  	a.mat = mat
   111  }
   112  
   113  // IsEmpty returns whether the receiver is empty. Empty matrices can be the
   114  // receiver for size-restricted operations. The receiver can be zeroed using
   115  // Reset.
   116  func (a *Tridiag) IsEmpty() bool {
   117  	return a.mat.N == 0
   118  }
   119  
   120  // Reset empties the matrix so that it can be reused as the receiver of a
   121  // dimensionally restricted operation.
   122  //
   123  // Reset should not be used when the matrix shares backing data. See the Reseter
   124  // interface for more information.
   125  func (a *Tridiag) Reset() {
   126  	a.mat.N = 0
   127  	a.mat.DL = a.mat.DL[:0]
   128  	a.mat.D = a.mat.D[:0]
   129  	a.mat.DU = a.mat.DU[:0]
   130  }
   131  
   132  // CloneFromTridiag makes a copy of the input Tridiag into the receiver,
   133  // overwriting the previous value of the receiver. CloneFromTridiag does not
   134  // place any restrictions on receiver shape.
   135  func (a *Tridiag) CloneFromTridiag(from *Tridiag) {
   136  	n := from.mat.N
   137  	switch n {
   138  	case 0:
   139  		panic(ErrZeroLength)
   140  	case 1:
   141  		a.mat = lapack64.Tridiagonal{
   142  			N:  1,
   143  			DL: use(a.mat.DL, 0),
   144  			D:  use(a.mat.D, 1),
   145  			DU: use(a.mat.DU, 0),
   146  		}
   147  		a.mat.D[0] = from.mat.D[0]
   148  	default:
   149  		a.mat = lapack64.Tridiagonal{
   150  			N:  n,
   151  			DL: use(a.mat.DL, n-1),
   152  			D:  use(a.mat.D, n),
   153  			DU: use(a.mat.DU, n-1),
   154  		}
   155  		copy(a.mat.DL, from.mat.DL)
   156  		copy(a.mat.D, from.mat.D)
   157  		copy(a.mat.DU, from.mat.DU)
   158  	}
   159  }
   160  
   161  // DiagView returns the diagonal as a matrix backed by the original data.
   162  func (a *Tridiag) DiagView() Diagonal {
   163  	return &DiagDense{
   164  		mat: blas64.Vector{
   165  			N:    a.mat.N,
   166  			Data: a.mat.D[:a.mat.N],
   167  			Inc:  1,
   168  		},
   169  	}
   170  }
   171  
   172  // Zero sets all of the matrix elements to zero.
   173  func (a *Tridiag) Zero() {
   174  	zero(a.mat.DL)
   175  	zero(a.mat.D)
   176  	zero(a.mat.DU)
   177  }
   178  
   179  // Trace returns the trace of the matrix.
   180  //
   181  // Trace will panic with ErrZeroLength if the matrix has zero size.
   182  func (a *Tridiag) Trace() float64 {
   183  	if a.IsEmpty() {
   184  		panic(ErrZeroLength)
   185  	}
   186  	return f64.Sum(a.mat.D)
   187  }
   188  
   189  // Norm returns the specified norm of the receiver. Valid norms are:
   190  //
   191  //	1 - The maximum absolute column sum
   192  //	2 - The Frobenius norm, the square root of the sum of the squares of the elements
   193  //	Inf - The maximum absolute row sum
   194  //
   195  // Norm will panic with ErrNormOrder if an illegal norm is specified and with
   196  // ErrZeroLength if the matrix has zero size.
   197  func (a *Tridiag) Norm(norm float64) float64 {
   198  	if a.IsEmpty() {
   199  		panic(ErrZeroLength)
   200  	}
   201  	return lapack64.Langt(normLapack(norm, false), a.mat)
   202  }
   203  
   204  // MulVecTo computes A⋅x or Aᵀ⋅x storing the result into dst.
   205  func (a *Tridiag) MulVecTo(dst *VecDense, trans bool, x Vector) {
   206  	n := a.mat.N
   207  	if x.Len() != n {
   208  		panic(ErrShape)
   209  	}
   210  	dst.reuseAsNonZeroed(n)
   211  	t := blas.NoTrans
   212  	if trans {
   213  		t = blas.Trans
   214  	}
   215  	xMat, _ := untransposeExtract(x)
   216  	if xVec, ok := xMat.(*VecDense); ok && dst != xVec {
   217  		dst.checkOverlap(xVec.mat)
   218  		lapack64.Lagtm(t, 1, a.mat, xVec.asGeneral(), 0, dst.asGeneral())
   219  	} else {
   220  		xCopy := getVecDenseWorkspace(n, false)
   221  		xCopy.CloneFromVec(x)
   222  		lapack64.Lagtm(t, 1, a.mat, xCopy.asGeneral(), 0, dst.asGeneral())
   223  		putVecDenseWorkspace(xCopy)
   224  	}
   225  }
   226  
   227  // SolveTo solves a tridiagonal system A⋅X = B  or  Aᵀ⋅X = B where A is an
   228  // n×n tridiagonal matrix represented by the receiver and B is a given n×nrhs
   229  // matrix. If A is non-singular, the result will be stored into dst and nil will
   230  // be returned. If A is singular, the contents of dst will be undefined and a
   231  // Condition error will be returned.
   232  func (a *Tridiag) SolveTo(dst *Dense, trans bool, b Matrix) error {
   233  	n, nrhs := b.Dims()
   234  	if n != a.mat.N {
   235  		panic(ErrShape)
   236  	}
   237  
   238  	dst.reuseAsNonZeroed(n, nrhs)
   239  	bU, bTrans := untranspose(b)
   240  	if dst == bU {
   241  		if bTrans {
   242  			work := getDenseWorkspace(n, nrhs, false)
   243  			defer putDenseWorkspace(work)
   244  			work.Copy(b)
   245  			dst.Copy(work)
   246  		}
   247  	} else {
   248  		if rm, ok := bU.(RawMatrixer); ok {
   249  			dst.checkOverlap(rm.RawMatrix())
   250  		}
   251  		dst.Copy(b)
   252  	}
   253  
   254  	var aCopy Tridiag
   255  	aCopy.CloneFromTridiag(a)
   256  	var ok bool
   257  	if trans {
   258  		ok = lapack64.Gtsv(blas.Trans, aCopy.mat, dst.mat)
   259  	} else {
   260  		ok = lapack64.Gtsv(blas.NoTrans, aCopy.mat, dst.mat)
   261  	}
   262  	if !ok {
   263  		return Condition(math.Inf(1))
   264  	}
   265  	return nil
   266  }
   267  
   268  // SolveVecTo solves a tridiagonal system A⋅X = B  or  Aᵀ⋅X = B where A is an
   269  // n×n tridiagonal matrix represented by the receiver and b is a given n-vector.
   270  // If A is non-singular, the result will be stored into dst and nil will be
   271  // returned. If A is singular, the contents of dst will be undefined and a
   272  // Condition error will be returned.
   273  func (a *Tridiag) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
   274  	n, nrhs := b.Dims()
   275  	if n != a.mat.N || nrhs != 1 {
   276  		panic(ErrShape)
   277  	}
   278  	if b, ok := b.(RawVectorer); ok && dst != b {
   279  		dst.checkOverlap(b.RawVector())
   280  	}
   281  	dst.reuseAsNonZeroed(n)
   282  	if dst != b {
   283  		dst.CopyVec(b)
   284  	}
   285  	var aCopy Tridiag
   286  	aCopy.CloneFromTridiag(a)
   287  	var ok bool
   288  	if trans {
   289  		ok = lapack64.Gtsv(blas.Trans, aCopy.mat, dst.asGeneral())
   290  	} else {
   291  		ok = lapack64.Gtsv(blas.NoTrans, aCopy.mat, dst.asGeneral())
   292  	}
   293  	if !ok {
   294  		return Condition(math.Inf(1))
   295  	}
   296  	return nil
   297  }
   298  
   299  // DoNonZero calls the function fn for each of the non-zero elements of A. The
   300  // function fn takes a row/column index and the element value of A at (i,j).
   301  func (a *Tridiag) DoNonZero(fn func(i, j int, v float64)) {
   302  	for i, aij := range a.mat.DU {
   303  		if aij != 0 {
   304  			fn(i, i+1, aij)
   305  		}
   306  	}
   307  	for i, aii := range a.mat.D {
   308  		if aii != 0 {
   309  			fn(i, i, aii)
   310  		}
   311  	}
   312  	for i, aij := range a.mat.DL {
   313  		if aij != 0 {
   314  			fn(i+1, i, aij)
   315  		}
   316  	}
   317  }
   318  
   319  // DoRowNonZero calls the function fn for each of the non-zero elements of row i
   320  // of A. The function fn takes a row/column index and the element value of A at
   321  // (i,j).
   322  func (a *Tridiag) DoRowNonZero(i int, fn func(i, j int, v float64)) {
   323  	n := a.mat.N
   324  	if uint(i) >= uint(n) {
   325  		panic(ErrRowAccess)
   326  	}
   327  	if n == 1 {
   328  		v := a.mat.D[0]
   329  		if v != 0 {
   330  			fn(0, 0, v)
   331  		}
   332  		return
   333  	}
   334  	switch i {
   335  	case 0:
   336  		v := a.mat.D[0]
   337  		if v != 0 {
   338  			fn(i, 0, v)
   339  		}
   340  		v = a.mat.DU[0]
   341  		if v != 0 {
   342  			fn(i, 1, v)
   343  		}
   344  	case n - 1:
   345  		v := a.mat.DL[n-2]
   346  		if v != 0 {
   347  			fn(n-1, n-2, v)
   348  		}
   349  		v = a.mat.D[n-1]
   350  		if v != 0 {
   351  			fn(n-1, n-1, v)
   352  		}
   353  	default:
   354  		v := a.mat.DL[i-1]
   355  		if v != 0 {
   356  			fn(i, i-1, v)
   357  		}
   358  		v = a.mat.D[i]
   359  		if v != 0 {
   360  			fn(i, i, v)
   361  		}
   362  		v = a.mat.DU[i]
   363  		if v != 0 {
   364  			fn(i, i+1, v)
   365  		}
   366  	}
   367  }
   368  
   369  // DoColNonZero calls the function fn for each of the non-zero elements of
   370  // column j of A. The function fn takes a row/column index and the element value
   371  // of A at (i, j).
   372  func (a *Tridiag) DoColNonZero(j int, fn func(i, j int, v float64)) {
   373  	n := a.mat.N
   374  	if uint(j) >= uint(n) {
   375  		panic(ErrColAccess)
   376  	}
   377  	if n == 1 {
   378  		v := a.mat.D[0]
   379  		if v != 0 {
   380  			fn(0, 0, v)
   381  		}
   382  		return
   383  	}
   384  	switch j {
   385  	case 0:
   386  		v := a.mat.D[0]
   387  		if v != 0 {
   388  			fn(0, 0, v)
   389  		}
   390  		v = a.mat.DL[0]
   391  		if v != 0 {
   392  			fn(1, 0, v)
   393  		}
   394  	case n - 1:
   395  		v := a.mat.DU[n-2]
   396  		if v != 0 {
   397  			fn(n-2, n-1, v)
   398  		}
   399  		v = a.mat.D[n-1]
   400  		if v != 0 {
   401  			fn(n-1, n-1, v)
   402  		}
   403  	default:
   404  		v := a.mat.DU[j-1]
   405  		if v != 0 {
   406  			fn(j-1, j, v)
   407  		}
   408  		v = a.mat.D[j]
   409  		if v != 0 {
   410  			fn(j, j, v)
   411  		}
   412  		v = a.mat.DL[j]
   413  		if v != 0 {
   414  			fn(j+1, j, v)
   415  		}
   416  	}
   417  }