github.com/wzzhu/tensor@v0.9.24/dense_matop.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  )
     6  
     7  // T performs a thunked transpose. It doesn't actually do anything, except store extra information about the post-transposed shapes and strides
     8  // Usually this is more than enough, as BLAS will handle the rest of the transpose
     9  func (t *Dense) T(axes ...int) (err error) {
    10  	var transform AP
    11  	if transform, axes, err = t.AP.T(axes...); err != nil {
    12  		return handleNoOp(err)
    13  	}
    14  
    15  	// is there any old transposes that need to be done first?
    16  	// this is important, because any old transposes for dim >=3 are merely permutations of the strides
    17  	if !t.old.IsZero() {
    18  		if t.IsVector() {
    19  			// the transform that was calculated was a waste of time - return it to the pool then untranspose
    20  			t.UT()
    21  			return
    22  		}
    23  
    24  		// check if the current axes are just a reverse of the previous transpose's
    25  		isReversed := true
    26  		for i, s := range t.oshape() {
    27  			if transform.Shape()[i] != s {
    28  				isReversed = false
    29  				break
    30  			}
    31  		}
    32  
    33  		// if it is reversed, well, we just restore the backed up one
    34  		if isReversed {
    35  			t.UT()
    36  			return
    37  		}
    38  
    39  		// cool beans. No funny reversals. We'd have to actually do transpose then
    40  		t.Transpose()
    41  	}
    42  
    43  	// swap out the old and the new
    44  	t.old = t.AP
    45  	t.transposeWith = axes
    46  	t.AP = transform
    47  	return nil
    48  }
    49  
    50  // UT is a quick way to untranspose a currently transposed *Dense
    51  // The reason for having this is quite simply illustrated by this problem:
    52  //		T = NewTensor(WithShape(2,3,4))
    53  //		T.T(1,2,0)
    54  //
    55  // To untranspose that, we'd need to apply a transpose of (2,0,1).
    56  // This means having to keep track and calculate the transposes.
    57  // Instead, here's a helpful convenience function to instantly untranspose any previous transposes.
    58  //
    59  // Nothing will happen if there was no previous transpose
    60  func (t *Dense) UT() {
    61  	if !t.old.IsZero() {
    62  		ReturnInts(t.transposeWith)
    63  		t.AP = t.old
    64  		t.old.zeroOnly()
    65  		t.transposeWith = nil
    66  	}
    67  }
    68  
    69  // SafeT is exactly like T(), except it returns a new *Dense. The data is also copied over, unmoved.
    70  func (t *Dense) SafeT(axes ...int) (retVal *Dense, err error) {
    71  	var transform AP
    72  	if transform, axes, err = t.AP.T(axes...); err != nil {
    73  		if err = handleNoOp(err); err != nil {
    74  			return
    75  		}
    76  	}
    77  
    78  	retVal = recycledDense(t.t, Shape{t.len()}, WithEngine(t.e))
    79  	copyDense(retVal, t)
    80  
    81  	retVal.e = t.e
    82  	retVal.oe = t.oe
    83  	retVal.AP = transform
    84  	t.AP.CloneTo(&retVal.old)
    85  	retVal.transposeWith = axes
    86  
    87  	return
    88  }
    89  
    90  // At returns the value at the given coordinate
    91  func (t *Dense) At(coords ...int) (interface{}, error) {
    92  	if !t.IsNativelyAccessible() {
    93  		return nil, errors.Errorf(inaccessibleData, t)
    94  	}
    95  	if len(coords) != t.Dims() {
    96  		return nil, errors.Errorf(dimMismatch, t.Dims(), len(coords))
    97  	}
    98  
    99  	at, err := t.at(coords...)
   100  	if err != nil {
   101  		return nil, errors.Wrap(err, "At()")
   102  	}
   103  
   104  	return t.Get(at), nil
   105  }
   106  
   107  // MaskAt returns the value of the mask at a given coordinate
   108  // returns false (valid) if not tensor is not masked
   109  func (t *Dense) MaskAt(coords ...int) (bool, error) {
   110  	if !t.IsMasked() {
   111  		return false, nil
   112  	}
   113  	if !t.IsNativelyAccessible() {
   114  		return false, errors.Errorf(inaccessibleData, t)
   115  	}
   116  	if len(coords) != t.Dims() {
   117  		return true, errors.Errorf(dimMismatch, t.Dims(), len(coords))
   118  	}
   119  
   120  	at, err := t.maskAt(coords...)
   121  	if err != nil {
   122  		return true, errors.Wrap(err, "MaskAt()")
   123  	}
   124  
   125  	return t.mask[at], nil
   126  }
   127  
   128  // SetAt sets the value at the given coordinate
   129  func (t *Dense) SetAt(v interface{}, coords ...int) error {
   130  	if !t.IsNativelyAccessible() {
   131  		return errors.Errorf(inaccessibleData, t)
   132  	}
   133  
   134  	if len(coords) != t.Dims() {
   135  		return errors.Errorf(dimMismatch, t.Dims(), len(coords))
   136  	}
   137  
   138  	at, err := t.at(coords...)
   139  	if err != nil {
   140  		return errors.Wrap(err, "SetAt()")
   141  	}
   142  	t.Set(at, v)
   143  	return nil
   144  }
   145  
   146  // SetMaskAtDataIndex set the value of the mask at a given index
   147  func (t *Dense) SetMaskAtIndex(v bool, i int) error {
   148  	if !t.IsMasked() {
   149  		return nil
   150  	}
   151  	t.mask[i] = v
   152  	return nil
   153  }
   154  
   155  // SetMaskAt sets the mask value at the given coordinate
   156  func (t *Dense) SetMaskAt(v bool, coords ...int) error {
   157  	if !t.IsMasked() {
   158  		return nil
   159  	}
   160  	if !t.IsNativelyAccessible() {
   161  		return errors.Errorf(inaccessibleData, t)
   162  	}
   163  	if len(coords) != t.Dims() {
   164  		return errors.Errorf(dimMismatch, t.Dims(), len(coords))
   165  	}
   166  
   167  	at, err := t.maskAt(coords...)
   168  	if err != nil {
   169  		return errors.Wrap(err, "SetAt()")
   170  	}
   171  	t.mask[at] = v
   172  	return nil
   173  }
   174  
   175  // CopyTo copies the underlying data to the destination *Dense. The original data is untouched.
   176  // Note: CopyTo doesn't care about the metadata of the destination *Dense. Take for example:
   177  //		T = NewTensor(WithShape(6))
   178  //		T2 = NewTensor(WithShape(2,3))
   179  //		err = T.CopyTo(T2) // err == nil
   180  //
   181  // The only time that this will fail is if the underlying sizes are different
   182  func (t *Dense) CopyTo(other *Dense) error {
   183  	if other == t {
   184  		return nil // nothing to copy to. Maybe return NoOpErr?
   185  	}
   186  
   187  	if other.Size() != t.Size() {
   188  		return errors.Errorf(sizeMismatch, t.Size(), other.Size())
   189  	}
   190  
   191  	// easy peasy lemon squeezy
   192  	if t.viewOf == 0 && other.viewOf == 0 {
   193  		copyDense(other, t)
   194  		return nil
   195  	}
   196  
   197  	// TODO: use copyDenseIter
   198  	return errors.Errorf(methodNYI, "CopyTo", "views")
   199  }
   200  
   201  // Narrow narrows the tensor.
   202  func (t *Dense) Narrow(dim, start, length int) (View, error) {
   203  	dim = resolveAxis(dim, t.Dims())
   204  
   205  	slices := make([]Slice, MinInt(dim+1, t.Dims()))
   206  	slices[dim] = S(start, start+length, 1)
   207  
   208  	return t.Slice(slices...)
   209  }
   210  
   211  // Slice performs slicing on the *Dense Tensor. It returns a view which shares the same underlying memory as the original *Dense.
   212  //
   213  // Given:
   214  //		T = NewTensor(WithShape(2,2), WithBacking(RangeFloat64(0,4)))
   215  //		V, _ := T.Slice(nil, singleSlice(1)) // T[:, 1]
   216  //
   217  // Any modification to the values in V, will be reflected in T as well.
   218  //
   219  // The method treats <nil> as equivalent to a colon slice. T.Slice(nil) is equivalent to T[:] in Numpy syntax
   220  func (t *Dense) Slice(slices ...Slice) (retVal View, err error) {
   221  	var newAP AP
   222  	var ndStart, ndEnd int
   223  
   224  	if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil {
   225  		return
   226  	}
   227  
   228  	view := borrowDense()
   229  	view.t = t.t
   230  	view.e = t.e
   231  	view.oe = t.oe
   232  	view.flag = t.flag
   233  	view.AP = newAP
   234  	view.setParentTensor(t)
   235  	t.sliceInto(ndStart, ndEnd, &view.array)
   236  
   237  	if t.IsMasked() {
   238  		view.mask = t.mask[ndStart:ndEnd]
   239  	}
   240  
   241  	return view, err
   242  }
   243  
   244  // SliceInto is a convenience method. It does NOT copy the values - it simply updates the AP of the view.
   245  // The underlying data is the same.
   246  // This method will override ALL the metadata in view.
   247  func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) {
   248  	var newAP AP
   249  	var ndStart, ndEnd int
   250  
   251  	if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil {
   252  		return
   253  	}
   254  
   255  	view.AP.zero()
   256  
   257  	view.t = t.t
   258  	view.e = t.e
   259  	view.oe = t.oe
   260  	view.flag = t.flag
   261  	view.AP = newAP
   262  	view.setParentTensor(t)
   263  	t.sliceInto(ndStart, ndEnd, &view.array)
   264  
   265  	if t.IsMasked() {
   266  		view.mask = t.mask[ndStart:ndEnd]
   267  	}
   268  
   269  	return view, err
   270  
   271  }
   272  
   273  // RollAxis rolls the axis backwards until it lies in the given position.
   274  //
   275  // This method was adapted from Numpy's Rollaxis. The licence for Numpy is a BSD-like licence and can be found here: https://github.com/numpy/numpy/blob/master/LICENSE.txt
   276  //
   277  // As a result of being adapted from Numpy, the quirks are also adapted. A good guide reducing the confusion around rollaxis can be found here: http://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing (see answer by hpaulj)
   278  func (t *Dense) RollAxis(axis, start int, safe bool) (retVal *Dense, err error) {
   279  	dims := t.Dims()
   280  
   281  	if !(axis >= 0 && axis < dims) {
   282  		err = errors.Errorf(invalidAxis, axis, dims)
   283  		return
   284  	}
   285  
   286  	if !(start >= 0 && start <= dims) {
   287  		err = errors.Wrap(errors.Errorf(invalidAxis, axis, dims), "Start axis is wrong")
   288  		return
   289  	}
   290  
   291  	if axis < start {
   292  		start--
   293  	}
   294  
   295  	if axis == start {
   296  		retVal = t
   297  		return
   298  	}
   299  
   300  	axes := BorrowInts(dims)
   301  	defer ReturnInts(axes)
   302  
   303  	for i := 0; i < dims; i++ {
   304  		axes[i] = i
   305  	}
   306  	copy(axes[axis:], axes[axis+1:])
   307  	copy(axes[start+1:], axes[start:])
   308  	axes[start] = axis
   309  
   310  	if safe {
   311  		return t.SafeT(axes...)
   312  	}
   313  	err = t.T(axes...)
   314  	retVal = t
   315  	return
   316  }
   317  
   318  /* Private Methods */
   319  
   320  // returns the new index given the old index
   321  func (t *Dense) transposeIndex(i int, transposePat, strides []int) int {
   322  	oldCoord, err := Itol(i, t.oshape(), t.ostrides())
   323  	if err != nil {
   324  		err = errors.Wrapf(err, "transposeIndex ItoL failure. i %d original shape %v. original strides %v", i, t.oshape(), t.ostrides())
   325  		panic(err)
   326  	}
   327  
   328  	/*
   329  		coordss, _ := Permute(transposePat, oldCoord)
   330  		coords := coordss[0]
   331  		expShape := t.Shape()
   332  		index, _ := Ltoi(expShape, strides, coords...)
   333  	*/
   334  
   335  	// The above is the "conceptual" algorithm.
   336  	// Too many checks above slows things down, so the below is the "optimized" edition
   337  	var index int
   338  	for i, axis := range transposePat {
   339  		index += oldCoord[axis] * strides[i]
   340  	}
   341  	return index
   342  }
   343  
   344  // at returns the index at which the coordinate is referring to.
   345  // This function encapsulates the addressing of elements in a contiguous block.
   346  // For a 2D ndarray, ndarray.at(i,j) is
   347  //		at = ndarray.strides[0]*i + ndarray.strides[1]*j
   348  // This is of course, extensible to any number of dimensions.
   349  func (t *Dense) at(coords ...int) (at int, err error) {
   350  	return Ltoi(t.Shape(), t.Strides(), coords...)
   351  }
   352  
   353  // maskat returns the mask index at which the coordinate is referring to.
   354  func (t *Dense) maskAt(coords ...int) (at int, err error) {
   355  	//TODO: Add check for non-masked tensor
   356  	return t.at(coords...)
   357  }