gonum.org/v1/gonum@v0.14.0/mat/dense_arithmetic.go (about)

     1  // Copyright ©2013 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math"
     9  
    10  	"gonum.org/v1/gonum/blas"
    11  	"gonum.org/v1/gonum/blas/blas64"
    12  	"gonum.org/v1/gonum/lapack/lapack64"
    13  )
    14  
    15  // Add adds a and b element-wise, placing the result in the receiver. Add
    16  // will panic if the two matrices do not have the same shape.
    17  func (m *Dense) Add(a, b Matrix) {
    18  	ar, ac := a.Dims()
    19  	br, bc := b.Dims()
    20  	if ar != br || ac != bc {
    21  		panic(ErrShape)
    22  	}
    23  
    24  	aU, aTrans := untransposeExtract(a)
    25  	bU, bTrans := untransposeExtract(b)
    26  	m.reuseAsNonZeroed(ar, ac)
    27  
    28  	if arm, ok := a.(*Dense); ok {
    29  		if brm, ok := b.(*Dense); ok {
    30  			amat, bmat := arm.mat, brm.mat
    31  			if m != aU {
    32  				m.checkOverlap(amat)
    33  			}
    34  			if m != bU {
    35  				m.checkOverlap(bmat)
    36  			}
    37  			for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
    38  				for i, v := range amat.Data[ja : ja+ac] {
    39  					m.mat.Data[i+jm] = v + bmat.Data[i+jb]
    40  				}
    41  			}
    42  			return
    43  		}
    44  	}
    45  
    46  	m.checkOverlapMatrix(aU)
    47  	m.checkOverlapMatrix(bU)
    48  	var restore func()
    49  	if aTrans && m == aU {
    50  		m, restore = m.isolatedWorkspace(aU)
    51  		defer restore()
    52  	} else if bTrans && m == bU {
    53  		m, restore = m.isolatedWorkspace(bU)
    54  		defer restore()
    55  	}
    56  
    57  	for r := 0; r < ar; r++ {
    58  		for c := 0; c < ac; c++ {
    59  			m.set(r, c, a.At(r, c)+b.At(r, c))
    60  		}
    61  	}
    62  }
    63  
    64  // Sub subtracts the matrix b from a, placing the result in the receiver. Sub
    65  // will panic if the two matrices do not have the same shape.
    66  func (m *Dense) Sub(a, b Matrix) {
    67  	ar, ac := a.Dims()
    68  	br, bc := b.Dims()
    69  	if ar != br || ac != bc {
    70  		panic(ErrShape)
    71  	}
    72  
    73  	aU, aTrans := untransposeExtract(a)
    74  	bU, bTrans := untransposeExtract(b)
    75  	m.reuseAsNonZeroed(ar, ac)
    76  
    77  	if arm, ok := a.(*Dense); ok {
    78  		if brm, ok := b.(*Dense); ok {
    79  			amat, bmat := arm.mat, brm.mat
    80  			if m != aU {
    81  				m.checkOverlap(amat)
    82  			}
    83  			if m != bU {
    84  				m.checkOverlap(bmat)
    85  			}
    86  			for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
    87  				for i, v := range amat.Data[ja : ja+ac] {
    88  					m.mat.Data[i+jm] = v - bmat.Data[i+jb]
    89  				}
    90  			}
    91  			return
    92  		}
    93  	}
    94  
    95  	m.checkOverlapMatrix(aU)
    96  	m.checkOverlapMatrix(bU)
    97  	var restore func()
    98  	if aTrans && m == aU {
    99  		m, restore = m.isolatedWorkspace(aU)
   100  		defer restore()
   101  	} else if bTrans && m == bU {
   102  		m, restore = m.isolatedWorkspace(bU)
   103  		defer restore()
   104  	}
   105  
   106  	for r := 0; r < ar; r++ {
   107  		for c := 0; c < ac; c++ {
   108  			m.set(r, c, a.At(r, c)-b.At(r, c))
   109  		}
   110  	}
   111  }
   112  
   113  // MulElem performs element-wise multiplication of a and b, placing the result
   114  // in the receiver. MulElem will panic if the two matrices do not have the same
   115  // shape.
   116  func (m *Dense) MulElem(a, b Matrix) {
   117  	ar, ac := a.Dims()
   118  	br, bc := b.Dims()
   119  	if ar != br || ac != bc {
   120  		panic(ErrShape)
   121  	}
   122  
   123  	aU, aTrans := untransposeExtract(a)
   124  	bU, bTrans := untransposeExtract(b)
   125  	m.reuseAsNonZeroed(ar, ac)
   126  
   127  	if arm, ok := a.(*Dense); ok {
   128  		if brm, ok := b.(*Dense); ok {
   129  			amat, bmat := arm.mat, brm.mat
   130  			if m != aU {
   131  				m.checkOverlap(amat)
   132  			}
   133  			if m != bU {
   134  				m.checkOverlap(bmat)
   135  			}
   136  			for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
   137  				for i, v := range amat.Data[ja : ja+ac] {
   138  					m.mat.Data[i+jm] = v * bmat.Data[i+jb]
   139  				}
   140  			}
   141  			return
   142  		}
   143  	}
   144  
   145  	m.checkOverlapMatrix(aU)
   146  	m.checkOverlapMatrix(bU)
   147  	var restore func()
   148  	if aTrans && m == aU {
   149  		m, restore = m.isolatedWorkspace(aU)
   150  		defer restore()
   151  	} else if bTrans && m == bU {
   152  		m, restore = m.isolatedWorkspace(bU)
   153  		defer restore()
   154  	}
   155  
   156  	for r := 0; r < ar; r++ {
   157  		for c := 0; c < ac; c++ {
   158  			m.set(r, c, a.At(r, c)*b.At(r, c))
   159  		}
   160  	}
   161  }
   162  
   163  // DivElem performs element-wise division of a by b, placing the result
   164  // in the receiver. DivElem will panic if the two matrices do not have the same
   165  // shape.
   166  func (m *Dense) DivElem(a, b Matrix) {
   167  	ar, ac := a.Dims()
   168  	br, bc := b.Dims()
   169  	if ar != br || ac != bc {
   170  		panic(ErrShape)
   171  	}
   172  
   173  	aU, aTrans := untransposeExtract(a)
   174  	bU, bTrans := untransposeExtract(b)
   175  	m.reuseAsNonZeroed(ar, ac)
   176  
   177  	if arm, ok := a.(*Dense); ok {
   178  		if brm, ok := b.(*Dense); ok {
   179  			amat, bmat := arm.mat, brm.mat
   180  			if m != aU {
   181  				m.checkOverlap(amat)
   182  			}
   183  			if m != bU {
   184  				m.checkOverlap(bmat)
   185  			}
   186  			for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
   187  				for i, v := range amat.Data[ja : ja+ac] {
   188  					m.mat.Data[i+jm] = v / bmat.Data[i+jb]
   189  				}
   190  			}
   191  			return
   192  		}
   193  	}
   194  
   195  	m.checkOverlapMatrix(aU)
   196  	m.checkOverlapMatrix(bU)
   197  	var restore func()
   198  	if aTrans && m == aU {
   199  		m, restore = m.isolatedWorkspace(aU)
   200  		defer restore()
   201  	} else if bTrans && m == bU {
   202  		m, restore = m.isolatedWorkspace(bU)
   203  		defer restore()
   204  	}
   205  
   206  	for r := 0; r < ar; r++ {
   207  		for c := 0; c < ac; c++ {
   208  			m.set(r, c, a.At(r, c)/b.At(r, c))
   209  		}
   210  	}
   211  }
   212  
   213  // Inverse computes the inverse of the matrix a, storing the result into the
   214  // receiver. If a is ill-conditioned, a Condition error will be returned.
   215  // Note that matrix inversion is numerically unstable, and should generally
   216  // be avoided where possible, for example by using the Solve routines.
   217  func (m *Dense) Inverse(a Matrix) error {
   218  	// TODO(btracey): Special case for RawTriangular, etc.
   219  	r, c := a.Dims()
   220  	if r != c {
   221  		panic(ErrSquare)
   222  	}
   223  	m.reuseAsNonZeroed(a.Dims())
   224  	aU, aTrans := untransposeExtract(a)
   225  	switch rm := aU.(type) {
   226  	case *Dense:
   227  		if m != aU || aTrans {
   228  			if m == aU || m.checkOverlap(rm.mat) {
   229  				tmp := getDenseWorkspace(r, c, false)
   230  				tmp.Copy(a)
   231  				m.Copy(tmp)
   232  				putDenseWorkspace(tmp)
   233  				break
   234  			}
   235  			m.Copy(a)
   236  		}
   237  	default:
   238  		m.Copy(a)
   239  	}
   240  	// Compute the norm of A.
   241  	work := getFloat64s(4*r, false) // Length must be at least 4*r for Gecon.
   242  	norm := lapack64.Lange(CondNorm, m.mat, work)
   243  	// Compute the LU factorization of A.
   244  	ipiv := getInts(r, false)
   245  	defer putInts(ipiv)
   246  	ok := lapack64.Getrf(m.mat, ipiv)
   247  	if !ok {
   248  		// A is exactly singular.
   249  		return Condition(math.Inf(1))
   250  	}
   251  	// Compute the condition number of A using the LU factorization.
   252  	iwork := getInts(r, false)
   253  	defer putInts(iwork)
   254  	rcond := lapack64.Gecon(CondNorm, m.mat, norm, work, iwork)
   255  	// Compute A^{-1} from the LU factorization regardless of the value of rcond.
   256  	lapack64.Getri(m.mat, ipiv, work, -1)
   257  	if int(work[0]) > len(work) {
   258  		l := int(work[0])
   259  		putFloat64s(work)
   260  		work = getFloat64s(l, false)
   261  	}
   262  	defer putFloat64s(work)
   263  	ok = lapack64.Getri(m.mat, ipiv, work, len(work))
   264  	if !ok || rcond == 0 {
   265  		// A is exactly singular.
   266  		return Condition(math.Inf(1))
   267  	}
   268  	// Check whether A is singular for computational purposes.
   269  	cond := 1 / rcond
   270  	if cond > ConditionTolerance {
   271  		return Condition(cond)
   272  	}
   273  	return nil
   274  }
   275  
   276  // Mul takes the matrix product of a and b, placing the result in the receiver.
   277  // If the number of columns in a does not equal the number of rows in b, Mul will panic.
   278  func (m *Dense) Mul(a, b Matrix) {
   279  	ar, ac := a.Dims()
   280  	br, bc := b.Dims()
   281  
   282  	if ac != br {
   283  		panic(ErrShape)
   284  	}
   285  
   286  	aU, aTrans := untransposeExtract(a)
   287  	bU, bTrans := untransposeExtract(b)
   288  	m.reuseAsNonZeroed(ar, bc)
   289  	var restore func()
   290  	if m == aU {
   291  		m, restore = m.isolatedWorkspace(aU)
   292  		defer restore()
   293  	} else if m == bU {
   294  		m, restore = m.isolatedWorkspace(bU)
   295  		defer restore()
   296  	}
   297  	aT := blas.NoTrans
   298  	if aTrans {
   299  		aT = blas.Trans
   300  	}
   301  	bT := blas.NoTrans
   302  	if bTrans {
   303  		bT = blas.Trans
   304  	}
   305  
   306  	// Some of the cases do not have a transpose option, so create
   307  	// temporary memory.
   308  	// C = Aᵀ * B = (Bᵀ * A)ᵀ
   309  	// Cᵀ = Bᵀ * A.
   310  	if aU, ok := aU.(*Dense); ok {
   311  		if restore == nil {
   312  			m.checkOverlap(aU.mat)
   313  		}
   314  		switch bU := bU.(type) {
   315  		case *Dense:
   316  			if restore == nil {
   317  				m.checkOverlap(bU.mat)
   318  			}
   319  			blas64.Gemm(aT, bT, 1, aU.mat, bU.mat, 0, m.mat)
   320  			return
   321  
   322  		case *SymDense:
   323  			if aTrans {
   324  				c := getDenseWorkspace(ac, ar, false)
   325  				blas64.Symm(blas.Left, 1, bU.mat, aU.mat, 0, c.mat)
   326  				strictCopy(m, c.T())
   327  				putDenseWorkspace(c)
   328  				return
   329  			}
   330  			blas64.Symm(blas.Right, 1, bU.mat, aU.mat, 0, m.mat)
   331  			return
   332  
   333  		case *TriDense:
   334  			// Trmm updates in place, so copy aU first.
   335  			if aTrans {
   336  				c := getDenseWorkspace(ac, ar, false)
   337  				var tmp Dense
   338  				tmp.SetRawMatrix(aU.mat)
   339  				c.Copy(&tmp)
   340  				bT := blas.Trans
   341  				if bTrans {
   342  					bT = blas.NoTrans
   343  				}
   344  				blas64.Trmm(blas.Left, bT, 1, bU.mat, c.mat)
   345  				strictCopy(m, c.T())
   346  				putDenseWorkspace(c)
   347  				return
   348  			}
   349  			m.Copy(a)
   350  			blas64.Trmm(blas.Right, bT, 1, bU.mat, m.mat)
   351  			return
   352  
   353  		case *VecDense:
   354  			m.checkOverlap(bU.asGeneral())
   355  			bvec := bU.RawVector()
   356  			if bTrans {
   357  				// {ar,1} x {1,bc}, which is not a vector.
   358  				// Instead, construct B as a General.
   359  				bmat := blas64.General{
   360  					Rows:   bc,
   361  					Cols:   1,
   362  					Stride: bvec.Inc,
   363  					Data:   bvec.Data,
   364  				}
   365  				blas64.Gemm(aT, bT, 1, aU.mat, bmat, 0, m.mat)
   366  				return
   367  			}
   368  			cvec := blas64.Vector{
   369  				Inc:  m.mat.Stride,
   370  				Data: m.mat.Data,
   371  			}
   372  			blas64.Gemv(aT, 1, aU.mat, bvec, 0, cvec)
   373  			return
   374  		}
   375  	}
   376  	if bU, ok := bU.(*Dense); ok {
   377  		if restore == nil {
   378  			m.checkOverlap(bU.mat)
   379  		}
   380  		switch aU := aU.(type) {
   381  		case *SymDense:
   382  			if bTrans {
   383  				c := getDenseWorkspace(bc, br, false)
   384  				blas64.Symm(blas.Right, 1, aU.mat, bU.mat, 0, c.mat)
   385  				strictCopy(m, c.T())
   386  				putDenseWorkspace(c)
   387  				return
   388  			}
   389  			blas64.Symm(blas.Left, 1, aU.mat, bU.mat, 0, m.mat)
   390  			return
   391  
   392  		case *TriDense:
   393  			// Trmm updates in place, so copy bU first.
   394  			if bTrans {
   395  				c := getDenseWorkspace(bc, br, false)
   396  				var tmp Dense
   397  				tmp.SetRawMatrix(bU.mat)
   398  				c.Copy(&tmp)
   399  				aT := blas.Trans
   400  				if aTrans {
   401  					aT = blas.NoTrans
   402  				}
   403  				blas64.Trmm(blas.Right, aT, 1, aU.mat, c.mat)
   404  				strictCopy(m, c.T())
   405  				putDenseWorkspace(c)
   406  				return
   407  			}
   408  			m.Copy(b)
   409  			blas64.Trmm(blas.Left, aT, 1, aU.mat, m.mat)
   410  			return
   411  
   412  		case *VecDense:
   413  			m.checkOverlap(aU.asGeneral())
   414  			avec := aU.RawVector()
   415  			if aTrans {
   416  				// {1,ac} x {ac, bc}
   417  				// Transpose B so that the vector is on the right.
   418  				cvec := blas64.Vector{
   419  					Inc:  1,
   420  					Data: m.mat.Data,
   421  				}
   422  				bT := blas.Trans
   423  				if bTrans {
   424  					bT = blas.NoTrans
   425  				}
   426  				blas64.Gemv(bT, 1, bU.mat, avec, 0, cvec)
   427  				return
   428  			}
   429  			// {ar,1} x {1,bc} which is not a vector result.
   430  			// Instead, construct A as a General.
   431  			amat := blas64.General{
   432  				Rows:   ar,
   433  				Cols:   1,
   434  				Stride: avec.Inc,
   435  				Data:   avec.Data,
   436  			}
   437  			blas64.Gemm(aT, bT, 1, amat, bU.mat, 0, m.mat)
   438  			return
   439  		}
   440  	}
   441  
   442  	m.checkOverlapMatrix(aU)
   443  	m.checkOverlapMatrix(bU)
   444  	row := getFloat64s(ac, false)
   445  	defer putFloat64s(row)
   446  	for r := 0; r < ar; r++ {
   447  		for i := range row {
   448  			row[i] = a.At(r, i)
   449  		}
   450  		for c := 0; c < bc; c++ {
   451  			var v float64
   452  			for i, e := range row {
   453  				v += e * b.At(i, c)
   454  			}
   455  			m.mat.Data[r*m.mat.Stride+c] = v
   456  		}
   457  	}
   458  }
   459  
   460  // strictCopy copies a into m panicking if the shape of a and m differ.
   461  func strictCopy(m *Dense, a Matrix) {
   462  	r, c := m.Copy(a)
   463  	if r != m.mat.Rows || c != m.mat.Cols {
   464  		// Panic with a string since this
   465  		// is not a user-facing panic.
   466  		panic(ErrShape.Error())
   467  	}
   468  }
   469  
   470  // Exp calculates the exponential of the matrix a, e^a, placing the result
   471  // in the receiver. Exp will panic with ErrShape if a is not square.
   472  func (m *Dense) Exp(a Matrix) {
   473  	// The implementation used here is from Functions of Matrices: Theory and Computation
   474  	// Chapter 10, Algorithm 10.20. https://doi.org/10.1137/1.9780898717778.ch10
   475  
   476  	r, c := a.Dims()
   477  	if r != c {
   478  		panic(ErrShape)
   479  	}
   480  
   481  	m.reuseAsNonZeroed(r, r)
   482  	if r == 1 {
   483  		m.mat.Data[0] = math.Exp(a.At(0, 0))
   484  		return
   485  	}
   486  
   487  	pade := []struct {
   488  		theta float64
   489  		b     []float64
   490  	}{
   491  		{theta: 0.015, b: []float64{
   492  			120, 60, 12, 1,
   493  		}},
   494  		{theta: 0.25, b: []float64{
   495  			30240, 15120, 3360, 420, 30, 1,
   496  		}},
   497  		{theta: 0.95, b: []float64{
   498  			17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1,
   499  		}},
   500  		{theta: 2.1, b: []float64{
   501  			17643225600, 8821612800, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1,
   502  		}},
   503  	}
   504  
   505  	a1 := m
   506  	a1.Copy(a)
   507  	v := getDenseWorkspace(r, r, true)
   508  	vraw := v.RawMatrix()
   509  	n := r * r
   510  	vvec := blas64.Vector{N: n, Inc: 1, Data: vraw.Data}
   511  	defer putDenseWorkspace(v)
   512  
   513  	u := getDenseWorkspace(r, r, true)
   514  	uraw := u.RawMatrix()
   515  	uvec := blas64.Vector{N: n, Inc: 1, Data: uraw.Data}
   516  	defer putDenseWorkspace(u)
   517  
   518  	a2 := getDenseWorkspace(r, r, false)
   519  	defer putDenseWorkspace(a2)
   520  
   521  	n1 := Norm(a, 1)
   522  	for i, t := range pade {
   523  		if n1 > t.theta {
   524  			continue
   525  		}
   526  
   527  		// This loop only executes once, so
   528  		// this is not as horrible as it looks.
   529  		p := getDenseWorkspace(r, r, true)
   530  		praw := p.RawMatrix()
   531  		pvec := blas64.Vector{N: n, Inc: 1, Data: praw.Data}
   532  		defer putDenseWorkspace(p)
   533  
   534  		for k := 0; k < r; k++ {
   535  			p.set(k, k, 1)
   536  			v.set(k, k, t.b[0])
   537  			u.set(k, k, t.b[1])
   538  		}
   539  
   540  		a2.Mul(a1, a1)
   541  		for j := 0; j <= i; j++ {
   542  			p.Mul(p, a2)
   543  			blas64.Axpy(t.b[2*j+2], pvec, vvec)
   544  			blas64.Axpy(t.b[2*j+3], pvec, uvec)
   545  		}
   546  		u.Mul(a1, u)
   547  
   548  		// Use p as a workspace here and
   549  		// rename u for the second call's
   550  		// receiver.
   551  		vmu, vpu := u, p
   552  		vpu.Add(v, u)
   553  		vmu.Sub(v, u)
   554  
   555  		_ = m.Solve(vmu, vpu)
   556  		return
   557  	}
   558  
   559  	// Remaining Padé table line.
   560  	const theta13 = 5.4
   561  	b := [...]float64{
   562  		64764752532480000, 32382376266240000, 7771770303897600, 1187353796428800,
   563  		129060195264000, 10559470521600, 670442572800, 33522128640,
   564  		1323241920, 40840800, 960960, 16380, 182, 1,
   565  	}
   566  
   567  	s := math.Log2(n1 / theta13)
   568  	if s >= 0 {
   569  		s = math.Ceil(s)
   570  		a1.Scale(1/math.Pow(2, s), a1)
   571  	}
   572  	a2.Mul(a1, a1)
   573  
   574  	i := getDenseWorkspace(r, r, true)
   575  	for j := 0; j < r; j++ {
   576  		i.set(j, j, 1)
   577  	}
   578  	iraw := i.RawMatrix()
   579  	ivec := blas64.Vector{N: n, Inc: 1, Data: iraw.Data}
   580  	defer putDenseWorkspace(i)
   581  
   582  	a2raw := a2.RawMatrix()
   583  	a2vec := blas64.Vector{N: n, Inc: 1, Data: a2raw.Data}
   584  
   585  	a4 := getDenseWorkspace(r, r, false)
   586  	a4raw := a4.RawMatrix()
   587  	a4vec := blas64.Vector{N: n, Inc: 1, Data: a4raw.Data}
   588  	defer putDenseWorkspace(a4)
   589  	a4.Mul(a2, a2)
   590  
   591  	a6 := getDenseWorkspace(r, r, false)
   592  	a6raw := a6.RawMatrix()
   593  	a6vec := blas64.Vector{N: n, Inc: 1, Data: a6raw.Data}
   594  	defer putDenseWorkspace(a6)
   595  	a6.Mul(a2, a4)
   596  
   597  	// V = A_6(b_12*A_6 + b_10*A_4 + b_8*A_2) + b_6*A_6 + b_4*A_4 + b_2*A_2 +b_0*I
   598  	blas64.Axpy(b[12], a6vec, vvec)
   599  	blas64.Axpy(b[10], a4vec, vvec)
   600  	blas64.Axpy(b[8], a2vec, vvec)
   601  	v.Mul(v, a6)
   602  	blas64.Axpy(b[6], a6vec, vvec)
   603  	blas64.Axpy(b[4], a4vec, vvec)
   604  	blas64.Axpy(b[2], a2vec, vvec)
   605  	blas64.Axpy(b[0], ivec, vvec)
   606  
   607  	// U = A(A_6(b_13*A_6 + b_11*A_4 + b_9*A_2) + b_7*A_6 + b_5*A_4 + b_2*A_3 +b_1*I)
   608  	blas64.Axpy(b[13], a6vec, uvec)
   609  	blas64.Axpy(b[11], a4vec, uvec)
   610  	blas64.Axpy(b[9], a2vec, uvec)
   611  	u.Mul(u, a6)
   612  	blas64.Axpy(b[7], a6vec, uvec)
   613  	blas64.Axpy(b[5], a4vec, uvec)
   614  	blas64.Axpy(b[3], a2vec, uvec)
   615  	blas64.Axpy(b[1], ivec, uvec)
   616  	u.Mul(u, a1)
   617  
   618  	// Use i as a workspace here and
   619  	// rename u for the second call's
   620  	// receiver.
   621  	vmu, vpu := u, i
   622  	vpu.Add(v, u)
   623  	vmu.Sub(v, u)
   624  
   625  	_ = m.Solve(vmu, vpu)
   626  
   627  	for ; s > 0; s-- {
   628  		m.Mul(m, m)
   629  	}
   630  }
   631  
   632  // Pow calculates the integral power of the matrix a to n, placing the result
   633  // in the receiver. Pow will panic if n is negative or if a is not square.
   634  func (m *Dense) Pow(a Matrix, n int) {
   635  	if n < 0 {
   636  		panic("mat: illegal power")
   637  	}
   638  	r, c := a.Dims()
   639  	if r != c {
   640  		panic(ErrShape)
   641  	}
   642  
   643  	m.reuseAsNonZeroed(r, c)
   644  
   645  	// Take possible fast paths.
   646  	switch n {
   647  	case 0:
   648  		for i := 0; i < r; i++ {
   649  			zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
   650  			m.mat.Data[i*m.mat.Stride+i] = 1
   651  		}
   652  		return
   653  	case 1:
   654  		m.Copy(a)
   655  		return
   656  	case 2:
   657  		m.Mul(a, a)
   658  		return
   659  	}
   660  
   661  	// Perform iterative exponentiation by squaring in work space.
   662  	w := getDenseWorkspace(r, r, false)
   663  	w.Copy(a)
   664  	s := getDenseWorkspace(r, r, false)
   665  	s.Copy(a)
   666  	x := getDenseWorkspace(r, r, false)
   667  	for n--; n > 0; n >>= 1 {
   668  		if n&1 != 0 {
   669  			x.Mul(w, s)
   670  			w, x = x, w
   671  		}
   672  		if n != 1 {
   673  			x.Mul(s, s)
   674  			s, x = x, s
   675  		}
   676  	}
   677  	m.Copy(w)
   678  	putDenseWorkspace(w)
   679  	putDenseWorkspace(s)
   680  	putDenseWorkspace(x)
   681  }
   682  
   683  // Kronecker calculates the Kronecker product of a and b, placing the result in
   684  // the receiver.
   685  func (m *Dense) Kronecker(a, b Matrix) {
   686  	ra, ca := a.Dims()
   687  	rb, cb := b.Dims()
   688  
   689  	m.reuseAsNonZeroed(ra*rb, ca*cb)
   690  	for i := 0; i < ra; i++ {
   691  		for j := 0; j < ca; j++ {
   692  			m.slice(i*rb, (i+1)*rb, j*cb, (j+1)*cb).Scale(a.At(i, j), b)
   693  		}
   694  	}
   695  }
   696  
   697  // Scale multiplies the elements of a by f, placing the result in the receiver.
   698  //
   699  // See the Scaler interface for more information.
   700  func (m *Dense) Scale(f float64, a Matrix) {
   701  	ar, ac := a.Dims()
   702  
   703  	m.reuseAsNonZeroed(ar, ac)
   704  
   705  	aU, aTrans := untransposeExtract(a)
   706  	if rm, ok := aU.(*Dense); ok {
   707  		amat := rm.mat
   708  		if m == aU || m.checkOverlap(amat) {
   709  			var restore func()
   710  			m, restore = m.isolatedWorkspace(a)
   711  			defer restore()
   712  		}
   713  		if !aTrans {
   714  			for ja, jm := 0, 0; ja < ar*amat.Stride; ja, jm = ja+amat.Stride, jm+m.mat.Stride {
   715  				for i, v := range amat.Data[ja : ja+ac] {
   716  					m.mat.Data[i+jm] = v * f
   717  				}
   718  			}
   719  		} else {
   720  			for ja, jm := 0, 0; ja < ac*amat.Stride; ja, jm = ja+amat.Stride, jm+1 {
   721  				for i, v := range amat.Data[ja : ja+ar] {
   722  					m.mat.Data[i*m.mat.Stride+jm] = v * f
   723  				}
   724  			}
   725  		}
   726  		return
   727  	}
   728  
   729  	m.checkOverlapMatrix(a)
   730  	for r := 0; r < ar; r++ {
   731  		for c := 0; c < ac; c++ {
   732  			m.set(r, c, f*a.At(r, c))
   733  		}
   734  	}
   735  }
   736  
   737  // Apply applies the function fn to each of the elements of a, placing the
   738  // resulting matrix in the receiver. The function fn takes a row/column
   739  // index and element value and returns some function of that tuple.
   740  func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) {
   741  	ar, ac := a.Dims()
   742  
   743  	m.reuseAsNonZeroed(ar, ac)
   744  
   745  	aU, aTrans := untransposeExtract(a)
   746  	if rm, ok := aU.(*Dense); ok {
   747  		amat := rm.mat
   748  		if m == aU || m.checkOverlap(amat) {
   749  			var restore func()
   750  			m, restore = m.isolatedWorkspace(a)
   751  			defer restore()
   752  		}
   753  		if !aTrans {
   754  			for j, ja, jm := 0, 0, 0; ja < ar*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+m.mat.Stride {
   755  				for i, v := range amat.Data[ja : ja+ac] {
   756  					m.mat.Data[i+jm] = fn(j, i, v)
   757  				}
   758  			}
   759  		} else {
   760  			for j, ja, jm := 0, 0, 0; ja < ac*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+1 {
   761  				for i, v := range amat.Data[ja : ja+ar] {
   762  					m.mat.Data[i*m.mat.Stride+jm] = fn(i, j, v)
   763  				}
   764  			}
   765  		}
   766  		return
   767  	}
   768  
   769  	m.checkOverlapMatrix(a)
   770  	for r := 0; r < ar; r++ {
   771  		for c := 0; c < ac; c++ {
   772  			m.set(r, c, fn(r, c, a.At(r, c)))
   773  		}
   774  	}
   775  }
   776  
   777  // RankOne performs a rank-one update to the matrix a with the vectors x and
   778  // y, where x and y are treated as column vectors. The result is stored in the
   779  // receiver. The Outer method can be used instead of RankOne if a is not needed.
   780  //
   781  //	m = a + alpha * x * yᵀ
   782  func (m *Dense) RankOne(a Matrix, alpha float64, x, y Vector) {
   783  	ar, ac := a.Dims()
   784  	if x.Len() != ar {
   785  		panic(ErrShape)
   786  	}
   787  	if y.Len() != ac {
   788  		panic(ErrShape)
   789  	}
   790  
   791  	if a != m {
   792  		aU, _ := untransposeExtract(a)
   793  		if rm, ok := aU.(*Dense); ok {
   794  			m.checkOverlap(rm.RawMatrix())
   795  		}
   796  	}
   797  
   798  	var xmat, ymat blas64.Vector
   799  	fast := true
   800  	xU, _ := untransposeExtract(x)
   801  	if rv, ok := xU.(*VecDense); ok {
   802  		r, c := xU.Dims()
   803  		xmat = rv.mat
   804  		m.checkOverlap(generalFromVector(xmat, r, c))
   805  	} else {
   806  		fast = false
   807  	}
   808  	yU, _ := untransposeExtract(y)
   809  	if rv, ok := yU.(*VecDense); ok {
   810  		r, c := yU.Dims()
   811  		ymat = rv.mat
   812  		m.checkOverlap(generalFromVector(ymat, r, c))
   813  	} else {
   814  		fast = false
   815  	}
   816  
   817  	if fast {
   818  		if m != a {
   819  			m.reuseAsNonZeroed(ar, ac)
   820  			m.Copy(a)
   821  		}
   822  		blas64.Ger(alpha, xmat, ymat, m.mat)
   823  		return
   824  	}
   825  
   826  	m.reuseAsNonZeroed(ar, ac)
   827  	for i := 0; i < ar; i++ {
   828  		for j := 0; j < ac; j++ {
   829  			m.set(i, j, a.At(i, j)+alpha*x.AtVec(i)*y.AtVec(j))
   830  		}
   831  	}
   832  }
   833  
   834  // Outer calculates the outer product of the vectors x and y, where x and y
   835  // are treated as column vectors, and stores the result in the receiver.
   836  //
   837  //	m = alpha * x * yᵀ
   838  //
   839  // In order to update an existing matrix, see RankOne.
   840  func (m *Dense) Outer(alpha float64, x, y Vector) {
   841  	r, c := x.Len(), y.Len()
   842  
   843  	m.reuseAsZeroed(r, c)
   844  
   845  	var xmat, ymat blas64.Vector
   846  	fast := true
   847  	xU, _ := untransposeExtract(x)
   848  	if rv, ok := xU.(*VecDense); ok {
   849  		r, c := xU.Dims()
   850  		xmat = rv.mat
   851  		m.checkOverlap(generalFromVector(xmat, r, c))
   852  	} else {
   853  		fast = false
   854  	}
   855  	yU, _ := untransposeExtract(y)
   856  	if rv, ok := yU.(*VecDense); ok {
   857  		r, c := yU.Dims()
   858  		ymat = rv.mat
   859  		m.checkOverlap(generalFromVector(ymat, r, c))
   860  	} else {
   861  		fast = false
   862  	}
   863  
   864  	if fast {
   865  		for i := 0; i < r; i++ {
   866  			zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
   867  		}
   868  		blas64.Ger(alpha, xmat, ymat, m.mat)
   869  		return
   870  	}
   871  
   872  	for i := 0; i < r; i++ {
   873  		for j := 0; j < c; j++ {
   874  			m.set(i, j, alpha*x.AtVec(i)*y.AtVec(j))
   875  		}
   876  	}
   877  }