github.com/wzzhu/tensor@v0.9.24/genlib2/cmp_tests.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"text/template"
     7  )
     8  
     9  const (
    10  	APICallVVaxbRaw    = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})`
    11  	APICallVVbxcRaw    = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})`
    12  	APICallVVaxcRaw    = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})`
    13  	APICallVVbxaRaw    = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})`
    14  	APICallMixedaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})`
    15  	APICallMixedbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})`
    16  	APICallMixedaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})`
    17  	APICallMixedbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})`
    18  
    19  	DenseMethodCallVVaxbRaw    = `axb, err := a.{{.Name}}(b {{template "funcoptuse" . -}})`
    20  	DenseMethodCallVVbxcRaw    = `bxc, err := b.{{.Name}}(c {{template "funcoptuse" . -}})`
    21  	DenseMethodCallVVaxcRaw    = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})`
    22  	DenseMethodCallVVbxaRaw    = `bxa, err := b.{{.Name}}(a {{template "funcoptuse" . -}})`
    23  	DenseMethodCallMixedaxbRaw = `axb, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse" . -}})`
    24  	DenseMethodCallMixedbxcRaw = `bxc, err := c.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})`
    25  	DenseMethodCallMixedaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})`
    26  	DenseMethodCallMixedbxaRaw = `bxa, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})`
    27  )
    28  
    29  const transitivityCheckRaw = `{{if eq .FuncOpt "assame" -}}
    30  	if !threewayEq(axb.Data(), bxc.Data(), axc.Data()){
    31  		t.Errorf("a: %-v", a)
    32  		t.Errorf("b: %-v", b)
    33  		t.Errorf("c: %-v", c)
    34  		t.Errorf("axb.Data() %v", axb.Data())
    35  		t.Errorf("bxc.Data() %v", bxc.Data())
    36  		t.Errorf("axc.Data() %v", axc.Data())
    37  		return false
    38  	}
    39  {{else -}}
    40  	{{if eq .Level "API" -}}
    41  		ab := axb.(*Dense).Bools()
    42  		bc := bxc.(*Dense).Bools()
    43  		ac := axc.(*Dense).Bools()
    44  	{{else -}}
    45  		ab := axb.Bools()
    46  		bc := bxc.Bools()
    47  		ac := axc.Bools()
    48  	{{end -}}
    49  	for i, vab := range ab {
    50  		if vab && bc[i] {
    51  			if !ac[i]{
    52  				return false
    53  			}
    54  		}
    55  	}
    56  {{end -}}
    57  `
    58  
    59  const transitivityBodyRaw = `transFn := func(q *Dense) bool {
    60  	we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}})
    61  	_, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok
    62  
    63  	{{template "funcoptdecl" . -}}
    64  
    65  	r := newRand()
    66  	a := q.Clone().(*Dense)
    67  	b := q.Clone().(*Dense)
    68  	c := q.Clone().(*Dense)
    69  
    70  	bv, _ := quick.Value(b.Dtype().Type, r)
    71  	cv, _ := quick.Value(c.Dtype().Type, r)
    72  	b.Memset(bv.Interface())
    73  	c.Memset(cv.Interface())
    74  
    75  	{{template "axb" .}}
    76  	if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{
    77  		if err != nil {
    78  			return false
    79  		}
    80  		return true
    81  	}
    82  
    83  	{{template "bxc" . }}
    84  	if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", b, c, we, err); retEarly{
    85  		if err != nil {
    86  			return false
    87  		}
    88  		return true
    89  	}
    90  
    91  	{{template "axc" . }}
    92  	if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{
    93  		if err != nil {
    94  			return false
    95  		}
    96  		return true
    97  	}
    98  
    99  	{{template "transitivityCheck" .}}
   100  	return true
   101  }
   102  if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   103  	t.Errorf("Transitivity test for {{.Name}} failed: %v", err)
   104  }
   105  `
   106  
   107  const transitivityMixedBodyRaw = `transFn := func(q *Dense) bool {
   108  	we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}})
   109  	_, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok
   110  
   111  	{{template "funcoptdecl" . -}}
   112  
   113  	r := newRand()
   114  	a := q.Clone().(*Dense)
   115  	bv, _ := quick.Value(a.Dtype().Type, r)
   116  	b := bv.Interface()
   117  	c := q.Clone().(*Dense)
   118  	cv, _ := quick.Value(c.Dtype().Type, r)
   119  	c.Memset(cv.Interface())
   120  
   121  	{{template "axb" . }}
   122  	if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{
   123  		if err != nil {
   124  			return false
   125  		}
   126  		return true
   127  	}
   128  
   129  	{{template "bxc" . }}
   130  	if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", c, b, we, err); retEarly{
   131  		if err != nil {
   132  			return false
   133  		}
   134  		return true
   135  	}
   136  
   137  	{{template "axc" . }}
   138  	if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{
   139  		if err != nil {
   140  			return false
   141  		}
   142  		return true
   143  	}
   144  
   145  	{{template "transitivityCheck" .}}
   146  	return true
   147  }
   148  if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   149  	t.Errorf("Transitivity test for {{.Name}} failed: %v", err)
   150  }
   151  `
   152  
   153  const symmetryBodyRaw = `symFn := func(q *Dense) bool {
   154  	we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}})
   155  	_, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok
   156  
   157  	{{template "funcoptdecl" . -}}
   158  
   159  	r := newRand()
   160  	a := q.Clone().(*Dense)
   161  	b := q.Clone().(*Dense)
   162  
   163  	bv, _ := quick.Value(b.Dtype().Type, r)
   164  	b.Memset(bv.Interface())
   165  
   166  	{{template "axb" .}}
   167  	if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{
   168  		if err != nil {
   169  			return false
   170  		}
   171  		return true
   172  	}
   173  
   174  	{{template "bxa" .}}
   175  	if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{
   176  		if err != nil {
   177  			return false
   178  		}
   179  		return true
   180  	}
   181  	return reflect.DeepEqual(axb.Data(), bxa.Data())
   182  
   183  }
   184  if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   185  	t.Errorf("Transitivity test for {{.Name}} failed: %v", err)
   186  }
   187  `
   188  
   189  const symmetryMixedBodyRaw = `symFn := func(q *Dense) bool {
   190  	we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}})
   191  	_, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok
   192  
   193  	{{template "funcoptdecl" . -}}
   194  
   195  	r := newRand()
   196  	a := q.Clone().(*Dense)
   197  	bv, _ := quick.Value(a.Dtype().Type, r)
   198  	b := bv.Interface()
   199  
   200  	{{template "axb" .}}
   201  	if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{
   202  		if err != nil {
   203  			return false
   204  		}
   205  		return true
   206  	}
   207  
   208  	{{template "bxa" .}}
   209  	if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{
   210  		if err != nil {
   211  			return false
   212  		}
   213  		return true
   214  	}
   215  	return reflect.DeepEqual(axb.Data(), bxa.Data())
   216  
   217  }
   218  if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil {
   219  	t.Errorf("Symmetry test for {{.Name}} failed: %v", err)
   220  }
   221  `
   222  
   223  type CmpTest struct {
   224  	cmpOp
   225  	scalars             bool
   226  	lvl                 Level
   227  	FuncOpt             string
   228  	EqFailTypeClassName string
   229  }
   230  
   231  func (fn *CmpTest) Name() string {
   232  	if fn.cmpOp.Name() == "Eq" || fn.cmpOp.Name() == "Ne" {
   233  		return "El" + fn.cmpOp.Name()
   234  	}
   235  	return fn.cmpOp.Name()
   236  }
   237  
   238  func (fn *CmpTest) Level() string {
   239  	switch fn.lvl {
   240  	case API:
   241  		return "API"
   242  	case Dense:
   243  		return "Dense"
   244  	}
   245  	return ""
   246  }
   247  
   248  func (fn *CmpTest) Signature() *Signature {
   249  	var name string
   250  	switch fn.lvl {
   251  	case API:
   252  		name = fmt.Sprintf("Test%s", fn.cmpOp.Name())
   253  	case Dense:
   254  		name = fmt.Sprintf("TestDense_%s", fn.Name())
   255  	}
   256  	if fn.scalars {
   257  		name += "Scalar"
   258  	}
   259  	if fn.FuncOpt != "" {
   260  		name += "_" + fn.FuncOpt
   261  	}
   262  	return &Signature{
   263  		Name:           name,
   264  		NameTemplate:   plainName,
   265  		ParamNames:     []string{"t"},
   266  		ParamTemplates: []*template.Template{testingType},
   267  	}
   268  }
   269  
   270  func (fn *CmpTest) canWrite() bool {
   271  	return fn.IsTransitive || fn.IsSymmetric
   272  }
   273  
   274  func (fn *CmpTest) WriteBody(w io.Writer) {
   275  	if fn.IsTransitive {
   276  		fn.writeTransitivity(w)
   277  		fmt.Fprintf(w, "\n")
   278  	}
   279  	if fn.IsSymmetric {
   280  		fn.writeSymmetry(w)
   281  	}
   282  }
   283  
   284  func (fn *CmpTest) writeTransitivity(w io.Writer) {
   285  	var t *template.Template
   286  	if fn.scalars {
   287  		t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityMixedBodyRaw))
   288  	} else {
   289  		t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityBodyRaw))
   290  	}
   291  
   292  	switch fn.lvl {
   293  	case API:
   294  		if fn.scalars {
   295  			template.Must(t.New("axb").Parse(APICallMixedaxbRaw))
   296  			template.Must(t.New("bxc").Parse(APICallMixedbxcRaw))
   297  			template.Must(t.New("axc").Parse(APICallMixedaxcRaw))
   298  		} else {
   299  			template.Must(t.New("axb").Parse(APICallVVaxbRaw))
   300  			template.Must(t.New("bxc").Parse(APICallVVbxcRaw))
   301  			template.Must(t.New("axc").Parse(APICallVVaxcRaw))
   302  		}
   303  	case Dense:
   304  		if fn.scalars {
   305  			template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw))
   306  			template.Must(t.New("bxc").Parse(DenseMethodCallMixedbxcRaw))
   307  			template.Must(t.New("axc").Parse(DenseMethodCallMixedaxcRaw))
   308  		} else {
   309  			template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw))
   310  			template.Must(t.New("bxc").Parse(DenseMethodCallVVbxcRaw))
   311  			template.Must(t.New("axc").Parse(DenseMethodCallVVaxcRaw))
   312  		}
   313  	}
   314  	template.Must(t.New("transitivityCheck").Parse(transitivityCheckRaw))
   315  	template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt]))
   316  	template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt]))
   317  	template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt]))
   318  	template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt]))
   319  
   320  	t.Execute(w, fn)
   321  }
   322  
   323  func (fn *CmpTest) writeSymmetry(w io.Writer) {
   324  	var t *template.Template
   325  	if fn.scalars {
   326  		t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryMixedBodyRaw))
   327  	} else {
   328  		t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryBodyRaw))
   329  	}
   330  
   331  	switch fn.lvl {
   332  	case API:
   333  		if fn.scalars {
   334  			template.Must(t.New("axb").Parse(APICallMixedaxbRaw))
   335  			template.Must(t.New("bxa").Parse(APICallMixedbxaRaw))
   336  		} else {
   337  			template.Must(t.New("axb").Parse(APICallVVaxbRaw))
   338  			template.Must(t.New("bxa").Parse(APICallVVbxaRaw))
   339  		}
   340  	case Dense:
   341  		if fn.scalars {
   342  			template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw))
   343  			template.Must(t.New("bxa").Parse(DenseMethodCallMixedbxaRaw))
   344  		} else {
   345  			template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw))
   346  			template.Must(t.New("bxa").Parse(DenseMethodCallVVbxaRaw))
   347  		}
   348  	}
   349  	template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt]))
   350  	template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt]))
   351  	template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt]))
   352  	template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt]))
   353  
   354  	t.Execute(w, fn)
   355  }
   356  
   357  func (fn *CmpTest) Write(w io.Writer) {
   358  	sig := fn.Signature()
   359  	w.Write([]byte("func "))
   360  	sig.Write(w)
   361  	w.Write([]byte("{\n"))
   362  	fn.WriteBody(w)
   363  	w.Write([]byte("}\n"))
   364  }
   365  
   366  func generateAPICmpTests(f io.Writer, ak Kinds) {
   367  	var tests []*CmpTest
   368  
   369  	for _, op := range cmpBinOps {
   370  		t := &CmpTest{
   371  			cmpOp:               op,
   372  			lvl:                 API,
   373  			EqFailTypeClassName: "nil",
   374  		}
   375  		tests = append(tests, t)
   376  	}
   377  
   378  	for _, fn := range tests {
   379  		if fn.canWrite() {
   380  			fn.Write(f)
   381  		}
   382  		fn.FuncOpt = "assame"
   383  		fn.TypeClassName = "nonComplexNumberTypes"
   384  	}
   385  	for _, fn := range tests {
   386  		if fn.canWrite() {
   387  			fn.Write(f)
   388  		}
   389  	}
   390  
   391  }
   392  
   393  func generateAPICmpMixedTests(f io.Writer, ak Kinds) {
   394  	var tests []*CmpTest
   395  
   396  	for _, op := range cmpBinOps {
   397  		t := &CmpTest{
   398  			cmpOp:               op,
   399  			lvl:                 API,
   400  			scalars:             true,
   401  			EqFailTypeClassName: "nil",
   402  		}
   403  		tests = append(tests, t)
   404  	}
   405  
   406  	for _, fn := range tests {
   407  		if fn.canWrite() {
   408  			fn.Write(f)
   409  		}
   410  		fn.FuncOpt = "assame"
   411  		fn.TypeClassName = "nonComplexNumberTypes"
   412  	}
   413  	for _, fn := range tests {
   414  		if fn.canWrite() {
   415  			fn.Write(f)
   416  		}
   417  	}
   418  }
   419  
   420  func generateDenseMethodCmpTests(f io.Writer, ak Kinds) {
   421  	var tests []*CmpTest
   422  
   423  	for _, op := range cmpBinOps {
   424  		t := &CmpTest{
   425  			cmpOp:               op,
   426  			lvl:                 Dense,
   427  			EqFailTypeClassName: "nil",
   428  		}
   429  		tests = append(tests, t)
   430  	}
   431  
   432  	for _, fn := range tests {
   433  		if fn.canWrite() {
   434  			fn.Write(f)
   435  		}
   436  		fn.FuncOpt = "assame"
   437  		fn.TypeClassName = "nonComplexNumberTypes"
   438  	}
   439  	for _, fn := range tests {
   440  		if fn.canWrite() {
   441  			fn.Write(f)
   442  		}
   443  	}
   444  }
   445  
   446  func generateDenseMethodCmpMixedTests(f io.Writer, ak Kinds) {
   447  	var tests []*CmpTest
   448  
   449  	for _, op := range cmpBinOps {
   450  		t := &CmpTest{
   451  			cmpOp:               op,
   452  			lvl:                 Dense,
   453  			scalars:             true,
   454  			EqFailTypeClassName: "nil",
   455  		}
   456  		tests = append(tests, t)
   457  	}
   458  
   459  	for _, fn := range tests {
   460  		if fn.canWrite() {
   461  			fn.Write(f)
   462  		}
   463  		fn.FuncOpt = "assame"
   464  		fn.TypeClassName = "nonComplexNumberTypes"
   465  	}
   466  	for _, fn := range tests {
   467  		if fn.canWrite() {
   468  			fn.Write(f)
   469  		}
   470  	}
   471  }