github.com/wzzhu/tensor@v0.9.24/defaultengine_linalg.go (about)

     1  package tensor
     2  
     3  import (
     4  	"reflect"
     5  
     6  	"github.com/pkg/errors"
     7  	"gonum.org/v1/gonum/blas"
     8  	"gonum.org/v1/gonum/mat"
     9  )
    10  
    11  //  Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error
    12  func (e StdEng) Trace(t Tensor) (retVal interface{}, err error) {
    13  	if t.Dims() != 2 {
    14  		err = errors.Errorf(dimMismatch, 2, t.Dims())
    15  		return
    16  	}
    17  
    18  	if err = typeclassCheck(t.Dtype(), numberTypes); err != nil {
    19  		return nil, errors.Wrap(err, "Trace")
    20  	}
    21  
    22  	rstride := t.Strides()[0]
    23  	cstride := t.Strides()[1]
    24  
    25  	r := t.Shape()[0]
    26  	c := t.Shape()[1]
    27  
    28  	m := MinInt(r, c)
    29  	stride := rstride + cstride
    30  
    31  	switch data := t.Data().(type) {
    32  	case []int:
    33  		var trace int
    34  		for i := 0; i < m; i++ {
    35  			trace += data[i*stride]
    36  		}
    37  		retVal = trace
    38  	case []int8:
    39  		var trace int8
    40  		for i := 0; i < m; i++ {
    41  			trace += data[i*stride]
    42  		}
    43  		retVal = trace
    44  	case []int16:
    45  		var trace int16
    46  		for i := 0; i < m; i++ {
    47  			trace += data[i*stride]
    48  		}
    49  		retVal = trace
    50  	case []int32:
    51  		var trace int32
    52  		for i := 0; i < m; i++ {
    53  			trace += data[i*stride]
    54  		}
    55  		retVal = trace
    56  	case []int64:
    57  		var trace int64
    58  		for i := 0; i < m; i++ {
    59  			trace += data[i*stride]
    60  		}
    61  		retVal = trace
    62  	case []uint:
    63  		var trace uint
    64  		for i := 0; i < m; i++ {
    65  			trace += data[i*stride]
    66  		}
    67  		retVal = trace
    68  	case []uint8:
    69  		var trace uint8
    70  		for i := 0; i < m; i++ {
    71  			trace += data[i*stride]
    72  		}
    73  		retVal = trace
    74  	case []uint16:
    75  		var trace uint16
    76  		for i := 0; i < m; i++ {
    77  			trace += data[i*stride]
    78  		}
    79  		retVal = trace
    80  	case []uint32:
    81  		var trace uint32
    82  		for i := 0; i < m; i++ {
    83  			trace += data[i*stride]
    84  		}
    85  		retVal = trace
    86  	case []uint64:
    87  		var trace uint64
    88  		for i := 0; i < m; i++ {
    89  			trace += data[i*stride]
    90  		}
    91  		retVal = trace
    92  	case []float32:
    93  		var trace float32
    94  		for i := 0; i < m; i++ {
    95  			trace += data[i*stride]
    96  		}
    97  		retVal = trace
    98  	case []float64:
    99  		var trace float64
   100  		for i := 0; i < m; i++ {
   101  			trace += data[i*stride]
   102  		}
   103  		retVal = trace
   104  	case []complex64:
   105  		var trace complex64
   106  		for i := 0; i < m; i++ {
   107  			trace += data[i*stride]
   108  		}
   109  		retVal = trace
   110  	case []complex128:
   111  		var trace complex128
   112  		for i := 0; i < m; i++ {
   113  			trace += data[i*stride]
   114  		}
   115  		retVal = trace
   116  	}
   117  	return
   118  }
   119  
   120  func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
   121  	if _, ok := x.(DenseTensor); !ok {
   122  		err = errors.Errorf("Engine only supports working on x that is a DenseTensor. Got %T instead", x)
   123  		return
   124  	}
   125  
   126  	if _, ok := y.(DenseTensor); !ok {
   127  		err = errors.Errorf("Engine only supports working on y that is a DenseTensor. Got %T instead", y)
   128  		return
   129  	}
   130  
   131  	var a, b DenseTensor
   132  	if a, err = getFloatDenseTensor(x); err != nil {
   133  		err = errors.Wrapf(err, opFail, "Dot")
   134  		return
   135  	}
   136  	if b, err = getFloatDenseTensor(y); err != nil {
   137  		err = errors.Wrapf(err, opFail, "Dot")
   138  		return
   139  	}
   140  
   141  	fo := ParseFuncOpts(opts...)
   142  
   143  	var reuse, incr DenseTensor
   144  	if reuse, err = getFloatDenseTensor(fo.reuse); err != nil {
   145  		err = errors.Wrapf(err, opFail, "Dot - reuse")
   146  		return
   147  
   148  	}
   149  
   150  	if incr, err = getFloatDenseTensor(fo.incr); err != nil {
   151  		err = errors.Wrapf(err, opFail, "Dot - incr")
   152  		return
   153  	}
   154  
   155  	switch {
   156  	case a.IsScalar() && b.IsScalar():
   157  		var res interface{}
   158  		switch a.Dtype().Kind() {
   159  		case reflect.Float64:
   160  			res = a.GetF64(0) * b.GetF64(0)
   161  		case reflect.Float32:
   162  			res = a.GetF32(0) * b.GetF32(0)
   163  		}
   164  
   165  		switch {
   166  		case incr != nil:
   167  			if !incr.IsScalar() {
   168  				err = errors.Errorf(shapeMismatch, ScalarShape(), incr.Shape())
   169  				return
   170  			}
   171  			if err = e.E.MulIncr(a.Dtype().Type, a.hdr(), b.hdr(), incr.hdr()); err != nil {
   172  				err = errors.Wrapf(err, opFail, "Dot scalar incr")
   173  				return
   174  
   175  			}
   176  			retVal = incr
   177  		case reuse != nil:
   178  			reuse.Set(0, res)
   179  			reuse.reshape()
   180  			retVal = reuse
   181  		default:
   182  			retVal = New(FromScalar(res))
   183  		}
   184  		return
   185  	case a.IsScalar():
   186  		switch {
   187  		case incr != nil:
   188  			return Mul(a.ScalarValue(), b, WithIncr(incr))
   189  		case reuse != nil:
   190  			return Mul(a.ScalarValue(), b, WithReuse(reuse))
   191  		}
   192  		// default moved out
   193  		return Mul(a.ScalarValue(), b)
   194  	case b.IsScalar():
   195  		switch {
   196  		case incr != nil:
   197  			return Mul(a, b.ScalarValue(), WithIncr(incr))
   198  		case reuse != nil:
   199  			return Mul(a, b.ScalarValue(), WithReuse(reuse))
   200  		}
   201  		return Mul(a, b.ScalarValue())
   202  	}
   203  
   204  	switch {
   205  	case a.IsVector():
   206  		switch {
   207  		case b.IsVector():
   208  			// check size
   209  			if a.len() != b.len() {
   210  				err = errors.Errorf(shapeMismatch, a.Shape(), b.Shape())
   211  				return
   212  			}
   213  			var ret interface{}
   214  			if ret, err = e.Inner(a, b); err != nil {
   215  				return nil, errors.Wrapf(err, opFail, "Dot")
   216  			}
   217  			return New(FromScalar(ret)), nil
   218  		case b.IsMatrix():
   219  			b.T()
   220  			defer b.UT()
   221  			switch {
   222  			case reuse != nil && incr != nil:
   223  				return b.MatVecMul(a, WithReuse(reuse), WithIncr(incr))
   224  			case reuse != nil:
   225  				return b.MatVecMul(a, WithReuse(reuse))
   226  			case incr != nil:
   227  				return b.MatVecMul(a, WithIncr(incr))
   228  			default:
   229  			}
   230  			return b.MatVecMul(a)
   231  		default:
   232  
   233  		}
   234  	case a.IsMatrix():
   235  		switch {
   236  		case b.IsVector():
   237  			switch {
   238  			case reuse != nil && incr != nil:
   239  				return a.MatVecMul(b, WithReuse(reuse), WithIncr(incr))
   240  			case reuse != nil:
   241  				return a.MatVecMul(b, WithReuse(reuse))
   242  			case incr != nil:
   243  				return a.MatVecMul(b, WithIncr(incr))
   244  			default:
   245  			}
   246  			return a.MatVecMul(b)
   247  
   248  		case b.IsMatrix():
   249  			switch {
   250  			case reuse != nil && incr != nil:
   251  				return a.MatMul(b, WithReuse(reuse), WithIncr(incr))
   252  			case reuse != nil:
   253  				return a.MatMul(b, WithReuse(reuse))
   254  			case incr != nil:
   255  				return a.MatMul(b, WithIncr(incr))
   256  			default:
   257  			}
   258  			return a.MatMul(b)
   259  		default:
   260  		}
   261  	default:
   262  	}
   263  
   264  	as := a.Shape()
   265  	bs := b.Shape()
   266  	axesA := BorrowInts(1)
   267  	axesB := BorrowInts(1)
   268  	defer ReturnInts(axesA)
   269  	defer ReturnInts(axesB)
   270  
   271  	var lastA, secondLastB int
   272  
   273  	lastA = len(as) - 1
   274  	axesA[0] = lastA
   275  	if len(bs) >= 2 {
   276  		secondLastB = len(bs) - 2
   277  	} else {
   278  		secondLastB = 0
   279  	}
   280  	axesB[0] = secondLastB
   281  
   282  	if as[lastA] != bs[secondLastB] {
   283  		err = errors.Errorf(shapeMismatch, as, bs)
   284  		return
   285  	}
   286  
   287  	var rd *Dense
   288  	if rd, err = a.TensorMul(b, axesA, axesB); err != nil {
   289  		panic(err)
   290  	}
   291  
   292  	if reuse != nil {
   293  		copyDense(reuse, rd)
   294  		ap := rd.Info().Clone()
   295  		reuse.setAP(&ap)
   296  		defer ReturnTensor(rd)
   297  		// swap out the underlying data and metadata
   298  		// reuse.data, rd.data = rd.data, reuse.data
   299  		// reuse.AP, rd.AP = rd.AP, reuse.AP
   300  		// defer ReturnTensor(rd)
   301  
   302  		retVal = reuse
   303  	} else {
   304  		retVal = rd
   305  	}
   306  
   307  	return
   308  }
   309  
   310  // TODO: make it take DenseTensor
   311  func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) {
   312  	var t *Dense
   313  	var ok bool
   314  	if err = e.checkAccessible(a); err != nil {
   315  		return nil, nil, nil, errors.Wrapf(err, "opFail %v", "SVD")
   316  	}
   317  	if t, ok = a.(*Dense); !ok {
   318  		return nil, nil, nil, errors.Errorf("StdEng only performs SVDs for DenseTensors. Got %T instead", a)
   319  	}
   320  	if err = typeclassCheck(a.Dtype(), floatTypes); err != nil {
   321  		return nil, nil, nil, errors.Errorf("StdEng can only perform SVDs for float64 and float32 type. Got tensor of %v instead", t.Dtype())
   322  	}
   323  
   324  	if !t.IsMatrix() {
   325  		return nil, nil, nil, errors.Errorf(dimMismatch, 2, t.Dims())
   326  	}
   327  
   328  	var m *mat.Dense
   329  	var svd mat.SVD
   330  
   331  	if m, err = ToMat64(t, UseUnsafe()); err != nil {
   332  		return
   333  	}
   334  
   335  	switch {
   336  	case full && uv:
   337  		ok = svd.Factorize(m, mat.SVDFull)
   338  	case !full && uv:
   339  		ok = svd.Factorize(m, mat.SVDThin)
   340  	case full && !uv:
   341  		// illogical state - if you specify "full", you WANT the UV matrices
   342  		// error
   343  		err = errors.Errorf("SVD requires computation of `u` and `v` matrices if `full` was specified.")
   344  		return
   345  	default:
   346  		// by default, we return only the singular values
   347  		ok = svd.Factorize(m, mat.SVDNone)
   348  	}
   349  
   350  	if !ok {
   351  		// error
   352  		err = errors.Errorf("Unable to compute SVD")
   353  		return
   354  	}
   355  
   356  	// extract values
   357  	var um, vm mat.Dense
   358  	s = recycledDense(Float64, Shape{MinInt(t.Shape()[0], t.Shape()[1])}, WithEngine(e))
   359  	svd.Values(s.Data().([]float64))
   360  	if uv {
   361  		svd.UTo(&um)
   362  		svd.VTo(&vm)
   363  		// vm.VFromSVD(&svd)
   364  
   365  		u = FromMat64(&um, UseUnsafe(), As(t.t))
   366  		v = FromMat64(&vm, UseUnsafe(), As(t.t))
   367  	}
   368  
   369  	return
   370  }
   371  
   372  // Inner is a thin layer over BLAS's D/Sdot.
   373  // It returns a scalar value, wrapped in an interface{}, which is not quite nice.
   374  func (e StdEng) Inner(a, b Tensor) (retVal interface{}, err error) {
   375  	var ad, bd DenseTensor
   376  	if ad, bd, err = e.checkTwoFloatComplexTensors(a, b); err != nil {
   377  		return nil, errors.Wrapf(err, opFail, "StdEng.Inner")
   378  	}
   379  
   380  	switch A := ad.Data().(type) {
   381  	case []float32:
   382  		B := bd.Float32s()
   383  		retVal = whichblas.Sdot(len(A), A, 1, B, 1)
   384  	case []float64:
   385  		B := bd.Float64s()
   386  		retVal = whichblas.Ddot(len(A), A, 1, B, 1)
   387  	case []complex64:
   388  		B := bd.Complex64s()
   389  		retVal = whichblas.Cdotu(len(A), A, 1, B, 1)
   390  	case []complex128:
   391  		B := bd.Complex128s()
   392  		retVal = whichblas.Zdotu(len(A), A, 1, B, 1)
   393  	}
   394  	return
   395  }
   396  
   397  // MatVecMul is a thin layer over BLAS' DGEMV
   398  // Because DGEMV computes:
   399  // 		y = αA * x + βy
   400  // we set beta to 0, so we don't have to manually zero out the reused/retval tensor data
   401  func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) {
   402  	// check all are DenseTensors
   403  	var ad, bd, pd DenseTensor
   404  	if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil {
   405  		return errors.Wrapf(err, opFail, "StdEng.MatVecMul")
   406  	}
   407  
   408  	m := ad.oshape()[0]
   409  	n := ad.oshape()[1]
   410  
   411  	tA := blas.NoTrans
   412  	do := a.DataOrder()
   413  	z := ad.oldAP().IsZero()
   414  
   415  	var lda int
   416  	switch {
   417  	case do.IsRowMajor() && z:
   418  		lda = n
   419  	case do.IsRowMajor() && !z:
   420  		tA = blas.Trans
   421  		lda = n
   422  	case do.IsColMajor() && z:
   423  		tA = blas.Trans
   424  		lda = m
   425  		m, n = n, m
   426  	case do.IsColMajor() && !z:
   427  		lda = m
   428  		m, n = n, m
   429  	}
   430  
   431  	incX, incY := 1, 1 // step size
   432  
   433  	// ASPIRATIONAL TODO: different incX and incY
   434  	// TECHNICAL DEBT. TECHDEBT. TECH DEBT
   435  	// Example use case:
   436  	// log.Printf("a %v %v", ad.Strides(), ad.ostrides())
   437  	// log.Printf("b %v", b.Strides())
   438  	// incX := a.Strides()[0]
   439  	// incY = b.Strides()[0]
   440  
   441  	switch A := ad.Data().(type) {
   442  	case []float64:
   443  		x := bd.Float64s()
   444  		y := pd.Float64s()
   445  		alpha, beta := float64(1), float64(0)
   446  		whichblas.Dgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY)
   447  	case []float32:
   448  		x := bd.Float32s()
   449  		y := pd.Float32s()
   450  		alpha, beta := float32(1), float32(0)
   451  		whichblas.Sgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY)
   452  	case []complex64:
   453  		x := bd.Complex64s()
   454  		y := pd.Complex64s()
   455  		var alpha, beta complex64 = complex(1, 0), complex(0, 0)
   456  		whichblas.Cgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY)
   457  	case []complex128:
   458  		x := bd.Complex128s()
   459  		y := pd.Complex128s()
   460  		var alpha, beta complex128 = complex(1, 0), complex(0, 0)
   461  		whichblas.Zgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY)
   462  	default:
   463  		return errors.Errorf(typeNYI, "matVecMul", bd.Data())
   464  	}
   465  
   466  	return nil
   467  }
   468  
   469  // MatMul is a thin layer over DGEMM.
   470  // DGEMM computes:
   471  //		C = αA * B +  βC
   472  // To prevent needless zeroing out of the slice, we just set β to 0
   473  func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) {
   474  	// check all are DenseTensors
   475  	var ad, bd, pd DenseTensor
   476  	if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil {
   477  		return errors.Wrapf(err, opFail, "StdEng.MatMul")
   478  	}
   479  
   480  	ado := a.DataOrder()
   481  	bdo := b.DataOrder()
   482  	cdo := prealloc.DataOrder()
   483  
   484  	// get result shapes. k is the shared dimension
   485  	// a is (m, k)
   486  	// b is (k, n)
   487  	// c is (m, n)
   488  	var m, n, k int
   489  	m = ad.Shape()[0]
   490  	k = ad.Shape()[1]
   491  	n = bd.Shape()[1]
   492  
   493  	// wrt the strides, we use the original strides, because that's what BLAS needs, instead of calling .Strides()
   494  	// lda in colmajor = number of rows;
   495  	// lda in row major = number of cols
   496  	var lda, ldb, ldc int
   497  	switch {
   498  	case ado.IsColMajor():
   499  		lda = m
   500  	case ado.IsRowMajor():
   501  		lda = k
   502  	}
   503  
   504  	switch {
   505  	case bdo.IsColMajor():
   506  		ldb = bd.Shape()[0]
   507  	case bdo.IsRowMajor():
   508  		ldb = n
   509  	}
   510  
   511  	switch {
   512  	case cdo.IsColMajor():
   513  		ldc = prealloc.Shape()[0]
   514  	case cdo.IsRowMajor():
   515  		ldc = prealloc.Shape()[1]
   516  	}
   517  
   518  	// check for trans
   519  	tA, tB := blas.NoTrans, blas.NoTrans
   520  	if !ad.oldAP().IsZero() {
   521  		tA = blas.Trans
   522  		if ado.IsRowMajor() {
   523  			lda = m
   524  		} else {
   525  			lda = k
   526  		}
   527  	}
   528  	if !bd.oldAP().IsZero() {
   529  		tB = blas.Trans
   530  		if bdo.IsRowMajor() {
   531  			ldb = bd.Shape()[0]
   532  		} else {
   533  			ldb = bd.Shape()[1]
   534  		}
   535  	}
   536  
   537  	switch A := ad.Data().(type) {
   538  	case []float64:
   539  		B := bd.Float64s()
   540  		C := pd.Float64s()
   541  		alpha, beta := float64(1), float64(0)
   542  		if ado.IsColMajor() && bdo.IsColMajor() {
   543  			whichblas.Dgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc)
   544  		} else {
   545  			whichblas.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc)
   546  		}
   547  	case []float32:
   548  		B := bd.Float32s()
   549  		C := pd.Float32s()
   550  		alpha, beta := float32(1), float32(0)
   551  		if ado.IsColMajor() && bdo.IsColMajor() {
   552  			whichblas.Sgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc)
   553  		} else {
   554  			whichblas.Sgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc)
   555  		}
   556  	case []complex64:
   557  		B := bd.Complex64s()
   558  		C := pd.Complex64s()
   559  		var alpha, beta complex64 = complex(1, 0), complex(0, 0)
   560  		if ado.IsColMajor() && bdo.IsColMajor() {
   561  			whichblas.Cgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc)
   562  		} else {
   563  			whichblas.Cgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc)
   564  		}
   565  	case []complex128:
   566  		B := bd.Complex128s()
   567  		C := pd.Complex128s()
   568  		var alpha, beta complex128 = complex(1, 0), complex(0, 0)
   569  		if ado.IsColMajor() && bdo.IsColMajor() {
   570  			whichblas.Zgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc)
   571  		} else {
   572  			whichblas.Zgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc)
   573  		}
   574  	default:
   575  		return errors.Errorf(typeNYI, "matMul", ad.Data())
   576  	}
   577  	return
   578  }
   579  
   580  // Outer is a thin wrapper over S/Dger
   581  func (e StdEng) Outer(a, b, prealloc Tensor) (err error) {
   582  	// check all are DenseTensors
   583  	var ad, bd, pd DenseTensor
   584  	if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil {
   585  		return errors.Wrapf(err, opFail, "StdEng.Outer")
   586  	}
   587  
   588  	m := ad.Size()
   589  	n := bd.Size()
   590  	pdo := pd.DataOrder()
   591  
   592  	// the stride of a Vector is always going to be [1],
   593  	// incX := t.Strides()[0]
   594  	// incY := other.Strides()[0]
   595  	incX, incY := 1, 1
   596  	// lda := pd.Strides()[0]
   597  	var lda int
   598  	switch {
   599  	case pdo.IsColMajor():
   600  		aShape := a.Shape().Clone()
   601  		bShape := b.Shape().Clone()
   602  		if err = a.Reshape(aShape[0], 1); err != nil {
   603  			return err
   604  		}
   605  		if err = b.Reshape(1, bShape[0]); err != nil {
   606  			return err
   607  		}
   608  
   609  		if err = e.MatMul(a, b, prealloc); err != nil {
   610  			return err
   611  		}
   612  
   613  		if err = b.Reshape(bShape...); err != nil {
   614  			return
   615  		}
   616  		if err = a.Reshape(aShape...); err != nil {
   617  			return
   618  		}
   619  		return nil
   620  
   621  	case pdo.IsRowMajor():
   622  		lda = pd.Shape()[1]
   623  	}
   624  
   625  	switch x := ad.Data().(type) {
   626  	case []float64:
   627  		y := bd.Float64s()
   628  		A := pd.Float64s()
   629  		alpha := float64(1)
   630  		whichblas.Dger(m, n, alpha, x, incX, y, incY, A, lda)
   631  	case []float32:
   632  		y := bd.Float32s()
   633  		A := pd.Float32s()
   634  		alpha := float32(1)
   635  		whichblas.Sger(m, n, alpha, x, incX, y, incY, A, lda)
   636  	case []complex64:
   637  		y := bd.Complex64s()
   638  		A := pd.Complex64s()
   639  		var alpha complex64 = complex(1, 0)
   640  		whichblas.Cgeru(m, n, alpha, x, incX, y, incY, A, lda)
   641  	case []complex128:
   642  		y := bd.Complex128s()
   643  		A := pd.Complex128s()
   644  		var alpha complex128 = complex(1, 0)
   645  		whichblas.Zgeru(m, n, alpha, x, incX, y, incY, A, lda)
   646  	default:
   647  		return errors.Errorf(typeNYI, "outer", b.Data())
   648  	}
   649  	return nil
   650  }
   651  
   652  /* UNEXPORTED UTILITY FUNCTIONS */
   653  
   654  func (e StdEng) checkTwoFloatTensors(a, b Tensor) (ad, bd DenseTensor, err error) {
   655  	if err = e.checkAccessible(a); err != nil {
   656  		return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible")
   657  	}
   658  	if err = e.checkAccessible(b); err != nil {
   659  		return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible")
   660  	}
   661  
   662  	if a.Dtype() != b.Dtype() {
   663  		return nil, nil, errors.New("Expected a and b to have the same Dtype")
   664  	}
   665  
   666  	if ad, err = getFloatDenseTensor(a); err != nil {
   667  		return nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor")
   668  	}
   669  	if bd, err = getFloatDenseTensor(b); err != nil {
   670  		return nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor")
   671  	}
   672  	return
   673  }
   674  
   675  func (e StdEng) checkThreeFloatTensors(a, b, ret Tensor) (ad, bd, retVal DenseTensor, err error) {
   676  	if err = e.checkAccessible(a); err != nil {
   677  		return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible")
   678  	}
   679  	if err = e.checkAccessible(b); err != nil {
   680  		return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible")
   681  	}
   682  	if err = e.checkAccessible(ret); err != nil {
   683  		return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: ret is not accessible")
   684  	}
   685  
   686  	if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() {
   687  		return nil, nil, nil, errors.New("Expected a and b and retVal all to have the same Dtype")
   688  	}
   689  
   690  	if ad, err = getFloatDenseTensor(a); err != nil {
   691  		return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor")
   692  	}
   693  	if bd, err = getFloatDenseTensor(b); err != nil {
   694  		return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor")
   695  	}
   696  	if retVal, err = getFloatDenseTensor(ret); err != nil {
   697  		return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects retVal to be be a DenseTensor")
   698  	}
   699  	return
   700  }
   701  
   702  func (e StdEng) checkTwoFloatComplexTensors(a, b Tensor) (ad, bd DenseTensor, err error) {
   703  	if err = e.checkAccessible(a); err != nil {
   704  		return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible")
   705  	}
   706  	if err = e.checkAccessible(b); err != nil {
   707  		return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible")
   708  	}
   709  
   710  	if a.Dtype() != b.Dtype() {
   711  		return nil, nil, errors.New("Expected a and b to have the same Dtype")
   712  	}
   713  
   714  	if ad, err = getFloatComplexDenseTensor(a); err != nil {
   715  		return nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor")
   716  	}
   717  	if bd, err = getFloatComplexDenseTensor(b); err != nil {
   718  		return nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor")
   719  	}
   720  	return
   721  }
   722  
   723  func (e StdEng) checkThreeFloatComplexTensors(a, b, ret Tensor) (ad, bd, retVal DenseTensor, err error) {
   724  	if err = e.checkAccessible(a); err != nil {
   725  		return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible")
   726  	}
   727  	if err = e.checkAccessible(b); err != nil {
   728  		return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible")
   729  	}
   730  	if err = e.checkAccessible(ret); err != nil {
   731  		return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: ret is not accessible")
   732  	}
   733  
   734  	if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() {
   735  		return nil, nil, nil, errors.New("Expected a and b and retVal all to have the same Dtype")
   736  	}
   737  
   738  	if ad, err = getFloatComplexDenseTensor(a); err != nil {
   739  		return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor")
   740  	}
   741  	if bd, err = getFloatComplexDenseTensor(b); err != nil {
   742  		return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor")
   743  	}
   744  	if retVal, err = getFloatComplexDenseTensor(ret); err != nil {
   745  		return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects retVal to be be a DenseTensor")
   746  	}
   747  	return
   748  }