github.com/wzzhu/tensor@v0.9.24/dense_softmax_test.go (about) 1 package tensor 2 3 import ( 4 "fmt" 5 "testing" 6 7 "github.com/stretchr/testify/assert" 8 ) 9 10 func TestSoftMax(t *testing.T) { 11 testCases := []struct { 12 fn func(x Tensor, axis int, opts ...FuncOpt) (Tensor, error) 13 x Tensor 14 axis int 15 expectedOutput interface{} 16 }{ 17 { 18 fn: LogSoftMax, 19 x: New( 20 Of(Float64), 21 WithShape(3, 4), 22 WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), 23 ), 24 axis: -1, 25 expectedOutput: []float64{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}, 26 }, 27 { 28 fn: LogSoftMax, 29 x: New( 30 Of(Float32), 31 WithShape(3, 4), 32 WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), 33 ), 34 axis: -1, 35 expectedOutput: []float32{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}, 36 }, 37 { 38 fn: LogSoftMax, 39 x: New( 40 Of(Float32), 41 WithShape(3, 2, 2), 42 WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), 43 ), 44 axis: -1, 45 expectedOutput: []float32{-0.7443967, -0.64439666, -0.7443967, -0.64439666, -0.7443967, -0.64439666, -0.7443966, -0.64439666, -0.7443966, -0.64439666, -0.7443967, -0.64439666}, 46 }, 47 { 48 fn: LogSoftMax, 49 x: New( 50 Of(Float64), 51 WithShape(3, 2, 2), 52 WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), 53 ), 54 axis: 1, 55 expectedOutput: []float64{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}, 56 }, 57 { 58 fn: SoftMax, 59 x: New( 60 Of(Float64), 61 WithShape(3, 2, 2), 62 WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), 63 ), 64 axis: 1, 65 expectedOutput: []float64{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}, 66 }, 67 { 68 fn: SoftMax, 69 x: New( 70 Of(Float64), 71 WithShape(3, 2, 2), 72 WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), 73 ), 74 axis: -1, 75 expectedOutput: []float64{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}, 76 }, 77 { 78 fn: SoftMax, 79 x: New( 80 Of(Float32), 81 WithShape(3, 4), 82 WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), 83 ), 84 axis: -1, 85 expectedOutput: []float32{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}, 86 }, 87 { 88 fn: SoftMax, 89 x: New( 90 Of(Float64), 91 WithShape(3, 4), 92 WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}), 93 ), 94 axis: -1, 95 expectedOutput: []float64{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}, 96 }, 97 } 98 for i, tC := range testCases { 99 t.Run(fmt.Sprintf("Example #%d - %v %v", i+1, tC.x.Shape(), tC.x.Dtype()), func(t *testing.T) { 100 c := assert.New(t) 101 102 output, err := tC.fn(tC.x, tC.axis) 103 t.Logf("output: %#v", output.Data()) 104 105 c.NoError(err) 106 c.NotNil(output) 107 108 c.Equal(tC.x.Shape(), output.Shape()) 109 c.InDeltaSlice(tC.expectedOutput, output.Data(), 1e-6) 110 }) 111 } 112 } 113 114 func TestSoftMaxB(t *testing.T) { 115 testCases := []struct { 116 fn func(output, grad Tensor, axis int, opts ...FuncOpt) (Tensor, error) 117 output Tensor 118 grad Tensor 119 axis int 120 expectedOutput interface{} 121 }{ 122 { 123 fn: SoftMaxB, 124 output: New( 125 Of(Float64), 126 WithShape(3, 4), 127 WithBacking([]float64{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}), 128 ), 129 grad: New( 130 Of(Float64), 131 WithShape(3, 4), 132 WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), 133 ), 134 axis: -1, 135 expectedOutput: []float64{-0.003474116568224552, -0.0014762147035963322, 0.0009803563066858392, 0.00396997522759976, -0.003474116880376028, -0.001476214931490494, 0.0009803561238580223, 0.003969975025543781, -0.0034741159267098936, -0.0014762139946130218, 0.0009803570151630109, 0.003969976093553957}, 136 }, 137 { 138 fn: LogSoftMaxB, 139 output: New( 140 Of(Float64), 141 WithShape(3, 4), 142 WithBacking([]float64{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}), 143 ), 144 grad: New( 145 Of(Float64), 146 WithShape(3, 4), 147 WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), 148 ), 149 axis: -1, 150 expectedOutput: []float64{-0.011383822036598441, -0.003632778232153768, 0.0038817407844924366, 0.01113485948425977, -0.005597937295155945, -0.001445223403599799, 0.0020925260396803457, 0.004950634659075405, 0.00018794744628654992, 0.0007423314249541871, 0.00030331129486827163, -0.0012335901661089598}, 151 }, 152 { 153 fn: SoftMaxB, 154 output: New( 155 Of(Float64), 156 WithShape(3, 2, 2), 157 WithBacking([]float64{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}), 158 ), 159 grad: New( 160 Of(Float64), 161 WithShape(3, 2, 2), 162 WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), 163 ), 164 axis: -1, 165 expectedOutput: []float64{-0.002493760401928919, 0.0024937604019289205, -0.0024937604019289183, 0.002493760401928922, -0.002493760401928915, 0.002493760401928922, -0.002493760401928912, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289183}, 166 }, 167 { 168 fn: SoftMaxB, 169 output: New( 170 Of(Float64), 171 WithShape(3, 2, 2), 172 WithBacking([]float64{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}), 173 ), 174 grad: New( 175 Of(Float64), 176 WithShape(3, 2, 2), 177 WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), 178 ), 179 axis: 1, 180 expectedOutput: []float64{-0.004950331454237199, -0.004950331454237198, 0.004950331454237199, 0.0049503314542372, -0.004950331454237196, -0.004950331454237193, 0.004950331454237203, 0.0049503314542372065, -0.004950331454237193, -0.0049503314542372, 0.0049503314542372065, 0.004950331454237193}, 181 }, 182 { 183 fn: LogSoftMaxB, 184 output: New( 185 Of(Float64), 186 WithShape(3, 2, 2), 187 WithBacking([]float64{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}), 188 ), 189 grad: New( 190 Of(Float64), 191 WithShape(3, 2, 2), 192 WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), 193 ), 194 axis: 1, 195 expectedOutput: []float64{-0.008006640107500884, -0.007009960161251325, 0.00800664010750088, 0.007009960161251332, -0.004019920322502654, -0.003023240376253103, 0.004019920322502661, 0.0030232403762530968, -3.32005375044292e-05, 0.0009634794087451421, 3.320053750442642e-05, -0.0009634794087451543}, 196 }, 197 { 198 fn: LogSoftMaxB, 199 output: New( 200 Of(Float32), 201 WithShape(3, 2, 2), 202 WithBacking([]float32{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}), 203 ), 204 grad: New( 205 Of(Float32), 206 WithShape(3, 2, 2), 207 WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), 208 ), 209 axis: 1, 210 expectedOutput: []float64{-0.008006640107500884, -0.007009960161251325, 0.00800664010750088, 0.007009960161251332, -0.004019920322502654, -0.003023240376253103, 0.004019920322502661, 0.0030232403762530968, -3.32005375044292e-05, 0.0009634794087451421, 3.320053750442642e-05, -0.0009634794087451543}, 211 }, 212 { 213 fn: SoftMaxB, 214 output: New( 215 Of(Float32), 216 WithShape(3, 2, 2), 217 WithBacking([]float32{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}), 218 ), 219 grad: New( 220 Of(Float32), 221 WithShape(3, 2, 2), 222 WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), 223 ), 224 axis: 1, 225 expectedOutput: []float32{-0.004950331454237199, -0.004950331454237198, 0.004950331454237199, 0.0049503314542372, -0.004950331454237196, -0.004950331454237193, 0.004950331454237203, 0.0049503314542372065, -0.004950331454237193, -0.0049503314542372, 0.0049503314542372065, 0.004950331454237193}, 226 }, 227 { 228 fn: SoftMaxB, 229 output: New( 230 Of(Float32), 231 WithShape(3, 4), 232 WithBacking([]float32{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}), 233 ), 234 grad: New( 235 Of(Float64), 236 WithShape(3, 4), 237 WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), 238 ), 239 axis: -1, 240 expectedOutput: []float32{-0.003474116568224552, -0.0014762147035963322, 0.0009803563066858392, 0.00396997522759976, -0.003474116880376028, -0.001476214931490494, 0.0009803561238580223, 0.003969975025543781, -0.0034741159267098936, -0.0014762139946130218, 0.0009803570151630109, 0.003969976093553957}, 241 }, 242 { 243 fn: LogSoftMaxB, 244 output: New( 245 Of(Float64), 246 WithShape(3, 4), 247 WithBacking([]float32{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}), 248 ), 249 grad: New( 250 Of(Float64), 251 WithShape(3, 4), 252 WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), 253 ), 254 axis: -1, 255 expectedOutput: []float32{-0.011383822036598441, -0.003632778232153768, 0.0038817407844924366, 0.01113485948425977, -0.005597937295155945, -0.001445223403599799, 0.0020925260396803457, 0.004950634659075405, 0.00018794744628654992, 0.0007423314249541871, 0.00030331129486827163, -0.0012335901661089598}, 256 }, 257 { 258 fn: SoftMaxB, 259 output: New( 260 Of(Float64), 261 WithShape(3, 2, 2), 262 WithBacking([]float32{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}), 263 ), 264 grad: New( 265 Of(Float64), 266 WithShape(3, 2, 2), 267 WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}), 268 ), 269 axis: -1, 270 expectedOutput: []float32{-0.002493760401928919, 0.0024937604019289205, -0.0024937604019289183, 0.002493760401928922, -0.002493760401928915, 0.002493760401928922, -0.002493760401928912, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289183}, 271 }, 272 } 273 for i, tC := range testCases { 274 t.Run(fmt.Sprintf("Example #%d - %v %v", i+1, tC.output.Shape(), tC.output.Dtype()), func(t *testing.T) { 275 c := assert.New(t) 276 277 dx, err := tC.fn(tC.output, tC.grad, tC.axis) 278 t.Logf("output: %#v", tC.output.Data()) 279 280 c.NoError(err) 281 c.NotNil(dx) 282 283 c.Equal(tC.output.Shape(), dx.Shape()) 284 c.InDeltaSlice(tC.expectedOutput, dx.Data(), 1e-6) 285 }) 286 } 287 }