github.com/wzzhu/tensor@v0.9.24/dense_matop.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 ) 6 7 // T performs a thunked transpose. It doesn't actually do anything, except store extra information about the post-transposed shapes and strides 8 // Usually this is more than enough, as BLAS will handle the rest of the transpose 9 func (t *Dense) T(axes ...int) (err error) { 10 var transform AP 11 if transform, axes, err = t.AP.T(axes...); err != nil { 12 return handleNoOp(err) 13 } 14 15 // is there any old transposes that need to be done first? 16 // this is important, because any old transposes for dim >=3 are merely permutations of the strides 17 if !t.old.IsZero() { 18 if t.IsVector() { 19 // the transform that was calculated was a waste of time - return it to the pool then untranspose 20 t.UT() 21 return 22 } 23 24 // check if the current axes are just a reverse of the previous transpose's 25 isReversed := true 26 for i, s := range t.oshape() { 27 if transform.Shape()[i] != s { 28 isReversed = false 29 break 30 } 31 } 32 33 // if it is reversed, well, we just restore the backed up one 34 if isReversed { 35 t.UT() 36 return 37 } 38 39 // cool beans. No funny reversals. We'd have to actually do transpose then 40 t.Transpose() 41 } 42 43 // swap out the old and the new 44 t.old = t.AP 45 t.transposeWith = axes 46 t.AP = transform 47 return nil 48 } 49 50 // UT is a quick way to untranspose a currently transposed *Dense 51 // The reason for having this is quite simply illustrated by this problem: 52 // T = NewTensor(WithShape(2,3,4)) 53 // T.T(1,2,0) 54 // 55 // To untranspose that, we'd need to apply a transpose of (2,0,1). 56 // This means having to keep track and calculate the transposes. 57 // Instead, here's a helpful convenience function to instantly untranspose any previous transposes. 58 // 59 // Nothing will happen if there was no previous transpose 60 func (t *Dense) UT() { 61 if !t.old.IsZero() { 62 ReturnInts(t.transposeWith) 63 t.AP = t.old 64 t.old.zeroOnly() 65 t.transposeWith = nil 66 } 67 } 68 69 // SafeT is exactly like T(), except it returns a new *Dense. The data is also copied over, unmoved. 70 func (t *Dense) SafeT(axes ...int) (retVal *Dense, err error) { 71 var transform AP 72 if transform, axes, err = t.AP.T(axes...); err != nil { 73 if err = handleNoOp(err); err != nil { 74 return 75 } 76 } 77 78 retVal = recycledDense(t.t, Shape{t.len()}, WithEngine(t.e)) 79 copyDense(retVal, t) 80 81 retVal.e = t.e 82 retVal.oe = t.oe 83 retVal.AP = transform 84 t.AP.CloneTo(&retVal.old) 85 retVal.transposeWith = axes 86 87 return 88 } 89 90 // At returns the value at the given coordinate 91 func (t *Dense) At(coords ...int) (interface{}, error) { 92 if !t.IsNativelyAccessible() { 93 return nil, errors.Errorf(inaccessibleData, t) 94 } 95 if len(coords) != t.Dims() { 96 return nil, errors.Errorf(dimMismatch, t.Dims(), len(coords)) 97 } 98 99 at, err := t.at(coords...) 100 if err != nil { 101 return nil, errors.Wrap(err, "At()") 102 } 103 104 return t.Get(at), nil 105 } 106 107 // MaskAt returns the value of the mask at a given coordinate 108 // returns false (valid) if not tensor is not masked 109 func (t *Dense) MaskAt(coords ...int) (bool, error) { 110 if !t.IsMasked() { 111 return false, nil 112 } 113 if !t.IsNativelyAccessible() { 114 return false, errors.Errorf(inaccessibleData, t) 115 } 116 if len(coords) != t.Dims() { 117 return true, errors.Errorf(dimMismatch, t.Dims(), len(coords)) 118 } 119 120 at, err := t.maskAt(coords...) 121 if err != nil { 122 return true, errors.Wrap(err, "MaskAt()") 123 } 124 125 return t.mask[at], nil 126 } 127 128 // SetAt sets the value at the given coordinate 129 func (t *Dense) SetAt(v interface{}, coords ...int) error { 130 if !t.IsNativelyAccessible() { 131 return errors.Errorf(inaccessibleData, t) 132 } 133 134 if len(coords) != t.Dims() { 135 return errors.Errorf(dimMismatch, t.Dims(), len(coords)) 136 } 137 138 at, err := t.at(coords...) 139 if err != nil { 140 return errors.Wrap(err, "SetAt()") 141 } 142 t.Set(at, v) 143 return nil 144 } 145 146 // SetMaskAtDataIndex set the value of the mask at a given index 147 func (t *Dense) SetMaskAtIndex(v bool, i int) error { 148 if !t.IsMasked() { 149 return nil 150 } 151 t.mask[i] = v 152 return nil 153 } 154 155 // SetMaskAt sets the mask value at the given coordinate 156 func (t *Dense) SetMaskAt(v bool, coords ...int) error { 157 if !t.IsMasked() { 158 return nil 159 } 160 if !t.IsNativelyAccessible() { 161 return errors.Errorf(inaccessibleData, t) 162 } 163 if len(coords) != t.Dims() { 164 return errors.Errorf(dimMismatch, t.Dims(), len(coords)) 165 } 166 167 at, err := t.maskAt(coords...) 168 if err != nil { 169 return errors.Wrap(err, "SetAt()") 170 } 171 t.mask[at] = v 172 return nil 173 } 174 175 // CopyTo copies the underlying data to the destination *Dense. The original data is untouched. 176 // Note: CopyTo doesn't care about the metadata of the destination *Dense. Take for example: 177 // T = NewTensor(WithShape(6)) 178 // T2 = NewTensor(WithShape(2,3)) 179 // err = T.CopyTo(T2) // err == nil 180 // 181 // The only time that this will fail is if the underlying sizes are different 182 func (t *Dense) CopyTo(other *Dense) error { 183 if other == t { 184 return nil // nothing to copy to. Maybe return NoOpErr? 185 } 186 187 if other.Size() != t.Size() { 188 return errors.Errorf(sizeMismatch, t.Size(), other.Size()) 189 } 190 191 // easy peasy lemon squeezy 192 if t.viewOf == 0 && other.viewOf == 0 { 193 copyDense(other, t) 194 return nil 195 } 196 197 // TODO: use copyDenseIter 198 return errors.Errorf(methodNYI, "CopyTo", "views") 199 } 200 201 // Narrow narrows the tensor. 202 func (t *Dense) Narrow(dim, start, length int) (View, error) { 203 dim = resolveAxis(dim, t.Dims()) 204 205 slices := make([]Slice, MinInt(dim+1, t.Dims())) 206 slices[dim] = S(start, start+length, 1) 207 208 return t.Slice(slices...) 209 } 210 211 // Slice performs slicing on the *Dense Tensor. It returns a view which shares the same underlying memory as the original *Dense. 212 // 213 // Given: 214 // T = NewTensor(WithShape(2,2), WithBacking(RangeFloat64(0,4))) 215 // V, _ := T.Slice(nil, singleSlice(1)) // T[:, 1] 216 // 217 // Any modification to the values in V, will be reflected in T as well. 218 // 219 // The method treats <nil> as equivalent to a colon slice. T.Slice(nil) is equivalent to T[:] in Numpy syntax 220 func (t *Dense) Slice(slices ...Slice) (retVal View, err error) { 221 var newAP AP 222 var ndStart, ndEnd int 223 224 if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil { 225 return 226 } 227 228 view := borrowDense() 229 view.t = t.t 230 view.e = t.e 231 view.oe = t.oe 232 view.flag = t.flag 233 view.AP = newAP 234 view.setParentTensor(t) 235 t.sliceInto(ndStart, ndEnd, &view.array) 236 237 if t.IsMasked() { 238 view.mask = t.mask[ndStart:ndEnd] 239 } 240 241 return view, err 242 } 243 244 // SliceInto is a convenience method. It does NOT copy the values - it simply updates the AP of the view. 245 // The underlying data is the same. 246 // This method will override ALL the metadata in view. 247 func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) { 248 var newAP AP 249 var ndStart, ndEnd int 250 251 if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil { 252 return 253 } 254 255 view.AP.zero() 256 257 view.t = t.t 258 view.e = t.e 259 view.oe = t.oe 260 view.flag = t.flag 261 view.AP = newAP 262 view.setParentTensor(t) 263 t.sliceInto(ndStart, ndEnd, &view.array) 264 265 if t.IsMasked() { 266 view.mask = t.mask[ndStart:ndEnd] 267 } 268 269 return view, err 270 271 } 272 273 // RollAxis rolls the axis backwards until it lies in the given position. 274 // 275 // This method was adapted from Numpy's Rollaxis. The licence for Numpy is a BSD-like licence and can be found here: https://github.com/numpy/numpy/blob/master/LICENSE.txt 276 // 277 // As a result of being adapted from Numpy, the quirks are also adapted. A good guide reducing the confusion around rollaxis can be found here: http://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing (see answer by hpaulj) 278 func (t *Dense) RollAxis(axis, start int, safe bool) (retVal *Dense, err error) { 279 dims := t.Dims() 280 281 if !(axis >= 0 && axis < dims) { 282 err = errors.Errorf(invalidAxis, axis, dims) 283 return 284 } 285 286 if !(start >= 0 && start <= dims) { 287 err = errors.Wrap(errors.Errorf(invalidAxis, axis, dims), "Start axis is wrong") 288 return 289 } 290 291 if axis < start { 292 start-- 293 } 294 295 if axis == start { 296 retVal = t 297 return 298 } 299 300 axes := BorrowInts(dims) 301 defer ReturnInts(axes) 302 303 for i := 0; i < dims; i++ { 304 axes[i] = i 305 } 306 copy(axes[axis:], axes[axis+1:]) 307 copy(axes[start+1:], axes[start:]) 308 axes[start] = axis 309 310 if safe { 311 return t.SafeT(axes...) 312 } 313 err = t.T(axes...) 314 retVal = t 315 return 316 } 317 318 /* Private Methods */ 319 320 // returns the new index given the old index 321 func (t *Dense) transposeIndex(i int, transposePat, strides []int) int { 322 oldCoord, err := Itol(i, t.oshape(), t.ostrides()) 323 if err != nil { 324 err = errors.Wrapf(err, "transposeIndex ItoL failure. i %d original shape %v. original strides %v", i, t.oshape(), t.ostrides()) 325 panic(err) 326 } 327 328 /* 329 coordss, _ := Permute(transposePat, oldCoord) 330 coords := coordss[0] 331 expShape := t.Shape() 332 index, _ := Ltoi(expShape, strides, coords...) 333 */ 334 335 // The above is the "conceptual" algorithm. 336 // Too many checks above slows things down, so the below is the "optimized" edition 337 var index int 338 for i, axis := range transposePat { 339 index += oldCoord[axis] * strides[i] 340 } 341 return index 342 } 343 344 // at returns the index at which the coordinate is referring to. 345 // This function encapsulates the addressing of elements in a contiguous block. 346 // For a 2D ndarray, ndarray.at(i,j) is 347 // at = ndarray.strides[0]*i + ndarray.strides[1]*j 348 // This is of course, extensible to any number of dimensions. 349 func (t *Dense) at(coords ...int) (at int, err error) { 350 return Ltoi(t.Shape(), t.Strides(), coords...) 351 } 352 353 // maskat returns the mask index at which the coordinate is referring to. 354 func (t *Dense) maskAt(coords ...int) (at int, err error) { 355 //TODO: Add check for non-masked tensor 356 return t.at(coords...) 357 }