gorgonia.org/tensor@v0.9.24/ap.go (about) 1 package tensor 2 3 import ( 4 "fmt" 5 6 "github.com/pkg/errors" 7 ) 8 9 // An AP is an access pattern. It tells the various ndarrays how to access their data through the use of strides 10 // Through the AP, there are several definitions of things, most notably there are two very specific "special cases": 11 // Scalar has Dims() of 0. 12 // - (1) 13 // Scalarlikes are higher order tensors, but each with a size of 1. The Dims() are not 0. 14 // - (1, 1) 15 // - (1, 1, 1) 16 // - (1, 1, 1, 1), etc 17 // Vector has Dims() of 1, but its shape can take several forms: 18 // - (x, 1) 19 // - (1, x) 20 // - (x) 21 // Matrix has Dims() of 2. This is the most basic form. The len(shape) has to be equal to 2 as well 22 // ndarray has Dims() of n. 23 type AP struct { 24 shape Shape // len(shape) is the operational definition of the dimensions 25 strides []int // strides is usually calculated from shape 26 fin bool // is this struct change-proof? 27 28 o DataOrder 29 Δ Triangle 30 } 31 32 func makeAP(size int) AP { 33 return AP{ 34 shape: Shape(BorrowInts(size)), 35 strides: BorrowInts(size), 36 } 37 } 38 39 // MakeAP creates an AP, given the shape and strides. 40 func MakeAP(shape Shape, strides []int, o DataOrder, Δ Triangle) AP { 41 return AP{ 42 shape: shape, 43 strides: strides, 44 o: o, 45 Δ: Δ, 46 fin: true, 47 } 48 } 49 50 // Init initializes an already created AP with a shape and stries. 51 // It will panic if AP is nil. 52 func (ap *AP) Init(shape Shape, strides []int) { 53 ap.shape = shape 54 ap.strides = strides 55 ap.fin = true 56 } 57 58 // SetShape is for very specific times when modifying the AP is necessary, such as reshaping and doing I/O related stuff 59 // 60 // Caveats: 61 // 62 // - SetShape will recalculate the strides. 63 // 64 // - If the AP is locked, nothing will happen 65 func (ap *AP) SetShape(s ...int) { 66 if !ap.fin { 67 // scalars are a special case, we don't want to remove it completely 68 if len(s) == 0 { 69 if ap.shape == nil || ap.strides == nil { 70 ap.shape = Shape{} 71 } 72 ap.shape = ap.shape[:0] 73 ap.strides = ap.strides[:0] 74 return 75 } 76 77 if ap.shape != nil { 78 ReturnInts(ap.shape) 79 ap.shape = nil 80 } 81 if ap.strides != nil { 82 ReturnInts(ap.strides) 83 ap.strides = nil 84 } 85 ap.shape = Shape(s).Clone() 86 ap.strides = ap.calcStrides() 87 } 88 } 89 90 // Shape returns the shape of the AP 91 func (ap *AP) Shape() Shape { return ap.shape } 92 93 // Strides returns the strides of the AP 94 func (ap *AP) Strides() []int { return ap.strides } 95 96 // Dims returns the dimensions of the shape in the AP 97 func (ap *AP) Dims() int { return ap.shape.Dims() } 98 99 // Size returns the expected array size of the shape 100 func (ap *AP) Size() int { return ap.shape.TotalSize() } 101 102 // String implements fmt.Stringer and runtime.Stringer 103 func (ap *AP) String() string { return fmt.Sprintf("%v", ap) } 104 105 // Format implements fmt.Formatter 106 func (ap *AP) Format(state fmt.State, c rune) { 107 fmt.Fprintf(state, "Shape: %v, Stride: %v, Lock: %t", ap.shape, ap.strides, ap.fin) 108 } 109 110 // IsVector returns whether the access pattern falls into one of three possible definitions of vectors: 111 // vanilla vector (not a row or a col) 112 // column vector 113 // row vector 114 func (ap *AP) IsVector() bool { return ap.shape.IsVector() } 115 116 // IsVectorLike returns true if the shape is vector-like (i.e. the shape only has one dim that is a non-1). 117 func (ap *AP) IsVectorLike() bool { 118 return ap.shape.IsVectorLike() && allones(ap.strides) 119 } 120 121 // IsColVec returns true when the access pattern has the shape (x, 1) 122 func (ap *AP) IsColVec() bool { return ap.shape.IsColVec() } 123 124 // IsRowVec returns true when the access pattern has the shape (1, x) 125 func (ap *AP) IsRowVec() bool { return ap.shape.IsRowVec() } 126 127 // IsScalar returns true if the access pattern indicates it's a scalar value. 128 func (ap *AP) IsScalar() bool { return ap.shape.IsScalar() } 129 130 // IsScalarEquiv returns true if the access pattern is equivalent to a scalar shape. 131 func (ap *AP) IsScalarEquiv() bool { return ap.shape.IsScalarEquiv() } 132 133 // IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices 134 func (ap *AP) IsMatrix() bool { return len(ap.shape) == 2 } 135 136 // IsZero tell us if the ap has zero size 137 func (ap *AP) IsZero() bool { 138 return len(ap.shape) == 0 && len(ap.strides) == 0 && !ap.fin && ap.o == 0 && ap.Δ == 0 139 } 140 141 // Zero zeros out an AP. 142 func (ap *AP) zero() { 143 // log.Printf("ZEROING. Called by %v", string(debug.Stack())) 144 145 // Jorge's original implementation for zeroing a AP is as below 146 // but to cater for the (*Dense).fix() method of the *Dense 147 // a nil shape is used to signal unsetness 148 // so we cannot just truncate the shape even though it would be a lot more efficient 149 150 // ap.shape = ap.shape[:0] 151 // ap.strides = ap.strides[:0] 152 ReturnInts([]int(ap.shape)) 153 ReturnInts(ap.strides) 154 ap.zeroOnly() 155 } 156 157 // side effect free zeroing 158 func (ap *AP) zeroOnly() { 159 ap.shape = nil 160 ap.strides = nil 161 162 ap.fin = false 163 ap.o = 0 164 ap.Δ = 0 165 } 166 167 func (ap *AP) zeroWithDims(dims int) { 168 //ap.shape = BorrowInts(dims) 169 //ap.strides = BorrowInts(dims) 170 if cap(ap.shape) >= dims { 171 ap.shape = ap.shape[:dims] 172 } 173 ap.shape = BorrowInts(dims) 174 if cap(ap.strides) >= dims { 175 ap.strides = ap.strides[:dims] 176 } 177 ap.strides = BorrowInts(dims) 178 } 179 180 // Clone clones the *AP. Clearly. It returns AP 181 func (ap *AP) Clone() (retVal AP) { 182 retVal = makeAP(cap(ap.shape)) 183 184 copy(retVal.shape, ap.shape) 185 copy(retVal.strides, ap.strides) 186 187 // handle vectors 188 retVal.shape = retVal.shape[:len(ap.shape)] 189 retVal.strides = retVal.strides[:len(ap.strides)] 190 191 retVal.fin = ap.fin 192 retVal.o = ap.o 193 retVal.Δ = ap.Δ 194 return 195 } 196 197 func (ap *AP) CloneTo(dest *AP) { 198 dest.shape = append(dest.shape[:0], ap.shape...) 199 dest.strides = append(dest.strides[:0], ap.strides...) 200 dest.fin = ap.fin 201 dest.o = ap.o 202 dest.Δ = ap.Δ 203 } 204 205 // DataOrder returns the data order of the AP. 206 func (ap *AP) DataOrder() DataOrder { return ap.o } 207 208 // C returns true if the access pattern is C-contiguous array 209 func (ap *AP) C() bool { return ap.o.IsRowMajor() && ap.o.IsContiguous() } 210 211 // F returns true if the access pattern is Fortran contiguous array 212 func (ap *AP) F() bool { return ap.o.IsColMajor() && ap.o.IsContiguous() } 213 214 // S returns the metadata of the sliced tensor. 215 func (ap *AP) S(size int, slices ...Slice) (newAP AP, ndStart, ndEnd int, err error) { 216 if len(slices) > len(ap.shape) { 217 // error 218 err = errors.Errorf(dimMismatch, len(ap.shape), len(slices)) 219 return 220 } 221 222 ndEnd = size 223 newShape := ap.shape.Clone() // the new shape 224 dims := ap.Dims() // reported dimensions 225 newStrides := BorrowInts(dims) // the new strides 226 227 var outerDim int 228 order := ap.o 229 if ap.o.IsRowMajor() || ap.IsVector() { 230 outerDim = 0 231 } else { 232 outerDim = len(ap.shape) - 1 233 } 234 235 for i := 0; i < dims; i++ { 236 var sl Slice 237 if i <= len(slices)-1 { 238 sl = slices[i] 239 } 240 241 size := ap.shape[i] 242 var stride int 243 stride = ap.strides[i] 244 // if ap.IsVector() { 245 // // handles non-vanilla vectors 246 // stride = ap.strides[0] 247 // } else { 248 // stride = ap.strides[i] 249 // } 250 251 var start, end, step int 252 if start, end, step, err = SliceDetails(sl, size); err != nil { 253 err = errors.Wrapf(err, "Unable to get slice details on slice %d with size %d: %v", i, sl, size) 254 return 255 } 256 257 // a slice where start == end is [] 258 ndStart = ndStart + start*stride 259 ndEnd = ndEnd - (size-end)*stride 260 261 if step > 0 { 262 if newShape[i] = (end - start) / step; (end-start)%step > 0 && i > 0 { 263 newShape[i]++ 264 } 265 newStrides[i] = stride * step 266 267 //fix 268 if newShape[i] <= 0 { 269 newShape[i] = 1 270 } 271 } else { 272 newShape[i] = (end - start) 273 newStrides[i] = stride 274 } 275 276 if (sl != nil && (!ap.IsVector() && i != outerDim)) || step > 1 { 277 order = MakeDataOrder(order, NonContiguous) 278 } 279 } 280 281 if ndEnd-ndStart == 1 { 282 // scalars are a special case 283 newAP = AP{} 284 newAP.SetShape() // make it a Scalar 285 newAP.lock() 286 } else { 287 288 // drop any dimension with size 1, except the last dimension 289 offset := 0 290 for d := 0; d < dims; d++ { 291 if newShape[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1 && dims > 2*/ { 292 newShape = append(newShape[:d], newShape[d+1:]...) 293 newStrides = append(newStrides[:d], newStrides[d+1:]...) 294 d-- 295 dims-- 296 offset++ 297 } 298 } 299 300 newAP = MakeAP(newShape, newStrides, order, ap.Δ) 301 } 302 return 303 } 304 305 // T returns the transposed metadata based on the given input 306 func (ap *AP) T(axes ...int) (retVal AP, a []int, err error) { 307 308 // prep axes 309 if len(axes) > 0 && len(axes) != ap.Dims() { 310 err = errors.Errorf(dimMismatch, ap.Dims(), len(axes)) 311 return 312 } 313 314 dims := len(ap.shape) 315 if len(axes) == 0 || axes == nil { 316 axes = make([]int, dims) 317 for i := 0; i < dims; i++ { 318 axes[i] = dims - 1 - i 319 } 320 } 321 a = axes 322 323 if ap.shape.IsScalarEquiv() { 324 return ap.Clone(), a, noopError{} 325 } 326 327 // if axes is 0, 1, 2, 3... then no op 328 if monotonic, incr1 := IsMonotonicInts(axes); monotonic && incr1 && axes[0] == 0 { 329 return ap.Clone(), a, noopError{} 330 } 331 332 currentShape := ap.shape 333 currentStride := ap.strides 334 shape := make(Shape, len(currentShape)) 335 strides := make([]int, len(currentStride)) 336 337 switch { 338 case ap.IsScalar(): 339 return 340 case ap.IsVector(): 341 if axes[0] == 0 { 342 return 343 } 344 strides[0], strides[1] = 1, 1 345 shape[0], shape[1] = currentShape[1], currentShape[0] 346 default: 347 copy(shape, currentShape) 348 copy(strides, currentStride) 349 err = UnsafePermute(axes, shape, strides) 350 if err != nil { 351 err = handleNoOp(err) 352 } 353 } 354 355 o := MakeDataOrder(ap.o, Transposed) 356 retVal = MakeAP(shape, strides, o, ap.Δ) 357 retVal.fin = true 358 return 359 } 360 361 // locking and unlocking is used to ensure that the shape and stride doesn't change (it's not really safe though, as a direct mutation of the strides/shape would still mutate it, but at least the dimensions cannot change) 362 func (ap *AP) lock() { ap.fin = true } 363 func (ap *AP) unlock() { ap.fin = false } 364 365 func (ap *AP) calcStrides() []int { 366 switch { 367 case ap.o.IsRowMajor(): 368 return ap.shape.CalcStrides() 369 case ap.o.IsColMajor(): 370 return ap.shape.CalcStridesColMajor() 371 } 372 panic("unreachable") 373 } 374 375 // setDataOrder is a method such that any tensor that embeds *AP will have the same method 376 func (ap *AP) setDataOrder(o DataOrder) { 377 if !o.HasSameOrder(ap.o) { 378 ap.o = ap.o.toggleColMajor() 379 } 380 } 381 382 // TransposeIndex returns the new index given the old index 383 func TransposeIndex(i int, oldShape, pattern, oldStrides, newStrides []int) int { 384 oldCoord, err := Itol(i, oldShape, oldStrides) 385 if err != nil { 386 panic(err) // or return error? 387 } 388 /* 389 coordss, _ := Permute(pattern, oldCoord) 390 coords := coordss[0] 391 index, _ := Ltoi(newShape, strides, coords...) 392 */ 393 394 // The above is the "conceptual" algorithm. 395 // Too many checks above slows things down, so the below is the "optimized" edition 396 var index int 397 for i, axis := range pattern { 398 index += oldCoord[axis] * newStrides[i] 399 } 400 return index 401 } 402 403 // UntransposeIndex returns the old index given the new index 404 func UntransposeIndex(i int, oldShape, pattern, oldStrides, newStrides []int) int { 405 newPattern := make([]int, len(pattern)) 406 for i, p := range pattern { 407 newPattern[p] = i 408 } 409 return TransposeIndex(i, oldShape, newPattern, oldStrides, newStrides) 410 } 411 412 // BroadcastStrides handles broadcasting from different shapes. 413 // 414 // Deprecated: this function will be unexported 415 func BroadcastStrides(destShape, srcShape Shape, destStrides, srcStrides []int) (retVal []int, err error) { 416 dims := len(destShape) 417 start := dims - len(srcShape) 418 419 if destShape.IsVector() && srcShape.IsVector() { 420 return []int{srcStrides[0]}, nil 421 } 422 423 if start < 0 { 424 //error 425 err = errors.Errorf(dimMismatch, dims, len(srcShape)) 426 return 427 } 428 429 retVal = BorrowInts(len(destStrides)) 430 for i := dims - 1; i >= start; i-- { 431 s := srcShape[i-start] 432 switch { 433 case s == 1: 434 retVal[i] = 0 435 case s != destShape[i]: 436 // error 437 err = errors.Errorf("Cannot broadcast from %v to %v", srcShape, destShape) 438 return 439 default: 440 retVal[i] = srcStrides[i-start] 441 } 442 } 443 for i := 0; i < start; i++ { 444 retVal[i] = 0 445 } 446 return 447 }