github.com/wzzhu/tensor@v0.9.24/genlib2/arith_tests.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "text/template" 7 ) 8 9 const ( 10 APICallVVxRaw = `correct, err := {{.Name}}(a, b)` // no funcopt 11 APICallVVReuseMutRaw = `ret, err = {{.Name}}(a, b` 12 APICallVVRaw = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})` 13 APICallVSRaw = `ret, err := {{.Name}}(a, b {{template "funcoptuse"}})` 14 APICallSVRaw = `ret, err := {{.Name}}(b, a {{template "funcoptuse"}})` 15 16 APIInvVVRaw = `ret, err = {{.Inv}}(ret, b, UseUnsafe())` 17 APIInvVSRaw = `ret, err = {{.Inv}}(ret, b, UseUnsafe())` 18 APIInvSVRaw = `ret, err = {{.Name}}(b, ret, UseUnsafe())` 19 20 DenseMethodCallVVxRaw = `correct, err := a.{{.Name}}(b)` // no funcopt 21 DenseMethodCallVVReuseMutRaw = `ret, err = a.{{.Name}}(b` 22 DenseMethodCallVVRaw = `ret, err := a.{{.Name}}(b {{template "funcoptuse"}})` 23 DenseMethodCallVSRaw = `ret, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse"}})` 24 DenseMethodCallSVRaw = `ret, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse"}})` 25 26 DenseMethodInvVVRaw = `ret, err = ret.{{.Inv}}(b, UseUnsafe())` 27 DenseMethodInvVSRaw = `ret, err = ret.{{.Inv}}Scalar(b, true, UseUnsafe())` 28 DenseMethodInvSVRaw = `ret, err = ret.{{.Name}}Scalar(b, false, UseUnsafe())` 29 30 APIRetType = `Tensor` 31 DenseRetType = `*Dense` 32 ) 33 34 type ArithTest struct { 35 arithOp 36 scalars bool 37 lvl Level 38 FuncOpt string 39 EqFailTypeClassName string 40 } 41 42 func (fn *ArithTest) Signature() *Signature { 43 var name string 44 switch fn.lvl { 45 case API: 46 name = fmt.Sprintf("Test%s", fn.Name()) 47 48 case Dense: 49 name = fmt.Sprintf("TestDense_%s", fn.Name()) 50 } 51 if fn.scalars { 52 name += "Scalar" 53 } 54 if fn.FuncOpt != "" { 55 name += "_" + fn.FuncOpt 56 } 57 return &Signature{ 58 Name: name, 59 NameTemplate: plainName, 60 ParamNames: []string{"t"}, 61 ParamTemplates: []*template.Template{testingType}, 62 } 63 } 64 65 func (fn *ArithTest) WriteBody(w io.Writer) { 66 if fn.HasIdentity { 67 fn.writeIdentity(w) 68 fmt.Fprintf(w, "\n") 69 } 70 71 if fn.IsInv { 72 fn.writeInv(w) 73 } 74 fn.WriteScalarWrongType(w) 75 76 if fn.FuncOpt == "reuse" && fn.arithOp.Name() != "Pow" { 77 fn.writeReuseMutate(w) 78 } 79 } 80 81 func (fn *ArithTest) canWrite() bool { 82 if fn.HasIdentity || fn.IsInv { 83 return true 84 } 85 return false 86 } 87 88 func (fn *ArithTest) writeIdentity(w io.Writer) { 89 var t *template.Template 90 if fn.scalars { 91 t = template.Must(template.New("dense identity test").Funcs(funcs).Parse(denseIdentityArithScalarTestRaw)) 92 } else { 93 t = template.Must(template.New("dense identity test").Funcs(funcs).Parse(denseIdentityArithTestBodyRaw)) 94 } 95 switch fn.lvl { 96 case API: 97 if fn.scalars { 98 template.Must(t.New("call0").Parse(APICallVSRaw)) 99 template.Must(t.New("call1").Parse(APICallSVRaw)) 100 } else { 101 template.Must(t.New("call0").Parse(APICallVVRaw)) 102 } 103 case Dense: 104 if fn.scalars { 105 template.Must(t.New("call0").Parse(DenseMethodCallVSRaw)) 106 template.Must(t.New("call1").Parse(DenseMethodCallSVRaw)) 107 } else { 108 template.Must(t.New("call0").Parse(DenseMethodCallVVRaw)) 109 } 110 111 } 112 template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) 113 template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) 114 template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) 115 template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) 116 117 t.Execute(w, fn) 118 } 119 120 func (fn *ArithTest) writeInv(w io.Writer) { 121 var t *template.Template 122 if fn.scalars { 123 t = template.Must(template.New("dense involution test").Funcs(funcs).Parse(denseInvArithScalarTestRaw)) 124 } else { 125 t = template.Must(template.New("dense involution test").Funcs(funcs).Parse(denseInvArithTestBodyRaw)) 126 } 127 switch fn.lvl { 128 case API: 129 if fn.scalars { 130 template.Must(t.New("call0").Parse(APICallVSRaw)) 131 template.Must(t.New("call1").Parse(APICallSVRaw)) 132 template.Must(t.New("callInv0").Parse(APIInvVSRaw)) 133 template.Must(t.New("callInv1").Parse(APIInvSVRaw)) 134 } else { 135 template.Must(t.New("call0").Parse(APICallVVRaw)) 136 template.Must(t.New("callInv").Parse(APIInvVVRaw)) 137 } 138 case Dense: 139 if fn.scalars { 140 template.Must(t.New("call0").Parse(DenseMethodCallVSRaw)) 141 template.Must(t.New("call1").Parse(DenseMethodCallSVRaw)) 142 template.Must(t.New("callInv0").Parse(DenseMethodInvVSRaw)) 143 template.Must(t.New("callInv1").Parse(DenseMethodInvSVRaw)) 144 } else { 145 template.Must(t.New("call0").Parse(DenseMethodCallVVRaw)) 146 template.Must(t.New("callInv").Parse(DenseMethodInvVVRaw)) 147 } 148 } 149 150 template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) 151 template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) 152 template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) 153 template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) 154 155 t.Execute(w, fn) 156 } 157 158 func (fn *ArithTest) writeReuseMutate(w io.Writer) { 159 t := template.Must(template.New("Reuse mutation test").Funcs(funcs).Parse(denseArithReuseMutationTestRaw)) 160 switch fn.lvl { 161 case API: 162 return // tmp 163 case Dense: 164 template.Must(t.New("callVanilla").Parse(DenseMethodCallVVxRaw)) 165 template.Must(t.New("retType").Parse(DenseRetType)) 166 template.Must(t.New("call0").Parse(DenseMethodCallVVReuseMutRaw)) 167 168 } 169 template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) 170 template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) 171 template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) 172 t.Execute(w, fn) 173 } 174 175 func (fn *ArithTest) WriteScalarWrongType(w io.Writer) { 176 if !fn.scalars { 177 return 178 } 179 if fn.FuncOpt != "" { 180 return 181 } 182 t := template.Must(template.New("dense scalar wrongtype test").Funcs(funcs).Parse(denseArithScalarWrongTypeTestRaw)) 183 template.Must(t.New("call0").Parse(APICallVSRaw)) 184 template.Must(t.New("call1").Parse(APICallSVRaw)) 185 template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) 186 template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) 187 template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) 188 template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) 189 190 t.Execute(w, fn) 191 } 192 193 func (fn *ArithTest) Write(w io.Writer) { 194 sig := fn.Signature() 195 w.Write([]byte("func ")) 196 sig.Write(w) 197 w.Write([]byte("{\n")) 198 fn.WriteBody(w) 199 w.Write([]byte("}\n")) 200 } 201 202 func generateAPIArithTests(f io.Writer, ak Kinds) { 203 var tests []*ArithTest 204 for _, op := range arithBinOps { 205 t := &ArithTest{ 206 arithOp: op, 207 lvl: API, 208 EqFailTypeClassName: "nil", 209 } 210 if t.name == "Pow" { 211 t.EqFailTypeClassName = "complexTypes" 212 } 213 tests = append(tests, t) 214 } 215 216 for _, fn := range tests { 217 if fn.canWrite() { 218 fn.Write(f) 219 } 220 fn.FuncOpt = "unsafe" 221 } 222 223 for _, fn := range tests { 224 if fn.canWrite() { 225 fn.Write(f) 226 } 227 fn.FuncOpt = "reuse" 228 } 229 230 for _, fn := range tests { 231 if fn.canWrite() { 232 fn.Write(f) 233 } 234 fn.FuncOpt = "incr" 235 } 236 237 for _, fn := range tests { 238 if fn.canWrite() { 239 fn.Write(f) 240 } 241 } 242 } 243 244 func generateAPIArithScalarTests(f io.Writer, ak Kinds) { 245 var tests []*ArithTest 246 for _, op := range arithBinOps { 247 t := &ArithTest{ 248 arithOp: op, 249 scalars: true, 250 lvl: API, 251 EqFailTypeClassName: "nil", 252 } 253 switch t.name { 254 case "Pow": 255 t.EqFailTypeClassName = "complexTypes" 256 case "Sub": 257 t.EqFailTypeClassName = "unsignedTypes" 258 } 259 tests = append(tests, t) 260 } 261 262 for _, fn := range tests { 263 if fn.canWrite() { 264 fn.Write(f) 265 } 266 fn.FuncOpt = "unsafe" 267 } 268 269 for _, fn := range tests { 270 if fn.canWrite() { 271 fn.Write(f) 272 } 273 fn.FuncOpt = "reuse" 274 } 275 276 for _, fn := range tests { 277 if fn.canWrite() { 278 fn.Write(f) 279 } 280 fn.FuncOpt = "incr" 281 } 282 283 for _, fn := range tests { 284 if fn.canWrite() { 285 fn.Write(f) 286 } 287 } 288 } 289 290 func generateDenseMethodArithTests(f io.Writer, ak Kinds) { 291 var tests []*ArithTest 292 for _, op := range arithBinOps { 293 t := &ArithTest{ 294 arithOp: op, 295 lvl: Dense, 296 EqFailTypeClassName: "nil", 297 } 298 if t.name == "Pow" { 299 t.EqFailTypeClassName = "complexTypes" 300 } 301 tests = append(tests, t) 302 } 303 304 for _, fn := range tests { 305 if fn.canWrite() { 306 fn.Write(f) 307 } 308 fn.FuncOpt = "unsafe" 309 } 310 311 for _, fn := range tests { 312 if fn.canWrite() { 313 fn.Write(f) 314 } 315 fn.FuncOpt = "reuse" 316 } 317 318 for _, fn := range tests { 319 if fn.canWrite() { 320 fn.Write(f) 321 } 322 fn.FuncOpt = "incr" 323 } 324 325 for _, fn := range tests { 326 if fn.canWrite() { 327 fn.Write(f) 328 } 329 } 330 } 331 332 func generateDenseMethodScalarTests(f io.Writer, ak Kinds) { 333 var tests []*ArithTest 334 for _, op := range arithBinOps { 335 t := &ArithTest{ 336 arithOp: op, 337 scalars: true, 338 lvl: Dense, 339 EqFailTypeClassName: "nil", 340 } 341 switch t.name { 342 case "Pow": 343 t.EqFailTypeClassName = "complexTypes" 344 case "Sub": 345 t.EqFailTypeClassName = "unsignedTypes" 346 } 347 tests = append(tests, t) 348 } 349 350 for _, fn := range tests { 351 if fn.canWrite() { 352 fn.Write(f) 353 } 354 fn.FuncOpt = "unsafe" 355 } 356 357 for _, fn := range tests { 358 if fn.canWrite() { 359 fn.Write(f) 360 } 361 fn.FuncOpt = "reuse" 362 } 363 364 for _, fn := range tests { 365 if fn.canWrite() { 366 fn.Write(f) 367 } 368 fn.FuncOpt = "incr" 369 } 370 371 for _, fn := range tests { 372 if fn.canWrite() { 373 fn.Write(f) 374 } 375 } 376 }