github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/mat/matrix.go (about)

     1  // Copyright ©2013 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/jingcheng-WU/gonum/blas"
    11  	"github.com/jingcheng-WU/gonum/blas/blas64"
    12  	"github.com/jingcheng-WU/gonum/floats/scalar"
    13  	"github.com/jingcheng-WU/gonum/lapack"
    14  	"github.com/jingcheng-WU/gonum/lapack/lapack64"
    15  )
    16  
    17  // Matrix is the basic matrix interface type.
    18  type Matrix interface {
    19  	// Dims returns the dimensions of a Matrix.
    20  	Dims() (r, c int)
    21  
    22  	// At returns the value of a matrix element at row i, column j.
    23  	// It will panic if i or j are out of bounds for the matrix.
    24  	At(i, j int) float64
    25  
    26  	// T returns the transpose of the Matrix. Whether T returns a copy of the
    27  	// underlying data is implementation dependent.
    28  	// This method may be implemented using the Transpose type, which
    29  	// provides an implicit matrix transpose.
    30  	T() Matrix
    31  }
    32  
    33  // allMatrix represents the extra set of methods that all mat Matrix types
    34  // should satisfy. This is used to enforce compile-time consistency between the
    35  // Dense types, especially helpful when adding new features.
    36  type allMatrix interface {
    37  	Reseter
    38  	IsEmpty() bool
    39  	Zero()
    40  }
    41  
    42  // denseMatrix represents the extra set of methods that all Dense Matrix types
    43  // should satisfy. This is used to enforce compile-time consistency between the
    44  // Dense types, especially helpful when adding new features.
    45  type denseMatrix interface {
    46  	DiagView() Diagonal
    47  	Tracer
    48  }
    49  
    50  var (
    51  	_ Matrix       = Transpose{}
    52  	_ Untransposer = Transpose{}
    53  )
    54  
    55  // Transpose is a type for performing an implicit matrix transpose. It implements
    56  // the Matrix interface, returning values from the transpose of the matrix within.
    57  type Transpose struct {
    58  	Matrix Matrix
    59  }
    60  
    61  // At returns the value of the element at row i and column j of the transposed
    62  // matrix, that is, row j and column i of the Matrix field.
    63  func (t Transpose) At(i, j int) float64 {
    64  	return t.Matrix.At(j, i)
    65  }
    66  
    67  // Dims returns the dimensions of the transposed matrix. The number of rows returned
    68  // is the number of columns in the Matrix field, and the number of columns is
    69  // the number of rows in the Matrix field.
    70  func (t Transpose) Dims() (r, c int) {
    71  	c, r = t.Matrix.Dims()
    72  	return r, c
    73  }
    74  
    75  // T performs an implicit transpose by returning the Matrix field.
    76  func (t Transpose) T() Matrix {
    77  	return t.Matrix
    78  }
    79  
    80  // Untranspose returns the Matrix field.
    81  func (t Transpose) Untranspose() Matrix {
    82  	return t.Matrix
    83  }
    84  
    85  // Untransposer is a type that can undo an implicit transpose.
    86  type Untransposer interface {
    87  	// Note: This interface is needed to unify all of the Transpose types. In
    88  	// the mat methods, we need to test if the Matrix has been implicitly
    89  	// transposed. If this is checked by testing for the specific Transpose type
    90  	// then the behavior will be different if the user uses T() or TTri() for a
    91  	// triangular matrix.
    92  
    93  	// Untranspose returns the underlying Matrix stored for the implicit transpose.
    94  	Untranspose() Matrix
    95  }
    96  
    97  // UntransposeBander is a type that can undo an implicit band transpose.
    98  type UntransposeBander interface {
    99  	// Untranspose returns the underlying Banded stored for the implicit transpose.
   100  	UntransposeBand() Banded
   101  }
   102  
   103  // UntransposeTrier is a type that can undo an implicit triangular transpose.
   104  type UntransposeTrier interface {
   105  	// Untranspose returns the underlying Triangular stored for the implicit transpose.
   106  	UntransposeTri() Triangular
   107  }
   108  
   109  // UntransposeTriBander is a type that can undo an implicit triangular banded
   110  // transpose.
   111  type UntransposeTriBander interface {
   112  	// Untranspose returns the underlying Triangular stored for the implicit transpose.
   113  	UntransposeTriBand() TriBanded
   114  }
   115  
   116  // Mutable is a matrix interface type that allows elements to be altered.
   117  type Mutable interface {
   118  	// Set alters the matrix element at row i, column j to v.
   119  	// It will panic if i or j are out of bounds for the matrix.
   120  	Set(i, j int, v float64)
   121  
   122  	Matrix
   123  }
   124  
   125  // A RowViewer can return a Vector reflecting a row that is backed by the matrix
   126  // data. The Vector returned will have length equal to the number of columns.
   127  type RowViewer interface {
   128  	RowView(i int) Vector
   129  }
   130  
   131  // A RawRowViewer can return a slice of float64 reflecting a row that is backed by the matrix
   132  // data.
   133  type RawRowViewer interface {
   134  	RawRowView(i int) []float64
   135  }
   136  
   137  // A ColViewer can return a Vector reflecting a column that is backed by the matrix
   138  // data. The Vector returned will have length equal to the number of rows.
   139  type ColViewer interface {
   140  	ColView(j int) Vector
   141  }
   142  
   143  // A RawColViewer can return a slice of float64 reflecting a column that is backed by the matrix
   144  // data.
   145  type RawColViewer interface {
   146  	RawColView(j int) []float64
   147  }
   148  
   149  // A ClonerFrom can make a copy of a into the receiver, overwriting the previous value of the
   150  // receiver. The clone operation does not make any restriction on shape and will not cause
   151  // shadowing.
   152  type ClonerFrom interface {
   153  	CloneFrom(a Matrix)
   154  }
   155  
   156  // A Reseter can reset the matrix so that it can be reused as the receiver of a dimensionally
   157  // restricted operation. This is commonly used when the matrix is being used as a workspace
   158  // or temporary matrix.
   159  //
   160  // If the matrix is a view, using Reset may result in data corruption in elements outside
   161  // the view. Similarly, if the matrix shares backing data with another variable, using
   162  // Reset may lead to unexpected changes in data values.
   163  type Reseter interface {
   164  	Reset()
   165  }
   166  
   167  // A Copier can make a copy of elements of a into the receiver. The submatrix copied
   168  // starts at row and column 0 and has dimensions equal to the minimum dimensions of
   169  // the two matrices. The number of row and columns copied is returned.
   170  // Copy will copy from a source that aliases the receiver unless the source is transposed;
   171  // an aliasing transpose copy will panic with the exception for a special case when
   172  // the source data has a unitary increment or stride.
   173  type Copier interface {
   174  	Copy(a Matrix) (r, c int)
   175  }
   176  
   177  // A Grower can grow the size of the represented matrix by the given number of rows and columns.
   178  // Growing beyond the size given by the Caps method will result in the allocation of a new
   179  // matrix and copying of the elements. If Grow is called with negative increments it will
   180  // panic with ErrIndexOutOfRange.
   181  type Grower interface {
   182  	Caps() (r, c int)
   183  	Grow(r, c int) Matrix
   184  }
   185  
   186  // A BandWidther represents a banded matrix and can return the left and right half-bandwidths, k1 and
   187  // k2.
   188  type BandWidther interface {
   189  	BandWidth() (k1, k2 int)
   190  }
   191  
   192  // A RawMatrixSetter can set the underlying blas64.General used by the receiver. There is no restriction
   193  // on the shape of the receiver. Changes to the receiver's elements will be reflected in the blas64.General.Data.
   194  type RawMatrixSetter interface {
   195  	SetRawMatrix(a blas64.General)
   196  }
   197  
   198  // A RawMatrixer can return a blas64.General representation of the receiver. Changes to the blas64.General.Data
   199  // slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not.
   200  type RawMatrixer interface {
   201  	RawMatrix() blas64.General
   202  }
   203  
   204  // A RawVectorer can return a blas64.Vector representation of the receiver. Changes to the blas64.Vector.Data
   205  // slice will be reflected in the original matrix, changes to the Inc field will not.
   206  type RawVectorer interface {
   207  	RawVector() blas64.Vector
   208  }
   209  
   210  // A NonZeroDoer can call a function for each non-zero element of the receiver.
   211  // The parameters of the function are the element indices and its value.
   212  type NonZeroDoer interface {
   213  	DoNonZero(func(i, j int, v float64))
   214  }
   215  
   216  // A RowNonZeroDoer can call a function for each non-zero element of a row of the receiver.
   217  // The parameters of the function are the element indices and its value.
   218  type RowNonZeroDoer interface {
   219  	DoRowNonZero(i int, fn func(i, j int, v float64))
   220  }
   221  
   222  // A ColNonZeroDoer can call a function for each non-zero element of a column of the receiver.
   223  // The parameters of the function are the element indices and its value.
   224  type ColNonZeroDoer interface {
   225  	DoColNonZero(j int, fn func(i, j int, v float64))
   226  }
   227  
   228  // untranspose untransposes a matrix if applicable. If a is an Untransposer, then
   229  // untranspose returns the underlying matrix and true. If it is not, then it returns
   230  // the input matrix and false.
   231  func untranspose(a Matrix) (Matrix, bool) {
   232  	if ut, ok := a.(Untransposer); ok {
   233  		return ut.Untranspose(), true
   234  	}
   235  	return a, false
   236  }
   237  
   238  // untransposeExtract returns an untransposed matrix in a built-in matrix type.
   239  //
   240  // The untransposed matrix is returned unaltered if it is a built-in matrix type.
   241  // Otherwise, if it implements a Raw method, an appropriate built-in type value
   242  // is returned holding the raw matrix value of the input. If neither of these
   243  // is possible, the untransposed matrix is returned.
   244  func untransposeExtract(a Matrix) (Matrix, bool) {
   245  	ut, trans := untranspose(a)
   246  	switch m := ut.(type) {
   247  	case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense, *VecDense:
   248  		return m, trans
   249  	// TODO(btracey): Add here if we ever have an equivalent of RawDiagDense.
   250  	case RawSymBander:
   251  		rsb := m.RawSymBand()
   252  		if rsb.Uplo != blas.Upper {
   253  			return ut, trans
   254  		}
   255  		var sb SymBandDense
   256  		sb.SetRawSymBand(rsb)
   257  		return &sb, trans
   258  	case RawTriBander:
   259  		rtb := m.RawTriBand()
   260  		if rtb.Diag == blas.Unit {
   261  			return ut, trans
   262  		}
   263  		var tb TriBandDense
   264  		tb.SetRawTriBand(rtb)
   265  		return &tb, trans
   266  	case RawBander:
   267  		var b BandDense
   268  		b.SetRawBand(m.RawBand())
   269  		return &b, trans
   270  	case RawTriangular:
   271  		rt := m.RawTriangular()
   272  		if rt.Diag == blas.Unit {
   273  			return ut, trans
   274  		}
   275  		var t TriDense
   276  		t.SetRawTriangular(rt)
   277  		return &t, trans
   278  	case RawSymmetricer:
   279  		rs := m.RawSymmetric()
   280  		if rs.Uplo != blas.Upper {
   281  			return ut, trans
   282  		}
   283  		var s SymDense
   284  		s.SetRawSymmetric(rs)
   285  		return &s, trans
   286  	case RawMatrixer:
   287  		var d Dense
   288  		d.SetRawMatrix(m.RawMatrix())
   289  		return &d, trans
   290  	case RawVectorer:
   291  		var v VecDense
   292  		v.SetRawVector(m.RawVector())
   293  		return &v, trans
   294  	default:
   295  		return ut, trans
   296  	}
   297  }
   298  
   299  // TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
   300  // TODO(btracey): Add in fast paths to Row/Col for the other concrete types
   301  // (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
   302  
   303  // Col copies the elements in the jth column of the matrix into the slice dst.
   304  // The length of the provided slice must equal the number of rows, unless the
   305  // slice is nil in which case a new slice is first allocated.
   306  func Col(dst []float64, j int, a Matrix) []float64 {
   307  	r, c := a.Dims()
   308  	if j < 0 || j >= c {
   309  		panic(ErrColAccess)
   310  	}
   311  	if dst == nil {
   312  		dst = make([]float64, r)
   313  	} else {
   314  		if len(dst) != r {
   315  			panic(ErrColLength)
   316  		}
   317  	}
   318  	aU, aTrans := untranspose(a)
   319  	if rm, ok := aU.(RawMatrixer); ok {
   320  		m := rm.RawMatrix()
   321  		if aTrans {
   322  			copy(dst, m.Data[j*m.Stride:j*m.Stride+m.Cols])
   323  			return dst
   324  		}
   325  		blas64.Copy(blas64.Vector{N: r, Inc: m.Stride, Data: m.Data[j:]},
   326  			blas64.Vector{N: r, Inc: 1, Data: dst},
   327  		)
   328  		return dst
   329  	}
   330  	for i := 0; i < r; i++ {
   331  		dst[i] = a.At(i, j)
   332  	}
   333  	return dst
   334  }
   335  
   336  // Row copies the elements in the ith row of the matrix into the slice dst.
   337  // The length of the provided slice must equal the number of columns, unless the
   338  // slice is nil in which case a new slice is first allocated.
   339  func Row(dst []float64, i int, a Matrix) []float64 {
   340  	r, c := a.Dims()
   341  	if i < 0 || i >= r {
   342  		panic(ErrColAccess)
   343  	}
   344  	if dst == nil {
   345  		dst = make([]float64, c)
   346  	} else {
   347  		if len(dst) != c {
   348  			panic(ErrRowLength)
   349  		}
   350  	}
   351  	aU, aTrans := untranspose(a)
   352  	if rm, ok := aU.(RawMatrixer); ok {
   353  		m := rm.RawMatrix()
   354  		if aTrans {
   355  			blas64.Copy(blas64.Vector{N: c, Inc: m.Stride, Data: m.Data[i:]},
   356  				blas64.Vector{N: c, Inc: 1, Data: dst},
   357  			)
   358  			return dst
   359  		}
   360  		copy(dst, m.Data[i*m.Stride:i*m.Stride+m.Cols])
   361  		return dst
   362  	}
   363  	for j := 0; j < c; j++ {
   364  		dst[j] = a.At(i, j)
   365  	}
   366  	return dst
   367  }
   368  
   369  // Cond returns the condition number of the given matrix under the given norm.
   370  // The condition number must be based on the 1-norm, 2-norm or ∞-norm.
   371  // Cond will panic with matrix.ErrShape if the matrix has zero size.
   372  //
   373  // BUG(btracey): The computation of the 1-norm and ∞-norm for non-square matrices
   374  // is inaccurate, although is typically the right order of magnitude. See
   375  // https://github.com/xianyi/OpenBLAS/issues/636. While the value returned will
   376  // change with the resolution of this bug, the result from Cond will match the
   377  // condition number used internally.
   378  func Cond(a Matrix, norm float64) float64 {
   379  	m, n := a.Dims()
   380  	if m == 0 || n == 0 {
   381  		panic(ErrShape)
   382  	}
   383  	var lnorm lapack.MatrixNorm
   384  	switch norm {
   385  	default:
   386  		panic("mat: bad norm value")
   387  	case 1:
   388  		lnorm = lapack.MaxColumnSum
   389  	case 2:
   390  		var svd SVD
   391  		ok := svd.Factorize(a, SVDNone)
   392  		if !ok {
   393  			return math.Inf(1)
   394  		}
   395  		return svd.Cond()
   396  	case math.Inf(1):
   397  		lnorm = lapack.MaxRowSum
   398  	}
   399  
   400  	if m == n {
   401  		// Use the LU decomposition to compute the condition number.
   402  		var lu LU
   403  		lu.factorize(a, lnorm)
   404  		return lu.Cond()
   405  	}
   406  	if m > n {
   407  		// Use the QR factorization to compute the condition number.
   408  		var qr QR
   409  		qr.factorize(a, lnorm)
   410  		return qr.Cond()
   411  	}
   412  	// Use the LQ factorization to compute the condition number.
   413  	var lq LQ
   414  	lq.factorize(a, lnorm)
   415  	return lq.Cond()
   416  }
   417  
   418  // Det returns the determinant of the matrix a. In many expressions using LogDet
   419  // will be more numerically stable.
   420  func Det(a Matrix) float64 {
   421  	det, sign := LogDet(a)
   422  	return math.Exp(det) * sign
   423  }
   424  
   425  // Dot returns the sum of the element-wise product of a and b.
   426  // Dot panics if the matrix sizes are unequal.
   427  func Dot(a, b Vector) float64 {
   428  	la := a.Len()
   429  	lb := b.Len()
   430  	if la != lb {
   431  		panic(ErrShape)
   432  	}
   433  	if arv, ok := a.(RawVectorer); ok {
   434  		if brv, ok := b.(RawVectorer); ok {
   435  			return blas64.Dot(arv.RawVector(), brv.RawVector())
   436  		}
   437  	}
   438  	var sum float64
   439  	for i := 0; i < la; i++ {
   440  		sum += a.At(i, 0) * b.At(i, 0)
   441  	}
   442  	return sum
   443  }
   444  
   445  // Equal returns whether the matrices a and b have the same size
   446  // and are element-wise equal.
   447  func Equal(a, b Matrix) bool {
   448  	ar, ac := a.Dims()
   449  	br, bc := b.Dims()
   450  	if ar != br || ac != bc {
   451  		return false
   452  	}
   453  	aU, aTrans := untranspose(a)
   454  	bU, bTrans := untranspose(b)
   455  	if rma, ok := aU.(RawMatrixer); ok {
   456  		if rmb, ok := bU.(RawMatrixer); ok {
   457  			ra := rma.RawMatrix()
   458  			rb := rmb.RawMatrix()
   459  			if aTrans == bTrans {
   460  				for i := 0; i < ra.Rows; i++ {
   461  					for j := 0; j < ra.Cols; j++ {
   462  						if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
   463  							return false
   464  						}
   465  					}
   466  				}
   467  				return true
   468  			}
   469  			for i := 0; i < ra.Rows; i++ {
   470  				for j := 0; j < ra.Cols; j++ {
   471  					if ra.Data[i*ra.Stride+j] != rb.Data[j*rb.Stride+i] {
   472  						return false
   473  					}
   474  				}
   475  			}
   476  			return true
   477  		}
   478  	}
   479  	if rma, ok := aU.(RawSymmetricer); ok {
   480  		if rmb, ok := bU.(RawSymmetricer); ok {
   481  			ra := rma.RawSymmetric()
   482  			rb := rmb.RawSymmetric()
   483  			// Symmetric matrices are always upper and equal to their transpose.
   484  			for i := 0; i < ra.N; i++ {
   485  				for j := i; j < ra.N; j++ {
   486  					if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
   487  						return false
   488  					}
   489  				}
   490  			}
   491  			return true
   492  		}
   493  	}
   494  	if ra, ok := aU.(*VecDense); ok {
   495  		if rb, ok := bU.(*VecDense); ok {
   496  			// If the raw vectors are the same length they must either both be
   497  			// transposed or both not transposed (or have length 1).
   498  			for i := 0; i < ra.mat.N; i++ {
   499  				if ra.mat.Data[i*ra.mat.Inc] != rb.mat.Data[i*rb.mat.Inc] {
   500  					return false
   501  				}
   502  			}
   503  			return true
   504  		}
   505  	}
   506  	for i := 0; i < ar; i++ {
   507  		for j := 0; j < ac; j++ {
   508  			if a.At(i, j) != b.At(i, j) {
   509  				return false
   510  			}
   511  		}
   512  	}
   513  	return true
   514  }
   515  
   516  // EqualApprox returns whether the matrices a and b have the same size and contain all equal
   517  // elements with tolerance for element-wise equality specified by epsilon. Matrices
   518  // with non-equal shapes are not equal.
   519  func EqualApprox(a, b Matrix, epsilon float64) bool {
   520  	ar, ac := a.Dims()
   521  	br, bc := b.Dims()
   522  	if ar != br || ac != bc {
   523  		return false
   524  	}
   525  	aU, aTrans := untranspose(a)
   526  	bU, bTrans := untranspose(b)
   527  	if rma, ok := aU.(RawMatrixer); ok {
   528  		if rmb, ok := bU.(RawMatrixer); ok {
   529  			ra := rma.RawMatrix()
   530  			rb := rmb.RawMatrix()
   531  			if aTrans == bTrans {
   532  				for i := 0; i < ra.Rows; i++ {
   533  					for j := 0; j < ra.Cols; j++ {
   534  						if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
   535  							return false
   536  						}
   537  					}
   538  				}
   539  				return true
   540  			}
   541  			for i := 0; i < ra.Rows; i++ {
   542  				for j := 0; j < ra.Cols; j++ {
   543  					if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[j*rb.Stride+i], epsilon, epsilon) {
   544  						return false
   545  					}
   546  				}
   547  			}
   548  			return true
   549  		}
   550  	}
   551  	if rma, ok := aU.(RawSymmetricer); ok {
   552  		if rmb, ok := bU.(RawSymmetricer); ok {
   553  			ra := rma.RawSymmetric()
   554  			rb := rmb.RawSymmetric()
   555  			// Symmetric matrices are always upper and equal to their transpose.
   556  			for i := 0; i < ra.N; i++ {
   557  				for j := i; j < ra.N; j++ {
   558  					if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
   559  						return false
   560  					}
   561  				}
   562  			}
   563  			return true
   564  		}
   565  	}
   566  	if ra, ok := aU.(*VecDense); ok {
   567  		if rb, ok := bU.(*VecDense); ok {
   568  			// If the raw vectors are the same length they must either both be
   569  			// transposed or both not transposed (or have length 1).
   570  			for i := 0; i < ra.mat.N; i++ {
   571  				if !scalar.EqualWithinAbsOrRel(ra.mat.Data[i*ra.mat.Inc], rb.mat.Data[i*rb.mat.Inc], epsilon, epsilon) {
   572  					return false
   573  				}
   574  			}
   575  			return true
   576  		}
   577  	}
   578  	for i := 0; i < ar; i++ {
   579  		for j := 0; j < ac; j++ {
   580  			if !scalar.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) {
   581  				return false
   582  			}
   583  		}
   584  	}
   585  	return true
   586  }
   587  
   588  // LogDet returns the log of the determinant and the sign of the determinant
   589  // for the matrix that has been factorized. Numerical stability in product and
   590  // division expressions is generally improved by working in log space.
   591  func LogDet(a Matrix) (det float64, sign float64) {
   592  	// TODO(btracey): Add specialized routines for TriDense, etc.
   593  	var lu LU
   594  	lu.Factorize(a)
   595  	return lu.LogDet()
   596  }
   597  
   598  // Max returns the largest element value of the matrix A.
   599  // Max will panic with matrix.ErrShape if the matrix has zero size.
   600  func Max(a Matrix) float64 {
   601  	r, c := a.Dims()
   602  	if r == 0 || c == 0 {
   603  		panic(ErrShape)
   604  	}
   605  	// Max(A) = Max(Aᵀ)
   606  	aU, _ := untranspose(a)
   607  	switch m := aU.(type) {
   608  	case RawMatrixer:
   609  		rm := m.RawMatrix()
   610  		max := math.Inf(-1)
   611  		for i := 0; i < rm.Rows; i++ {
   612  			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
   613  				if v > max {
   614  					max = v
   615  				}
   616  			}
   617  		}
   618  		return max
   619  	case RawTriangular:
   620  		rm := m.RawTriangular()
   621  		// The max of a triangular is at least 0 unless the size is 1.
   622  		if rm.N == 1 {
   623  			return rm.Data[0]
   624  		}
   625  		max := 0.0
   626  		if rm.Uplo == blas.Upper {
   627  			for i := 0; i < rm.N; i++ {
   628  				for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
   629  					if v > max {
   630  						max = v
   631  					}
   632  				}
   633  			}
   634  			return max
   635  		}
   636  		for i := 0; i < rm.N; i++ {
   637  			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
   638  				if v > max {
   639  					max = v
   640  				}
   641  			}
   642  		}
   643  		return max
   644  	case RawSymmetricer:
   645  		rm := m.RawSymmetric()
   646  		if rm.Uplo != blas.Upper {
   647  			panic(badSymTriangle)
   648  		}
   649  		max := math.Inf(-1)
   650  		for i := 0; i < rm.N; i++ {
   651  			for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
   652  				if v > max {
   653  					max = v
   654  				}
   655  			}
   656  		}
   657  		return max
   658  	default:
   659  		r, c := aU.Dims()
   660  		max := math.Inf(-1)
   661  		for i := 0; i < r; i++ {
   662  			for j := 0; j < c; j++ {
   663  				v := aU.At(i, j)
   664  				if v > max {
   665  					max = v
   666  				}
   667  			}
   668  		}
   669  		return max
   670  	}
   671  }
   672  
   673  // Min returns the smallest element value of the matrix A.
   674  // Min will panic with matrix.ErrShape if the matrix has zero size.
   675  func Min(a Matrix) float64 {
   676  	r, c := a.Dims()
   677  	if r == 0 || c == 0 {
   678  		panic(ErrShape)
   679  	}
   680  	// Min(A) = Min(Aᵀ)
   681  	aU, _ := untranspose(a)
   682  	switch m := aU.(type) {
   683  	case RawMatrixer:
   684  		rm := m.RawMatrix()
   685  		min := math.Inf(1)
   686  		for i := 0; i < rm.Rows; i++ {
   687  			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
   688  				if v < min {
   689  					min = v
   690  				}
   691  			}
   692  		}
   693  		return min
   694  	case RawTriangular:
   695  		rm := m.RawTriangular()
   696  		// The min of a triangular is at most 0 unless the size is 1.
   697  		if rm.N == 1 {
   698  			return rm.Data[0]
   699  		}
   700  		min := 0.0
   701  		if rm.Uplo == blas.Upper {
   702  			for i := 0; i < rm.N; i++ {
   703  				for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
   704  					if v < min {
   705  						min = v
   706  					}
   707  				}
   708  			}
   709  			return min
   710  		}
   711  		for i := 0; i < rm.N; i++ {
   712  			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
   713  				if v < min {
   714  					min = v
   715  				}
   716  			}
   717  		}
   718  		return min
   719  	case RawSymmetricer:
   720  		rm := m.RawSymmetric()
   721  		if rm.Uplo != blas.Upper {
   722  			panic(badSymTriangle)
   723  		}
   724  		min := math.Inf(1)
   725  		for i := 0; i < rm.N; i++ {
   726  			for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
   727  				if v < min {
   728  					min = v
   729  				}
   730  			}
   731  		}
   732  		return min
   733  	default:
   734  		r, c := aU.Dims()
   735  		min := math.Inf(1)
   736  		for i := 0; i < r; i++ {
   737  			for j := 0; j < c; j++ {
   738  				v := aU.At(i, j)
   739  				if v < min {
   740  					min = v
   741  				}
   742  			}
   743  		}
   744  		return min
   745  	}
   746  }
   747  
   748  // Norm returns the specified norm of the matrix A. Valid norms are:
   749  //  1 - The maximum absolute column sum
   750  //  2 - The Frobenius norm, the square root of the sum of the squares of the elements
   751  //  Inf - The maximum absolute row sum
   752  //
   753  // Norm will panic with ErrNormOrder if an illegal norm order is specified and
   754  // with ErrShape if the matrix has zero size.
   755  func Norm(a Matrix, norm float64) float64 {
   756  	r, c := a.Dims()
   757  	if r == 0 || c == 0 {
   758  		panic(ErrShape)
   759  	}
   760  	aU, aTrans := untranspose(a)
   761  	var work []float64
   762  	switch rma := aU.(type) {
   763  	case RawMatrixer:
   764  		rm := rma.RawMatrix()
   765  		n := normLapack(norm, aTrans)
   766  		if n == lapack.MaxColumnSum {
   767  			work = getFloats(rm.Cols, false)
   768  			defer putFloats(work)
   769  		}
   770  		return lapack64.Lange(n, rm, work)
   771  	case RawTriangular:
   772  		rm := rma.RawTriangular()
   773  		n := normLapack(norm, aTrans)
   774  		if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
   775  			work = getFloats(rm.N, false)
   776  			defer putFloats(work)
   777  		}
   778  		return lapack64.Lantr(n, rm, work)
   779  	case RawSymmetricer:
   780  		rm := rma.RawSymmetric()
   781  		n := normLapack(norm, aTrans)
   782  		if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
   783  			work = getFloats(rm.N, false)
   784  			defer putFloats(work)
   785  		}
   786  		return lapack64.Lansy(n, rm, work)
   787  	case *VecDense:
   788  		rv := rma.RawVector()
   789  		switch norm {
   790  		default:
   791  			panic(ErrNormOrder)
   792  		case 1:
   793  			if aTrans {
   794  				imax := blas64.Iamax(rv)
   795  				return math.Abs(rma.At(imax, 0))
   796  			}
   797  			return blas64.Asum(rv)
   798  		case 2:
   799  			return blas64.Nrm2(rv)
   800  		case math.Inf(1):
   801  			if aTrans {
   802  				return blas64.Asum(rv)
   803  			}
   804  			imax := blas64.Iamax(rv)
   805  			return math.Abs(rma.At(imax, 0))
   806  		}
   807  	}
   808  	switch norm {
   809  	default:
   810  		panic(ErrNormOrder)
   811  	case 1:
   812  		var max float64
   813  		for j := 0; j < c; j++ {
   814  			var sum float64
   815  			for i := 0; i < r; i++ {
   816  				sum += math.Abs(a.At(i, j))
   817  			}
   818  			if sum > max {
   819  				max = sum
   820  			}
   821  		}
   822  		return max
   823  	case 2:
   824  		var sum float64
   825  		for i := 0; i < r; i++ {
   826  			for j := 0; j < c; j++ {
   827  				v := a.At(i, j)
   828  				sum += v * v
   829  			}
   830  		}
   831  		return math.Sqrt(sum)
   832  	case math.Inf(1):
   833  		var max float64
   834  		for i := 0; i < r; i++ {
   835  			var sum float64
   836  			for j := 0; j < c; j++ {
   837  				sum += math.Abs(a.At(i, j))
   838  			}
   839  			if sum > max {
   840  				max = sum
   841  			}
   842  		}
   843  		return max
   844  	}
   845  }
   846  
   847  // normLapack converts the float64 norm input in Norm to a lapack.MatrixNorm.
   848  func normLapack(norm float64, aTrans bool) lapack.MatrixNorm {
   849  	switch norm {
   850  	case 1:
   851  		n := lapack.MaxColumnSum
   852  		if aTrans {
   853  			n = lapack.MaxRowSum
   854  		}
   855  		return n
   856  	case 2:
   857  		return lapack.Frobenius
   858  	case math.Inf(1):
   859  		n := lapack.MaxRowSum
   860  		if aTrans {
   861  			n = lapack.MaxColumnSum
   862  		}
   863  		return n
   864  	default:
   865  		panic(ErrNormOrder)
   866  	}
   867  }
   868  
   869  // Sum returns the sum of the elements of the matrix.
   870  func Sum(a Matrix) float64 {
   871  
   872  	var sum float64
   873  	aU, _ := untranspose(a)
   874  	switch rma := aU.(type) {
   875  	case RawSymmetricer:
   876  		rm := rma.RawSymmetric()
   877  		for i := 0; i < rm.N; i++ {
   878  			// Diagonals count once while off-diagonals count twice.
   879  			sum += rm.Data[i*rm.Stride+i]
   880  			var s float64
   881  			for _, v := range rm.Data[i*rm.Stride+i+1 : i*rm.Stride+rm.N] {
   882  				s += v
   883  			}
   884  			sum += 2 * s
   885  		}
   886  		return sum
   887  	case RawTriangular:
   888  		rm := rma.RawTriangular()
   889  		var startIdx, endIdx int
   890  		for i := 0; i < rm.N; i++ {
   891  			// Start and end index for this triangle-row.
   892  			switch rm.Uplo {
   893  			case blas.Upper:
   894  				startIdx = i
   895  				endIdx = rm.N
   896  			case blas.Lower:
   897  				startIdx = 0
   898  				endIdx = i + 1
   899  			default:
   900  				panic(badTriangle)
   901  			}
   902  			for _, v := range rm.Data[i*rm.Stride+startIdx : i*rm.Stride+endIdx] {
   903  				sum += v
   904  			}
   905  		}
   906  		return sum
   907  	case RawMatrixer:
   908  		rm := rma.RawMatrix()
   909  		for i := 0; i < rm.Rows; i++ {
   910  			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
   911  				sum += v
   912  			}
   913  		}
   914  		return sum
   915  	case *VecDense:
   916  		rm := rma.RawVector()
   917  		for i := 0; i < rm.N; i++ {
   918  			sum += rm.Data[i*rm.Inc]
   919  		}
   920  		return sum
   921  	default:
   922  		r, c := a.Dims()
   923  		for i := 0; i < r; i++ {
   924  			for j := 0; j < c; j++ {
   925  				sum += a.At(i, j)
   926  			}
   927  		}
   928  		return sum
   929  	}
   930  }
   931  
   932  // A Tracer can compute the trace of the matrix. Trace must panic if the
   933  // matrix is not square.
   934  type Tracer interface {
   935  	Trace() float64
   936  }
   937  
   938  // Trace returns the trace of the matrix. Trace will panic if the
   939  // matrix is not square. If a is a Tracer, its Trace method will be
   940  // used to calculate the matrix trace.
   941  func Trace(a Matrix) float64 {
   942  	m, _ := untransposeExtract(a)
   943  	if t, ok := m.(Tracer); ok {
   944  		return t.Trace()
   945  	}
   946  	r, c := a.Dims()
   947  	if r != c {
   948  		panic(ErrSquare)
   949  	}
   950  	var v float64
   951  	for i := 0; i < r; i++ {
   952  		v += a.At(i, i)
   953  	}
   954  	return v
   955  }
   956  
   957  func min(a, b int) int {
   958  	if a < b {
   959  		return a
   960  	}
   961  	return b
   962  }
   963  
   964  func max(a, b int) int {
   965  	if a > b {
   966  		return a
   967  	}
   968  	return b
   969  }
   970  
   971  // use returns a float64 slice with l elements, using f if it
   972  // has the necessary capacity, otherwise creating a new slice.
   973  func use(f []float64, l int) []float64 {
   974  	if l <= cap(f) {
   975  		return f[:l]
   976  	}
   977  	return make([]float64, l)
   978  }
   979  
   980  // useZeroed returns a float64 slice with l elements, using f if it
   981  // has the necessary capacity, otherwise creating a new slice. The
   982  // elements of the returned slice are guaranteed to be zero.
   983  func useZeroed(f []float64, l int) []float64 {
   984  	if l <= cap(f) {
   985  		f = f[:l]
   986  		zero(f)
   987  		return f
   988  	}
   989  	return make([]float64, l)
   990  }
   991  
   992  // zero zeros the given slice's elements.
   993  func zero(f []float64) {
   994  	for i := range f {
   995  		f[i] = 0
   996  	}
   997  }
   998  
   999  // useInt returns an int slice with l elements, using i if it
  1000  // has the necessary capacity, otherwise creating a new slice.
  1001  func useInt(i []int, l int) []int {
  1002  	if l <= cap(i) {
  1003  		return i[:l]
  1004  	}
  1005  	return make([]int, l)
  1006  }