github.com/wzzhu/tensor@v0.9.24/genlib2/unary_tests.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "text/template" 7 ) 8 9 const unaryTestBodyRaw = `invFn := func(q *Dense) bool { 10 a := q.Clone().(*Dense) 11 {{template "funcoptdecl" -}} 12 correct := a.Clone().(*Dense) 13 {{template "funcoptcorrect" -}} 14 15 16 we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 17 _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok 18 19 ret, err := {{.Name}}(a {{template "funcoptuse"}}) 20 if err, retEarly := qcErrCheck(t, "{{.Name}}", a, nil, we, err); retEarly{ 21 if err != nil { 22 return false 23 } 24 return true 25 } 26 {{if ne .InvTypeClass "" -}} 27 if err := typeclassCheck(a.Dtype(), {{.InvTypeClass}}); err != nil { 28 return true // uninvertible due to type class implementation issues 29 } 30 {{end -}} 31 {{if eq .FuncOpt "incr" -}} 32 if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()) ; err != nil { 33 t.Errorf("err while subtracting incr: %v", err) 34 return false 35 } 36 {{end -}} 37 {{.Inv}}(ret, UseUnsafe()) 38 if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { 39 return false 40 } 41 {{template "funcoptcheck" -}} 42 return true 43 } 44 45 if err := quick.Check(invFn, &quick.Config{Rand:newRand(), MaxCount: quickchecks}); err != nil{ 46 t.Errorf("Inv tests for {{.Name}} failed: %v", err) 47 } 48 ` 49 50 type unaryTest struct { 51 unaryOp 52 FuncOpt string 53 EqFailTypeClassName string 54 InvTypeClass string 55 } 56 57 func (fn *unaryTest) Name() string { 58 if fn.unaryOp.Name() == "Eq" || fn.unaryOp.Name() == "Ne" { 59 return "El" + fn.unaryOp.Name() 60 } 61 return fn.unaryOp.Name() 62 } 63 64 func (fn *unaryTest) Signature() *Signature { 65 name := fmt.Sprintf("Test%s", fn.unaryOp.Name()) 66 if fn.FuncOpt != "" { 67 name += "_" + fn.FuncOpt 68 } 69 return &Signature{ 70 Name: name, 71 NameTemplate: plainName, 72 ParamNames: []string{"t"}, 73 ParamTemplates: []*template.Template{testingType}, 74 } 75 } 76 77 func (fn *unaryTest) WriteBody(w io.Writer) { 78 t := template.Must(template.New("unary test body").Funcs(funcs).Parse(unaryTestBodyRaw)) 79 template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) 80 template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) 81 template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) 82 template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) 83 t.Execute(w, fn) 84 } 85 86 func (fn *unaryTest) canWrite() bool { return fn.Inv != "" } 87 88 func (fn *unaryTest) Write(w io.Writer) { 89 sig := fn.Signature() 90 w.Write([]byte("func ")) 91 sig.Write(w) 92 w.Write([]byte("{\n")) 93 fn.WriteBody(w) 94 w.Write([]byte("}\n")) 95 } 96 97 func generateAPIUnaryTests(f io.Writer, ak Kinds) { 98 var tests []*unaryTest 99 for _, op := range conditionalUnaries { 100 t := &unaryTest{ 101 unaryOp: op, 102 EqFailTypeClassName: "nil", 103 } 104 105 tests = append(tests, t) 106 } 107 108 for _, op := range unconditionalUnaries { 109 t := &unaryTest{ 110 unaryOp: op, 111 EqFailTypeClassName: "nil", 112 } 113 switch op.name { 114 case "Square": 115 t.InvTypeClass = "floatcmplxTypes" 116 case "Cube": 117 t.InvTypeClass = "floatTypes" 118 } 119 120 tests = append(tests, t) 121 } 122 123 for _, fn := range tests { 124 if fn.canWrite() { 125 fn.Write(f) 126 } 127 fn.FuncOpt = "unsafe" 128 } 129 130 for _, fn := range tests { 131 if fn.canWrite() { 132 fn.Write(f) 133 } 134 fn.FuncOpt = "reuse" 135 } 136 137 for _, fn := range tests { 138 if fn.canWrite() { 139 fn.Write(f) 140 } 141 fn.FuncOpt = "incr" 142 } 143 144 // for now incr cannot be quickchecked 145 146 for _, fn := range tests { 147 if fn.canWrite() { 148 fn.Write(f) 149 } 150 } 151 }