github.com/wzzhu/tensor@v0.9.24/defaultengine_mapreduce.go (about)

     1  package tensor
     2  
     3  import (
     4  	"reflect"
     5  	"sort"
     6  
     7  	"github.com/pkg/errors"
     8  
     9  	"github.com/wzzhu/tensor/internal/execution"
    10  	"github.com/wzzhu/tensor/internal/storage"
    11  )
    12  
    13  func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
    14  	if err = unaryCheck(a, nil); err != nil {
    15  		err = errors.Wrap(err, "Failed Map()")
    16  		return
    17  	}
    18  
    19  	var reuse DenseTensor
    20  	var safe, _, incr bool
    21  	if reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil {
    22  		return
    23  	}
    24  	switch {
    25  	case safe && reuse == nil:
    26  		// create reuse
    27  		if v, ok := a.(View); ok {
    28  			if v.IsMaterializable() {
    29  				reuse = v.Materialize().(DenseTensor)
    30  			} else {
    31  				reuse = v.Clone().(DenseTensor)
    32  			}
    33  		} else {
    34  			reuse = New(Of(a.Dtype()), WithShape(a.Shape().Clone()...))
    35  		}
    36  	case reuse != nil:
    37  		if !reuse.IsNativelyAccessible() {
    38  			return nil, errors.Errorf(inaccessibleData, reuse)
    39  		}
    40  		if a.Size() != reuse.Size() {
    41  			return nil, errors.Errorf(shapeMismatch, a.Shape(), reuse.Shape())
    42  		}
    43  	}
    44  
    45  	// PREP DATA
    46  	typ := a.Dtype().Type
    47  	var dataA, dataReuse, used *storage.Header
    48  	var ait, rit, uit Iterator
    49  	var useIter bool
    50  	if dataA, dataReuse, ait, rit, useIter, err = prepDataUnary(a, reuse); err != nil {
    51  		return nil, errors.Wrapf(err, "StdEng.Map")
    52  	}
    53  
    54  	// HANDLE USE CASES
    55  	switch {
    56  	case !safe:
    57  		used = dataA
    58  		uit = ait
    59  	default:
    60  		used = dataReuse
    61  		uit = rit
    62  	}
    63  
    64  	// DO
    65  	if useIter {
    66  		err = e.E.MapIter(typ, fn, used, incr, uit)
    67  	} else {
    68  		err = e.E.Map(typ, fn, used, incr)
    69  	}
    70  	if err != nil {
    71  		err = errors.Wrapf(err, "Unable to apply function %v to tensor of %v", fn, typ)
    72  		return
    73  	}
    74  
    75  	// SET RETVAL
    76  	switch {
    77  	case reuse != nil:
    78  		if err = reuseCheckShape(reuse, a.Shape()); err != nil {
    79  			err = errors.Wrapf(err, "Reuse shape check failed")
    80  			return
    81  		}
    82  		retVal = reuse
    83  	case !safe:
    84  		retVal = a
    85  	default:
    86  		retVal = reuse
    87  	}
    88  	return
    89  }
    90  
    91  func (e StdEng) Reduce(fn interface{}, a Tensor, axis int, defaultValue interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
    92  	if !a.IsNativelyAccessible() {
    93  		return nil, errors.Errorf(inaccessibleData, a)
    94  	}
    95  	var at, reuse DenseTensor
    96  	var dataA, dataReuse *storage.Header
    97  	if at, reuse, dataA, dataReuse, err = e.prepReduce(a, axis, opts...); err != nil {
    98  		err = errors.Wrap(err, "Prep Reduce failed")
    99  		return
   100  	}
   101  
   102  	lastAxis := a.Dims() - 1
   103  	typ := a.Dtype().Type
   104  
   105  	// actual call out to the internal engine
   106  	switch {
   107  	case (axis == 0 && at.DataOrder().IsRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().IsColMajor()):
   108  		var size, split int
   109  		if at.DataOrder().IsColMajor() {
   110  			return nil, errors.Errorf("NYI: colmajor")
   111  		}
   112  		size = a.Shape()[0]
   113  		split = a.DataSize() / size
   114  		storage.CopySliced(typ, dataReuse, 0, split, dataA, 0, split)
   115  		err = e.E.ReduceFirst(typ, dataA, dataReuse, split, size, fn)
   116  	case (axis == lastAxis && at.DataOrder().IsRowMajor()) || (axis == 0 && at.DataOrder().IsColMajor()):
   117  		var dimSize int
   118  		if at.DataOrder().IsColMajor() {
   119  			return nil, errors.Errorf("NYI: colmajor")
   120  		}
   121  		dimSize = a.Shape()[axis]
   122  		err = e.E.ReduceLast(typ, dataA, dataReuse, dimSize, defaultValue, fn)
   123  	default:
   124  		dim0 := a.Shape()[0]
   125  		dimSize := a.Shape()[axis]
   126  		outerStride := a.Strides()[0]
   127  		stride := a.Strides()[axis]
   128  		expected := reuse.Strides()[0]
   129  		err = e.E.ReduceDefault(typ, dataA, dataReuse, dim0, dimSize, outerStride, stride, expected, fn)
   130  	}
   131  	retVal = reuse
   132  	return
   133  }
   134  
   135  func (e StdEng) OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn, defaultValue interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
   136  	if !a.IsNativelyAccessible() {
   137  		return nil, errors.Errorf(inaccessibleData, a)
   138  	}
   139  
   140  	var at, reuse DenseTensor
   141  	var dataA, dataReuse *storage.Header
   142  	if at, reuse, dataA, dataReuse, err = e.prepReduce(a, axis, opts...); err != nil {
   143  		err = errors.Wrap(err, "Prep Reduce failed")
   144  		return
   145  	}
   146  
   147  	lastAxis := a.Dims() - 1
   148  	typ := a.Dtype().Type
   149  
   150  	// actual call out to the internal engine
   151  	switch {
   152  	case (axis == 0 && at.DataOrder().IsRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().IsColMajor()):
   153  		var size, split int
   154  		if at.DataOrder().IsColMajor() {
   155  			return nil, errors.Errorf("NYI: colmajor")
   156  		}
   157  		size = a.Shape()[0]
   158  		split = a.DataSize() / size
   159  		storage.CopySliced(typ, dataReuse, 0, split, dataA, 0, split)
   160  		err = e.E.ReduceFirst(typ, dataA, dataReuse, split, size, firstFn)
   161  	case (axis == lastAxis && at.DataOrder().IsRowMajor()) || (axis == 0 && at.DataOrder().IsColMajor()):
   162  		var dimSize int
   163  		if at.DataOrder().IsColMajor() {
   164  			return nil, errors.Errorf("NYI: colmajor")
   165  		}
   166  		dimSize = a.Shape()[axis]
   167  		err = e.E.ReduceLast(typ, dataA, dataReuse, dimSize, defaultValue, lastFn)
   168  	default:
   169  		dim0 := a.Shape()[0]
   170  		dimSize := a.Shape()[axis]
   171  		outerStride := a.Strides()[0]
   172  		stride := a.Strides()[axis]
   173  		expected := reuse.Strides()[0]
   174  		err = e.E.ReduceDefault(typ, dataA, dataReuse, dim0, dimSize, outerStride, stride, expected, defaultFn)
   175  	}
   176  	retVal = reuse
   177  	return
   178  }
   179  
   180  func (e StdEng) Sum(a Tensor, along ...int) (retVal Tensor, err error) {
   181  	a2 := a
   182  	if v, ok := a.(View); ok && v.IsMaterializable() {
   183  		a2 = v.Materialize()
   184  	}
   185  	return e.reduce("Sum", execution.MonotonicSum, execution.SumMethods, a2, along...)
   186  }
   187  
   188  func (e StdEng) Min(a Tensor, along ...int) (retVal Tensor, err error) {
   189  	a2 := a
   190  	if v, ok := a.(View); ok && v.IsMaterializable() {
   191  		a2 = v.Materialize()
   192  	}
   193  	return e.reduce("Min", execution.MonotonicMin, execution.MinMethods, a2, along...)
   194  }
   195  
   196  func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) {
   197  	a2 := a
   198  	if v, ok := a.(View); ok && v.IsMaterializable() {
   199  		a2 = v.Materialize()
   200  	}
   201  	return e.reduce("Max", execution.MonotonicMax, execution.MaxMethods, a2, along...)
   202  }
   203  
   204  func (e StdEng) reduce(
   205  	op string,
   206  	monotonicMethod func(t reflect.Type, a *storage.Header) (interface{}, error),
   207  	methods func(t reflect.Type) (interface{}, interface{}, interface{}, error),
   208  	a Tensor,
   209  	along ...int) (retVal Tensor, err error) {
   210  	switch at := a.(type) {
   211  	case *Dense:
   212  		hdr := at.hdr()
   213  		typ := at.t.Type
   214  		monotonic, incr1 := IsMonotonicInts(along) // if both are true, then it means all axes are accounted for, then it'll return a scalar value
   215  		if (monotonic && incr1 && len(along) == a.Dims()) || len(along) == 0 {
   216  			var ret interface{}
   217  			if ret, err = monotonicMethod(typ, hdr); err != nil {
   218  				return
   219  			}
   220  			return New(FromScalar(ret)), nil
   221  		}
   222  		var firstFn, lastFn, defaultFn interface{}
   223  		if firstFn, lastFn, defaultFn, err = methods(typ); err != nil {
   224  			return
   225  		}
   226  		defaultVal := reflect.Zero(typ).Interface()
   227  
   228  		retVal = a
   229  		dimsReduced := 0
   230  		sort.Slice(along, func(i, j int) bool { return along[i] < along[j] })
   231  
   232  		for _, axis := range along {
   233  			axis -= dimsReduced
   234  			dimsReduced++
   235  			if axis >= retVal.Dims() {
   236  				err = errors.Errorf(dimMismatch, retVal.Dims(), axis)
   237  				return
   238  			}
   239  
   240  			if retVal, err = e.OptimizedReduce(retVal, axis, firstFn, lastFn, defaultFn, defaultVal); err != nil {
   241  				return
   242  			}
   243  		}
   244  		return
   245  
   246  	default:
   247  		return nil, errors.Errorf("Cannot perform %s on %T", op, a)
   248  	}
   249  
   250  }
   251  
   252  func (StdEng) prepReduce(a Tensor, axis int, opts ...FuncOpt) (at, reuse DenseTensor, dataA, dataReuse *storage.Header, err error) {
   253  	if axis >= a.Dims() {
   254  		err = errors.Errorf(dimMismatch, axis, a.Dims())
   255  		return
   256  	}
   257  
   258  	if err = unaryCheck(a, nil); err != nil {
   259  		err = errors.Wrap(err, "prepReduce failed")
   260  		return
   261  	}
   262  
   263  	// FUNC PREP
   264  	var safe bool
   265  	if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil {
   266  		err = errors.Wrap(err, "Unable to prep unary tensor")
   267  		return
   268  	}
   269  
   270  	var newShape Shape
   271  	for i, s := range a.Shape() {
   272  		if i == axis {
   273  			continue
   274  		}
   275  		newShape = append(newShape, s)
   276  	}
   277  
   278  	switch {
   279  	case !safe:
   280  		err = errors.New("Reduce only supports safe operations.")
   281  		return
   282  	case reuse != nil && !reuse.IsNativelyAccessible():
   283  		err = errors.Errorf(inaccessibleData, reuse)
   284  		return
   285  	case reuse != nil:
   286  		if reuse.Shape().TotalSize() != newShape.TotalSize() {
   287  			err = errors.Errorf(shapeMismatch, reuse.Shape(), newShape)
   288  			return
   289  		}
   290  		reuse.Reshape(newShape...)
   291  	case safe && reuse == nil:
   292  		reuse = New(Of(a.Dtype()), WithShape(newShape...))
   293  	}
   294  
   295  	// DATA PREP
   296  	var useIter bool
   297  	if dataA, dataReuse, _, _, useIter, err = prepDataUnary(a, reuse); err != nil {
   298  		err = errors.Wrapf(err, "StdEng.Reduce data prep")
   299  		return
   300  	}
   301  
   302  	var ok bool
   303  	if at, ok = a.(DenseTensor); !ok || useIter {
   304  		err = errors.Errorf("Reduce does not (yet) support iterable tensors")
   305  		return
   306  	}
   307  	return
   308  }