gorgonia.org/tensor@v0.9.24/defaultengine_prep.go (about) 1 package tensor 2 3 import ( 4 "reflect" 5 6 "github.com/pkg/errors" 7 "gorgonia.org/tensor/internal/storage" 8 // "log" 9 ) 10 11 func handleFuncOpts(expShape Shape, expType Dtype, o DataOrder, strict bool, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr, same bool, err error) { 12 fo := ParseFuncOpts(opts...) 13 14 reuseT, incr := fo.IncrReuse() 15 safe = fo.Safe() 16 same = fo.Same() 17 toReuse = reuseT != nil 18 19 if toReuse { 20 if reuse, err = getDenseTensor(reuseT); err != nil { 21 returnOpOpt(fo) 22 err = errors.Wrapf(err, "Cannot reuse a Tensor that isn't a DenseTensor. Got %T instead", reuseT) 23 return 24 } 25 26 if reuse != nil && !reuse.IsNativelyAccessible() { 27 returnOpOpt(fo) 28 err = errors.Errorf(inaccessibleData, reuse) 29 return 30 } 31 32 if (strict || same) && reuse.Dtype() != expType { 33 returnOpOpt(fo) 34 err = errors.Errorf(typeMismatch, expType, reuse.Dtype()) 35 err = errors.Wrapf(err, "Cannot use reuse") 36 return 37 } 38 39 if reuse.len() != expShape.TotalSize() && !expShape.IsScalar() { 40 returnOpOpt(fo) 41 err = errors.Errorf(shapeMismatch, reuse.Shape(), expShape) 42 err = errors.Wrapf(err, "Cannot use reuse: shape mismatch - reuse.len() %v, expShape.TotalSize() %v", reuse.len(), expShape.TotalSize()) 43 return 44 } 45 if !reuse.Shape().Eq(expShape) { 46 cloned := expShape.Clone() 47 if err = reuse.Reshape(cloned...); err != nil { 48 return 49 50 } 51 ReturnInts([]int(cloned)) 52 } 53 54 if !incr && reuse != nil { 55 reuse.setDataOrder(o) 56 // err = reuse.reshape(expShape...) 57 } 58 59 } 60 returnOpOpt(fo) 61 return 62 } 63 64 func binaryCheck(a, b Tensor, tc *typeclass) (err error) { 65 // check if the tensors are accessible 66 if !a.IsNativelyAccessible() { 67 return errors.Errorf(inaccessibleData, a) 68 } 69 70 if !b.IsNativelyAccessible() { 71 return errors.Errorf(inaccessibleData, b) 72 } 73 74 at := a.Dtype() 75 bt := b.Dtype() 76 if tc != nil { 77 if err = typeclassCheck(at, tc); err != nil { 78 return errors.Wrapf(err, typeclassMismatch, "a") 79 } 80 if err = typeclassCheck(bt, tc); err != nil { 81 return errors.Wrapf(err, typeclassMismatch, "b") 82 } 83 } 84 85 if at.Kind() != bt.Kind() { 86 return errors.Errorf(typeMismatch, at, bt) 87 } 88 if !a.Shape().Eq(b.Shape()) { 89 return errors.Errorf(shapeMismatch, b.Shape(), a.Shape()) 90 } 91 return nil 92 } 93 94 func unaryCheck(a Tensor, tc *typeclass) error { 95 if !a.IsNativelyAccessible() { 96 return errors.Errorf(inaccessibleData, a) 97 } 98 at := a.Dtype() 99 if tc != nil { 100 if err := typeclassCheck(at, tc); err != nil { 101 return errors.Wrapf(err, typeclassMismatch, "a") 102 } 103 } 104 return nil 105 } 106 107 // scalarDtypeCheck checks that a scalar value has the same dtype as the dtype of a given tensor. 108 func scalarDtypeCheck(a Tensor, b interface{}) error { 109 var dt Dtype 110 switch bt := b.(type) { 111 case Dtyper: 112 dt = bt.Dtype() 113 default: 114 t := reflect.TypeOf(b) 115 dt = Dtype{t} 116 } 117 118 if a.Dtype() != dt { 119 return errors.Errorf("Expected scalar to have the same Dtype as the tensor (%v). Got %T instead ", a.Dtype(), b) 120 } 121 return nil 122 } 123 124 // prepDataVV prepares the data given the input and reuse tensors. It also retruns several indicators 125 // 126 // useIter indicates that the iterator methods should be used. 127 // swap indicates that the operands are swapped. 128 func prepDataVV(a, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, ait, bit, iit Iterator, useIter, swap bool, err error) { 129 // get data 130 dataA = a.hdr() 131 dataB = b.hdr() 132 if reuse != nil { 133 dataReuse = reuse.hdr() 134 } 135 136 // iter 137 useIter = a.RequiresIterator() || 138 b.RequiresIterator() || 139 (reuse != nil && reuse.RequiresIterator()) || 140 !a.DataOrder().HasSameOrder(b.DataOrder()) || 141 (reuse != nil && (!a.DataOrder().HasSameOrder(reuse.DataOrder()) || !b.DataOrder().HasSameOrder(reuse.DataOrder()))) 142 if useIter { 143 ait = a.Iterator() 144 bit = b.Iterator() 145 if reuse != nil { 146 iit = reuse.Iterator() 147 } 148 } 149 150 // swap 151 if _, ok := a.(*CS); ok { 152 if _, ok := b.(DenseTensor); ok { 153 swap = true 154 dataA, dataB = dataB, dataA 155 ait, bit = bit, ait 156 } 157 } 158 159 return 160 } 161 162 func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, ait, iit Iterator, useIter bool, newAlloc bool, err error) { 163 // get data 164 dataA = a.hdr() 165 dataB, newAlloc = scalarToHeader(b) 166 if reuse != nil { 167 dataReuse = reuse.hdr() 168 } 169 170 if a.IsScalar() { 171 return 172 } 173 useIter = a.RequiresIterator() || 174 (reuse != nil && reuse.RequiresIterator()) || 175 (reuse != nil && !reuse.DataOrder().HasSameOrder(a.DataOrder())) 176 if useIter { 177 ait = a.Iterator() 178 if reuse != nil { 179 iit = reuse.Iterator() 180 } 181 } 182 return 183 } 184 185 func prepDataSV(a interface{}, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, bit, iit Iterator, useIter bool, newAlloc bool, err error) { 186 // get data 187 dataA, newAlloc = scalarToHeader(a) 188 dataB = b.hdr() 189 if reuse != nil { 190 dataReuse = reuse.hdr() 191 } 192 193 // get iterator 194 if b.IsScalar() { 195 return 196 } 197 useIter = b.RequiresIterator() || 198 (reuse != nil && reuse.RequiresIterator()) || 199 (reuse != nil && !reuse.DataOrder().HasSameOrder(b.DataOrder())) 200 201 if useIter { 202 bit = b.Iterator() 203 if reuse != nil { 204 iit = reuse.Iterator() 205 } 206 } 207 return 208 } 209 210 func prepDataUnary(a Tensor, reuse Tensor) (dataA, dataReuse *storage.Header, ait, rit Iterator, useIter bool, err error) { 211 // get data 212 dataA = a.hdr() 213 if reuse != nil { 214 dataReuse = reuse.hdr() 215 } 216 217 // get iterator 218 if a.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) { 219 ait = a.Iterator() 220 if reuse != nil { 221 rit = reuse.Iterator() 222 } 223 useIter = true 224 } 225 return 226 }