github.com/wzzhu/tensor@v0.9.24/genlib2/generic_arith.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "strings" 7 "text/template" 8 ) 9 10 type GenericVecVecArith struct { 11 TypedBinOp 12 Iter bool 13 Incr bool 14 WithRecv bool // not many BinOps have this 15 Check TypeClass // can be nil 16 CheckTemplate string 17 } 18 19 func (fn *GenericVecVecArith) Name() string { 20 switch { 21 case fn.Iter && fn.Incr: 22 return fmt.Sprintf("%sIterIncr", fn.TypedBinOp.Name()) 23 case fn.Iter && !fn.Incr: 24 return fmt.Sprintf("%sIter", fn.TypedBinOp.Name()) 25 case !fn.Iter && fn.Incr: 26 return fmt.Sprintf("%sIncr", fn.TypedBinOp.Name()) 27 case fn.WithRecv: 28 return fmt.Sprintf("%vRecv", fn.TypedBinOp.Name()) 29 default: 30 return fmt.Sprintf("Vec%s", fn.TypedBinOp.Name()) 31 } 32 } 33 34 func (fn *GenericVecVecArith) Signature() *Signature { 35 var paramNames []string 36 var paramTemplates []*template.Template 37 var err bool 38 39 switch { 40 case fn.Iter && fn.Incr: 41 paramNames = []string{"a", "b", "incr", "ait", "bit", "iit"} 42 paramTemplates = []*template.Template{sliceType, sliceType, sliceType, iteratorType, iteratorType, iteratorType} 43 err = true 44 case fn.Iter && !fn.Incr: 45 paramNames = []string{"a", "b", "ait", "bit"} 46 paramTemplates = []*template.Template{sliceType, sliceType, iteratorType, iteratorType} 47 err = true 48 case !fn.Iter && fn.Incr: 49 paramNames = []string{"a", "b", "incr"} 50 paramTemplates = []*template.Template{sliceType, sliceType, sliceType} 51 case fn.WithRecv: 52 paramNames = []string{"a", "b", "recv"} 53 paramTemplates = []*template.Template{sliceType, sliceType, sliceType} 54 default: 55 paramNames = []string{"a", "b"} 56 paramTemplates = []*template.Template{sliceType, sliceType} 57 } 58 59 if fn.Check != nil { 60 err = true 61 } 62 63 return &Signature{ 64 Name: fn.Name(), 65 NameTemplate: typeAnnotatedName, 66 ParamNames: paramNames, 67 ParamTemplates: paramTemplates, 68 69 Kind: fn.Kind(), 70 Err: err, 71 } 72 } 73 74 func (fn *GenericVecVecArith) WriteBody(w io.Writer) { 75 var Range, Left, Right string 76 var Index0, Index1, Index2 string 77 var IterName0, IterName1, IterName2 string 78 var T *template.Template 79 80 Range = "a" 81 Index0 = "i" 82 Index1 = "j" 83 Left = "a[i]" 84 Right = "b[j]" 85 86 T = template.New(fn.Name()).Funcs(funcs) 87 switch { 88 case fn.Iter && fn.Incr: 89 Range = "incr" 90 Index2 = "k" 91 IterName0 = "ait" 92 IterName1 = "bit" 93 IterName2 = "iit" 94 T = template.Must(T.Parse(genericTernaryIterLoopRaw)) 95 template.Must(T.New("loopbody").Parse(iterIncrLoopBody)) 96 case fn.Iter && !fn.Incr: 97 IterName0 = "ait" 98 IterName1 = "bit" 99 T = template.Must(T.Parse(genericBinaryIterLoopRaw)) 100 template.Must(T.New("loopbody").Parse(basicSet)) 101 case !fn.Iter && fn.Incr: 102 Range = "incr" 103 Right = "b[i]" 104 T = template.Must(T.Parse(genericLoopRaw)) 105 template.Must(T.New("loopbody").Parse(basicIncr)) 106 case fn.WithRecv: 107 Range = "recv" 108 Right = "b[i]" 109 T = template.Must(T.Parse(genericLoopRaw)) 110 template.Must(T.New("loopbody").Parse(basicSet)) 111 default: 112 Right = "b[i]" 113 T = template.Must(T.Parse(genericLoopRaw)) 114 template.Must(T.New("loopbody").Parse(basicSet)) 115 } 116 template.Must(T.New("callFunc").Parse(binOpCallFunc)) 117 template.Must(T.New("opDo").Parse(binOpDo)) 118 template.Must(T.New("symbol").Parse(fn.SymbolTemplate())) 119 120 if fn.Check != nil && fn.Check(fn.Kind()) { 121 w.Write([]byte("var errs errorIndices\n")) 122 } 123 template.Must(T.New("check").Parse(fn.CheckTemplate)) 124 125 lb := LoopBody{ 126 TypedOp: fn.TypedBinOp, 127 Range: Range, 128 Left: Left, 129 Right: Right, 130 131 Index0: Index0, 132 Index1: Index1, 133 Index2: Index2, 134 135 IterName0: IterName0, 136 IterName1: IterName1, 137 IterName2: IterName2, 138 } 139 T.Execute(w, lb) 140 } 141 142 func (fn *GenericVecVecArith) Write(w io.Writer) { 143 sig := fn.Signature() 144 if !fn.Iter && isFloat(fn.Kind()) && !fn.WithRecv { 145 // golinkPragma.Execute(w, fn) 146 w.Write([]byte("func ")) 147 sig.Write(w) 148 if fn.Incr { 149 fmt.Fprintf(w, "{ %v%v(a, b, incr)}\n", vecPkg(fn.Kind()), getalias(fn.Name())) 150 } else { 151 fmt.Fprintf(w, "{ %v%v(a, b)}\n", vecPkg(fn.Kind()), getalias(fn.Name())) 152 } 153 return 154 } 155 156 w.Write([]byte("func ")) 157 sig.Write(w) 158 159 switch { 160 case !fn.Iter && fn.Incr: 161 w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]; incr = incr[:len(a)]\n")) 162 case fn.WithRecv: 163 w.Write([]byte("{\na = a[:len(recv)]; b = b[:len(recv)]\n")) 164 case !fn.Iter && !fn.Incr && !fn.WithRecv: 165 w.Write([]byte("{\na = a[:len(a)]; b = b[:len(a)]\n")) 166 default: 167 w.Write([]byte("{\n")) 168 } 169 fn.WriteBody(w) 170 if sig.Err { 171 if fn.Check != nil { 172 w.Write([]byte("\nif err != nil {\n return\n}\nif len(errs) > 0 {\n return errs }\nreturn nil")) 173 } else { 174 w.Write([]byte("\nreturn\n")) 175 } 176 } 177 w.Write([]byte("}\n\n")) 178 } 179 180 type GenericMixedArith struct { 181 GenericVecVecArith 182 LeftVec bool 183 } 184 185 func (fn *GenericMixedArith) Name() string { 186 n := fn.GenericVecVecArith.Name() 187 n = strings.TrimPrefix(n, "Vec") 188 if fn.LeftVec { 189 n += "VS" 190 } else { 191 n += "SV" 192 } 193 return n 194 } 195 196 func (fn *GenericMixedArith) Signature() *Signature { 197 var paramNames []string 198 var paramTemplates []*template.Template 199 var err bool 200 201 switch { 202 case fn.Iter && fn.Incr: 203 paramNames = []string{"a", "b", "incr", "ait", "iit"} 204 paramTemplates = []*template.Template{sliceType, sliceType, sliceType, iteratorType, iteratorType} 205 if fn.LeftVec { 206 paramTemplates[1] = scalarType 207 } else { 208 paramTemplates[0] = scalarType 209 paramNames[3] = "bit" 210 } 211 err = true 212 case fn.Iter && !fn.Incr: 213 paramNames = []string{"a", "b", "ait"} 214 paramTemplates = []*template.Template{sliceType, sliceType, iteratorType} 215 if fn.LeftVec { 216 paramTemplates[1] = scalarType 217 } else { 218 paramTemplates[0] = scalarType 219 paramNames[2] = "bit" 220 } 221 222 err = true 223 case !fn.Iter && fn.Incr: 224 paramNames = []string{"a", "b", "incr"} 225 paramTemplates = []*template.Template{sliceType, sliceType, sliceType} 226 if fn.LeftVec { 227 paramTemplates[1] = scalarType 228 } else { 229 paramTemplates[0] = scalarType 230 } 231 232 default: 233 paramNames = []string{"a", "b"} 234 paramTemplates = []*template.Template{sliceType, sliceType} 235 if fn.LeftVec { 236 paramTemplates[1] = scalarType 237 } else { 238 paramTemplates[0] = scalarType 239 } 240 } 241 242 if fn.Check != nil { 243 err = true 244 } 245 246 return &Signature{ 247 Name: fn.Name(), 248 NameTemplate: typeAnnotatedName, 249 ParamNames: paramNames, 250 ParamTemplates: paramTemplates, 251 252 Kind: fn.Kind(), 253 Err: err, 254 } 255 } 256 257 func (fn *GenericMixedArith) WriteBody(w io.Writer) { 258 var Range, Left, Right string 259 var Index0, Index1 string 260 var IterName0, IterName1 string 261 262 Range = "a" 263 Left = "a[i]" 264 Right = "b[i]" 265 Index0 = "i" 266 267 T := template.New(fn.Name()).Funcs(funcs) 268 switch { 269 case fn.Iter && fn.Incr: 270 Range = "incr" 271 T = template.Must(T.Parse(genericBinaryIterLoopRaw)) 272 template.Must(T.New("loopbody").Parse(iterIncrLoopBody)) 273 case fn.Iter && !fn.Incr: 274 T = template.Must(T.Parse(genericUnaryIterLoopRaw)) 275 template.Must(T.New("loopbody").Parse(basicSet)) 276 case !fn.Iter && fn.Incr: 277 Range = "incr" 278 T = template.Must(T.Parse(genericLoopRaw)) 279 template.Must(T.New("loopbody").Parse(basicIncr)) 280 default: 281 T = template.Must(T.Parse(genericLoopRaw)) 282 template.Must(T.New("loopbody").Parse(basicSet)) 283 } 284 285 if fn.LeftVec { 286 Right = "b" 287 } else { 288 Left = "a" 289 if !fn.Incr { 290 Range = "b" 291 } 292 // Index0 = "j" 293 } 294 295 switch { 296 case fn.Iter && fn.Incr && fn.LeftVec: 297 IterName0 = "ait" 298 IterName1 = "iit" 299 Index1 = "k" 300 case fn.Iter && !fn.Incr && fn.LeftVec: 301 IterName0 = "ait" 302 case fn.Iter && fn.Incr && !fn.LeftVec: 303 IterName0 = "bit" 304 IterName1 = "iit" 305 Index1 = "k" 306 case fn.Iter && !fn.Incr && !fn.LeftVec: 307 IterName0 = "bit" 308 } 309 310 template.Must(T.New("callFunc").Parse(binOpCallFunc)) 311 template.Must(T.New("opDo").Parse(binOpDo)) 312 template.Must(T.New("symbol").Parse(fn.SymbolTemplate())) 313 314 if fn.Check != nil && fn.Check(fn.Kind()) { 315 w.Write([]byte("var errs errorIndices\n")) 316 } 317 template.Must(T.New("check").Parse(fn.CheckTemplate)) 318 319 lb := LoopBody{ 320 TypedOp: fn.TypedBinOp, 321 Range: Range, 322 Left: Left, 323 Right: Right, 324 325 Index0: Index0, 326 Index1: Index1, 327 IterName0: IterName0, 328 IterName1: IterName1, 329 } 330 T.Execute(w, lb) 331 } 332 333 func (fn *GenericMixedArith) Write(w io.Writer) { 334 sig := fn.Signature() 335 336 w.Write([]byte("func ")) 337 sig.Write(w) 338 339 w.Write([]byte("{\n")) 340 341 fn.WriteBody(w) 342 if sig.Err { 343 if fn.Check != nil { 344 w.Write([]byte("\nif err != nil {\n return\n}\nif len(errs) > 0 {\n return errs }\nreturn nil")) 345 } else { 346 w.Write([]byte("\nreturn\n")) 347 } 348 } 349 w.Write([]byte("}\n\n")) 350 } 351 352 type GenericScalarScalarArith struct { 353 TypedBinOp 354 } 355 356 func (fn *GenericScalarScalarArith) Signature() *Signature { 357 return &Signature{ 358 Name: fn.Name(), 359 NameTemplate: typeAnnotatedName, 360 ParamNames: []string{"a", "b"}, 361 ParamTemplates: []*template.Template{scalarType, scalarType}, 362 RetVals: []string{""}, 363 RetValTemplates: []*template.Template{scalarType}, 364 Kind: fn.Kind(), 365 } 366 } 367 368 func (fn *GenericScalarScalarArith) WriteBody(w io.Writer) { 369 tmpl := `return {{if .IsFunc -}} 370 {{ template "callFunc" . -}} 371 {{else -}} 372 {{template "opDo" . -}} 373 {{end -}}` 374 opDo := `a {{template "symbol" .Kind}} b` 375 callFunc := `{{template "symbol" .Kind}}(a, b)` 376 377 T := template.Must(template.New(fn.Name()).Funcs(funcs).Parse(tmpl)) 378 template.Must(T.New("opDo").Parse(opDo)) 379 template.Must(T.New("callFunc").Parse(callFunc)) 380 template.Must(T.New("symbol").Parse(fn.SymbolTemplate())) 381 382 T.Execute(w, fn) 383 } 384 385 func (fn *GenericScalarScalarArith) Write(w io.Writer) { 386 w.Write([]byte("func ")) 387 sig := fn.Signature() 388 sig.Write(w) 389 w.Write([]byte("{")) 390 fn.WriteBody(w) 391 w.Write([]byte("}\n")) 392 } 393 394 func makeGenericVecVecAriths(tbo []TypedBinOp) (retVal []*GenericVecVecArith) { 395 for _, tb := range tbo { 396 if tc := tb.TypeClass(); tc != nil && !tc(tb.Kind()) { 397 continue 398 } 399 fn := &GenericVecVecArith{ 400 TypedBinOp: tb, 401 } 402 if tb.Name() == "Div" && !isFloatCmplx(tb.Kind()) { 403 fn.Check = panicsDiv0 404 fn.CheckTemplate = check0 405 } 406 407 retVal = append(retVal, fn) 408 409 } 410 411 return retVal 412 } 413 414 func makeGenericMixedAriths(tbo []TypedBinOp) (retVal []*GenericMixedArith) { 415 for _, tb := range tbo { 416 if tc := tb.TypeClass(); tc != nil && !tc(tb.Kind()) { 417 continue 418 } 419 fn := &GenericMixedArith{ 420 GenericVecVecArith: GenericVecVecArith{ 421 TypedBinOp: tb, 422 }, 423 } 424 if tb.Name() == "Div" && !isFloatCmplx(tb.Kind()) { 425 fn.Check = panicsDiv0 426 fn.CheckTemplate = check0 427 } 428 retVal = append(retVal, fn) 429 } 430 return 431 } 432 433 func makeGenericScalarScalarAriths(tbo []TypedBinOp) (retVal []*GenericScalarScalarArith) { 434 for _, tb := range tbo { 435 if tc := tb.TypeClass(); tc != nil && !tc(tb.Kind()) { 436 continue 437 } 438 fn := &GenericScalarScalarArith{ 439 TypedBinOp: tb, 440 } 441 retVal = append(retVal, fn) 442 } 443 return 444 } 445 446 func generateGenericVecVecArith(f io.Writer, ak Kinds) { 447 gen := makeGenericVecVecAriths(typedAriths) 448 449 // importStmt := ` 450 // import ( 451 // _ "unsafe" 452 453 // _ "gorgonia.org/vecf32" 454 // _ "gorgonia.org/vecf64") 455 // ` 456 // f.Write([]byte(importStmt)) 457 458 for _, g := range gen { 459 g.Write(f) 460 g.Incr = true 461 } 462 for _, g := range gen { 463 g.Write(f) 464 g.Incr = false 465 g.Iter = true 466 } 467 for _, g := range gen { 468 g.Write(f) 469 g.Incr = true 470 } 471 for _, g := range gen { 472 g.Write(f) 473 } 474 for _, g := range gen { 475 g.Incr = false 476 g.Iter = false 477 g.WithRecv = true 478 g.Write(f) 479 } 480 } 481 482 func generateGenericMixedArith(f io.Writer, ak Kinds) { 483 gen := makeGenericMixedAriths(typedAriths) 484 485 // SV first 486 for _, g := range gen { 487 g.Write(f) 488 g.Incr = true 489 } 490 for _, g := range gen { 491 g.Write(f) 492 g.Incr = false 493 g.Iter = true 494 } 495 for _, g := range gen { 496 g.Write(f) 497 g.Incr = true 498 } 499 for _, g := range gen { 500 g.Write(f) 501 502 // reset 503 g.LeftVec = true 504 g.Incr = false 505 g.Iter = false 506 } 507 508 // VS 509 for _, g := range gen { 510 g.Write(f) 511 g.Incr = true 512 } 513 for _, g := range gen { 514 g.Write(f) 515 g.Incr = false 516 g.Iter = true 517 } 518 for _, g := range gen { 519 g.Write(f) 520 g.Incr = true 521 } 522 for _, g := range gen { 523 g.Write(f) 524 } 525 } 526 527 func generateGenericScalarScalarArith(f io.Writer, ak Kinds) { 528 gen := makeGenericScalarScalarAriths(typedAriths) 529 for _, g := range gen { 530 g.Write(f) 531 } 532 }