gorgonia.org/tensor@v0.9.24/dense_matop_test.go (about) 1 package tensor 2 3 import ( 4 "fmt" 5 "testing" 6 7 "github.com/stretchr/testify/assert" 8 "gorgonia.org/vecf64" 9 ) 10 11 func cloneArray(a interface{}) interface{} { 12 switch at := a.(type) { 13 case []float64: 14 retVal := make([]float64, len(at)) 15 copy(retVal, at) 16 return retVal 17 case []float32: 18 retVal := make([]float32, len(at)) 19 copy(retVal, at) 20 return retVal 21 case []int: 22 retVal := make([]int, len(at)) 23 copy(retVal, at) 24 return retVal 25 case []int64: 26 retVal := make([]int64, len(at)) 27 copy(retVal, at) 28 return retVal 29 case []int32: 30 retVal := make([]int32, len(at)) 31 copy(retVal, at) 32 return retVal 33 case []byte: 34 retVal := make([]byte, len(at)) 35 copy(retVal, at) 36 return retVal 37 case []bool: 38 retVal := make([]bool, len(at)) 39 copy(retVal, at) 40 return retVal 41 } 42 return nil 43 } 44 45 func castToDt(val float64, dt Dtype) interface{} { 46 switch dt { 47 case Bool: 48 return false 49 case Int: 50 return int(val) 51 case Int8: 52 return int8(val) 53 case Int16: 54 return int16(val) 55 case Int32: 56 return int32(val) 57 case Int64: 58 return int64(val) 59 case Uint: 60 return uint(val) 61 case Uint8: 62 return uint8(val) 63 case Uint16: 64 return uint16(val) 65 case Uint32: 66 return uint32(val) 67 case Uint64: 68 return uint64(val) 69 case Float32: 70 return float32(val) 71 case Float64: 72 return float64(val) 73 default: 74 return 0 75 } 76 } 77 78 var atTests = []struct { 79 data interface{} 80 shape Shape 81 coord []int 82 83 correct interface{} 84 err bool 85 }{ 86 // matrix 87 {[]float64{0, 1, 2, 3, 4, 5}, Shape{2, 3}, []int{0, 1}, float64(1), false}, 88 {[]float32{0, 1, 2, 3, 4, 5}, Shape{2, 3}, []int{1, 1}, float32(4), false}, 89 {[]float64{0, 1, 2, 3, 4, 5}, Shape{2, 3}, []int{1, 2, 3}, nil, true}, 90 91 // 3-tensor 92 {[]int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, 93 Shape{2, 3, 4}, []int{1, 1, 1}, 17, false}, 94 {[]int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, 95 Shape{2, 3, 4}, []int{1, 2, 3}, int64(23), false}, 96 {[]int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, 97 Shape{2, 3, 4}, []int{0, 3, 2}, 23, true}, 98 } 99 100 func TestDense_At(t *testing.T) { 101 for i, ats := range atTests { 102 T := New(WithShape(ats.shape...), WithBacking(ats.data)) 103 got, err := T.At(ats.coord...) 104 if checkErr(t, ats.err, err, "At", i) { 105 continue 106 } 107 108 if got != ats.correct { 109 t.Errorf("Expected %v. Got %v", ats.correct, got) 110 } 111 } 112 } 113 114 func Test_transposeIndex(t *testing.T) { 115 a := []byte{0, 1, 2, 3} 116 T := New(WithShape(2, 2), WithBacking(a)) 117 118 correct := []int{0, 2, 1, 3} 119 for i, v := range correct { 120 got := T.transposeIndex(i, []int{1, 0}, []int{2, 1}) 121 if v != got { 122 t.Errorf("transposeIndex error. Expected %v. Got %v", v, got) 123 } 124 } 125 } 126 127 var transposeTests = []struct { 128 name string 129 shape Shape 130 transposeWith []int 131 data interface{} 132 133 correctShape Shape 134 correctStrides []int // after .T() 135 correctStrides2 []int // after .Transpose() 136 correctData interface{} 137 }{ 138 {"c.T()", Shape{4, 1}, nil, []float64{0, 1, 2, 3}, 139 Shape{1, 4}, []int{1, 1}, []int{4, 1}, []float64{0, 1, 2, 3}}, 140 141 {"r.T()", Shape{1, 4}, nil, []float32{0, 1, 2, 3}, 142 Shape{4, 1}, []int{1, 1}, []int{1, 1}, []float32{0, 1, 2, 3}}, 143 144 {"v.T()", Shape{4}, nil, []int{0, 1, 2, 3}, 145 Shape{4}, []int{1}, []int{1}, []int{0, 1, 2, 3}}, 146 147 {"M.T()", Shape{2, 3}, nil, []int64{0, 1, 2, 3, 4, 5}, 148 Shape{3, 2}, []int{1, 3}, []int{2, 1}, []int64{0, 3, 1, 4, 2, 5}}, 149 150 {"M.T(0,1) (NOOP)", Shape{2, 3}, []int{0, 1}, []int32{0, 1, 2, 3, 4, 5}, 151 Shape{2, 3}, []int{3, 1}, []int{3, 1}, []int32{0, 1, 2, 3, 4, 5}}, 152 153 {"3T.T()", Shape{2, 3, 4}, nil, 154 []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, 155 156 Shape{4, 3, 2}, []int{1, 4, 12}, []int{6, 2, 1}, 157 []byte{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}}, 158 159 {"3T.T(2, 1, 0) (Same as .T())", Shape{2, 3, 4}, []int{2, 1, 0}, 160 []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, 161 Shape{4, 3, 2}, []int{1, 4, 12}, []int{6, 2, 1}, 162 []int{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}}, 163 164 {"3T.T(2, 1, 0) (Same as .T())", Shape{2, 3, 4}, []int{2, 1, 0}, 165 []int16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, 166 Shape{4, 3, 2}, []int{1, 4, 12}, []int{6, 2, 1}, 167 []int16{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}}, 168 169 {"3T.T(0, 2, 1)", Shape{2, 3, 4}, []int{0, 2, 1}, 170 []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, 171 Shape{2, 4, 3}, []int{12, 1, 4}, []int{12, 3, 1}, 172 []int32{0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23}}, 173 174 {"3T.T{1, 0, 2)", Shape{2, 3, 4}, []int{1, 0, 2}, 175 []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, 176 Shape{3, 2, 4}, []int{4, 12, 1}, []int{8, 4, 1}, 177 []float64{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23}}, 178 179 {"3T.T{1, 2, 0)", Shape{2, 3, 4}, []int{1, 2, 0}, 180 []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, 181 Shape{3, 4, 2}, []int{4, 1, 12}, []int{8, 2, 1}, 182 []float64{0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23}}, 183 184 {"3T.T{2, 0, 1)", Shape{2, 3, 4}, []int{2, 0, 1}, 185 []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, 186 Shape{4, 2, 3}, []int{1, 12, 4}, []int{6, 3, 1}, 187 []float32{0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}}, 188 189 {"3T.T{0, 1, 2} (NOOP)", Shape{2, 3, 4}, []int{0, 1, 2}, 190 []bool{true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false}, 191 Shape{2, 3, 4}, []int{12, 4, 1}, []int{12, 4, 1}, 192 []bool{true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false}}, 193 194 {"M[2,2].T for bools, just for completeness sake", Shape{2, 2}, nil, 195 []bool{true, true, false, false}, 196 Shape{2, 2}, []int{1, 2}, []int{2, 1}, 197 []bool{true, false, true, false}, 198 }, 199 200 {"M[2,2].T for strings, just for completeness sake", Shape{2, 2}, nil, 201 []string{"hello", "world", "今日は", "世界"}, 202 Shape{2, 2}, []int{1, 2}, []int{2, 1}, 203 []string{"hello", "今日は", "world", "世界"}, 204 }, 205 } 206 207 func TestDense_Transpose(t *testing.T) { 208 assert := assert.New(t) 209 var err error 210 211 // standard transposes 212 for _, tts := range transposeTests { 213 T := New(WithShape(tts.shape...), WithBacking(tts.data)) 214 if err = T.T(tts.transposeWith...); err != nil { 215 t.Errorf("%v - %v", tts.name, err) 216 continue 217 } 218 219 assert.True(tts.correctShape.Eq(T.Shape()), "Transpose %v Expected shape: %v. Got %v", tts.name, tts.correctShape, T.Shape()) 220 assert.Equal(tts.correctStrides, T.Strides(), "Transpose %v. Expected stride: %v. Got %v", tts.name, tts.correctStrides, T.Strides()) 221 T.Transpose() 222 assert.True(tts.correctShape.Eq(T.Shape()), "Transpose %v Expected shape: %v. Got %v", tts.name, tts.correctShape, T.Shape()) 223 assert.Equal(tts.correctStrides2, T.Strides(), "Transpose2 %v - Expected stride %v. Got %v", tts.name, tts.correctStrides2, T.Strides()) 224 assert.Equal(tts.correctData, T.Data(), "Transpose %v", tts.name) 225 } 226 227 // test stacked .T() calls 228 var T *Dense 229 230 // column vector 231 T = New(WithShape(4, 1), WithBacking(Range(Int, 0, 4))) 232 if err = T.T(); err != nil { 233 t.Errorf("Stacked .T() #1 for vector. Error: %v", err) 234 goto matrev 235 } 236 if err = T.T(); err != nil { 237 t.Errorf("Stacked .T() #1 for vector. Error: %v", err) 238 goto matrev 239 } 240 assert.True(T.old.IsZero()) 241 assert.Nil(T.transposeWith) 242 assert.True(T.IsColVec()) 243 244 matrev: 245 // matrix, reversed 246 T = New(WithShape(2, 3), WithBacking(Range(Byte, 0, 6))) 247 if err = T.T(); err != nil { 248 t.Errorf("Stacked .T() #1 for matrix reverse. Error: %v", err) 249 goto matnorev 250 } 251 if err = T.T(); err != nil { 252 t.Errorf("Stacked .T() #2 for matrix reverse. Error: %v", err) 253 goto matnorev 254 } 255 assert.True(T.old.IsZero()) 256 assert.Nil(T.transposeWith) 257 assert.True(Shape{2, 3}.Eq(T.Shape())) 258 259 matnorev: 260 // 3-tensor, non reversed 261 T = New(WithShape(2, 3, 4), WithBacking(Range(Int64, 0, 24))) 262 if err = T.T(); err != nil { 263 t.Fatalf("Stacked .T() #1 for tensor with no reverse. Error: %v", err) 264 } 265 if err = T.T(2, 0, 1); err != nil { 266 t.Fatalf("Stacked .T() #2 for tensor with no reverse. Error: %v", err) 267 } 268 correctData := []int64{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23} 269 assert.Equal(correctData, T.Data()) 270 assert.Equal([]int{2, 0, 1}, T.transposeWith) 271 assert.NotNil(T.old) 272 273 } 274 275 func TestTUT(t *testing.T) { 276 assert := assert.New(t) 277 var T *Dense 278 279 T = New(Of(Float64), WithShape(2, 3, 4)) 280 T.T() 281 T.UT() 282 assert.True(T.old.IsZero()) 283 assert.Nil(T.transposeWith) 284 285 T.T(2, 0, 1) 286 T.UT() 287 assert.True(T.old.IsZero()) 288 assert.Nil(T.transposeWith) 289 } 290 291 type repeatTest struct { 292 name string 293 tensor *Dense 294 ne bool // should assert tensor not equal 295 axis int 296 repeats []int 297 298 correct interface{} 299 shape Shape 300 err bool 301 } 302 303 var repeatTests = []repeatTest{ 304 {"Scalar Repeat on axis 0", New(FromScalar(true)), 305 true, 0, []int{3}, 306 []bool{true, true, true}, 307 Shape{3}, false, 308 }, 309 310 {"Scalar Repeat on axis 1", New(FromScalar(byte(255))), 311 false, 1, []int{3}, 312 []byte{255, 255, 255}, 313 Shape{1, 3}, false, 314 }, 315 316 {"Vector Repeat on axis 0", New(WithShape(2), WithBacking([]int32{1, 2})), 317 false, 0, []int{3}, 318 []int32{1, 1, 1, 2, 2, 2}, 319 Shape{6}, false, 320 }, 321 322 {"ColVec Repeat on axis 0", New(WithShape(2, 1), WithBacking([]int64{1, 2})), 323 false, 0, []int{3}, 324 []int64{1, 1, 1, 2, 2, 2}, 325 Shape{6, 1}, false, 326 }, 327 328 {"RowVec Repeat on axis 0", New(WithShape(1, 2), WithBacking([]int{1, 2})), 329 false, 0, []int{3}, 330 []int{1, 2, 1, 2, 1, 2}, 331 Shape{3, 2}, false, 332 }, 333 334 {"ColVec Repeat on axis 1", New(WithShape(2, 1), WithBacking([]float32{1, 2})), 335 false, 1, []int{3}, 336 []float32{1, 1, 1, 2, 2, 2}, 337 Shape{2, 3}, false, 338 }, 339 340 {"RowVec Repeat on axis 1", New(WithShape(1, 2), WithBacking([]float64{1, 2})), 341 false, 1, []int{3}, 342 []float64{1, 1, 1, 2, 2, 2}, 343 Shape{1, 6}, false, 344 }, 345 346 {"Vector Repeat on all axes", New(WithShape(2), WithBacking([]byte{1, 2})), 347 false, AllAxes, []int{3}, 348 []byte{1, 1, 1, 2, 2, 2}, 349 Shape{6}, false, 350 }, 351 352 {"ColVec Repeat on all axes", New(WithShape(2, 1), WithBacking([]int32{1, 2})), 353 false, AllAxes, []int{3}, 354 []int32{1, 1, 1, 2, 2, 2}, 355 Shape{6}, false, 356 }, 357 358 {"RowVec Repeat on all axes", New(WithShape(1, 2), WithBacking([]int64{1, 2})), 359 false, AllAxes, []int{3}, 360 []int64{1, 1, 1, 2, 2, 2}, 361 Shape{6}, false, 362 }, 363 364 {"M[2,2] Repeat on all axes with repeats = (1,2,1,1)", New(WithShape(2, 2), WithBacking([]int{1, 2, 3, 4})), 365 false, AllAxes, []int{1, 2, 1, 1}, 366 []int{1, 2, 2, 3, 4}, 367 Shape{5}, false, 368 }, 369 370 {"M[2,2] Repeat on axis 1 with repeats = (2, 1)", New(WithShape(2, 2), WithBacking([]float32{1, 2, 3, 4})), 371 false, 1, []int{2, 1}, 372 []float32{1, 1, 2, 3, 3, 4}, 373 Shape{2, 3}, false, 374 }, 375 376 {"M[2,2] Repeat on axis 1 with repeats = (1, 2)", New(WithShape(2, 2), WithBacking([]float64{1, 2, 3, 4})), 377 false, 1, []int{1, 2}, 378 []float64{1, 2, 2, 3, 4, 4}, 379 Shape{2, 3}, false, 380 }, 381 382 {"M[2,2] Repeat on axis 0 with repeats = (1, 2)", New(WithShape(2, 2), WithBacking([]float64{1, 2, 3, 4})), 383 false, 0, []int{1, 2}, 384 []float64{1, 2, 3, 4, 3, 4}, 385 Shape{3, 2}, false, 386 }, 387 388 {"M[2,2] Repeat on axis 0 with repeats = (2, 1)", New(WithShape(2, 2), WithBacking([]float64{1, 2, 3, 4})), 389 false, 0, []int{2, 1}, 390 []float64{1, 2, 1, 2, 3, 4}, 391 Shape{3, 2}, false, 392 }, 393 394 {"3T[2,3,2] Repeat on axis 1 with repeats = (1,2,1)", New(WithShape(2, 3, 2), WithBacking(vecf64.Range(1, 2*3*2+1))), 395 false, 1, []int{1, 2, 1}, 396 []float64{1, 2, 3, 4, 3, 4, 5, 6, 7, 8, 9, 10, 9, 10, 11, 12}, 397 Shape{2, 4, 2}, false, 398 }, 399 400 {"3T[2,3,2] Generic Repeat by 2", New(WithShape(2, 3, 2), WithBacking(vecf64.Range(1, 2*3*2+1))), 401 false, AllAxes, []int{2}, 402 []float64{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12}, 403 Shape{24}, false, 404 }, 405 406 {"3T[2,3,2] repeat with broadcast errors", New(WithShape(2, 3, 2), WithBacking(vecf64.Range(1, 2*3*2+1))), 407 false, 0, []int{1, 2, 1}, 408 nil, nil, true, 409 }, 410 411 // idiots 412 {"Nonexistent axis", New(WithShape(2, 1), WithBacking([]bool{true, false})), 413 false, 2, []int{3}, nil, nil, true, 414 }, 415 } 416 417 func TestDense_Repeat(t *testing.T) { 418 assert := assert.New(t) 419 420 for i, test := range repeatTests { 421 T, err := test.tensor.Repeat(test.axis, test.repeats...) 422 if checkErr(t, test.err, err, "Repeat", i) { 423 continue 424 } 425 426 var D DenseTensor 427 if D, err = getDenseTensor(T); err != nil { 428 t.Errorf("Expected Repeat to return a *Dense. got %v of %T instead", T, T) 429 continue 430 } 431 432 if test.ne { 433 assert.NotEqual(test.tensor, D, test.name) 434 } 435 436 assert.Equal(test.correct, D.Data(), test.name) 437 assert.Equal(test.shape, D.Shape(), test.name) 438 } 439 } 440 441 func TestDense_Repeat_Slow(t *testing.T) { 442 rt2 := make([]repeatTest, len(repeatTests)) 443 for i, rt := range repeatTests { 444 rt2[i] = repeatTest{ 445 name: rt.name, 446 ne: rt.ne, 447 axis: rt.axis, 448 repeats: rt.repeats, 449 correct: rt.correct, 450 shape: rt.shape, 451 err: rt.err, 452 tensor: rt.tensor.Clone().(*Dense), 453 } 454 } 455 for i := range rt2 { 456 maskLen := rt2[i].tensor.len() 457 mask := make([]bool, maskLen) 458 rt2[i].tensor.mask = mask 459 } 460 461 assert := assert.New(t) 462 463 for i, test := range rt2 { 464 T, err := test.tensor.Repeat(test.axis, test.repeats...) 465 if checkErr(t, test.err, err, "Repeat", i) { 466 continue 467 } 468 469 var D DenseTensor 470 if D, err = getDenseTensor(T); err != nil { 471 t.Errorf("Expected Repeat to return a *Dense. got %v of %T instead", T, T) 472 continue 473 } 474 475 if test.ne { 476 assert.NotEqual(test.tensor, D, test.name) 477 } 478 479 assert.Equal(test.correct, D.Data(), test.name) 480 assert.Equal(test.shape, D.Shape(), test.name) 481 } 482 } 483 484 func TestDense_CopyTo(t *testing.T) { 485 assert := assert.New(t) 486 var T, T2 *Dense 487 var T3 Tensor 488 var err error 489 490 T = New(WithShape(2), WithBacking([]float64{1, 2})) 491 T2 = New(Of(Float64), WithShape(1, 2)) 492 493 err = T.CopyTo(T2) 494 if err != nil { 495 t.Fatal(err) 496 } 497 assert.Equal(T2.Data(), T.Data()) 498 499 // now, modify T1's data 500 T.Set(0, float64(5000)) 501 assert.NotEqual(T2.Data(), T.Data()) 502 503 // test views 504 T = New(Of(Byte), WithShape(3, 3)) 505 T2 = New(Of(Byte), WithShape(2, 2)) 506 T3, _ = T.Slice(makeRS(0, 2), makeRS(0, 2)) // T[0:2, 0:2], shape == (2,2) 507 if err = T2.CopyTo(T3.(*Dense)); err != nil { 508 t.Log(err) // for now it's a not yet implemented error. TODO: FIX THIS 509 } 510 511 // dumbass time 512 513 T = New(Of(Float32), WithShape(3, 3)) 514 T2 = New(Of(Float32), WithShape(2, 2)) 515 if err = T.CopyTo(T2); err == nil { 516 t.Error("Expected an error") 517 } 518 519 if err = T.CopyTo(T); err != nil { 520 t.Error("Copying a *Tensor to itself should yield no error. ") 521 } 522 523 } 524 525 var denseSliceTests = []struct { 526 name string 527 data interface{} 528 shape Shape 529 slices []Slice 530 531 correctShape Shape 532 correctStride []int 533 correctData interface{} 534 }{ 535 // scalar-equiv vector (issue 102) 536 {"a[0], a is scalar-equiv", []float64{2}, 537 Shape{1}, []Slice{ss(0)}, ScalarShape(), nil, 2.0}, 538 539 // vector 540 {"a[0]", []bool{true, true, false, false, false}, 541 Shape{5}, []Slice{ss(0)}, ScalarShape(), nil, true}, 542 {"a[0:2]", Range(Byte, 0, 5), Shape{5}, []Slice{makeRS(0, 2)}, Shape{2}, []int{1}, []byte{0, 1}}, 543 {"a[1:5:2]", Range(Int32, 0, 5), Shape{5}, []Slice{makeRS(1, 5, 2)}, Shape{2}, []int{2}, []int32{1, 2, 3, 4}}, 544 545 // colvec 546 {"c[0]", Range(Int64, 0, 5), Shape{5, 1}, []Slice{ss(0)}, ScalarShape(), nil, int64(0)}, 547 {"c[0:2]", Range(Float32, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 2)}, Shape{2, 1}, []int{1, 1}, []float32{0, 1}}, 548 {"c[1:5:2]", Range(Float64, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 5, 2)}, Shape{2, 1}, []int{2, 1}, []float64{0, 1, 2, 3, 4}}, 549 550 // // rowvec 551 {"r[0]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{ss(0)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, 552 {"r[0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{makeRS(0, 2)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, 553 {"r[0:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{makeRS(0, 5, 2)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, 554 {"r[:, 0]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, ss(0)}, ScalarShape(), nil, float64(0)}, 555 {"r[:, 0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2}, []int{5, 1}, []float64{0, 1}}, 556 {"r[:, 1:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{1, 2}, []int{5, 2}, []float64{1, 2, 3, 4}}, 557 558 // // matrix 559 {"A[0]", Range(Float64, 0, 6), Shape{2, 3}, []Slice{ss(0)}, Shape{1, 3}, []int{1}, Range(Float64, 0, 3)}, 560 {"A[0:2]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{makeRS(0, 2)}, Shape{2, 5}, []int{5, 1}, Range(Float64, 0, 10)}, 561 {"A[0, 0]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{ss(0), ss(0)}, ScalarShape(), nil, float64(0)}, 562 {"A[0, 1:5]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{ss(0), makeRS(1, 5)}, Shape{4}, []int{1}, Range(Float64, 1, 5)}, 563 {"A[0, 1:5:2]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{ss(0), makeRS(1, 5, 2)}, Shape{1, 2}, []int{2}, Range(Float64, 1, 5)}, 564 {"A[:, 0]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, ss(0)}, Shape{4, 1}, []int{5}, Range(Float64, 0, 16)}, 565 {"A[:, 1:5]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, makeRS(1, 5)}, Shape{4, 4}, []int{5, 1}, Range(Float64, 1, 20)}, 566 {"A[:, 1:5:2]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{4, 2}, []int{5, 2}, Range(Float64, 1, 20)}, 567 568 // 3tensor with leading and trailing 1s 569 570 {"3T1[0]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{ss(0)}, Shape{9, 1}, []int{1, 1}, Range(Float64, 0, 9)}, 571 {"3T1[nil, 0:2]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2, 1}, []int{9, 1, 1}, Range(Float64, 0, 2)}, 572 {"3T1[nil, 0:5:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(0, 5, 3)}, Shape{1, 2, 1}, []int{9, 3, 1}, Range(Float64, 0, 5)}, 573 {"3T1[nil, 1:5:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(1, 5, 3)}, Shape{1, 2, 1}, []int{9, 3, 1}, Range(Float64, 1, 5)}, 574 {"3T1[nil, 1:9:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(1, 9, 3)}, Shape{1, 3, 1}, []int{9, 3, 1}, Range(Float64, 1, 9)}, 575 576 // 3tensor 577 {"3T[0]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(0)}, Shape{9, 2}, []int{2, 1}, Range(Float64, 0, 18)}, 578 {"3T[1]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1)}, Shape{9, 2}, []int{2, 1}, Range(Float64, 18, 36)}, 579 {"3T[1, 2]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), ss(2)}, Shape{2}, []int{1}, Range(Float64, 22, 24)}, 580 {"3T[1, 2:4]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 4)}, Shape{2, 2}, []int{2, 1}, Range(Float64, 22, 26)}, 581 {"3T[1, 2:8:2]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 8, 2)}, Shape{3, 2}, []int{4, 1}, Range(Float64, 22, 34)}, 582 {"3T[1, 2:8:3]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 8, 3)}, Shape{2, 2}, []int{6, 1}, Range(Float64, 22, 34)}, 583 {"3T[1, 2:9:2]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2)}, Shape{4, 7}, []int{14, 1}, Range(Float64, 77, 126)}, 584 {"3T[1, 2:9:2, 1]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2), ss(1)}, Shape{4}, []int{14}, Range(Float64, 78, 121)}, // should this be a colvec? 585 {"3T[1, 2:9:2, 1:4:2]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2), makeRS(1, 4, 2)}, Shape{4, 2}, []int{14, 2}, Range(Float64, 78, 123)}, 586 } 587 588 func TestDense_Slice(t *testing.T) { 589 assert := assert.New(t) 590 var T *Dense 591 var V Tensor 592 var err error 593 594 for _, sts := range denseSliceTests { 595 T = New(WithShape(sts.shape...), WithBacking(sts.data)) 596 t.Log(sts.name) 597 if V, err = T.Slice(sts.slices...); err != nil { 598 t.Error(err) 599 continue 600 } 601 assert.True(sts.correctShape.Eq(V.Shape()), "Test: %v - Incorrect Shape. Correct: %v. Got %v", sts.name, sts.correctShape, V.Shape()) 602 assert.Equal(sts.correctStride, V.Strides(), "Test: %v - Incorrect Stride", sts.name) 603 assert.Equal(sts.correctData, V.Data(), "Test: %v - Incorrect Data", sts.name) 604 } 605 606 // Transposed slice 607 T = New(WithShape(2, 3), WithBacking(Range(Float32, 0, 6))) 608 T.T() 609 V, err = T.Slice(ss(0)) 610 assert.True(Shape{2}.Eq(V.Shape())) 611 assert.Equal([]int{3}, V.Strides()) 612 assert.Equal([]float32{0, 1, 2, 3}, V.Data()) 613 assert.True(V.(*Dense).old.IsZero()) 614 615 // slice a sliced 616 t.Logf("%v", V) 617 V, err = V.Slice(makeRS(1, 2)) 618 t.Logf("%v", V) 619 assert.True(ScalarShape().Eq(V.Shape())) 620 assert.Equal(float32(3), V.Data()) 621 622 // And now, ladies and gentlemen, the idiots! 623 624 // too many slices 625 _, err = T.Slice(ss(1), ss(2), ss(3), ss(4)) 626 if err == nil { 627 t.Error("Expected a DimMismatchError error") 628 } 629 630 // out of range sliced 631 _, err = T.Slice(makeRS(20, 5)) 632 if err == nil { 633 t.Error("Expected a IndexError") 634 } 635 636 // surely nobody can be this dumb? Having a start of negatives 637 _, err = T.Slice(makeRS(-1, 1)) 638 if err == nil { 639 t.Error("Expected a IndexError") 640 } 641 } 642 643 func TestDense_Narrow(t *testing.T) { 644 testCases := []struct { 645 x *Dense 646 dim, start, length int 647 expected *Dense 648 }{ 649 { 650 x: New( 651 WithShape(3), 652 WithBacking([]int{1, 2, 3}), 653 ), 654 dim: 0, 655 start: 1, 656 length: 1, 657 expected: New( 658 WithShape(), 659 WithBacking([]int{2}), 660 ), 661 }, 662 { 663 x: New( 664 WithShape(3, 3), 665 WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}), 666 ), 667 dim: 0, 668 start: 0, 669 length: 2, 670 expected: New( 671 WithShape(2, 3), 672 WithBacking([]int{1, 2, 3, 4, 5, 6}), 673 ), 674 }, 675 { 676 x: New( 677 WithShape(3, 3), 678 WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}), 679 ), 680 dim: 1, 681 start: 1, 682 length: 2, 683 expected: New( 684 WithShape(3, 2), 685 WithBacking([]int{2, 3, 5, 6, 8, 9}), 686 ), 687 }, 688 { 689 x: New( 690 WithShape(3, 3), 691 WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}), 692 ), 693 dim: 1, 694 start: 0, 695 length: 1, 696 expected: New( 697 WithShape(3), 698 WithBacking([]int{1, 4, 7}), 699 ), 700 }, 701 } 702 703 for i, tC := range testCases { 704 t.Run(fmt.Sprintf("Example #%d narrow(%v,%d,%d,%v)", i+1, tC.x.Shape(), tC.dim, tC.start, tC.length), func(t *testing.T) { 705 c := assert.New(t) 706 // t.Logf("X:\n%v", tC.x) 707 708 y, err := tC.x.Narrow(tC.dim, tC.start, tC.length) 709 c.NoError(err) 710 // t.Logf("y:\n%v", y) 711 712 yMat := y.Materialize() 713 c.Equal(tC.expected.Shape(), yMat.Shape()) 714 c.Equal(tC.expected.Data(), yMat.Data()) 715 716 // err = y.Memset(1024) 717 // c.NoError(err) 718 // t.Logf("After Memset\nY: %v\nX:\n%v", y, tC.x) 719 }) 720 } 721 } 722 723 func TestDense_SliceInto(t *testing.T) { 724 V := New(WithShape(100), Of(Byte)) 725 T := New(WithBacking([]float64{1, 2, 3, 4, 5, 6}), WithShape(2, 3)) 726 T.SliceInto(V, ss(0)) 727 728 assert.True(t, Shape{3}.Eq(V.Shape()), "Got %v", V.Shape()) 729 assert.Equal(t, []float64{1, 2, 3}, V.Data()) 730 } 731 732 var rollaxisTests = []struct { 733 axis, start int 734 735 correctShape Shape 736 }{ 737 {0, 0, Shape{1, 2, 3, 4}}, 738 {0, 1, Shape{1, 2, 3, 4}}, 739 {0, 2, Shape{2, 1, 3, 4}}, 740 {0, 3, Shape{2, 3, 1, 4}}, 741 {0, 4, Shape{2, 3, 4, 1}}, 742 743 {1, 0, Shape{2, 1, 3, 4}}, 744 {1, 1, Shape{1, 2, 3, 4}}, 745 {1, 2, Shape{1, 2, 3, 4}}, 746 {1, 3, Shape{1, 3, 2, 4}}, 747 {1, 4, Shape{1, 3, 4, 2}}, 748 749 {2, 0, Shape{3, 1, 2, 4}}, 750 {2, 1, Shape{1, 3, 2, 4}}, 751 {2, 2, Shape{1, 2, 3, 4}}, 752 {2, 3, Shape{1, 2, 3, 4}}, 753 {2, 4, Shape{1, 2, 4, 3}}, 754 755 {3, 0, Shape{4, 1, 2, 3}}, 756 {3, 1, Shape{1, 4, 2, 3}}, 757 {3, 2, Shape{1, 2, 4, 3}}, 758 {3, 3, Shape{1, 2, 3, 4}}, 759 {3, 4, Shape{1, 2, 3, 4}}, 760 } 761 762 // The RollAxis tests are directly adapted from Numpy's test cases. 763 func TestDense_RollAxis(t *testing.T) { 764 assert := assert.New(t) 765 var T *Dense 766 var err error 767 768 for _, rats := range rollaxisTests { 769 T = New(Of(Byte), WithShape(1, 2, 3, 4)) 770 if _, err = T.RollAxis(rats.axis, rats.start, false); assert.NoError(err) { 771 assert.True(rats.correctShape.Eq(T.Shape()), "%d %d Expected %v, got %v", rats.axis, rats.start, rats.correctShape, T.Shape()) 772 } 773 } 774 } 775 776 var concatTests = []struct { 777 name string 778 dt Dtype 779 a interface{} 780 b interface{} 781 shape Shape 782 shapeB Shape 783 axis int 784 785 correctShape Shape 786 correctData interface{} 787 }{ 788 // Float64 789 {"vector", Float64, nil, nil, Shape{2}, nil, 0, Shape{4}, []float64{0, 1, 0, 1}}, 790 {"matrix; axis 0 ", Float64, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []float64{0, 1, 2, 3, 0, 1, 2, 3}}, 791 {"matrix; axis 1 ", Float64, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []float64{0, 1, 0, 1, 2, 3, 2, 3}}, 792 793 // Float32 794 {"vector", Float32, nil, nil, Shape{2}, nil, 0, Shape{4}, []float32{0, 1, 0, 1}}, 795 {"matrix; axis 0 ", Float32, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []float32{0, 1, 2, 3, 0, 1, 2, 3}}, 796 {"matrix; axis 1 ", Float32, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []float32{0, 1, 0, 1, 2, 3, 2, 3}}, 797 798 // Int 799 {"vector", Int, nil, nil, Shape{2}, nil, 0, Shape{4}, []int{0, 1, 0, 1}}, 800 {"matrix; axis 0 ", Int, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int{0, 1, 2, 3, 0, 1, 2, 3}}, 801 {"matrix; axis 1 ", Int, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int{0, 1, 0, 1, 2, 3, 2, 3}}, 802 803 // Int64 804 {"vector", Int64, nil, nil, Shape{2}, nil, 0, Shape{4}, []int64{0, 1, 0, 1}}, 805 {"matrix; axis 0 ", Int64, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int64{0, 1, 2, 3, 0, 1, 2, 3}}, 806 {"matrix; axis 1 ", Int64, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int64{0, 1, 0, 1, 2, 3, 2, 3}}, 807 808 // Int32 809 {"vector", Int32, nil, nil, Shape{2}, nil, 0, Shape{4}, []int32{0, 1, 0, 1}}, 810 {"matrix; axis 0 ", Int32, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int32{0, 1, 2, 3, 0, 1, 2, 3}}, 811 {"matrix; axis 1 ", Int32, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int32{0, 1, 0, 1, 2, 3, 2, 3}}, 812 813 // Byte 814 {"vector", Byte, nil, nil, Shape{2}, nil, 0, Shape{4}, []byte{0, 1, 0, 1}}, 815 {"matrix; axis 0 ", Byte, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []byte{0, 1, 2, 3, 0, 1, 2, 3}}, 816 {"matrix; axis 1 ", Byte, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []byte{0, 1, 0, 1, 2, 3, 2, 3}}, 817 818 // Bool 819 {"vector", Bool, []bool{true, false}, nil, Shape{2}, nil, 0, Shape{4}, []bool{true, false, true, false}}, 820 {"matrix; axis 0 ", Bool, []bool{true, false, true, false}, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []bool{true, false, true, false, true, false, true, false}}, 821 {"matrix; axis 1 ", Bool, []bool{true, false, true, false}, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []bool{true, false, true, false, true, false, true, false}}, 822 823 // gorgonia/gorgonia#218 related 824 {"matrix; axis 0", Float64, nil, nil, Shape{2, 2}, Shape{1, 2}, 0, Shape{3, 2}, []float64{0, 1, 2, 3, 0, 1}}, 825 {"matrix; axis 1", Float64, nil, nil, Shape{2, 2}, Shape{2, 1}, 1, Shape{2, 3}, []float64{0, 1, 0, 2, 3, 1}}, 826 {"colvec matrix, axis 0", Float64, nil, nil, Shape{2, 1}, Shape{1, 1}, 0, Shape{3, 1}, []float64{0, 1, 0}}, 827 {"rowvec matrix, axis 1", Float64, nil, nil, Shape{1, 2}, Shape{1, 1}, 1, Shape{1, 3}, []float64{0, 1, 0}}, 828 829 {"3tensor; axis 0", Float64, nil, nil, Shape{2, 3, 2}, Shape{1, 3, 2}, 0, Shape{3, 3, 2}, []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5}}, 830 {"3tensor; axis 2", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 3, 1}, 2, Shape{2, 3, 3}, []float64{0, 1, 0, 2, 3, 1, 4, 5, 2, 6, 7, 3, 8, 9, 4, 10, 11, 5}}, 831 {"3tensor; axis 1", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 1, 2}, 1, Shape{2, 4, 2}, []float64{0, 1, 2, 3, 4, 5, 0, 1, 6, 7, 8, 9, 10, 11, 2, 3}}, 832 } 833 834 func TestDense_Concat(t *testing.T) { 835 assert := assert.New(t) 836 837 for _, cts := range concatTests { 838 var T0, T1 *Dense 839 840 if cts.a == nil { 841 T0 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) 842 } else { 843 T0 = New(WithShape(cts.shape...), WithBacking(cts.a)) 844 } 845 846 switch { 847 case cts.shapeB == nil && cts.a == nil: 848 T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) 849 case cts.shapeB == nil && cts.a != nil: 850 T1 = New(WithShape(cts.shape...), WithBacking(cloneArray(cts.a))) 851 case cts.shapeB != nil && cts.b == nil: 852 T1 = New(WithShape(cts.shapeB...), WithBacking(Range(cts.dt, 0, cts.shapeB.TotalSize()))) 853 case cts.shapeB != nil && cts.b != nil: 854 T1 = New(WithShape(cts.shapeB...), WithBacking(cts.b)) 855 } 856 857 T2, err := T0.Concat(cts.axis, T1) 858 if err != nil { 859 t.Errorf("Test %v failed: %v", cts.name, err) 860 continue 861 } 862 863 assert.True(cts.correctShape.Eq(T2.Shape())) 864 assert.Equal(cts.correctData, T2.Data()) 865 } 866 867 //Masked case 868 869 for _, cts := range concatTests { 870 var T0, T1 *Dense 871 872 if cts.a == nil { 873 T0 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) 874 T0.MaskedEqual(castToDt(0.0, cts.dt)) 875 } else { 876 T0 = New(WithShape(cts.shape...), WithBacking(cts.a)) 877 T0.MaskedEqual(castToDt(0.0, cts.dt)) 878 } 879 880 switch { 881 case cts.shapeB == nil && cts.a == nil: 882 T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) 883 case cts.shapeB == nil && cts.a != nil: 884 T1 = New(WithShape(cts.shape...), WithBacking(cloneArray(cts.a))) 885 case cts.shapeB != nil && cts.b == nil: 886 T1 = New(WithShape(cts.shapeB...), WithBacking(Range(cts.dt, 0, cts.shapeB.TotalSize()))) 887 case cts.shapeB != nil && cts.b != nil: 888 T1 = New(WithShape(cts.shapeB...), WithBacking(cts.b)) 889 } 890 T1.MaskedEqual(castToDt(0.0, cts.dt)) 891 892 T2, err := T0.Concat(cts.axis, T1) 893 if err != nil { 894 t.Errorf("Test %v failed: %v", cts.name, err) 895 continue 896 } 897 898 T3 := New(WithShape(cts.correctShape...), WithBacking(cts.correctData)) 899 T3.MaskedEqual(castToDt(0.0, cts.dt)) 900 901 assert.True(cts.correctShape.Eq(T2.Shape())) 902 assert.Equal(cts.correctData, T2.Data()) 903 assert.Equal(T3.mask, T2.mask) 904 } 905 } 906 907 func TestDense_Concat_sliced(t *testing.T) { 908 v := New( 909 WithShape(1, 5), 910 WithBacking([]float64{0, 1, 2, 3, 4}), 911 ) 912 cols := make([]Tensor, v.Shape().TotalSize()) 913 for i := 0; i < v.Shape().TotalSize(); i++ { 914 sliced, err := v.Slice(nil, ss(i)) 915 if err != nil { 916 t.Fatalf("Failed to slice %d. Error: %v", i, err) 917 } 918 if err = sliced.Reshape(sliced.Shape().TotalSize(), 1); err != nil { 919 t.Fatalf("Failed to reshape %d. Error %v", i, err) 920 } 921 cols[i] = sliced 922 } 923 result, err := Concat(1, cols[0], cols[1:]...) 924 if err != nil { 925 t.Error(err) 926 } 927 assert.Equal(t, v.Data(), result.Data()) 928 if v.Uintptr() == result.Uintptr() { 929 t.Error("They should not share the same backing data!") 930 } 931 932 } 933 934 var simpleStackTests = []struct { 935 name string 936 dt Dtype 937 shape Shape 938 axis int 939 stackCount int 940 941 correctShape Shape 942 correctData interface{} 943 }{ 944 // Size 8 945 {"vector, axis 0, stack 2", Float64, Shape{2}, 0, 2, Shape{2, 2}, []float64{0, 1, 100, 101}}, 946 {"vector, axis 1, stack 2", Float64, Shape{2}, 1, 2, Shape{2, 2}, []float64{0, 100, 1, 101}}, 947 {"matrix, axis 0, stack 2", Float64, Shape{2, 3}, 0, 2, Shape{2, 2, 3}, []float64{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105}}, 948 {"matrix, axis 1, stack 2", Float64, Shape{2, 3}, 1, 2, Shape{2, 2, 3}, []float64{0, 1, 2, 100, 101, 102, 3, 4, 5, 103, 104, 105}}, 949 {"matrix, axis 2, stack 2", Float64, Shape{2, 3}, 2, 2, Shape{2, 3, 2}, []float64{0, 100, 1, 101, 2, 102, 3, 103, 4, 104, 5, 105}}, 950 {"matrix, axis 0, stack 3", Float64, Shape{2, 3}, 0, 3, Shape{3, 2, 3}, []float64{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105, 200, 201, 202, 203, 204, 205}}, 951 {"matrix, axis 1, stack 3", Float64, Shape{2, 3}, 1, 3, Shape{2, 3, 3}, []float64{0, 1, 2, 100, 101, 102, 200, 201, 202, 3, 4, 5, 103, 104, 105, 203, 204, 205}}, 952 {"matrix, axis 2, stack 3", Float64, Shape{2, 3}, 2, 3, Shape{2, 3, 3}, []float64{0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203, 4, 104, 204, 5, 105, 205}}, 953 954 // Size 4 955 {"vector, axis 0, stack 2 (f32)", Float32, Shape{2}, 0, 2, Shape{2, 2}, []float32{0, 1, 100, 101}}, 956 {"vector, axis 1, stack 2 (f32)", Float32, Shape{2}, 1, 2, Shape{2, 2}, []float32{0, 100, 1, 101}}, 957 {"matrix, axis 0, stack 2 (f32)", Float32, Shape{2, 3}, 0, 2, Shape{2, 2, 3}, []float32{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105}}, 958 {"matrix, axis 1, stack 2 (f32)", Float32, Shape{2, 3}, 1, 2, Shape{2, 2, 3}, []float32{0, 1, 2, 100, 101, 102, 3, 4, 5, 103, 104, 105}}, 959 {"matrix, axis 2, stack 2 (f32)", Float32, Shape{2, 3}, 2, 2, Shape{2, 3, 2}, []float32{0, 100, 1, 101, 2, 102, 3, 103, 4, 104, 5, 105}}, 960 {"matrix, axis 0, stack 3 (f32)", Float32, Shape{2, 3}, 0, 3, Shape{3, 2, 3}, []float32{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105, 200, 201, 202, 203, 204, 205}}, 961 {"matrix, axis 1, stack 3 (f32)", Float32, Shape{2, 3}, 1, 3, Shape{2, 3, 3}, []float32{0, 1, 2, 100, 101, 102, 200, 201, 202, 3, 4, 5, 103, 104, 105, 203, 204, 205}}, 962 {"matrix, axis 2, stack 3 (f32)", Float32, Shape{2, 3}, 2, 3, Shape{2, 3, 3}, []float32{0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203, 4, 104, 204, 5, 105, 205}}, 963 964 // Size 2 965 {"vector, axis 0, stack 2 (i16)", Int16, Shape{2}, 0, 2, Shape{2, 2}, []int16{0, 1, 100, 101}}, 966 {"vector, axis 1, stack 2 (i16)", Int16, Shape{2}, 1, 2, Shape{2, 2}, []int16{0, 100, 1, 101}}, 967 {"matrix, axis 0, stack 2 (i16)", Int16, Shape{2, 3}, 0, 2, Shape{2, 2, 3}, []int16{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105}}, 968 {"matrix, axis 1, stack 2 (i16)", Int16, Shape{2, 3}, 1, 2, Shape{2, 2, 3}, []int16{0, 1, 2, 100, 101, 102, 3, 4, 5, 103, 104, 105}}, 969 {"matrix, axis 2, stack 2 (i16)", Int16, Shape{2, 3}, 2, 2, Shape{2, 3, 2}, []int16{0, 100, 1, 101, 2, 102, 3, 103, 4, 104, 5, 105}}, 970 {"matrix, axis 0, stack 3 (i16)", Int16, Shape{2, 3}, 0, 3, Shape{3, 2, 3}, []int16{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105, 200, 201, 202, 203, 204, 205}}, 971 {"matrix, axis 1, stack 3 (i16)", Int16, Shape{2, 3}, 1, 3, Shape{2, 3, 3}, []int16{0, 1, 2, 100, 101, 102, 200, 201, 202, 3, 4, 5, 103, 104, 105, 203, 204, 205}}, 972 {"matrix, axis 2, stack 3 (i16)", Int16, Shape{2, 3}, 2, 3, Shape{2, 3, 3}, []int16{0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203, 4, 104, 204, 5, 105, 205}}, 973 974 // Size 1 975 {"vector, axis 0, stack 2 (u8)", Byte, Shape{2}, 0, 2, Shape{2, 2}, []byte{0, 1, 100, 101}}, 976 {"vector, axis 1, stack 2 (u8)", Byte, Shape{2}, 1, 2, Shape{2, 2}, []byte{0, 100, 1, 101}}, 977 {"matrix, axis 0, stack 2 (u8)", Byte, Shape{2, 3}, 0, 2, Shape{2, 2, 3}, []byte{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105}}, 978 {"matrix, axis 1, stack 2 (u8)", Byte, Shape{2, 3}, 1, 2, Shape{2, 2, 3}, []byte{0, 1, 2, 100, 101, 102, 3, 4, 5, 103, 104, 105}}, 979 {"matrix, axis 2, stack 2 (u8)", Byte, Shape{2, 3}, 2, 2, Shape{2, 3, 2}, []byte{0, 100, 1, 101, 2, 102, 3, 103, 4, 104, 5, 105}}, 980 {"matrix, axis 0, stack 3 (u8)", Byte, Shape{2, 3}, 0, 3, Shape{3, 2, 3}, []byte{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105, 200, 201, 202, 203, 204, 205}}, 981 {"matrix, axis 1, stack 3 (u8)", Byte, Shape{2, 3}, 1, 3, Shape{2, 3, 3}, []byte{0, 1, 2, 100, 101, 102, 200, 201, 202, 3, 4, 5, 103, 104, 105, 203, 204, 205}}, 982 {"matrix, axis 2, stack 3 (u8)", Byte, Shape{2, 3}, 2, 3, Shape{2, 3, 3}, []byte{0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203, 4, 104, 204, 5, 105, 205}}, 983 } 984 985 var viewStackTests = []struct { 986 name string 987 dt Dtype 988 shape Shape 989 transform []int 990 slices []Slice 991 axis int 992 stackCount int 993 994 correctShape Shape 995 correctData interface{} 996 }{ 997 // Size 8 998 {"matrix(4x4)[1:3, 1:3] axis 0", Float64, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 0, 2, Shape{2, 2, 2}, []float64{5, 6, 9, 10, 105, 106, 109, 110}}, 999 {"matrix(4x4)[1:3, 1:3] axis 1", Float64, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 1, 2, Shape{2, 2, 2}, []float64{5, 6, 105, 106, 9, 10, 109, 110}}, 1000 {"matrix(4x4)[1:3, 1:3] axis 2", Float64, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 2, 2, Shape{2, 2, 2}, []float64{5, 105, 6, 106, 9, 109, 10, 110}}, 1001 1002 // Size 4 1003 {"matrix(4x4)[1:3, 1:3] axis 0 (u32)", Uint32, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 0, 2, Shape{2, 2, 2}, []uint32{5, 6, 9, 10, 105, 106, 109, 110}}, 1004 {"matrix(4x4)[1:3, 1:3] axis 1 (u32)", Uint32, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 1, 2, Shape{2, 2, 2}, []uint32{5, 6, 105, 106, 9, 10, 109, 110}}, 1005 {"matrix(4x4)[1:3, 1:3] axis 2 (u32)", Uint32, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 2, 2, Shape{2, 2, 2}, []uint32{5, 105, 6, 106, 9, 109, 10, 110}}, 1006 1007 // Size 2 1008 {"matrix(4x4)[1:3, 1:3] axis 0 (u16)", Uint16, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 0, 2, Shape{2, 2, 2}, []uint16{5, 6, 9, 10, 105, 106, 109, 110}}, 1009 {"matrix(4x4)[1:3, 1:3] axis 1 (u16)", Uint16, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 1, 2, Shape{2, 2, 2}, []uint16{5, 6, 105, 106, 9, 10, 109, 110}}, 1010 {"matrix(4x4)[1:3, 1:3] axis 2 (u16)", Uint16, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 2, 2, Shape{2, 2, 2}, []uint16{5, 105, 6, 106, 9, 109, 10, 110}}, 1011 1012 // Size 1 1013 {"matrix(4x4)[1:3, 1:3] axis 0 (u8)", Byte, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 0, 2, Shape{2, 2, 2}, []byte{5, 6, 9, 10, 105, 106, 109, 110}}, 1014 {"matrix(4x4)[1:3, 1:3] axis 1 (u8)", Byte, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 1, 2, Shape{2, 2, 2}, []byte{5, 6, 105, 106, 9, 10, 109, 110}}, 1015 {"matrix(4x4)[1:3, 1:3] axis 2 (u8)", Byte, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 2, 2, Shape{2, 2, 2}, []byte{5, 105, 6, 106, 9, 109, 10, 110}}, 1016 } 1017 1018 func TestDense_Stack(t *testing.T) { 1019 assert := assert.New(t) 1020 var err error 1021 for _, sts := range simpleStackTests { 1022 T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize()))) 1023 1024 var stacked []*Dense 1025 for i := 0; i < sts.stackCount-1; i++ { 1026 offset := (i + 1) * 100 1027 T1 := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, offset, sts.shape.TotalSize()+offset))) 1028 stacked = append(stacked, T1) 1029 } 1030 1031 T2, err := T.Stack(sts.axis, stacked...) 1032 if err != nil { 1033 t.Error(err) 1034 continue 1035 } 1036 assert.True(sts.correctShape.Eq(T2.Shape())) 1037 assert.Equal(sts.correctData, T2.Data()) 1038 } 1039 1040 for _, sts := range viewStackTests { 1041 T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize()))) 1042 switch { 1043 case sts.slices != nil && sts.transform == nil: 1044 var sliced Tensor 1045 if sliced, err = T.Slice(sts.slices...); err != nil { 1046 t.Error(err) 1047 continue 1048 } 1049 T = sliced.(*Dense) 1050 case sts.transform != nil && sts.slices == nil: 1051 T.T(sts.transform...) 1052 } 1053 1054 var stacked []*Dense 1055 for i := 0; i < sts.stackCount-1; i++ { 1056 offset := (i + 1) * 100 1057 T1 := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, offset, sts.shape.TotalSize()+offset))) 1058 switch { 1059 case sts.slices != nil && sts.transform == nil: 1060 var sliced Tensor 1061 if sliced, err = T1.Slice(sts.slices...); err != nil { 1062 t.Error(err) 1063 continue 1064 } 1065 T1 = sliced.(*Dense) 1066 case sts.transform != nil && sts.slices == nil: 1067 T1.T(sts.transform...) 1068 } 1069 1070 stacked = append(stacked, T1) 1071 } 1072 T2, err := T.Stack(sts.axis, stacked...) 1073 if err != nil { 1074 t.Error(err) 1075 continue 1076 } 1077 assert.True(sts.correctShape.Eq(T2.Shape())) 1078 assert.Equal(sts.correctData, T2.Data(), "%q failed", sts.name) 1079 } 1080 1081 // Repeat tests with masks 1082 for _, sts := range simpleStackTests { 1083 T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize()))) 1084 1085 var stacked []*Dense 1086 for i := 0; i < sts.stackCount-1; i++ { 1087 offset := (i + 1) * 100 1088 T1 := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, offset, sts.shape.TotalSize()+offset))) 1089 T1.MaskedInside(castToDt(102.0, sts.dt), castToDt(225.0, sts.dt)) 1090 stacked = append(stacked, T1) 1091 } 1092 1093 T2, err := T.Stack(sts.axis, stacked...) 1094 if err != nil { 1095 t.Error(err) 1096 continue 1097 } 1098 1099 T3 := New(WithShape(sts.correctShape...), WithBacking(sts.correctData)) 1100 T3.MaskedInside(castToDt(102.0, sts.dt), castToDt(225.0, sts.dt)) 1101 1102 assert.True(sts.correctShape.Eq(T2.Shape())) 1103 assert.Equal(sts.correctData, T2.Data()) 1104 assert.Equal(T3.mask, T2.mask) 1105 } 1106 1107 for _, sts := range viewStackTests { 1108 T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize()))) 1109 switch { 1110 case sts.slices != nil && sts.transform == nil: 1111 var sliced Tensor 1112 if sliced, err = T.Slice(sts.slices...); err != nil { 1113 t.Error(err) 1114 continue 1115 } 1116 T = sliced.(*Dense) 1117 case sts.transform != nil && sts.slices == nil: 1118 T.T(sts.transform...) 1119 } 1120 1121 var stacked []*Dense 1122 for i := 0; i < sts.stackCount-1; i++ { 1123 offset := (i + 1) * 100 1124 T1 := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, offset, sts.shape.TotalSize()+offset))) 1125 T1.MaskedInside(castToDt(102.0, sts.dt), castToDt(225.0, sts.dt)) 1126 switch { 1127 case sts.slices != nil && sts.transform == nil: 1128 var sliced Tensor 1129 if sliced, err = T1.Slice(sts.slices...); err != nil { 1130 t.Error(err) 1131 continue 1132 } 1133 T1 = sliced.(*Dense) 1134 case sts.transform != nil && sts.slices == nil: 1135 T1.T(sts.transform...) 1136 } 1137 1138 stacked = append(stacked, T1) 1139 } 1140 1141 T2, err := T.Stack(sts.axis, stacked...) 1142 if err != nil { 1143 t.Error(err) 1144 continue 1145 } 1146 1147 T3 := New(WithShape(sts.correctShape...), WithBacking(sts.correctData)) 1148 T3.MaskedInside(castToDt(102.0, sts.dt), castToDt(225.0, sts.dt)) 1149 1150 assert.True(sts.correctShape.Eq(T2.Shape())) 1151 assert.Equal(sts.correctData, T2.Data()) 1152 assert.Equal(T3.mask, T2.mask) 1153 } 1154 1155 // arbitrary view slices 1156 1157 T := New(WithShape(2, 2), WithBacking([]string{"hello", "world", "nihao", "sekai"})) 1158 var stacked []*Dense 1159 for i := 0; i < 1; i++ { 1160 T1 := New(WithShape(2, 2), WithBacking([]string{"blah1", "blah2", "blah3", "blah4"})) 1161 var sliced Tensor 1162 if sliced, err = T1.Slice(nil, nil); err != nil { 1163 t.Error(err) 1164 break 1165 } 1166 T1 = sliced.(*Dense) 1167 stacked = append(stacked, T1) 1168 } 1169 T2, err := T.Stack(0, stacked...) 1170 if err != nil { 1171 t.Error(err) 1172 return 1173 } 1174 1175 correctShape := Shape{2, 2, 2} 1176 correctData := []string{"hello", "world", "nihao", "sekai", "blah1", "blah2", "blah3", "blah4"} 1177 assert.True(correctShape.Eq(T2.Shape())) 1178 assert.Equal(correctData, T2.Data(), "%q failed", "arbitrary view slice") 1179 }