github.com/wzzhu/tensor@v0.9.24/dense_selbyidx_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/assert"
     7  )
     8  
     9  type selByIndicesTest struct {
    10  	Name    string
    11  	Data    interface{}
    12  	Shape   Shape
    13  	Indices []int
    14  	Axis    int
    15  	WillErr bool
    16  
    17  	Correct      interface{}
    18  	CorrectShape Shape
    19  }
    20  
    21  var selByIndicesTests = []selByIndicesTest{
    22  	{Name: "Basic", Data: Range(Float64, 0, 4), Shape: Shape{2, 2}, Indices: []int{0, 1}, Axis: 0, WillErr: false,
    23  		Correct: []float64{0, 1, 2, 3}, CorrectShape: Shape{2, 2},
    24  	},
    25  	{Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false,
    26  		Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}},
    27  
    28  	{Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false,
    29  		Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}},
    30  
    31  	{Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false,
    32  		Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}},
    33  
    34  	{Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false,
    35  		Correct: []int{1, 1}, CorrectShape: Shape{2}},
    36  
    37  	{Name: "Vector, axis 1", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 1, WillErr: true,
    38  		Correct: []int{1, 1}, CorrectShape: Shape{2}},
    39  	{Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false,
    40  		Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}},
    41  	{Name: "(2,1) Matrx (colvec) with (10) indices", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false,
    42  		Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10},
    43  	},
    44  }
    45  
    46  func TestDense_SelectByIndices(t *testing.T) {
    47  	assert := assert.New(t)
    48  	for i, tc := range selByIndicesTests {
    49  		T := New(WithShape(tc.Shape...), WithBacking(tc.Data))
    50  		indices := New(WithBacking(tc.Indices))
    51  		ret, err := ByIndices(T, indices, tc.Axis)
    52  		if checkErr(t, tc.WillErr, err, tc.Name, i) {
    53  			continue
    54  		}
    55  		assert.Equal(tc.Correct, ret.Data())
    56  		assert.True(tc.CorrectShape.Eq(ret.Shape()))
    57  	}
    58  }
    59  
    60  var selByIndicesBTests = []struct {
    61  	selByIndicesTest
    62  
    63  	CorrectGrad      interface{}
    64  	CorrectGradShape Shape
    65  }{
    66  	// Basic
    67  	{
    68  		CorrectGrad: []float64{1, 1, 1, 1},
    69  	},
    70  	// 3-tensor, axis 0
    71  	{
    72  		CorrectGrad: []float64{0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0},
    73  	},
    74  	// 3-tensor, axis 1
    75  	{
    76  		CorrectGrad: []float64{0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2},
    77  	},
    78  	// 3-tensor, axis 2
    79  	{
    80  		CorrectGrad: []float64{0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0},
    81  	},
    82  	// vector, axis 0
    83  	{
    84  		CorrectGrad: []int{0, 2, 0, 0, 0},
    85  	},
    86  	// vector, axis 1
    87  	{
    88  		CorrectGrad: []float32{4, 6, 8, 12, 8, 12, 0, 0},
    89  	},
    90  	// (4,2) Matrix with (10) indices
    91  	{
    92  		CorrectGrad: []float32{2, 2, 4, 4, 4, 4, 0, 0},
    93  	},
    94  	// (2, 1) Matrix (colvec) with (10) indices
    95  	{
    96  		CorrectGrad: []float64{0, 10},
    97  	},
    98  }
    99  
   100  func init() {
   101  	for i := range selByIndicesBTests {
   102  		selByIndicesBTests[i].selByIndicesTest = selByIndicesTests[i]
   103  		selByIndicesBTests[i].CorrectGradShape = selByIndicesTests[i].Shape
   104  	}
   105  }
   106  
   107  func TestDense_SelectByIndicesB(t *testing.T) {
   108  
   109  	assert := assert.New(t)
   110  	for i, tc := range selByIndicesBTests {
   111  		T := New(WithShape(tc.Shape...), WithBacking(tc.Data))
   112  		indices := New(WithBacking(tc.Indices))
   113  		ret, err := ByIndices(T, indices, tc.Axis)
   114  		if checkErr(t, tc.WillErr, err, tc.Name, i) {
   115  			continue
   116  		}
   117  		outGrad := ret.Clone().(*Dense)
   118  		switch outGrad.Dtype() {
   119  		case Float64:
   120  			outGrad.Memset(1.0)
   121  		case Float32:
   122  			outGrad.Memset(float32(1.0))
   123  		}
   124  
   125  		grad, err := ByIndicesB(T, outGrad, indices, tc.Axis)
   126  		if checkErr(t, tc.WillErr, err, tc.Name, i) {
   127  			continue
   128  		}
   129  		assert.Equal(tc.CorrectGrad, grad.Data(), "%v - x:\n%v\nindices:\n%#v\ny:\n%#v\ngrad:\n%v", tc.Name, T, indices, ret, grad)
   130  		assert.True(tc.CorrectGradShape.Eq(grad.Shape()), "%v - Grad shape should be %v. Got %v instead.\n\nx:\n%v\nindices:\n%#v\ny:\n%#v\ngrad:\n%v", tc.Name, tc.CorrectGradShape, grad.Shape(), T, indices, ret, grad)
   131  	}
   132  
   133  }