gorgonia.org/gorgonia@v0.9.17/op_by_indices_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"log"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/require"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  func TestByIndicesOpDo(t *testing.T) {
    12  	testCases := []struct {
    13  		desc           string
    14  		input          tensor.Tensor
    15  		indices        tensor.Tensor
    16  		axis           int
    17  		expectedOutput []float64
    18  		expectedShape  tensor.Shape
    19  	}{
    20  		{
    21  			desc: "Example 1",
    22  			input: tensor.New(
    23  				tensor.WithShape(4, 2),
    24  				tensor.WithBacking(tensor.Range(tensor.Float64, 0, 8)),
    25  			),
    26  			indices: tensor.New(
    27  				tensor.WithShape(4),
    28  				tensor.WithBacking([]int{0, 3, 2, 1}),
    29  			),
    30  			axis:           0,
    31  			expectedOutput: []float64{0, 1, 6, 7, 4, 5, 2, 3},
    32  			expectedShape:  tensor.Shape{4, 2},
    33  		},
    34  		{
    35  			// 0 1 2
    36  			// 3 4 5
    37  			desc: "Example 2",
    38  			input: tensor.New(
    39  				tensor.WithShape(2, 3),
    40  				tensor.WithBacking(tensor.Range(tensor.Float64, 0, 6)),
    41  			),
    42  			indices: tensor.New(
    43  				tensor.WithShape(4),
    44  				tensor.WithBacking([]int{0, 2, 1, 1}),
    45  			),
    46  			axis:           1,
    47  			expectedOutput: []float64{0, 2, 1, 1, 3, 5, 4, 4},
    48  			expectedShape:  tensor.Shape{2, 4},
    49  		},
    50  		{
    51  			desc: "Example 3",
    52  			input: tensor.New(
    53  				tensor.WithShape(2, 5),
    54  				tensor.WithBacking(tensor.Range(tensor.Float64, 0, 10)),
    55  			),
    56  			indices: tensor.New(
    57  				tensor.WithShape(2),
    58  				tensor.WithBacking([]int{1, 1}),
    59  			),
    60  			axis:           0,
    61  			expectedOutput: []float64{5, 6, 7, 8, 9, 5, 6, 7, 8, 9},
    62  			expectedShape:  tensor.Shape{2, 5},
    63  		},
    64  	}
    65  
    66  	for _, tcase := range testCases {
    67  		t.Run(tcase.desc, func(t *testing.T) {
    68  			c := require.New(t)
    69  
    70  			op := newByIndicesOp(tcase.axis)
    71  
    72  			inputV, _, _, err := anyToValue(tcase.input)
    73  			c.NoError(err)
    74  
    75  			indicesV, _, _, err := anyToValue(tcase.indices)
    76  			c.NoError(err)
    77  
    78  			output, err := op.Do(inputV, indicesV)
    79  			c.NoError(err)
    80  
    81  			c.Equal(tcase.expectedOutput, output.Data())
    82  			c.Equal(tcase.expectedShape, output.Shape())
    83  		})
    84  	}
    85  }
    86  
    87  func TestByIndicesOpFull(t *testing.T) {
    88  	testCases := []struct {
    89  		desc           string
    90  		input          tensor.Tensor
    91  		indices        tensor.Tensor
    92  		axis           int
    93  		expectedOutput []float64
    94  		expectedShape  tensor.Shape
    95  	}{
    96  		{
    97  			desc: "Example 1",
    98  			input: tensor.New(
    99  				tensor.WithShape(4, 2),
   100  				tensor.WithBacking(tensor.Range(tensor.Float64, 0, 8)),
   101  			),
   102  			indices: tensor.New(
   103  				tensor.WithShape(4),
   104  				tensor.WithBacking([]int{0, 3, 2, 1}),
   105  			),
   106  			axis:           0,
   107  			expectedOutput: []float64{0, 1, 6, 7, 4, 5, 2, 3},
   108  			expectedShape:  tensor.Shape{4, 2},
   109  		},
   110  	}
   111  
   112  	for _, tcase := range testCases {
   113  		t.Run(tcase.desc, func(t *testing.T) {
   114  			c := require.New(t)
   115  
   116  			g := NewGraph()
   117  
   118  			indices := NewTensor(g, tensor.Int, 1, WithName("indices"), WithShape(tcase.indices.Shape().TotalSize()), WithValue(tcase.indices))
   119  			input := NewTensor(g, tensor.Float64, tcase.input.Shape().Dims(), WithName("input"), WithShape(tcase.input.Shape()...), WithValue(tcase.input))
   120  
   121  			output, err := ByIndices(input, indices, tcase.axis)
   122  			c.NoError(err)
   123  
   124  			log.Printf("output shape: %v", output.Shape())
   125  			log.Printf("input shape: %v", input.Shape())
   126  
   127  			y := NewTensor(g, tensor.Float64, tcase.input.Shape().Dims(), WithName("target"), WithShape(tcase.input.Shape()...), WithValue(tcase.input))
   128  
   129  			cost := Must(Mean(Must((Sub(output, y))))) // MSE
   130  
   131  			_, err = Grad(cost, input)
   132  			c.NoError(err)
   133  
   134  			// logger := log.New(os.Stdout, "", 0)
   135  
   136  			vm := NewTapeMachine(
   137  				g,
   138  				//WithLogger(logger),
   139  				WithWatchlist(),
   140  				BindDualValues(output),
   141  				TraceExec(),
   142  			)
   143  
   144  			c.NoError(vm.RunAll())
   145  			c.NoError(vm.Close())
   146  
   147  			log.Printf("input %v", input.Value())
   148  			log.Printf("result: %v", output.Value())
   149  			log.Printf("cost: %v", cost.Value())
   150  
   151  			c.Equal(tcase.expectedOutput, output.Value().Data())
   152  			c.Equal(tcase.expectedShape, output.Shape())
   153  			c.Equal(0.0, cost.Value().Data().(float64))
   154  		})
   155  	}
   156  }