github.com/wzzhu/tensor@v0.9.24/defaultengine_misc.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 "github.com/wzzhu/tensor/internal/storage" 6 ) 7 8 func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 9 if err = unaryCheck(a, nonComplexNumberTypes); err != nil { 10 return nil, errors.Wrap(err, "Clamp failed") 11 } 12 13 var reuse DenseTensor 14 var safe, toReuse, incr bool 15 if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { 16 return nil, errors.Wrap(err, "Unable to handle funcOpts") 17 } 18 19 typ := a.Dtype().Type 20 var ait, rit Iterator 21 var dataA, dataReuse *storage.Header 22 var useIter bool 23 24 if dataA, dataReuse, ait, rit, useIter, err = prepDataUnary(a, reuse); err != nil { 25 return nil, errors.Wrapf(err, opFail, "StdEng.Neg") 26 } 27 28 if useIter { 29 switch { 30 case incr: 31 cloned := a.Clone().(Tensor) 32 if err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max); err != nil { 33 return nil, errors.Wrapf(err, "Unable to perform Clamp") 34 } 35 ait.Reset() 36 err = e.E.AddIter(typ, dataReuse, cloned.hdr(), rit, ait) 37 retVal = reuse 38 case toReuse: 39 storage.CopyIter(typ, dataReuse, dataA, rit, ait) 40 rit.Reset() 41 err = e.E.ClampIter(typ, dataReuse, rit, min, max) 42 retVal = reuse 43 case !safe: 44 err = e.E.ClampIter(typ, dataA, ait, min, max) 45 retVal = a 46 default: 47 cloned := a.Clone().(Tensor) 48 err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max) 49 retVal = cloned 50 } 51 return 52 } 53 switch { 54 case incr: 55 cloned := a.Clone().(Tensor) 56 if err = e.E.Clamp(typ, cloned.hdr(), min, max); err != nil { 57 return nil, errors.Wrapf(err, "Unable to perform Clamp") 58 } 59 err = e.E.Add(typ, dataReuse, cloned.hdr()) 60 retVal = reuse 61 case toReuse: 62 storage.Copy(typ, dataReuse, dataA) 63 err = e.E.Clamp(typ, dataReuse, min, max) 64 retVal = reuse 65 case !safe: 66 err = e.E.Clamp(typ, dataA, min, max) 67 retVal = a 68 default: 69 cloned := a.Clone().(Tensor) 70 err = e.E.Clamp(typ, cloned.hdr(), min, max) 71 retVal = cloned 72 } 73 return 74 } 75 76 func (e StdEng) FMA(a, x, y Tensor) (Tensor, error) { 77 return e.Mul(a, x, WithIncr(y)) 78 } 79 func (e StdEng) FMAScalar(a Tensor, x interface{}, y Tensor) (Tensor, error) { 80 return e.MulScalar(a, x, true, WithIncr(y)) 81 }