gorgonia.org/gorgonia@v0.9.17/op_sparsemax_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/require"
     7  	"gorgonia.org/tensor"
     8  )
     9  
    10  var testCasesSparseMaxDo = []struct {
    11  	size     tensor.Shape
    12  	input    interface{}
    13  	expected interface{}
    14  	axis     int
    15  }{
    16  	{
    17  		tensor.Shape{4}, []float64{0.3, 0.1, 1.2, 2.3}, []float64{0, 0, 0, 1.0}, -1,
    18  	},
    19  	{
    20  		tensor.Shape{10}, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, -1,
    21  	},
    22  	{
    23  		tensor.Shape{3}, []float64{0.1, 0.1, 0.1}, []float64{0.3333333333333333, 0.3333333333333333, 0.3333333333333333}, -1,
    24  	},
    25  	{
    26  		tensor.Shape{4}, []float64{-0.1, 0.3, -1.1, 2.7}, []float64{0, 0, 0, 1.0}, -1,
    27  	},
    28  	{
    29  		tensor.Shape{4}, []float32{0.3, 0.1, 1.2, 2.3}, []float32{0, 0, 0, 1.0}, -1,
    30  	},
    31  	{
    32  		tensor.Shape{10}, []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, -1,
    33  	},
    34  	{
    35  		tensor.Shape{3}, []float32{0.1, 0.1, 0.1}, []float32{0.33333334, 0.33333334, 0.33333334}, -1,
    36  	},
    37  	{
    38  		tensor.Shape{4}, []float32{-0.1, 0.3, -1.1, 2.7}, []float32{0, 0, 0, 1.0}, -1,
    39  	},
    40  	{
    41  		tensor.Shape{4}, []float64{0.9, 0.9, 0.9, 0.5}, []float64{0.3333333333333333, 0.3333333333333333, 0.3333333333333333, 0.0000}, -1,
    42  	},
    43  	{
    44  		tensor.Shape{6, 2},
    45  		[]float64{-1.0000, -1.0000, 1.0000, 1.0000, -0.9998, -0.9998, 0.9998, 0.9998, 0.9945, 0.9945, -0.9945, -0.9945},
    46  		[]float64{0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5},
    47  		-1,
    48  	},
    49  	// {
    50  	// 	tensor.Shape{6, 2},
    51  	// 	[]float64{-1.0, -1.0, 1.0, 1.0, -0.9998, -0.9998, 0.9998, 0.9998, 0.9945, 0.9945, -0.9945, -0.9945},
    52  	// 	[]float64{0.0000, 0.0000, 0.3352, 0.3352, 0.0000, 0.0000, 0.3350, 0.3350, 0.3297, 0.3297, 0.0000, 0.0000},
    53  	// 	0, // TODO
    54  	// },
    55  	{
    56  		tensor.Shape{6, 2},
    57  		[]float32{-1.0000, -1.0000, 1.0000, 1.0000, -0.9998, -0.9998, 0.9998, 0.9998, 0.9945, 0.9945, -0.9945, -0.9945},
    58  		[]float32{0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000},
    59  		-1,
    60  	},
    61  }
    62  
    63  var testCasesSparseMaxDoDiff = []struct {
    64  	shape tensor.Shape
    65  	input interface{}
    66  	grad  interface{}
    67  
    68  	expected      interface{}
    69  	expectedShape tensor.Shape
    70  }{
    71  	{
    72  		tensor.Shape{5},
    73  		[]float64{1.9968e-05, 1.9968e-05, 5.2120e-02, 2.3542e-01, 7.1242e-01},
    74  		[]float64{0.2860, -0.0702, 0.8080, 0.9913, 1.4683},
    75  		[]float64{-0.41068, -0.76688, 0.11132000000000009, 0.29462, 0.77162},
    76  		tensor.Shape{5},
    77  	},
    78  	{
    79  		tensor.Shape{5},
    80  		[]float64{5.5620e-02, 2.0027e-05, 7.1182e-01, 2.3252e-01, 2.0027e-05},
    81  		[]float64{0.1109, -1.4741, 0.7671, 0.2878, 0.0334},
    82  		[]float64{0.16588, -1.41912, 0.82208, 0.34278, 0.08837999999999999},
    83  		tensor.Shape{5},
    84  	},
    85  	{
    86  		tensor.Shape{5},
    87  		[]float64{0.0369, 0.3210, 0.0000, 0.3210, 0.3210},
    88  		[]float64{0.2094, -1.0000, 0.6411, -0.5032, -0.3909},
    89  		[]float64{0.630575, -0.5788249999999999, 0, -0.08202499999999996, 0.030274999999999996},
    90  		tensor.Shape{5},
    91  	},
    92  	{
    93  		tensor.Shape{5},
    94  		[]float64{0.2592, 0.0000, 0.6909, 0.0498, 0.0000},
    95  		[]float64{0.2094, -1.0000, 0.6411, 0.0000, -0.3909},
    96  		[]float64{-0.07410000000000003, 0, 0.3576, -0.28350000000000003, 0},
    97  		tensor.Shape{5},
    98  	},
    99  	{
   100  		tensor.Shape{5},
   101  		[]float32{0.0000, 0.0000, 0.0521, 0.2354, 0.7124},
   102  		[]float32{0.2860, -0.0702, 0.8080, 0.9913, 1.4683},
   103  		[]float32{-0, -0, -0.2812, -0.09790003, 0.37909997},
   104  		tensor.Shape{5},
   105  	},
   106  	{
   107  		tensor.Shape{5},
   108  		[]float32{0.0556, 0.0000, 0.7118, 0.2325, 0.0000},
   109  		[]float32{0.1109, -1.4741, 0.7671, 0.2878, 0.0334},
   110  		[]float32{-0.2777, -0, 0.37849998, -0.10079998, -0},
   111  		tensor.Shape{5},
   112  	},
   113  	{
   114  		tensor.Shape{5},
   115  		[]float32{0.2841, 0.0000, 0.7159, 0.0000, 0.0000},
   116  		[]float32{0.2094, -1.0000, 0.6411, -0.5032, -0.3909},
   117  		[]float32{-0.21585, -0, 0.21585, -0, -0},
   118  		tensor.Shape{5},
   119  	},
   120  	{
   121  		tensor.Shape{5},
   122  		[]float32{0.2592, 0.0000, 0.6909, 0.0498, 0.0000},
   123  		[]float32{0.2094, -1.0000, 0.6411, 0.0000, -0.3909},
   124  		[]float32{-0.07409999, -0, 0.3576, -0.2835, -0},
   125  		tensor.Shape{5},
   126  	},
   127  	{
   128  		tensor.Shape{5, 1},
   129  		[]float32{1, 1, 1, 1, 1},
   130  		[]float32{0.2094, -1.0000, 0.6411, -0.5032, -0.3909},
   131  		[]float32{1.253, 0.043599963, 1.6847, 0.54039997, 0.65269995, 1.253, 0.043599963, 1.6847, 0.54039997, 0.65269995, 1.253, 0.043599963, 1.6847, 0.54039997, 0.65269995, 1.253, 0.043599963, 1.6847, 0.54039997, 0.65269995, 1.253, 0.043599963, 1.6847, 0.54039997, 0.65269995},
   132  		tensor.Shape{5, 5},
   133  	},
   134  }
   135  
   136  func TestSparsemaxDo(t *testing.T) {
   137  	c := require.New(t)
   138  
   139  	for i, testCase := range testCasesSparseMaxDo {
   140  		dtype := tensor.Float64
   141  
   142  		switch testCase.input.(type) {
   143  		case []float32:
   144  			dtype = tensor.Float32
   145  		}
   146  
   147  		tt := tensor.New(tensor.Of(dtype), tensor.WithShape(testCase.size...), tensor.WithBacking(testCase.input))
   148  		op := newSparsemaxOp(testCase.axis)
   149  
   150  		out, err := op.Do(tt)
   151  		c.NoError(err, "failed test case: %d", i)
   152  		c.Equal(testCase.expected, out.Data(), "failed test case: %d", i)
   153  	}
   154  }
   155  
   156  func TestSparsemaxDoDiff(t *testing.T) {
   157  	c := require.New(t)
   158  
   159  	for i, testCase := range testCasesSparseMaxDoDiff {
   160  		g := NewGraph()
   161  		a := NewTensor(g, Float64, 1, WithName("a"), WithShape(1))
   162  		b := NewTensor(g, Float64, 1, WithName("b"), WithShape(1))
   163  
   164  		op := newSparsemaxOp()
   165  		r, err := ApplyOp(op, a)
   166  		c.NoError(err)
   167  
   168  		var backing interface{}
   169  
   170  		switch testCase.input.(type) {
   171  		case []float64:
   172  			backing = make([]float64, testCase.expectedShape.TotalSize())
   173  		case []float32:
   174  			backing = make([]float32, testCase.expectedShape.TotalSize())
   175  		}
   176  
   177  		aT := tensor.New(tensor.WithShape(testCase.shape...), tensor.WithBacking(testCase.input))
   178  		bT := tensor.New(tensor.WithShape(testCase.shape.TotalSize()), tensor.WithBacking(testCase.grad))
   179  		rT := tensor.New(tensor.WithShape(testCase.expectedShape...), tensor.WithBacking(backing))
   180  
   181  		aVal, _, _, _ := anyToValue(aT)
   182  		bVal, _, _, _ := anyToValue(bT)
   183  		rVal, _, _, _ := anyToValue(rT)
   184  
   185  		a.bind(dvUnit(aVal))
   186  		b.bind(dvUnit(bVal))
   187  		r.bind(dvUnitVar(rVal))
   188  
   189  		err = op.DoDiff(ExecutionContext{}, Nodes{a, b}, r)
   190  		c.NoError(err, "failed test case: %d", i)
   191  
   192  		c.Equal(testCase.expected, r.boundTo.Data())
   193  	}
   194  }
   195  
   196  func TestSparsemaxDoSymDiff(t *testing.T) {
   197  	c := require.New(t)
   198  
   199  	for i, testCase := range testCasesSparseMaxDoDiff {
   200  		g := NewGraph()
   201  		a := NewTensor(g, Float64, 1, WithName("a"), WithShape(1))
   202  		b := NewTensor(g, Float64, 1, WithName("b"), WithShape(1))
   203  
   204  		aT := tensor.New(tensor.WithShape(testCase.shape...), tensor.WithBacking(testCase.input))
   205  		bT := tensor.New(tensor.WithShape(testCase.shape.TotalSize()), tensor.WithBacking(testCase.grad))
   206  
   207  		aVal, _, _, _ := anyToValue(aT)
   208  		bVal, _, _, _ := anyToValue(bT)
   209  
   210  		a.bind(dvUnit(aVal))
   211  		b.bind(dvUnit(bVal))
   212  
   213  		op := newSparsemaxOp()
   214  		diff, err := op.SymDiff(Nodes{a}, nil, b)
   215  		c.NoError(err, "failed test case: %d", i)
   216  
   217  		c.Len(diff, 1)
   218  
   219  		vm := NewTapeMachine(g)
   220  
   221  		c.NoError(vm.RunAll())
   222  		c.NoError(vm.Close())
   223  
   224  		c.Equal(testCase.expected, diff[0].boundTo.Data(), "failed test case: %d", i)
   225  	}
   226  }
   227  
   228  func TestSparsemaxFull(t *testing.T) {
   229  	c := require.New(t)
   230  
   231  	for i, testCase := range testCasesSparseMaxDo {
   232  		dtype := tensor.Float64
   233  
   234  		if _, ok := testCase.input.([]float32); ok {
   235  			dtype = tensor.Float32
   236  		}
   237  
   238  		tt := tensor.New(tensor.Of(dtype), tensor.WithShape(testCase.size...), tensor.WithBacking(testCase.input))
   239  		expected := tensor.New(tensor.Of(dtype), tensor.WithShape(testCase.size...), tensor.WithBacking(testCase.expected))
   240  
   241  		g := NewGraph()
   242  		inp := NewTensor(g, dtype, testCase.size.Dims(), WithShape(testCase.size...), WithName("inp"))
   243  		out := Must(Sparsemax(inp, testCase.axis))
   244  
   245  		vm := NewTapeMachine(g)
   246  		err := Let(inp, tt)
   247  		c.NoError(err, "failed assigning input on case %d", i)
   248  
   249  		c.NoError(vm.RunAll())
   250  		c.NoError(vm.Close())
   251  
   252  		c.Equal(expected.Data(), out.Value().(*tensor.Dense).Data(), "output is not equal to expected value for case %d", i)
   253  	}
   254  }