github.com/wzzhu/tensor@v0.9.24/shape_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  )
     9  
    10  func TestShapeBasics(t *testing.T) {
    11  	var s Shape
    12  	var ds int
    13  	var err error
    14  	s = Shape{1, 2}
    15  
    16  	if ds, err = s.DimSize(0); err != nil {
    17  		t.Error(err)
    18  	}
    19  	if ds != 1 {
    20  		t.Error("Expected DimSize(0) to be 1")
    21  	}
    22  
    23  	if ds, err = s.DimSize(2); err == nil {
    24  		t.Error("Expected a DimensionMismatch error")
    25  	}
    26  
    27  	s = ScalarShape()
    28  	if ds, err = s.DimSize(0); err != nil {
    29  		t.Error(err)
    30  	}
    31  
    32  	if ds != 0 {
    33  		t.Error("Expected DimSize(0) of a scalar to be 0")
    34  	}
    35  
    36  	// format for completeness sake
    37  	s = Shape{2, 1}
    38  	if fmt.Sprintf("%d", s) != "[2 1]" {
    39  		t.Error("Shape.Format() error")
    40  	}
    41  }
    42  
    43  func TestShapeIsX(t *testing.T) {
    44  	assert := assert.New(t)
    45  	var s Shape
    46  
    47  	// scalar shape
    48  	s = Shape{}
    49  	assert.True(s.IsScalar())
    50  	assert.True(s.IsScalarEquiv())
    51  	assert.False(s.IsVector())
    52  	assert.False(s.IsColVec())
    53  	assert.False(s.IsRowVec())
    54  
    55  	// vectors
    56  
    57  	// scalar-equiv vector
    58  	s = Shape{1}
    59  	assert.False(s.IsScalar())
    60  	assert.True(s.IsScalarEquiv())
    61  	assert.True(s.IsVector())
    62  	assert.True(s.IsVectorLike())
    63  	assert.True(s.IsVector())
    64  	assert.False(s.IsColVec())
    65  	assert.False(s.IsRowVec())
    66  
    67  	// vanila vector
    68  	s = Shape{2}
    69  	assert.False(s.IsScalar())
    70  	assert.True(s.IsVector())
    71  	assert.False(s.IsColVec())
    72  	assert.False(s.IsRowVec())
    73  
    74  	// col vec
    75  	s = Shape{2, 1}
    76  	assert.False(s.IsScalar())
    77  	assert.True(s.IsVector())
    78  	assert.True(s.IsVectorLike())
    79  	assert.True(s.IsColVec())
    80  	assert.False(s.IsRowVec())
    81  
    82  	// row vec
    83  	s = Shape{1, 2}
    84  	assert.False(s.IsScalar())
    85  	assert.True(s.IsVector())
    86  	assert.True(s.IsVectorLike())
    87  	assert.False(s.IsColVec())
    88  	assert.True(s.IsRowVec())
    89  
    90  	// matrix and up
    91  	s = Shape{2, 2}
    92  	assert.False(s.IsScalar())
    93  	assert.False(s.IsVector())
    94  	assert.False(s.IsColVec())
    95  	assert.False(s.IsRowVec())
    96  
    97  	// scalar equiv matrix
    98  	s = Shape{1, 1}
    99  	assert.False(s.IsScalar())
   100  	assert.True(s.IsScalarEquiv())
   101  	assert.True(s.IsVectorLike())
   102  	assert.False(s.IsVector())
   103  }
   104  
   105  func TestShapeCalcStride(t *testing.T) {
   106  	assert := assert.New(t)
   107  	var s Shape
   108  
   109  	// scalar shape
   110  	s = Shape{}
   111  	assert.Nil(s.CalcStrides())
   112  
   113  	// vector shape
   114  	s = Shape{1}
   115  	assert.Equal([]int{1}, s.CalcStrides())
   116  
   117  	s = Shape{2, 1}
   118  	assert.Equal([]int{1, 1}, s.CalcStrides())
   119  
   120  	s = Shape{1, 2}
   121  	assert.Equal([]int{2, 1}, s.CalcStrides())
   122  
   123  	s = Shape{2}
   124  	assert.Equal([]int{1}, s.CalcStrides())
   125  
   126  	// matrix strides
   127  	s = Shape{2, 2}
   128  	assert.Equal([]int{2, 1}, s.CalcStrides())
   129  
   130  	s = Shape{5, 2}
   131  	assert.Equal([]int{2, 1}, s.CalcStrides())
   132  
   133  	// 3D strides
   134  	s = Shape{2, 3, 4}
   135  	assert.Equal([]int{12, 4, 1}, s.CalcStrides())
   136  
   137  	// stupid shape
   138  	s = Shape{-2, 1, 2}
   139  	fail := func() {
   140  		s.CalcStrides()
   141  	}
   142  	assert.Panics(fail)
   143  }
   144  
   145  func TestShapeEquality(t *testing.T) {
   146  	assert := assert.New(t)
   147  	var s1, s2 Shape
   148  
   149  	// scalar
   150  	s1 = Shape{}
   151  	s2 = Shape{}
   152  	assert.True(s1.Eq(s2))
   153  	assert.True(s2.Eq(s1))
   154  
   155  	// scalars and scalar equiv are not the same!
   156  	s1 = Shape{1}
   157  	s2 = Shape{}
   158  	assert.False(s1.Eq(s2))
   159  	assert.False(s2.Eq(s1))
   160  
   161  	// vector
   162  	s1 = Shape{3}
   163  	s2 = Shape{5}
   164  	assert.False(s1.Eq(s2))
   165  	assert.False(s2.Eq(s1))
   166  
   167  	s1 = Shape{2, 1}
   168  	s2 = Shape{2, 1}
   169  	assert.True(s1.Eq(s2))
   170  	assert.True(s2.Eq(s1))
   171  
   172  	s2 = Shape{2}
   173  	assert.True(s1.Eq(s2))
   174  	assert.True(s2.Eq(s1))
   175  
   176  	s2 = Shape{1, 2}
   177  	assert.False(s1.Eq(s2))
   178  	assert.False(s2.Eq(s1))
   179  
   180  	s1 = Shape{2}
   181  	assert.True(s1.Eq(s2))
   182  	assert.True(s2.Eq(s1))
   183  
   184  	s2 = Shape{2, 3}
   185  	assert.False(s1.Eq(s2))
   186  	assert.False(s2.Eq(s1))
   187  
   188  	// matrix
   189  	s1 = Shape{2, 3}
   190  	assert.True(s1.Eq(s2))
   191  	assert.True(s2.Eq(s1))
   192  
   193  	s2 = Shape{3, 2}
   194  	assert.False(s1.Eq(s2))
   195  	assert.False(s2.Eq(s1))
   196  
   197  	// just for that green coloured code
   198  	s1 = Shape{2}
   199  	s2 = Shape{1, 3}
   200  	assert.False(s1.Eq(s2))
   201  	assert.False(s2.Eq(s1))
   202  }
   203  
   204  var shapeSliceTests = []struct {
   205  	name string
   206  	s    Shape
   207  	sli  []Slice
   208  
   209  	expected Shape
   210  	err      bool
   211  }{
   212  	{"slicing a scalar shape", ScalarShape(), nil, ScalarShape(), false},
   213  	{"slicing a scalar shape", ScalarShape(), []Slice{rs{0, 0, 0}}, nil, true},
   214  	{"vec[0]", Shape{2}, []Slice{rs{0, 1, 0}}, ScalarShape(), false},
   215  	{"vec[3]", Shape{2}, []Slice{rs{3, 4, 0}}, nil, true},
   216  	{"vec[:, 0]", Shape{2}, []Slice{nil, rs{0, 1, 0}}, nil, true},
   217  	{"vec[1:4:2]", Shape{5}, []Slice{rs{1, 4, 2}}, ScalarShape(), false},
   218  	{"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, Shape{2, 2}, false},
   219  	{"tensor[:, 0, :]", Shape{1, 2, 2}, []Slice{nil, rs{0, 1, 1}, nil}, Shape{1, 2}, false},
   220  	{"tensor[0, :, :, :]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil, nil}, Shape{1, 2, 2}, false},
   221  	{"tensor[0,]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}}, Shape{1, 2, 2}, false},
   222  }
   223  
   224  func TestShape_Slice(t *testing.T) {
   225  	for i, ssts := range shapeSliceTests {
   226  		newShape, err := ssts.s.S(ssts.sli...)
   227  		if checkErr(t, ssts.err, err, "Shape slice", i) {
   228  			continue
   229  		}
   230  
   231  		if !ssts.expected.Eq(newShape) {
   232  			t.Errorf("Test %q: Expected shape %v. Got %v instead", ssts.name, ssts.expected, newShape)
   233  		}
   234  	}
   235  }
   236  
   237  var shapeRepeatTests = []struct {
   238  	name    string
   239  	s       Shape
   240  	repeats []int
   241  	axis    int
   242  
   243  	expected        Shape
   244  	expectedRepeats []int
   245  	expectedSize    int
   246  	err             bool
   247  }{
   248  	{"scalar repeat on axis 0", ScalarShape(), []int{3}, 0, Shape{3}, []int{3}, 1, false},
   249  	{"scalar repeat on axis 1", ScalarShape(), []int{3}, 1, Shape{1, 3}, []int{3}, 1, false},
   250  	{"vector repeat on axis 0", Shape{2}, []int{3}, 0, Shape{6}, []int{3, 3}, 2, false},
   251  	{"vector repeat on axis 1", Shape{2}, []int{3}, 1, Shape{2, 3}, []int{3}, 1, false},
   252  	{"colvec repeats on axis 0", Shape{2, 1}, []int{3}, 0, Shape{6, 1}, []int{3, 3}, 2, false},
   253  	{"colvec repeats on axis 1", Shape{2, 1}, []int{3}, 1, Shape{2, 3}, []int{3}, 1, false},
   254  	{"rowvec repeats on axis 0", Shape{1, 2}, []int{3}, 0, Shape{3, 2}, []int{3}, 1, false},
   255  	{"rowvec repeats on axis 1", Shape{1, 2}, []int{3}, 1, Shape{1, 6}, []int{3, 3}, 2, false},
   256  	{"3-Tensor repeats", Shape{2, 3, 2}, []int{1, 2, 1}, 1, Shape{2, 4, 2}, []int{1, 2, 1}, 3, false},
   257  	{"3-Tensor generic repeats", Shape{2, 3, 2}, []int{2}, AllAxes, Shape{24}, []int{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 12, false},
   258  	{"3-Tensor generic repeat, axis specified", Shape{2, 3, 2}, []int{2}, 2, Shape{2, 3, 4}, []int{2, 2}, 2, false},
   259  
   260  	// stupids
   261  	{"nonexisting axis 2", Shape{2, 1}, []int{3}, 2, nil, nil, 0, true},
   262  	{"mismatching repeats", Shape{2, 3, 2}, []int{3, 1, 2}, 0, nil, nil, 0, true},
   263  }
   264  
   265  func TestShape_Repeat(t *testing.T) {
   266  	assert := assert.New(t)
   267  	for _, srts := range shapeRepeatTests {
   268  		newShape, reps, size, err := srts.s.Repeat(srts.axis, srts.repeats...)
   269  
   270  		switch {
   271  		case srts.err:
   272  			if err == nil {
   273  				t.Error("Expected an error")
   274  			}
   275  			continue
   276  		case !srts.err && err != nil:
   277  			t.Error(err)
   278  			continue
   279  		}
   280  
   281  		assert.True(srts.expected.Eq(newShape), "Test %q:  Want: %v. Got %v", srts.name, srts.expected, newShape)
   282  		assert.Equal(srts.expectedRepeats, reps, "Test %q: ", srts.name)
   283  		assert.Equal(srts.expectedSize, size, "Test %q: ", srts.name)
   284  	}
   285  }
   286  
   287  var shapeConcatTests = []struct {
   288  	name string
   289  	s    Shape
   290  	axis int
   291  	ss   []Shape
   292  
   293  	expected Shape
   294  	err      bool
   295  }{
   296  	{"standard, axis 0 ", Shape{2, 2}, 0, []Shape{{2, 2}, {2, 2}}, Shape{6, 2}, false},
   297  	{"standard, axis 1 ", Shape{2, 2}, 1, []Shape{{2, 2}, {2, 2}}, Shape{2, 6}, false},
   298  	{"standard, axis AllAxes ", Shape{2, 2}, -1, []Shape{{2, 2}, {2, 2}}, Shape{6, 2}, false},
   299  	{"concat to empty", Shape{2}, 0, nil, Shape{2}, false},
   300  
   301  	{"stupids: different dims", Shape{2, 2}, 0, []Shape{{2, 3, 2}}, nil, true},
   302  	{"stupids: negative axes", Shape{2, 2}, -5, []Shape{{2, 2}}, nil, true},
   303  	{"stupids: toobig axis", Shape{2, 2}, 5, []Shape{{2, 2}}, nil, true},
   304  	{"subtle stupids: dim mismatch", Shape{2, 2}, 0, []Shape{{2, 2}, {2, 3}}, nil, true},
   305  }
   306  
   307  func TestShape_Concat(t *testing.T) {
   308  	assert := assert.New(t)
   309  	for _, scts := range shapeConcatTests {
   310  		newShape, err := scts.s.Concat(scts.axis, scts.ss...)
   311  		switch {
   312  		case scts.err:
   313  			if err == nil {
   314  				t.Error("Expected an error")
   315  			}
   316  			continue
   317  		case !scts.err && err != nil:
   318  			t.Error(err)
   319  			continue
   320  		}
   321  		assert.Equal(scts.expected, newShape)
   322  	}
   323  }