github.com/wzzhu/tensor@v0.9.24/defaultenginefloat32.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 "github.com/wzzhu/tensor/internal/execution" 6 "github.com/wzzhu/tensor/internal/storage" 7 8 "gorgonia.org/vecf32" 9 ) 10 11 func handleFuncOptsF32(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { 12 fo := ParseFuncOpts(opts...) 13 14 reuseT, incr := fo.IncrReuse() 15 safe = fo.Safe() 16 toReuse = reuseT != nil 17 18 if toReuse { 19 var ok bool 20 if reuse, ok = reuseT.(DenseTensor); !ok { 21 returnOpOpt(fo) 22 err = errors.Errorf("Cannot reuse a different type of Tensor in a *Dense-Scalar operation. Reuse is of %T", reuseT) 23 return 24 } 25 if reuse.len() != expShape.TotalSize() && !expShape.IsScalar() { 26 returnOpOpt(fo) 27 err = errors.Errorf(shapeMismatch, reuse.Shape(), expShape) 28 err = errors.Wrapf(err, "Cannot use reuse: shape mismatch") 29 return 30 } 31 32 if !incr && reuse != nil { 33 reuse.setDataOrder(o) 34 // err = reuse.reshape(expShape...) 35 } 36 37 } 38 returnOpOpt(fo) 39 return 40 } 41 42 func prepDataVSF32(a Tensor, b interface{}, reuse Tensor) (dataA *storage.Header, dataB float32, dataReuse *storage.Header, ait, iit Iterator, useIter bool, err error) { 43 // get data 44 dataA = a.hdr() 45 switch bt := b.(type) { 46 case float32: 47 dataB = bt 48 case *float32: 49 dataB = *bt 50 default: 51 err = errors.Errorf("b is not a float32: %T", b) 52 return 53 } 54 if reuse != nil { 55 dataReuse = reuse.hdr() 56 } 57 58 if a.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) { 59 ait = a.Iterator() 60 if reuse != nil { 61 iit = reuse.Iterator() 62 } 63 useIter = true 64 } 65 return 66 } 67 68 func (e Float32Engine) checkThree(a, b Tensor, reuse Tensor) error { 69 if !a.IsNativelyAccessible() { 70 return errors.Errorf(inaccessibleData, a) 71 } 72 if !b.IsNativelyAccessible() { 73 return errors.Errorf(inaccessibleData, b) 74 } 75 76 if reuse != nil && !reuse.IsNativelyAccessible() { 77 return errors.Errorf(inaccessibleData, reuse) 78 } 79 80 if a.Dtype() != Float32 { 81 return errors.Errorf("Expected a to be of Float32. Got %v instead", a.Dtype()) 82 } 83 if a.Dtype() != b.Dtype() || (reuse != nil && b.Dtype() != reuse.Dtype()) { 84 return errors.Errorf("Expected a, b and reuse to have the same Dtype. Got %v, %v and %v instead", a.Dtype(), b.Dtype(), reuse.Dtype()) 85 } 86 return nil 87 } 88 89 func (e Float32Engine) checkTwo(a Tensor, reuse Tensor) error { 90 if !a.IsNativelyAccessible() { 91 return errors.Errorf(inaccessibleData, a) 92 } 93 if reuse != nil && !reuse.IsNativelyAccessible() { 94 return errors.Errorf(inaccessibleData, reuse) 95 } 96 97 if a.Dtype() != Float32 { 98 return errors.Errorf("Expected a to be of Float32. Got %v instead", a.Dtype()) 99 } 100 101 if reuse != nil && reuse.Dtype() != a.Dtype() { 102 return errors.Errorf("Expected reuse to be the same as a. Got %v instead", reuse.Dtype()) 103 } 104 return nil 105 } 106 107 // Float32Engine is an execution engine that is optimized to only work with float32s. It assumes all data will are float32s. 108 // 109 // Use this engine only as form of optimization. You should probably be using the basic default engine for most cases. 110 type Float32Engine struct { 111 StdEng 112 } 113 114 // makeArray allocates a slice for the array 115 func (e Float32Engine) makeArray(arr *array, t Dtype, size int) { 116 if t != Float32 { 117 panic("Float32Engine only creates float32s") 118 } 119 if size < 0 { 120 panic("Cannot have negative sizes when making array") 121 } 122 arr.Header.Raw = make([]byte, size*4) 123 arr.t = t 124 } 125 126 func (e Float32Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { 127 reuse := y 128 if err = e.checkThree(a, x, reuse); err != nil { 129 return nil, errors.Wrap(err, "Failed checks") 130 } 131 132 var dataA, dataB, dataReuse *storage.Header 133 var ait, bit, iit Iterator 134 var useIter bool 135 if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, x, reuse); err != nil { 136 return nil, errors.Wrap(err, "Float32Engine.FMA") 137 } 138 if useIter { 139 err = execution.MulIterIncrF32(dataA.Float32s(), dataB.Float32s(), dataReuse.Float32s(), ait, bit, iit) 140 retVal = reuse 141 return 142 } 143 144 vecf32.IncrMul(dataA.Float32s(), dataB.Float32s(), dataReuse.Float32s()) 145 retVal = reuse 146 return 147 } 148 149 func (e Float32Engine) FMAScalar(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { 150 reuse := y 151 if err = e.checkTwo(a, reuse); err != nil { 152 return nil, errors.Wrap(err, "Failed checks") 153 } 154 155 var ait, iit Iterator 156 var dataTensor, dataReuse *storage.Header 157 var scalar float32 158 var useIter bool 159 if dataTensor, scalar, dataReuse, ait, iit, useIter, err = prepDataVSF32(a, x, reuse); err != nil { 160 return nil, errors.Wrapf(err, opFail, "Float32Engine.FMAScalar") 161 } 162 if useIter { 163 err = execution.MulIterIncrVSF32(dataTensor.Float32s(), scalar, dataReuse.Float32s(), ait, iit) 164 retVal = reuse 165 } 166 167 execution.MulIncrVSF32(dataTensor.Float32s(), scalar, dataReuse.Float32s()) 168 retVal = reuse 169 return 170 } 171 172 // Add performs a + b elementwise. Both a and b must have the same shape. 173 // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) 174 func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { 175 if a.RequiresIterator() || b.RequiresIterator() { 176 return e.StdEng.Add(a, b, opts...) 177 } 178 179 var reuse DenseTensor 180 var safe, toReuse, incr bool 181 if reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), a.DataOrder(), opts...); err != nil { 182 return nil, errors.Wrap(err, "Unable to handle funcOpts") 183 } 184 if err = e.checkThree(a, b, reuse); err != nil { 185 return nil, errors.Wrap(err, "Failed checks") 186 } 187 188 var hdrA, hdrB, hdrReuse *storage.Header 189 var dataA, dataB, dataReuse []float32 190 191 if hdrA, hdrB, hdrReuse, _, _, _, _, _, err = prepDataVV(a, b, reuse); err != nil { 192 return nil, errors.Wrapf(err, "Float32Engine.Add") 193 } 194 dataA = hdrA.Float32s() 195 dataB = hdrB.Float32s() 196 if hdrReuse != nil { 197 dataReuse = hdrReuse.Float32s() 198 } 199 200 switch { 201 case incr: 202 vecf32.IncrAdd(dataA, dataB, dataReuse) 203 retVal = reuse 204 case toReuse: 205 copy(dataReuse, dataA) 206 vecf32.Add(dataReuse, dataB) 207 retVal = reuse 208 case !safe: 209 vecf32.Add(dataA, dataB) 210 retVal = a 211 default: 212 ret := a.Clone().(headerer) 213 vecf32.Add(ret.hdr().Float32s(), dataB) 214 retVal = ret.(Tensor) 215 } 216 return 217 } 218 219 func (e Float32Engine) Inner(a, b Tensor) (retVal float32, err error) { 220 var A, B []float32 221 var AD, BD *Dense 222 var ok bool 223 224 if AD, ok = a.(*Dense); !ok { 225 return 0, errors.Errorf("a is not a *Dense") 226 } 227 if BD, ok = b.(*Dense); !ok { 228 return 0, errors.Errorf("b is not a *Dense") 229 } 230 231 A = AD.Float32s() 232 B = BD.Float32s() 233 retVal = whichblas.Sdot(len(A), A, 1, B, 1) 234 return 235 }