github.com/wzzhu/tensor@v0.9.24/defaultengine_misc.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	"github.com/wzzhu/tensor/internal/storage"
     6  )
     7  
     8  func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
     9  	if err = unaryCheck(a, nonComplexNumberTypes); err != nil {
    10  		return nil, errors.Wrap(err, "Clamp failed")
    11  	}
    12  
    13  	var reuse DenseTensor
    14  	var safe, toReuse, incr bool
    15  	if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil {
    16  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
    17  	}
    18  
    19  	typ := a.Dtype().Type
    20  	var ait, rit Iterator
    21  	var dataA, dataReuse *storage.Header
    22  	var useIter bool
    23  
    24  	if dataA, dataReuse, ait, rit, useIter, err = prepDataUnary(a, reuse); err != nil {
    25  		return nil, errors.Wrapf(err, opFail, "StdEng.Neg")
    26  	}
    27  
    28  	if useIter {
    29  		switch {
    30  		case incr:
    31  			cloned := a.Clone().(Tensor)
    32  			if err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max); err != nil {
    33  				return nil, errors.Wrapf(err, "Unable to perform Clamp")
    34  			}
    35  			ait.Reset()
    36  			err = e.E.AddIter(typ, dataReuse, cloned.hdr(), rit, ait)
    37  			retVal = reuse
    38  		case toReuse:
    39  			storage.CopyIter(typ, dataReuse, dataA, rit, ait)
    40  			rit.Reset()
    41  			err = e.E.ClampIter(typ, dataReuse, rit, min, max)
    42  			retVal = reuse
    43  		case !safe:
    44  			err = e.E.ClampIter(typ, dataA, ait, min, max)
    45  			retVal = a
    46  		default:
    47  			cloned := a.Clone().(Tensor)
    48  			err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max)
    49  			retVal = cloned
    50  		}
    51  		return
    52  	}
    53  	switch {
    54  	case incr:
    55  		cloned := a.Clone().(Tensor)
    56  		if err = e.E.Clamp(typ, cloned.hdr(), min, max); err != nil {
    57  			return nil, errors.Wrapf(err, "Unable to perform Clamp")
    58  		}
    59  		err = e.E.Add(typ, dataReuse, cloned.hdr())
    60  		retVal = reuse
    61  	case toReuse:
    62  		storage.Copy(typ, dataReuse, dataA)
    63  		err = e.E.Clamp(typ, dataReuse, min, max)
    64  		retVal = reuse
    65  	case !safe:
    66  		err = e.E.Clamp(typ, dataA, min, max)
    67  		retVal = a
    68  	default:
    69  		cloned := a.Clone().(Tensor)
    70  		err = e.E.Clamp(typ, cloned.hdr(), min, max)
    71  		retVal = cloned
    72  	}
    73  	return
    74  }
    75  
    76  func (e StdEng) FMA(a, x, y Tensor) (Tensor, error) {
    77  	return e.Mul(a, x, WithIncr(y))
    78  }
    79  func (e StdEng) FMAScalar(a Tensor, x interface{}, y Tensor) (Tensor, error) {
    80  	return e.MulScalar(a, x, true, WithIncr(y))
    81  }