github.com/wzzhu/tensor@v0.9.24/defaultengine_mapreduce.go (about) 1 package tensor 2 3 import ( 4 "reflect" 5 "sort" 6 7 "github.com/pkg/errors" 8 9 "github.com/wzzhu/tensor/internal/execution" 10 "github.com/wzzhu/tensor/internal/storage" 11 ) 12 13 func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { 14 if err = unaryCheck(a, nil); err != nil { 15 err = errors.Wrap(err, "Failed Map()") 16 return 17 } 18 19 var reuse DenseTensor 20 var safe, _, incr bool 21 if reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { 22 return 23 } 24 switch { 25 case safe && reuse == nil: 26 // create reuse 27 if v, ok := a.(View); ok { 28 if v.IsMaterializable() { 29 reuse = v.Materialize().(DenseTensor) 30 } else { 31 reuse = v.Clone().(DenseTensor) 32 } 33 } else { 34 reuse = New(Of(a.Dtype()), WithShape(a.Shape().Clone()...)) 35 } 36 case reuse != nil: 37 if !reuse.IsNativelyAccessible() { 38 return nil, errors.Errorf(inaccessibleData, reuse) 39 } 40 if a.Size() != reuse.Size() { 41 return nil, errors.Errorf(shapeMismatch, a.Shape(), reuse.Shape()) 42 } 43 } 44 45 // PREP DATA 46 typ := a.Dtype().Type 47 var dataA, dataReuse, used *storage.Header 48 var ait, rit, uit Iterator 49 var useIter bool 50 if dataA, dataReuse, ait, rit, useIter, err = prepDataUnary(a, reuse); err != nil { 51 return nil, errors.Wrapf(err, "StdEng.Map") 52 } 53 54 // HANDLE USE CASES 55 switch { 56 case !safe: 57 used = dataA 58 uit = ait 59 default: 60 used = dataReuse 61 uit = rit 62 } 63 64 // DO 65 if useIter { 66 err = e.E.MapIter(typ, fn, used, incr, uit) 67 } else { 68 err = e.E.Map(typ, fn, used, incr) 69 } 70 if err != nil { 71 err = errors.Wrapf(err, "Unable to apply function %v to tensor of %v", fn, typ) 72 return 73 } 74 75 // SET RETVAL 76 switch { 77 case reuse != nil: 78 if err = reuseCheckShape(reuse, a.Shape()); err != nil { 79 err = errors.Wrapf(err, "Reuse shape check failed") 80 return 81 } 82 retVal = reuse 83 case !safe: 84 retVal = a 85 default: 86 retVal = reuse 87 } 88 return 89 } 90 91 func (e StdEng) Reduce(fn interface{}, a Tensor, axis int, defaultValue interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 92 if !a.IsNativelyAccessible() { 93 return nil, errors.Errorf(inaccessibleData, a) 94 } 95 var at, reuse DenseTensor 96 var dataA, dataReuse *storage.Header 97 if at, reuse, dataA, dataReuse, err = e.prepReduce(a, axis, opts...); err != nil { 98 err = errors.Wrap(err, "Prep Reduce failed") 99 return 100 } 101 102 lastAxis := a.Dims() - 1 103 typ := a.Dtype().Type 104 105 // actual call out to the internal engine 106 switch { 107 case (axis == 0 && at.DataOrder().IsRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().IsColMajor()): 108 var size, split int 109 if at.DataOrder().IsColMajor() { 110 return nil, errors.Errorf("NYI: colmajor") 111 } 112 size = a.Shape()[0] 113 split = a.DataSize() / size 114 storage.CopySliced(typ, dataReuse, 0, split, dataA, 0, split) 115 err = e.E.ReduceFirst(typ, dataA, dataReuse, split, size, fn) 116 case (axis == lastAxis && at.DataOrder().IsRowMajor()) || (axis == 0 && at.DataOrder().IsColMajor()): 117 var dimSize int 118 if at.DataOrder().IsColMajor() { 119 return nil, errors.Errorf("NYI: colmajor") 120 } 121 dimSize = a.Shape()[axis] 122 err = e.E.ReduceLast(typ, dataA, dataReuse, dimSize, defaultValue, fn) 123 default: 124 dim0 := a.Shape()[0] 125 dimSize := a.Shape()[axis] 126 outerStride := a.Strides()[0] 127 stride := a.Strides()[axis] 128 expected := reuse.Strides()[0] 129 err = e.E.ReduceDefault(typ, dataA, dataReuse, dim0, dimSize, outerStride, stride, expected, fn) 130 } 131 retVal = reuse 132 return 133 } 134 135 func (e StdEng) OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn, defaultValue interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 136 if !a.IsNativelyAccessible() { 137 return nil, errors.Errorf(inaccessibleData, a) 138 } 139 140 var at, reuse DenseTensor 141 var dataA, dataReuse *storage.Header 142 if at, reuse, dataA, dataReuse, err = e.prepReduce(a, axis, opts...); err != nil { 143 err = errors.Wrap(err, "Prep Reduce failed") 144 return 145 } 146 147 lastAxis := a.Dims() - 1 148 typ := a.Dtype().Type 149 150 // actual call out to the internal engine 151 switch { 152 case (axis == 0 && at.DataOrder().IsRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().IsColMajor()): 153 var size, split int 154 if at.DataOrder().IsColMajor() { 155 return nil, errors.Errorf("NYI: colmajor") 156 } 157 size = a.Shape()[0] 158 split = a.DataSize() / size 159 storage.CopySliced(typ, dataReuse, 0, split, dataA, 0, split) 160 err = e.E.ReduceFirst(typ, dataA, dataReuse, split, size, firstFn) 161 case (axis == lastAxis && at.DataOrder().IsRowMajor()) || (axis == 0 && at.DataOrder().IsColMajor()): 162 var dimSize int 163 if at.DataOrder().IsColMajor() { 164 return nil, errors.Errorf("NYI: colmajor") 165 } 166 dimSize = a.Shape()[axis] 167 err = e.E.ReduceLast(typ, dataA, dataReuse, dimSize, defaultValue, lastFn) 168 default: 169 dim0 := a.Shape()[0] 170 dimSize := a.Shape()[axis] 171 outerStride := a.Strides()[0] 172 stride := a.Strides()[axis] 173 expected := reuse.Strides()[0] 174 err = e.E.ReduceDefault(typ, dataA, dataReuse, dim0, dimSize, outerStride, stride, expected, defaultFn) 175 } 176 retVal = reuse 177 return 178 } 179 180 func (e StdEng) Sum(a Tensor, along ...int) (retVal Tensor, err error) { 181 a2 := a 182 if v, ok := a.(View); ok && v.IsMaterializable() { 183 a2 = v.Materialize() 184 } 185 return e.reduce("Sum", execution.MonotonicSum, execution.SumMethods, a2, along...) 186 } 187 188 func (e StdEng) Min(a Tensor, along ...int) (retVal Tensor, err error) { 189 a2 := a 190 if v, ok := a.(View); ok && v.IsMaterializable() { 191 a2 = v.Materialize() 192 } 193 return e.reduce("Min", execution.MonotonicMin, execution.MinMethods, a2, along...) 194 } 195 196 func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) { 197 a2 := a 198 if v, ok := a.(View); ok && v.IsMaterializable() { 199 a2 = v.Materialize() 200 } 201 return e.reduce("Max", execution.MonotonicMax, execution.MaxMethods, a2, along...) 202 } 203 204 func (e StdEng) reduce( 205 op string, 206 monotonicMethod func(t reflect.Type, a *storage.Header) (interface{}, error), 207 methods func(t reflect.Type) (interface{}, interface{}, interface{}, error), 208 a Tensor, 209 along ...int) (retVal Tensor, err error) { 210 switch at := a.(type) { 211 case *Dense: 212 hdr := at.hdr() 213 typ := at.t.Type 214 monotonic, incr1 := IsMonotonicInts(along) // if both are true, then it means all axes are accounted for, then it'll return a scalar value 215 if (monotonic && incr1 && len(along) == a.Dims()) || len(along) == 0 { 216 var ret interface{} 217 if ret, err = monotonicMethod(typ, hdr); err != nil { 218 return 219 } 220 return New(FromScalar(ret)), nil 221 } 222 var firstFn, lastFn, defaultFn interface{} 223 if firstFn, lastFn, defaultFn, err = methods(typ); err != nil { 224 return 225 } 226 defaultVal := reflect.Zero(typ).Interface() 227 228 retVal = a 229 dimsReduced := 0 230 sort.Slice(along, func(i, j int) bool { return along[i] < along[j] }) 231 232 for _, axis := range along { 233 axis -= dimsReduced 234 dimsReduced++ 235 if axis >= retVal.Dims() { 236 err = errors.Errorf(dimMismatch, retVal.Dims(), axis) 237 return 238 } 239 240 if retVal, err = e.OptimizedReduce(retVal, axis, firstFn, lastFn, defaultFn, defaultVal); err != nil { 241 return 242 } 243 } 244 return 245 246 default: 247 return nil, errors.Errorf("Cannot perform %s on %T", op, a) 248 } 249 250 } 251 252 func (StdEng) prepReduce(a Tensor, axis int, opts ...FuncOpt) (at, reuse DenseTensor, dataA, dataReuse *storage.Header, err error) { 253 if axis >= a.Dims() { 254 err = errors.Errorf(dimMismatch, axis, a.Dims()) 255 return 256 } 257 258 if err = unaryCheck(a, nil); err != nil { 259 err = errors.Wrap(err, "prepReduce failed") 260 return 261 } 262 263 // FUNC PREP 264 var safe bool 265 if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { 266 err = errors.Wrap(err, "Unable to prep unary tensor") 267 return 268 } 269 270 var newShape Shape 271 for i, s := range a.Shape() { 272 if i == axis { 273 continue 274 } 275 newShape = append(newShape, s) 276 } 277 278 switch { 279 case !safe: 280 err = errors.New("Reduce only supports safe operations.") 281 return 282 case reuse != nil && !reuse.IsNativelyAccessible(): 283 err = errors.Errorf(inaccessibleData, reuse) 284 return 285 case reuse != nil: 286 if reuse.Shape().TotalSize() != newShape.TotalSize() { 287 err = errors.Errorf(shapeMismatch, reuse.Shape(), newShape) 288 return 289 } 290 reuse.Reshape(newShape...) 291 case safe && reuse == nil: 292 reuse = New(Of(a.Dtype()), WithShape(newShape...)) 293 } 294 295 // DATA PREP 296 var useIter bool 297 if dataA, dataReuse, _, _, useIter, err = prepDataUnary(a, reuse); err != nil { 298 err = errors.Wrapf(err, "StdEng.Reduce data prep") 299 return 300 } 301 302 var ok bool 303 if at, ok = a.(DenseTensor); !ok || useIter { 304 err = errors.Errorf("Reduce does not (yet) support iterable tensors") 305 return 306 } 307 return 308 }