github.com/wzzhu/tensor@v0.9.24/iterator.go (about) 1 package tensor 2 3 import ( 4 "runtime" 5 ) 6 7 func requiresOrderedIterator(e Engine, t Tensor) bool { 8 if t.IsScalar() { 9 return false 10 } 11 if t.RequiresIterator() { 12 return true 13 } 14 switch tt := t.(type) { 15 case DenseTensor: 16 return !e.WorksWith(tt.DataOrder()) 17 case SparseTensor: 18 return true 19 } 20 panic("Unreachable") 21 } 22 23 // Iterator is the generic iterator interface. 24 // It's used to iterate across multi-dimensional slices, no matter the underlying data arrangement 25 type Iterator interface { 26 // Start returns the first index 27 Start() (int, error) 28 29 // Next returns the next index. Next is defined as the next value in the coordinates 30 // For example: let x be a (5,5) matrix that is row-major. Current index is for the coordinate (3,3). 31 // Next() returns the index of (3,4). 32 // 33 // If there is no underlying data store for (3,4) - say for example, the matrix is a sparse matrix, it return an error. 34 // If however, there is an underlying data store for (3,4), but it's not valid (for example, masked tensors), it will not return an error. 35 // 36 // Second example: let x be a (5,5) matrix that is col-major. Current index is for coordinate (3,3). 37 // Next() returns the index of (4,3). 38 Next() (int, error) 39 40 // NextValidity is like Next, but returns the validity of the value at the index as well. 41 NextValidity() (int, bool, error) 42 43 // NextValid returns the next valid index, as well as a skip count. 44 NextValid() (int, int, error) 45 46 // NextInvalid returns the next invalid index, as well as a skip count. 47 NextInvalid() (int, int, error) 48 49 // Reset resets the iterator 50 Reset() 51 52 // SetReverse tells the iterator to iterate in reverse 53 SetReverse() 54 55 // SetForward tells the iterator to iterate forwards 56 SetForward() 57 58 // Coord returns the coordinates 59 Coord() []int 60 61 // Done returns true when the iterator is done iterating. 62 Done() bool 63 64 // Shape returns the shape of the multidimensional tensor it's iterating on. 65 Shape() Shape 66 } 67 68 // NewIterator creates a new Iterator from an ap. The type of iterator depends on number of 69 // aps passed, and whether they are masked or not 70 func NewIterator(aps ...*AP) Iterator { 71 switch len(aps) { 72 case 0: 73 return nil 74 case 1: 75 return newFlatIterator(aps[0]) 76 default: 77 return NewMultIterator(aps...) 78 } 79 } 80 81 // IteratorFromDense creates a new Iterator from a list of dense tensors 82 func IteratorFromDense(tts ...DenseTensor) Iterator { 83 switch len(tts) { 84 case 0: 85 return nil 86 case 1: 87 if mt, ok := tts[0].(MaskedTensor); ok && mt.IsMasked() { 88 return FlatMaskedIteratorFromDense(mt) 89 } 90 return FlatIteratorFromDense(tts[0]) 91 default: 92 return MultIteratorFromDense(tts...) 93 } 94 } 95 96 func destroyIterator(it Iterator) { 97 switch itt := it.(type) { 98 case *MultIterator: 99 destroyMultIterator(itt) 100 } 101 } 102 103 func iteratorLoadAP(it Iterator, ap *AP) { 104 switch itt := it.(type) { 105 case *FlatIterator: 106 itt.AP = ap 107 case *FlatMaskedIterator: 108 itt.AP = ap 109 case *MultIterator: // Do nothing, TODO: perhaps add something here 110 111 } 112 } 113 114 /* FLAT ITERATOR */ 115 116 // FlatIterator is an iterator that iterates over Tensors according to the data's layout. 117 // It utilizes the *AP of a Tensor to determine what the next index is. 118 // This data structure is similar to Numpy's flatiter, with some standard Go based restrictions of course 119 // (such as, not allowing negative indices) 120 type FlatIterator struct { 121 *AP 122 123 //state 124 track []int 125 nextIndex int 126 lastIndex int 127 size int 128 done bool 129 veclikeDim int // the dimension of a vectorlike shape that is not a 1. 130 reverse bool // if true, iterator starts at end of array and runs backwards 131 132 isScalar bool 133 isVector bool 134 135 outerFirst bool 136 } 137 138 // newFlatIterator creates a new FlatIterator. 139 func newFlatIterator(ap *AP) *FlatIterator { 140 var dim int 141 if ap.IsVectorLike() { 142 for d, i := range ap.shape { 143 if i != 1 { 144 dim = d 145 break 146 } 147 } 148 } 149 150 return &FlatIterator{ 151 AP: ap, 152 track: make([]int, len(ap.shape)), 153 size: ap.shape.TotalSize(), 154 veclikeDim: dim, 155 156 isScalar: ap.IsScalar(), 157 isVector: ap.IsVectorLike(), 158 } 159 } 160 161 // FlatIteratorFromDense creates a new FlatIterator from a dense tensor 162 func FlatIteratorFromDense(tt DenseTensor) *FlatIterator { 163 return newFlatIterator(tt.Info()) 164 } 165 166 // SetReverse initializes iterator to run backwards 167 func (it *FlatIterator) SetReverse() { 168 it.reverse = true 169 it.Reset() 170 return 171 } 172 173 // SetForward initializes iterator to run forwards 174 func (it *FlatIterator) SetForward() { 175 it.reverse = false 176 it.Reset() 177 return 178 } 179 180 //Start begins iteration 181 func (it *FlatIterator) Start() (int, error) { 182 it.Reset() 183 return it.Next() 184 } 185 186 //Done checks whether iterators are done 187 func (it *FlatIterator) Done() bool { 188 return it.done 189 } 190 191 // Next returns the index of the current coordinate. 192 func (it *FlatIterator) Next() (int, error) { 193 if it.done { 194 return -1, noopError{} 195 } 196 197 switch { 198 case it.isScalar: 199 it.done = true 200 return 0, nil 201 case it.isVector: 202 if it.reverse { 203 return it.singlePrevious() 204 } 205 return it.singleNext() 206 default: 207 if it.reverse { 208 return it.ndPrevious() 209 } 210 if it.outerFirst { 211 return it.colMajorNDNext() 212 } 213 return it.ndNext() 214 } 215 } 216 217 // NextValidity returns the index of the current coordinate, and whether or not it's valid. Identical to Next() 218 func (it *FlatIterator) NextValidity() (int, bool, error) { 219 i, err := it.Next() 220 return i, true, err 221 } 222 223 // NextValid returns the index of the current coordinate. Identical to Next for FlatIterator 224 // Also returns the number of increments to get to next element ( 1, or -1 in reverse case). This is to maintain 225 // consistency with the masked iterator, for which the step between valid elements can be more than 1 226 func (it *FlatIterator) NextValid() (int, int, error) { 227 if it.done { 228 return -1, 1, noopError{} 229 } 230 switch { 231 case it.isScalar: 232 it.done = true 233 return 0, 0, nil 234 case it.isVector: 235 if it.reverse { 236 a, err := it.singlePrevious() 237 return a, -1, err 238 } 239 a, err := it.singleNext() 240 return a, 1, err 241 default: 242 if it.reverse { 243 a, err := it.ndPrevious() 244 return a, -1, err 245 } 246 247 if it.outerFirst { 248 a, err := it.colMajorNDNext() 249 return a, 1, err 250 } 251 a, err := it.ndNext() 252 return a, 1, err 253 } 254 } 255 256 // NextInvalid returns the index of the current coordinate. Identical to Next for FlatIterator 257 // also returns the number of increments to get to next invalid element (1 or -1 in reverse case). 258 // Like NextValid, this method's purpose is to maintain consistency with the masked iterator, 259 // for which the step between invalid elements can be anywhere from 0 to the tensor's length 260 func (it *FlatIterator) NextInvalid() (int, int, error) { 261 if it.reverse { 262 return -1, -it.lastIndex, noopError{} 263 } 264 return -1, it.Size() - it.lastIndex, noopError{} 265 } 266 267 func (it *FlatIterator) singleNext() (int, error) { 268 it.lastIndex = it.nextIndex 269 it.nextIndex++ 270 271 var tracked int 272 it.track[it.veclikeDim]++ 273 tracked = it.track[it.veclikeDim] 274 275 if tracked >= it.size { 276 it.done = true 277 } 278 279 return it.lastIndex, nil 280 } 281 282 func (it *FlatIterator) singlePrevious() (int, error) { 283 it.lastIndex = it.nextIndex 284 it.nextIndex-- 285 286 var tracked int 287 it.track[it.veclikeDim]-- 288 tracked = it.track[it.veclikeDim] 289 290 if tracked < 0 { 291 it.done = true 292 } 293 return it.lastIndex, nil 294 } 295 296 func (it *FlatIterator) ndNext() (int, error) { 297 // the reason for this weird looking bits of code is because the SSA compiler doesn't 298 // know how to optimize for this bit of code, not keeping things in registers correctly 299 // @stuartcarnie optimized this iout to great effect 300 301 v := len(it.shape) - 1 302 nextIndex := it.nextIndex 303 it.lastIndex = nextIndex 304 305 // the following 3 lines causes the compiler to perform bounds check here, 306 // instead of being done in the loop 307 coord := it.shape[:v+1] 308 track := it.track[:v+1] 309 strides := it.strides[:v+1] 310 for i := v; i >= 0; i-- { 311 track[i]++ 312 shapeI := coord[i] 313 strideI := strides[i] 314 315 if track[i] == shapeI { 316 if i == 0 { 317 it.done = true 318 } 319 track[i] = 0 320 nextIndex -= (shapeI - 1) * strideI 321 continue 322 } 323 nextIndex += strideI 324 break 325 } 326 it.nextIndex = nextIndex 327 return it.lastIndex, nil 328 } 329 330 func (it *FlatIterator) colMajorNDNext() (int, error) { 331 // the reason for this weird looking bits of code is because the SSA compiler doesn't 332 // know how to optimize for this bit of code, not keeping things in registers correctly 333 // @stuartcarnie optimized this iout to great effect 334 335 v := len(it.shape) - 1 336 nextIndex := it.nextIndex 337 it.lastIndex = nextIndex 338 339 // the following 3 lines causes the compiler to perform bounds check here, 340 // instead of being done in the loop 341 coord := it.shape[:v+1] 342 track := it.track[:v+1] 343 strides := it.strides[:v+1] 344 for i := 0; i <= v; i++ { 345 track[i]++ 346 shapeI := coord[i] 347 strideI := strides[i] 348 349 if track[i] == shapeI { 350 if i == v { 351 it.done = true 352 } 353 track[i] = 0 354 355 nextIndex -= (shapeI - 1) * strideI 356 continue 357 } 358 nextIndex += strideI 359 break 360 } 361 it.nextIndex = nextIndex 362 return it.lastIndex, nil 363 364 } 365 366 func (it *FlatIterator) ndPrevious() (int, error) { 367 it.lastIndex = it.nextIndex 368 for i := len(it.shape) - 1; i >= 0; i-- { 369 it.track[i]-- 370 if it.track[i] < 0 { 371 if i == 0 { 372 it.done = true 373 } 374 it.track[i] = it.shape[i] - 1 375 it.nextIndex += (it.shape[i] - 1) * it.strides[i] 376 continue 377 } 378 it.nextIndex -= it.strides[i] 379 break 380 } 381 return it.lastIndex, nil 382 } 383 384 // TODO v0.9.0 385 func (it *FlatIterator) colMajorNDPrevious() (int, error) { 386 return 0, nil 387 } 388 389 // Coord returns the next coordinate. 390 // When Next() is called, the coordinates are updated AFTER the Next() returned. 391 // See example for more details. 392 // 393 // The returned coordinates is mutable. Changing any values in the return value will 394 // change the state of the iterator 395 func (it *FlatIterator) Coord() []int { return it.track } 396 397 // Slice is a convenience function that augments 398 func (it *FlatIterator) Slice(sli Slice) (retVal []int, err error) { 399 var next int 400 var nexts []int 401 for next, err = it.Next(); err == nil; next, err = it.Next() { 402 nexts = append(nexts, next) 403 } 404 if _, ok := err.(NoOpError); err != nil && !ok { 405 return 406 } 407 408 if sli == nil { 409 retVal = nexts 410 return 411 } 412 413 start := sli.Start() 414 end := sli.End() 415 step := sli.Step() 416 417 // sanity checks 418 if err = CheckSlice(sli, len(nexts)); err != nil { 419 return 420 } 421 422 if step < 0 { 423 // reverse the nexts 424 for i := len(nexts)/2 - 1; i >= 0; i-- { 425 j := len(nexts) - 1 - i 426 nexts[i], nexts[j] = nexts[j], nexts[i] 427 } 428 step = -step 429 } 430 431 // cleanup before loop 432 if end > len(nexts) { 433 end = len(nexts) 434 } 435 // nexts = nexts[:end] 436 437 for i := start; i < end; i += step { 438 retVal = append(retVal, nexts[i]) 439 } 440 441 err = nil 442 return 443 } 444 445 // Reset resets the iterator state. 446 func (it *FlatIterator) Reset() { 447 it.done = false 448 if it.reverse { 449 for i := range it.track { 450 it.track[i] = it.shape[i] - 1 451 } 452 453 switch { 454 case it.IsScalar(): 455 it.nextIndex = 0 456 case it.isVector: 457 it.nextIndex = (it.shape[0] - 1) * it.strides[0] 458 // case it.IsRowVec(): 459 // it.nextIndex = (it.shape[1] - 1) * it.strides[1] 460 // case it.IsColVec(): 461 // it.nextIndex = (it.shape[0] - 1) * it.strides[0] 462 default: 463 it.nextIndex = 0 464 for i := range it.track { 465 it.nextIndex += (it.shape[i] - 1) * it.strides[i] 466 } 467 } 468 } else { 469 it.nextIndex = 0 470 for i := range it.track { 471 it.track[i] = 0 472 } 473 } 474 } 475 476 // Chan returns a channel of ints. This is useful for iterating multiple Tensors at the same time. 477 func (it *FlatIterator) Chan() (retVal chan int) { 478 retVal = make(chan int) 479 480 go func() { 481 for next, err := it.Next(); err == nil; next, err = it.Next() { 482 retVal <- next 483 } 484 close(retVal) 485 }() 486 487 return 488 } 489 490 /* FLAT MASKED ITERATOR */ 491 492 // FlatMaskedIterator is an iterator that iterates over simple masked Tensors. 493 // It is used when the mask stride is identical to data stride with the exception of trailing zeros, 494 // in which case the data index is always a perfect integer multiple of the mask index 495 type FlatMaskedIterator struct { 496 *FlatIterator 497 mask []bool 498 } 499 500 // FlatMaskedIteratorFromDense creates a new FlatMaskedIterator from dense tensor 501 func FlatMaskedIteratorFromDense(tt MaskedTensor) *FlatMaskedIterator { 502 it := new(FlatMaskedIterator) 503 runtime.SetFinalizer(it, destroyIterator) 504 it.FlatIterator = FlatIteratorFromDense(tt) 505 it.mask = tt.Mask() 506 return it 507 } 508 509 func (it *FlatMaskedIterator) NextValidity() (int, bool, error) { 510 if len(it.mask) == 0 { 511 return it.FlatIterator.NextValidity() 512 } 513 514 var i int 515 var err error 516 if i, err = it.Next(); err == nil { 517 return i, !it.mask[i], err 518 } 519 return -1, false, err 520 } 521 522 // NextValid returns the index of the next valid element, 523 // as well as the number of increments to get to next element 524 func (it *FlatMaskedIterator) NextValid() (int, int, error) { 525 if len(it.mask) == 0 { 526 return it.FlatIterator.NextValid() 527 } 528 var count int 529 var mult = 1 530 if it.reverse { 531 mult = -1 532 } 533 534 for i, err := it.Next(); err == nil; i, err = it.Next() { 535 count++ 536 if !(it.mask[i]) { 537 return i, mult * count, err 538 } 539 } 540 return -1, mult * count, noopError{} 541 } 542 543 // NextInvalid returns the index of the next invalid element 544 // as well as the number of increments to get to next invalid element 545 func (it *FlatMaskedIterator) NextInvalid() (int, int, error) { 546 if it.mask == nil { 547 return it.FlatIterator.NextInvalid() 548 } 549 var count int 550 var mult = 1 551 if it.reverse { 552 mult = -1 553 } 554 for i, err := it.Next(); err == nil; i, err = it.Next() { 555 count++ 556 if it.mask[i] { 557 return i, mult * count, err 558 } 559 } 560 return -1, mult * count, noopError{} 561 } 562 563 // FlatSparseIterator is an iterator that works very much in the same way as flatiterator, except for sparse tensors 564 type FlatSparseIterator struct { 565 *CS 566 567 //state 568 nextIndex int 569 lastIndex int 570 track []int 571 done bool 572 reverse bool 573 } 574 575 func NewFlatSparseIterator(t *CS) *FlatSparseIterator { 576 it := new(FlatSparseIterator) 577 it.CS = t 578 it.track = BorrowInts(len(t.s)) 579 return it 580 } 581 582 func (it *FlatSparseIterator) Start() (int, error) { 583 it.Reset() 584 return it.Next() 585 } 586 587 func (it *FlatSparseIterator) Next() (int, error) { 588 if it.done { 589 return -1, noopError{} 590 } 591 592 // var ok bool 593 it.lastIndex, _ = it.at(it.track...) 594 595 // increment the coordinates 596 for i := len(it.s) - 1; i >= 0; i-- { 597 it.track[i]++ 598 if it.track[i] == it.s[i] { 599 if i == 0 { 600 it.done = true 601 } 602 it.track[i] = 0 603 continue 604 } 605 break 606 } 607 608 return it.lastIndex, nil 609 } 610 611 func (it *FlatSparseIterator) NextValidity() (int, bool, error) { 612 i, err := it.Next() 613 if i == -1 { 614 return i, false, err 615 } 616 return i, true, err 617 } 618 619 func (it *FlatSparseIterator) NextValid() (int, int, error) { 620 var i int 621 var err error 622 for i, err = it.Next(); err == nil && i == -1; i, err = it.Next() { 623 624 } 625 return i, -1, err 626 } 627 628 func (it *FlatSparseIterator) NextInvalid() (int, int, error) { 629 var i int 630 var err error 631 for i, err = it.Next(); err == nil && i != -1; i, err = it.Next() { 632 633 } 634 return i, -1, err 635 } 636 637 func (it *FlatSparseIterator) Reset() { 638 if it.reverse { 639 for i := range it.track { 640 it.track[i] = it.s[i] - 1 641 } 642 643 } else { 644 it.nextIndex = 0 645 for i := range it.track { 646 it.track[i] = 0 647 } 648 } 649 it.done = false 650 } 651 652 func (it *FlatSparseIterator) SetReverse() { 653 it.reverse = true 654 it.Reset() 655 } 656 657 func (it *FlatSparseIterator) SetForward() { 658 it.reverse = false 659 it.Reset() 660 } 661 662 func (it *FlatSparseIterator) Coord() []int { 663 return it.track 664 } 665 666 func (it *FlatSparseIterator) Done() bool { 667 return it.done 668 } 669 670 /* TEMPORARILY REMOVED 671 // SortedMultiStridePerm takes multiple input strides, and creates a sorted stride permutation. 672 // It's based very closely on Numpy's PyArray_CreateMultiSortedStridePerm, where a stable insertion sort is used 673 // to create the permutations. 674 func SortedMultiStridePerm(dims int, aps []*AP) (retVal []int) { 675 retVal = BorrowInts(dims) 676 for i := 0; i < dims; i++ { 677 retVal[i] = i 678 } 679 680 for i := 1; i < dims; i++ { 681 ipos := i 682 axisi := retVal[i] 683 684 for j := i - 1; j >= 0; j-- { 685 var ambig, swap bool 686 ambig = true 687 axisj := retVal[j] 688 689 for _, ap := range aps { 690 if ap.shape[axisi] != 1 && ap.shape[axisj] != 1 { 691 if ap.strides[axisi] <= ap.strides[axisj] { 692 swap = true 693 } else if ambig { 694 swap = true 695 } 696 ambig = false 697 } 698 } 699 700 if !ambig && swap { 701 ipos = j 702 } else { 703 break 704 } 705 706 } 707 if ipos != i { 708 for j := i; j > ipos; j-- { 709 retVal[j] = retVal[j-1] 710 } 711 retVal[ipos] = axisi 712 } 713 } 714 return 715 } 716 */