github.com/wzzhu/tensor@v0.9.24/ap_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	//"fmt"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  )
     9  
    10  func dummyScalar1() AP { return AP{} }
    11  
    12  func dummyScalar2() AP { return AP{shape: Shape{1}} }
    13  
    14  func dummyColVec() AP {
    15  	return AP{
    16  		shape:   Shape{5, 1},
    17  		strides: []int{1},
    18  	}
    19  }
    20  
    21  func dummyRowVec() AP {
    22  	return AP{
    23  		shape:   Shape{1, 5},
    24  		strides: []int{1},
    25  	}
    26  }
    27  
    28  func dummyVec() AP {
    29  	return AP{
    30  		shape:   Shape{5},
    31  		strides: []int{1},
    32  	}
    33  }
    34  
    35  func twothree() AP {
    36  	return AP{
    37  		shape:   Shape{2, 3},
    38  		strides: []int{3, 1},
    39  	}
    40  }
    41  
    42  func twothreefour() AP {
    43  	return AP{
    44  		shape:   Shape{2, 3, 4},
    45  		strides: []int{12, 4, 1},
    46  	}
    47  }
    48  
    49  func TestAccessPatternBasics(t *testing.T) {
    50  	assert := assert.New(t)
    51  	ap := new(AP)
    52  
    53  	ap.SetShape(1, 2)
    54  	assert.Equal(Shape{1, 2}, ap.Shape())
    55  	assert.Equal([]int{2, 1}, ap.Strides())
    56  	assert.Equal(2, ap.Dims())
    57  	assert.Equal(2, ap.Size())
    58  
    59  	ap.SetShape(2, 3, 2)
    60  	assert.Equal(Shape{2, 3, 2}, ap.Shape())
    61  	assert.Equal([]int{6, 2, 1}, ap.Strides())
    62  	assert.Equal(12, ap.Size())
    63  
    64  	ap.lock()
    65  	ap.SetShape(1, 2, 3)
    66  	assert.Equal(Shape{2, 3, 2}, ap.shape)
    67  	assert.Equal([]int{6, 2, 1}, ap.strides)
    68  
    69  	ap.unlock()
    70  	ap.SetShape(1, 2)
    71  	assert.Equal(Shape{1, 2}, ap.Shape())
    72  	assert.Equal([]int{2, 1}, ap.Strides())
    73  	assert.Equal(2, ap.Dims())
    74  	assert.Equal(2, ap.Size())
    75  
    76  	if ap.String() != "Shape: (1, 2), Stride: [2 1], Lock: false" {
    77  		t.Errorf("AP formatting error. Got %q", ap.String())
    78  	}
    79  
    80  	ap2 := ap.Clone()
    81  	assert.Equal(*ap, ap2)
    82  }
    83  
    84  func TestAccessPatternIsX(t *testing.T) {
    85  	assert := assert.New(t)
    86  	var ap AP
    87  
    88  	ap = dummyScalar1()
    89  	assert.True(ap.IsScalar())
    90  	assert.True(ap.IsScalarEquiv())
    91  	assert.False(ap.IsVector())
    92  	assert.False(ap.IsColVec())
    93  	assert.False(ap.IsRowVec())
    94  
    95  	ap = dummyScalar2()
    96  	assert.False(ap.IsScalar())
    97  	assert.True(ap.IsScalarEquiv())
    98  	assert.True(ap.IsVectorLike())
    99  	assert.True(ap.IsVector())
   100  	assert.False(ap.IsColVec())
   101  	assert.False(ap.IsRowVec())
   102  
   103  	ap = dummyColVec()
   104  	assert.True(ap.IsColVec())
   105  	assert.True(ap.IsVector())
   106  	assert.False(ap.IsRowVec())
   107  	assert.False(ap.IsScalar())
   108  
   109  	ap = dummyRowVec()
   110  	assert.True(ap.IsRowVec())
   111  	assert.True(ap.IsVector())
   112  	assert.False(ap.IsColVec())
   113  	assert.False(ap.IsScalar())
   114  
   115  	ap = twothree()
   116  	assert.True(ap.IsMatrix())
   117  	assert.False(ap.IsScalar())
   118  	assert.False(ap.IsVector())
   119  	assert.False(ap.IsRowVec())
   120  	assert.False(ap.IsColVec())
   121  
   122  }
   123  
   124  func TestAccessPatternT(t *testing.T) {
   125  	assert := assert.New(t)
   126  	var ap, apT AP
   127  	var axes []int
   128  	var err error
   129  
   130  	ap = twothree()
   131  
   132  	// test no axes
   133  	apT, axes, err = ap.T()
   134  	if err != nil {
   135  		t.Error(err)
   136  	}
   137  
   138  	assert.Equal(Shape{3, 2}, apT.shape)
   139  	assert.Equal([]int{1, 3}, apT.strides)
   140  	assert.Equal([]int{1, 0}, axes)
   141  	assert.Equal(2, apT.Dims())
   142  
   143  	// test no op
   144  	apT, _, err = ap.T(0, 1)
   145  	if err != nil {
   146  		if _, ok := err.(NoOpError); !ok {
   147  			t.Error(err)
   148  		}
   149  	}
   150  
   151  	// test 3D
   152  	ap = twothreefour()
   153  	apT, axes, err = ap.T(2, 0, 1)
   154  	if err != nil {
   155  		t.Error(err)
   156  	}
   157  	assert.Equal(Shape{4, 2, 3}, apT.shape)
   158  	assert.Equal([]int{1, 12, 4}, apT.strides)
   159  	assert.Equal([]int{2, 0, 1}, axes)
   160  	assert.Equal(3, apT.Dims())
   161  
   162  	// test stupid axes
   163  	_, _, err = ap.T(1, 2, 3)
   164  	if err == nil {
   165  		t.Error("Expected an error")
   166  	}
   167  }
   168  
   169  var sliceTests = []struct {
   170  	name   string
   171  	shape  Shape
   172  	slices []Slice
   173  
   174  	correctStart  int
   175  	correctEnd    int
   176  	correctShape  Shape
   177  	correctStride []int
   178  	contiguous    bool
   179  }{
   180  	// vectors
   181  	{"a[0]", Shape{5}, []Slice{S(0)}, 0, 1, ScalarShape(), nil, true},
   182  	{"a[0:2]", Shape{5}, []Slice{S(0, 2)}, 0, 2, Shape{2}, []int{1}, true},
   183  	{"a[1:3]", Shape{5}, []Slice{S(1, 3)}, 1, 3, Shape{2}, []int{1}, true},
   184  	{"a[1:5:2]", Shape{5}, []Slice{S(1, 5, 2)}, 1, 5, Shape{2}, []int{2}, false},
   185  
   186  	// matrix
   187  	{"A[0]", Shape{2, 3}, []Slice{S(0)}, 0, 3, Shape{1, 3}, []int{1}, true},
   188  	{"A[1:3]", Shape{4, 5}, []Slice{S(1, 3)}, 5, 15, Shape{2, 5}, []int{5, 1}, true},
   189  	{"A[0:10] (intentionally over)", Shape{4, 5}, []Slice{S(0, 10)}, 0, 20, Shape{4, 5}, []int{5, 1}, true}, // as if nothing happened
   190  	{"A[:, 1:3]", Shape{4, 5}, []Slice{nil, S(1, 3)}, 1, 18, Shape{4, 2}, []int{5, 1}, false},
   191  
   192  	// tensor
   193  	{"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, 0, 4, Shape{2, 2}, []int{2, 1}, true},
   194  	{"tensor[:, 0, :]", Shape{1, 2, 2}, []Slice{nil, rs{0, 1, 1}, nil}, 0, 2, Shape{1, 2}, []int{4, 1}, false},
   195  	{"tensor[0, :, :, :]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil, nil}, 0, 4, Shape{1, 2, 2}, []int{4, 2, 1}, true},
   196  	{"tensor[0,]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}}, 0, 4, Shape{1, 2, 2}, []int{4, 2, 1}, true},
   197  }
   198  
   199  func TestAccessPatternS(t *testing.T) {
   200  	assert := assert.New(t)
   201  	var ap, apS AP
   202  	var ndStart, ndEnd int
   203  	var err error
   204  
   205  	for _, sts := range sliceTests {
   206  		ap = MakeAP(sts.shape, sts.shape.CalcStrides(), 0, 0)
   207  		if apS, ndStart, ndEnd, err = ap.S(sts.shape.TotalSize(), sts.slices...); err != nil {
   208  			t.Errorf("%v errored: %v", sts.name, err)
   209  			continue
   210  		}
   211  		assert.Equal(sts.correctStart, ndStart, "Wrong start: %v. Want %d Got %d", sts.name, sts.correctStart, ndStart)
   212  		assert.Equal(sts.correctEnd, ndEnd, "Wrong end: %v. Want %d Got %d", sts.name, sts.correctEnd, ndEnd)
   213  		assert.True(sts.correctShape.Eq(apS.shape), "Wrong shape: %v. Want %v. Got %v", sts.name, sts.correctShape, apS.shape)
   214  		assert.Equal(sts.correctStride, apS.strides, "Wrong strides: %v. Want %v. Got %v", sts.name, sts.correctStride, apS.strides)
   215  		assert.Equal(sts.contiguous, apS.DataOrder().IsContiguous(), "Wrong contiguity for %v Want %t.", sts.name, sts.contiguous)
   216  	}
   217  }
   218  
   219  func TestTransposeIndex(t *testing.T) {
   220  	var newInd int
   221  	var oldShape Shape
   222  	var pattern, oldStrides, newStrides, corrects []int
   223  
   224  	/*
   225  		(2,3)->(3,2)
   226  		0, 1, 2
   227  		3, 4, 5
   228  
   229  		becomes
   230  
   231  		0, 3
   232  		1, 4
   233  		2, 5
   234  
   235  		1 -> 2
   236  		2 -> 4
   237  		3 -> 1
   238  		4 -> 3
   239  		0 and 5 stay the same
   240  	*/
   241  
   242  	oldShape = Shape{2, 3}
   243  	pattern = []int{1, 0}
   244  	oldStrides = []int{3, 1}
   245  	newStrides = []int{2, 1}
   246  	corrects = []int{0, 2, 4, 1, 3, 5}
   247  	for i := 0; i < 6; i++ {
   248  		newInd = TransposeIndex(i, oldShape, pattern, oldStrides, newStrides)
   249  		if newInd != corrects[i] {
   250  			t.Errorf("Want %d, got %d instead", corrects[i], newInd)
   251  		}
   252  	}
   253  
   254  	/*
   255  		(2,3,4) -(1,0,2)-> (3,2,4)
   256  		0, 1, 2, 3
   257  		4, 5, 6, 7
   258  		8, 9, 10, 11
   259  
   260  		12, 13, 14, 15
   261  		16, 17, 18, 19
   262  		20, 21, 22, 23
   263  
   264  		becomes
   265  
   266  		0,   1,  2,  3
   267  		12, 13, 14, 15,
   268  
   269  		4,   5,  6,  7
   270  		16, 17, 18, 19
   271  
   272  		8,   9, 10, 11
   273  		20, 21, 22, 23
   274  	*/
   275  	oldShape = Shape{2, 3, 4}
   276  	pattern = []int{1, 0, 2}
   277  	oldStrides = []int{12, 4, 1}
   278  	newStrides = []int{8, 4, 1}
   279  	corrects = []int{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23}
   280  	for i := 0; i < len(corrects); i++ {
   281  		newInd = TransposeIndex(i, oldShape, pattern, oldStrides, newStrides)
   282  		if newInd != corrects[i] {
   283  			t.Errorf("Want %d, got %d instead", corrects[i], newInd)
   284  		}
   285  	}
   286  
   287  	/*
   288  		(2,3,4) -(2,0,1)-> (4,2,3)
   289  		0, 1, 2, 3
   290  		4, 5, 6, 7
   291  		8, 9, 10, 11
   292  
   293  		12, 13, 14, 15
   294  		16, 17, 18, 19
   295  		20, 21, 22, 23
   296  
   297  		becomes
   298  
   299  		0,   4,  8
   300  		12, 16, 20
   301  
   302  		1,   5,  9
   303  		13, 17, 21
   304  
   305  		2,   6, 10
   306  		14, 18, 22
   307  
   308  		3,   7, 11
   309  		15, 19, 23
   310  	*/
   311  
   312  	oldShape = Shape{2, 3, 4}
   313  	pattern = []int{2, 0, 1}
   314  	oldStrides = []int{12, 4, 1}
   315  	newStrides = []int{6, 3, 1}
   316  	corrects = []int{0, 6, 12, 18, 1, 7, 13, 19, 2, 8, 14, 20, 3, 9, 15, 21, 4, 10, 16, 22, 5, 11, 17, 23}
   317  	for i := 0; i < len(corrects); i++ {
   318  		newInd = TransposeIndex(i, oldShape, pattern, oldStrides, newStrides)
   319  		if newInd != corrects[i] {
   320  			t.Errorf("Want %d, got %d instead", corrects[i], newInd)
   321  		}
   322  	}
   323  
   324  }
   325  
   326  func TestUntransposeIndex(t *testing.T) {
   327  	var newInd int
   328  	var oldShape Shape
   329  	var pattern, oldStrides, newStrides, corrects []int
   330  
   331  	// vice versa
   332  	oldShape = Shape{3, 2}
   333  	oldStrides = []int{2, 1}
   334  	newStrides = []int{3, 1}
   335  	corrects = []int{0, 3, 1, 4, 2, 5}
   336  	pattern = []int{1, 0}
   337  	for i := 0; i < 6; i++ {
   338  		newInd = UntransposeIndex(i, oldShape, pattern, oldStrides, newStrides)
   339  		if newInd != corrects[i] {
   340  			t.Errorf("Want %d, got %d instead", corrects[i], newInd)
   341  		}
   342  	}
   343  
   344  	oldShape = Shape{3, 2, 4}
   345  	oldStrides = []int{8, 4, 1}
   346  	newStrides = []int{12, 4, 1}
   347  	pattern = []int{1, 0, 2}
   348  	corrects = []int{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23}
   349  	for i := 0; i < len(corrects); i++ {
   350  		newInd = TransposeIndex(i, oldShape, pattern, oldStrides, newStrides)
   351  		if newInd != corrects[i] {
   352  			t.Errorf("Want %d, got %d instead", corrects[i], newInd)
   353  		}
   354  	}
   355  
   356  	oldShape = Shape{4, 2, 3}
   357  	pattern = []int{2, 0, 1}
   358  	newStrides = []int{12, 4, 1}
   359  	oldStrides = []int{6, 3, 1}
   360  	corrects = []int{0, 4, 8, 12, 16, 20}
   361  	for i := 0; i < len(corrects); i++ {
   362  		newInd = UntransposeIndex(i, oldShape, pattern, oldStrides, newStrides)
   363  		if newInd != corrects[i] {
   364  			t.Errorf("Want %d, got %d instead", corrects[i], newInd)
   365  		}
   366  	}
   367  }
   368  
   369  func TestBroadcastStrides(t *testing.T) {
   370  	ds := Shape{4, 4}
   371  	ss := Shape{4}
   372  	dst := []int{4, 1}
   373  	sst := []int{1}
   374  
   375  	st, err := BroadcastStrides(ds, ss, dst, sst)
   376  	if err != nil {
   377  		t.Error(err)
   378  	}
   379  	t.Log(st)
   380  }