gorgonia.org/tensor@v0.9.24/iterator_test.go (about) 1 package tensor 2 3 import ( 4 "testing" 5 6 "github.com/stretchr/testify/assert" 7 ) 8 9 // newAP is a helper function now 10 func newAP(shape Shape, strides []int) *AP { 11 ap := MakeAP(shape, strides, 0, 0) 12 return &ap 13 } 14 15 var flatIterTests1 = []struct { 16 shape Shape 17 strides []int 18 19 correct []int 20 }{ 21 {ScalarShape(), []int{}, []int{0}}, // scalar 22 {Shape{5}, []int{1}, []int{0, 1, 2, 3, 4}}, // vector 23 {Shape{5, 1}, []int{1, 1}, []int{0, 1, 2, 3, 4}}, // colvec 24 {Shape{1, 5}, []int{5, 1}, []int{0, 1, 2, 3, 4}}, // rowvec 25 {Shape{2, 3}, []int{3, 1}, []int{0, 1, 2, 3, 4, 5}}, // basic mat 26 {Shape{3, 2}, []int{1, 3}, []int{0, 3, 1, 4, 2, 5}}, // basic mat, transposed 27 {Shape{2}, []int{2}, []int{0, 2}}, // basic 2x2 mat, sliced: Mat[:, 1] 28 {Shape{2, 2}, []int{5, 1}, []int{0, 1, 5, 6}}, // basic 5x5, sliced: Mat[1:3, 2,4] 29 {Shape{2, 2}, []int{1, 5}, []int{0, 5, 1, 6}}, // basic 5x5, sliced: Mat[1:3, 2,4] then transposed 30 31 {Shape{2, 3, 4}, []int{12, 4, 1}, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}}, // basic 3-Tensor 32 {Shape{2, 4, 3}, []int{12, 1, 4}, []int{0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23}}, // basic 3-Tensor (under (0, 2, 1) transpose) 33 {Shape{4, 2, 3}, []int{1, 12, 4}, []int{0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}}, // basic 3-Tensor (under (2, 0, 1) transpose) 34 {Shape{3, 2, 4}, []int{4, 12, 1}, []int{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23}}, // basic 3-Tensor (under (1, 0, 2) transpose) 35 {Shape{4, 3, 2}, []int{1, 4, 12}, []int{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}}, // basic 3-Tensor (under (2, 1, 0) transpose) 36 37 // ARTIFICIAL CASES - TODO 38 // These cases should be impossible to reach in normal operation 39 // You would have to specially construct these 40 // {Shape{1, 5}, []int{1}, []int{0, 1, 2, 3, 4}}, // rowvec - NEARLY IMPOSSIBLE CASE- TODO 41 } 42 43 var flatIterSlices = []struct { 44 slices []Slice 45 corrects [][]int 46 }{ 47 {[]Slice{nil}, [][]int{{0}}}, 48 {[]Slice{rs{0, 3, 1}, rs{0, 5, 2}, rs{0, 6, -1}}, [][]int{{0, 1, 2}, {0, 2, 4}, {4, 3, 2, 1, 0}}}, 49 } 50 51 func TestFlatIterator(t *testing.T) { 52 assert := assert.New(t) 53 54 var ap *AP 55 var it *FlatIterator 56 var err error 57 var nexts []int 58 59 // basic stuff 60 for i, fit := range flatIterTests1 { 61 nexts = nexts[:0] 62 err = nil 63 ap = newAP(fit.shape, fit.strides) 64 it = newFlatIterator(ap) 65 for next, err := it.Next(); err == nil; next, err = it.Next() { 66 nexts = append(nexts, next) 67 } 68 if _, ok := err.(NoOpError); err != nil && !ok { 69 t.Error(err) 70 } 71 assert.Equal(fit.correct, nexts, "Test %d", i) 72 } 73 } 74 75 func TestFlatIteratorReverse(t *testing.T) { 76 assert := assert.New(t) 77 78 var ap *AP 79 var it *FlatIterator 80 var err error 81 var nexts []int 82 83 // basic stuff 84 for i, fit := range flatIterTests1 { 85 nexts = nexts[:0] 86 err = nil 87 ap = newAP(fit.shape, fit.strides) 88 it = newFlatIterator(ap) 89 it.SetReverse() 90 for next, err := it.Next(); err == nil; next, err = it.Next() { 91 nexts = append(nexts, next) 92 } 93 if _, ok := err.(NoOpError); err != nil && !ok { 94 t.Error(err) 95 } 96 // reverse slice 97 for i, j := 0, len(nexts)-1; i < j; i, j = i+1, j-1 { 98 nexts[i], nexts[j] = nexts[j], nexts[i] 99 } 100 // and then check 101 assert.Equal(fit.correct, nexts, "Test %d", i) 102 } 103 } 104 105 func TestMultIterator(t *testing.T) { 106 assert := assert.New(t) 107 108 var ap []*AP 109 var it *MultIterator 110 var err error 111 var nexts [][]int 112 113 doReverse := []bool{false, true} 114 for _, reverse := range doReverse { 115 ap = make([]*AP, 6) 116 nexts = make([][]int, 6) 117 118 // Repeat flat tests 119 for i, fit := range flatIterTests1 { 120 nexts[0] = nexts[0][:0] 121 err = nil 122 ap[0] = newAP(fit.shape, fit.strides) 123 it = NewMultIterator(ap[0]) 124 if reverse { 125 it.SetReverse() 126 } 127 for next, err := it.Next(); err == nil; next, err = it.Next() { 128 nexts[0] = append(nexts[0], next) 129 } 130 if _, ok := err.(NoOpError); err != nil && !ok { 131 t.Error(err) 132 } 133 if reverse { 134 for i, j := 0, len(nexts[0])-1; i < j; i, j = i+1, j-1 { 135 nexts[0][i], nexts[0][j] = nexts[0][j], nexts[0][i] 136 } 137 } 138 assert.Equal(fit.correct, nexts[0], "Repeating flat test %d. Reverse? %v", i, reverse) 139 } 140 // Test multiple iterators simultaneously 141 /* 142 var choices = []int{0, 0, 9, 9, 0, 9} 143 for j := 0; j < 6; j++ { 144 fit := flatIterTests1[choices[j]] 145 nexts[j] = nexts[j][:0] 146 err = nil 147 ap[j] = newAP(fit.shape, fit.strides) 148 } 149 it = NewMultIterator(ap...) 150 if reverse { 151 it.SetReverse() 152 } 153 for _, err := it.Next(); err == nil; _, err = it.Next() { 154 for j := 0; j < 6; j++ { 155 nexts[j] = append(nexts[j], it.LastIndex(j)) 156 } 157 158 if _, ok := err.(NoOpError); err != nil && !ok { 159 t.Error(err) 160 } 161 } 162 163 for j := 0; j < 6; j++ { 164 fit := flatIterTests1[choices[j]] 165 if reverse { 166 for i, k := 0, len(nexts[j])-1; i < k; i, k = i+1, k-1 { 167 nexts[j][i], nexts[j][k] = nexts[j][k], nexts[j][i] 168 } 169 } 170 if ap[j].IsScalar() { 171 assert.Equal(fit.correct, nexts[j][:1], "Test multiple iterators %d", j) 172 } else { 173 assert.Equal(fit.correct, nexts[j], "Test multiple iterators %d", j) 174 } 175 } 176 */ 177 } 178 179 } 180 181 func TestIteratorInterface(t *testing.T) { 182 assert := assert.New(t) 183 184 var ap *AP 185 var it Iterator 186 var err error 187 var nexts []int 188 189 // basic stuff 190 for i, fit := range flatIterTests1 { 191 nexts = nexts[:0] 192 err = nil 193 ap = newAP(fit.shape, fit.strides) 194 it = NewIterator(ap) 195 for next, err := it.Start(); err == nil; next, err = it.Next() { 196 nexts = append(nexts, next) 197 } 198 if _, ok := err.(NoOpError); err != nil && !ok { 199 t.Error(err) 200 } 201 assert.Equal(fit.correct, nexts, "Test %d", i) 202 } 203 } 204 205 func TestMultIteratorFromDense(t *testing.T) { 206 assert := assert.New(t) 207 208 T1 := New(Of(Int), WithShape(3, 20)) 209 data1 := T1.Data().([]int) 210 T2 := New(Of(Int), WithShape(3, 20)) 211 data2 := T2.Data().([]int) 212 T3 := New(Of(Int), FromScalar(7)) 213 data3 := T3.Data().(int) 214 215 for i := 0; i < 60; i++ { 216 data1[i] = i 217 data2[i] = 7 * i 218 } 219 it := MultIteratorFromDense(T1, T2, T3) 220 221 for _, err := it.Next(); err == nil; _, err = it.Next() { 222 x := data1[it.LastIndex(0)] 223 y := data2[it.LastIndex(1)] 224 z := data3 225 assert.True(y == x*z) 226 } 227 } 228 229 func TestFlatIterator_Chan(t *testing.T) { 230 assert := assert.New(t) 231 232 var ap *AP 233 var it *FlatIterator 234 var nexts []int 235 236 // basic stuff 237 for i, fit := range flatIterTests1 { 238 nexts = nexts[:0] 239 ap = newAP(fit.shape, fit.strides) 240 it = newFlatIterator(ap) 241 ch := it.Chan() 242 for next := range ch { 243 nexts = append(nexts, next) 244 } 245 assert.Equal(fit.correct, nexts, "Test %d", i) 246 } 247 } 248 249 func TestFlatIterator_Slice(t *testing.T) { 250 assert := assert.New(t) 251 252 var ap *AP 253 var it *FlatIterator 254 var err error 255 var nexts []int 256 257 for i, fit := range flatIterTests1 { 258 ap = newAP(fit.shape, fit.strides) 259 it = newFlatIterator(ap) 260 nexts, err = it.Slice(nil) 261 if _, ok := err.(NoOpError); err != nil && !ok { 262 t.Error(err) 263 } 264 265 assert.Equal(fit.correct, nexts, "Test %d", i) 266 267 if i < len(flatIterSlices) { 268 fis := flatIterSlices[i] 269 for j, sli := range fis.slices { 270 it.Reset() 271 272 nexts, err = it.Slice(sli) 273 if _, ok := err.(NoOpError); err != nil && !ok { 274 t.Error(err) 275 } 276 277 assert.Equal(fis.corrects[j], nexts, "Test %d", i) 278 } 279 } 280 } 281 } 282 283 func TestFlatIterator_Coord(t *testing.T) { 284 assert := assert.New(t) 285 286 var ap *AP 287 var it *FlatIterator 288 var err error 289 // var nexts []int 290 var donecount int 291 292 ap = newAP(Shape{2, 3, 4}, []int{12, 4, 1}) 293 it = newFlatIterator(ap) 294 295 var correct = [][]int{ 296 {0, 0, 1}, 297 {0, 0, 2}, 298 {0, 0, 3}, 299 {0, 1, 0}, 300 {0, 1, 1}, 301 {0, 1, 2}, 302 {0, 1, 3}, 303 {0, 2, 0}, 304 {0, 2, 1}, 305 {0, 2, 2}, 306 {0, 2, 3}, 307 {1, 0, 0}, 308 {1, 0, 1}, 309 {1, 0, 2}, 310 {1, 0, 3}, 311 {1, 1, 0}, 312 {1, 1, 1}, 313 {1, 1, 2}, 314 {1, 1, 3}, 315 {1, 2, 0}, 316 {1, 2, 1}, 317 {1, 2, 2}, 318 {1, 2, 3}, 319 {0, 0, 0}, 320 } 321 322 for _, err = it.Next(); err == nil; _, err = it.Next() { 323 assert.Equal(correct[donecount], it.Coord()) 324 donecount++ 325 } 326 } 327 328 // really this is just for completeness sake 329 func TestFlatIterator_Reset(t *testing.T) { 330 assert := assert.New(t) 331 ap := newAP(Shape{2, 3, 4}, []int{12, 4, 1}) 332 it := newFlatIterator(ap) 333 334 it.Next() 335 it.Next() 336 it.Reset() 337 assert.Equal(0, it.nextIndex) 338 assert.Equal(false, it.done) 339 assert.Equal([]int{0, 0, 0}, it.track) 340 341 for _, err := it.Next(); err == nil; _, err = it.Next() { 342 } 343 344 it.Reset() 345 assert.Equal(0, it.nextIndex) 346 assert.Equal(false, it.done) 347 assert.Equal([]int{0, 0, 0}, it.track) 348 } 349 350 func TestDestroyIterator(t *testing.T) { 351 it := new(MultIterator) 352 destroyIterator(it) 353 } 354 355 /* BENCHMARK */ 356 type oldFlatIterator struct { 357 *AP 358 359 //state 360 lastIndex int 361 track []int 362 done bool 363 } 364 365 // newFlatIterator creates a new FlatIterator 366 func newOldFlatIterator(ap *AP) *oldFlatIterator { 367 return &oldFlatIterator{ 368 AP: ap, 369 track: make([]int, len(ap.shape)), 370 } 371 } 372 373 func (it *oldFlatIterator) Next() (int, error) { 374 if it.done { 375 return -1, noopError{} 376 } 377 378 retVal, err := Ltoi(it.shape, it.strides, it.track...) 379 it.lastIndex = retVal 380 381 if it.IsScalar() { 382 it.done = true 383 return retVal, err 384 } 385 386 for d := len(it.shape) - 1; d >= 0; d-- { 387 if d == 0 && it.track[0]+1 >= it.shape[0] { 388 it.done = true 389 it.track[d] = 0 // overflow it 390 break 391 } 392 393 if it.track[d] < it.shape[d]-1 { 394 it.track[d]++ 395 break 396 } 397 // overflow 398 it.track[d] = 0 399 } 400 401 return retVal, err 402 } 403 404 func (it *oldFlatIterator) Reset() { 405 it.done = false 406 it.lastIndex = 0 407 408 if it.done { 409 return 410 } 411 412 for i := range it.track { 413 it.track[i] = 0 414 } 415 } 416 417 func BenchmarkOldFlatIterator(b *testing.B) { 418 var err error 419 420 // as if T = NewTensor(WithShape(30, 1000, 1000)) 421 // then T[:, 0:900:15, 250:750:50] 422 ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) 423 it := newOldFlatIterator(ap) 424 425 for n := 0; n < b.N; n++ { 426 for _, err := it.Next(); err == nil; _, err = it.Next() { 427 428 } 429 if _, ok := err.(NoOpError); err != nil && !ok { 430 b.Error(err) 431 } 432 433 it.Reset() 434 } 435 } 436 437 func BenchmarkFlatIterator(b *testing.B) { 438 var err error 439 440 // as if T = NewTensor(WithShape(30, 1000, 1000)) 441 // then T[:, 0:900:15, 250:750:50] 442 ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) 443 it := newFlatIterator(ap) 444 445 for n := 0; n < b.N; n++ { 446 for _, err := it.Next(); err == nil; _, err = it.Next() { 447 448 } 449 if _, ok := err.(NoOpError); err != nil && !ok { 450 b.Error(err) 451 } 452 453 it.Reset() 454 } 455 } 456 457 func BenchmarkFlatIteratorParallel6(b *testing.B) { 458 var err error 459 460 // as if T = NewTensor(WithShape(30, 1000, 1000)) 461 // then T[:, 0:900:15, 250:750:50] 462 ap := make([]*AP, 6) 463 it := make([]*FlatIterator, 6) 464 465 for j := 0; j < 6; j++ { 466 ap[j] = newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) 467 it[j] = newFlatIterator(ap[j]) 468 } 469 470 for n := 0; n < b.N; n++ { 471 for _, err := it[0].Next(); err == nil; _, err = it[0].Next() { 472 for j := 1; j < 6; j++ { 473 it[j].Next() 474 } 475 476 } 477 if _, ok := err.(NoOpError); err != nil && !ok { 478 b.Error(err) 479 } 480 for j := 0; j < 6; j++ { 481 it[j].Reset() 482 } 483 } 484 485 } 486 487 func BenchmarkFlatIteratorMulti1(b *testing.B) { 488 var err error 489 490 // as if T = NewTensor(WithShape(30, 1000, 1000)) 491 // then T[:, 0:900:15, 250:750:50] 492 ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) 493 494 it := NewMultIterator(ap) 495 496 for n := 0; n < b.N; n++ { 497 for _, err := it.Next(); err == nil; _, err = it.Next() { 498 499 } 500 if _, ok := err.(NoOpError); err != nil && !ok { 501 b.Error(err) 502 } 503 it.Reset() 504 } 505 } 506 507 func BenchmarkFlatIteratorGeneric1(b *testing.B) { 508 var err error 509 510 // as if T = NewTensor(WithShape(30, 1000, 1000)) 511 // then T[:, 0:900:15, 250:750:50] 512 ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) 513 514 it := NewIterator(ap) 515 516 for n := 0; n < b.N; n++ { 517 for _, err := it.Next(); err == nil; _, err = it.Next() { 518 519 } 520 if _, ok := err.(NoOpError); err != nil && !ok { 521 b.Error(err) 522 } 523 it.Reset() 524 } 525 } 526 527 func BenchmarkFlatIteratorMulti6(b *testing.B) { 528 var err error 529 530 // as if T = NewTensor(WithShape(30, 1000, 1000)) 531 // then T[:, 0:900:15, 250:750:50] 532 ap := make([]*AP, 6) 533 534 for j := 0; j < 6; j++ { 535 ap[j] = newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) 536 } 537 538 it := NewMultIterator(ap...) 539 540 for n := 0; n < b.N; n++ { 541 for _, err := it.Next(); err == nil; _, err = it.Next() { 542 543 } 544 if _, ok := err.(NoOpError); err != nil && !ok { 545 b.Error(err) 546 } 547 it.Reset() 548 } 549 }