github.com/gopherd/gonum@v0.0.4/mat/matrix.go (about)

     1  // Copyright ©2013 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gopherd/gonum/blas"
    11  	"github.com/gopherd/gonum/blas/blas64"
    12  	"github.com/gopherd/gonum/floats/scalar"
    13  	"github.com/gopherd/gonum/lapack"
    14  )
    15  
    16  // Matrix is the basic matrix interface type.
    17  type Matrix interface {
    18  	// Dims returns the dimensions of a Matrix.
    19  	Dims() (r, c int)
    20  
    21  	// At returns the value of a matrix element at row i, column j.
    22  	// It will panic if i or j are out of bounds for the matrix.
    23  	At(i, j int) float64
    24  
    25  	// T returns the transpose of the Matrix. Whether T returns a copy of the
    26  	// underlying data is implementation dependent.
    27  	// This method may be implemented using the Transpose type, which
    28  	// provides an implicit matrix transpose.
    29  	T() Matrix
    30  }
    31  
    32  // allMatrix represents the extra set of methods that all mat Matrix types
    33  // should satisfy. This is used to enforce compile-time consistency between the
    34  // Dense types, especially helpful when adding new features.
    35  type allMatrix interface {
    36  	Reseter
    37  	IsEmpty() bool
    38  	Zero()
    39  }
    40  
    41  // denseMatrix represents the extra set of methods that all Dense Matrix types
    42  // should satisfy. This is used to enforce compile-time consistency between the
    43  // Dense types, especially helpful when adding new features.
    44  type denseMatrix interface {
    45  	DiagView() Diagonal
    46  	Tracer
    47  	Normer
    48  }
    49  
    50  var (
    51  	_ Matrix       = Transpose{}
    52  	_ Untransposer = Transpose{}
    53  )
    54  
    55  // Transpose is a type for performing an implicit matrix transpose. It implements
    56  // the Matrix interface, returning values from the transpose of the matrix within.
    57  type Transpose struct {
    58  	Matrix Matrix
    59  }
    60  
    61  // At returns the value of the element at row i and column j of the transposed
    62  // matrix, that is, row j and column i of the Matrix field.
    63  func (t Transpose) At(i, j int) float64 {
    64  	return t.Matrix.At(j, i)
    65  }
    66  
    67  // Dims returns the dimensions of the transposed matrix. The number of rows returned
    68  // is the number of columns in the Matrix field, and the number of columns is
    69  // the number of rows in the Matrix field.
    70  func (t Transpose) Dims() (r, c int) {
    71  	c, r = t.Matrix.Dims()
    72  	return r, c
    73  }
    74  
    75  // T performs an implicit transpose by returning the Matrix field.
    76  func (t Transpose) T() Matrix {
    77  	return t.Matrix
    78  }
    79  
    80  // Untranspose returns the Matrix field.
    81  func (t Transpose) Untranspose() Matrix {
    82  	return t.Matrix
    83  }
    84  
    85  // Untransposer is a type that can undo an implicit transpose.
    86  type Untransposer interface {
    87  	// Note: This interface is needed to unify all of the Transpose types. In
    88  	// the mat methods, we need to test if the Matrix has been implicitly
    89  	// transposed. If this is checked by testing for the specific Transpose type
    90  	// then the behavior will be different if the user uses T() or TTri() for a
    91  	// triangular matrix.
    92  
    93  	// Untranspose returns the underlying Matrix stored for the implicit transpose.
    94  	Untranspose() Matrix
    95  }
    96  
    97  // UntransposeBander is a type that can undo an implicit band transpose.
    98  type UntransposeBander interface {
    99  	// Untranspose returns the underlying Banded stored for the implicit transpose.
   100  	UntransposeBand() Banded
   101  }
   102  
   103  // UntransposeTrier is a type that can undo an implicit triangular transpose.
   104  type UntransposeTrier interface {
   105  	// Untranspose returns the underlying Triangular stored for the implicit transpose.
   106  	UntransposeTri() Triangular
   107  }
   108  
   109  // UntransposeTriBander is a type that can undo an implicit triangular banded
   110  // transpose.
   111  type UntransposeTriBander interface {
   112  	// Untranspose returns the underlying Triangular stored for the implicit transpose.
   113  	UntransposeTriBand() TriBanded
   114  }
   115  
   116  // Mutable is a matrix interface type that allows elements to be altered.
   117  type Mutable interface {
   118  	// Set alters the matrix element at row i, column j to v.
   119  	// It will panic if i or j are out of bounds for the matrix.
   120  	Set(i, j int, v float64)
   121  
   122  	Matrix
   123  }
   124  
   125  // A RowViewer can return a Vector reflecting a row that is backed by the matrix
   126  // data. The Vector returned will have length equal to the number of columns.
   127  type RowViewer interface {
   128  	RowView(i int) Vector
   129  }
   130  
   131  // A RawRowViewer can return a slice of float64 reflecting a row that is backed by the matrix
   132  // data.
   133  type RawRowViewer interface {
   134  	RawRowView(i int) []float64
   135  }
   136  
   137  // A ColViewer can return a Vector reflecting a column that is backed by the matrix
   138  // data. The Vector returned will have length equal to the number of rows.
   139  type ColViewer interface {
   140  	ColView(j int) Vector
   141  }
   142  
   143  // A RawColViewer can return a slice of float64 reflecting a column that is backed by the matrix
   144  // data.
   145  type RawColViewer interface {
   146  	RawColView(j int) []float64
   147  }
   148  
   149  // A ClonerFrom can make a copy of a into the receiver, overwriting the previous value of the
   150  // receiver. The clone operation does not make any restriction on shape and will not cause
   151  // shadowing.
   152  type ClonerFrom interface {
   153  	CloneFrom(a Matrix)
   154  }
   155  
   156  // A Reseter can reset the matrix so that it can be reused as the receiver of a dimensionally
   157  // restricted operation. This is commonly used when the matrix is being used as a workspace
   158  // or temporary matrix.
   159  //
   160  // If the matrix is a view, using Reset may result in data corruption in elements outside
   161  // the view. Similarly, if the matrix shares backing data with another variable, using
   162  // Reset may lead to unexpected changes in data values.
   163  type Reseter interface {
   164  	Reset()
   165  }
   166  
   167  // A Copier can make a copy of elements of a into the receiver. The submatrix copied
   168  // starts at row and column 0 and has dimensions equal to the minimum dimensions of
   169  // the two matrices. The number of row and columns copied is returned.
   170  // Copy will copy from a source that aliases the receiver unless the source is transposed;
   171  // an aliasing transpose copy will panic with the exception for a special case when
   172  // the source data has a unitary increment or stride.
   173  type Copier interface {
   174  	Copy(a Matrix) (r, c int)
   175  }
   176  
   177  // A Grower can grow the size of the represented matrix by the given number of rows and columns.
   178  // Growing beyond the size given by the Caps method will result in the allocation of a new
   179  // matrix and copying of the elements. If Grow is called with negative increments it will
   180  // panic with ErrIndexOutOfRange.
   181  type Grower interface {
   182  	Caps() (r, c int)
   183  	Grow(r, c int) Matrix
   184  }
   185  
   186  // A RawMatrixSetter can set the underlying blas64.General used by the receiver. There is no restriction
   187  // on the shape of the receiver. Changes to the receiver's elements will be reflected in the blas64.General.Data.
   188  type RawMatrixSetter interface {
   189  	SetRawMatrix(a blas64.General)
   190  }
   191  
   192  // A RawMatrixer can return a blas64.General representation of the receiver. Changes to the blas64.General.Data
   193  // slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not.
   194  type RawMatrixer interface {
   195  	RawMatrix() blas64.General
   196  }
   197  
   198  // A RawVectorer can return a blas64.Vector representation of the receiver. Changes to the blas64.Vector.Data
   199  // slice will be reflected in the original matrix, changes to the Inc field will not.
   200  type RawVectorer interface {
   201  	RawVector() blas64.Vector
   202  }
   203  
   204  // A NonZeroDoer can call a function for each non-zero element of the receiver.
   205  // The parameters of the function are the element indices and its value.
   206  type NonZeroDoer interface {
   207  	DoNonZero(func(i, j int, v float64))
   208  }
   209  
   210  // A RowNonZeroDoer can call a function for each non-zero element of a row of the receiver.
   211  // The parameters of the function are the element indices and its value.
   212  type RowNonZeroDoer interface {
   213  	DoRowNonZero(i int, fn func(i, j int, v float64))
   214  }
   215  
   216  // A ColNonZeroDoer can call a function for each non-zero element of a column of the receiver.
   217  // The parameters of the function are the element indices and its value.
   218  type ColNonZeroDoer interface {
   219  	DoColNonZero(j int, fn func(i, j int, v float64))
   220  }
   221  
   222  // untranspose untransposes a matrix if applicable. If a is an Untransposer, then
   223  // untranspose returns the underlying matrix and true. If it is not, then it returns
   224  // the input matrix and false.
   225  func untranspose(a Matrix) (Matrix, bool) {
   226  	if ut, ok := a.(Untransposer); ok {
   227  		return ut.Untranspose(), true
   228  	}
   229  	return a, false
   230  }
   231  
   232  // untransposeExtract returns an untransposed matrix in a built-in matrix type.
   233  //
   234  // The untransposed matrix is returned unaltered if it is a built-in matrix type.
   235  // Otherwise, if it implements a Raw method, an appropriate built-in type value
   236  // is returned holding the raw matrix value of the input. If neither of these
   237  // is possible, the untransposed matrix is returned.
   238  func untransposeExtract(a Matrix) (Matrix, bool) {
   239  	ut, trans := untranspose(a)
   240  	switch m := ut.(type) {
   241  	case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense, *VecDense, *Tridiag:
   242  		return m, trans
   243  	// TODO(btracey): Add here if we ever have an equivalent of RawDiagDense.
   244  	case RawSymBander:
   245  		rsb := m.RawSymBand()
   246  		if rsb.Uplo != blas.Upper {
   247  			return ut, trans
   248  		}
   249  		var sb SymBandDense
   250  		sb.SetRawSymBand(rsb)
   251  		return &sb, trans
   252  	case RawTriBander:
   253  		rtb := m.RawTriBand()
   254  		if rtb.Diag == blas.Unit {
   255  			return ut, trans
   256  		}
   257  		var tb TriBandDense
   258  		tb.SetRawTriBand(rtb)
   259  		return &tb, trans
   260  	case RawBander:
   261  		var b BandDense
   262  		b.SetRawBand(m.RawBand())
   263  		return &b, trans
   264  	case RawTriangular:
   265  		rt := m.RawTriangular()
   266  		if rt.Diag == blas.Unit {
   267  			return ut, trans
   268  		}
   269  		var t TriDense
   270  		t.SetRawTriangular(rt)
   271  		return &t, trans
   272  	case RawSymmetricer:
   273  		rs := m.RawSymmetric()
   274  		if rs.Uplo != blas.Upper {
   275  			return ut, trans
   276  		}
   277  		var s SymDense
   278  		s.SetRawSymmetric(rs)
   279  		return &s, trans
   280  	case RawMatrixer:
   281  		var d Dense
   282  		d.SetRawMatrix(m.RawMatrix())
   283  		return &d, trans
   284  	case RawVectorer:
   285  		var v VecDense
   286  		v.SetRawVector(m.RawVector())
   287  		return &v, trans
   288  	case RawTridiagonaler:
   289  		var d Tridiag
   290  		d.SetRawTridiagonal(m.RawTridiagonal())
   291  		return &d, trans
   292  	default:
   293  		return ut, trans
   294  	}
   295  }
   296  
   297  // TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
   298  // TODO(btracey): Add in fast paths to Row/Col for the other concrete types
   299  // (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
   300  
   301  // Col copies the elements in the jth column of the matrix into the slice dst.
   302  // The length of the provided slice must equal the number of rows, unless the
   303  // slice is nil in which case a new slice is first allocated.
   304  func Col(dst []float64, j int, a Matrix) []float64 {
   305  	r, c := a.Dims()
   306  	if j < 0 || j >= c {
   307  		panic(ErrColAccess)
   308  	}
   309  	if dst == nil {
   310  		dst = make([]float64, r)
   311  	} else {
   312  		if len(dst) != r {
   313  			panic(ErrColLength)
   314  		}
   315  	}
   316  	aU, aTrans := untranspose(a)
   317  	if rm, ok := aU.(RawMatrixer); ok {
   318  		m := rm.RawMatrix()
   319  		if aTrans {
   320  			copy(dst, m.Data[j*m.Stride:j*m.Stride+m.Cols])
   321  			return dst
   322  		}
   323  		blas64.Copy(blas64.Vector{N: r, Inc: m.Stride, Data: m.Data[j:]},
   324  			blas64.Vector{N: r, Inc: 1, Data: dst},
   325  		)
   326  		return dst
   327  	}
   328  	for i := 0; i < r; i++ {
   329  		dst[i] = a.At(i, j)
   330  	}
   331  	return dst
   332  }
   333  
   334  // Row copies the elements in the ith row of the matrix into the slice dst.
   335  // The length of the provided slice must equal the number of columns, unless the
   336  // slice is nil in which case a new slice is first allocated.
   337  func Row(dst []float64, i int, a Matrix) []float64 {
   338  	r, c := a.Dims()
   339  	if i < 0 || i >= r {
   340  		panic(ErrColAccess)
   341  	}
   342  	if dst == nil {
   343  		dst = make([]float64, c)
   344  	} else {
   345  		if len(dst) != c {
   346  			panic(ErrRowLength)
   347  		}
   348  	}
   349  	aU, aTrans := untranspose(a)
   350  	if rm, ok := aU.(RawMatrixer); ok {
   351  		m := rm.RawMatrix()
   352  		if aTrans {
   353  			blas64.Copy(blas64.Vector{N: c, Inc: m.Stride, Data: m.Data[i:]},
   354  				blas64.Vector{N: c, Inc: 1, Data: dst},
   355  			)
   356  			return dst
   357  		}
   358  		copy(dst, m.Data[i*m.Stride:i*m.Stride+m.Cols])
   359  		return dst
   360  	}
   361  	for j := 0; j < c; j++ {
   362  		dst[j] = a.At(i, j)
   363  	}
   364  	return dst
   365  }
   366  
   367  // Cond returns the condition number of the given matrix under the given norm.
   368  // The condition number must be based on the 1-norm, 2-norm or ∞-norm.
   369  // Cond will panic with ErrZeroLength if the matrix has zero size.
   370  //
   371  // BUG(btracey): The computation of the 1-norm and ∞-norm for non-square matrices
   372  // is inaccurate, although is typically the right order of magnitude. See
   373  // https://github.com/xianyi/OpenBLAS/issues/636. While the value returned will
   374  // change with the resolution of this bug, the result from Cond will match the
   375  // condition number used internally.
   376  func Cond(a Matrix, norm float64) float64 {
   377  	m, n := a.Dims()
   378  	if m == 0 || n == 0 {
   379  		panic(ErrZeroLength)
   380  	}
   381  	var lnorm lapack.MatrixNorm
   382  	switch norm {
   383  	default:
   384  		panic("mat: bad norm value")
   385  	case 1:
   386  		lnorm = lapack.MaxColumnSum
   387  	case 2:
   388  		var svd SVD
   389  		ok := svd.Factorize(a, SVDNone)
   390  		if !ok {
   391  			return math.Inf(1)
   392  		}
   393  		return svd.Cond()
   394  	case math.Inf(1):
   395  		lnorm = lapack.MaxRowSum
   396  	}
   397  
   398  	if m == n {
   399  		// Use the LU decomposition to compute the condition number.
   400  		var lu LU
   401  		lu.factorize(a, lnorm)
   402  		return lu.Cond()
   403  	}
   404  	if m > n {
   405  		// Use the QR factorization to compute the condition number.
   406  		var qr QR
   407  		qr.factorize(a, lnorm)
   408  		return qr.Cond()
   409  	}
   410  	// Use the LQ factorization to compute the condition number.
   411  	var lq LQ
   412  	lq.factorize(a, lnorm)
   413  	return lq.Cond()
   414  }
   415  
   416  // Det returns the determinant of the quare matrix a. In many expressions using
   417  // LogDet will be more numerically stable.
   418  //
   419  // Det panics with ErrSquare if a is not square and with ErrZeroLength if a has
   420  // zero size.
   421  func Det(a Matrix) float64 {
   422  	det, sign := LogDet(a)
   423  	return math.Exp(det) * sign
   424  }
   425  
   426  // Dot returns the sum of the element-wise product of a and b.
   427  //
   428  // Dot panics with ErrShape if the vector sizes are unequal and with
   429  // ErrZeroLength if the sizes are zero.
   430  func Dot(a, b Vector) float64 {
   431  	la := a.Len()
   432  	lb := b.Len()
   433  	if la != lb {
   434  		panic(ErrShape)
   435  	}
   436  	if la == 0 {
   437  		panic(ErrZeroLength)
   438  	}
   439  	if arv, ok := a.(RawVectorer); ok {
   440  		if brv, ok := b.(RawVectorer); ok {
   441  			return blas64.Dot(arv.RawVector(), brv.RawVector())
   442  		}
   443  	}
   444  	var sum float64
   445  	for i := 0; i < la; i++ {
   446  		sum += a.At(i, 0) * b.At(i, 0)
   447  	}
   448  	return sum
   449  }
   450  
   451  // Equal returns whether the matrices a and b have the same size
   452  // and are element-wise equal.
   453  func Equal(a, b Matrix) bool {
   454  	ar, ac := a.Dims()
   455  	br, bc := b.Dims()
   456  	if ar != br || ac != bc {
   457  		return false
   458  	}
   459  	aU, aTrans := untranspose(a)
   460  	bU, bTrans := untranspose(b)
   461  	if rma, ok := aU.(RawMatrixer); ok {
   462  		if rmb, ok := bU.(RawMatrixer); ok {
   463  			ra := rma.RawMatrix()
   464  			rb := rmb.RawMatrix()
   465  			if aTrans == bTrans {
   466  				for i := 0; i < ra.Rows; i++ {
   467  					for j := 0; j < ra.Cols; j++ {
   468  						if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
   469  							return false
   470  						}
   471  					}
   472  				}
   473  				return true
   474  			}
   475  			for i := 0; i < ra.Rows; i++ {
   476  				for j := 0; j < ra.Cols; j++ {
   477  					if ra.Data[i*ra.Stride+j] != rb.Data[j*rb.Stride+i] {
   478  						return false
   479  					}
   480  				}
   481  			}
   482  			return true
   483  		}
   484  	}
   485  	if rma, ok := aU.(RawSymmetricer); ok {
   486  		if rmb, ok := bU.(RawSymmetricer); ok {
   487  			ra := rma.RawSymmetric()
   488  			rb := rmb.RawSymmetric()
   489  			// Symmetric matrices are always upper and equal to their transpose.
   490  			for i := 0; i < ra.N; i++ {
   491  				for j := i; j < ra.N; j++ {
   492  					if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
   493  						return false
   494  					}
   495  				}
   496  			}
   497  			return true
   498  		}
   499  	}
   500  	if ra, ok := aU.(*VecDense); ok {
   501  		if rb, ok := bU.(*VecDense); ok {
   502  			// If the raw vectors are the same length they must either both be
   503  			// transposed or both not transposed (or have length 1).
   504  			for i := 0; i < ra.mat.N; i++ {
   505  				if ra.mat.Data[i*ra.mat.Inc] != rb.mat.Data[i*rb.mat.Inc] {
   506  					return false
   507  				}
   508  			}
   509  			return true
   510  		}
   511  	}
   512  	for i := 0; i < ar; i++ {
   513  		for j := 0; j < ac; j++ {
   514  			if a.At(i, j) != b.At(i, j) {
   515  				return false
   516  			}
   517  		}
   518  	}
   519  	return true
   520  }
   521  
   522  // EqualApprox returns whether the matrices a and b have the same size and contain all equal
   523  // elements with tolerance for element-wise equality specified by epsilon. Matrices
   524  // with non-equal shapes are not equal.
   525  func EqualApprox(a, b Matrix, epsilon float64) bool {
   526  	ar, ac := a.Dims()
   527  	br, bc := b.Dims()
   528  	if ar != br || ac != bc {
   529  		return false
   530  	}
   531  	aU, aTrans := untranspose(a)
   532  	bU, bTrans := untranspose(b)
   533  	if rma, ok := aU.(RawMatrixer); ok {
   534  		if rmb, ok := bU.(RawMatrixer); ok {
   535  			ra := rma.RawMatrix()
   536  			rb := rmb.RawMatrix()
   537  			if aTrans == bTrans {
   538  				for i := 0; i < ra.Rows; i++ {
   539  					for j := 0; j < ra.Cols; j++ {
   540  						if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
   541  							return false
   542  						}
   543  					}
   544  				}
   545  				return true
   546  			}
   547  			for i := 0; i < ra.Rows; i++ {
   548  				for j := 0; j < ra.Cols; j++ {
   549  					if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[j*rb.Stride+i], epsilon, epsilon) {
   550  						return false
   551  					}
   552  				}
   553  			}
   554  			return true
   555  		}
   556  	}
   557  	if rma, ok := aU.(RawSymmetricer); ok {
   558  		if rmb, ok := bU.(RawSymmetricer); ok {
   559  			ra := rma.RawSymmetric()
   560  			rb := rmb.RawSymmetric()
   561  			// Symmetric matrices are always upper and equal to their transpose.
   562  			for i := 0; i < ra.N; i++ {
   563  				for j := i; j < ra.N; j++ {
   564  					if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
   565  						return false
   566  					}
   567  				}
   568  			}
   569  			return true
   570  		}
   571  	}
   572  	if ra, ok := aU.(*VecDense); ok {
   573  		if rb, ok := bU.(*VecDense); ok {
   574  			// If the raw vectors are the same length they must either both be
   575  			// transposed or both not transposed (or have length 1).
   576  			for i := 0; i < ra.mat.N; i++ {
   577  				if !scalar.EqualWithinAbsOrRel(ra.mat.Data[i*ra.mat.Inc], rb.mat.Data[i*rb.mat.Inc], epsilon, epsilon) {
   578  					return false
   579  				}
   580  			}
   581  			return true
   582  		}
   583  	}
   584  	for i := 0; i < ar; i++ {
   585  		for j := 0; j < ac; j++ {
   586  			if !scalar.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) {
   587  				return false
   588  			}
   589  		}
   590  	}
   591  	return true
   592  }
   593  
   594  // LogDet returns the log of the determinant and the sign of the determinant
   595  // for the matrix that has been factorized. Numerical stability in product and
   596  // division expressions is generally improved by working in log space.
   597  //
   598  // LogDet panics with ErrSquare is a is not square and with ErrZeroLength if a
   599  // has zero size.
   600  func LogDet(a Matrix) (det float64, sign float64) {
   601  	// TODO(btracey): Add specialized routines for TriDense, etc.
   602  	var lu LU
   603  	lu.Factorize(a)
   604  	return lu.LogDet()
   605  }
   606  
   607  // Max returns the largest element value of the matrix A.
   608  //
   609  // Max will panic with ErrZeroLength if the matrix has zero size.
   610  func Max(a Matrix) float64 {
   611  	r, c := a.Dims()
   612  	if r == 0 || c == 0 {
   613  		panic(ErrZeroLength)
   614  	}
   615  	// Max(A) = Max(Aᵀ)
   616  	aU, _ := untranspose(a)
   617  	switch m := aU.(type) {
   618  	case RawMatrixer:
   619  		rm := m.RawMatrix()
   620  		max := math.Inf(-1)
   621  		for i := 0; i < rm.Rows; i++ {
   622  			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
   623  				if v > max {
   624  					max = v
   625  				}
   626  			}
   627  		}
   628  		return max
   629  	case RawTriangular:
   630  		rm := m.RawTriangular()
   631  		// The max of a triangular is at least 0 unless the size is 1.
   632  		if rm.N == 1 {
   633  			return rm.Data[0]
   634  		}
   635  		max := 0.0
   636  		if rm.Uplo == blas.Upper {
   637  			for i := 0; i < rm.N; i++ {
   638  				for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
   639  					if v > max {
   640  						max = v
   641  					}
   642  				}
   643  			}
   644  			return max
   645  		}
   646  		for i := 0; i < rm.N; i++ {
   647  			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
   648  				if v > max {
   649  					max = v
   650  				}
   651  			}
   652  		}
   653  		return max
   654  	case RawSymmetricer:
   655  		rm := m.RawSymmetric()
   656  		if rm.Uplo != blas.Upper {
   657  			panic(badSymTriangle)
   658  		}
   659  		max := math.Inf(-1)
   660  		for i := 0; i < rm.N; i++ {
   661  			for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
   662  				if v > max {
   663  					max = v
   664  				}
   665  			}
   666  		}
   667  		return max
   668  	default:
   669  		r, c := aU.Dims()
   670  		max := math.Inf(-1)
   671  		for i := 0; i < r; i++ {
   672  			for j := 0; j < c; j++ {
   673  				v := aU.At(i, j)
   674  				if v > max {
   675  					max = v
   676  				}
   677  			}
   678  		}
   679  		return max
   680  	}
   681  }
   682  
   683  // Min returns the smallest element value of the matrix A.
   684  //
   685  // Min will panic with ErrZeroLength if the matrix has zero size.
   686  func Min(a Matrix) float64 {
   687  	r, c := a.Dims()
   688  	if r == 0 || c == 0 {
   689  		panic(ErrZeroLength)
   690  	}
   691  	// Min(A) = Min(Aᵀ)
   692  	aU, _ := untranspose(a)
   693  	switch m := aU.(type) {
   694  	case RawMatrixer:
   695  		rm := m.RawMatrix()
   696  		min := math.Inf(1)
   697  		for i := 0; i < rm.Rows; i++ {
   698  			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
   699  				if v < min {
   700  					min = v
   701  				}
   702  			}
   703  		}
   704  		return min
   705  	case RawTriangular:
   706  		rm := m.RawTriangular()
   707  		// The min of a triangular is at most 0 unless the size is 1.
   708  		if rm.N == 1 {
   709  			return rm.Data[0]
   710  		}
   711  		min := 0.0
   712  		if rm.Uplo == blas.Upper {
   713  			for i := 0; i < rm.N; i++ {
   714  				for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
   715  					if v < min {
   716  						min = v
   717  					}
   718  				}
   719  			}
   720  			return min
   721  		}
   722  		for i := 0; i < rm.N; i++ {
   723  			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
   724  				if v < min {
   725  					min = v
   726  				}
   727  			}
   728  		}
   729  		return min
   730  	case RawSymmetricer:
   731  		rm := m.RawSymmetric()
   732  		if rm.Uplo != blas.Upper {
   733  			panic(badSymTriangle)
   734  		}
   735  		min := math.Inf(1)
   736  		for i := 0; i < rm.N; i++ {
   737  			for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
   738  				if v < min {
   739  					min = v
   740  				}
   741  			}
   742  		}
   743  		return min
   744  	default:
   745  		r, c := aU.Dims()
   746  		min := math.Inf(1)
   747  		for i := 0; i < r; i++ {
   748  			for j := 0; j < c; j++ {
   749  				v := aU.At(i, j)
   750  				if v < min {
   751  					min = v
   752  				}
   753  			}
   754  		}
   755  		return min
   756  	}
   757  }
   758  
   759  // A Normer can compute a norm of the matrix. Valid norms are:
   760  //  1 - The maximum absolute column sum
   761  //  2 - The Frobenius norm, the square root of the sum of the squares of the elements
   762  //  Inf - The maximum absolute row sum
   763  type Normer interface {
   764  	Norm(norm float64) float64
   765  }
   766  
   767  // Norm returns the specified norm of the matrix A. Valid norms are:
   768  //  1 - The maximum absolute column sum
   769  //  2 - The Frobenius norm, the square root of the sum of the squares of the elements
   770  //  Inf - The maximum absolute row sum
   771  //
   772  // If a is a Normer, its Norm method will be used to calculate the norm.
   773  //
   774  // Norm will panic with ErrNormOrder if an illegal norm is specified and with
   775  // ErrShape if the matrix has zero size.
   776  func Norm(a Matrix, norm float64) float64 {
   777  	r, c := a.Dims()
   778  	if r == 0 || c == 0 {
   779  		panic(ErrZeroLength)
   780  	}
   781  	m, trans := untransposeExtract(a)
   782  	if m, ok := m.(Normer); ok {
   783  		if trans {
   784  			switch norm {
   785  			case 1:
   786  				norm = math.Inf(1)
   787  			case math.Inf(1):
   788  				norm = 1
   789  			}
   790  		}
   791  		return m.Norm(norm)
   792  	}
   793  	switch norm {
   794  	default:
   795  		panic(ErrNormOrder)
   796  	case 1:
   797  		var max float64
   798  		for j := 0; j < c; j++ {
   799  			var sum float64
   800  			for i := 0; i < r; i++ {
   801  				sum += math.Abs(a.At(i, j))
   802  			}
   803  			if sum > max {
   804  				max = sum
   805  			}
   806  		}
   807  		return max
   808  	case 2:
   809  		var sum float64
   810  		for i := 0; i < r; i++ {
   811  			for j := 0; j < c; j++ {
   812  				v := a.At(i, j)
   813  				sum += v * v
   814  			}
   815  		}
   816  		return math.Sqrt(sum)
   817  	case math.Inf(1):
   818  		var max float64
   819  		for i := 0; i < r; i++ {
   820  			var sum float64
   821  			for j := 0; j < c; j++ {
   822  				sum += math.Abs(a.At(i, j))
   823  			}
   824  			if sum > max {
   825  				max = sum
   826  			}
   827  		}
   828  		return max
   829  	}
   830  }
   831  
   832  // normLapack converts the float64 norm input in Norm to a lapack.MatrixNorm.
   833  func normLapack(norm float64, aTrans bool) lapack.MatrixNorm {
   834  	switch norm {
   835  	case 1:
   836  		n := lapack.MaxColumnSum
   837  		if aTrans {
   838  			n = lapack.MaxRowSum
   839  		}
   840  		return n
   841  	case 2:
   842  		return lapack.Frobenius
   843  	case math.Inf(1):
   844  		n := lapack.MaxRowSum
   845  		if aTrans {
   846  			n = lapack.MaxColumnSum
   847  		}
   848  		return n
   849  	default:
   850  		panic(ErrNormOrder)
   851  	}
   852  }
   853  
   854  // Sum returns the sum of the elements of the matrix.
   855  //
   856  // Sum will panic with ErrZeroLength if the matrix has zero size.
   857  func Sum(a Matrix) float64 {
   858  	r, c := a.Dims()
   859  	if r == 0 || c == 0 {
   860  		panic(ErrZeroLength)
   861  	}
   862  	var sum float64
   863  	aU, _ := untranspose(a)
   864  	switch rma := aU.(type) {
   865  	case RawSymmetricer:
   866  		rm := rma.RawSymmetric()
   867  		for i := 0; i < rm.N; i++ {
   868  			// Diagonals count once while off-diagonals count twice.
   869  			sum += rm.Data[i*rm.Stride+i]
   870  			var s float64
   871  			for _, v := range rm.Data[i*rm.Stride+i+1 : i*rm.Stride+rm.N] {
   872  				s += v
   873  			}
   874  			sum += 2 * s
   875  		}
   876  		return sum
   877  	case RawTriangular:
   878  		rm := rma.RawTriangular()
   879  		var startIdx, endIdx int
   880  		for i := 0; i < rm.N; i++ {
   881  			// Start and end index for this triangle-row.
   882  			switch rm.Uplo {
   883  			case blas.Upper:
   884  				startIdx = i
   885  				endIdx = rm.N
   886  			case blas.Lower:
   887  				startIdx = 0
   888  				endIdx = i + 1
   889  			default:
   890  				panic(badTriangle)
   891  			}
   892  			for _, v := range rm.Data[i*rm.Stride+startIdx : i*rm.Stride+endIdx] {
   893  				sum += v
   894  			}
   895  		}
   896  		return sum
   897  	case RawMatrixer:
   898  		rm := rma.RawMatrix()
   899  		for i := 0; i < rm.Rows; i++ {
   900  			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
   901  				sum += v
   902  			}
   903  		}
   904  		return sum
   905  	case *VecDense:
   906  		rm := rma.RawVector()
   907  		for i := 0; i < rm.N; i++ {
   908  			sum += rm.Data[i*rm.Inc]
   909  		}
   910  		return sum
   911  	default:
   912  		r, c := a.Dims()
   913  		for i := 0; i < r; i++ {
   914  			for j := 0; j < c; j++ {
   915  				sum += a.At(i, j)
   916  			}
   917  		}
   918  		return sum
   919  	}
   920  }
   921  
   922  // A Tracer can compute the trace of the matrix. Trace must panic with ErrSquare
   923  // if the matrix is not square.
   924  type Tracer interface {
   925  	Trace() float64
   926  }
   927  
   928  // Trace returns the trace of the matrix. If a is a Tracer, its Trace method
   929  // will be used to calculate the matrix trace.
   930  //
   931  // Trace will panic with ErrSquare if the matrix is not square and with
   932  // ErrZeroLength if the matrix has zero size.
   933  func Trace(a Matrix) float64 {
   934  	r, c := a.Dims()
   935  	if r == 0 || c == 0 {
   936  		panic(ErrZeroLength)
   937  	}
   938  	m, _ := untransposeExtract(a)
   939  	if t, ok := m.(Tracer); ok {
   940  		return t.Trace()
   941  	}
   942  	if r != c {
   943  		panic(ErrSquare)
   944  	}
   945  	var v float64
   946  	for i := 0; i < r; i++ {
   947  		v += a.At(i, i)
   948  	}
   949  	return v
   950  }
   951  
   952  func min(a, b int) int {
   953  	if a < b {
   954  		return a
   955  	}
   956  	return b
   957  }
   958  
   959  func max(a, b int) int {
   960  	if a > b {
   961  		return a
   962  	}
   963  	return b
   964  }
   965  
   966  // use returns a float64 slice with l elements, using f if it
   967  // has the necessary capacity, otherwise creating a new slice.
   968  func use(f []float64, l int) []float64 {
   969  	if l <= cap(f) {
   970  		return f[:l]
   971  	}
   972  	return make([]float64, l)
   973  }
   974  
   975  // useZeroed returns a float64 slice with l elements, using f if it
   976  // has the necessary capacity, otherwise creating a new slice. The
   977  // elements of the returned slice are guaranteed to be zero.
   978  func useZeroed(f []float64, l int) []float64 {
   979  	if l <= cap(f) {
   980  		f = f[:l]
   981  		zero(f)
   982  		return f
   983  	}
   984  	return make([]float64, l)
   985  }
   986  
   987  // zero zeros the given slice's elements.
   988  func zero(f []float64) {
   989  	for i := range f {
   990  		f[i] = 0
   991  	}
   992  }
   993  
   994  // useInt returns an int slice with l elements, using i if it
   995  // has the necessary capacity, otherwise creating a new slice.
   996  func useInt(i []int, l int) []int {
   997  	if l <= cap(i) {
   998  		return i[:l]
   999  	}
  1000  	return make([]int, l)
  1001  }