github.com/wzzhu/tensor@v0.9.24/defaultenginefloat64.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	"github.com/wzzhu/tensor/internal/execution"
     6  	"github.com/wzzhu/tensor/internal/storage"
     7  
     8  	"gorgonia.org/vecf64"
     9  )
    10  
    11  func handleFuncOptsF64(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) {
    12  	fo := ParseFuncOpts(opts...)
    13  
    14  	reuseT, incr := fo.IncrReuse()
    15  	safe = fo.Safe()
    16  	toReuse = reuseT != nil
    17  
    18  	if toReuse {
    19  		var ok bool
    20  		if reuse, ok = reuseT.(DenseTensor); !ok {
    21  			returnOpOpt(fo)
    22  			err = errors.Errorf("Cannot reuse a different type of Tensor in a *Dense-Scalar operation. Reuse is of %T", reuseT)
    23  			return
    24  		}
    25  		if reuse.len() != expShape.TotalSize() && !expShape.IsScalar() {
    26  			returnOpOpt(fo)
    27  			err = errors.Errorf(shapeMismatch, reuse.Shape(), expShape)
    28  			err = errors.Wrapf(err, "Cannot use reuse: shape mismatch")
    29  			return
    30  		}
    31  
    32  		if !incr && reuse != nil {
    33  			reuse.setDataOrder(o)
    34  			// err = reuse.reshape(expShape...)
    35  		}
    36  
    37  	}
    38  	returnOpOpt(fo)
    39  	return
    40  }
    41  
    42  func prepDataVSF64(a Tensor, b interface{}, reuse Tensor) (dataA *storage.Header, dataB float64, dataReuse *storage.Header, ait, iit Iterator, useIter bool, err error) {
    43  	// get data
    44  	dataA = a.hdr()
    45  	switch bt := b.(type) {
    46  	case float64:
    47  		dataB = bt
    48  	case *float64:
    49  		dataB = *bt
    50  	default:
    51  		err = errors.Errorf("b is not a float64: %T", b)
    52  		return
    53  	}
    54  	if reuse != nil {
    55  		dataReuse = reuse.hdr()
    56  	}
    57  
    58  	if a.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) {
    59  		ait = a.Iterator()
    60  		if reuse != nil {
    61  			iit = reuse.Iterator()
    62  		}
    63  		useIter = true
    64  	}
    65  	return
    66  }
    67  
    68  func (e Float64Engine) checkThree(a, b Tensor, reuse Tensor) error {
    69  	if !a.IsNativelyAccessible() {
    70  		return errors.Errorf(inaccessibleData, a)
    71  	}
    72  	if !b.IsNativelyAccessible() {
    73  		return errors.Errorf(inaccessibleData, b)
    74  	}
    75  
    76  	if reuse != nil && !reuse.IsNativelyAccessible() {
    77  		return errors.Errorf(inaccessibleData, reuse)
    78  	}
    79  
    80  	if a.Dtype() != Float64 {
    81  		return errors.Errorf("Expected a to be of Float64. Got %v instead", a.Dtype())
    82  	}
    83  	if a.Dtype() != b.Dtype() || (reuse != nil && b.Dtype() != reuse.Dtype()) {
    84  		return errors.Errorf("Expected a, b and reuse to have the same Dtype. Got %v, %v and %v instead", a.Dtype(), b.Dtype(), reuse.Dtype())
    85  	}
    86  	return nil
    87  }
    88  
    89  func (e Float64Engine) checkTwo(a Tensor, reuse Tensor) error {
    90  	if !a.IsNativelyAccessible() {
    91  		return errors.Errorf(inaccessibleData, a)
    92  	}
    93  	if reuse != nil && !reuse.IsNativelyAccessible() {
    94  		return errors.Errorf(inaccessibleData, reuse)
    95  	}
    96  
    97  	if a.Dtype() != Float64 {
    98  		return errors.Errorf("Expected a to be of Float64. Got %v instead", a.Dtype())
    99  	}
   100  
   101  	if reuse != nil && reuse.Dtype() != a.Dtype() {
   102  		return errors.Errorf("Expected reuse to be the same as a. Got %v instead", reuse.Dtype())
   103  	}
   104  	return nil
   105  }
   106  
   107  // Float64Engine is an execution engine that is optimized to only work with float64s. It assumes all data will are float64s.
   108  //
   109  // Use this engine only as form of optimization. You should probably be using the basic default engine for most cases.
   110  type Float64Engine struct {
   111  	StdEng
   112  }
   113  
   114  // makeArray allocates a slice for the array
   115  func (e Float64Engine) makeArray(arr *array, t Dtype, size int) {
   116  	if t != Float64 {
   117  		panic("Float64Engine only creates float64s")
   118  	}
   119  	arr.Header.Raw = make([]byte, size*8)
   120  	arr.t = t
   121  }
   122  
   123  func (e Float64Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) {
   124  	reuse := y
   125  	if err = e.checkThree(a, x, reuse); err != nil {
   126  		return nil, errors.Wrap(err, "Failed checks")
   127  	}
   128  
   129  	var dataA, dataB, dataReuse *storage.Header
   130  	var ait, bit, iit Iterator
   131  	var useIter bool
   132  	if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, x, reuse); err != nil {
   133  		return nil, errors.Wrap(err, "Float64Engine.FMA")
   134  	}
   135  	if useIter {
   136  		err = execution.MulIterIncrF64(dataA.Float64s(), dataB.Float64s(), dataReuse.Float64s(), ait, bit, iit)
   137  		retVal = reuse
   138  		return
   139  	}
   140  
   141  	vecf64.IncrMul(dataA.Float64s(), dataB.Float64s(), dataReuse.Float64s())
   142  	retVal = reuse
   143  	return
   144  }
   145  
   146  func (e Float64Engine) FMAScalar(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) {
   147  	reuse := y
   148  	if err = e.checkTwo(a, reuse); err != nil {
   149  		return nil, errors.Wrap(err, "Failed checks")
   150  	}
   151  
   152  	var ait, iit Iterator
   153  	var dataTensor, dataReuse *storage.Header
   154  	var scalar float64
   155  	var useIter bool
   156  	if dataTensor, scalar, dataReuse, ait, iit, useIter, err = prepDataVSF64(a, x, reuse); err != nil {
   157  		return nil, errors.Wrapf(err, opFail, "Float64Engine.FMAScalar")
   158  	}
   159  	if useIter {
   160  		err = execution.MulIterIncrVSF64(dataTensor.Float64s(), scalar, dataReuse.Float64s(), ait, iit)
   161  		retVal = reuse
   162  	}
   163  
   164  	execution.MulIncrVSF64(dataTensor.Float64s(), scalar, dataReuse.Float64s())
   165  	retVal = reuse
   166  	return
   167  }
   168  
   169  // Add performs a + b elementwise. Both a and b must have the same shape.
   170  // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T)
   171  func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
   172  	if a.RequiresIterator() || b.RequiresIterator() {
   173  		return e.StdEng.Add(a, b, opts...)
   174  	}
   175  
   176  	var reuse DenseTensor
   177  	var safe, toReuse, incr bool
   178  	if reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), a.DataOrder(), opts...); err != nil {
   179  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
   180  	}
   181  	if err = e.checkThree(a, b, reuse); err != nil {
   182  		return nil, errors.Wrap(err, "Failed checks")
   183  	}
   184  
   185  	var hdrA, hdrB, hdrReuse *storage.Header
   186  	var dataA, dataB, dataReuse []float64
   187  
   188  	if hdrA, hdrB, hdrReuse, _, _, _, _, _, err = prepDataVV(a, b, reuse); err != nil {
   189  		return nil, errors.Wrapf(err, "Float64Engine.Add")
   190  	}
   191  	dataA = hdrA.Float64s()
   192  	dataB = hdrB.Float64s()
   193  	if hdrReuse != nil {
   194  		dataReuse = hdrReuse.Float64s()
   195  	}
   196  
   197  	switch {
   198  	case incr:
   199  		vecf64.IncrAdd(dataA, dataB, dataReuse)
   200  		retVal = reuse
   201  	case toReuse:
   202  		copy(dataReuse, dataA)
   203  		vecf64.Add(dataReuse, dataB)
   204  		retVal = reuse
   205  	case !safe:
   206  		vecf64.Add(dataA, dataB)
   207  		retVal = a
   208  	default:
   209  		ret := a.Clone().(headerer)
   210  		vecf64.Add(ret.hdr().Float64s(), dataB)
   211  		retVal = ret.(Tensor)
   212  	}
   213  	return
   214  }
   215  
   216  func (e Float64Engine) Inner(a, b Tensor) (retVal float64, err error) {
   217  	var A, B []float64
   218  	var AD, BD *Dense
   219  	var ok bool
   220  
   221  	if AD, ok = a.(*Dense); !ok {
   222  		return 0, errors.Errorf("a is not a *Dense")
   223  	}
   224  	if BD, ok = b.(*Dense); !ok {
   225  		return 0, errors.Errorf("b is not a *Dense")
   226  	}
   227  
   228  	A = AD.Float64s()
   229  	B = BD.Float64s()
   230  	retVal = whichblas.Ddot(len(A), A, 1, B, 1)
   231  	return
   232  }