gorgonia.org/tensor@v0.9.24/iterator_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/assert"
     7  )
     8  
     9  // newAP is  a helper function now
    10  func newAP(shape Shape, strides []int) *AP {
    11  	ap := MakeAP(shape, strides, 0, 0)
    12  	return &ap
    13  }
    14  
    15  var flatIterTests1 = []struct {
    16  	shape   Shape
    17  	strides []int
    18  
    19  	correct []int
    20  }{
    21  	{ScalarShape(), []int{}, []int{0}},                  // scalar
    22  	{Shape{5}, []int{1}, []int{0, 1, 2, 3, 4}},          // vector
    23  	{Shape{5, 1}, []int{1, 1}, []int{0, 1, 2, 3, 4}},    // colvec
    24  	{Shape{1, 5}, []int{5, 1}, []int{0, 1, 2, 3, 4}},    // rowvec
    25  	{Shape{2, 3}, []int{3, 1}, []int{0, 1, 2, 3, 4, 5}}, // basic mat
    26  	{Shape{3, 2}, []int{1, 3}, []int{0, 3, 1, 4, 2, 5}}, // basic mat, transposed
    27  	{Shape{2}, []int{2}, []int{0, 2}},                   // basic 2x2 mat, sliced: Mat[:, 1]
    28  	{Shape{2, 2}, []int{5, 1}, []int{0, 1, 5, 6}},       // basic 5x5, sliced: Mat[1:3, 2,4]
    29  	{Shape{2, 2}, []int{1, 5}, []int{0, 5, 1, 6}},       // basic 5x5, sliced: Mat[1:3, 2,4] then transposed
    30  
    31  	{Shape{2, 3, 4}, []int{12, 4, 1}, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}}, // basic 3-Tensor
    32  	{Shape{2, 4, 3}, []int{12, 1, 4}, []int{0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23}}, // basic 3-Tensor (under (0, 2, 1) transpose)
    33  	{Shape{4, 2, 3}, []int{1, 12, 4}, []int{0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}}, // basic 3-Tensor (under (2, 0, 1) transpose)
    34  	{Shape{3, 2, 4}, []int{4, 12, 1}, []int{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23}}, // basic 3-Tensor (under (1, 0, 2) transpose)
    35  	{Shape{4, 3, 2}, []int{1, 4, 12}, []int{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}}, // basic 3-Tensor (under (2, 1, 0) transpose)
    36  
    37  	// ARTIFICIAL CASES - TODO
    38  	// These cases should be impossible to reach in normal operation
    39  	// You would have to specially construct these
    40  	// {Shape{1, 5}, []int{1}, []int{0, 1, 2, 3, 4}},       // rowvec - NEARLY IMPOSSIBLE CASE- TODO
    41  }
    42  
    43  var flatIterSlices = []struct {
    44  	slices   []Slice
    45  	corrects [][]int
    46  }{
    47  	{[]Slice{nil}, [][]int{{0}}},
    48  	{[]Slice{rs{0, 3, 1}, rs{0, 5, 2}, rs{0, 6, -1}}, [][]int{{0, 1, 2}, {0, 2, 4}, {4, 3, 2, 1, 0}}},
    49  }
    50  
    51  func TestFlatIterator(t *testing.T) {
    52  	assert := assert.New(t)
    53  
    54  	var ap *AP
    55  	var it *FlatIterator
    56  	var err error
    57  	var nexts []int
    58  
    59  	// basic stuff
    60  	for i, fit := range flatIterTests1 {
    61  		nexts = nexts[:0]
    62  		err = nil
    63  		ap = newAP(fit.shape, fit.strides)
    64  		it = newFlatIterator(ap)
    65  		for next, err := it.Next(); err == nil; next, err = it.Next() {
    66  			nexts = append(nexts, next)
    67  		}
    68  		if _, ok := err.(NoOpError); err != nil && !ok {
    69  			t.Error(err)
    70  		}
    71  		assert.Equal(fit.correct, nexts, "Test %d", i)
    72  	}
    73  }
    74  
    75  func TestFlatIteratorReverse(t *testing.T) {
    76  	assert := assert.New(t)
    77  
    78  	var ap *AP
    79  	var it *FlatIterator
    80  	var err error
    81  	var nexts []int
    82  
    83  	// basic stuff
    84  	for i, fit := range flatIterTests1 {
    85  		nexts = nexts[:0]
    86  		err = nil
    87  		ap = newAP(fit.shape, fit.strides)
    88  		it = newFlatIterator(ap)
    89  		it.SetReverse()
    90  		for next, err := it.Next(); err == nil; next, err = it.Next() {
    91  			nexts = append(nexts, next)
    92  		}
    93  		if _, ok := err.(NoOpError); err != nil && !ok {
    94  			t.Error(err)
    95  		}
    96  		// reverse slice
    97  		for i, j := 0, len(nexts)-1; i < j; i, j = i+1, j-1 {
    98  			nexts[i], nexts[j] = nexts[j], nexts[i]
    99  		}
   100  		// and then check
   101  		assert.Equal(fit.correct, nexts, "Test %d", i)
   102  	}
   103  }
   104  
   105  func TestMultIterator(t *testing.T) {
   106  	assert := assert.New(t)
   107  
   108  	var ap []*AP
   109  	var it *MultIterator
   110  	var err error
   111  	var nexts [][]int
   112  
   113  	doReverse := []bool{false, true}
   114  	for _, reverse := range doReverse {
   115  		ap = make([]*AP, 6)
   116  		nexts = make([][]int, 6)
   117  
   118  		// Repeat flat tests
   119  		for i, fit := range flatIterTests1 {
   120  			nexts[0] = nexts[0][:0]
   121  			err = nil
   122  			ap[0] = newAP(fit.shape, fit.strides)
   123  			it = NewMultIterator(ap[0])
   124  			if reverse {
   125  				it.SetReverse()
   126  			}
   127  			for next, err := it.Next(); err == nil; next, err = it.Next() {
   128  				nexts[0] = append(nexts[0], next)
   129  			}
   130  			if _, ok := err.(NoOpError); err != nil && !ok {
   131  				t.Error(err)
   132  			}
   133  			if reverse {
   134  				for i, j := 0, len(nexts[0])-1; i < j; i, j = i+1, j-1 {
   135  					nexts[0][i], nexts[0][j] = nexts[0][j], nexts[0][i]
   136  				}
   137  			}
   138  			assert.Equal(fit.correct, nexts[0], "Repeating flat test %d. Reverse? %v", i, reverse)
   139  		}
   140  		// Test multiple iterators simultaneously
   141  		/*
   142  			var choices = []int{0, 0, 9, 9, 0, 9}
   143  			for j := 0; j < 6; j++ {
   144  				fit := flatIterTests1[choices[j]]
   145  				nexts[j] = nexts[j][:0]
   146  				err = nil
   147  				ap[j] = newAP(fit.shape, fit.strides)
   148  			}
   149  			it = NewMultIterator(ap...)
   150  			if reverse {
   151  				it.SetReverse()
   152  			}
   153  			for _, err := it.Next(); err == nil; _, err = it.Next() {
   154  				for j := 0; j < 6; j++ {
   155  					nexts[j] = append(nexts[j], it.LastIndex(j))
   156  				}
   157  
   158  				if _, ok := err.(NoOpError); err != nil && !ok {
   159  					t.Error(err)
   160  				}
   161  			}
   162  
   163  			for j := 0; j < 6; j++ {
   164  				fit := flatIterTests1[choices[j]]
   165  				if reverse {
   166  					for i, k := 0, len(nexts[j])-1; i < k; i, k = i+1, k-1 {
   167  						nexts[j][i], nexts[j][k] = nexts[j][k], nexts[j][i]
   168  					}
   169  				}
   170  				if ap[j].IsScalar() {
   171  					assert.Equal(fit.correct, nexts[j][:1], "Test multiple iterators %d", j)
   172  				} else {
   173  					assert.Equal(fit.correct, nexts[j], "Test multiple iterators %d", j)
   174  				}
   175  			}
   176  		*/
   177  	}
   178  
   179  }
   180  
   181  func TestIteratorInterface(t *testing.T) {
   182  	assert := assert.New(t)
   183  
   184  	var ap *AP
   185  	var it Iterator
   186  	var err error
   187  	var nexts []int
   188  
   189  	// basic stuff
   190  	for i, fit := range flatIterTests1 {
   191  		nexts = nexts[:0]
   192  		err = nil
   193  		ap = newAP(fit.shape, fit.strides)
   194  		it = NewIterator(ap)
   195  		for next, err := it.Start(); err == nil; next, err = it.Next() {
   196  			nexts = append(nexts, next)
   197  		}
   198  		if _, ok := err.(NoOpError); err != nil && !ok {
   199  			t.Error(err)
   200  		}
   201  		assert.Equal(fit.correct, nexts, "Test %d", i)
   202  	}
   203  }
   204  
   205  func TestMultIteratorFromDense(t *testing.T) {
   206  	assert := assert.New(t)
   207  
   208  	T1 := New(Of(Int), WithShape(3, 20))
   209  	data1 := T1.Data().([]int)
   210  	T2 := New(Of(Int), WithShape(3, 20))
   211  	data2 := T2.Data().([]int)
   212  	T3 := New(Of(Int), FromScalar(7))
   213  	data3 := T3.Data().(int)
   214  
   215  	for i := 0; i < 60; i++ {
   216  		data1[i] = i
   217  		data2[i] = 7 * i
   218  	}
   219  	it := MultIteratorFromDense(T1, T2, T3)
   220  
   221  	for _, err := it.Next(); err == nil; _, err = it.Next() {
   222  		x := data1[it.LastIndex(0)]
   223  		y := data2[it.LastIndex(1)]
   224  		z := data3
   225  		assert.True(y == x*z)
   226  	}
   227  }
   228  
   229  func TestFlatIterator_Chan(t *testing.T) {
   230  	assert := assert.New(t)
   231  
   232  	var ap *AP
   233  	var it *FlatIterator
   234  	var nexts []int
   235  
   236  	// basic stuff
   237  	for i, fit := range flatIterTests1 {
   238  		nexts = nexts[:0]
   239  		ap = newAP(fit.shape, fit.strides)
   240  		it = newFlatIterator(ap)
   241  		ch := it.Chan()
   242  		for next := range ch {
   243  			nexts = append(nexts, next)
   244  		}
   245  		assert.Equal(fit.correct, nexts, "Test %d", i)
   246  	}
   247  }
   248  
   249  func TestFlatIterator_Slice(t *testing.T) {
   250  	assert := assert.New(t)
   251  
   252  	var ap *AP
   253  	var it *FlatIterator
   254  	var err error
   255  	var nexts []int
   256  
   257  	for i, fit := range flatIterTests1 {
   258  		ap = newAP(fit.shape, fit.strides)
   259  		it = newFlatIterator(ap)
   260  		nexts, err = it.Slice(nil)
   261  		if _, ok := err.(NoOpError); err != nil && !ok {
   262  			t.Error(err)
   263  		}
   264  
   265  		assert.Equal(fit.correct, nexts, "Test %d", i)
   266  
   267  		if i < len(flatIterSlices) {
   268  			fis := flatIterSlices[i]
   269  			for j, sli := range fis.slices {
   270  				it.Reset()
   271  
   272  				nexts, err = it.Slice(sli)
   273  				if _, ok := err.(NoOpError); err != nil && !ok {
   274  					t.Error(err)
   275  				}
   276  
   277  				assert.Equal(fis.corrects[j], nexts, "Test %d", i)
   278  			}
   279  		}
   280  	}
   281  }
   282  
   283  func TestFlatIterator_Coord(t *testing.T) {
   284  	assert := assert.New(t)
   285  
   286  	var ap *AP
   287  	var it *FlatIterator
   288  	var err error
   289  	// var nexts []int
   290  	var donecount int
   291  
   292  	ap = newAP(Shape{2, 3, 4}, []int{12, 4, 1})
   293  	it = newFlatIterator(ap)
   294  
   295  	var correct = [][]int{
   296  		{0, 0, 1},
   297  		{0, 0, 2},
   298  		{0, 0, 3},
   299  		{0, 1, 0},
   300  		{0, 1, 1},
   301  		{0, 1, 2},
   302  		{0, 1, 3},
   303  		{0, 2, 0},
   304  		{0, 2, 1},
   305  		{0, 2, 2},
   306  		{0, 2, 3},
   307  		{1, 0, 0},
   308  		{1, 0, 1},
   309  		{1, 0, 2},
   310  		{1, 0, 3},
   311  		{1, 1, 0},
   312  		{1, 1, 1},
   313  		{1, 1, 2},
   314  		{1, 1, 3},
   315  		{1, 2, 0},
   316  		{1, 2, 1},
   317  		{1, 2, 2},
   318  		{1, 2, 3},
   319  		{0, 0, 0},
   320  	}
   321  
   322  	for _, err = it.Next(); err == nil; _, err = it.Next() {
   323  		assert.Equal(correct[donecount], it.Coord())
   324  		donecount++
   325  	}
   326  }
   327  
   328  // really this is just for completeness sake
   329  func TestFlatIterator_Reset(t *testing.T) {
   330  	assert := assert.New(t)
   331  	ap := newAP(Shape{2, 3, 4}, []int{12, 4, 1})
   332  	it := newFlatIterator(ap)
   333  
   334  	it.Next()
   335  	it.Next()
   336  	it.Reset()
   337  	assert.Equal(0, it.nextIndex)
   338  	assert.Equal(false, it.done)
   339  	assert.Equal([]int{0, 0, 0}, it.track)
   340  
   341  	for _, err := it.Next(); err == nil; _, err = it.Next() {
   342  	}
   343  
   344  	it.Reset()
   345  	assert.Equal(0, it.nextIndex)
   346  	assert.Equal(false, it.done)
   347  	assert.Equal([]int{0, 0, 0}, it.track)
   348  }
   349  
   350  func TestDestroyIterator(t *testing.T) {
   351  	it := new(MultIterator)
   352  	destroyIterator(it)
   353  }
   354  
   355  /* BENCHMARK */
   356  type oldFlatIterator struct {
   357  	*AP
   358  
   359  	//state
   360  	lastIndex int
   361  	track     []int
   362  	done      bool
   363  }
   364  
   365  // newFlatIterator creates a new FlatIterator
   366  func newOldFlatIterator(ap *AP) *oldFlatIterator {
   367  	return &oldFlatIterator{
   368  		AP:    ap,
   369  		track: make([]int, len(ap.shape)),
   370  	}
   371  }
   372  
   373  func (it *oldFlatIterator) Next() (int, error) {
   374  	if it.done {
   375  		return -1, noopError{}
   376  	}
   377  
   378  	retVal, err := Ltoi(it.shape, it.strides, it.track...)
   379  	it.lastIndex = retVal
   380  
   381  	if it.IsScalar() {
   382  		it.done = true
   383  		return retVal, err
   384  	}
   385  
   386  	for d := len(it.shape) - 1; d >= 0; d-- {
   387  		if d == 0 && it.track[0]+1 >= it.shape[0] {
   388  			it.done = true
   389  			it.track[d] = 0 // overflow it
   390  			break
   391  		}
   392  
   393  		if it.track[d] < it.shape[d]-1 {
   394  			it.track[d]++
   395  			break
   396  		}
   397  		// overflow
   398  		it.track[d] = 0
   399  	}
   400  
   401  	return retVal, err
   402  }
   403  
   404  func (it *oldFlatIterator) Reset() {
   405  	it.done = false
   406  	it.lastIndex = 0
   407  
   408  	if it.done {
   409  		return
   410  	}
   411  
   412  	for i := range it.track {
   413  		it.track[i] = 0
   414  	}
   415  }
   416  
   417  func BenchmarkOldFlatIterator(b *testing.B) {
   418  	var err error
   419  
   420  	// as if T = NewTensor(WithShape(30, 1000, 1000))
   421  	// then T[:, 0:900:15, 250:750:50]
   422  	ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50})
   423  	it := newOldFlatIterator(ap)
   424  
   425  	for n := 0; n < b.N; n++ {
   426  		for _, err := it.Next(); err == nil; _, err = it.Next() {
   427  
   428  		}
   429  		if _, ok := err.(NoOpError); err != nil && !ok {
   430  			b.Error(err)
   431  		}
   432  
   433  		it.Reset()
   434  	}
   435  }
   436  
   437  func BenchmarkFlatIterator(b *testing.B) {
   438  	var err error
   439  
   440  	// as if T = NewTensor(WithShape(30, 1000, 1000))
   441  	// then T[:, 0:900:15, 250:750:50]
   442  	ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50})
   443  	it := newFlatIterator(ap)
   444  
   445  	for n := 0; n < b.N; n++ {
   446  		for _, err := it.Next(); err == nil; _, err = it.Next() {
   447  
   448  		}
   449  		if _, ok := err.(NoOpError); err != nil && !ok {
   450  			b.Error(err)
   451  		}
   452  
   453  		it.Reset()
   454  	}
   455  }
   456  
   457  func BenchmarkFlatIteratorParallel6(b *testing.B) {
   458  	var err error
   459  
   460  	// as if T = NewTensor(WithShape(30, 1000, 1000))
   461  	// then T[:, 0:900:15, 250:750:50]
   462  	ap := make([]*AP, 6)
   463  	it := make([]*FlatIterator, 6)
   464  
   465  	for j := 0; j < 6; j++ {
   466  		ap[j] = newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50})
   467  		it[j] = newFlatIterator(ap[j])
   468  	}
   469  
   470  	for n := 0; n < b.N; n++ {
   471  		for _, err := it[0].Next(); err == nil; _, err = it[0].Next() {
   472  			for j := 1; j < 6; j++ {
   473  				it[j].Next()
   474  			}
   475  
   476  		}
   477  		if _, ok := err.(NoOpError); err != nil && !ok {
   478  			b.Error(err)
   479  		}
   480  		for j := 0; j < 6; j++ {
   481  			it[j].Reset()
   482  		}
   483  	}
   484  
   485  }
   486  
   487  func BenchmarkFlatIteratorMulti1(b *testing.B) {
   488  	var err error
   489  
   490  	// as if T = NewTensor(WithShape(30, 1000, 1000))
   491  	// then T[:, 0:900:15, 250:750:50]
   492  	ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50})
   493  
   494  	it := NewMultIterator(ap)
   495  
   496  	for n := 0; n < b.N; n++ {
   497  		for _, err := it.Next(); err == nil; _, err = it.Next() {
   498  
   499  		}
   500  		if _, ok := err.(NoOpError); err != nil && !ok {
   501  			b.Error(err)
   502  		}
   503  		it.Reset()
   504  	}
   505  }
   506  
   507  func BenchmarkFlatIteratorGeneric1(b *testing.B) {
   508  	var err error
   509  
   510  	// as if T = NewTensor(WithShape(30, 1000, 1000))
   511  	// then T[:, 0:900:15, 250:750:50]
   512  	ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50})
   513  
   514  	it := NewIterator(ap)
   515  
   516  	for n := 0; n < b.N; n++ {
   517  		for _, err := it.Next(); err == nil; _, err = it.Next() {
   518  
   519  		}
   520  		if _, ok := err.(NoOpError); err != nil && !ok {
   521  			b.Error(err)
   522  		}
   523  		it.Reset()
   524  	}
   525  }
   526  
   527  func BenchmarkFlatIteratorMulti6(b *testing.B) {
   528  	var err error
   529  
   530  	// as if T = NewTensor(WithShape(30, 1000, 1000))
   531  	// then T[:, 0:900:15, 250:750:50]
   532  	ap := make([]*AP, 6)
   533  
   534  	for j := 0; j < 6; j++ {
   535  		ap[j] = newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50})
   536  	}
   537  
   538  	it := NewMultIterator(ap...)
   539  
   540  	for n := 0; n < b.N; n++ {
   541  		for _, err := it.Next(); err == nil; _, err = it.Next() {
   542  
   543  		}
   544  		if _, ok := err.(NoOpError); err != nil && !ok {
   545  			b.Error(err)
   546  		}
   547  		it.Reset()
   548  	}
   549  }