github.com/wzzhu/tensor@v0.9.24/genlib2/dense_reduction_tests.go (about)

     1  package main
     2  
     3  import (
     4  	"io"
     5  	"text/template"
     6  )
     7  
     8  const testDenseReduceRaw = `var denseReductionTests = []struct {
     9  	of Dtype
    10  	fn interface{}
    11  	def interface{}
    12  	axis int
    13  
    14  	correct interface{}
    15  	correctShape Shape
    16  }{
    17  	{{range .Kinds -}}
    18  	{{if isNumber . -}}
    19  	// {{.}}
    20  	{ {{asType . | title}}, execution.Add{{short .}}, {{asType .}}(0), 0, []{{asType .}}{6, 8, 10, 12, 14, 16}, Shape{3,2} },
    21  	{ {{asType . | title}}, execution.Add{{short .}}, {{asType .}}(0), 1, []{{asType .}}{6, 9, 24, 27}, Shape{2, 2}},
    22  	{ {{asType . | title}}, execution.Add{{short .}}, {{asType .}}(0), 2, []{{asType .}}{1, 5, 9, 13, 17, 21}, Shape{2, 3}},
    23  	{{end -}}
    24  	{{end -}}
    25  }
    26  
    27  func TestDense_Reduce(t *testing.T){
    28  	assert := assert.New(t)
    29  	for _, drt := range denseReductionTests {
    30  		T := New(WithShape(2,3,2), WithBacking(Range(drt.of, 0, 2*3*2)))
    31  		T2, err := T.Reduce(drt.fn, drt.axis, drt.def, )
    32  		if err != nil {
    33  			t.Error(err)
    34  			continue
    35  		}
    36  		assert.True(drt.correctShape.Eq(T2.Shape()))
    37  		assert.Equal(drt.correct, T2.Data())
    38  
    39  		// stupids:
    40  		_, err = T.Reduce(drt.fn, 1000, drt.def,)
    41  		assert.NotNil(err)
    42  
    43  		// wrong function type
    44  		var f interface{}
    45  		f = func(a, b float64)float64{return 0}
    46  		if drt.of == Float64 {
    47  			f = func(a, b int)int{return 0}
    48  		}
    49  
    50  		_, err = T.Reduce(f, 0, drt.correct)
    51  		assert.NotNil(err)
    52  
    53  		// wrong default value type
    54  		var def2 interface{}
    55  		def2 = 3.14
    56  		if drt.of == Float64 {
    57  			def2 = int(1)
    58  		}
    59  
    60  		_, err = T.Reduce(drt.fn, 3, def2) // only last axis requires a default value
    61  		assert.NotNil(err)
    62  	}
    63  }
    64  `
    65  
    66  var (
    67  	testDenseReduce *template.Template
    68  )
    69  
    70  func init() {
    71  	testDenseReduce = template.Must(template.New("testDenseReduce").Funcs(funcs).Parse(testDenseReduceRaw))
    72  }
    73  
    74  func generateDenseReductionTests(f io.Writer, generic Kinds) {
    75  	testDenseReduce.Execute(f, generic)
    76  }