github.com/wzzhu/tensor@v0.9.24/defaultenginefloat64.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 "github.com/wzzhu/tensor/internal/execution" 6 "github.com/wzzhu/tensor/internal/storage" 7 8 "gorgonia.org/vecf64" 9 ) 10 11 func handleFuncOptsF64(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { 12 fo := ParseFuncOpts(opts...) 13 14 reuseT, incr := fo.IncrReuse() 15 safe = fo.Safe() 16 toReuse = reuseT != nil 17 18 if toReuse { 19 var ok bool 20 if reuse, ok = reuseT.(DenseTensor); !ok { 21 returnOpOpt(fo) 22 err = errors.Errorf("Cannot reuse a different type of Tensor in a *Dense-Scalar operation. Reuse is of %T", reuseT) 23 return 24 } 25 if reuse.len() != expShape.TotalSize() && !expShape.IsScalar() { 26 returnOpOpt(fo) 27 err = errors.Errorf(shapeMismatch, reuse.Shape(), expShape) 28 err = errors.Wrapf(err, "Cannot use reuse: shape mismatch") 29 return 30 } 31 32 if !incr && reuse != nil { 33 reuse.setDataOrder(o) 34 // err = reuse.reshape(expShape...) 35 } 36 37 } 38 returnOpOpt(fo) 39 return 40 } 41 42 func prepDataVSF64(a Tensor, b interface{}, reuse Tensor) (dataA *storage.Header, dataB float64, dataReuse *storage.Header, ait, iit Iterator, useIter bool, err error) { 43 // get data 44 dataA = a.hdr() 45 switch bt := b.(type) { 46 case float64: 47 dataB = bt 48 case *float64: 49 dataB = *bt 50 default: 51 err = errors.Errorf("b is not a float64: %T", b) 52 return 53 } 54 if reuse != nil { 55 dataReuse = reuse.hdr() 56 } 57 58 if a.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) { 59 ait = a.Iterator() 60 if reuse != nil { 61 iit = reuse.Iterator() 62 } 63 useIter = true 64 } 65 return 66 } 67 68 func (e Float64Engine) checkThree(a, b Tensor, reuse Tensor) error { 69 if !a.IsNativelyAccessible() { 70 return errors.Errorf(inaccessibleData, a) 71 } 72 if !b.IsNativelyAccessible() { 73 return errors.Errorf(inaccessibleData, b) 74 } 75 76 if reuse != nil && !reuse.IsNativelyAccessible() { 77 return errors.Errorf(inaccessibleData, reuse) 78 } 79 80 if a.Dtype() != Float64 { 81 return errors.Errorf("Expected a to be of Float64. Got %v instead", a.Dtype()) 82 } 83 if a.Dtype() != b.Dtype() || (reuse != nil && b.Dtype() != reuse.Dtype()) { 84 return errors.Errorf("Expected a, b and reuse to have the same Dtype. Got %v, %v and %v instead", a.Dtype(), b.Dtype(), reuse.Dtype()) 85 } 86 return nil 87 } 88 89 func (e Float64Engine) checkTwo(a Tensor, reuse Tensor) error { 90 if !a.IsNativelyAccessible() { 91 return errors.Errorf(inaccessibleData, a) 92 } 93 if reuse != nil && !reuse.IsNativelyAccessible() { 94 return errors.Errorf(inaccessibleData, reuse) 95 } 96 97 if a.Dtype() != Float64 { 98 return errors.Errorf("Expected a to be of Float64. Got %v instead", a.Dtype()) 99 } 100 101 if reuse != nil && reuse.Dtype() != a.Dtype() { 102 return errors.Errorf("Expected reuse to be the same as a. Got %v instead", reuse.Dtype()) 103 } 104 return nil 105 } 106 107 // Float64Engine is an execution engine that is optimized to only work with float64s. It assumes all data will are float64s. 108 // 109 // Use this engine only as form of optimization. You should probably be using the basic default engine for most cases. 110 type Float64Engine struct { 111 StdEng 112 } 113 114 // makeArray allocates a slice for the array 115 func (e Float64Engine) makeArray(arr *array, t Dtype, size int) { 116 if t != Float64 { 117 panic("Float64Engine only creates float64s") 118 } 119 arr.Header.Raw = make([]byte, size*8) 120 arr.t = t 121 } 122 123 func (e Float64Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { 124 reuse := y 125 if err = e.checkThree(a, x, reuse); err != nil { 126 return nil, errors.Wrap(err, "Failed checks") 127 } 128 129 var dataA, dataB, dataReuse *storage.Header 130 var ait, bit, iit Iterator 131 var useIter bool 132 if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, x, reuse); err != nil { 133 return nil, errors.Wrap(err, "Float64Engine.FMA") 134 } 135 if useIter { 136 err = execution.MulIterIncrF64(dataA.Float64s(), dataB.Float64s(), dataReuse.Float64s(), ait, bit, iit) 137 retVal = reuse 138 return 139 } 140 141 vecf64.IncrMul(dataA.Float64s(), dataB.Float64s(), dataReuse.Float64s()) 142 retVal = reuse 143 return 144 } 145 146 func (e Float64Engine) FMAScalar(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { 147 reuse := y 148 if err = e.checkTwo(a, reuse); err != nil { 149 return nil, errors.Wrap(err, "Failed checks") 150 } 151 152 var ait, iit Iterator 153 var dataTensor, dataReuse *storage.Header 154 var scalar float64 155 var useIter bool 156 if dataTensor, scalar, dataReuse, ait, iit, useIter, err = prepDataVSF64(a, x, reuse); err != nil { 157 return nil, errors.Wrapf(err, opFail, "Float64Engine.FMAScalar") 158 } 159 if useIter { 160 err = execution.MulIterIncrVSF64(dataTensor.Float64s(), scalar, dataReuse.Float64s(), ait, iit) 161 retVal = reuse 162 } 163 164 execution.MulIncrVSF64(dataTensor.Float64s(), scalar, dataReuse.Float64s()) 165 retVal = reuse 166 return 167 } 168 169 // Add performs a + b elementwise. Both a and b must have the same shape. 170 // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) 171 func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { 172 if a.RequiresIterator() || b.RequiresIterator() { 173 return e.StdEng.Add(a, b, opts...) 174 } 175 176 var reuse DenseTensor 177 var safe, toReuse, incr bool 178 if reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), a.DataOrder(), opts...); err != nil { 179 return nil, errors.Wrap(err, "Unable to handle funcOpts") 180 } 181 if err = e.checkThree(a, b, reuse); err != nil { 182 return nil, errors.Wrap(err, "Failed checks") 183 } 184 185 var hdrA, hdrB, hdrReuse *storage.Header 186 var dataA, dataB, dataReuse []float64 187 188 if hdrA, hdrB, hdrReuse, _, _, _, _, _, err = prepDataVV(a, b, reuse); err != nil { 189 return nil, errors.Wrapf(err, "Float64Engine.Add") 190 } 191 dataA = hdrA.Float64s() 192 dataB = hdrB.Float64s() 193 if hdrReuse != nil { 194 dataReuse = hdrReuse.Float64s() 195 } 196 197 switch { 198 case incr: 199 vecf64.IncrAdd(dataA, dataB, dataReuse) 200 retVal = reuse 201 case toReuse: 202 copy(dataReuse, dataA) 203 vecf64.Add(dataReuse, dataB) 204 retVal = reuse 205 case !safe: 206 vecf64.Add(dataA, dataB) 207 retVal = a 208 default: 209 ret := a.Clone().(headerer) 210 vecf64.Add(ret.hdr().Float64s(), dataB) 211 retVal = ret.(Tensor) 212 } 213 return 214 } 215 216 func (e Float64Engine) Inner(a, b Tensor) (retVal float64, err error) { 217 var A, B []float64 218 var AD, BD *Dense 219 var ok bool 220 221 if AD, ok = a.(*Dense); !ok { 222 return 0, errors.Errorf("a is not a *Dense") 223 } 224 if BD, ok = b.(*Dense); !ok { 225 return 0, errors.Errorf("b is not a *Dense") 226 } 227 228 A = AD.Float64s() 229 B = BD.Float64s() 230 retVal = whichblas.Ddot(len(A), A, 1, B, 1) 231 return 232 }