github.com/wzzhu/tensor@v0.9.24/iterator_mult.go (about) 1 package tensor 2 3 import ( 4 "runtime" 5 ) 6 7 // MultIterator is an iterator that iterates over multiple tensors, including masked tensors. 8 // It utilizes the *AP of a Tensor to determine what the next index is. 9 // This data structure is similar to Numpy's flatiter, with some standard Go based restrictions of course 10 // (such as, not allowing negative indices) 11 type MultIterator struct { 12 *AP // Uses AP of the largest tensor in list 13 fit0 *FlatIterator //largest fit in fitArr (by AP total size) 14 mask []bool 15 16 numMasked int 17 lastIndexArr []int 18 shape Shape 19 whichBlock []int 20 fitArr []*FlatIterator 21 strides []int 22 23 size int 24 done bool 25 reverse bool 26 } 27 28 func genIterator(m map[int]int, strides []int, idx int) (int, bool) { 29 key := hashIntArray(strides) 30 f, ok := m[key] 31 if !ok { 32 m[key] = idx 33 return idx, ok 34 } 35 return f, ok 36 } 37 38 // NewMultIterator creates a new MultIterator from a list of APs 39 func NewMultIterator(aps ...*AP) *MultIterator { 40 nit := len(aps) 41 if nit < 1 { 42 return nil 43 } 44 for _, ap := range aps { 45 if ap == nil { 46 panic("ap is nil") //TODO: Probably remove this panic 47 } 48 } 49 50 var maxDims int 51 var maxShape = aps[0].shape 52 53 for i := range aps { 54 if aps[i].Dims() >= maxDims { 55 maxDims = aps[i].Dims() 56 if aps[i].Size() > maxShape.TotalSize() { 57 maxShape = aps[i].shape 58 } 59 } 60 61 } 62 63 it := new(MultIterator) 64 65 it.whichBlock = BorrowInts(nit) 66 it.lastIndexArr = BorrowInts(nit) 67 it.strides = BorrowInts(nit * maxDims) 68 69 shape := BorrowInts(len(maxShape)) 70 copy(shape, maxShape) 71 it.shape = shape 72 73 for _, ap := range aps { 74 _, err := BroadcastStrides(shape, ap.shape, it.strides[:maxDims], ap.strides) 75 if err != nil { 76 panic("can not broadcast strides") 77 } 78 } 79 80 for i := range it.strides { 81 it.strides[i] = 0 82 } 83 84 it.fitArr = make([]*FlatIterator, nit) 85 86 //TODO: Convert this make to Borrow perhaps? 87 m := make(map[int]int) 88 89 nBlocks := 0 90 offset := 0 91 for i, ap := range aps { 92 f, ok := genIterator(m, ap.strides, nBlocks) 93 if !ok { 94 offset = nBlocks * maxDims 95 apStrides, _ := BroadcastStrides(shape, ap.shape, it.strides[offset:offset+maxDims], ap.strides) 96 copy(it.strides[offset:offset+maxDims], apStrides) 97 ReturnInts(apStrides) // Borrowed in BroadcastStrides but returned here - dangerous pattern? 98 nBlocks++ 99 } 100 ap2 := MakeAP(it.shape[:maxDims], it.strides[offset:offset+maxDims], ap.o, ap.Δ) 101 it.whichBlock[i] = f 102 it.fitArr[nBlocks-1] = newFlatIterator(&ap2) 103 } 104 105 it.fitArr = it.fitArr[:nBlocks] 106 it.strides = it.strides[:nBlocks*maxDims] 107 // fill 0s with 1s 108 for i := range it.strides { 109 if it.strides[i] == 0 { 110 it.strides[i] = 1 111 } 112 } 113 114 it.fit0 = it.fitArr[0] 115 for _, f := range it.fitArr { 116 if it.fit0.size < f.size { 117 it.fit0 = f 118 it.AP = f.AP 119 } 120 } 121 return it 122 } 123 124 // MultIteratorFromDense creates a new MultIterator from a list of dense tensors 125 func MultIteratorFromDense(tts ...DenseTensor) *MultIterator { 126 aps := make([]*AP, len(tts)) 127 hasMask := BorrowBools(len(tts)) 128 defer ReturnBools(hasMask) 129 130 var masked = false 131 numMasked := 0 132 for i, tt := range tts { 133 aps[i] = tt.Info() 134 if mt, ok := tt.(MaskedTensor); ok { 135 hasMask[i] = mt.IsMasked() 136 } 137 masked = masked || hasMask[i] 138 if hasMask[i] { 139 numMasked++ 140 } 141 } 142 143 it := NewMultIterator(aps...) 144 runtime.SetFinalizer(it, destroyIterator) 145 146 if masked { 147 // create new mask slice if more than tensor is masked 148 if numMasked > 1 { 149 it.mask = BorrowBools(it.shape.TotalSize()) 150 memsetBools(it.mask, false) 151 for i, err := it.Start(); err == nil; i, err = it.Next() { 152 for j, k := range it.lastIndexArr { 153 if hasMask[j] { 154 it.mask[i] = it.mask[i] || tts[j].(MaskedTensor).Mask()[k] 155 } 156 } 157 } 158 } 159 } 160 it.numMasked = numMasked 161 return it 162 } 163 164 // destroyMultIterator returns any borrowed objects back to pool 165 func destroyMultIterator(it *MultIterator) { 166 167 if cap(it.whichBlock) > 0 { 168 ReturnInts(it.whichBlock) 169 it.whichBlock = nil 170 } 171 if cap(it.lastIndexArr) > 0 { 172 ReturnInts(it.lastIndexArr) 173 it.lastIndexArr = nil 174 } 175 if cap(it.strides) > 0 { 176 ReturnInts(it.strides) 177 it.strides = nil 178 } 179 if it.numMasked > 1 { 180 if cap(it.mask) > 0 { 181 ReturnBools(it.mask) 182 it.mask = nil 183 } 184 } 185 } 186 187 // SetReverse initializes iterator to run backward 188 func (it *MultIterator) SetReverse() { 189 for _, f := range it.fitArr { 190 f.SetReverse() 191 } 192 } 193 194 // SetForward initializes iterator to run forward 195 func (it *MultIterator) SetForward() { 196 for _, f := range it.fitArr { 197 f.SetForward() 198 } 199 } 200 201 //Start begins iteration 202 func (it *MultIterator) Start() (int, error) { 203 it.Reset() 204 return it.Next() 205 } 206 207 //Done checks whether iterators are done 208 func (it *MultIterator) Done() bool { 209 for _, f := range it.fitArr { 210 if !f.done { 211 it.done = false 212 return false 213 } 214 } 215 it.done = true 216 return true 217 } 218 219 // Next returns the index of the next coordinate 220 func (it *MultIterator) Next() (int, error) { 221 if it.done { 222 return -1, noopError{} 223 } 224 it.done = false 225 for _, f := range it.fitArr { 226 if _, err := f.Next(); err != nil { 227 return -1, err 228 } 229 it.done = it.done || f.done 230 } 231 for i, j := range it.whichBlock { 232 it.lastIndexArr[i] = it.fitArr[j].lastIndex 233 } 234 return it.fit0.lastIndex, nil 235 } 236 237 func (it *MultIterator) NextValidity() (int, bool, error) { 238 i, err := it.Next() 239 if err != nil { 240 return i, false, err 241 } 242 243 if len(it.mask) == 0 { 244 return i, true, err 245 } 246 return i, it.mask[i], err 247 } 248 249 // NextValid returns the index of the next valid coordinate 250 func (it *MultIterator) NextValid() (int, int, error) { 251 var invalid = true 252 var count int 253 var mult = 1 254 if it.reverse { 255 mult = -1 256 } 257 for invalid { 258 if it.done { 259 for i, j := range it.whichBlock { 260 it.lastIndexArr[i] = it.fitArr[j].lastIndex 261 } 262 return -1, 0, noopError{} 263 } 264 for _, f := range it.fitArr { 265 f.Next() 266 it.done = it.done || f.done 267 } 268 count++ 269 invalid = !it.mask[it.fit0.lastIndex] 270 } 271 return it.fit0.lastIndex, mult * count, nil 272 } 273 274 // NextInvalid returns the index of the next invalid coordinate 275 func (it *MultIterator) NextInvalid() (int, int, error) { 276 var valid = true 277 278 var count = 0 279 var mult = 1 280 if it.reverse { 281 mult = -1 282 } 283 for valid { 284 if it.done { 285 for i, j := range it.whichBlock { 286 it.lastIndexArr[i] = it.fitArr[j].lastIndex 287 } 288 return -1, 0, noopError{} 289 } 290 for _, f := range it.fitArr { 291 f.Next() 292 it.done = it.done || f.done 293 } 294 count++ 295 valid = !it.mask[it.fit0.lastIndex] 296 } 297 return it.fit0.lastIndex, mult * count, nil 298 } 299 300 // Coord returns the next coordinate. 301 // When Next() is called, the coordinates are updated AFTER the Next() returned. 302 // See example for more details. 303 func (it *MultIterator) Coord() []int { 304 return it.fit0.track 305 } 306 307 // Reset resets the iterator state. 308 func (it *MultIterator) Reset() { 309 for _, f := range it.fitArr { 310 f.Reset() 311 } 312 for i, j := range it.whichBlock { 313 it.lastIndexArr[i] = it.fitArr[j].lastIndex 314 } 315 it.done = false 316 } 317 318 // LastIndex returns index of requested iterator 319 func (it *MultIterator) LastIndex(j int) int { 320 return it.lastIndexArr[j] 321 } 322 323 /* 324 // Chan returns a channel of ints. This is useful for iterating multiple Tensors at the same time. 325 func (it *FlatIterator) Chan() (retVal chan int) { 326 retVal = make(chan int) 327 328 go func() { 329 for next, err := it.Next(); err == nil; next, err = it.Next() { 330 retVal <- next 331 } 332 close(retVal) 333 }() 334 335 return 336 } 337 338 */