github.com/wzzhu/tensor@v0.9.24/genlib2/dense_reduction_methods_tests.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"text/template"
     7  )
     8  
     9  const testDenseSumRaw = `var sumTests = []struct {
    10  	name string
    11  	of Dtype
    12  	shape Shape
    13  	along []int
    14  
    15  	correctShape Shape
    16  	correct interface{}
    17  }{
    18  	{{range .Kinds -}}
    19  	{{if isNumber . -}}
    20  	{"common case: T.Sum() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(15)},
    21  	{"A.Sum(0) for {{.}}", {{asType . | title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{3, 5, 7}},
    22  	{"A.Sum(1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType .}}{3, 12}},
    23  	{"A.Sum(0,1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(15)},
    24  	{"A.Sum(1,0) for {{.}}", {{asType . | title}},  Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(15)},
    25  	{"3T.Sum(1,2) for {{.}}", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{66, {{if eq .String "int8"}}-46{{else}}210{{end}} }},
    26  	{"4T.Sum() for {{.}}", {{asType . | title}},  Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(120)},
    27  	{"4T.Sum(1,3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{10, 18, 42, 50}},
    28  	{"4T.Sum(0, 2, 3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{44, 76}},
    29  	{{end -}}
    30  	{{end -}}
    31  }
    32  func TestDense_Sum(t *testing.T){
    33  	assert := assert.New(t)
    34  	var T, T2 *Dense
    35  	var err error
    36  
    37  	for _, sts := range sumTests {
    38  		T = New(WithShape(sts.shape...), WithBacking(Range(sts.of, 0, sts.shape.TotalSize())))
    39  		if T2, err = T.Sum(sts.along ...); err != nil {
    40  			t.Error(err)
    41  			continue
    42  		}
    43  		assert.True(sts.correctShape.Eq(T2.Shape()))
    44  		assert.Equal(sts.correct, T2.Data())
    45  	}
    46  
    47  	// idiots
    48  	_,err =T.Sum(1000)
    49  	assert.NotNil(err)
    50  }
    51  `
    52  
    53  const testDenseMaxRaw = `var maxTests = []struct {
    54  	name  string
    55  	of Dtype
    56  	shape Shape
    57  	along []int
    58  
    59  	correctShape Shape
    60  	correct  interface{}
    61  }{
    62  	{{range .Kinds -}}
    63  	{{if isNumber . -}}
    64  	{{if isOrd . -}}
    65  	{"common case: T.Max() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(5)},
    66  	{"A.Max(0)", {{asType . | title}}, Shape{2,3},[]int{0}, Shape{3}, []{{asType . }}{3, 4, 5}},
    67  	{"A.Max(1)", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType . }}{2,5}},
    68  	{"A.Max(0,1)", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(5)},
    69  	{"A.Max(1,0)", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(5)},
    70  	{"3T.Max(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{11, 23} },
    71  	{"4T.Max()", {{asType . | title}},  Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(15)},
    72  	{"4T.Max(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{5, 7, 13, 15}},
    73  	{"4T.Max(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{11, 15}},
    74  	{{end -}}
    75  	{{end -}}
    76  	{{end -}}
    77  }
    78  
    79  func TestDense_Max(t *testing.T){
    80  	assert := assert.New(t)
    81  	var T, T2 *Dense
    82  	var err error
    83  
    84  	for _, mts := range maxTests {
    85  		T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize())))
    86  		if T2, err = T.Max(mts.along...); err != nil{
    87  			t.Error(err)
    88  			continue
    89  		}
    90  		assert.True(mts.correctShape.Eq(T2.Shape()))
    91  		assert.Equal(mts.correct, T2.Data())
    92  	}
    93  	/* IDIOT TESTING TIME */
    94  	_, err = T.Max(1000)
    95  	assert.NotNil(err)
    96  }
    97  `
    98  
    99  const testDenseMinRaw = `var minTests = []struct {
   100  	name  string
   101  	of Dtype
   102  	shape Shape
   103  	along []int
   104  
   105  	correctShape Shape
   106  	correct  interface{}
   107  }{
   108  	{{range .Kinds -}}
   109  	{{if isNumber . -}}
   110  	{{if isOrd . -}}
   111  	{"common case: T.Min() for {{.}}", {{asType .|title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(0)},
   112  	{"A.Min(0)", {{asType .|title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{0, 1, 2}},
   113  	{"A.Min(1)", {{asType .|title}}, Shape{2,3}, []int{1}, Shape{2}, []{{asType .}}{0, 3}},
   114  	{"A.Min(0,1)", {{asType .|title}}, Shape{2,3}, []int{0, 1}, ScalarShape(), {{asType .}}(0)},
   115  	{"A.Min(1,0)", {{asType .|title}}, Shape{2,3}, []int{1, 0}, ScalarShape(), {{asType .}}(0)},
   116  	{"3T.Min(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{0,12} },
   117  	{"4T.Min()", {{asType . | title}},  Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(0)},
   118  	{"4T.Min(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{0, 2, 8, 10}},
   119  	{"4T.Min(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{0, 4}},
   120  	{{end -}}
   121  	{{end -}}
   122  	{{end -}}
   123  }
   124  
   125  func TestDense_Min(t *testing.T){
   126  	assert := assert.New(t)
   127  	var T, T2 *Dense
   128  	var err error
   129  
   130  	for _, mts := range minTests {
   131  		T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize())))
   132  		if T2, err = T.Min(mts.along...); err != nil{
   133  			t.Error(err)
   134  			continue
   135  		}
   136  		assert.True(mts.correctShape.Eq(T2.Shape()))
   137  		assert.Equal(mts.correct, T2.Data())
   138  	}
   139  
   140  	/* IDIOT TESTING TIME */
   141  	_, err = T.Min(1000)
   142  	assert.NotNil(err)
   143  }
   144  `
   145  
   146  var (
   147  	testDenseSum *template.Template
   148  	testDenseMax *template.Template
   149  	testDenseMin *template.Template
   150  )
   151  
   152  func init() {
   153  	testDenseSum = template.Must(template.New("testDenseSum").Funcs(funcs).Parse(testDenseSumRaw))
   154  	testDenseMax = template.Must(template.New("testDenseMax").Funcs(funcs).Parse(testDenseMaxRaw))
   155  	testDenseMin = template.Must(template.New("testDenseMin").Funcs(funcs).Parse(testDenseMinRaw))
   156  }
   157  
   158  func generateDenseReductionMethodsTests(f io.Writer, generic Kinds) {
   159  	testDenseSum.Execute(f, generic)
   160  	fmt.Fprint(f, "\n")
   161  	testDenseMax.Execute(f, generic)
   162  	fmt.Fprint(f, "\n")
   163  	testDenseMin.Execute(f, generic)
   164  }