github.com/wzzhu/tensor@v0.9.24/dense_selbyidx_test.go (about) 1 package tensor 2 3 import ( 4 "testing" 5 6 "github.com/stretchr/testify/assert" 7 ) 8 9 type selByIndicesTest struct { 10 Name string 11 Data interface{} 12 Shape Shape 13 Indices []int 14 Axis int 15 WillErr bool 16 17 Correct interface{} 18 CorrectShape Shape 19 } 20 21 var selByIndicesTests = []selByIndicesTest{ 22 {Name: "Basic", Data: Range(Float64, 0, 4), Shape: Shape{2, 2}, Indices: []int{0, 1}, Axis: 0, WillErr: false, 23 Correct: []float64{0, 1, 2, 3}, CorrectShape: Shape{2, 2}, 24 }, 25 {Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false, 26 Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}}, 27 28 {Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false, 29 Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}}, 30 31 {Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false, 32 Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}}, 33 34 {Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false, 35 Correct: []int{1, 1}, CorrectShape: Shape{2}}, 36 37 {Name: "Vector, axis 1", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 1, WillErr: true, 38 Correct: []int{1, 1}, CorrectShape: Shape{2}}, 39 {Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false, 40 Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}}, 41 {Name: "(2,1) Matrx (colvec) with (10) indices", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, 42 Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10}, 43 }, 44 } 45 46 func TestDense_SelectByIndices(t *testing.T) { 47 assert := assert.New(t) 48 for i, tc := range selByIndicesTests { 49 T := New(WithShape(tc.Shape...), WithBacking(tc.Data)) 50 indices := New(WithBacking(tc.Indices)) 51 ret, err := ByIndices(T, indices, tc.Axis) 52 if checkErr(t, tc.WillErr, err, tc.Name, i) { 53 continue 54 } 55 assert.Equal(tc.Correct, ret.Data()) 56 assert.True(tc.CorrectShape.Eq(ret.Shape())) 57 } 58 } 59 60 var selByIndicesBTests = []struct { 61 selByIndicesTest 62 63 CorrectGrad interface{} 64 CorrectGradShape Shape 65 }{ 66 // Basic 67 { 68 CorrectGrad: []float64{1, 1, 1, 1}, 69 }, 70 // 3-tensor, axis 0 71 { 72 CorrectGrad: []float64{0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0}, 73 }, 74 // 3-tensor, axis 1 75 { 76 CorrectGrad: []float64{0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2}, 77 }, 78 // 3-tensor, axis 2 79 { 80 CorrectGrad: []float64{0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0}, 81 }, 82 // vector, axis 0 83 { 84 CorrectGrad: []int{0, 2, 0, 0, 0}, 85 }, 86 // vector, axis 1 87 { 88 CorrectGrad: []float32{4, 6, 8, 12, 8, 12, 0, 0}, 89 }, 90 // (4,2) Matrix with (10) indices 91 { 92 CorrectGrad: []float32{2, 2, 4, 4, 4, 4, 0, 0}, 93 }, 94 // (2, 1) Matrix (colvec) with (10) indices 95 { 96 CorrectGrad: []float64{0, 10}, 97 }, 98 } 99 100 func init() { 101 for i := range selByIndicesBTests { 102 selByIndicesBTests[i].selByIndicesTest = selByIndicesTests[i] 103 selByIndicesBTests[i].CorrectGradShape = selByIndicesTests[i].Shape 104 } 105 } 106 107 func TestDense_SelectByIndicesB(t *testing.T) { 108 109 assert := assert.New(t) 110 for i, tc := range selByIndicesBTests { 111 T := New(WithShape(tc.Shape...), WithBacking(tc.Data)) 112 indices := New(WithBacking(tc.Indices)) 113 ret, err := ByIndices(T, indices, tc.Axis) 114 if checkErr(t, tc.WillErr, err, tc.Name, i) { 115 continue 116 } 117 outGrad := ret.Clone().(*Dense) 118 switch outGrad.Dtype() { 119 case Float64: 120 outGrad.Memset(1.0) 121 case Float32: 122 outGrad.Memset(float32(1.0)) 123 } 124 125 grad, err := ByIndicesB(T, outGrad, indices, tc.Axis) 126 if checkErr(t, tc.WillErr, err, tc.Name, i) { 127 continue 128 } 129 assert.Equal(tc.CorrectGrad, grad.Data(), "%v - x:\n%v\nindices:\n%#v\ny:\n%#v\ngrad:\n%v", tc.Name, T, indices, ret, grad) 130 assert.True(tc.CorrectGradShape.Eq(grad.Shape()), "%v - Grad shape should be %v. Got %v instead.\n\nx:\n%v\nindices:\n%#v\ny:\n%#v\ngrad:\n%v", tc.Name, tc.CorrectGradShape, grad.Shape(), T, indices, ret, grad) 131 } 132 133 }