github.com/wzzhu/tensor@v0.9.24/iterator.go (about)

     1  package tensor
     2  
     3  import (
     4  	"runtime"
     5  )
     6  
     7  func requiresOrderedIterator(e Engine, t Tensor) bool {
     8  	if t.IsScalar() {
     9  		return false
    10  	}
    11  	if t.RequiresIterator() {
    12  		return true
    13  	}
    14  	switch tt := t.(type) {
    15  	case DenseTensor:
    16  		return !e.WorksWith(tt.DataOrder())
    17  	case SparseTensor:
    18  		return true
    19  	}
    20  	panic("Unreachable")
    21  }
    22  
    23  // Iterator is the generic iterator interface.
    24  // It's used to iterate across multi-dimensional slices, no matter the underlying data arrangement
    25  type Iterator interface {
    26  	// Start returns the first index
    27  	Start() (int, error)
    28  
    29  	// Next returns the next index. Next is defined as the next value in the coordinates
    30  	// For example: let x be a (5,5) matrix that is row-major. Current index is for the coordinate (3,3).
    31  	// Next() returns the index of (3,4).
    32  	//
    33  	// If there is no underlying data store for (3,4) - say for example, the matrix is a sparse matrix, it return an error.
    34  	// If however, there is an underlying data store for (3,4), but it's not valid (for example, masked tensors), it will not return an error.
    35  	//
    36  	// Second example: let x be a (5,5) matrix that is col-major. Current index is for coordinate (3,3).
    37  	// Next() returns the index of (4,3).
    38  	Next() (int, error)
    39  
    40  	// NextValidity is like Next, but returns the validity of the value at the index as well.
    41  	NextValidity() (int, bool, error)
    42  
    43  	// NextValid returns the next valid index, as well as a skip count.
    44  	NextValid() (int, int, error)
    45  
    46  	// NextInvalid returns the next invalid index, as well as a skip count.
    47  	NextInvalid() (int, int, error)
    48  
    49  	// Reset resets the iterator
    50  	Reset()
    51  
    52  	// SetReverse tells the iterator to iterate in reverse
    53  	SetReverse()
    54  
    55  	// SetForward tells the iterator to iterate forwards
    56  	SetForward()
    57  
    58  	// Coord returns the coordinates
    59  	Coord() []int
    60  
    61  	// Done returns true when the iterator is done iterating.
    62  	Done() bool
    63  
    64  	// Shape returns the shape of the multidimensional tensor it's iterating on.
    65  	Shape() Shape
    66  }
    67  
    68  // NewIterator creates a new Iterator from an ap. The type of iterator depends on number of
    69  // aps passed, and whether they are masked or not
    70  func NewIterator(aps ...*AP) Iterator {
    71  	switch len(aps) {
    72  	case 0:
    73  		return nil
    74  	case 1:
    75  		return newFlatIterator(aps[0])
    76  	default:
    77  		return NewMultIterator(aps...)
    78  	}
    79  }
    80  
    81  // IteratorFromDense creates a new Iterator from a list of dense tensors
    82  func IteratorFromDense(tts ...DenseTensor) Iterator {
    83  	switch len(tts) {
    84  	case 0:
    85  		return nil
    86  	case 1:
    87  		if mt, ok := tts[0].(MaskedTensor); ok && mt.IsMasked() {
    88  			return FlatMaskedIteratorFromDense(mt)
    89  		}
    90  		return FlatIteratorFromDense(tts[0])
    91  	default:
    92  		return MultIteratorFromDense(tts...)
    93  	}
    94  }
    95  
    96  func destroyIterator(it Iterator) {
    97  	switch itt := it.(type) {
    98  	case *MultIterator:
    99  		destroyMultIterator(itt)
   100  	}
   101  }
   102  
   103  func iteratorLoadAP(it Iterator, ap *AP) {
   104  	switch itt := it.(type) {
   105  	case *FlatIterator:
   106  		itt.AP = ap
   107  	case *FlatMaskedIterator:
   108  		itt.AP = ap
   109  	case *MultIterator: // Do nothing, TODO: perhaps add something here
   110  
   111  	}
   112  }
   113  
   114  /* FLAT ITERATOR */
   115  
   116  // FlatIterator is an iterator that iterates over Tensors according to the data's layout.
   117  // It utilizes the *AP of a Tensor to determine what the next index is.
   118  // This data structure is similar to Numpy's flatiter, with some standard Go based restrictions of course
   119  // (such as, not allowing negative indices)
   120  type FlatIterator struct {
   121  	*AP
   122  
   123  	//state
   124  	track      []int
   125  	nextIndex  int
   126  	lastIndex  int
   127  	size       int
   128  	done       bool
   129  	veclikeDim int  // the dimension of a  vectorlike shape that is not a 1.
   130  	reverse    bool // if true, iterator starts at end of array and runs backwards
   131  
   132  	isScalar bool
   133  	isVector bool
   134  
   135  	outerFirst bool
   136  }
   137  
   138  // newFlatIterator creates a new FlatIterator.
   139  func newFlatIterator(ap *AP) *FlatIterator {
   140  	var dim int
   141  	if ap.IsVectorLike() {
   142  		for d, i := range ap.shape {
   143  			if i != 1 {
   144  				dim = d
   145  				break
   146  			}
   147  		}
   148  	}
   149  
   150  	return &FlatIterator{
   151  		AP:         ap,
   152  		track:      make([]int, len(ap.shape)),
   153  		size:       ap.shape.TotalSize(),
   154  		veclikeDim: dim,
   155  
   156  		isScalar: ap.IsScalar(),
   157  		isVector: ap.IsVectorLike(),
   158  	}
   159  }
   160  
   161  // FlatIteratorFromDense creates a new FlatIterator from a dense tensor
   162  func FlatIteratorFromDense(tt DenseTensor) *FlatIterator {
   163  	return newFlatIterator(tt.Info())
   164  }
   165  
   166  // SetReverse initializes iterator to run backwards
   167  func (it *FlatIterator) SetReverse() {
   168  	it.reverse = true
   169  	it.Reset()
   170  	return
   171  }
   172  
   173  // SetForward initializes iterator to run forwards
   174  func (it *FlatIterator) SetForward() {
   175  	it.reverse = false
   176  	it.Reset()
   177  	return
   178  }
   179  
   180  //Start begins iteration
   181  func (it *FlatIterator) Start() (int, error) {
   182  	it.Reset()
   183  	return it.Next()
   184  }
   185  
   186  //Done checks whether iterators are done
   187  func (it *FlatIterator) Done() bool {
   188  	return it.done
   189  }
   190  
   191  // Next returns the index of the current coordinate.
   192  func (it *FlatIterator) Next() (int, error) {
   193  	if it.done {
   194  		return -1, noopError{}
   195  	}
   196  
   197  	switch {
   198  	case it.isScalar:
   199  		it.done = true
   200  		return 0, nil
   201  	case it.isVector:
   202  		if it.reverse {
   203  			return it.singlePrevious()
   204  		}
   205  		return it.singleNext()
   206  	default:
   207  		if it.reverse {
   208  			return it.ndPrevious()
   209  		}
   210  		if it.outerFirst {
   211  			return it.colMajorNDNext()
   212  		}
   213  		return it.ndNext()
   214  	}
   215  }
   216  
   217  // NextValidity returns the index of the current coordinate, and whether or not it's valid. Identical to Next()
   218  func (it *FlatIterator) NextValidity() (int, bool, error) {
   219  	i, err := it.Next()
   220  	return i, true, err
   221  }
   222  
   223  // NextValid returns the index of the current coordinate. Identical to Next for FlatIterator
   224  // Also returns the number of increments to get to next element ( 1,  or -1 in reverse case). This is to maintain
   225  // consistency with the masked iterator, for which the step between valid elements can be more than 1
   226  func (it *FlatIterator) NextValid() (int, int, error) {
   227  	if it.done {
   228  		return -1, 1, noopError{}
   229  	}
   230  	switch {
   231  	case it.isScalar:
   232  		it.done = true
   233  		return 0, 0, nil
   234  	case it.isVector:
   235  		if it.reverse {
   236  			a, err := it.singlePrevious()
   237  			return a, -1, err
   238  		}
   239  		a, err := it.singleNext()
   240  		return a, 1, err
   241  	default:
   242  		if it.reverse {
   243  			a, err := it.ndPrevious()
   244  			return a, -1, err
   245  		}
   246  
   247  		if it.outerFirst {
   248  			a, err := it.colMajorNDNext()
   249  			return a, 1, err
   250  		}
   251  		a, err := it.ndNext()
   252  		return a, 1, err
   253  	}
   254  }
   255  
   256  // NextInvalid returns the index of the current coordinate. Identical to Next for FlatIterator
   257  // also returns the number of increments to get to next invalid element (1 or -1 in reverse case).
   258  // Like NextValid, this method's purpose is to maintain consistency with the masked iterator,
   259  // for which the step between invalid elements can be anywhere from 0 to the  tensor's length
   260  func (it *FlatIterator) NextInvalid() (int, int, error) {
   261  	if it.reverse {
   262  		return -1, -it.lastIndex, noopError{}
   263  	}
   264  	return -1, it.Size() - it.lastIndex, noopError{}
   265  }
   266  
   267  func (it *FlatIterator) singleNext() (int, error) {
   268  	it.lastIndex = it.nextIndex
   269  	it.nextIndex++
   270  
   271  	var tracked int
   272  	it.track[it.veclikeDim]++
   273  	tracked = it.track[it.veclikeDim]
   274  
   275  	if tracked >= it.size {
   276  		it.done = true
   277  	}
   278  
   279  	return it.lastIndex, nil
   280  }
   281  
   282  func (it *FlatIterator) singlePrevious() (int, error) {
   283  	it.lastIndex = it.nextIndex
   284  	it.nextIndex--
   285  
   286  	var tracked int
   287  	it.track[it.veclikeDim]--
   288  	tracked = it.track[it.veclikeDim]
   289  
   290  	if tracked < 0 {
   291  		it.done = true
   292  	}
   293  	return it.lastIndex, nil
   294  }
   295  
   296  func (it *FlatIterator) ndNext() (int, error) {
   297  	// the reason for this weird looking bits of code is because the SSA compiler doesn't
   298  	// know how to optimize for this bit of code, not keeping things in registers correctly
   299  	// @stuartcarnie optimized this iout to great effect
   300  
   301  	v := len(it.shape) - 1
   302  	nextIndex := it.nextIndex
   303  	it.lastIndex = nextIndex
   304  
   305  	// the following 3 lines causes the compiler to perform bounds check here,
   306  	// instead of being done in the loop
   307  	coord := it.shape[:v+1]
   308  	track := it.track[:v+1]
   309  	strides := it.strides[:v+1]
   310  	for i := v; i >= 0; i-- {
   311  		track[i]++
   312  		shapeI := coord[i]
   313  		strideI := strides[i]
   314  
   315  		if track[i] == shapeI {
   316  			if i == 0 {
   317  				it.done = true
   318  			}
   319  			track[i] = 0
   320  			nextIndex -= (shapeI - 1) * strideI
   321  			continue
   322  		}
   323  		nextIndex += strideI
   324  		break
   325  	}
   326  	it.nextIndex = nextIndex
   327  	return it.lastIndex, nil
   328  }
   329  
   330  func (it *FlatIterator) colMajorNDNext() (int, error) {
   331  	// the reason for this weird looking bits of code is because the SSA compiler doesn't
   332  	// know how to optimize for this bit of code, not keeping things in registers correctly
   333  	// @stuartcarnie optimized this iout to great effect
   334  
   335  	v := len(it.shape) - 1
   336  	nextIndex := it.nextIndex
   337  	it.lastIndex = nextIndex
   338  
   339  	// the following 3 lines causes the compiler to perform bounds check here,
   340  	// instead of being done in the loop
   341  	coord := it.shape[:v+1]
   342  	track := it.track[:v+1]
   343  	strides := it.strides[:v+1]
   344  	for i := 0; i <= v; i++ {
   345  		track[i]++
   346  		shapeI := coord[i]
   347  		strideI := strides[i]
   348  
   349  		if track[i] == shapeI {
   350  			if i == v {
   351  				it.done = true
   352  			}
   353  			track[i] = 0
   354  
   355  			nextIndex -= (shapeI - 1) * strideI
   356  			continue
   357  		}
   358  		nextIndex += strideI
   359  		break
   360  	}
   361  	it.nextIndex = nextIndex
   362  	return it.lastIndex, nil
   363  
   364  }
   365  
   366  func (it *FlatIterator) ndPrevious() (int, error) {
   367  	it.lastIndex = it.nextIndex
   368  	for i := len(it.shape) - 1; i >= 0; i-- {
   369  		it.track[i]--
   370  		if it.track[i] < 0 {
   371  			if i == 0 {
   372  				it.done = true
   373  			}
   374  			it.track[i] = it.shape[i] - 1
   375  			it.nextIndex += (it.shape[i] - 1) * it.strides[i]
   376  			continue
   377  		}
   378  		it.nextIndex -= it.strides[i]
   379  		break
   380  	}
   381  	return it.lastIndex, nil
   382  }
   383  
   384  // TODO v0.9.0
   385  func (it *FlatIterator) colMajorNDPrevious() (int, error) {
   386  	return 0, nil
   387  }
   388  
   389  // Coord returns the next coordinate.
   390  // When Next() is called, the coordinates are updated AFTER the Next() returned.
   391  // See example for more details.
   392  //
   393  // The returned coordinates is mutable. Changing any values in the return value will
   394  // change the state of the iterator
   395  func (it *FlatIterator) Coord() []int { return it.track }
   396  
   397  // Slice is a convenience function that augments
   398  func (it *FlatIterator) Slice(sli Slice) (retVal []int, err error) {
   399  	var next int
   400  	var nexts []int
   401  	for next, err = it.Next(); err == nil; next, err = it.Next() {
   402  		nexts = append(nexts, next)
   403  	}
   404  	if _, ok := err.(NoOpError); err != nil && !ok {
   405  		return
   406  	}
   407  
   408  	if sli == nil {
   409  		retVal = nexts
   410  		return
   411  	}
   412  
   413  	start := sli.Start()
   414  	end := sli.End()
   415  	step := sli.Step()
   416  
   417  	// sanity checks
   418  	if err = CheckSlice(sli, len(nexts)); err != nil {
   419  		return
   420  	}
   421  
   422  	if step < 0 {
   423  		// reverse the nexts
   424  		for i := len(nexts)/2 - 1; i >= 0; i-- {
   425  			j := len(nexts) - 1 - i
   426  			nexts[i], nexts[j] = nexts[j], nexts[i]
   427  		}
   428  		step = -step
   429  	}
   430  
   431  	// cleanup before loop
   432  	if end > len(nexts) {
   433  		end = len(nexts)
   434  	}
   435  	// nexts = nexts[:end]
   436  
   437  	for i := start; i < end; i += step {
   438  		retVal = append(retVal, nexts[i])
   439  	}
   440  
   441  	err = nil
   442  	return
   443  }
   444  
   445  // Reset resets the iterator state.
   446  func (it *FlatIterator) Reset() {
   447  	it.done = false
   448  	if it.reverse {
   449  		for i := range it.track {
   450  			it.track[i] = it.shape[i] - 1
   451  		}
   452  
   453  		switch {
   454  		case it.IsScalar():
   455  			it.nextIndex = 0
   456  		case it.isVector:
   457  			it.nextIndex = (it.shape[0] - 1) * it.strides[0]
   458  		// case it.IsRowVec():
   459  		// 	it.nextIndex = (it.shape[1] - 1) * it.strides[1]
   460  		// case it.IsColVec():
   461  		// 	it.nextIndex = (it.shape[0] - 1) * it.strides[0]
   462  		default:
   463  			it.nextIndex = 0
   464  			for i := range it.track {
   465  				it.nextIndex += (it.shape[i] - 1) * it.strides[i]
   466  			}
   467  		}
   468  	} else {
   469  		it.nextIndex = 0
   470  		for i := range it.track {
   471  			it.track[i] = 0
   472  		}
   473  	}
   474  }
   475  
   476  // Chan returns a channel of ints. This is useful for iterating multiple Tensors at the same time.
   477  func (it *FlatIterator) Chan() (retVal chan int) {
   478  	retVal = make(chan int)
   479  
   480  	go func() {
   481  		for next, err := it.Next(); err == nil; next, err = it.Next() {
   482  			retVal <- next
   483  		}
   484  		close(retVal)
   485  	}()
   486  
   487  	return
   488  }
   489  
   490  /* FLAT MASKED ITERATOR */
   491  
   492  // FlatMaskedIterator is an iterator that iterates over simple masked Tensors.
   493  // It is used when the mask stride is identical to data stride with the exception of trailing zeros,
   494  // in which case the data index is always a perfect integer multiple of the mask index
   495  type FlatMaskedIterator struct {
   496  	*FlatIterator
   497  	mask []bool
   498  }
   499  
   500  // FlatMaskedIteratorFromDense creates a new FlatMaskedIterator from dense tensor
   501  func FlatMaskedIteratorFromDense(tt MaskedTensor) *FlatMaskedIterator {
   502  	it := new(FlatMaskedIterator)
   503  	runtime.SetFinalizer(it, destroyIterator)
   504  	it.FlatIterator = FlatIteratorFromDense(tt)
   505  	it.mask = tt.Mask()
   506  	return it
   507  }
   508  
   509  func (it *FlatMaskedIterator) NextValidity() (int, bool, error) {
   510  	if len(it.mask) == 0 {
   511  		return it.FlatIterator.NextValidity()
   512  	}
   513  
   514  	var i int
   515  	var err error
   516  	if i, err = it.Next(); err == nil {
   517  		return i, !it.mask[i], err
   518  	}
   519  	return -1, false, err
   520  }
   521  
   522  // NextValid returns the index of the next valid element,
   523  // as well as the number of increments to get to next element
   524  func (it *FlatMaskedIterator) NextValid() (int, int, error) {
   525  	if len(it.mask) == 0 {
   526  		return it.FlatIterator.NextValid()
   527  	}
   528  	var count int
   529  	var mult = 1
   530  	if it.reverse {
   531  		mult = -1
   532  	}
   533  
   534  	for i, err := it.Next(); err == nil; i, err = it.Next() {
   535  		count++
   536  		if !(it.mask[i]) {
   537  			return i, mult * count, err
   538  		}
   539  	}
   540  	return -1, mult * count, noopError{}
   541  }
   542  
   543  // NextInvalid returns the index of the next invalid element
   544  // as well as the number of increments to get to next invalid element
   545  func (it *FlatMaskedIterator) NextInvalid() (int, int, error) {
   546  	if it.mask == nil {
   547  		return it.FlatIterator.NextInvalid()
   548  	}
   549  	var count int
   550  	var mult = 1
   551  	if it.reverse {
   552  		mult = -1
   553  	}
   554  	for i, err := it.Next(); err == nil; i, err = it.Next() {
   555  		count++
   556  		if it.mask[i] {
   557  			return i, mult * count, err
   558  		}
   559  	}
   560  	return -1, mult * count, noopError{}
   561  }
   562  
   563  // FlatSparseIterator is an iterator that works very much in the same way as flatiterator, except for sparse tensors
   564  type FlatSparseIterator struct {
   565  	*CS
   566  
   567  	//state
   568  	nextIndex int
   569  	lastIndex int
   570  	track     []int
   571  	done      bool
   572  	reverse   bool
   573  }
   574  
   575  func NewFlatSparseIterator(t *CS) *FlatSparseIterator {
   576  	it := new(FlatSparseIterator)
   577  	it.CS = t
   578  	it.track = BorrowInts(len(t.s))
   579  	return it
   580  }
   581  
   582  func (it *FlatSparseIterator) Start() (int, error) {
   583  	it.Reset()
   584  	return it.Next()
   585  }
   586  
   587  func (it *FlatSparseIterator) Next() (int, error) {
   588  	if it.done {
   589  		return -1, noopError{}
   590  	}
   591  
   592  	// var ok bool
   593  	it.lastIndex, _ = it.at(it.track...)
   594  
   595  	// increment the coordinates
   596  	for i := len(it.s) - 1; i >= 0; i-- {
   597  		it.track[i]++
   598  		if it.track[i] == it.s[i] {
   599  			if i == 0 {
   600  				it.done = true
   601  			}
   602  			it.track[i] = 0
   603  			continue
   604  		}
   605  		break
   606  	}
   607  
   608  	return it.lastIndex, nil
   609  }
   610  
   611  func (it *FlatSparseIterator) NextValidity() (int, bool, error) {
   612  	i, err := it.Next()
   613  	if i == -1 {
   614  		return i, false, err
   615  	}
   616  	return i, true, err
   617  }
   618  
   619  func (it *FlatSparseIterator) NextValid() (int, int, error) {
   620  	var i int
   621  	var err error
   622  	for i, err = it.Next(); err == nil && i == -1; i, err = it.Next() {
   623  
   624  	}
   625  	return i, -1, err
   626  }
   627  
   628  func (it *FlatSparseIterator) NextInvalid() (int, int, error) {
   629  	var i int
   630  	var err error
   631  	for i, err = it.Next(); err == nil && i != -1; i, err = it.Next() {
   632  
   633  	}
   634  	return i, -1, err
   635  }
   636  
   637  func (it *FlatSparseIterator) Reset() {
   638  	if it.reverse {
   639  		for i := range it.track {
   640  			it.track[i] = it.s[i] - 1
   641  		}
   642  
   643  	} else {
   644  		it.nextIndex = 0
   645  		for i := range it.track {
   646  			it.track[i] = 0
   647  		}
   648  	}
   649  	it.done = false
   650  }
   651  
   652  func (it *FlatSparseIterator) SetReverse() {
   653  	it.reverse = true
   654  	it.Reset()
   655  }
   656  
   657  func (it *FlatSparseIterator) SetForward() {
   658  	it.reverse = false
   659  	it.Reset()
   660  }
   661  
   662  func (it *FlatSparseIterator) Coord() []int {
   663  	return it.track
   664  }
   665  
   666  func (it *FlatSparseIterator) Done() bool {
   667  	return it.done
   668  }
   669  
   670  /* TEMPORARILY REMOVED
   671  // SortedMultiStridePerm takes multiple input strides, and creates a sorted stride permutation.
   672  // It's based very closely on Numpy's PyArray_CreateMultiSortedStridePerm, where a stable insertion sort is used
   673  // to create the permutations.
   674  func SortedMultiStridePerm(dims int, aps []*AP) (retVal []int) {
   675  	retVal = BorrowInts(dims)
   676  	for i := 0; i < dims; i++ {
   677  		retVal[i] = i
   678  	}
   679  
   680  	for i := 1; i < dims; i++ {
   681  		ipos := i
   682  		axisi := retVal[i]
   683  
   684  		for j := i - 1; j >= 0; j-- {
   685  			var ambig, swap bool
   686  			ambig = true
   687  			axisj := retVal[j]
   688  
   689  			for _, ap := range aps {
   690  				if ap.shape[axisi] != 1 && ap.shape[axisj] != 1 {
   691  					if ap.strides[axisi] <= ap.strides[axisj] {
   692  						swap = true
   693  					} else if ambig {
   694  						swap = true
   695  					}
   696  					ambig = false
   697  				}
   698  			}
   699  
   700  			if !ambig && swap {
   701  				ipos = j
   702  			} else {
   703  				break
   704  			}
   705  
   706  		}
   707  		if ipos != i {
   708  			for j := i; j > ipos; j-- {
   709  				retVal[j] = retVal[j-1]
   710  			}
   711  			retVal[ipos] = axisi
   712  		}
   713  	}
   714  	return
   715  }
   716  */