github.com/wzzhu/tensor@v0.9.24/defaultenginefloat32.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	"github.com/wzzhu/tensor/internal/execution"
     6  	"github.com/wzzhu/tensor/internal/storage"
     7  
     8  	"gorgonia.org/vecf32"
     9  )
    10  
    11  func handleFuncOptsF32(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) {
    12  	fo := ParseFuncOpts(opts...)
    13  
    14  	reuseT, incr := fo.IncrReuse()
    15  	safe = fo.Safe()
    16  	toReuse = reuseT != nil
    17  
    18  	if toReuse {
    19  		var ok bool
    20  		if reuse, ok = reuseT.(DenseTensor); !ok {
    21  			returnOpOpt(fo)
    22  			err = errors.Errorf("Cannot reuse a different type of Tensor in a *Dense-Scalar operation. Reuse is of %T", reuseT)
    23  			return
    24  		}
    25  		if reuse.len() != expShape.TotalSize() && !expShape.IsScalar() {
    26  			returnOpOpt(fo)
    27  			err = errors.Errorf(shapeMismatch, reuse.Shape(), expShape)
    28  			err = errors.Wrapf(err, "Cannot use reuse: shape mismatch")
    29  			return
    30  		}
    31  
    32  		if !incr && reuse != nil {
    33  			reuse.setDataOrder(o)
    34  			// err = reuse.reshape(expShape...)
    35  		}
    36  
    37  	}
    38  	returnOpOpt(fo)
    39  	return
    40  }
    41  
    42  func prepDataVSF32(a Tensor, b interface{}, reuse Tensor) (dataA *storage.Header, dataB float32, dataReuse *storage.Header, ait, iit Iterator, useIter bool, err error) {
    43  	// get data
    44  	dataA = a.hdr()
    45  	switch bt := b.(type) {
    46  	case float32:
    47  		dataB = bt
    48  	case *float32:
    49  		dataB = *bt
    50  	default:
    51  		err = errors.Errorf("b is not a float32: %T", b)
    52  		return
    53  	}
    54  	if reuse != nil {
    55  		dataReuse = reuse.hdr()
    56  	}
    57  
    58  	if a.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) {
    59  		ait = a.Iterator()
    60  		if reuse != nil {
    61  			iit = reuse.Iterator()
    62  		}
    63  		useIter = true
    64  	}
    65  	return
    66  }
    67  
    68  func (e Float32Engine) checkThree(a, b Tensor, reuse Tensor) error {
    69  	if !a.IsNativelyAccessible() {
    70  		return errors.Errorf(inaccessibleData, a)
    71  	}
    72  	if !b.IsNativelyAccessible() {
    73  		return errors.Errorf(inaccessibleData, b)
    74  	}
    75  
    76  	if reuse != nil && !reuse.IsNativelyAccessible() {
    77  		return errors.Errorf(inaccessibleData, reuse)
    78  	}
    79  
    80  	if a.Dtype() != Float32 {
    81  		return errors.Errorf("Expected a to be of Float32. Got %v instead", a.Dtype())
    82  	}
    83  	if a.Dtype() != b.Dtype() || (reuse != nil && b.Dtype() != reuse.Dtype()) {
    84  		return errors.Errorf("Expected a, b and reuse to have the same Dtype. Got %v, %v and %v instead", a.Dtype(), b.Dtype(), reuse.Dtype())
    85  	}
    86  	return nil
    87  }
    88  
    89  func (e Float32Engine) checkTwo(a Tensor, reuse Tensor) error {
    90  	if !a.IsNativelyAccessible() {
    91  		return errors.Errorf(inaccessibleData, a)
    92  	}
    93  	if reuse != nil && !reuse.IsNativelyAccessible() {
    94  		return errors.Errorf(inaccessibleData, reuse)
    95  	}
    96  
    97  	if a.Dtype() != Float32 {
    98  		return errors.Errorf("Expected a to be of Float32. Got %v instead", a.Dtype())
    99  	}
   100  
   101  	if reuse != nil && reuse.Dtype() != a.Dtype() {
   102  		return errors.Errorf("Expected reuse to be the same as a. Got %v instead", reuse.Dtype())
   103  	}
   104  	return nil
   105  }
   106  
   107  // Float32Engine is an execution engine that is optimized to only work with float32s. It assumes all data will are float32s.
   108  //
   109  // Use this engine only as form of optimization. You should probably be using the basic default engine for most cases.
   110  type Float32Engine struct {
   111  	StdEng
   112  }
   113  
   114  // makeArray allocates a slice for the array
   115  func (e Float32Engine) makeArray(arr *array, t Dtype, size int) {
   116  	if t != Float32 {
   117  		panic("Float32Engine only creates float32s")
   118  	}
   119  	if size < 0 {
   120  		panic("Cannot have negative sizes when making array")
   121  	}
   122  	arr.Header.Raw = make([]byte, size*4)
   123  	arr.t = t
   124  }
   125  
   126  func (e Float32Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) {
   127  	reuse := y
   128  	if err = e.checkThree(a, x, reuse); err != nil {
   129  		return nil, errors.Wrap(err, "Failed checks")
   130  	}
   131  
   132  	var dataA, dataB, dataReuse *storage.Header
   133  	var ait, bit, iit Iterator
   134  	var useIter bool
   135  	if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, x, reuse); err != nil {
   136  		return nil, errors.Wrap(err, "Float32Engine.FMA")
   137  	}
   138  	if useIter {
   139  		err = execution.MulIterIncrF32(dataA.Float32s(), dataB.Float32s(), dataReuse.Float32s(), ait, bit, iit)
   140  		retVal = reuse
   141  		return
   142  	}
   143  
   144  	vecf32.IncrMul(dataA.Float32s(), dataB.Float32s(), dataReuse.Float32s())
   145  	retVal = reuse
   146  	return
   147  }
   148  
   149  func (e Float32Engine) FMAScalar(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) {
   150  	reuse := y
   151  	if err = e.checkTwo(a, reuse); err != nil {
   152  		return nil, errors.Wrap(err, "Failed checks")
   153  	}
   154  
   155  	var ait, iit Iterator
   156  	var dataTensor, dataReuse *storage.Header
   157  	var scalar float32
   158  	var useIter bool
   159  	if dataTensor, scalar, dataReuse, ait, iit, useIter, err = prepDataVSF32(a, x, reuse); err != nil {
   160  		return nil, errors.Wrapf(err, opFail, "Float32Engine.FMAScalar")
   161  	}
   162  	if useIter {
   163  		err = execution.MulIterIncrVSF32(dataTensor.Float32s(), scalar, dataReuse.Float32s(), ait, iit)
   164  		retVal = reuse
   165  	}
   166  
   167  	execution.MulIncrVSF32(dataTensor.Float32s(), scalar, dataReuse.Float32s())
   168  	retVal = reuse
   169  	return
   170  }
   171  
   172  // Add performs a + b elementwise. Both a and b must have the same shape.
   173  // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T)
   174  func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
   175  	if a.RequiresIterator() || b.RequiresIterator() {
   176  		return e.StdEng.Add(a, b, opts...)
   177  	}
   178  
   179  	var reuse DenseTensor
   180  	var safe, toReuse, incr bool
   181  	if reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), a.DataOrder(), opts...); err != nil {
   182  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
   183  	}
   184  	if err = e.checkThree(a, b, reuse); err != nil {
   185  		return nil, errors.Wrap(err, "Failed checks")
   186  	}
   187  
   188  	var hdrA, hdrB, hdrReuse *storage.Header
   189  	var dataA, dataB, dataReuse []float32
   190  
   191  	if hdrA, hdrB, hdrReuse, _, _, _, _, _, err = prepDataVV(a, b, reuse); err != nil {
   192  		return nil, errors.Wrapf(err, "Float32Engine.Add")
   193  	}
   194  	dataA = hdrA.Float32s()
   195  	dataB = hdrB.Float32s()
   196  	if hdrReuse != nil {
   197  		dataReuse = hdrReuse.Float32s()
   198  	}
   199  
   200  	switch {
   201  	case incr:
   202  		vecf32.IncrAdd(dataA, dataB, dataReuse)
   203  		retVal = reuse
   204  	case toReuse:
   205  		copy(dataReuse, dataA)
   206  		vecf32.Add(dataReuse, dataB)
   207  		retVal = reuse
   208  	case !safe:
   209  		vecf32.Add(dataA, dataB)
   210  		retVal = a
   211  	default:
   212  		ret := a.Clone().(headerer)
   213  		vecf32.Add(ret.hdr().Float32s(), dataB)
   214  		retVal = ret.(Tensor)
   215  	}
   216  	return
   217  }
   218  
   219  func (e Float32Engine) Inner(a, b Tensor) (retVal float32, err error) {
   220  	var A, B []float32
   221  	var AD, BD *Dense
   222  	var ok bool
   223  
   224  	if AD, ok = a.(*Dense); !ok {
   225  		return 0, errors.Errorf("a is not a *Dense")
   226  	}
   227  	if BD, ok = b.(*Dense); !ok {
   228  		return 0, errors.Errorf("b is not a *Dense")
   229  	}
   230  
   231  	A = AD.Float32s()
   232  	B = BD.Float32s()
   233  	retVal = whichblas.Sdot(len(A), A, 1, B, 1)
   234  	return
   235  }