github.com/wzzhu/tensor@v0.9.24/genlib2/internaleng.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "reflect" 7 "strings" 8 "text/template" 9 ) 10 11 type InternalEngArithMethod struct { 12 BinOp 13 Kinds []reflect.Kind 14 Incr bool 15 Iter bool 16 WithRecv bool 17 } 18 19 type eLoopBody struct { 20 BinOp 21 22 Err bool 23 Kinds []reflect.Kind 24 } 25 26 func (fn *InternalEngArithMethod) Name() string { 27 switch { 28 case fn.Incr && fn.Iter: 29 return fmt.Sprintf("%sIterIncr", fn.BinOp.Name()) 30 case fn.Incr && !fn.Iter: 31 return fmt.Sprintf("%sIncr", fn.BinOp.Name()) 32 case !fn.Incr && fn.Iter: 33 return fmt.Sprintf("%sIter", fn.BinOp.Name()) 34 case fn.WithRecv: 35 return fmt.Sprintf("%sRecv", fn.BinOp.Name()) 36 default: 37 return fn.BinOp.Name() 38 } 39 } 40 41 func (fn *InternalEngArithMethod) Signature() *Signature { 42 var paramNames []string 43 var paramTemplates []*template.Template 44 switch { 45 case fn.Iter && fn.Incr: 46 paramNames = []string{"t", "a", "b", "incr", "ait", "bit", "iit"} 47 paramTemplates = []*template.Template{reflectType, arrayType, arrayType, arrayType, iteratorType, iteratorType, iteratorType} 48 case fn.Iter && !fn.Incr: 49 paramNames = []string{"t", "a", "b", "ait", "bit"} 50 paramTemplates = []*template.Template{reflectType, arrayType, arrayType, iteratorType, iteratorType} 51 case !fn.Iter && fn.Incr: 52 paramNames = []string{"t", "a", "b", "incr"} 53 paramTemplates = []*template.Template{reflectType, arrayType, arrayType, arrayType} 54 case fn.WithRecv: 55 paramNames = []string{"t", "a", "b", "recv"} 56 paramTemplates = []*template.Template{reflectType, arrayType, arrayType, arrayType} 57 default: 58 paramNames = []string{"t", "a", "b"} 59 paramTemplates = []*template.Template{reflectType, arrayType, arrayType} 60 61 } 62 63 return &Signature{ 64 Name: fn.Name(), 65 NameTemplate: plainName, 66 ParamNames: paramNames, 67 ParamTemplates: paramTemplates, 68 Err: true, 69 } 70 } 71 72 func (fn *InternalEngArithMethod) WriteBody(w io.Writer) { 73 var T *template.Template 74 switch { 75 case fn.Incr && fn.Iter: 76 T = eArithIterIncr 77 case fn.Incr && !fn.Iter: 78 T = eArithIncr 79 case fn.Iter && !fn.Incr: 80 T = eArithIter 81 case fn.WithRecv: 82 T = eArithRecv 83 default: 84 T = eArith 85 } 86 lb := eLoopBody{ 87 BinOp: fn.BinOp, 88 Kinds: fn.Kinds, 89 } 90 T.Execute(w, lb) 91 } 92 93 func (fn *InternalEngArithMethod) Write(w io.Writer) { 94 w.Write([]byte("func (e E) ")) 95 sig := fn.Signature() 96 sig.Write(w) 97 w.Write([]byte("{ \n")) 98 fn.WriteBody(w) 99 w.Write([]byte("}\n\n")) 100 } 101 102 func generateEArith(f io.Writer, kinds Kinds) { 103 var methods []*InternalEngArithMethod 104 for _, bo := range arithBinOps { 105 var ks []reflect.Kind 106 for _, k := range kinds.Kinds { 107 if tc := bo.TypeClass(); tc != nil && tc(k) { 108 ks = append(ks, k) 109 } 110 } 111 meth := &InternalEngArithMethod{ 112 BinOp: bo, 113 Kinds: ks, 114 } 115 methods = append(methods, meth) 116 } 117 118 // write vanilla 119 for _, meth := range methods { 120 meth.Write(f) 121 meth.Incr = true 122 } 123 124 // write incr 125 for _, meth := range methods { 126 meth.Write(f) 127 meth.Incr = false 128 meth.Iter = true 129 } 130 131 // write iter 132 for _, meth := range methods { 133 meth.Write(f) 134 meth.Incr = true 135 } 136 137 // write iter incr 138 for _, meth := range methods { 139 meth.Write(f) 140 meth.Incr = false 141 meth.Iter = false 142 } 143 144 // write recv 145 for _, meth := range methods { 146 meth.WithRecv = true 147 meth.Write(f) 148 } 149 } 150 151 /* MAP */ 152 153 type InternalEngMap struct { 154 Kinds []reflect.Kind 155 Iter bool 156 } 157 158 func (fn *InternalEngMap) Signature() *Signature { 159 paramNames := []string{"t", "fn", "a", "incr"} 160 paramTemplates := []*template.Template{reflectType, interfaceType, arrayType, boolType} 161 name := "Map" 162 if fn.Iter { 163 paramNames = append(paramNames, "ait") 164 paramTemplates = append(paramTemplates, iteratorType) 165 name += "Iter" 166 } 167 168 return &Signature{ 169 Name: name, 170 NameTemplate: plainName, 171 ParamNames: paramNames, 172 ParamTemplates: paramTemplates, 173 Err: true, 174 } 175 } 176 177 func (fn *InternalEngMap) WriteBody(w io.Writer) { 178 T := eMap 179 if fn.Iter { 180 T = eMapIter 181 } 182 template.Must(T.New("fntype0").Funcs(funcs).Parse(unaryFuncTypeRaw)) 183 template.Must(T.New("fntype1").Funcs(funcs).Parse(unaryFuncErrTypeRaw)) 184 185 lb := eLoopBody{ 186 Kinds: fn.Kinds, 187 } 188 T.Execute(w, lb) 189 } 190 191 func (fn *InternalEngMap) Write(w io.Writer) { 192 w.Write([]byte("func (e E) ")) 193 sig := fn.Signature() 194 sig.Write(w) 195 w.Write([]byte("{ \n")) 196 fn.WriteBody(w) 197 198 w.Write([]byte("\nreturn\n}\n\n")) 199 } 200 201 func generateEMap(f io.Writer, kinds Kinds) { 202 m := new(InternalEngMap) 203 for _, k := range kinds.Kinds { 204 if isParameterized(k) { 205 continue 206 } 207 m.Kinds = append(m.Kinds, k) 208 } 209 m.Write(f) 210 m.Iter = true 211 m.Write(f) 212 } 213 214 /* Cmp */ 215 216 // InternalEngCmpMethod is exactly the same structure as the arith one, except it's Same instead of Incr. 217 // Some copy and paste leads to more clarity, rather than reusing the structure. 218 type InternalEngCmp struct { 219 BinOp 220 Kinds []reflect.Kind 221 RetSame bool 222 Iter bool 223 } 224 225 func (fn *InternalEngCmp) Name() string { 226 switch { 227 case fn.Iter && fn.RetSame: 228 return fmt.Sprintf("%sSameIter", fn.BinOp.Name()) 229 case fn.Iter && !fn.RetSame: 230 return fmt.Sprintf("%sIter", fn.BinOp.Name()) 231 case !fn.Iter && fn.RetSame: 232 return fmt.Sprintf("%sSame", fn.BinOp.Name()) 233 default: 234 return fn.BinOp.Name() 235 } 236 } 237 238 func (fn *InternalEngCmp) Signature() *Signature { 239 var paramNames []string 240 var paramTemplates []*template.Template 241 242 switch { 243 case fn.Iter && fn.RetSame: 244 paramNames = []string{"t", "a", "b", "ait", "bit"} 245 paramTemplates = []*template.Template{reflectType, arrayType, arrayType, iteratorType, iteratorType} 246 case fn.Iter && !fn.RetSame: 247 paramNames = []string{"t", "a", "b", "retVal", "ait", "bit", "rit"} 248 paramTemplates = []*template.Template{reflectType, arrayType, arrayType, arrayType, iteratorType, iteratorType, iteratorType} 249 case !fn.Iter && fn.RetSame: 250 paramNames = []string{"t", "a", "b"} 251 paramTemplates = []*template.Template{reflectType, arrayType, arrayType} 252 default: 253 paramNames = []string{"t", "a", "b", "retVal"} 254 paramTemplates = []*template.Template{reflectType, arrayType, arrayType, arrayType} 255 } 256 257 return &Signature{ 258 Name: fn.Name(), 259 NameTemplate: plainName, 260 ParamNames: paramNames, 261 ParamTemplates: paramTemplates, 262 263 Err: true, 264 } 265 } 266 267 func (fn *InternalEngCmp) WriteBody(w io.Writer) { 268 var T *template.Template 269 switch { 270 case fn.Iter && fn.RetSame: 271 T = eCmpSameIter 272 case fn.Iter && !fn.RetSame: 273 T = eCmpBoolIter 274 case !fn.Iter && fn.RetSame: 275 T = eCmpSame 276 default: 277 T = eCmpBool 278 } 279 280 lb := eLoopBody{ 281 BinOp: fn.BinOp, 282 Kinds: fn.Kinds, 283 } 284 T.Execute(w, lb) 285 } 286 287 func (fn *InternalEngCmp) Write(w io.Writer) { 288 w.Write([]byte("func (e E) ")) 289 sig := fn.Signature() 290 sig.Write(w) 291 w.Write([]byte("{ \n")) 292 fn.WriteBody(w) 293 w.Write([]byte("}\n\n")) 294 } 295 296 func generateECmp(f io.Writer, kinds Kinds) { 297 var methods []*InternalEngCmp 298 for _, bo := range cmpBinOps { 299 var ks []reflect.Kind 300 for _, k := range kinds.Kinds { 301 if tc := bo.TypeClass(); tc != nil && tc(k) { 302 ks = append(ks, k) 303 } 304 } 305 meth := &InternalEngCmp{ 306 BinOp: bo, 307 Kinds: ks, 308 } 309 methods = append(methods, meth) 310 } 311 312 for _, meth := range methods { 313 meth.Write(f) 314 meth.RetSame = true 315 } 316 317 for _, meth := range methods { 318 meth.Write(f) 319 meth.RetSame = false 320 meth.Iter = true 321 } 322 for _, meth := range methods { 323 meth.Write(f) 324 meth.RetSame = true 325 } 326 327 for _, meth := range methods { 328 meth.Write(f) 329 } 330 } 331 332 /* MIN/MAX BETWEEN */ 333 334 type InternalEngMinMaxBetween struct { 335 BinOp 336 Kinds []reflect.Kind 337 Iter bool 338 } 339 340 func (fn *InternalEngMinMaxBetween) Name() string { 341 name := fn.BinOp.Name() 342 343 switch { 344 case fn.Iter: 345 return fmt.Sprintf("%sBetweenIter", name) 346 default: 347 return name + "Between" 348 } 349 } 350 351 func (fn *InternalEngMinMaxBetween) Signature() *Signature { 352 var paramNames []string 353 var paramTemplates []*template.Template 354 355 switch { 356 case fn.Iter: 357 paramNames = []string{"t", "a", "b", "ait", "bit"} 358 paramTemplates = []*template.Template{reflectType, arrayType, arrayType, iteratorType, iteratorType} 359 default: 360 paramNames = []string{"t", "a", "b"} 361 paramTemplates = []*template.Template{reflectType, arrayType, arrayType} 362 } 363 return &Signature{ 364 Name: fn.Name(), 365 NameTemplate: plainName, 366 ParamNames: paramNames, 367 ParamTemplates: paramTemplates, 368 369 Err: true, 370 } 371 } 372 373 func (fn *InternalEngMinMaxBetween) WriteBody(w io.Writer) { 374 var T *template.Template 375 376 switch { 377 case fn.Iter: 378 T = eMinMaxIter 379 default: 380 T = eMinMaxSame 381 } 382 383 lb := eLoopBody{ 384 BinOp: fn.BinOp, 385 Kinds: fn.Kinds, 386 } 387 T.Execute(w, lb) 388 } 389 390 func (fn *InternalEngMinMaxBetween) Write(w io.Writer) { 391 w.Write([]byte("func (e E) ")) 392 sig := fn.Signature() 393 sig.Write(w) 394 w.Write([]byte("{\n")) 395 fn.WriteBody(w) 396 w.Write([]byte("}\n\n")) 397 } 398 399 func generateEMinMaxBetween(f io.Writer, kinds Kinds) { 400 minmaxOps := []cmpOp{cmpBinOps[0], cmpBinOps[2]} // Gt and Lt 401 minmaxOps[0].name = "Max" 402 minmaxOps[1].name = "Min" 403 var methods []*InternalEngMinMaxBetween 404 for _, bo := range minmaxOps { 405 var ks []reflect.Kind 406 for _, k := range kinds.Kinds { 407 if tc := bo.TypeClass(); tc != nil && tc(k) { 408 ks = append(ks, k) 409 } 410 } 411 meth := &InternalEngMinMaxBetween{ 412 BinOp: bo, 413 Kinds: ks, 414 } 415 methods = append(methods, meth) 416 } 417 418 for _, meth := range methods { 419 meth.Write(f) 420 meth.Iter = true 421 } 422 for _, meth := range methods { 423 meth.Write(f) 424 } 425 426 } 427 428 /* REDUCTION */ 429 430 type InternalEngReduce struct { 431 Kinds []reflect.Kind 432 433 Dim int // 0 == first dim, -1 == last dim 434 Flat bool 435 } 436 437 func (fn *InternalEngReduce) Name() string { 438 switch { 439 case fn.Flat: 440 return "Reduce" 441 case fn.Dim == 0: 442 return "ReduceFirst" 443 case fn.Dim < 0: 444 return "ReduceLast" 445 case fn.Dim > 0: 446 return "ReduceDefault" 447 } 448 panic("unreachable") 449 } 450 451 func (fn *InternalEngReduce) Signature() *Signature { 452 var paramNames, retVals []string 453 var paramTemplates, retValTemplates []*template.Template 454 455 switch { 456 case fn.Flat: 457 paramNames = []string{"t", "a", "defaultValue", "fn"} 458 paramTemplates = []*template.Template{reflectType, arrayType, interfaceType, interfaceType} 459 retVals = []string{"retVal"} 460 retValTemplates = []*template.Template{interfaceType} 461 case fn.Dim == 0: 462 paramNames = []string{"t", "data", "retVal", "split", "size", "fn"} 463 paramTemplates = []*template.Template{reflectType, arrayType, arrayType, intType, intType, interfaceType} 464 case fn.Dim < 0: 465 paramNames = []string{"t", "data", "retVal", "dimSize", "defaultValue", "fn"} 466 paramTemplates = []*template.Template{reflectType, arrayType, arrayType, intType, interfaceType, interfaceType} 467 case fn.Dim > 0: 468 paramNames = []string{"t", "data", "retVal", "dim0", "dimSize", "outerStride", "stride", "expected", "fn"} 469 paramTemplates = []*template.Template{reflectType, arrayType, arrayType, intType, intType, intType, intType, intType, interfaceType} 470 } 471 return &Signature{ 472 Name: fn.Name(), 473 NameTemplate: plainName, 474 ParamNames: paramNames, 475 ParamTemplates: paramTemplates, 476 RetVals: retVals, 477 RetValTemplates: retValTemplates, 478 Err: true, 479 } 480 } 481 482 func (fn *InternalEngReduce) WriteBody(w io.Writer) { 483 var T *template.Template 484 switch { 485 case fn.Flat: 486 T = eReduce 487 case fn.Dim == 0: 488 T = eReduceFirst 489 case fn.Dim < 0: 490 T = eReduceLast 491 case fn.Dim > 0: 492 T = eReduceDefault 493 } 494 495 T.Execute(w, fn) 496 } 497 498 func (fn *InternalEngReduce) Write(w io.Writer) { 499 w.Write([]byte("func (e E) ")) 500 sig := fn.Signature() 501 sig.Write(w) 502 w.Write([]byte("{ \n")) 503 fn.WriteBody(w) 504 w.Write([]byte("}\n\n")) 505 } 506 507 func generateEReduce(f io.Writer, kinds Kinds) { 508 ks := filter(kinds.Kinds, isNotParameterized) 509 510 fn := &InternalEngReduce{ 511 Kinds: ks, 512 } 513 fn.Write(f) 514 fn.Dim = -1 515 fn.Write(f) 516 fn.Dim = 1 517 fn.Write(f) 518 fn.Flat = true 519 fn.Write(f) 520 } 521 522 /* UNARY */ 523 524 type InternalEngUnary struct { 525 UnaryOp 526 Kinds []reflect.Kind 527 Iter bool 528 } 529 530 func (fn *InternalEngUnary) Signature() *Signature { 531 paramNames := []string{"t", "a"} 532 paramTemplates := []*template.Template{reflectType, arrayType} 533 534 if fn.Iter { 535 paramNames = append(paramNames, "ait") 536 paramTemplates = append(paramTemplates, iteratorType) 537 } 538 539 if strings.HasPrefix(fn.Name(), "Clamp") { 540 paramNames = append(paramNames, "minVal", "maxVal") 541 paramTemplates = append(paramTemplates, interfaceType, interfaceType) 542 } 543 544 return &Signature{ 545 Name: fn.Name(), 546 NameTemplate: plainName, 547 ParamNames: paramNames, 548 ParamTemplates: paramTemplates, 549 Err: true, 550 } 551 } 552 553 func (fn *InternalEngUnary) Name() string { 554 n := fn.UnaryOp.Name() 555 if fn.Iter { 556 n += "Iter" 557 } 558 return n 559 } 560 561 func (fn *InternalEngUnary) WriteBody(w io.Writer) { 562 var T *template.Template 563 switch { 564 case fn.Name() == "Clamp": 565 T = eUnaryClamp 566 case fn.Name() == "ClampIter": 567 T = eUnaryClampIter 568 case fn.Iter: 569 T = eUnaryIter 570 default: 571 T = eUnary 572 } 573 574 T.Execute(w, fn) 575 } 576 577 func (fn *InternalEngUnary) Write(w io.Writer) { 578 w.Write([]byte("func (e E) ")) 579 sig := fn.Signature() 580 sig.Write(w) 581 w.Write([]byte("{ \n")) 582 fn.WriteBody(w) 583 w.Write([]byte("}\n\n")) 584 } 585 586 func generateUncondEUnary(f io.Writer, kinds Kinds) { 587 var unaries []*InternalEngUnary 588 for _, u := range unconditionalUnaries { 589 var ks []reflect.Kind 590 for _, k := range kinds.Kinds { 591 if tc := u.TypeClass(); tc != nil && !tc(k) { 592 continue 593 } 594 ks = append(ks, k) 595 } 596 ieu := &InternalEngUnary{ 597 UnaryOp: u, 598 Kinds: ks, 599 } 600 unaries = append(unaries, ieu) 601 } 602 603 for _, u := range unaries { 604 u.Write(f) 605 u.Iter = true 606 } 607 608 for _, u := range unaries { 609 u.Write(f) 610 } 611 } 612 613 func generateCondEUnary(f io.Writer, kinds Kinds) { 614 var unaries []*InternalEngUnary 615 for _, u := range conditionalUnaries { 616 var ks []reflect.Kind 617 for _, k := range kinds.Kinds { 618 if tc := u.TypeClass(); tc != nil && !tc(k) { 619 continue 620 } 621 // special case for complex 622 if isComplex(k) { 623 continue 624 } 625 ks = append(ks, k) 626 } 627 ieu := &InternalEngUnary{ 628 UnaryOp: u, 629 Kinds: ks, 630 } 631 unaries = append(unaries, ieu) 632 } 633 634 for _, u := range unaries { 635 u.Write(f) 636 u.Iter = true 637 } 638 639 for _, u := range unaries { 640 u.Write(f) 641 } 642 } 643 644 func generateSpecialEUnaries(f io.Writer, kinds Kinds) { 645 var unaries []*InternalEngUnary 646 for _, u := range specialUnaries { 647 var ks []reflect.Kind 648 for _, k := range kinds.Kinds { 649 if tc := u.TypeClass(); tc != nil && !tc(k) { 650 continue 651 } 652 653 ks = append(ks, k) 654 } 655 656 ieu := &InternalEngUnary{ 657 UnaryOp: u, 658 Kinds: ks, 659 } 660 unaries = append(unaries, ieu) 661 } 662 663 for _, u := range unaries { 664 u.Write(f) 665 u.Iter = true 666 } 667 668 for _, u := range unaries { 669 u.Write(f) 670 } 671 } 672 673 /* Argmethods */ 674 675 type InternalEngArgMethod struct { 676 Name string 677 Masked bool 678 Flat bool 679 Kinds []reflect.Kind 680 } 681 682 func (fn *InternalEngArgMethod) Signature() *Signature { 683 var name string 684 var paramNames []string 685 var paramTemplates []*template.Template 686 var retVals []string 687 var retValTemplates []*template.Template 688 var err bool 689 switch { 690 case fn.Masked && fn.Flat: 691 name = fmt.Sprintf("Arg%sFlatMasked", fn.Name) 692 paramNames = []string{"t", "a", "mask"} 693 paramTemplates = []*template.Template{reflectType, arrayType, boolsType} 694 retVals = []string{"retVal"} 695 retValTemplates = []*template.Template{intType} 696 err = false 697 case fn.Masked && !fn.Flat: 698 name = fmt.Sprintf("Arg%sIterMasked", fn.Name) 699 paramNames = []string{"t", "a", "mask", "it", "lastSize"} 700 paramTemplates = []*template.Template{reflectType, arrayType, boolsType, iteratorType, intType} 701 retVals = []string{"indices"} 702 retValTemplates = []*template.Template{intsType} 703 err = true 704 case !fn.Masked && fn.Flat: 705 name = fmt.Sprintf("Arg%sFlat", fn.Name) 706 paramNames = []string{"t", "a"} 707 paramTemplates = []*template.Template{reflectType, arrayType} 708 retVals = []string{"retVal"} 709 retValTemplates = []*template.Template{intType} 710 err = false 711 default: 712 name = fmt.Sprintf("Arg%sIter", fn.Name) 713 paramNames = []string{"t", "a", "it", "lastSize"} 714 paramTemplates = []*template.Template{reflectType, arrayType, iteratorType, intType} 715 retVals = []string{"indices"} 716 retValTemplates = []*template.Template{intsType} 717 err = true 718 } 719 720 return &Signature{ 721 Name: name, 722 NameTemplate: plainName, 723 ParamNames: paramNames, 724 ParamTemplates: paramTemplates, 725 RetVals: retVals, 726 RetValTemplates: retValTemplates, 727 Err: err, 728 } 729 } 730 731 func (fn *InternalEngArgMethod) WriteBody(w io.Writer) { 732 switch { 733 case fn.Masked && fn.Flat: 734 eArgmaxFlatMasked.Execute(w, fn) 735 case fn.Masked && !fn.Flat: 736 eArgmaxMasked.Execute(w, fn) 737 case !fn.Masked && fn.Flat: 738 eArgmaxFlat.Execute(w, fn) 739 default: 740 eArgmax.Execute(w, fn) 741 } 742 } 743 744 func (fn *InternalEngArgMethod) Write(w io.Writer) { 745 sig := fn.Signature() 746 w.Write([]byte("func (e E) ")) 747 sig.Write(w) 748 w.Write([]byte("{\n")) 749 fn.WriteBody(w) 750 w.Write([]byte("}\n")) 751 } 752 753 func generateInternalEngArgmethods(f io.Writer, ak Kinds) { 754 ks := filter(ak.Kinds, isOrd) 755 meths := []*InternalEngArgMethod{ 756 &InternalEngArgMethod{Name: "max", Kinds: ks}, 757 &InternalEngArgMethod{Name: "min", Kinds: ks}, 758 } 759 760 // default 761 for _, fn := range meths { 762 fn.Write(f) 763 fn.Masked = true 764 } 765 // masked 766 for _, fn := range meths { 767 fn.Write(f) 768 fn.Flat = true 769 } 770 // flat masked 771 for _, fn := range meths { 772 fn.Write(f) 773 fn.Flat = true 774 fn.Masked = false 775 } 776 // flat 777 for _, fn := range meths { 778 fn.Write(f) 779 } 780 }