github.com/wzzhu/tensor@v0.9.24/genlib2/dense_reduction_tests.go (about) 1 package main 2 3 import ( 4 "io" 5 "text/template" 6 ) 7 8 const testDenseReduceRaw = `var denseReductionTests = []struct { 9 of Dtype 10 fn interface{} 11 def interface{} 12 axis int 13 14 correct interface{} 15 correctShape Shape 16 }{ 17 {{range .Kinds -}} 18 {{if isNumber . -}} 19 // {{.}} 20 { {{asType . | title}}, execution.Add{{short .}}, {{asType .}}(0), 0, []{{asType .}}{6, 8, 10, 12, 14, 16}, Shape{3,2} }, 21 { {{asType . | title}}, execution.Add{{short .}}, {{asType .}}(0), 1, []{{asType .}}{6, 9, 24, 27}, Shape{2, 2}}, 22 { {{asType . | title}}, execution.Add{{short .}}, {{asType .}}(0), 2, []{{asType .}}{1, 5, 9, 13, 17, 21}, Shape{2, 3}}, 23 {{end -}} 24 {{end -}} 25 } 26 27 func TestDense_Reduce(t *testing.T){ 28 assert := assert.New(t) 29 for _, drt := range denseReductionTests { 30 T := New(WithShape(2,3,2), WithBacking(Range(drt.of, 0, 2*3*2))) 31 T2, err := T.Reduce(drt.fn, drt.axis, drt.def, ) 32 if err != nil { 33 t.Error(err) 34 continue 35 } 36 assert.True(drt.correctShape.Eq(T2.Shape())) 37 assert.Equal(drt.correct, T2.Data()) 38 39 // stupids: 40 _, err = T.Reduce(drt.fn, 1000, drt.def,) 41 assert.NotNil(err) 42 43 // wrong function type 44 var f interface{} 45 f = func(a, b float64)float64{return 0} 46 if drt.of == Float64 { 47 f = func(a, b int)int{return 0} 48 } 49 50 _, err = T.Reduce(f, 0, drt.correct) 51 assert.NotNil(err) 52 53 // wrong default value type 54 var def2 interface{} 55 def2 = 3.14 56 if drt.of == Float64 { 57 def2 = int(1) 58 } 59 60 _, err = T.Reduce(drt.fn, 3, def2) // only last axis requires a default value 61 assert.NotNil(err) 62 } 63 } 64 ` 65 66 var ( 67 testDenseReduce *template.Template 68 ) 69 70 func init() { 71 testDenseReduce = template.Must(template.New("testDenseReduce").Funcs(funcs).Parse(testDenseReduceRaw)) 72 } 73 74 func generateDenseReductionTests(f io.Writer, generic Kinds) { 75 testDenseReduce.Execute(f, generic) 76 }