github.com/wzzhu/tensor@v0.9.24/genlib2/unary_tests.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"text/template"
     7  )
     8  
     9  const unaryTestBodyRaw = `invFn := func(q *Dense) bool {
    10  	a := q.Clone().(*Dense)
    11  	{{template "funcoptdecl" -}}
    12  	correct := a.Clone().(*Dense)
    13  	{{template "funcoptcorrect" -}}
    14  
    15  
    16  	we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}})
    17  	_, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok
    18  	
    19  	ret, err := {{.Name}}(a {{template "funcoptuse"}})
    20  	if err, retEarly := qcErrCheck(t, "{{.Name}}", a, nil, we, err); retEarly{
    21  		if err != nil {
    22  			return false
    23  		}
    24  		return true
    25  	}
    26  	{{if ne .InvTypeClass "" -}}
    27  	if err := typeclassCheck(a.Dtype(), {{.InvTypeClass}}); err != nil {
    28  		return true // uninvertible due to type class implementation issues
    29  	}
    30  	{{end -}}
    31  	{{if eq .FuncOpt "incr" -}}
    32  	if ret, err = Sub(ret, identityVal(100, a.Dtype()),  UseUnsafe()) ; err != nil {
    33  		t.Errorf("err while subtracting incr: %v", err)
    34  		return false
    35  	}
    36  	{{end -}}
    37  	{{.Inv}}(ret, UseUnsafe())
    38  	if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) {
    39  		return false
    40  	}
    41  	{{template "funcoptcheck" -}}
    42  	return true
    43  }
    44  
    45  if err := quick.Check(invFn, &quick.Config{Rand:newRand(), MaxCount: quickchecks}); err != nil{
    46  	t.Errorf("Inv tests for {{.Name}} failed: %v", err)
    47  }
    48  `
    49  
    50  type unaryTest struct {
    51  	unaryOp
    52  	FuncOpt             string
    53  	EqFailTypeClassName string
    54  	InvTypeClass        string
    55  }
    56  
    57  func (fn *unaryTest) Name() string {
    58  	if fn.unaryOp.Name() == "Eq" || fn.unaryOp.Name() == "Ne" {
    59  		return "El" + fn.unaryOp.Name()
    60  	}
    61  	return fn.unaryOp.Name()
    62  }
    63  
    64  func (fn *unaryTest) Signature() *Signature {
    65  	name := fmt.Sprintf("Test%s", fn.unaryOp.Name())
    66  	if fn.FuncOpt != "" {
    67  		name += "_" + fn.FuncOpt
    68  	}
    69  	return &Signature{
    70  		Name:           name,
    71  		NameTemplate:   plainName,
    72  		ParamNames:     []string{"t"},
    73  		ParamTemplates: []*template.Template{testingType},
    74  	}
    75  }
    76  
    77  func (fn *unaryTest) WriteBody(w io.Writer) {
    78  	t := template.Must(template.New("unary test body").Funcs(funcs).Parse(unaryTestBodyRaw))
    79  	template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt]))
    80  	template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt]))
    81  	template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt]))
    82  	template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt]))
    83  	t.Execute(w, fn)
    84  }
    85  
    86  func (fn *unaryTest) canWrite() bool { return fn.Inv != "" }
    87  
    88  func (fn *unaryTest) Write(w io.Writer) {
    89  	sig := fn.Signature()
    90  	w.Write([]byte("func "))
    91  	sig.Write(w)
    92  	w.Write([]byte("{\n"))
    93  	fn.WriteBody(w)
    94  	w.Write([]byte("}\n"))
    95  }
    96  
    97  func generateAPIUnaryTests(f io.Writer, ak Kinds) {
    98  	var tests []*unaryTest
    99  	for _, op := range conditionalUnaries {
   100  		t := &unaryTest{
   101  			unaryOp:             op,
   102  			EqFailTypeClassName: "nil",
   103  		}
   104  
   105  		tests = append(tests, t)
   106  	}
   107  
   108  	for _, op := range unconditionalUnaries {
   109  		t := &unaryTest{
   110  			unaryOp:             op,
   111  			EqFailTypeClassName: "nil",
   112  		}
   113  		switch op.name {
   114  		case "Square":
   115  			t.InvTypeClass = "floatcmplxTypes"
   116  		case "Cube":
   117  			t.InvTypeClass = "floatTypes"
   118  		}
   119  
   120  		tests = append(tests, t)
   121  	}
   122  
   123  	for _, fn := range tests {
   124  		if fn.canWrite() {
   125  			fn.Write(f)
   126  		}
   127  		fn.FuncOpt = "unsafe"
   128  	}
   129  
   130  	for _, fn := range tests {
   131  		if fn.canWrite() {
   132  			fn.Write(f)
   133  		}
   134  		fn.FuncOpt = "reuse"
   135  	}
   136  
   137  	for _, fn := range tests {
   138  		if fn.canWrite() {
   139  			fn.Write(f)
   140  		}
   141  		fn.FuncOpt = "incr"
   142  	}
   143  
   144  	// for now incr cannot be quickchecked
   145  
   146  	for _, fn := range tests {
   147  		if fn.canWrite() {
   148  			fn.Write(f)
   149  		}
   150  	}
   151  }