gorgonia.org/tensor@v0.9.24/dense_softmax_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  )
     9  
    10  func TestSoftMax(t *testing.T) {
    11  	testCases := []struct {
    12  		fn             func(x Tensor, axis int, opts ...FuncOpt) (Tensor, error)
    13  		x              Tensor
    14  		axis           int
    15  		expectedOutput interface{}
    16  	}{
    17  		{
    18  			fn: LogSoftMax,
    19  			x: New(
    20  				Of(Float64),
    21  				WithShape(3, 4),
    22  				WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
    23  			),
    24  			axis:           -1,
    25  			expectedOutput: []float64{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628},
    26  		},
    27  		{
    28  			fn: LogSoftMax,
    29  			x: New(
    30  				Of(Float32),
    31  				WithShape(3, 4),
    32  				WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
    33  			),
    34  			axis:           -1,
    35  			expectedOutput: []float32{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628},
    36  		},
    37  		{
    38  			fn: LogSoftMax,
    39  			x: New(
    40  				Of(Float32),
    41  				WithShape(3, 2, 2),
    42  				WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
    43  			),
    44  			axis:           -1,
    45  			expectedOutput: []float32{-0.7443967, -0.64439666, -0.7443967, -0.64439666, -0.7443967, -0.64439666, -0.7443966, -0.64439666, -0.7443966, -0.64439666, -0.7443967, -0.64439666},
    46  		},
    47  		{
    48  			fn: LogSoftMax,
    49  			x: New(
    50  				Of(Float64),
    51  				WithShape(3, 2, 2),
    52  				WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
    53  			),
    54  			axis:           1,
    55  			expectedOutput: []float64{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918},
    56  		},
    57  		{
    58  			fn: SoftMax,
    59  			x: New(
    60  				Of(Float64),
    61  				WithShape(3, 2, 2),
    62  				WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
    63  			),
    64  			axis:           1,
    65  			expectedOutput: []float64{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478},
    66  		},
    67  		{
    68  			fn: SoftMax,
    69  			x: New(
    70  				Of(Float64),
    71  				WithShape(3, 2, 2),
    72  				WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
    73  			),
    74  			axis:           -1,
    75  			expectedOutput: []float64{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894},
    76  		},
    77  		{
    78  			fn: SoftMax,
    79  			x: New(
    80  				Of(Float32),
    81  				WithShape(3, 4),
    82  				WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
    83  			),
    84  			axis:           -1,
    85  			expectedOutput: []float32{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514},
    86  		},
    87  		{
    88  			fn: SoftMax,
    89  			x: New(
    90  				Of(Float64),
    91  				WithShape(3, 4),
    92  				WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
    93  			),
    94  			axis:           -1,
    95  			expectedOutput: []float64{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514},
    96  		},
    97  	}
    98  	for i, tC := range testCases {
    99  		t.Run(fmt.Sprintf("Example #%d - %v %v", i+1, tC.x.Shape(), tC.x.Dtype()), func(t *testing.T) {
   100  			c := assert.New(t)
   101  
   102  			output, err := tC.fn(tC.x, tC.axis)
   103  			t.Logf("output: %#v", output.Data())
   104  
   105  			c.NoError(err)
   106  			c.NotNil(output)
   107  
   108  			c.Equal(tC.x.Shape(), output.Shape())
   109  			c.InDeltaSlice(tC.expectedOutput, output.Data(), 1e-6)
   110  		})
   111  	}
   112  }
   113  
   114  func TestSoftMaxB(t *testing.T) {
   115  	testCases := []struct {
   116  		fn             func(output, grad Tensor, axis int, opts ...FuncOpt) (Tensor, error)
   117  		output         Tensor
   118  		grad           Tensor
   119  		axis           int
   120  		expectedOutput interface{}
   121  	}{
   122  		{
   123  			fn: SoftMaxB,
   124  			output: New(
   125  				Of(Float64),
   126  				WithShape(3, 4),
   127  				WithBacking([]float64{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}),
   128  			),
   129  			grad: New(
   130  				Of(Float64),
   131  				WithShape(3, 4),
   132  				WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
   133  			),
   134  			axis:           -1,
   135  			expectedOutput: []float64{-0.003474116568224552, -0.0014762147035963322, 0.0009803563066858392, 0.00396997522759976, -0.003474116880376028, -0.001476214931490494, 0.0009803561238580223, 0.003969975025543781, -0.0034741159267098936, -0.0014762139946130218, 0.0009803570151630109, 0.003969976093553957},
   136  		},
   137  		{
   138  			fn: LogSoftMaxB,
   139  			output: New(
   140  				Of(Float64),
   141  				WithShape(3, 4),
   142  				WithBacking([]float64{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}),
   143  			),
   144  			grad: New(
   145  				Of(Float64),
   146  				WithShape(3, 4),
   147  				WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
   148  			),
   149  			axis:           -1,
   150  			expectedOutput: []float64{-0.011383822036598441, -0.003632778232153768, 0.0038817407844924366, 0.01113485948425977, -0.005597937295155945, -0.001445223403599799, 0.0020925260396803457, 0.004950634659075405, 0.00018794744628654992, 0.0007423314249541871, 0.00030331129486827163, -0.0012335901661089598},
   151  		},
   152  		{
   153  			fn: SoftMaxB,
   154  			output: New(
   155  				Of(Float64),
   156  				WithShape(3, 2, 2),
   157  				WithBacking([]float64{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}),
   158  			),
   159  			grad: New(
   160  				Of(Float64),
   161  				WithShape(3, 2, 2),
   162  				WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
   163  			),
   164  			axis:           -1,
   165  			expectedOutput: []float64{-0.002493760401928919, 0.0024937604019289205, -0.0024937604019289183, 0.002493760401928922, -0.002493760401928915, 0.002493760401928922, -0.002493760401928912, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289183},
   166  		},
   167  		{
   168  			fn: SoftMaxB,
   169  			output: New(
   170  				Of(Float64),
   171  				WithShape(3, 2, 2),
   172  				WithBacking([]float64{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}),
   173  			),
   174  			grad: New(
   175  				Of(Float64),
   176  				WithShape(3, 2, 2),
   177  				WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
   178  			),
   179  			axis:           1,
   180  			expectedOutput: []float64{-0.004950331454237199, -0.004950331454237198, 0.004950331454237199, 0.0049503314542372, -0.004950331454237196, -0.004950331454237193, 0.004950331454237203, 0.0049503314542372065, -0.004950331454237193, -0.0049503314542372, 0.0049503314542372065, 0.004950331454237193},
   181  		},
   182  		{
   183  			fn: LogSoftMaxB,
   184  			output: New(
   185  				Of(Float64),
   186  				WithShape(3, 2, 2),
   187  				WithBacking([]float64{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}),
   188  			),
   189  			grad: New(
   190  				Of(Float64),
   191  				WithShape(3, 2, 2),
   192  				WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
   193  			),
   194  			axis:           1,
   195  			expectedOutput: []float64{-0.008006640107500884, -0.007009960161251325, 0.00800664010750088, 0.007009960161251332, -0.004019920322502654, -0.003023240376253103, 0.004019920322502661, 0.0030232403762530968, -3.32005375044292e-05, 0.0009634794087451421, 3.320053750442642e-05, -0.0009634794087451543},
   196  		},
   197  		{
   198  			fn: LogSoftMaxB,
   199  			output: New(
   200  				Of(Float32),
   201  				WithShape(3, 2, 2),
   202  				WithBacking([]float32{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}),
   203  			),
   204  			grad: New(
   205  				Of(Float32),
   206  				WithShape(3, 2, 2),
   207  				WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
   208  			),
   209  			axis:           1,
   210  			expectedOutput: []float64{-0.008006640107500884, -0.007009960161251325, 0.00800664010750088, 0.007009960161251332, -0.004019920322502654, -0.003023240376253103, 0.004019920322502661, 0.0030232403762530968, -3.32005375044292e-05, 0.0009634794087451421, 3.320053750442642e-05, -0.0009634794087451543},
   211  		},
   212  		{
   213  			fn: SoftMaxB,
   214  			output: New(
   215  				Of(Float32),
   216  				WithShape(3, 2, 2),
   217  				WithBacking([]float32{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}),
   218  			),
   219  			grad: New(
   220  				Of(Float32),
   221  				WithShape(3, 2, 2),
   222  				WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
   223  			),
   224  			axis:           1,
   225  			expectedOutput: []float32{-0.004950331454237199, -0.004950331454237198, 0.004950331454237199, 0.0049503314542372, -0.004950331454237196, -0.004950331454237193, 0.004950331454237203, 0.0049503314542372065, -0.004950331454237193, -0.0049503314542372, 0.0049503314542372065, 0.004950331454237193},
   226  		},
   227  		{
   228  			fn: SoftMaxB,
   229  			output: New(
   230  				Of(Float32),
   231  				WithShape(3, 4),
   232  				WithBacking([]float32{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}),
   233  			),
   234  			grad: New(
   235  				Of(Float64),
   236  				WithShape(3, 4),
   237  				WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
   238  			),
   239  			axis:           -1,
   240  			expectedOutput: []float32{-0.003474116568224552, -0.0014762147035963322, 0.0009803563066858392, 0.00396997522759976, -0.003474116880376028, -0.001476214931490494, 0.0009803561238580223, 0.003969975025543781, -0.0034741159267098936, -0.0014762139946130218, 0.0009803570151630109, 0.003969976093553957},
   241  		},
   242  		{
   243  			fn: LogSoftMaxB,
   244  			output: New(
   245  				Of(Float64),
   246  				WithShape(3, 4),
   247  				WithBacking([]float32{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}),
   248  			),
   249  			grad: New(
   250  				Of(Float64),
   251  				WithShape(3, 4),
   252  				WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
   253  			),
   254  			axis:           -1,
   255  			expectedOutput: []float32{-0.011383822036598441, -0.003632778232153768, 0.0038817407844924366, 0.01113485948425977, -0.005597937295155945, -0.001445223403599799, 0.0020925260396803457, 0.004950634659075405, 0.00018794744628654992, 0.0007423314249541871, 0.00030331129486827163, -0.0012335901661089598},
   256  		},
   257  		{
   258  			fn: SoftMaxB,
   259  			output: New(
   260  				Of(Float64),
   261  				WithShape(3, 2, 2),
   262  				WithBacking([]float32{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}),
   263  			),
   264  			grad: New(
   265  				Of(Float64),
   266  				WithShape(3, 2, 2),
   267  				WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
   268  			),
   269  			axis:           -1,
   270  			expectedOutput: []float32{-0.002493760401928919, 0.0024937604019289205, -0.0024937604019289183, 0.002493760401928922, -0.002493760401928915, 0.002493760401928922, -0.002493760401928912, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289183},
   271  		},
   272  	}
   273  	for i, tC := range testCases {
   274  		t.Run(fmt.Sprintf("Example #%d - %v %v", i+1, tC.output.Shape(), tC.output.Dtype()), func(t *testing.T) {
   275  			c := assert.New(t)
   276  
   277  			dx, err := tC.fn(tC.output, tC.grad, tC.axis)
   278  			t.Logf("output: %#v", tC.output.Data())
   279  
   280  			c.NoError(err)
   281  			c.NotNil(dx)
   282  
   283  			c.Equal(tC.output.Shape(), dx.Shape())
   284  			c.InDeltaSlice(tC.expectedOutput, dx.Data(), 1e-6)
   285  		})
   286  	}
   287  }