github.com/wzzhu/tensor@v0.9.24/shape_test.go (about) 1 package tensor 2 3 import ( 4 "fmt" 5 "testing" 6 7 "github.com/stretchr/testify/assert" 8 ) 9 10 func TestShapeBasics(t *testing.T) { 11 var s Shape 12 var ds int 13 var err error 14 s = Shape{1, 2} 15 16 if ds, err = s.DimSize(0); err != nil { 17 t.Error(err) 18 } 19 if ds != 1 { 20 t.Error("Expected DimSize(0) to be 1") 21 } 22 23 if ds, err = s.DimSize(2); err == nil { 24 t.Error("Expected a DimensionMismatch error") 25 } 26 27 s = ScalarShape() 28 if ds, err = s.DimSize(0); err != nil { 29 t.Error(err) 30 } 31 32 if ds != 0 { 33 t.Error("Expected DimSize(0) of a scalar to be 0") 34 } 35 36 // format for completeness sake 37 s = Shape{2, 1} 38 if fmt.Sprintf("%d", s) != "[2 1]" { 39 t.Error("Shape.Format() error") 40 } 41 } 42 43 func TestShapeIsX(t *testing.T) { 44 assert := assert.New(t) 45 var s Shape 46 47 // scalar shape 48 s = Shape{} 49 assert.True(s.IsScalar()) 50 assert.True(s.IsScalarEquiv()) 51 assert.False(s.IsVector()) 52 assert.False(s.IsColVec()) 53 assert.False(s.IsRowVec()) 54 55 // vectors 56 57 // scalar-equiv vector 58 s = Shape{1} 59 assert.False(s.IsScalar()) 60 assert.True(s.IsScalarEquiv()) 61 assert.True(s.IsVector()) 62 assert.True(s.IsVectorLike()) 63 assert.True(s.IsVector()) 64 assert.False(s.IsColVec()) 65 assert.False(s.IsRowVec()) 66 67 // vanila vector 68 s = Shape{2} 69 assert.False(s.IsScalar()) 70 assert.True(s.IsVector()) 71 assert.False(s.IsColVec()) 72 assert.False(s.IsRowVec()) 73 74 // col vec 75 s = Shape{2, 1} 76 assert.False(s.IsScalar()) 77 assert.True(s.IsVector()) 78 assert.True(s.IsVectorLike()) 79 assert.True(s.IsColVec()) 80 assert.False(s.IsRowVec()) 81 82 // row vec 83 s = Shape{1, 2} 84 assert.False(s.IsScalar()) 85 assert.True(s.IsVector()) 86 assert.True(s.IsVectorLike()) 87 assert.False(s.IsColVec()) 88 assert.True(s.IsRowVec()) 89 90 // matrix and up 91 s = Shape{2, 2} 92 assert.False(s.IsScalar()) 93 assert.False(s.IsVector()) 94 assert.False(s.IsColVec()) 95 assert.False(s.IsRowVec()) 96 97 // scalar equiv matrix 98 s = Shape{1, 1} 99 assert.False(s.IsScalar()) 100 assert.True(s.IsScalarEquiv()) 101 assert.True(s.IsVectorLike()) 102 assert.False(s.IsVector()) 103 } 104 105 func TestShapeCalcStride(t *testing.T) { 106 assert := assert.New(t) 107 var s Shape 108 109 // scalar shape 110 s = Shape{} 111 assert.Nil(s.CalcStrides()) 112 113 // vector shape 114 s = Shape{1} 115 assert.Equal([]int{1}, s.CalcStrides()) 116 117 s = Shape{2, 1} 118 assert.Equal([]int{1, 1}, s.CalcStrides()) 119 120 s = Shape{1, 2} 121 assert.Equal([]int{2, 1}, s.CalcStrides()) 122 123 s = Shape{2} 124 assert.Equal([]int{1}, s.CalcStrides()) 125 126 // matrix strides 127 s = Shape{2, 2} 128 assert.Equal([]int{2, 1}, s.CalcStrides()) 129 130 s = Shape{5, 2} 131 assert.Equal([]int{2, 1}, s.CalcStrides()) 132 133 // 3D strides 134 s = Shape{2, 3, 4} 135 assert.Equal([]int{12, 4, 1}, s.CalcStrides()) 136 137 // stupid shape 138 s = Shape{-2, 1, 2} 139 fail := func() { 140 s.CalcStrides() 141 } 142 assert.Panics(fail) 143 } 144 145 func TestShapeEquality(t *testing.T) { 146 assert := assert.New(t) 147 var s1, s2 Shape 148 149 // scalar 150 s1 = Shape{} 151 s2 = Shape{} 152 assert.True(s1.Eq(s2)) 153 assert.True(s2.Eq(s1)) 154 155 // scalars and scalar equiv are not the same! 156 s1 = Shape{1} 157 s2 = Shape{} 158 assert.False(s1.Eq(s2)) 159 assert.False(s2.Eq(s1)) 160 161 // vector 162 s1 = Shape{3} 163 s2 = Shape{5} 164 assert.False(s1.Eq(s2)) 165 assert.False(s2.Eq(s1)) 166 167 s1 = Shape{2, 1} 168 s2 = Shape{2, 1} 169 assert.True(s1.Eq(s2)) 170 assert.True(s2.Eq(s1)) 171 172 s2 = Shape{2} 173 assert.True(s1.Eq(s2)) 174 assert.True(s2.Eq(s1)) 175 176 s2 = Shape{1, 2} 177 assert.False(s1.Eq(s2)) 178 assert.False(s2.Eq(s1)) 179 180 s1 = Shape{2} 181 assert.True(s1.Eq(s2)) 182 assert.True(s2.Eq(s1)) 183 184 s2 = Shape{2, 3} 185 assert.False(s1.Eq(s2)) 186 assert.False(s2.Eq(s1)) 187 188 // matrix 189 s1 = Shape{2, 3} 190 assert.True(s1.Eq(s2)) 191 assert.True(s2.Eq(s1)) 192 193 s2 = Shape{3, 2} 194 assert.False(s1.Eq(s2)) 195 assert.False(s2.Eq(s1)) 196 197 // just for that green coloured code 198 s1 = Shape{2} 199 s2 = Shape{1, 3} 200 assert.False(s1.Eq(s2)) 201 assert.False(s2.Eq(s1)) 202 } 203 204 var shapeSliceTests = []struct { 205 name string 206 s Shape 207 sli []Slice 208 209 expected Shape 210 err bool 211 }{ 212 {"slicing a scalar shape", ScalarShape(), nil, ScalarShape(), false}, 213 {"slicing a scalar shape", ScalarShape(), []Slice{rs{0, 0, 0}}, nil, true}, 214 {"vec[0]", Shape{2}, []Slice{rs{0, 1, 0}}, ScalarShape(), false}, 215 {"vec[3]", Shape{2}, []Slice{rs{3, 4, 0}}, nil, true}, 216 {"vec[:, 0]", Shape{2}, []Slice{nil, rs{0, 1, 0}}, nil, true}, 217 {"vec[1:4:2]", Shape{5}, []Slice{rs{1, 4, 2}}, ScalarShape(), false}, 218 {"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, Shape{2, 2}, false}, 219 {"tensor[:, 0, :]", Shape{1, 2, 2}, []Slice{nil, rs{0, 1, 1}, nil}, Shape{1, 2}, false}, 220 {"tensor[0, :, :, :]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil, nil}, Shape{1, 2, 2}, false}, 221 {"tensor[0,]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}}, Shape{1, 2, 2}, false}, 222 } 223 224 func TestShape_Slice(t *testing.T) { 225 for i, ssts := range shapeSliceTests { 226 newShape, err := ssts.s.S(ssts.sli...) 227 if checkErr(t, ssts.err, err, "Shape slice", i) { 228 continue 229 } 230 231 if !ssts.expected.Eq(newShape) { 232 t.Errorf("Test %q: Expected shape %v. Got %v instead", ssts.name, ssts.expected, newShape) 233 } 234 } 235 } 236 237 var shapeRepeatTests = []struct { 238 name string 239 s Shape 240 repeats []int 241 axis int 242 243 expected Shape 244 expectedRepeats []int 245 expectedSize int 246 err bool 247 }{ 248 {"scalar repeat on axis 0", ScalarShape(), []int{3}, 0, Shape{3}, []int{3}, 1, false}, 249 {"scalar repeat on axis 1", ScalarShape(), []int{3}, 1, Shape{1, 3}, []int{3}, 1, false}, 250 {"vector repeat on axis 0", Shape{2}, []int{3}, 0, Shape{6}, []int{3, 3}, 2, false}, 251 {"vector repeat on axis 1", Shape{2}, []int{3}, 1, Shape{2, 3}, []int{3}, 1, false}, 252 {"colvec repeats on axis 0", Shape{2, 1}, []int{3}, 0, Shape{6, 1}, []int{3, 3}, 2, false}, 253 {"colvec repeats on axis 1", Shape{2, 1}, []int{3}, 1, Shape{2, 3}, []int{3}, 1, false}, 254 {"rowvec repeats on axis 0", Shape{1, 2}, []int{3}, 0, Shape{3, 2}, []int{3}, 1, false}, 255 {"rowvec repeats on axis 1", Shape{1, 2}, []int{3}, 1, Shape{1, 6}, []int{3, 3}, 2, false}, 256 {"3-Tensor repeats", Shape{2, 3, 2}, []int{1, 2, 1}, 1, Shape{2, 4, 2}, []int{1, 2, 1}, 3, false}, 257 {"3-Tensor generic repeats", Shape{2, 3, 2}, []int{2}, AllAxes, Shape{24}, []int{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 12, false}, 258 {"3-Tensor generic repeat, axis specified", Shape{2, 3, 2}, []int{2}, 2, Shape{2, 3, 4}, []int{2, 2}, 2, false}, 259 260 // stupids 261 {"nonexisting axis 2", Shape{2, 1}, []int{3}, 2, nil, nil, 0, true}, 262 {"mismatching repeats", Shape{2, 3, 2}, []int{3, 1, 2}, 0, nil, nil, 0, true}, 263 } 264 265 func TestShape_Repeat(t *testing.T) { 266 assert := assert.New(t) 267 for _, srts := range shapeRepeatTests { 268 newShape, reps, size, err := srts.s.Repeat(srts.axis, srts.repeats...) 269 270 switch { 271 case srts.err: 272 if err == nil { 273 t.Error("Expected an error") 274 } 275 continue 276 case !srts.err && err != nil: 277 t.Error(err) 278 continue 279 } 280 281 assert.True(srts.expected.Eq(newShape), "Test %q: Want: %v. Got %v", srts.name, srts.expected, newShape) 282 assert.Equal(srts.expectedRepeats, reps, "Test %q: ", srts.name) 283 assert.Equal(srts.expectedSize, size, "Test %q: ", srts.name) 284 } 285 } 286 287 var shapeConcatTests = []struct { 288 name string 289 s Shape 290 axis int 291 ss []Shape 292 293 expected Shape 294 err bool 295 }{ 296 {"standard, axis 0 ", Shape{2, 2}, 0, []Shape{{2, 2}, {2, 2}}, Shape{6, 2}, false}, 297 {"standard, axis 1 ", Shape{2, 2}, 1, []Shape{{2, 2}, {2, 2}}, Shape{2, 6}, false}, 298 {"standard, axis AllAxes ", Shape{2, 2}, -1, []Shape{{2, 2}, {2, 2}}, Shape{6, 2}, false}, 299 {"concat to empty", Shape{2}, 0, nil, Shape{2}, false}, 300 301 {"stupids: different dims", Shape{2, 2}, 0, []Shape{{2, 3, 2}}, nil, true}, 302 {"stupids: negative axes", Shape{2, 2}, -5, []Shape{{2, 2}}, nil, true}, 303 {"stupids: toobig axis", Shape{2, 2}, 5, []Shape{{2, 2}}, nil, true}, 304 {"subtle stupids: dim mismatch", Shape{2, 2}, 0, []Shape{{2, 2}, {2, 3}}, nil, true}, 305 } 306 307 func TestShape_Concat(t *testing.T) { 308 assert := assert.New(t) 309 for _, scts := range shapeConcatTests { 310 newShape, err := scts.s.Concat(scts.axis, scts.ss...) 311 switch { 312 case scts.err: 313 if err == nil { 314 t.Error("Expected an error") 315 } 316 continue 317 case !scts.err && err != nil: 318 t.Error(err) 319 continue 320 } 321 assert.Equal(scts.expected, newShape) 322 } 323 }