github.com/wzzhu/tensor@v0.9.24/dense_linalg.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  )
     6  
     7  // Trace returns the trace of the matrix (i.e. the sum of the diagonal elements). It only works for matrices
     8  func (t *Dense) Trace() (retVal interface{}, err error) {
     9  	e := t.e
    10  
    11  	if tracer, ok := e.(Tracer); ok {
    12  		return tracer.Trace(t)
    13  	}
    14  	return nil, errors.Errorf("Engine %T does not support Trace", e)
    15  }
    16  
    17  // Inner performs a dot product on two vectors. If t or other are not vectors, it will return an error.
    18  func (t *Dense) Inner(other Tensor) (retVal interface{}, err error) {
    19  	// check that the data is a float
    20  	if err = typeclassCheck(t.t, floatcmplxTypes); err != nil {
    21  		return nil, errors.Wrapf(err, unsupportedDtype, t.t, "Inner")
    22  	}
    23  
    24  	// check both are vectors
    25  	if !t.Shape().IsVector() || !other.Shape().IsVector() {
    26  		return nil, errors.Errorf("Inner only works when there are two vectors. t's Shape: %v; other's Shape %v", t.Shape(), other.Shape())
    27  	}
    28  
    29  	// we do this check instead of the more common t.Shape()[1] != other.Shape()[0],
    30  	// basically to ensure a similarity with numpy's dot and vectors.
    31  	if t.len() != other.DataSize() {
    32  		return nil, errors.Errorf(shapeMismatch, t.Shape(), other.Shape())
    33  	}
    34  
    35  	e := t.e
    36  	switch ip := e.(type) {
    37  	case InnerProderF32:
    38  		return ip.Inner(t, other)
    39  	case InnerProderF64:
    40  		return ip.Inner(t, other)
    41  	case InnerProder:
    42  		return ip.Inner(t, other)
    43  	}
    44  
    45  	return nil, errors.Errorf("Engine does not support Inner()")
    46  }
    47  
    48  // MatVecMul performs a matrix-vector multiplication.
    49  func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) {
    50  	// check that it's a matrix x vector
    51  	if t.Dims() != 2 || !other.Shape().IsVector() {
    52  		err = errors.Errorf("MatVecMul requires t be a matrix and other to be a vector. Got t's shape: %v, other's shape: %v", t.Shape(), other.Shape())
    53  		return
    54  	}
    55  
    56  	// checks that t is mxn matrix
    57  	m := t.Shape()[0]
    58  	n := t.Shape()[1]
    59  
    60  	// check shape
    61  	var odim int
    62  	oshape := other.Shape()
    63  	switch {
    64  	case oshape.IsColVec():
    65  		odim = oshape[0]
    66  	case oshape.IsRowVec():
    67  		odim = oshape[1]
    68  	case oshape.IsVector():
    69  		odim = oshape[0]
    70  	default:
    71  		err = errors.Errorf(shapeMismatch, t.Shape(), other.Shape()) // should be unreachable
    72  		return
    73  	}
    74  
    75  	if odim != n {
    76  		err = errors.Errorf(shapeMismatch, n, other.Shape())
    77  		return
    78  	}
    79  
    80  	expectedShape := Shape{m}
    81  
    82  	// check whether retVal has the same size as the resulting matrix would be: mx1
    83  	fo := ParseFuncOpts(opts...)
    84  	defer returnOpOpt(fo)
    85  	if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil {
    86  		err = errors.Wrapf(err, opFail, "MatVecMul")
    87  		return
    88  	}
    89  
    90  	if retVal == nil {
    91  		retVal = recycledDense(t.t, expectedShape, WithEngine(t.e))
    92  		if t.o.IsColMajor() {
    93  			AsFortran(nil)(retVal)
    94  		}
    95  	}
    96  
    97  	e := t.e
    98  
    99  	if mvm, ok := e.(MatVecMuler); ok {
   100  		if err = mvm.MatVecMul(t, other, retVal); err != nil {
   101  			return nil, errors.Wrapf(err, opFail, "MatVecMul")
   102  		}
   103  		return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape)
   104  	}
   105  	return nil, errors.New("engine does not support MatVecMul")
   106  }
   107  
   108  // MatMul is the basic matrix multiplication that you learned in high school. It takes an optional reuse ndarray, where the ndarray is reused as the result.
   109  // If that isn't passed in,  a new ndarray will be created instead.
   110  func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) {
   111  	// check that both are matrices
   112  	if !t.Shape().IsMatrix() || !other.Shape().IsMatrix() {
   113  		err = errors.Errorf("MatMul requires both operands to be matrices. Got t's shape: %v, other's shape: %v", t.Shape(), other.Shape())
   114  		return
   115  	}
   116  
   117  	// checks that t is mxk matrix
   118  	var m, n, k int
   119  	m = t.Shape()[0]
   120  	k = t.Shape()[1]
   121  	n = other.Shape()[1]
   122  
   123  	// check shape
   124  	if k != other.Shape()[0] {
   125  		err = errors.Errorf(shapeMismatch, t.Shape(), other.Shape())
   126  		return
   127  	}
   128  
   129  	// check whether retVal has the same size as the resulting matrix would be: mxn
   130  	expectedShape := Shape{m, n}
   131  
   132  	fo := ParseFuncOpts(opts...)
   133  	defer returnOpOpt(fo)
   134  	if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil {
   135  		err = errors.Wrapf(err, opFail, "MatMul")
   136  		return
   137  	}
   138  
   139  	if retVal == nil {
   140  		retVal = recycledDense(t.t, expectedShape, WithEngine(t.e))
   141  		if t.o.IsColMajor() {
   142  			AsFortran(nil)(retVal)
   143  		}
   144  	}
   145  
   146  	e := t.e
   147  	if mm, ok := e.(MatMuler); ok {
   148  		if err = mm.MatMul(t, other, retVal); err != nil {
   149  			return
   150  		}
   151  		return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape)
   152  	}
   153  
   154  	return nil, errors.New("engine does not support MatMul")
   155  }
   156  
   157  // Outer finds the outer product of two vectors
   158  func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) {
   159  	// check both are vectors
   160  	if !t.Shape().IsVector() || !other.Shape().IsVector() {
   161  		err = errors.Errorf("Outer only works when there are two vectors. t's shape: %v. other's shape: %v", t.Shape(), other.Shape())
   162  		return
   163  	}
   164  
   165  	m := t.Size()
   166  	n := other.Size()
   167  
   168  	// check whether retVal has the same size as the resulting matrix would be: mxn
   169  	expectedShape := Shape{m, n}
   170  
   171  	fo := ParseFuncOpts(opts...)
   172  	defer returnOpOpt(fo)
   173  	if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil {
   174  		err = errors.Wrapf(err, opFail, "Outer")
   175  		return
   176  	}
   177  
   178  	if retVal == nil {
   179  		retVal = recycledDense(t.t, expectedShape, WithEngine(t.e))
   180  		if t.o.IsColMajor() {
   181  			AsFortran(nil)(retVal)
   182  		}
   183  	}
   184  
   185  	e := t.e
   186  
   187  	// DGER does not have any beta. So the values have to be zeroed first if the tensor is to be reused
   188  	retVal.Zero()
   189  	if op, ok := e.(OuterProder); ok {
   190  		if err = op.Outer(t, other, retVal); err != nil {
   191  			return nil, errors.Wrapf(err, opFail, "engine.uter")
   192  		}
   193  		return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape)
   194  	}
   195  	return nil, errors.New("engine does not support Outer")
   196  }
   197  
   198  // TensorMul is for multiplying Tensors with more than 2 dimensions.
   199  //
   200  // The algorithm is conceptually simple (but tricky to get right):
   201  // 		1. Transpose and reshape the Tensors in such a way that both t and other are 2D matrices
   202  //		2. Use DGEMM to multiply them
   203  //		3. Reshape the results to be the new expected result
   204  //
   205  // This function is a Go implementation of Numpy's tensordot method. It simplifies a lot of what Numpy does.
   206  func (t *Dense) TensorMul(other Tensor, axesA, axesB []int) (retVal *Dense, err error) {
   207  	ts := t.Shape()
   208  	td := len(ts)
   209  
   210  	os := other.Shape()
   211  	od := len(os)
   212  
   213  	na := len(axesA)
   214  	nb := len(axesB)
   215  	sameLength := na == nb
   216  	if sameLength {
   217  		for i := 0; i < na; i++ {
   218  			if ts[axesA[i]] != os[axesB[i]] {
   219  				sameLength = false
   220  				break
   221  			}
   222  			if axesA[i] < 0 {
   223  				axesA[i] += td
   224  			}
   225  
   226  			if axesB[i] < 0 {
   227  				axesB[i] += od
   228  			}
   229  		}
   230  	}
   231  
   232  	if !sameLength {
   233  		err = errors.Errorf(shapeMismatch, ts, os)
   234  		return
   235  	}
   236  
   237  	// handle shapes
   238  	var notins []int
   239  	for i := 0; i < td; i++ {
   240  		notin := true
   241  		for _, a := range axesA {
   242  			if i == a {
   243  				notin = false
   244  				break
   245  			}
   246  		}
   247  		if notin {
   248  			notins = append(notins, i)
   249  		}
   250  	}
   251  
   252  	newAxesA := BorrowInts(len(notins) + len(axesA))
   253  	defer ReturnInts(newAxesA)
   254  	newAxesA = newAxesA[:0]
   255  	newAxesA = append(notins, axesA...)
   256  	n2 := 1
   257  	for _, a := range axesA {
   258  		n2 *= ts[a]
   259  	}
   260  
   261  	newShapeT := Shape(BorrowInts(2))
   262  	defer ReturnInts(newShapeT)
   263  	newShapeT[0] = ts.TotalSize() / n2
   264  	newShapeT[1] = n2
   265  
   266  	retShape1 := BorrowInts(len(ts))
   267  	defer ReturnInts(retShape1)
   268  	retShape1 = retShape1[:0]
   269  	for _, ni := range notins {
   270  		retShape1 = append(retShape1, ts[ni])
   271  	}
   272  
   273  	// work on other now
   274  	notins = notins[:0]
   275  	for i := 0; i < od; i++ {
   276  		notin := true
   277  		for _, a := range axesB {
   278  			if i == a {
   279  				notin = false
   280  				break
   281  			}
   282  		}
   283  		if notin {
   284  			notins = append(notins, i)
   285  		}
   286  	}
   287  
   288  	newAxesB := BorrowInts(len(notins) + len(axesB))
   289  	defer ReturnInts(newAxesB)
   290  	newAxesB = newAxesB[:0]
   291  	newAxesB = append(axesB, notins...)
   292  
   293  	newShapeO := Shape(BorrowInts(2))
   294  	defer ReturnInts(newShapeO)
   295  	newShapeO[0] = n2
   296  	newShapeO[1] = os.TotalSize() / n2
   297  
   298  	retShape2 := BorrowInts(len(ts))
   299  	retShape2 = retShape2[:0]
   300  	for _, ni := range notins {
   301  		retShape2 = append(retShape2, os[ni])
   302  	}
   303  
   304  	// we borrowClone because we don't want to touch the original Tensors
   305  	doT := t.Clone().(*Dense)
   306  	doOther := other.Clone().(*Dense)
   307  	defer ReturnTensor(doT)
   308  	defer ReturnTensor(doOther)
   309  
   310  	if err = doT.T(newAxesA...); err != nil {
   311  		return
   312  	}
   313  	doT.Transpose() // we have to materialize the transpose first or the underlying data won't be changed and the reshape that follows would be meaningless
   314  
   315  	if err = doT.Reshape(newShapeT...); err != nil {
   316  		return
   317  	}
   318  
   319  	if err = doOther.T(newAxesB...); err != nil {
   320  		return
   321  	}
   322  	doOther.Transpose()
   323  	if err = doOther.Reshape(newShapeO...); err != nil {
   324  		return
   325  	}
   326  
   327  	// the magic happens here
   328  	var rt Tensor
   329  	if rt, err = Dot(doT, doOther); err != nil {
   330  		return
   331  	}
   332  	retVal = rt.(*Dense)
   333  
   334  	retShape := BorrowInts(len(retShape1) + len(retShape2))
   335  	defer ReturnInts(retShape)
   336  
   337  	retShape = retShape[:0]
   338  	retShape = append(retShape, retShape1...)
   339  	retShape = append(retShape, retShape2...)
   340  
   341  	if len(retShape) == 0 { // In case a scalar is returned, it should be returned as shape = {1}
   342  		retShape = append(retShape, 1)
   343  	}
   344  
   345  	if err = retVal.Reshape(retShape...); err != nil {
   346  		return
   347  	}
   348  
   349  	return
   350  }
   351  
   352  // SVD does the Single Value Decomposition for the *Dense.
   353  //
   354  // How it works is it temporarily converts the *Dense into a gonum/mat64 matrix, and uses Gonum's SVD function to perform the SVD.
   355  // In the future, when gonum/lapack fully supports float32, we'll look into rewriting this
   356  func (t *Dense) SVD(uv, full bool) (s, u, v *Dense, err error) {
   357  	e := t.Engine()
   358  
   359  	if svder, ok := e.(SVDer); ok {
   360  		var sT, uT, vT Tensor
   361  		if sT, uT, vT, err = svder.SVD(t, uv, full); err != nil {
   362  			return nil, nil, nil, errors.Wrap(err, "Error while performing *Dense.SVD")
   363  		}
   364  		if s, err = assertDense(sT); err != nil {
   365  			return nil, nil, nil, errors.Wrapf(err, "sT is not *Dense (uv %t full %t). Got %T instead", uv, full, sT)
   366  		}
   367  		// if not uv and not full, u can be nil
   368  		if u, err = assertDense(uT); err != nil && !(!uv && !full) {
   369  			return nil, nil, nil, errors.Wrapf(err, "uT is not *Dense (uv %t full %t). Got %T instead", uv, full, uT)
   370  		}
   371  		// if not uv and not full, v can be nil
   372  		if v, err = assertDense(vT); err != nil && !(!uv && !full) {
   373  			return nil, nil, nil, errors.Wrapf(err, "vT is not *Dense (uv %t full %t). Got %T instead", uv, full, vT)
   374  		}
   375  		return s, u, v, nil
   376  	}
   377  	return nil, nil, nil, errors.New("Engine does not support SVD")
   378  }
   379  
   380  /* UTILITY FUNCTIONS */
   381  
   382  // handleReuse extracts a *Dense from Tensor, and checks the shape of the reuse Tensor
   383  func handleReuse(reuse Tensor, expectedShape Shape, safe bool) (retVal *Dense, err error) {
   384  	if reuse != nil {
   385  		if retVal, err = assertDense(reuse); err != nil {
   386  			err = errors.Wrapf(err, opFail, "handling reuse")
   387  			return
   388  		}
   389  		if !safe {
   390  			return
   391  		}
   392  		if err = reuseCheckShape(retVal, expectedShape); err != nil {
   393  			err = errors.Wrapf(err, "Unable to process reuse *Dense Tensor. Shape error.")
   394  			return
   395  		}
   396  		return
   397  	}
   398  	return
   399  }
   400  
   401  // handleIncr is the cleanup step for when there is an Tensor to increment. If the result tensor is the same as the reuse Tensor, the result tensor gets returned to the pool
   402  func handleIncr(res *Dense, reuse, incr Tensor, expectedShape Shape) (retVal *Dense, err error) {
   403  	// handle increments
   404  	if incr != nil {
   405  		if !expectedShape.Eq(incr.Shape()) {
   406  			err = errors.Errorf(shapeMismatch, expectedShape, incr.Shape())
   407  			return
   408  		}
   409  		var incrD *Dense
   410  		var ok bool
   411  		if incrD, ok = incr.(*Dense); !ok {
   412  			err = errors.Errorf(extractionFail, "*Dense", incr)
   413  			return
   414  		}
   415  
   416  		if err = typeclassCheck(incrD.t, numberTypes); err != nil {
   417  			err = errors.Wrapf(err, "handleIncr only handles Number types. Got %v instead", incrD.t)
   418  			return
   419  		}
   420  
   421  		if _, err = incrD.Add(res, UseUnsafe()); err != nil {
   422  			return
   423  		}
   424  		// vecAdd(incr.data, retVal.data)
   425  
   426  		// return retVal to pool - if and only if retVal is not reuse
   427  		// reuse indicates that someone else also has the reference to the *Dense
   428  		if res != reuse {
   429  			ReturnTensor(res)
   430  		}
   431  
   432  		// then
   433  		retVal = incrD
   434  		return
   435  	}
   436  
   437  	return res, nil
   438  }