gorgonia.org/gorgonia@v0.9.17/op_by_indices_test.go (about) 1 package gorgonia 2 3 import ( 4 "log" 5 "testing" 6 7 "github.com/stretchr/testify/require" 8 "gorgonia.org/tensor" 9 ) 10 11 func TestByIndicesOpDo(t *testing.T) { 12 testCases := []struct { 13 desc string 14 input tensor.Tensor 15 indices tensor.Tensor 16 axis int 17 expectedOutput []float64 18 expectedShape tensor.Shape 19 }{ 20 { 21 desc: "Example 1", 22 input: tensor.New( 23 tensor.WithShape(4, 2), 24 tensor.WithBacking(tensor.Range(tensor.Float64, 0, 8)), 25 ), 26 indices: tensor.New( 27 tensor.WithShape(4), 28 tensor.WithBacking([]int{0, 3, 2, 1}), 29 ), 30 axis: 0, 31 expectedOutput: []float64{0, 1, 6, 7, 4, 5, 2, 3}, 32 expectedShape: tensor.Shape{4, 2}, 33 }, 34 { 35 // 0 1 2 36 // 3 4 5 37 desc: "Example 2", 38 input: tensor.New( 39 tensor.WithShape(2, 3), 40 tensor.WithBacking(tensor.Range(tensor.Float64, 0, 6)), 41 ), 42 indices: tensor.New( 43 tensor.WithShape(4), 44 tensor.WithBacking([]int{0, 2, 1, 1}), 45 ), 46 axis: 1, 47 expectedOutput: []float64{0, 2, 1, 1, 3, 5, 4, 4}, 48 expectedShape: tensor.Shape{2, 4}, 49 }, 50 { 51 desc: "Example 3", 52 input: tensor.New( 53 tensor.WithShape(2, 5), 54 tensor.WithBacking(tensor.Range(tensor.Float64, 0, 10)), 55 ), 56 indices: tensor.New( 57 tensor.WithShape(2), 58 tensor.WithBacking([]int{1, 1}), 59 ), 60 axis: 0, 61 expectedOutput: []float64{5, 6, 7, 8, 9, 5, 6, 7, 8, 9}, 62 expectedShape: tensor.Shape{2, 5}, 63 }, 64 } 65 66 for _, tcase := range testCases { 67 t.Run(tcase.desc, func(t *testing.T) { 68 c := require.New(t) 69 70 op := newByIndicesOp(tcase.axis) 71 72 inputV, _, _, err := anyToValue(tcase.input) 73 c.NoError(err) 74 75 indicesV, _, _, err := anyToValue(tcase.indices) 76 c.NoError(err) 77 78 output, err := op.Do(inputV, indicesV) 79 c.NoError(err) 80 81 c.Equal(tcase.expectedOutput, output.Data()) 82 c.Equal(tcase.expectedShape, output.Shape()) 83 }) 84 } 85 } 86 87 func TestByIndicesOpFull(t *testing.T) { 88 testCases := []struct { 89 desc string 90 input tensor.Tensor 91 indices tensor.Tensor 92 axis int 93 expectedOutput []float64 94 expectedShape tensor.Shape 95 }{ 96 { 97 desc: "Example 1", 98 input: tensor.New( 99 tensor.WithShape(4, 2), 100 tensor.WithBacking(tensor.Range(tensor.Float64, 0, 8)), 101 ), 102 indices: tensor.New( 103 tensor.WithShape(4), 104 tensor.WithBacking([]int{0, 3, 2, 1}), 105 ), 106 axis: 0, 107 expectedOutput: []float64{0, 1, 6, 7, 4, 5, 2, 3}, 108 expectedShape: tensor.Shape{4, 2}, 109 }, 110 } 111 112 for _, tcase := range testCases { 113 t.Run(tcase.desc, func(t *testing.T) { 114 c := require.New(t) 115 116 g := NewGraph() 117 118 indices := NewTensor(g, tensor.Int, 1, WithName("indices"), WithShape(tcase.indices.Shape().TotalSize()), WithValue(tcase.indices)) 119 input := NewTensor(g, tensor.Float64, tcase.input.Shape().Dims(), WithName("input"), WithShape(tcase.input.Shape()...), WithValue(tcase.input)) 120 121 output, err := ByIndices(input, indices, tcase.axis) 122 c.NoError(err) 123 124 log.Printf("output shape: %v", output.Shape()) 125 log.Printf("input shape: %v", input.Shape()) 126 127 y := NewTensor(g, tensor.Float64, tcase.input.Shape().Dims(), WithName("target"), WithShape(tcase.input.Shape()...), WithValue(tcase.input)) 128 129 cost := Must(Mean(Must((Sub(output, y))))) // MSE 130 131 _, err = Grad(cost, input) 132 c.NoError(err) 133 134 // logger := log.New(os.Stdout, "", 0) 135 136 vm := NewTapeMachine( 137 g, 138 //WithLogger(logger), 139 WithWatchlist(), 140 BindDualValues(output), 141 TraceExec(), 142 ) 143 144 c.NoError(vm.RunAll()) 145 c.NoError(vm.Close()) 146 147 log.Printf("input %v", input.Value()) 148 log.Printf("result: %v", output.Value()) 149 log.Printf("cost: %v", cost.Value()) 150 151 c.Equal(tcase.expectedOutput, output.Value().Data()) 152 c.Equal(tcase.expectedShape, output.Shape()) 153 c.Equal(0.0, cost.Value().Data().(float64)) 154 }) 155 } 156 }