github.com/wzzhu/tensor@v0.9.24/genlib2/cmp_tests.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "text/template" 7 ) 8 9 const ( 10 APICallVVaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` 11 APICallVVbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` 12 APICallVVaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` 13 APICallVVbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` 14 APICallMixedaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` 15 APICallMixedbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` 16 APICallMixedaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` 17 APICallMixedbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` 18 19 DenseMethodCallVVaxbRaw = `axb, err := a.{{.Name}}(b {{template "funcoptuse" . -}})` 20 DenseMethodCallVVbxcRaw = `bxc, err := b.{{.Name}}(c {{template "funcoptuse" . -}})` 21 DenseMethodCallVVaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` 22 DenseMethodCallVVbxaRaw = `bxa, err := b.{{.Name}}(a {{template "funcoptuse" . -}})` 23 DenseMethodCallMixedaxbRaw = `axb, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse" . -}})` 24 DenseMethodCallMixedbxcRaw = `bxc, err := c.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` 25 DenseMethodCallMixedaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` 26 DenseMethodCallMixedbxaRaw = `bxa, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` 27 ) 28 29 const transitivityCheckRaw = `{{if eq .FuncOpt "assame" -}} 30 if !threewayEq(axb.Data(), bxc.Data(), axc.Data()){ 31 t.Errorf("a: %-v", a) 32 t.Errorf("b: %-v", b) 33 t.Errorf("c: %-v", c) 34 t.Errorf("axb.Data() %v", axb.Data()) 35 t.Errorf("bxc.Data() %v", bxc.Data()) 36 t.Errorf("axc.Data() %v", axc.Data()) 37 return false 38 } 39 {{else -}} 40 {{if eq .Level "API" -}} 41 ab := axb.(*Dense).Bools() 42 bc := bxc.(*Dense).Bools() 43 ac := axc.(*Dense).Bools() 44 {{else -}} 45 ab := axb.Bools() 46 bc := bxc.Bools() 47 ac := axc.Bools() 48 {{end -}} 49 for i, vab := range ab { 50 if vab && bc[i] { 51 if !ac[i]{ 52 return false 53 } 54 } 55 } 56 {{end -}} 57 ` 58 59 const transitivityBodyRaw = `transFn := func(q *Dense) bool { 60 we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 61 _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok 62 63 {{template "funcoptdecl" . -}} 64 65 r := newRand() 66 a := q.Clone().(*Dense) 67 b := q.Clone().(*Dense) 68 c := q.Clone().(*Dense) 69 70 bv, _ := quick.Value(b.Dtype().Type, r) 71 cv, _ := quick.Value(c.Dtype().Type, r) 72 b.Memset(bv.Interface()) 73 c.Memset(cv.Interface()) 74 75 {{template "axb" .}} 76 if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ 77 if err != nil { 78 return false 79 } 80 return true 81 } 82 83 {{template "bxc" . }} 84 if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", b, c, we, err); retEarly{ 85 if err != nil { 86 return false 87 } 88 return true 89 } 90 91 {{template "axc" . }} 92 if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ 93 if err != nil { 94 return false 95 } 96 return true 97 } 98 99 {{template "transitivityCheck" .}} 100 return true 101 } 102 if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { 103 t.Errorf("Transitivity test for {{.Name}} failed: %v", err) 104 } 105 ` 106 107 const transitivityMixedBodyRaw = `transFn := func(q *Dense) bool { 108 we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 109 _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok 110 111 {{template "funcoptdecl" . -}} 112 113 r := newRand() 114 a := q.Clone().(*Dense) 115 bv, _ := quick.Value(a.Dtype().Type, r) 116 b := bv.Interface() 117 c := q.Clone().(*Dense) 118 cv, _ := quick.Value(c.Dtype().Type, r) 119 c.Memset(cv.Interface()) 120 121 {{template "axb" . }} 122 if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ 123 if err != nil { 124 return false 125 } 126 return true 127 } 128 129 {{template "bxc" . }} 130 if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", c, b, we, err); retEarly{ 131 if err != nil { 132 return false 133 } 134 return true 135 } 136 137 {{template "axc" . }} 138 if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ 139 if err != nil { 140 return false 141 } 142 return true 143 } 144 145 {{template "transitivityCheck" .}} 146 return true 147 } 148 if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { 149 t.Errorf("Transitivity test for {{.Name}} failed: %v", err) 150 } 151 ` 152 153 const symmetryBodyRaw = `symFn := func(q *Dense) bool { 154 we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 155 _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok 156 157 {{template "funcoptdecl" . -}} 158 159 r := newRand() 160 a := q.Clone().(*Dense) 161 b := q.Clone().(*Dense) 162 163 bv, _ := quick.Value(b.Dtype().Type, r) 164 b.Memset(bv.Interface()) 165 166 {{template "axb" .}} 167 if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ 168 if err != nil { 169 return false 170 } 171 return true 172 } 173 174 {{template "bxa" .}} 175 if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ 176 if err != nil { 177 return false 178 } 179 return true 180 } 181 return reflect.DeepEqual(axb.Data(), bxa.Data()) 182 183 } 184 if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { 185 t.Errorf("Transitivity test for {{.Name}} failed: %v", err) 186 } 187 ` 188 189 const symmetryMixedBodyRaw = `symFn := func(q *Dense) bool { 190 we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) 191 _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok 192 193 {{template "funcoptdecl" . -}} 194 195 r := newRand() 196 a := q.Clone().(*Dense) 197 bv, _ := quick.Value(a.Dtype().Type, r) 198 b := bv.Interface() 199 200 {{template "axb" .}} 201 if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ 202 if err != nil { 203 return false 204 } 205 return true 206 } 207 208 {{template "bxa" .}} 209 if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ 210 if err != nil { 211 return false 212 } 213 return true 214 } 215 return reflect.DeepEqual(axb.Data(), bxa.Data()) 216 217 } 218 if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { 219 t.Errorf("Symmetry test for {{.Name}} failed: %v", err) 220 } 221 ` 222 223 type CmpTest struct { 224 cmpOp 225 scalars bool 226 lvl Level 227 FuncOpt string 228 EqFailTypeClassName string 229 } 230 231 func (fn *CmpTest) Name() string { 232 if fn.cmpOp.Name() == "Eq" || fn.cmpOp.Name() == "Ne" { 233 return "El" + fn.cmpOp.Name() 234 } 235 return fn.cmpOp.Name() 236 } 237 238 func (fn *CmpTest) Level() string { 239 switch fn.lvl { 240 case API: 241 return "API" 242 case Dense: 243 return "Dense" 244 } 245 return "" 246 } 247 248 func (fn *CmpTest) Signature() *Signature { 249 var name string 250 switch fn.lvl { 251 case API: 252 name = fmt.Sprintf("Test%s", fn.cmpOp.Name()) 253 case Dense: 254 name = fmt.Sprintf("TestDense_%s", fn.Name()) 255 } 256 if fn.scalars { 257 name += "Scalar" 258 } 259 if fn.FuncOpt != "" { 260 name += "_" + fn.FuncOpt 261 } 262 return &Signature{ 263 Name: name, 264 NameTemplate: plainName, 265 ParamNames: []string{"t"}, 266 ParamTemplates: []*template.Template{testingType}, 267 } 268 } 269 270 func (fn *CmpTest) canWrite() bool { 271 return fn.IsTransitive || fn.IsSymmetric 272 } 273 274 func (fn *CmpTest) WriteBody(w io.Writer) { 275 if fn.IsTransitive { 276 fn.writeTransitivity(w) 277 fmt.Fprintf(w, "\n") 278 } 279 if fn.IsSymmetric { 280 fn.writeSymmetry(w) 281 } 282 } 283 284 func (fn *CmpTest) writeTransitivity(w io.Writer) { 285 var t *template.Template 286 if fn.scalars { 287 t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityMixedBodyRaw)) 288 } else { 289 t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityBodyRaw)) 290 } 291 292 switch fn.lvl { 293 case API: 294 if fn.scalars { 295 template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) 296 template.Must(t.New("bxc").Parse(APICallMixedbxcRaw)) 297 template.Must(t.New("axc").Parse(APICallMixedaxcRaw)) 298 } else { 299 template.Must(t.New("axb").Parse(APICallVVaxbRaw)) 300 template.Must(t.New("bxc").Parse(APICallVVbxcRaw)) 301 template.Must(t.New("axc").Parse(APICallVVaxcRaw)) 302 } 303 case Dense: 304 if fn.scalars { 305 template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) 306 template.Must(t.New("bxc").Parse(DenseMethodCallMixedbxcRaw)) 307 template.Must(t.New("axc").Parse(DenseMethodCallMixedaxcRaw)) 308 } else { 309 template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) 310 template.Must(t.New("bxc").Parse(DenseMethodCallVVbxcRaw)) 311 template.Must(t.New("axc").Parse(DenseMethodCallVVaxcRaw)) 312 } 313 } 314 template.Must(t.New("transitivityCheck").Parse(transitivityCheckRaw)) 315 template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) 316 template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) 317 template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) 318 template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) 319 320 t.Execute(w, fn) 321 } 322 323 func (fn *CmpTest) writeSymmetry(w io.Writer) { 324 var t *template.Template 325 if fn.scalars { 326 t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryMixedBodyRaw)) 327 } else { 328 t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryBodyRaw)) 329 } 330 331 switch fn.lvl { 332 case API: 333 if fn.scalars { 334 template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) 335 template.Must(t.New("bxa").Parse(APICallMixedbxaRaw)) 336 } else { 337 template.Must(t.New("axb").Parse(APICallVVaxbRaw)) 338 template.Must(t.New("bxa").Parse(APICallVVbxaRaw)) 339 } 340 case Dense: 341 if fn.scalars { 342 template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) 343 template.Must(t.New("bxa").Parse(DenseMethodCallMixedbxaRaw)) 344 } else { 345 template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) 346 template.Must(t.New("bxa").Parse(DenseMethodCallVVbxaRaw)) 347 } 348 } 349 template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) 350 template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) 351 template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) 352 template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) 353 354 t.Execute(w, fn) 355 } 356 357 func (fn *CmpTest) Write(w io.Writer) { 358 sig := fn.Signature() 359 w.Write([]byte("func ")) 360 sig.Write(w) 361 w.Write([]byte("{\n")) 362 fn.WriteBody(w) 363 w.Write([]byte("}\n")) 364 } 365 366 func generateAPICmpTests(f io.Writer, ak Kinds) { 367 var tests []*CmpTest 368 369 for _, op := range cmpBinOps { 370 t := &CmpTest{ 371 cmpOp: op, 372 lvl: API, 373 EqFailTypeClassName: "nil", 374 } 375 tests = append(tests, t) 376 } 377 378 for _, fn := range tests { 379 if fn.canWrite() { 380 fn.Write(f) 381 } 382 fn.FuncOpt = "assame" 383 fn.TypeClassName = "nonComplexNumberTypes" 384 } 385 for _, fn := range tests { 386 if fn.canWrite() { 387 fn.Write(f) 388 } 389 } 390 391 } 392 393 func generateAPICmpMixedTests(f io.Writer, ak Kinds) { 394 var tests []*CmpTest 395 396 for _, op := range cmpBinOps { 397 t := &CmpTest{ 398 cmpOp: op, 399 lvl: API, 400 scalars: true, 401 EqFailTypeClassName: "nil", 402 } 403 tests = append(tests, t) 404 } 405 406 for _, fn := range tests { 407 if fn.canWrite() { 408 fn.Write(f) 409 } 410 fn.FuncOpt = "assame" 411 fn.TypeClassName = "nonComplexNumberTypes" 412 } 413 for _, fn := range tests { 414 if fn.canWrite() { 415 fn.Write(f) 416 } 417 } 418 } 419 420 func generateDenseMethodCmpTests(f io.Writer, ak Kinds) { 421 var tests []*CmpTest 422 423 for _, op := range cmpBinOps { 424 t := &CmpTest{ 425 cmpOp: op, 426 lvl: Dense, 427 EqFailTypeClassName: "nil", 428 } 429 tests = append(tests, t) 430 } 431 432 for _, fn := range tests { 433 if fn.canWrite() { 434 fn.Write(f) 435 } 436 fn.FuncOpt = "assame" 437 fn.TypeClassName = "nonComplexNumberTypes" 438 } 439 for _, fn := range tests { 440 if fn.canWrite() { 441 fn.Write(f) 442 } 443 } 444 } 445 446 func generateDenseMethodCmpMixedTests(f io.Writer, ak Kinds) { 447 var tests []*CmpTest 448 449 for _, op := range cmpBinOps { 450 t := &CmpTest{ 451 cmpOp: op, 452 lvl: Dense, 453 scalars: true, 454 EqFailTypeClassName: "nil", 455 } 456 tests = append(tests, t) 457 } 458 459 for _, fn := range tests { 460 if fn.canWrite() { 461 fn.Write(f) 462 } 463 fn.FuncOpt = "assame" 464 fn.TypeClassName = "nonComplexNumberTypes" 465 } 466 for _, fn := range tests { 467 if fn.canWrite() { 468 fn.Write(f) 469 } 470 } 471 }