github.com/wzzhu/tensor@v0.9.24/genlib2/generic_cmp.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "text/template" 7 ) 8 9 type GenericVecVecCmp struct { 10 TypedBinOp 11 RetSame bool 12 Iter bool 13 } 14 15 func (fn *GenericVecVecCmp) Name() string { 16 switch { 17 case fn.Iter && fn.RetSame: 18 return fmt.Sprintf("%sSameIter", fn.TypedBinOp.Name()) 19 case fn.Iter && !fn.RetSame: 20 return fmt.Sprintf("%sIter", fn.TypedBinOp.Name()) 21 case !fn.Iter && fn.RetSame: 22 return fmt.Sprintf("%sSame", fn.TypedBinOp.Name()) 23 default: 24 return fn.TypedBinOp.Name() 25 } 26 } 27 28 func (fn *GenericVecVecCmp) Signature() *Signature { 29 var paramNames []string 30 var paramTemplates []*template.Template 31 var err bool 32 33 switch { 34 case fn.Iter && fn.RetSame: 35 paramNames = []string{"a", "b", "ait", "bit"} 36 paramTemplates = []*template.Template{sliceType, sliceType, iteratorType, iteratorType} 37 err = true 38 case fn.Iter && !fn.RetSame: 39 paramNames = []string{"a", "b", "retVal", "ait", "bit", "rit"} 40 paramTemplates = []*template.Template{sliceType, sliceType, boolsType, iteratorType, iteratorType, iteratorType} 41 err = true 42 case !fn.Iter && fn.RetSame: 43 paramNames = []string{"a", "b"} 44 paramTemplates = []*template.Template{sliceType, sliceType} 45 default: 46 paramNames = []string{"a", "b", "retVal"} 47 paramTemplates = []*template.Template{sliceType, sliceType, boolsType} 48 } 49 50 return &Signature{ 51 Name: fn.Name(), 52 NameTemplate: typeAnnotatedName, 53 ParamNames: paramNames, 54 ParamTemplates: paramTemplates, 55 56 Kind: fn.Kind(), 57 Err: err, 58 } 59 } 60 61 func (fn *GenericVecVecCmp) WriteBody(w io.Writer) { 62 var Range, Left, Right string 63 var Index0, Index1, Index2 string 64 var IterName0, IterName1, IterName2 string 65 var T *template.Template 66 67 Range = "a" 68 Left = "a[i]" 69 Right = "b[j]" 70 Index0 = "i" 71 Index1 = "j" 72 switch { 73 case fn.Iter && fn.RetSame: 74 IterName0 = "ait" 75 IterName1 = "bit" 76 T = template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericBinaryIterLoopRaw)) 77 template.Must(T.New("loopbody").Funcs(funcs).Parse(sameSet)) 78 case fn.Iter && !fn.RetSame: 79 Range = "retVal" 80 Index2 = "k" 81 IterName0 = "ait" 82 IterName1 = "bit" 83 IterName2 = "rit" 84 T = template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericTernaryIterLoopRaw)) 85 template.Must(T.New("loopbody").Funcs(funcs).Parse(ternaryIterSet)) 86 case !fn.Iter && fn.RetSame: 87 Right = "b[i]" 88 T = template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericLoopRaw)) 89 template.Must(T.New("loopbody").Funcs(funcs).Parse(sameSet)) 90 default: 91 Range = "retVal" 92 Right = "b[i]" 93 T = template.Must(template.New(fn.Name()).Funcs(funcs).Parse(genericLoopRaw)) 94 template.Must(T.New("loopbody").Funcs(funcs).Parse(basicSet)) 95 } 96 template.Must(T.New("opDo").Funcs(funcs).Parse(binOpDo)) 97 template.Must(T.New("callFunc").Funcs(funcs).Parse("")) 98 template.Must(T.New("check").Funcs(funcs).Parse("")) 99 template.Must(T.New("symbol").Funcs(funcs).Parse(fn.SymbolTemplate())) 100 101 lb := LoopBody{ 102 TypedOp: fn.TypedBinOp, 103 Range: Range, 104 Left: Left, 105 Right: Right, 106 107 Index0: Index0, 108 Index1: Index1, 109 Index2: Index2, 110 111 IterName0: IterName0, 112 IterName1: IterName1, 113 IterName2: IterName2, 114 } 115 T.Execute(w, lb) 116 } 117 118 func (fn *GenericVecVecCmp) Write(w io.Writer) { 119 sig := fn.Signature() 120 w.Write([]byte("func ")) 121 sig.Write(w) 122 switch { 123 case !fn.Iter && !fn.RetSame: 124 w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]; retVal=retVal[:len(a)]\n")) 125 case !fn.Iter && fn.RetSame: 126 w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]\n")) 127 default: 128 w.Write([]byte("{")) 129 } 130 fn.WriteBody(w) 131 if sig.Err { 132 w.Write([]byte("\n return\n")) 133 } 134 w.Write([]byte("}\n\n")) 135 } 136 137 type GenericMixedCmp struct { 138 GenericVecVecCmp 139 LeftVec bool 140 } 141 142 func (fn *GenericMixedCmp) Name() string { 143 n := fn.GenericVecVecCmp.Name() 144 if fn.LeftVec { 145 n += "VS" 146 } else { 147 n += "SV" 148 } 149 return n 150 } 151 152 func (fn *GenericMixedCmp) Signature() *Signature { 153 var paramNames []string 154 var paramTemplates []*template.Template 155 var err bool 156 157 switch { 158 case fn.Iter && !fn.RetSame: 159 paramNames = []string{"a", "b", "retVal", "ait", "rit"} 160 paramTemplates = []*template.Template{sliceType, sliceType, boolsType, iteratorType, iteratorType} 161 err = true 162 case fn.Iter && fn.RetSame: 163 paramNames = []string{"a", "b", "ait"} 164 paramTemplates = []*template.Template{sliceType, sliceType, iteratorType} 165 err = true 166 case !fn.Iter && fn.RetSame: 167 paramNames = []string{"a", "b"} 168 paramTemplates = []*template.Template{sliceType, sliceType} 169 default: 170 paramNames = []string{"a", "b", "retVal"} 171 paramTemplates = []*template.Template{sliceType, sliceType, boolsType} 172 } 173 if fn.LeftVec { 174 paramTemplates[1] = scalarType 175 } else { 176 paramTemplates[0] = scalarType 177 if fn.Iter && !fn.RetSame { 178 paramNames[3] = "bit" 179 } else if fn.Iter && fn.RetSame { 180 paramNames[2] = "bit" 181 } 182 } 183 return &Signature{ 184 Name: fn.Name(), 185 NameTemplate: typeAnnotatedName, 186 ParamNames: paramNames, 187 ParamTemplates: paramTemplates, 188 189 Kind: fn.Kind(), 190 Err: err, 191 } 192 } 193 194 func (fn *GenericMixedCmp) WriteBody(w io.Writer) { 195 var Range, Left, Right string 196 var Index0, Index1 string 197 var IterName0, IterName1 string 198 var T *template.Template 199 200 Range = "a" 201 Left = "a[i]" 202 Right = "b[i]" 203 Index0 = "i" 204 205 T = template.New(fn.Name()).Funcs(funcs) 206 switch { 207 case fn.Iter && !fn.RetSame: 208 Range = "retVal" 209 T = template.Must(T.Parse(genericBinaryIterLoopRaw)) 210 template.Must(T.New("loopbody").Parse(ternaryIterSet)) 211 case fn.Iter && fn.RetSame: 212 T = template.Must(T.Parse(genericUnaryIterLoopRaw)) 213 template.Must(T.New("loopbody").Parse(sameSet)) 214 case !fn.Iter && fn.RetSame: 215 T = template.Must(T.Parse(genericLoopRaw)) 216 template.Must(T.New("loopbody").Parse(sameSet)) 217 default: 218 T = template.Must(T.Parse(genericLoopRaw)) 219 template.Must(T.New("loopbody").Parse(basicSet)) 220 } 221 222 if fn.LeftVec { 223 Right = "b" 224 } else { 225 Left = "a" 226 } 227 if !fn.RetSame { 228 Range = "retVal" 229 } else { 230 if !fn.LeftVec { 231 Range = "b" 232 } 233 } 234 235 switch { 236 case fn.Iter && !fn.RetSame && fn.LeftVec: 237 IterName0 = "ait" 238 IterName1 = "rit" 239 Index1 = "k" 240 case fn.Iter && fn.RetSame && fn.LeftVec: 241 IterName0 = "ait" 242 case fn.Iter && !fn.RetSame && !fn.LeftVec: 243 IterName0 = "bit" 244 IterName1 = "rit" 245 Index1 = "k" 246 case fn.Iter && fn.RetSame && !fn.LeftVec: 247 IterName0 = "bit" 248 } 249 250 template.Must(T.New("callFunc").Parse("")) 251 template.Must(T.New("opDo").Parse(binOpDo)) 252 template.Must(T.New("symbol").Parse(fn.SymbolTemplate())) 253 template.Must(T.New("check").Parse("")) 254 255 lb := LoopBody{ 256 TypedOp: fn.TypedBinOp, 257 Range: Range, 258 Left: Left, 259 Right: Right, 260 261 Index0: Index0, 262 Index1: Index1, 263 IterName0: IterName0, 264 IterName1: IterName1, 265 } 266 T.Execute(w, lb) 267 } 268 269 func (fn *GenericMixedCmp) Write(w io.Writer) { 270 sig := fn.Signature() 271 w.Write([]byte("func ")) 272 sig.Write(w) 273 w.Write([]byte("{ \n")) 274 fn.WriteBody(w) 275 if sig.Err { 276 w.Write([]byte("\nreturn\n")) 277 } 278 w.Write([]byte("}\n\n")) 279 } 280 281 func makeGenericVecVecCmps(tbo []TypedBinOp) (retVal []*GenericVecVecCmp) { 282 for _, tb := range tbo { 283 if tc := tb.TypeClass(); tc != nil && !tc(tb.Kind()) { 284 continue 285 } 286 fn := &GenericVecVecCmp{ 287 TypedBinOp: tb, 288 } 289 retVal = append(retVal, fn) 290 } 291 return 292 } 293 294 func makeGenericMixedCmps(tbo []TypedBinOp) (retVal []*GenericMixedCmp) { 295 for _, tb := range tbo { 296 if tc := tb.TypeClass(); tc != nil && !tc(tb.Kind()) { 297 continue 298 } 299 fn := &GenericMixedCmp{ 300 GenericVecVecCmp: GenericVecVecCmp{ 301 TypedBinOp: tb, 302 }, 303 } 304 retVal = append(retVal, fn) 305 } 306 return 307 } 308 309 func generateGenericVecVecCmp(f io.Writer, ak Kinds) { 310 gen := makeGenericVecVecCmps(typedCmps) 311 for _, g := range gen { 312 g.Write(f) 313 g.RetSame = true 314 315 } 316 for _, g := range gen { 317 if isBoolRepr(g.Kind()) { 318 g.Write(f) 319 } 320 g.RetSame = false 321 g.Iter = true 322 } 323 for _, g := range gen { 324 g.Write(f) 325 g.RetSame = true 326 } 327 for _, g := range gen { 328 if isBoolRepr(g.Kind()) { 329 g.Write(f) 330 } 331 } 332 } 333 334 func generateGenericMixedCmp(f io.Writer, ak Kinds) { 335 gen := makeGenericMixedCmps(typedCmps) 336 for _, g := range gen { 337 g.Write(f) 338 g.RetSame = true 339 } 340 for _, g := range gen { 341 if isBoolRepr(g.Kind()) { 342 g.Write(f) 343 } 344 g.RetSame = false 345 g.Iter = true 346 } 347 for _, g := range gen { 348 g.Write(f) 349 g.RetSame = true 350 } 351 for _, g := range gen { 352 if isBoolRepr(g.Kind()) { 353 g.Write(f) 354 } 355 g.LeftVec = true 356 g.RetSame = false 357 g.Iter = false 358 } 359 360 // VS 361 362 for _, g := range gen { 363 g.Write(f) 364 g.RetSame = true 365 } 366 for _, g := range gen { 367 if isBoolRepr(g.Kind()) { 368 g.Write(f) 369 } 370 g.RetSame = false 371 g.Iter = true 372 } 373 for _, g := range gen { 374 g.Write(f) 375 g.RetSame = true 376 } 377 for _, g := range gen { 378 if isBoolRepr(g.Kind()) { 379 g.Write(f) 380 } 381 } 382 } 383 384 /* OTHER */ 385 386 // element wise Min/Max 387 const genericElMinMaxRaw = `func VecMin{{short . | title}}(a, b []{{asType .}}) { 388 a = a[:len(a)] 389 b = b[:len(a)] 390 for i, v := range a { 391 bv := b[i] 392 if bv < v { 393 a[i] = bv 394 } 395 } 396 } 397 398 func MinSV{{short . | title}}(a {{asType .}}, b []{{asType .}}){ 399 for i := range b { 400 if a < b[i]{ 401 b[i] = a 402 } 403 } 404 } 405 406 func MinVS{{short . | title}}(a []{{asType .}}, b {{asType .}}){ 407 for i := range a { 408 if b < a[i]{ 409 a[i] = b 410 } 411 } 412 } 413 414 func VecMax{{short . | title}}(a, b []{{asType .}}) { 415 a = a[:len(a)] 416 b = b[:len(a)] 417 for i, v := range a { 418 bv := b[i] 419 if bv > v { 420 a[i] = bv 421 } 422 } 423 } 424 425 426 427 func MaxSV{{short . | title}}(a {{asType .}}, b []{{asType .}}){ 428 for i := range b { 429 if a > b[i]{ 430 b[i] = a 431 } 432 } 433 } 434 435 func MaxVS{{short . | title}}(a []{{asType .}}, b {{asType .}}){ 436 for i := range a { 437 if b > a[i]{ 438 a[i] = b 439 } 440 } 441 } 442 ` 443 444 // Iter Min/Max 445 const genericIterMinMaxRaw = `func MinIterSV{{short . | title}}(a {{asType .}}, b []{{asType .}}, bit Iterator) (err error){ 446 var i int 447 var validi bool 448 for { 449 if i, validi, err = bit.NextValidity(); err != nil{ 450 err = handleNoOp(err) 451 break 452 } 453 if validi { 454 if a < b[i] { 455 b[i] = a 456 } 457 } 458 } 459 return 460 } 461 462 func MinIterVS{{short . | title}}(a []{{asType .}}, b {{asType .}}, ait Iterator) (err error){ 463 var i int 464 var validi bool 465 for { 466 if i, validi, err = ait.NextValidity(); err != nil{ 467 err = handleNoOp(err) 468 break 469 } 470 if validi { 471 if b < a[i] { 472 a[i] = b 473 } 474 } 475 } 476 return 477 } 478 479 func VecMinIter{{short . | title}}(a , b []{{asType .}}, ait, bit Iterator) (err error){ 480 var i,j int 481 var validi ,validj bool 482 for { 483 if i, validi, err = ait.NextValidity(); err != nil{ 484 err = handleNoOp(err) 485 break 486 } 487 if j, validj, err = bit.NextValidity(); err != nil{ 488 err = handleNoOp(err) 489 break 490 } 491 if validi && validj { 492 if b[j] < a[i] { 493 a[i] = b[j] 494 } 495 } 496 } 497 return 498 } 499 500 501 func MaxIterSV{{short . | title}}(a {{asType .}}, b []{{asType .}}, bit Iterator) (err error){ 502 var i int 503 var validi bool 504 for { 505 if i, validi, err = bit.NextValidity(); err != nil{ 506 err = handleNoOp(err) 507 break 508 } 509 if validi { 510 if a > b[i] { 511 b[i] = a 512 } 513 } 514 } 515 return 516 } 517 518 func MaxIterVS{{short . | title}}(a []{{asType .}}, b {{asType .}}, ait Iterator) (err error){ 519 var i int 520 var validi bool 521 for { 522 if i, validi, err = ait.NextValidity(); err != nil{ 523 err = handleNoOp(err) 524 break 525 } 526 if validi { 527 if b > a[i] { 528 a[i] = b 529 } 530 } 531 } 532 return 533 } 534 535 func VecMaxIter{{short . | title}}(a , b []{{asType .}}, ait, bit Iterator) (err error){ 536 var i,j int 537 var validi, validj bool 538 for { 539 if i, validi, err = ait.NextValidity(); err != nil{ 540 err = handleNoOp(err) 541 break 542 } 543 if j, validj, err = bit.NextValidity(); err != nil{ 544 err = handleNoOp(err) 545 break 546 } 547 if validi && validj { 548 if b[j] > a[i] { 549 a[i] = b[j] 550 } 551 } 552 } 553 return 554 } 555 556 ` 557 558 // scalar Min/Max 559 const genericScalarMinMaxRaw = `func Min{{short .}}(a, b {{asType .}}) (c {{asType .}}) {if a < b { 560 return a 561 } 562 return b 563 } 564 565 566 func Max{{short .}}(a, b {{asType .}}) (c {{asType .}}) {if a > b { 567 return a 568 } 569 return b 570 } 571 ` 572 573 var ( 574 genericElMinMax *template.Template 575 genericMinMax *template.Template 576 genericElMinMaxIter *template.Template 577 ) 578 579 func init() { 580 genericElMinMax = template.Must(template.New("genericVecVecMinMax").Funcs(funcs).Parse(genericElMinMaxRaw)) 581 genericMinMax = template.Must(template.New("genericMinMax").Funcs(funcs).Parse(genericScalarMinMaxRaw)) 582 genericElMinMaxIter = template.Must(template.New("genericIterMinMax").Funcs(funcs).Parse(genericIterMinMaxRaw)) 583 } 584 585 func generateMinMax(f io.Writer, ak Kinds) { 586 for _, k := range filter(ak.Kinds, isOrd) { 587 genericElMinMax.Execute(f, k) 588 } 589 590 for _, k := range filter(ak.Kinds, isOrd) { 591 genericMinMax.Execute(f, k) 592 } 593 594 for _, k := range filter(ak.Kinds, isOrd) { 595 genericElMinMaxIter.Execute(f, k) 596 } 597 }