github.com/wzzhu/tensor@v0.9.24/genlib2/dense_getset_tests.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"reflect"
     7  	"text/template"
     8  )
     9  
    10  type testData struct {
    11  	Kind      reflect.Kind
    12  	TestData0 []interface{}
    13  	Set       interface{}
    14  	Correct   []interface{}
    15  }
    16  
    17  func makeTests(generic Kinds) []testData {
    18  	retVal := make([]testData, 0)
    19  	for _, k := range generic.Kinds {
    20  		if isParameterized(k) {
    21  			continue
    22  		}
    23  
    24  		td := testData{Kind: k}
    25  
    26  		data := make([]interface{}, 6)
    27  		correct := make([]interface{}, 6)
    28  
    29  		switch {
    30  		case isRangeable(k):
    31  			raw := []int{0, 1, 2, 3, 4, 5}
    32  			for i := range data {
    33  				data[i] = raw[i]
    34  				correct[i] = 45
    35  			}
    36  			td.Set = 45
    37  		case k == reflect.Bool:
    38  			raw := []bool{true, false, true, false, true, false}
    39  			for i := range data {
    40  				data[i] = raw[i]
    41  				correct[i] = false
    42  			}
    43  			td.Set = false
    44  		case k == reflect.String:
    45  			raw := []string{"\"zero\"", "\"one\"", "\"two\"", "\"three\"", "\"four\"", "\"five\""}
    46  			for i := range data {
    47  				data[i] = raw[i]
    48  				correct[i] = "\"HELLO WORLD\""
    49  			}
    50  			td.Set = "\"HELLO WORLD\""
    51  		default:
    52  			continue
    53  		}
    54  		td.TestData0 = data
    55  		td.Correct = correct
    56  		retVal = append(retVal, td)
    57  
    58  	}
    59  	return retVal
    60  }
    61  
    62  func makeZeroTests(generic Kinds) []testData {
    63  	retVal := make([]testData, 0)
    64  	for _, k := range generic.Kinds {
    65  		if isParameterized(k) {
    66  			continue
    67  		}
    68  
    69  		td := testData{Kind: k}
    70  
    71  		data := make([]interface{}, 6)
    72  		correct := make([]interface{}, 6)
    73  
    74  		switch {
    75  		case isRangeable(k):
    76  			raw := []int{0, 1, 2, 3, 4, 5}
    77  			for i := range data {
    78  				data[i] = raw[i]
    79  				correct[i] = 0
    80  			}
    81  		case k == reflect.Bool:
    82  			raw := []bool{true, false, true, false, true, false}
    83  			for i := range data {
    84  				data[i] = raw[i]
    85  				correct[i] = false
    86  			}
    87  		case k == reflect.String:
    88  			raw := []string{"\"zero\"", "\"one\"", "\"two\"", "\"three\"", "\"four\"", "\"five\""}
    89  			for i := range data {
    90  				data[i] = raw[i]
    91  				correct[i] = "\"\""
    92  			}
    93  		default:
    94  			continue
    95  		}
    96  		td.TestData0 = data
    97  		td.Correct = correct
    98  		retVal = append(retVal, td)
    99  
   100  	}
   101  	return retVal
   102  }
   103  
   104  const getTestRaw = `var denseSetGetTests = []struct {
   105  	of Dtype
   106  	data interface{} 
   107  	set interface{}
   108  
   109  	correct []interface{}
   110  }{
   111  	{{range . -}}
   112  	{{$k := .Kind -}}
   113  	{ {{title .Kind.String | strip}}, []{{.Kind.String | clean}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, {{printf "%v" .Set}}, []interface{}{ {{range .TestData0 -}} {{$k}}({{printf "%v" .}}), {{end -}} }},
   114  	{{end -}}
   115  }
   116  
   117  func TestDense_setget(t *testing.T) {
   118  	assert := assert.New(t)
   119  	for _, gts := range denseSetGetTests {
   120  		T := New(Of(gts.of), WithShape(len(gts.correct)))
   121  		for i, v := range gts.correct {
   122  			T.Set(i, v)
   123  			got := T.Get(i)
   124  			assert.Equal(v, got)
   125  		}
   126  	}
   127  }
   128  
   129  `
   130  
   131  const memsetTestRaw = `var denseMemsetTests = []struct{
   132  	of Dtype
   133  	data interface{}
   134  	val interface{}
   135  	shape Shape
   136  
   137  	correct interface{}
   138  }{
   139  	{{range . -}}
   140  	{{$val := .Set -}}
   141  	{{$k := .Kind -}}
   142  	{ {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, {{asType .Kind}}({{$val}}), Shape{2,3}, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, 
   143  	{{end -}}
   144  }
   145  
   146  func TestDense_memset(t *testing.T){
   147  	assert := assert.New(t)
   148  	for _, mts := range denseMemsetTests {
   149  		T := New(Of(mts.of), WithShape(mts.shape...))
   150  		T.Memset(mts.val)
   151  		assert.Equal(mts.correct, T.Data())
   152  
   153  		T = New(Of(mts.of), WithShape(mts.shape...), WithBacking(mts.data))
   154  		T2, _ := T.Slice(nil)
   155  		T2.Memset(mts.val)
   156  		assert.Equal(mts.correct, T2.Data())
   157  	}
   158  }
   159  `
   160  
   161  const zeroTestRaw = `var denseZeroTests = []struct{
   162  	of Dtype
   163  	data interface{}
   164  
   165  	correct interface{}
   166  }{
   167  	{{range . -}}
   168  	{{$val := .Set -}}
   169  	{{$k := .Kind -}}
   170  	{ {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, 
   171  	{{end -}}
   172  }
   173  
   174  func TestDense_Zero(t *testing.T) {
   175  	assert := assert.New(t)
   176  	for _, mts := range denseZeroTests {
   177  		
   178  		typ := reflect.TypeOf(mts.data)
   179  		val := reflect.ValueOf(mts.data)
   180  		data := reflect.MakeSlice(typ, val.Len(), val.Cap())
   181  		reflect.Copy(data, val)	
   182  
   183  		T := New(Of(mts.of), WithBacking(data.Interface()))
   184  		T.Zero()
   185  		assert.Equal(mts.correct, T.Data())
   186  
   187  		T = New(Of(mts.of),  WithBacking(mts.data))
   188  		T2, _ := T.Slice(nil)
   189  		T2.Zero()
   190  		assert.Equal(mts.correct, T2.Data())
   191  	}	
   192  }
   193  `
   194  
   195  const denseEqTestRaw = `func TestDense_Eq(t *testing.T) {
   196  	eqFn := func(q *Dense) bool{
   197  		a := q.Clone().(*Dense)
   198  		if !q.Eq(a) {
   199  			t.Error("Expected a clone to be exactly equal")
   200  			return false
   201  		}
   202  		a.Zero()
   203  
   204  		// Bools are excluded because the probability of having an array of all false is very high
   205  		if q.Eq(a)  && a.len() > 3 && a.Dtype() != Bool {
   206  			t.Errorf("a %v", a.Data())
   207  			t.Errorf("q %v", q.Data())
   208  			t.Error("Expected *Dense to be not equal")
   209  			return false
   210  		}
   211  		return true
   212  	}
   213  	if err := quick.Check(eqFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   214  		t.Errorf("Failed to perform equality checks")
   215  	}
   216  }`
   217  
   218  var (
   219  	GetTest    *template.Template
   220  	MemsetTest *template.Template
   221  	ZeroTest   *template.Template
   222  )
   223  
   224  func init() {
   225  	GetTest = template.Must(template.New("GetTest").Funcs(funcs).Parse(getTestRaw))
   226  	MemsetTest = template.Must(template.New("MemsetTest").Funcs(funcs).Parse(memsetTestRaw))
   227  	ZeroTest = template.Must(template.New("ZeroTest").Funcs(funcs).Parse(zeroTestRaw))
   228  }
   229  
   230  func generateDenseGetSetTests(f io.Writer, generic Kinds) {
   231  	tests := makeTests(generic)
   232  	GetTest.Execute(f, tests)
   233  	fmt.Fprintf(f, "\n\n")
   234  	MemsetTest.Execute(f, tests)
   235  	fmt.Fprintf(f, "\n\n")
   236  	ZeroTest.Execute(f, makeZeroTests(generic))
   237  	fmt.Fprintf(f, "\n%v\n", denseEqTestRaw)
   238  }