github.com/wzzhu/tensor@v0.9.24/api_matop.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 ) 6 7 // this file handles matops. While by default most of these matops should already have been defined as part of the 8 // Tensor interface, not all are possible(for example, concatenating a sparse tensor), hence the need for the following functions 9 10 // Narrow narrows the tensor. 11 func Narrow(t Tensor, dim, start, length int) (View, error) { 12 dim = resolveAxis(dim, t.Dims()) 13 14 slices := make([]Slice, MinInt(dim+1, t.Dims())) 15 slices[dim] = S(start, start+length, 1) 16 17 return t.Slice(slices...) 18 } 19 20 // Repeat repeats a Tensor along the axis and given the number of repeats. 21 func Repeat(t Tensor, axis int, repeats ...int) (retVal Tensor, err error) { 22 if r, ok := t.Engine().(Repeater); ok { 23 return r.Repeat(t, axis, repeats...) 24 } 25 return nil, errors.New("Engine does not support Repeat") 26 } 27 28 // RepeatReuse repeats a Tensor along the axis and the given number of repeats, and puts the results in the provided reuse tensor. If the reuse tensor is not correctly sized, then an error will be given, but the results will still be valid. 29 func RepeatReuse(t, reuse Tensor, axis int, repeats ...int) (retval Tensor, err error) { 30 if r, ok := t.Engine().(Repeater); ok { 31 return r.RepeatReuse(t, reuse, axis, repeats...) 32 } 33 return nil, errors.New("Engine does not support Repeat") 34 } 35 36 // T safely transposes a Tensor. It returns a tensor that is not a view of the input tensor - rather, the data is all copied. 37 func T(t Tensor, axes ...int) (retVal Tensor, err error) { 38 switch tt := t.(type) { 39 case *Dense: 40 return tt.SafeT(axes...) 41 } 42 panic("Unreachable") 43 } 44 45 // Transpose performs transposition of a tensor according to its axes. 46 func Transpose(t Tensor, axes ...int) (retVal Tensor, err error) { 47 switch tt := t.(type) { 48 case *Dense: 49 var ret *Dense 50 if ret, err = tt.SafeT(axes...); err != nil { 51 return 52 } 53 ret.Transpose() 54 retVal = ret 55 return 56 } 57 panic("Unreachable") 58 } 59 60 // Concat concatenates a list of Tensors. At the moment the operation only supports Tensors of the same type 61 // (*Dense can only be concatenated with a bunch of *Dense, CSCs can only be concatenated with a bunch of CSC, etc) 62 func Concat(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) { 63 if len(others) == 0 { 64 return t, nil 65 } 66 switch T := t.(type) { 67 case *Dense: 68 ts := make([]*Dense, len(others)) 69 for i, o := range others { 70 if ot, ok := o.(*Dense); ok { 71 ts[i] = ot 72 continue 73 } 74 return nil, errors.Errorf("Expected all Tensors to be *Dense") 75 } 76 return T.Concat(axis, ts...) 77 } 78 panic("Unreachable") 79 } 80 81 // Copy copies a tensor to another. For *Dense views, only the relevant slots are copied. 82 func Copy(dst, src Tensor) error { 83 switch st := src.(type) { 84 case DenseTensor: 85 dt, ok := dst.(DenseTensor) 86 if !ok { 87 return errors.Errorf("Cannot copy from DenseTensor to %T", dst) 88 } 89 90 if st.RequiresIterator() || dt.RequiresIterator() { 91 siter := st.Iterator() 92 diter := dt.Iterator() 93 _, err := copyDenseIter(dt, st, diter, siter) 94 return err 95 } 96 copyDense(dt, st) 97 return nil 98 default: 99 return errors.Errorf("NYI for Copy %T", src) 100 } 101 panic("Unreachable") 102 } 103 104 // Stack stacks a list of other Tensors. At the moment the operation only supports Tensors of the same type. 105 // (*Dense can only be stacked with *Dense... etc) 106 func Stack(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) { 107 if len(others) == 0 { 108 return t, nil 109 } 110 111 switch T := t.(type) { 112 case DenseTensor: 113 var dts []DenseTensor 114 if dts, err = tensorsToDenseTensors(others); err != nil { 115 return nil, errors.Wrap(err, "Cannot convert others into a slice of DenseTensors") 116 } 117 return T.stackDense(axis, dts...) 118 } 119 panic("Unreachable") 120 } 121 122 // Materialize takes a View and copies out the data into a new allocation. 123 func Materialize(t Tensor) Tensor { 124 switch tt := t.(type) { 125 case View: 126 return tt.Materialize() 127 default: 128 return t 129 } 130 } 131 132 func Diag(t Tensor) (retVal Tensor, err error) { 133 if d, ok := t.Engine().(Diager); ok { 134 return d.Diag(t) 135 } 136 return nil, errors.Errorf("Unable to perform diagonalization of tensor ") 137 } 138 139 // ByIndices allows for selection of value of `a` byt the indices listed in the `indices` tensor. 140 // The `indices` tensor has to be a vector-like tensor of ints. 141 func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 142 if axis >= a.Shape().Dims() { 143 return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims()) 144 } 145 if sbi, ok := a.Engine().(ByIndiceser); ok { 146 return sbi.SelectByIndices(a, indices, axis, opts...) 147 } 148 return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine()) 149 } 150 151 // ByIndicesB is the backpropagation of ByIndices. 152 func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 153 if axis >= a.Shape().Dims() { 154 return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims()) 155 } 156 if sbi, ok := a.Engine().(ByIndiceser); ok { 157 return sbi.SelectByIndicesB(a, b, indices, axis, opts...) 158 } 159 return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine()) 160 } 161 162 // LogSoftMax applies log softmax to the given tensor. 163 func LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 164 if sm, ok := x.Engine().(SoftMaxer); ok { 165 return sm.LogSoftMax(x, axis, opts...) 166 } 167 168 return nil, errors.Errorf("Unable to apply LogSoftMax. Engine %T does not support that.", x.Engine()) 169 } 170 171 // SoftMax applies softmax to the given tensor. 172 func SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 173 if sm, ok := x.Engine().(SoftMaxer); ok { 174 return sm.SoftMax(x, axis, opts...) 175 } 176 177 return nil, errors.Errorf("Unable to apply SoftMax. Engine %T does not support that.", x.Engine()) 178 } 179 180 // SoftMaxB applies softmax backwards operation 181 func SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 182 if sm, ok := output.Engine().(SoftMaxer); ok { 183 return sm.SoftMaxB(output, grad, axis, opts...) 184 } 185 186 return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine()) 187 } 188 189 // LogSoftMaxB applies softmax backwards operation 190 func LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 191 if sm, ok := output.Engine().(SoftMaxer); ok { 192 return sm.LogSoftMaxB(output, grad, axis, opts...) 193 } 194 195 return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine()) 196 }