github.com/wzzhu/tensor@v0.9.24/iterator_mult.go (about)

     1  package tensor
     2  
     3  import (
     4  	"runtime"
     5  )
     6  
     7  // MultIterator is an iterator that iterates over multiple tensors, including masked tensors.
     8  //  It utilizes the *AP of a Tensor to determine what the next index is.
     9  // This data structure is similar to Numpy's flatiter, with some standard Go based restrictions of course
    10  // (such as, not allowing negative indices)
    11  type MultIterator struct {
    12  	*AP                // Uses AP of the largest tensor in list
    13  	fit0 *FlatIterator //largest fit in fitArr (by AP total size)
    14  	mask []bool
    15  
    16  	numMasked    int
    17  	lastIndexArr []int
    18  	shape        Shape
    19  	whichBlock   []int
    20  	fitArr       []*FlatIterator
    21  	strides      []int
    22  
    23  	size    int
    24  	done    bool
    25  	reverse bool
    26  }
    27  
    28  func genIterator(m map[int]int, strides []int, idx int) (int, bool) {
    29  	key := hashIntArray(strides)
    30  	f, ok := m[key]
    31  	if !ok {
    32  		m[key] = idx
    33  		return idx, ok
    34  	}
    35  	return f, ok
    36  }
    37  
    38  // NewMultIterator creates a new MultIterator from a list of APs
    39  func NewMultIterator(aps ...*AP) *MultIterator {
    40  	nit := len(aps)
    41  	if nit < 1 {
    42  		return nil
    43  	}
    44  	for _, ap := range aps {
    45  		if ap == nil {
    46  			panic("ap is nil") //TODO: Probably remove this panic
    47  		}
    48  	}
    49  
    50  	var maxDims int
    51  	var maxShape = aps[0].shape
    52  
    53  	for i := range aps {
    54  		if aps[i].Dims() >= maxDims {
    55  			maxDims = aps[i].Dims()
    56  			if aps[i].Size() > maxShape.TotalSize() {
    57  				maxShape = aps[i].shape
    58  			}
    59  		}
    60  
    61  	}
    62  
    63  	it := new(MultIterator)
    64  
    65  	it.whichBlock = BorrowInts(nit)
    66  	it.lastIndexArr = BorrowInts(nit)
    67  	it.strides = BorrowInts(nit * maxDims)
    68  
    69  	shape := BorrowInts(len(maxShape))
    70  	copy(shape, maxShape)
    71  	it.shape = shape
    72  
    73  	for _, ap := range aps {
    74  		_, err := BroadcastStrides(shape, ap.shape, it.strides[:maxDims], ap.strides)
    75  		if err != nil {
    76  			panic("can not broadcast strides")
    77  		}
    78  	}
    79  
    80  	for i := range it.strides {
    81  		it.strides[i] = 0
    82  	}
    83  
    84  	it.fitArr = make([]*FlatIterator, nit)
    85  
    86  	//TODO: Convert this make to Borrow perhaps?
    87  	m := make(map[int]int)
    88  
    89  	nBlocks := 0
    90  	offset := 0
    91  	for i, ap := range aps {
    92  		f, ok := genIterator(m, ap.strides, nBlocks)
    93  		if !ok {
    94  			offset = nBlocks * maxDims
    95  			apStrides, _ := BroadcastStrides(shape, ap.shape, it.strides[offset:offset+maxDims], ap.strides)
    96  			copy(it.strides[offset:offset+maxDims], apStrides)
    97  			ReturnInts(apStrides) // Borrowed in BroadcastStrides but returned here - dangerous pattern?
    98  			nBlocks++
    99  		}
   100  		ap2 := MakeAP(it.shape[:maxDims], it.strides[offset:offset+maxDims], ap.o, ap.Δ)
   101  		it.whichBlock[i] = f
   102  		it.fitArr[nBlocks-1] = newFlatIterator(&ap2)
   103  	}
   104  
   105  	it.fitArr = it.fitArr[:nBlocks]
   106  	it.strides = it.strides[:nBlocks*maxDims]
   107  	// fill 0s with 1s
   108  	for i := range it.strides {
   109  		if it.strides[i] == 0 {
   110  			it.strides[i] = 1
   111  		}
   112  	}
   113  
   114  	it.fit0 = it.fitArr[0]
   115  	for _, f := range it.fitArr {
   116  		if it.fit0.size < f.size {
   117  			it.fit0 = f
   118  			it.AP = f.AP
   119  		}
   120  	}
   121  	return it
   122  }
   123  
   124  // MultIteratorFromDense creates a new MultIterator from a list of dense tensors
   125  func MultIteratorFromDense(tts ...DenseTensor) *MultIterator {
   126  	aps := make([]*AP, len(tts))
   127  	hasMask := BorrowBools(len(tts))
   128  	defer ReturnBools(hasMask)
   129  
   130  	var masked = false
   131  	numMasked := 0
   132  	for i, tt := range tts {
   133  		aps[i] = tt.Info()
   134  		if mt, ok := tt.(MaskedTensor); ok {
   135  			hasMask[i] = mt.IsMasked()
   136  		}
   137  		masked = masked || hasMask[i]
   138  		if hasMask[i] {
   139  			numMasked++
   140  		}
   141  	}
   142  
   143  	it := NewMultIterator(aps...)
   144  	runtime.SetFinalizer(it, destroyIterator)
   145  
   146  	if masked {
   147  		// create new mask slice if more than tensor is masked
   148  		if numMasked > 1 {
   149  			it.mask = BorrowBools(it.shape.TotalSize())
   150  			memsetBools(it.mask, false)
   151  			for i, err := it.Start(); err == nil; i, err = it.Next() {
   152  				for j, k := range it.lastIndexArr {
   153  					if hasMask[j] {
   154  						it.mask[i] = it.mask[i] || tts[j].(MaskedTensor).Mask()[k]
   155  					}
   156  				}
   157  			}
   158  		}
   159  	}
   160  	it.numMasked = numMasked
   161  	return it
   162  }
   163  
   164  // destroyMultIterator returns any borrowed objects back to pool
   165  func destroyMultIterator(it *MultIterator) {
   166  
   167  	if cap(it.whichBlock) > 0 {
   168  		ReturnInts(it.whichBlock)
   169  		it.whichBlock = nil
   170  	}
   171  	if cap(it.lastIndexArr) > 0 {
   172  		ReturnInts(it.lastIndexArr)
   173  		it.lastIndexArr = nil
   174  	}
   175  	if cap(it.strides) > 0 {
   176  		ReturnInts(it.strides)
   177  		it.strides = nil
   178  	}
   179  	if it.numMasked > 1 {
   180  		if cap(it.mask) > 0 {
   181  			ReturnBools(it.mask)
   182  			it.mask = nil
   183  		}
   184  	}
   185  }
   186  
   187  // SetReverse initializes iterator to run backward
   188  func (it *MultIterator) SetReverse() {
   189  	for _, f := range it.fitArr {
   190  		f.SetReverse()
   191  	}
   192  }
   193  
   194  // SetForward initializes iterator to run forward
   195  func (it *MultIterator) SetForward() {
   196  	for _, f := range it.fitArr {
   197  		f.SetForward()
   198  	}
   199  }
   200  
   201  //Start begins iteration
   202  func (it *MultIterator) Start() (int, error) {
   203  	it.Reset()
   204  	return it.Next()
   205  }
   206  
   207  //Done checks whether iterators are done
   208  func (it *MultIterator) Done() bool {
   209  	for _, f := range it.fitArr {
   210  		if !f.done {
   211  			it.done = false
   212  			return false
   213  		}
   214  	}
   215  	it.done = true
   216  	return true
   217  }
   218  
   219  // Next returns the index of the next coordinate
   220  func (it *MultIterator) Next() (int, error) {
   221  	if it.done {
   222  		return -1, noopError{}
   223  	}
   224  	it.done = false
   225  	for _, f := range it.fitArr {
   226  		if _, err := f.Next(); err != nil {
   227  			return -1, err
   228  		}
   229  		it.done = it.done || f.done
   230  	}
   231  	for i, j := range it.whichBlock {
   232  		it.lastIndexArr[i] = it.fitArr[j].lastIndex
   233  	}
   234  	return it.fit0.lastIndex, nil
   235  }
   236  
   237  func (it *MultIterator) NextValidity() (int, bool, error) {
   238  	i, err := it.Next()
   239  	if err != nil {
   240  		return i, false, err
   241  	}
   242  
   243  	if len(it.mask) == 0 {
   244  		return i, true, err
   245  	}
   246  	return i, it.mask[i], err
   247  }
   248  
   249  // NextValid returns the index of the next valid coordinate
   250  func (it *MultIterator) NextValid() (int, int, error) {
   251  	var invalid = true
   252  	var count int
   253  	var mult = 1
   254  	if it.reverse {
   255  		mult = -1
   256  	}
   257  	for invalid {
   258  		if it.done {
   259  			for i, j := range it.whichBlock {
   260  				it.lastIndexArr[i] = it.fitArr[j].lastIndex
   261  			}
   262  			return -1, 0, noopError{}
   263  		}
   264  		for _, f := range it.fitArr {
   265  			f.Next()
   266  			it.done = it.done || f.done
   267  		}
   268  		count++
   269  		invalid = !it.mask[it.fit0.lastIndex]
   270  	}
   271  	return it.fit0.lastIndex, mult * count, nil
   272  }
   273  
   274  // NextInvalid returns the index of the next invalid coordinate
   275  func (it *MultIterator) NextInvalid() (int, int, error) {
   276  	var valid = true
   277  
   278  	var count = 0
   279  	var mult = 1
   280  	if it.reverse {
   281  		mult = -1
   282  	}
   283  	for valid {
   284  		if it.done {
   285  			for i, j := range it.whichBlock {
   286  				it.lastIndexArr[i] = it.fitArr[j].lastIndex
   287  			}
   288  			return -1, 0, noopError{}
   289  		}
   290  		for _, f := range it.fitArr {
   291  			f.Next()
   292  			it.done = it.done || f.done
   293  		}
   294  		count++
   295  		valid = !it.mask[it.fit0.lastIndex]
   296  	}
   297  	return it.fit0.lastIndex, mult * count, nil
   298  }
   299  
   300  // Coord returns the next coordinate.
   301  // When Next() is called, the coordinates are updated AFTER the Next() returned.
   302  // See example for more details.
   303  func (it *MultIterator) Coord() []int {
   304  	return it.fit0.track
   305  }
   306  
   307  // Reset resets the iterator state.
   308  func (it *MultIterator) Reset() {
   309  	for _, f := range it.fitArr {
   310  		f.Reset()
   311  	}
   312  	for i, j := range it.whichBlock {
   313  		it.lastIndexArr[i] = it.fitArr[j].lastIndex
   314  	}
   315  	it.done = false
   316  }
   317  
   318  // LastIndex returns index of requested iterator
   319  func (it *MultIterator) LastIndex(j int) int {
   320  	return it.lastIndexArr[j]
   321  }
   322  
   323  /*
   324  // Chan returns a channel of ints. This is useful for iterating multiple Tensors at the same time.
   325  func (it *FlatIterator) Chan() (retVal chan int) {
   326  	retVal = make(chan int)
   327  
   328  	go func() {
   329  		for next, err := it.Next(); err == nil; next, err = it.Next() {
   330  			retVal <- next
   331  		}
   332  		close(retVal)
   333  	}()
   334  
   335  	return
   336  }
   337  
   338  */