github.com/wzzhu/tensor@v0.9.24/ap_test.go (about) 1 package tensor 2 3 import ( 4 //"fmt" 5 "testing" 6 7 "github.com/stretchr/testify/assert" 8 ) 9 10 func dummyScalar1() AP { return AP{} } 11 12 func dummyScalar2() AP { return AP{shape: Shape{1}} } 13 14 func dummyColVec() AP { 15 return AP{ 16 shape: Shape{5, 1}, 17 strides: []int{1}, 18 } 19 } 20 21 func dummyRowVec() AP { 22 return AP{ 23 shape: Shape{1, 5}, 24 strides: []int{1}, 25 } 26 } 27 28 func dummyVec() AP { 29 return AP{ 30 shape: Shape{5}, 31 strides: []int{1}, 32 } 33 } 34 35 func twothree() AP { 36 return AP{ 37 shape: Shape{2, 3}, 38 strides: []int{3, 1}, 39 } 40 } 41 42 func twothreefour() AP { 43 return AP{ 44 shape: Shape{2, 3, 4}, 45 strides: []int{12, 4, 1}, 46 } 47 } 48 49 func TestAccessPatternBasics(t *testing.T) { 50 assert := assert.New(t) 51 ap := new(AP) 52 53 ap.SetShape(1, 2) 54 assert.Equal(Shape{1, 2}, ap.Shape()) 55 assert.Equal([]int{2, 1}, ap.Strides()) 56 assert.Equal(2, ap.Dims()) 57 assert.Equal(2, ap.Size()) 58 59 ap.SetShape(2, 3, 2) 60 assert.Equal(Shape{2, 3, 2}, ap.Shape()) 61 assert.Equal([]int{6, 2, 1}, ap.Strides()) 62 assert.Equal(12, ap.Size()) 63 64 ap.lock() 65 ap.SetShape(1, 2, 3) 66 assert.Equal(Shape{2, 3, 2}, ap.shape) 67 assert.Equal([]int{6, 2, 1}, ap.strides) 68 69 ap.unlock() 70 ap.SetShape(1, 2) 71 assert.Equal(Shape{1, 2}, ap.Shape()) 72 assert.Equal([]int{2, 1}, ap.Strides()) 73 assert.Equal(2, ap.Dims()) 74 assert.Equal(2, ap.Size()) 75 76 if ap.String() != "Shape: (1, 2), Stride: [2 1], Lock: false" { 77 t.Errorf("AP formatting error. Got %q", ap.String()) 78 } 79 80 ap2 := ap.Clone() 81 assert.Equal(*ap, ap2) 82 } 83 84 func TestAccessPatternIsX(t *testing.T) { 85 assert := assert.New(t) 86 var ap AP 87 88 ap = dummyScalar1() 89 assert.True(ap.IsScalar()) 90 assert.True(ap.IsScalarEquiv()) 91 assert.False(ap.IsVector()) 92 assert.False(ap.IsColVec()) 93 assert.False(ap.IsRowVec()) 94 95 ap = dummyScalar2() 96 assert.False(ap.IsScalar()) 97 assert.True(ap.IsScalarEquiv()) 98 assert.True(ap.IsVectorLike()) 99 assert.True(ap.IsVector()) 100 assert.False(ap.IsColVec()) 101 assert.False(ap.IsRowVec()) 102 103 ap = dummyColVec() 104 assert.True(ap.IsColVec()) 105 assert.True(ap.IsVector()) 106 assert.False(ap.IsRowVec()) 107 assert.False(ap.IsScalar()) 108 109 ap = dummyRowVec() 110 assert.True(ap.IsRowVec()) 111 assert.True(ap.IsVector()) 112 assert.False(ap.IsColVec()) 113 assert.False(ap.IsScalar()) 114 115 ap = twothree() 116 assert.True(ap.IsMatrix()) 117 assert.False(ap.IsScalar()) 118 assert.False(ap.IsVector()) 119 assert.False(ap.IsRowVec()) 120 assert.False(ap.IsColVec()) 121 122 } 123 124 func TestAccessPatternT(t *testing.T) { 125 assert := assert.New(t) 126 var ap, apT AP 127 var axes []int 128 var err error 129 130 ap = twothree() 131 132 // test no axes 133 apT, axes, err = ap.T() 134 if err != nil { 135 t.Error(err) 136 } 137 138 assert.Equal(Shape{3, 2}, apT.shape) 139 assert.Equal([]int{1, 3}, apT.strides) 140 assert.Equal([]int{1, 0}, axes) 141 assert.Equal(2, apT.Dims()) 142 143 // test no op 144 apT, _, err = ap.T(0, 1) 145 if err != nil { 146 if _, ok := err.(NoOpError); !ok { 147 t.Error(err) 148 } 149 } 150 151 // test 3D 152 ap = twothreefour() 153 apT, axes, err = ap.T(2, 0, 1) 154 if err != nil { 155 t.Error(err) 156 } 157 assert.Equal(Shape{4, 2, 3}, apT.shape) 158 assert.Equal([]int{1, 12, 4}, apT.strides) 159 assert.Equal([]int{2, 0, 1}, axes) 160 assert.Equal(3, apT.Dims()) 161 162 // test stupid axes 163 _, _, err = ap.T(1, 2, 3) 164 if err == nil { 165 t.Error("Expected an error") 166 } 167 } 168 169 var sliceTests = []struct { 170 name string 171 shape Shape 172 slices []Slice 173 174 correctStart int 175 correctEnd int 176 correctShape Shape 177 correctStride []int 178 contiguous bool 179 }{ 180 // vectors 181 {"a[0]", Shape{5}, []Slice{S(0)}, 0, 1, ScalarShape(), nil, true}, 182 {"a[0:2]", Shape{5}, []Slice{S(0, 2)}, 0, 2, Shape{2}, []int{1}, true}, 183 {"a[1:3]", Shape{5}, []Slice{S(1, 3)}, 1, 3, Shape{2}, []int{1}, true}, 184 {"a[1:5:2]", Shape{5}, []Slice{S(1, 5, 2)}, 1, 5, Shape{2}, []int{2}, false}, 185 186 // matrix 187 {"A[0]", Shape{2, 3}, []Slice{S(0)}, 0, 3, Shape{1, 3}, []int{1}, true}, 188 {"A[1:3]", Shape{4, 5}, []Slice{S(1, 3)}, 5, 15, Shape{2, 5}, []int{5, 1}, true}, 189 {"A[0:10] (intentionally over)", Shape{4, 5}, []Slice{S(0, 10)}, 0, 20, Shape{4, 5}, []int{5, 1}, true}, // as if nothing happened 190 {"A[:, 1:3]", Shape{4, 5}, []Slice{nil, S(1, 3)}, 1, 18, Shape{4, 2}, []int{5, 1}, false}, 191 192 // tensor 193 {"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, 0, 4, Shape{2, 2}, []int{2, 1}, true}, 194 {"tensor[:, 0, :]", Shape{1, 2, 2}, []Slice{nil, rs{0, 1, 1}, nil}, 0, 2, Shape{1, 2}, []int{4, 1}, false}, 195 {"tensor[0, :, :, :]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil, nil}, 0, 4, Shape{1, 2, 2}, []int{4, 2, 1}, true}, 196 {"tensor[0,]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}}, 0, 4, Shape{1, 2, 2}, []int{4, 2, 1}, true}, 197 } 198 199 func TestAccessPatternS(t *testing.T) { 200 assert := assert.New(t) 201 var ap, apS AP 202 var ndStart, ndEnd int 203 var err error 204 205 for _, sts := range sliceTests { 206 ap = MakeAP(sts.shape, sts.shape.CalcStrides(), 0, 0) 207 if apS, ndStart, ndEnd, err = ap.S(sts.shape.TotalSize(), sts.slices...); err != nil { 208 t.Errorf("%v errored: %v", sts.name, err) 209 continue 210 } 211 assert.Equal(sts.correctStart, ndStart, "Wrong start: %v. Want %d Got %d", sts.name, sts.correctStart, ndStart) 212 assert.Equal(sts.correctEnd, ndEnd, "Wrong end: %v. Want %d Got %d", sts.name, sts.correctEnd, ndEnd) 213 assert.True(sts.correctShape.Eq(apS.shape), "Wrong shape: %v. Want %v. Got %v", sts.name, sts.correctShape, apS.shape) 214 assert.Equal(sts.correctStride, apS.strides, "Wrong strides: %v. Want %v. Got %v", sts.name, sts.correctStride, apS.strides) 215 assert.Equal(sts.contiguous, apS.DataOrder().IsContiguous(), "Wrong contiguity for %v Want %t.", sts.name, sts.contiguous) 216 } 217 } 218 219 func TestTransposeIndex(t *testing.T) { 220 var newInd int 221 var oldShape Shape 222 var pattern, oldStrides, newStrides, corrects []int 223 224 /* 225 (2,3)->(3,2) 226 0, 1, 2 227 3, 4, 5 228 229 becomes 230 231 0, 3 232 1, 4 233 2, 5 234 235 1 -> 2 236 2 -> 4 237 3 -> 1 238 4 -> 3 239 0 and 5 stay the same 240 */ 241 242 oldShape = Shape{2, 3} 243 pattern = []int{1, 0} 244 oldStrides = []int{3, 1} 245 newStrides = []int{2, 1} 246 corrects = []int{0, 2, 4, 1, 3, 5} 247 for i := 0; i < 6; i++ { 248 newInd = TransposeIndex(i, oldShape, pattern, oldStrides, newStrides) 249 if newInd != corrects[i] { 250 t.Errorf("Want %d, got %d instead", corrects[i], newInd) 251 } 252 } 253 254 /* 255 (2,3,4) -(1,0,2)-> (3,2,4) 256 0, 1, 2, 3 257 4, 5, 6, 7 258 8, 9, 10, 11 259 260 12, 13, 14, 15 261 16, 17, 18, 19 262 20, 21, 22, 23 263 264 becomes 265 266 0, 1, 2, 3 267 12, 13, 14, 15, 268 269 4, 5, 6, 7 270 16, 17, 18, 19 271 272 8, 9, 10, 11 273 20, 21, 22, 23 274 */ 275 oldShape = Shape{2, 3, 4} 276 pattern = []int{1, 0, 2} 277 oldStrides = []int{12, 4, 1} 278 newStrides = []int{8, 4, 1} 279 corrects = []int{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23} 280 for i := 0; i < len(corrects); i++ { 281 newInd = TransposeIndex(i, oldShape, pattern, oldStrides, newStrides) 282 if newInd != corrects[i] { 283 t.Errorf("Want %d, got %d instead", corrects[i], newInd) 284 } 285 } 286 287 /* 288 (2,3,4) -(2,0,1)-> (4,2,3) 289 0, 1, 2, 3 290 4, 5, 6, 7 291 8, 9, 10, 11 292 293 12, 13, 14, 15 294 16, 17, 18, 19 295 20, 21, 22, 23 296 297 becomes 298 299 0, 4, 8 300 12, 16, 20 301 302 1, 5, 9 303 13, 17, 21 304 305 2, 6, 10 306 14, 18, 22 307 308 3, 7, 11 309 15, 19, 23 310 */ 311 312 oldShape = Shape{2, 3, 4} 313 pattern = []int{2, 0, 1} 314 oldStrides = []int{12, 4, 1} 315 newStrides = []int{6, 3, 1} 316 corrects = []int{0, 6, 12, 18, 1, 7, 13, 19, 2, 8, 14, 20, 3, 9, 15, 21, 4, 10, 16, 22, 5, 11, 17, 23} 317 for i := 0; i < len(corrects); i++ { 318 newInd = TransposeIndex(i, oldShape, pattern, oldStrides, newStrides) 319 if newInd != corrects[i] { 320 t.Errorf("Want %d, got %d instead", corrects[i], newInd) 321 } 322 } 323 324 } 325 326 func TestUntransposeIndex(t *testing.T) { 327 var newInd int 328 var oldShape Shape 329 var pattern, oldStrides, newStrides, corrects []int 330 331 // vice versa 332 oldShape = Shape{3, 2} 333 oldStrides = []int{2, 1} 334 newStrides = []int{3, 1} 335 corrects = []int{0, 3, 1, 4, 2, 5} 336 pattern = []int{1, 0} 337 for i := 0; i < 6; i++ { 338 newInd = UntransposeIndex(i, oldShape, pattern, oldStrides, newStrides) 339 if newInd != corrects[i] { 340 t.Errorf("Want %d, got %d instead", corrects[i], newInd) 341 } 342 } 343 344 oldShape = Shape{3, 2, 4} 345 oldStrides = []int{8, 4, 1} 346 newStrides = []int{12, 4, 1} 347 pattern = []int{1, 0, 2} 348 corrects = []int{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23} 349 for i := 0; i < len(corrects); i++ { 350 newInd = TransposeIndex(i, oldShape, pattern, oldStrides, newStrides) 351 if newInd != corrects[i] { 352 t.Errorf("Want %d, got %d instead", corrects[i], newInd) 353 } 354 } 355 356 oldShape = Shape{4, 2, 3} 357 pattern = []int{2, 0, 1} 358 newStrides = []int{12, 4, 1} 359 oldStrides = []int{6, 3, 1} 360 corrects = []int{0, 4, 8, 12, 16, 20} 361 for i := 0; i < len(corrects); i++ { 362 newInd = UntransposeIndex(i, oldShape, pattern, oldStrides, newStrides) 363 if newInd != corrects[i] { 364 t.Errorf("Want %d, got %d instead", corrects[i], newInd) 365 } 366 } 367 } 368 369 func TestBroadcastStrides(t *testing.T) { 370 ds := Shape{4, 4} 371 ss := Shape{4} 372 dst := []int{4, 1} 373 sst := []int{1} 374 375 st, err := BroadcastStrides(ds, ss, dst, sst) 376 if err != nil { 377 t.Error(err) 378 } 379 t.Log(st) 380 }