github.com/wzzhu/tensor@v0.9.24/genlib2/dense_argmethods_tests.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"reflect"
     7  	"text/template"
     8  )
     9  
    10  type ArgMethodTestData struct {
    11  	Kind reflect.Kind
    12  	Data []int
    13  }
    14  
    15  var data = []int{
    16  	3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5,
    17  	1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8,
    18  	0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5,
    19  	0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1,
    20  	9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8,
    21  	0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3,
    22  	7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7,
    23  	2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9,
    24  	5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1,
    25  	7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5,
    26  	6, 2, 9, 4, 4, 2, 4, 4, 4, 3,
    27  }
    28  
    29  const argMethodsDataRaw = `var basicDense{{short .Kind}} = New(WithShape(2,3,4,5,2), WithBacking([]{{asType .Kind}}{ {{range .Data -}}{{.}}, {{end -}} }))
    30  `
    31  
    32  const argmaxCorrect = `var argmaxCorrect = []struct {
    33  	shape Shape
    34  	data  []int
    35  }{
    36  	{Shape{3,4,5,2}, []int{
    37  		1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1,
    38  		1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0,
    39  		1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1,
    40  		1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1,
    41  		0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0,
    42  		1, 0, 0, 0, 0,
    43  	}},
    44  	{Shape{2,4,5,2}, []int{
    45  		1, 0, 1, 1, 2, 0, 2, 0, 0, 1, 2, 1, 2, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1,
    46  		2, 2, 0, 1, 1, 2, 2, 1, 0, 2, 0, 2, 0, 2, 2, 1, 0, 0, 0, 0, 0, 1, 0,
    47  		0, 0, 2, 1, 0, 1, 2, 1, 0, 1, 1, 2, 0, 1, 0, 0, 0, 0, 2, 1, 0, 1, 0,
    48  		0, 2, 1, 1, 0, 0, 0, 0, 0, 2, 0,
    49  	}},
    50  	{Shape{2,3,5,2}, []int{
    51  		3, 2, 2, 1, 1, 2, 1, 0, 0, 1, 3, 2, 1, 0, 1, 0, 2, 2, 3, 0, 1, 0, 1,
    52  		3, 0, 2, 3, 3, 2, 1, 2, 2, 0, 0, 1, 3, 2, 0, 1, 2, 0, 3, 0, 1, 0, 1,
    53  		3, 2, 2, 1, 2, 1, 3, 1, 2, 0, 2, 2, 0, 0,
    54  	}},
    55  	{Shape{2,3,4,2}, []int{
    56  		4, 3, 2, 1, 1, 2, 0, 1, 1, 1, 1, 3, 1, 0, 0, 2, 2, 1, 0, 4, 2, 2, 3,
    57  		1, 1, 1, 0, 2, 0, 0, 2, 2, 1, 4, 0, 1, 4, 1, 1, 0, 4, 3, 1, 1, 2, 3,
    58  		1, 1,
    59  	}},
    60  	{Shape{2,3,4,5}, []int{
    61  		1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1,
    62  		1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0,
    63  		0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1,
    64  		0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1,
    65  		1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
    66  		0, 0, 0, 0, 0,
    67  	}},
    68  }
    69  `
    70  
    71  const argminCorrect = `var argminCorrect = []struct {
    72  	shape Shape
    73  	data []int
    74  }{
    75  	{Shape{3,4,5,2}, []int{
    76  		0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0,
    77  		0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1,
    78  		0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0,
    79  		0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0,
    80  		1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1,
    81  		0, 1, 1, 0, 1,
    82  	}},
    83  	{Shape{2,4,5,2}, []int{
    84  		2, 1, 0, 0, 1, 2, 1, 2, 1, 2, 1, 0, 0, 2, 1, 0, 1, 2, 0, 1, 0, 2, 2,
    85  		0, 0, 1, 2, 0, 0, 1, 2, 1, 0, 1, 0, 2, 0, 1, 0, 1, 2, 1, 2, 1, 2, 1,
    86  		2, 1, 1, 0, 2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 2, 2, 0, 0, 1, 0, 2,
    87  		2, 0, 0, 0, 1, 2, 2, 2, 2, 1, 1,
    88  	}},
    89  	{Shape{2,3,5,2}, []int{
    90  		0, 1, 0, 2, 2, 1, 3, 2, 3, 2, 1, 0, 3, 3, 0, 1, 0, 3, 0, 2, 0, 1, 0,
    91  		1, 3, 0, 2, 1, 0, 0, 3, 1, 3, 1, 2, 2, 1, 2, 0, 1, 3, 0, 1, 0, 1, 0,
    92  		2, 1, 0, 3, 0, 2, 0, 0, 0, 1, 0, 1, 1, 1,
    93  	}},
    94  	{Shape{2,3,4,2}, []int{
    95  		1, 0, 0, 0, 2, 3, 4, 0, 3, 0, 3, 0, 4, 4, 3, 1, 0, 2, 3, 0, 3, 0, 0,
    96  		2, 4, 4, 3, 4, 2, 3, 0, 0, 4, 0, 1, 3, 3, 2, 0, 4, 2, 1, 4, 2, 4, 0,
    97  		2, 0,
    98  	}},
    99  	{Shape{2,3,4,5}, []int{
   100  		0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0,
   101  		0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1,
   102  		1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0,
   103  		1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0,
   104  		0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0,
   105  		1, 1, 1, 0, 1,
   106  	}},
   107  }
   108  `
   109  
   110  type ArgMethodTest struct {
   111  	Kind       reflect.Kind
   112  	ArgMethod  string
   113  	ArgAllAxes int
   114  }
   115  
   116  const testArgMethodsRaw = `func TestDense_{{title .ArgMethod}}_{{short .Kind}}(t *testing.T){
   117  	assert := assert.New(t)
   118  	var T, {{.ArgMethod}} *Dense
   119  	var err error
   120  	T = basicDense{{short .Kind}}.Clone().(*Dense)
   121  	for i:= 0; i < T.Dims(); i++ {
   122  		if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(i); err != nil {
   123  			t.Error(err)
   124  			continue
   125  		}
   126  
   127  		assert.True({{.ArgMethod}}Correct[i].shape.Eq({{.ArgMethod}}.Shape()), "{{title .ArgMethod}}(%d) error. Want shape %v. Got %v", i, {{.ArgMethod}}Correct[i].shape)
   128  		assert.Equal({{.ArgMethod}}Correct[i].data, {{.ArgMethod}}.Data(), "{{title .ArgMethod}}(%d) error. ", i)
   129  	}
   130  	// test all axes
   131  	if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil {
   132  		t.Error(err)
   133  		return
   134  	}
   135  	assert.True({{.ArgMethod}}.IsScalar())
   136  	assert.Equal({{.ArgAllAxes}}, {{.ArgMethod}}.ScalarValue())
   137  
   138  	{{if hasPrefix .Kind.String "float" -}}
   139  	// test with NaN
   140  	T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,2,{{mathPkg .Kind}}NaN(), 4}))
   141  	if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil {
   142  		t.Errorf("Failed test with NaN: %v", err)
   143  	}
   144  	assert.True({{.ArgMethod}}.IsScalar())
   145  	assert.Equal(2, {{.ArgMethod}}.ScalarValue(), "NaN test")
   146  
   147  	// test with Mask and Nan
   148  	T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,{{if eq .ArgMethod "argmax"}}9{{else}}-9{{end}},{{mathPkg .Kind}}NaN(), 4}, []bool{false,true,true,false}))
   149  	if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil {
   150  		t.Errorf("Failed test with NaN: %v", err)
   151  	}		
   152  	assert.True({{.ArgMethod}}.IsScalar())
   153  	assert.Equal({{if eq .ArgMethod "argmin"}}0{{else}}3{{end}}, {{.ArgMethod}}.ScalarValue(), "Masked NaN test")
   154  
   155  	// test with +Inf
   156  	T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,2,{{mathPkg .Kind}}Inf(1),4}))
   157  	if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil {
   158  		t.Errorf("Failed test with +Inf: %v", err)
   159  	}
   160  	assert.True({{.ArgMethod}}.IsScalar())
   161  	assert.Equal({{if eq .ArgMethod "argmax"}}2{{else}}0{{end}}, {{.ArgMethod}}.ScalarValue(), "+Inf test")
   162  
   163     // test with Mask and +Inf
   164  	T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,{{if eq .ArgMethod "argmax"}}9{{else}}-9{{end}},{{mathPkg .Kind}}Inf(1), 4}, []bool{false,true,true,false}))
   165  	if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil {
   166  		t.Errorf("Failed test with NaN: %v", err)
   167  	}		
   168  	assert.True({{.ArgMethod}}.IsScalar())
   169  	assert.Equal({{if eq .ArgMethod "argmin"}}0{{else}}3{{end}}, {{.ArgMethod}}.ScalarValue(), "Masked NaN test")
   170      
   171  	// test with -Inf
   172  	T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,2,{{mathPkg .Kind}}Inf(-1),4 }))
   173  	if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil {
   174  		t.Errorf("Failed test with -Inf: %v", err)
   175  	}
   176  	assert.True({{.ArgMethod}}.IsScalar())
   177  	assert.Equal({{if eq .ArgMethod "argmin"}}2{{else}}3{{end}}, {{.ArgMethod}}.ScalarValue(), "+Inf test")
   178  
   179  	// test with Mask and -Inf
   180  	T = New(WithShape(4), WithBacking([]{{asType .Kind}}{1,{{if eq .ArgMethod "argmax"}}9{{else}}-9{{end}},{{mathPkg .Kind}}Inf(-1), 4}, []bool{false,true,true,false}))
   181  	if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(AllAxes); err != nil {
   182  		t.Errorf("Failed test with NaN: %v", err)
   183  	}		
   184  	assert.True({{.ArgMethod}}.IsScalar())
   185  	assert.Equal({{if eq .ArgMethod "argmin"}}0{{else}}3{{end}}, {{.ArgMethod}}.ScalarValue(), "Masked -Inf test")
   186  
   187  	{{end -}}
   188  
   189  	// with different engine
   190  	T = basicDense{{short .Kind}}.Clone().(*Dense)
   191  	WithEngine(dummyEngine2{})(T)
   192  	for i:= 0; i < T.Dims(); i++ {
   193  		if {{.ArgMethod}}, err = T.{{title .ArgMethod}}(i); err != nil {
   194  			t.Error(err)
   195  			continue
   196  		}
   197  
   198  		assert.True({{.ArgMethod}}Correct[i].shape.Eq({{.ArgMethod}}.Shape()), "{{title .ArgMethod}}(%d) error. Want shape %v. Got %v", i, {{.ArgMethod}}Correct[i].shape)
   199  		assert.Equal({{.ArgMethod}}Correct[i].data, {{.ArgMethod}}.Data(), "{{title .ArgMethod}}(%d) error. ", i)
   200  	}
   201  
   202  
   203  
   204  	// idiotsville
   205  	_, err = T.{{title .ArgMethod}}(10000)
   206  	assert.NotNil(err)
   207  
   208  }
   209  `
   210  
   211  var (
   212  	argMethodsData *template.Template
   213  	testArgMethods *template.Template
   214  )
   215  
   216  func init() {
   217  	argMethodsData = template.Must(template.New("argmethodsData").Funcs(funcs).Parse(argMethodsDataRaw))
   218  	testArgMethods = template.Must(template.New("testArgMethod").Funcs(funcs).Parse(testArgMethodsRaw))
   219  }
   220  
   221  func generateArgmethodsTests(f io.Writer, generic Kinds) {
   222  	fmt.Fprintf(f, "/* Test data */\n\n")
   223  	for _, k := range generic.Kinds {
   224  		if isNumber(k) && isOrd(k) {
   225  			op := ArgMethodTestData{k, data}
   226  			argMethodsData.Execute(f, op)
   227  		}
   228  	}
   229  	fmt.Fprintf(f, "\n%s\n%s\n", argmaxCorrect, argminCorrect)
   230  	for _, k := range generic.Kinds {
   231  		if isNumber(k) && isOrd(k) {
   232  			op := ArgMethodTest{k, "argmax", 7}
   233  			testArgMethods.Execute(f, op)
   234  			op = ArgMethodTest{k, "argmin", 11}
   235  			testArgMethods.Execute(f, op)
   236  		}
   237  	}
   238  }