github.com/wzzhu/tensor@v0.9.24/api_matop.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  )
     6  
     7  // this file handles matops. While by default most of these matops should already have been defined as part of the
     8  // Tensor interface, not all are possible(for example, concatenating a sparse tensor), hence the need for the following functions
     9  
    10  // Narrow narrows the tensor.
    11  func Narrow(t Tensor, dim, start, length int) (View, error) {
    12  	dim = resolveAxis(dim, t.Dims())
    13  
    14  	slices := make([]Slice, MinInt(dim+1, t.Dims()))
    15  	slices[dim] = S(start, start+length, 1)
    16  
    17  	return t.Slice(slices...)
    18  }
    19  
    20  // Repeat repeats a Tensor along the axis and given the number of repeats.
    21  func Repeat(t Tensor, axis int, repeats ...int) (retVal Tensor, err error) {
    22  	if r, ok := t.Engine().(Repeater); ok {
    23  		return r.Repeat(t, axis, repeats...)
    24  	}
    25  	return nil, errors.New("Engine does not support Repeat")
    26  }
    27  
    28  // RepeatReuse repeats a Tensor along the axis and the given number of repeats, and puts the results in the provided reuse tensor. If the reuse tensor is not correctly sized, then  an error will be given, but the results will still be valid.
    29  func RepeatReuse(t, reuse Tensor, axis int, repeats ...int) (retval Tensor, err error) {
    30  	if r, ok := t.Engine().(Repeater); ok {
    31  		return r.RepeatReuse(t, reuse, axis, repeats...)
    32  	}
    33  	return nil, errors.New("Engine does not support Repeat")
    34  }
    35  
    36  // T safely transposes a Tensor. It returns a tensor that is not a view of the input tensor - rather, the data is all copied.
    37  func T(t Tensor, axes ...int) (retVal Tensor, err error) {
    38  	switch tt := t.(type) {
    39  	case *Dense:
    40  		return tt.SafeT(axes...)
    41  	}
    42  	panic("Unreachable")
    43  }
    44  
    45  // Transpose performs transposition of a tensor according to its axes.
    46  func Transpose(t Tensor, axes ...int) (retVal Tensor, err error) {
    47  	switch tt := t.(type) {
    48  	case *Dense:
    49  		var ret *Dense
    50  		if ret, err = tt.SafeT(axes...); err != nil {
    51  			return
    52  		}
    53  		ret.Transpose()
    54  		retVal = ret
    55  		return
    56  	}
    57  	panic("Unreachable")
    58  }
    59  
    60  // Concat concatenates a list of Tensors. At the moment the operation only supports Tensors of the same type
    61  // (*Dense can only be concatenated with a bunch of *Dense, CSCs can only be concatenated with a bunch of CSC, etc)
    62  func Concat(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) {
    63  	if len(others) == 0 {
    64  		return t, nil
    65  	}
    66  	switch T := t.(type) {
    67  	case *Dense:
    68  		ts := make([]*Dense, len(others))
    69  		for i, o := range others {
    70  			if ot, ok := o.(*Dense); ok {
    71  				ts[i] = ot
    72  				continue
    73  			}
    74  			return nil, errors.Errorf("Expected all Tensors to be *Dense")
    75  		}
    76  		return T.Concat(axis, ts...)
    77  	}
    78  	panic("Unreachable")
    79  }
    80  
    81  // Copy copies a tensor to another. For *Dense views, only the relevant slots are copied.
    82  func Copy(dst, src Tensor) error {
    83  	switch st := src.(type) {
    84  	case DenseTensor:
    85  		dt, ok := dst.(DenseTensor)
    86  		if !ok {
    87  			return errors.Errorf("Cannot copy from DenseTensor to %T", dst)
    88  		}
    89  
    90  		if st.RequiresIterator() || dt.RequiresIterator() {
    91  			siter := st.Iterator()
    92  			diter := dt.Iterator()
    93  			_, err := copyDenseIter(dt, st, diter, siter)
    94  			return err
    95  		}
    96  		copyDense(dt, st)
    97  		return nil
    98  	default:
    99  		return errors.Errorf("NYI for Copy %T", src)
   100  	}
   101  	panic("Unreachable")
   102  }
   103  
   104  // Stack stacks a list of other Tensors. At the moment the operation only supports Tensors of the same type.
   105  // (*Dense can only be stacked with *Dense... etc)
   106  func Stack(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) {
   107  	if len(others) == 0 {
   108  		return t, nil
   109  	}
   110  
   111  	switch T := t.(type) {
   112  	case DenseTensor:
   113  		var dts []DenseTensor
   114  		if dts, err = tensorsToDenseTensors(others); err != nil {
   115  			return nil, errors.Wrap(err, "Cannot  convert others into a slice of DenseTensors")
   116  		}
   117  		return T.stackDense(axis, dts...)
   118  	}
   119  	panic("Unreachable")
   120  }
   121  
   122  // Materialize takes a View and copies out the data into a new allocation.
   123  func Materialize(t Tensor) Tensor {
   124  	switch tt := t.(type) {
   125  	case View:
   126  		return tt.Materialize()
   127  	default:
   128  		return t
   129  	}
   130  }
   131  
   132  func Diag(t Tensor) (retVal Tensor, err error) {
   133  	if d, ok := t.Engine().(Diager); ok {
   134  		return d.Diag(t)
   135  	}
   136  	return nil, errors.Errorf("Unable to perform diagonalization of tensor ")
   137  }
   138  
   139  // ByIndices allows for selection of value of `a`  byt the indices listed in the `indices` tensor.
   140  // The `indices` tensor has to be a vector-like tensor of ints.
   141  func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
   142  	if axis >= a.Shape().Dims() {
   143  		return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims())
   144  	}
   145  	if sbi, ok := a.Engine().(ByIndiceser); ok {
   146  		return sbi.SelectByIndices(a, indices, axis, opts...)
   147  	}
   148  	return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine())
   149  }
   150  
   151  // ByIndicesB is the backpropagation of ByIndices.
   152  func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
   153  	if axis >= a.Shape().Dims() {
   154  		return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims())
   155  	}
   156  	if sbi, ok := a.Engine().(ByIndiceser); ok {
   157  		return sbi.SelectByIndicesB(a, b, indices, axis, opts...)
   158  	}
   159  	return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine())
   160  }
   161  
   162  // LogSoftMax applies log softmax to the given tensor.
   163  func LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
   164  	if sm, ok := x.Engine().(SoftMaxer); ok {
   165  		return sm.LogSoftMax(x, axis, opts...)
   166  	}
   167  
   168  	return nil, errors.Errorf("Unable to apply LogSoftMax. Engine %T does not support that.", x.Engine())
   169  }
   170  
   171  // SoftMax applies softmax to the given tensor.
   172  func SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
   173  	if sm, ok := x.Engine().(SoftMaxer); ok {
   174  		return sm.SoftMax(x, axis, opts...)
   175  	}
   176  
   177  	return nil, errors.Errorf("Unable to apply SoftMax. Engine %T does not support that.", x.Engine())
   178  }
   179  
   180  // SoftMaxB applies softmax backwards operation
   181  func SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
   182  	if sm, ok := output.Engine().(SoftMaxer); ok {
   183  		return sm.SoftMaxB(output, grad, axis, opts...)
   184  	}
   185  
   186  	return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine())
   187  }
   188  
   189  // LogSoftMaxB applies softmax backwards operation
   190  func LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
   191  	if sm, ok := output.Engine().(SoftMaxer); ok {
   192  		return sm.LogSoftMaxB(output, grad, axis, opts...)
   193  	}
   194  
   195  	return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine())
   196  }