gorgonia.org/gorgonia@v0.9.17/op_sparsemax_test.go (about) 1 package gorgonia 2 3 import ( 4 "testing" 5 6 "github.com/stretchr/testify/require" 7 "gorgonia.org/tensor" 8 ) 9 10 var testCasesSparseMaxDo = []struct { 11 size tensor.Shape 12 input interface{} 13 expected interface{} 14 axis int 15 }{ 16 { 17 tensor.Shape{4}, []float64{0.3, 0.1, 1.2, 2.3}, []float64{0, 0, 0, 1.0}, -1, 18 }, 19 { 20 tensor.Shape{10}, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, -1, 21 }, 22 { 23 tensor.Shape{3}, []float64{0.1, 0.1, 0.1}, []float64{0.3333333333333333, 0.3333333333333333, 0.3333333333333333}, -1, 24 }, 25 { 26 tensor.Shape{4}, []float64{-0.1, 0.3, -1.1, 2.7}, []float64{0, 0, 0, 1.0}, -1, 27 }, 28 { 29 tensor.Shape{4}, []float32{0.3, 0.1, 1.2, 2.3}, []float32{0, 0, 0, 1.0}, -1, 30 }, 31 { 32 tensor.Shape{10}, []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, -1, 33 }, 34 { 35 tensor.Shape{3}, []float32{0.1, 0.1, 0.1}, []float32{0.33333334, 0.33333334, 0.33333334}, -1, 36 }, 37 { 38 tensor.Shape{4}, []float32{-0.1, 0.3, -1.1, 2.7}, []float32{0, 0, 0, 1.0}, -1, 39 }, 40 { 41 tensor.Shape{4}, []float64{0.9, 0.9, 0.9, 0.5}, []float64{0.3333333333333333, 0.3333333333333333, 0.3333333333333333, 0.0000}, -1, 42 }, 43 { 44 tensor.Shape{6, 2}, 45 []float64{-1.0000, -1.0000, 1.0000, 1.0000, -0.9998, -0.9998, 0.9998, 0.9998, 0.9945, 0.9945, -0.9945, -0.9945}, 46 []float64{0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5}, 47 -1, 48 }, 49 // { 50 // tensor.Shape{6, 2}, 51 // []float64{-1.0, -1.0, 1.0, 1.0, -0.9998, -0.9998, 0.9998, 0.9998, 0.9945, 0.9945, -0.9945, -0.9945}, 52 // []float64{0.0000, 0.0000, 0.3352, 0.3352, 0.0000, 0.0000, 0.3350, 0.3350, 0.3297, 0.3297, 0.0000, 0.0000}, 53 // 0, // TODO 54 // }, 55 { 56 tensor.Shape{6, 2}, 57 []float32{-1.0000, -1.0000, 1.0000, 1.0000, -0.9998, -0.9998, 0.9998, 0.9998, 0.9945, 0.9945, -0.9945, -0.9945}, 58 []float32{0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000}, 59 -1, 60 }, 61 } 62 63 var testCasesSparseMaxDoDiff = []struct { 64 shape tensor.Shape 65 input interface{} 66 grad interface{} 67 68 expected interface{} 69 expectedShape tensor.Shape 70 }{ 71 { 72 tensor.Shape{5}, 73 []float64{1.9968e-05, 1.9968e-05, 5.2120e-02, 2.3542e-01, 7.1242e-01}, 74 []float64{0.2860, -0.0702, 0.8080, 0.9913, 1.4683}, 75 []float64{-0.41068, -0.76688, 0.11132000000000009, 0.29462, 0.77162}, 76 tensor.Shape{5}, 77 }, 78 { 79 tensor.Shape{5}, 80 []float64{5.5620e-02, 2.0027e-05, 7.1182e-01, 2.3252e-01, 2.0027e-05}, 81 []float64{0.1109, -1.4741, 0.7671, 0.2878, 0.0334}, 82 []float64{0.16588, -1.41912, 0.82208, 0.34278, 0.08837999999999999}, 83 tensor.Shape{5}, 84 }, 85 { 86 tensor.Shape{5}, 87 []float64{0.0369, 0.3210, 0.0000, 0.3210, 0.3210}, 88 []float64{0.2094, -1.0000, 0.6411, -0.5032, -0.3909}, 89 []float64{0.630575, -0.5788249999999999, 0, -0.08202499999999996, 0.030274999999999996}, 90 tensor.Shape{5}, 91 }, 92 { 93 tensor.Shape{5}, 94 []float64{0.2592, 0.0000, 0.6909, 0.0498, 0.0000}, 95 []float64{0.2094, -1.0000, 0.6411, 0.0000, -0.3909}, 96 []float64{-0.07410000000000003, 0, 0.3576, -0.28350000000000003, 0}, 97 tensor.Shape{5}, 98 }, 99 { 100 tensor.Shape{5}, 101 []float32{0.0000, 0.0000, 0.0521, 0.2354, 0.7124}, 102 []float32{0.2860, -0.0702, 0.8080, 0.9913, 1.4683}, 103 []float32{-0, -0, -0.2812, -0.09790003, 0.37909997}, 104 tensor.Shape{5}, 105 }, 106 { 107 tensor.Shape{5}, 108 []float32{0.0556, 0.0000, 0.7118, 0.2325, 0.0000}, 109 []float32{0.1109, -1.4741, 0.7671, 0.2878, 0.0334}, 110 []float32{-0.2777, -0, 0.37849998, -0.10079998, -0}, 111 tensor.Shape{5}, 112 }, 113 { 114 tensor.Shape{5}, 115 []float32{0.2841, 0.0000, 0.7159, 0.0000, 0.0000}, 116 []float32{0.2094, -1.0000, 0.6411, -0.5032, -0.3909}, 117 []float32{-0.21585, -0, 0.21585, -0, -0}, 118 tensor.Shape{5}, 119 }, 120 { 121 tensor.Shape{5}, 122 []float32{0.2592, 0.0000, 0.6909, 0.0498, 0.0000}, 123 []float32{0.2094, -1.0000, 0.6411, 0.0000, -0.3909}, 124 []float32{-0.07409999, -0, 0.3576, -0.2835, -0}, 125 tensor.Shape{5}, 126 }, 127 { 128 tensor.Shape{5, 1}, 129 []float32{1, 1, 1, 1, 1}, 130 []float32{0.2094, -1.0000, 0.6411, -0.5032, -0.3909}, 131 []float32{1.253, 0.043599963, 1.6847, 0.54039997, 0.65269995, 1.253, 0.043599963, 1.6847, 0.54039997, 0.65269995, 1.253, 0.043599963, 1.6847, 0.54039997, 0.65269995, 1.253, 0.043599963, 1.6847, 0.54039997, 0.65269995, 1.253, 0.043599963, 1.6847, 0.54039997, 0.65269995}, 132 tensor.Shape{5, 5}, 133 }, 134 } 135 136 func TestSparsemaxDo(t *testing.T) { 137 c := require.New(t) 138 139 for i, testCase := range testCasesSparseMaxDo { 140 dtype := tensor.Float64 141 142 switch testCase.input.(type) { 143 case []float32: 144 dtype = tensor.Float32 145 } 146 147 tt := tensor.New(tensor.Of(dtype), tensor.WithShape(testCase.size...), tensor.WithBacking(testCase.input)) 148 op := newSparsemaxOp(testCase.axis) 149 150 out, err := op.Do(tt) 151 c.NoError(err, "failed test case: %d", i) 152 c.Equal(testCase.expected, out.Data(), "failed test case: %d", i) 153 } 154 } 155 156 func TestSparsemaxDoDiff(t *testing.T) { 157 c := require.New(t) 158 159 for i, testCase := range testCasesSparseMaxDoDiff { 160 g := NewGraph() 161 a := NewTensor(g, Float64, 1, WithName("a"), WithShape(1)) 162 b := NewTensor(g, Float64, 1, WithName("b"), WithShape(1)) 163 164 op := newSparsemaxOp() 165 r, err := ApplyOp(op, a) 166 c.NoError(err) 167 168 var backing interface{} 169 170 switch testCase.input.(type) { 171 case []float64: 172 backing = make([]float64, testCase.expectedShape.TotalSize()) 173 case []float32: 174 backing = make([]float32, testCase.expectedShape.TotalSize()) 175 } 176 177 aT := tensor.New(tensor.WithShape(testCase.shape...), tensor.WithBacking(testCase.input)) 178 bT := tensor.New(tensor.WithShape(testCase.shape.TotalSize()), tensor.WithBacking(testCase.grad)) 179 rT := tensor.New(tensor.WithShape(testCase.expectedShape...), tensor.WithBacking(backing)) 180 181 aVal, _, _, _ := anyToValue(aT) 182 bVal, _, _, _ := anyToValue(bT) 183 rVal, _, _, _ := anyToValue(rT) 184 185 a.bind(dvUnit(aVal)) 186 b.bind(dvUnit(bVal)) 187 r.bind(dvUnitVar(rVal)) 188 189 err = op.DoDiff(ExecutionContext{}, Nodes{a, b}, r) 190 c.NoError(err, "failed test case: %d", i) 191 192 c.Equal(testCase.expected, r.boundTo.Data()) 193 } 194 } 195 196 func TestSparsemaxDoSymDiff(t *testing.T) { 197 c := require.New(t) 198 199 for i, testCase := range testCasesSparseMaxDoDiff { 200 g := NewGraph() 201 a := NewTensor(g, Float64, 1, WithName("a"), WithShape(1)) 202 b := NewTensor(g, Float64, 1, WithName("b"), WithShape(1)) 203 204 aT := tensor.New(tensor.WithShape(testCase.shape...), tensor.WithBacking(testCase.input)) 205 bT := tensor.New(tensor.WithShape(testCase.shape.TotalSize()), tensor.WithBacking(testCase.grad)) 206 207 aVal, _, _, _ := anyToValue(aT) 208 bVal, _, _, _ := anyToValue(bT) 209 210 a.bind(dvUnit(aVal)) 211 b.bind(dvUnit(bVal)) 212 213 op := newSparsemaxOp() 214 diff, err := op.SymDiff(Nodes{a}, nil, b) 215 c.NoError(err, "failed test case: %d", i) 216 217 c.Len(diff, 1) 218 219 vm := NewTapeMachine(g) 220 221 c.NoError(vm.RunAll()) 222 c.NoError(vm.Close()) 223 224 c.Equal(testCase.expected, diff[0].boundTo.Data(), "failed test case: %d", i) 225 } 226 } 227 228 func TestSparsemaxFull(t *testing.T) { 229 c := require.New(t) 230 231 for i, testCase := range testCasesSparseMaxDo { 232 dtype := tensor.Float64 233 234 if _, ok := testCase.input.([]float32); ok { 235 dtype = tensor.Float32 236 } 237 238 tt := tensor.New(tensor.Of(dtype), tensor.WithShape(testCase.size...), tensor.WithBacking(testCase.input)) 239 expected := tensor.New(tensor.Of(dtype), tensor.WithShape(testCase.size...), tensor.WithBacking(testCase.expected)) 240 241 g := NewGraph() 242 inp := NewTensor(g, dtype, testCase.size.Dims(), WithShape(testCase.size...), WithName("inp")) 243 out := Must(Sparsemax(inp, testCase.axis)) 244 245 vm := NewTapeMachine(g) 246 err := Let(inp, tt) 247 c.NoError(err, "failed assigning input on case %d", i) 248 249 c.NoError(vm.RunAll()) 250 c.NoError(vm.Close()) 251 252 c.Equal(expected.Data(), out.Value().(*tensor.Dense).Data(), "output is not equal to expected value for case %d", i) 253 } 254 }