github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/colexec/execgen/template.go (about) 1 // Copyright 2020 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package execgen 12 13 import ( 14 "fmt" 15 "go/token" 16 "regexp" 17 "sort" 18 "strings" 19 20 "github.com/dave/dst" 21 "github.com/dave/dst/dstutil" 22 ) 23 24 type templateInfo struct { 25 funcInfos map[string]*funcInfo 26 27 letInfos map[string]*letInfo 28 } 29 30 type templateParamInfo struct { 31 fieldOrdinal int 32 field *dst.Field 33 } 34 35 type funcInfo struct { 36 decl *dst.FuncDecl 37 templateParams []templateParamInfo 38 39 // instantiateArgs is a list of lists of arguments that were passed explicitly 40 // as execgen:instantiate declarations. 41 instantiateArgs [][]string 42 } 43 44 // letInfo contains a list of all of the values in an execgen:let declaration. 45 type letInfo struct { 46 // typ is a type literal. 47 typ *dst.ArrayType 48 vals []string 49 } 50 51 // Match // execgen:template<foo, bar> 52 var templateRe = regexp.MustCompile(`\/\/ execgen:template<((?:(?:\w+),?\W*)+)>`) 53 54 // Match // execgen:instantiate<foo, bar> 55 var instantiateRe = regexp.MustCompile(`\/\/ execgen:instantiate<((?:(?:\w+),?\W*)+)>`) 56 57 // replaceTemplateVars removes the template arguments from a callsite of a 58 // templated function. It returns the template arguments that were used, and a 59 // new CallExpr that doesn't have the template arguments. 60 func replaceTemplateVars( 61 info *funcInfo, call *dst.CallExpr, 62 ) (templateArgs []dst.Expr, newCall *dst.CallExpr, mangledName string) { 63 if len(info.templateParams) == 0 { 64 return nil, call, "" 65 } 66 templateArgs = make([]dst.Expr, len(info.templateParams)) 67 // Collect template arguments. 68 for i, param := range info.templateParams { 69 templateArgs[i] = dst.Clone(call.Args[param.fieldOrdinal]).(dst.Expr) 70 // Clear the decorations so that argument comments are not used in 71 // template function names. 72 templateArgs[i].Decorations().Start.Clear() 73 templateArgs[i].Decorations().End.Clear() 74 } 75 // Remove template vars from callsite. 76 newArgs := make([]dst.Expr, 0, len(call.Args)-len(info.templateParams)) 77 for i := range call.Args { 78 skip := false 79 for _, p := range info.templateParams { 80 if p.fieldOrdinal == i { 81 skip = true 82 break 83 } 84 } 85 if !skip { 86 newArgs = append(newArgs, dst.Clone(call.Args[i]).(dst.Expr)) 87 } 88 } 89 ret := dst.Clone(call).(*dst.CallExpr) 90 newName := getTemplateVariantName(info, templateArgs) 91 ret.Fun = newName 92 ret.Args = newArgs 93 return templateArgs, ret, newName.Name 94 } 95 96 // monomorphizeTemplate produces a variant of the input function body, given the 97 // definition of the function in funcInfo, and the concrete, template-time values 98 // that the function is being invoked with. It will try to find conditional 99 // statements that use the template variables and output only the branches that 100 // match. 101 // 102 // For example, given the function: 103 // // execgen:inline 104 // // execgen:template<t, i> 105 // 106 // func b(t bool, i int) int { 107 // if t { 108 // x = 3 109 // } else { 110 // x = 4 111 // } 112 // switch i { 113 // case 5: fmt.Println("5") 114 // case 6: fmt.Println("6") 115 // } 116 // return x 117 // } 118 // 119 // and a caller 120 // 121 // b(true, 5) 122 // 123 // this function will generate 124 // 125 // if true { 126 // x = 3 127 // } else { 128 // x = 4 129 // } 130 // switch 5 { 131 // case 5: fmt.Println("5") 132 // case 6: fmt.Println("6") 133 // } 134 // return x 135 // 136 // in its first pass. However, because the if's condition (true, in this case) 137 // is a logical expression containing boolean literals, and the switch statement 138 // is a switch on a template variable alone, a second pass "folds" 139 // the conditionals and replaces them like so: 140 // 141 // x = 3 142 // fmt.Println(5) 143 // return x 144 // 145 // Note that this method lexically replaces all formal parameters, so together 146 // with createTemplateFuncVariant, it enables templates to call other templates 147 // with template variables. 148 func monomorphizeTemplate(n dst.Node, info *funcInfo, args []dst.Expr) dst.Node { 149 // Create map from formal param name to arg. 150 paramMap := make(map[string]dst.Expr) 151 for i, p := range info.templateParams { 152 paramMap[p.field.Names[0].Name] = args[i] 153 } 154 templateSwitches := make(map[*dst.SwitchStmt]struct{}) 155 n = dstutil.Apply(n, func(cursor *dstutil.Cursor) bool { 156 // Replace all usages of the formal parameter with the template arg. 157 c := cursor.Node() 158 switch t := c.(type) { 159 case *dst.Ident: 160 if arg := paramMap[t.Name]; arg != nil { 161 p := cursor.Parent() 162 if s, ok := p.(*dst.SwitchStmt); ok { 163 if s.Tag.(*dst.Ident) == t { 164 // Write down the switch statements we see that are of the form: 165 // switch <templateParam> { 166 // ... 167 // } 168 // We'll replace these later. 169 templateSwitches[s] = struct{}{} 170 } 171 } 172 cursor.Replace(dst.Clone(arg)) 173 } 174 } 175 return true 176 }, nil) 177 178 return foldConditionals(n, info, templateSwitches) 179 } 180 181 // foldConditionals edits conditional statements to try to remove branches that 182 // are statically falsifiable. It works with two cases: 183 // 184 // if <bool> { } else { } and if !<bool> { } else { } 185 // 186 // execgen:switch 187 // 188 // switch <ident> { 189 // case <otherIdent>: 190 // case <ident>: 191 // ... 192 // } 193 func foldConditionals( 194 n dst.Node, info *funcInfo, templateSwitches map[*dst.SwitchStmt]struct{}, 195 ) dst.Node { 196 return dstutil.Apply(n, func(cursor *dstutil.Cursor) bool { 197 n := cursor.Node() 198 switch n := n.(type) { 199 case *dst.SwitchStmt: 200 if _, ok := templateSwitches[n]; !ok { 201 // Not a template switch. 202 return true 203 } 204 t := prettyPrintExprs(n.Tag) 205 for _, item := range n.Body.List { 206 c := item.(*dst.CaseClause) 207 for _, e := range c.List { 208 if prettyPrintExprs(e) == t { 209 body := &dst.BlockStmt{ 210 List: c.Body, 211 Decs: dst.BlockStmtDecorations{ 212 NodeDecs: c.Decs.NodeDecs, 213 Lbrace: c.Decs.Colon, 214 }, 215 } 216 newBody := foldConditionals(body, info, templateSwitches).(*dst.BlockStmt) 217 insertBlockStmt(cursor, newBody) 218 cursor.Delete() 219 return true 220 } 221 } 222 } 223 case *dst.IfStmt: 224 ret, ok := tryEvalBool(n.Cond) 225 if !ok { 226 return true 227 } 228 // Since we're replacing the node, make sure we preserve any comments. 229 if len(n.Decs.NodeDecs.Start) > 0 { 230 cursor.InsertBefore(&dst.AssignStmt{ 231 Tok: token.ASSIGN, 232 Lhs: []dst.Expr{dst.NewIdent("_")}, 233 Rhs: []dst.Expr{ 234 &dst.BasicLit{ 235 Kind: token.STRING, 236 Value: "true", 237 }, 238 }, 239 Decs: dst.AssignStmtDecorations{ 240 NodeDecs: n.Decs.NodeDecs, 241 }, 242 }) 243 } 244 if ret { 245 // Replace with the if side. 246 newBody := foldConditionals(n.Body, info, templateSwitches).(*dst.BlockStmt) 247 insertBlockStmt(cursor, newBody) 248 cursor.Delete() 249 return true 250 } 251 // Replace with the else side, if it exists. 252 if n.Else != nil { 253 newElse := foldConditionals(n.Else, info, templateSwitches) 254 switch e := newElse.(type) { 255 case *dst.BlockStmt: 256 insertBlockStmt(cursor, e) 257 cursor.Delete() 258 default: 259 cursor.Replace(newElse) 260 } 261 } else { 262 cursor.Delete() 263 } 264 } 265 return true 266 }, nil) 267 } 268 269 // tryEvalBool attempts to statically evaluate the input expr as a logical 270 // combination of boolean literals (like false || true). It returns the result 271 // of the evaluation and whether or not the expression was actually evaluable 272 // as such. 273 func tryEvalBool(n dst.Expr) (ret bool, ok bool) { 274 switch n := n.(type) { 275 case *dst.UnaryExpr: 276 // !<expr> 277 if n.Op == token.NOT { 278 ret, ok = tryEvalBool(n.X) 279 ret = !ret 280 return ret, ok 281 } 282 return false, false 283 case *dst.BinaryExpr: 284 // expr && expr or expr || expr 285 if n.Op != token.LAND && n.Op != token.LOR { 286 return false, false 287 } 288 l, ok := tryEvalBool(n.X) 289 if !ok { 290 return false, false 291 } 292 r, ok := tryEvalBool(n.Y) 293 if !ok { 294 return false, false 295 } 296 switch n.Op { 297 case token.LAND: 298 return l && r, true 299 case token.LOR: 300 return l || r, true 301 default: 302 panic("unreachable") 303 } 304 case *dst.Ident: 305 switch n.Name { 306 case "true": 307 return true, true 308 case "false": 309 return false, true 310 } 311 return false, false 312 } 313 return false, false 314 } 315 316 func insertBlockStmt(cursor *dstutil.Cursor, block *dst.BlockStmt) { 317 // Make sure to preserve comments. 318 cursor.InsertBefore(&dst.EmptyStmt{ 319 Implicit: true, 320 Decs: dst.EmptyStmtDecorations{ 321 NodeDecs: dst.NodeDecs{ 322 Before: dst.NewLine, 323 Start: trimLeadingNewLines(append(block.Decs.Lbrace, block.Decs.NodeDecs.Start...)), 324 End: block.Decs.End, 325 After: dst.NewLine, 326 }}, 327 }) 328 for _, stmt := range block.List { 329 cursor.InsertBefore(stmt) 330 } 331 } 332 333 // trimTemplateDeclMatches takes a list of matches from an execgen:blah<a,b,c> 334 // regexp match and returns the trimmed list of a, b, and c. 335 func trimTemplateDeclMatches(matches []string) []string { 336 match := matches[1] 337 338 templateVars := strings.Split(match, ",") 339 for i, v := range templateVars { 340 templateVars[i] = strings.TrimSpace(v) 341 } 342 return templateVars 343 } 344 345 const runtimeToTemplateSuffix = "_runtime_to_template" 346 347 // findTemplateDecls, given an AST, finds all functions annotated with 348 // execgen:template<foo,bar>, and returns a funcInfo for each of them, and 349 // finds all var decls annotated with execgen:let, returning a letInfo for 350 // each of them. 351 func findTemplateDecls(f *dst.File) templateInfo { 352 ret := templateInfo{ 353 funcInfos: make(map[string]*funcInfo), 354 letInfos: make(map[string]*letInfo), 355 } 356 357 dstutil.Apply(f, func(cursor *dstutil.Cursor) bool { 358 n := cursor.Node() 359 switch n := n.(type) { 360 case *dst.FuncDecl: 361 var templateVars []string 362 var instantiateArgs [][]string 363 i := 0 364 for _, dec := range n.Decs.Start { 365 if matches := templateRe.FindStringSubmatch(dec); matches != nil { 366 templateVars = trimTemplateDeclMatches(matches) 367 continue 368 } 369 370 if matches := instantiateRe.FindStringSubmatch(dec); matches != nil { 371 instantiateMatches := trimTemplateDeclMatches(matches) 372 newInstantiateArgs := expandInstantiateArgs(instantiateMatches, ret.letInfos) 373 instantiateArgs = append(instantiateArgs, newInstantiateArgs...) 374 // Eventually let's delete the instantiate comments as well. 375 continue 376 } 377 // Filter decorations in place. 378 n.Decs.Start[i] = dec 379 i++ 380 } 381 n.Decs.Start = n.Decs.Start[:i] 382 if templateVars == nil { 383 return false 384 } 385 386 // Process template funcs: find template params from runtime definition 387 // and save in funcInfo. 388 info := &funcInfo{ 389 instantiateArgs: instantiateArgs, 390 } 391 for _, v := range templateVars { 392 var found bool 393 for i, f := range n.Type.Params.List { 394 // We can safely 0-index here because fields always have at least 395 // one name, and we've already banned the case where they have more 396 // than one. (e.g. func a (a int, b int, c, d int)) 397 if f.Names[0].Name == v { 398 info.templateParams = append(info.templateParams, templateParamInfo{ 399 fieldOrdinal: i, 400 field: dst.Clone(f).(*dst.Field), 401 }) 402 found = true 403 break 404 } 405 } 406 if !found { 407 panic(fmt.Errorf("template var %s not found", v)) 408 } 409 } 410 // Delete template params from runtime definition. 411 newParamList := make([]*dst.Field, 0, len(n.Type.Params.List)-len(info.templateParams)) 412 for i, field := range n.Type.Params.List { 413 var skip bool 414 for _, p := range info.templateParams { 415 if i == p.fieldOrdinal { 416 skip = true 417 break 418 } 419 } 420 if !skip { 421 newParamList = append(newParamList, field) 422 } 423 } 424 funcDecs := n.Decs 425 // Replace the template function with a const marker, just so we can keep 426 // the comments above the template function available. 427 cursor.InsertBefore(&dst.GenDecl{ 428 Tok: token.CONST, 429 Specs: []dst.Spec{ 430 &dst.ValueSpec{ 431 Names: []*dst.Ident{dst.NewIdent("_")}, 432 Values: []dst.Expr{ 433 &dst.BasicLit{ 434 Kind: token.STRING, 435 Value: fmt.Sprintf(`"template_%s"`, n.Name.Name), 436 }, 437 }, 438 }, 439 }, 440 Decs: dst.GenDeclDecorations{ 441 NodeDecs: funcDecs.NodeDecs, 442 }, 443 }) 444 oldParamList := n.Type.Params.List 445 n.Type.Params.List = newParamList 446 n.Decs.Start = trimStartDecs(n) 447 info.decl = n 448 ret.funcInfos[info.decl.Name.Name] = info 449 450 for _, args := range info.instantiateArgs { 451 exprList := make([]dst.Expr, len(args)) 452 for j := range args { 453 exprList[j] = dst.NewIdent(args[j]) 454 } 455 createTemplateFuncVariant(f, info, exprList) 456 } 457 458 // Now, we need to generate the "look up table" that allows us to convert 459 // runtime values into template values for the template args. 460 461 // We only do this if there were execgen:instantiate statements, since we 462 // assume that if there were no such statements, the concrete callsites 463 // were already present. 464 465 if info.instantiateArgs != nil { 466 runtimeArgs := make([]dst.Expr, len(n.Type.Params.List)) 467 for i, p := range n.Type.Params.List { 468 runtimeArgs[i] = dst.NewIdent(p.Names[0].Name) 469 } 470 decl := &dst.FuncDecl{ 471 Name: dst.NewIdent(fmt.Sprintf("%s%s", info.decl.Name.Name, runtimeToTemplateSuffix)), 472 Type: dst.Clone(info.decl.Type).(*dst.FuncType), 473 Body: &dst.BlockStmt{ 474 List: []dst.Stmt{ 475 generateSwitchStatementLookup(info, runtimeArgs, templateVars, info.instantiateArgs), 476 }, 477 }, 478 } 479 decl.Type.Params.List = oldParamList 480 cursor.InsertAfter(decl) 481 } 482 cursor.Delete() 483 484 case *dst.GenDecl: 485 // Search for execgen:let declarations. 486 isLet := false 487 for _, dec := range n.Decs.Start { 488 if dec == "// execgen:let" { 489 isLet = true 490 break 491 } 492 } 493 if !isLet { 494 return true 495 } 496 if n.Tok != token.VAR { 497 panic("execgen:let only allowed on vars") 498 } 499 for _, spec := range n.Specs { 500 n := spec.(*dst.ValueSpec) 501 if len(n.Names) != 1 || len(n.Values) != 1 { 502 panic("execgen:let must have 1 name and one value per var") 503 } 504 info := &letInfo{} 505 name := n.Names[0].Name 506 c, ok := n.Values[0].(*dst.CompositeLit) 507 if !ok { 508 panic("execgen:let must use a composite literal value") 509 } 510 typ, ok := c.Type.(*dst.ArrayType) 511 if !ok { 512 panic("execgen:let must be on an array type literal") 513 } 514 info.vals = make([]string, len(c.Elts)) 515 info.typ = typ 516 for i := range c.Elts { 517 info.vals[i] = prettyPrintExprs(c.Elts[i]) 518 } 519 ret.letInfos[name] = info 520 } 521 522 cursor.Delete() 523 } 524 return true 525 }, nil) 526 527 return ret 528 } 529 530 // expandInstantiateArgs takes a list of strings, the arguments to an 531 // execgen:instantiate annotation, and returns a list of list of strings, after 532 // combinatorially expanding any execgen:let lists in the instantiate arguments. 533 // For example, given the instantiateArgs: 534 // ["Bools", "Bools", 3] 535 // and an execgen:let that maps "Bools" to ["true", "false"], we'd return the 536 // list of lists: 537 // [true, true, 3] 538 // [true, false, 3] 539 // [false, true, 3] 540 // [false, false, 3] 541 func expandInstantiateArgs(instantiateArgs []string, letInfos map[string]*letInfo) [][]string { 542 expandedArgs := make([][]string, len(instantiateArgs)) 543 for i, arg := range instantiateArgs { 544 if info := letInfos[arg]; info != nil { 545 expandedArgs[i] = info.vals 546 } else { 547 expandedArgs[i] = []string{arg} 548 } 549 } 550 return generateInstantiateCombinations(expandedArgs) 551 } 552 553 func generateInstantiateCombinations(args [][]string) [][]string { 554 if len(args) == 1 { 555 // Base case: transform the final options list into an arguments list of 556 // lists where each arguments list is a single element containing one of 557 // the final options. 558 // For example, given [[true, false]], we'll return: 559 // [[true], [false]] 560 ret := make([][]string, len(args[0])) 561 for i, arg := range args[0] { 562 ret[i] = []string{arg} 563 } 564 return ret 565 } 566 rest := generateInstantiateCombinations(args[1:]) 567 ret := make([][]string, 0, len(rest)*len(args[0])) 568 for _, argOption := range args[0] { 569 // For every option of argument, prepend it to every args list from 570 // the recursive step. 571 for _, args := range rest { 572 ret = append(ret, append([]string{argOption}, args...)) 573 } 574 } 575 return ret 576 } 577 578 // generateSwitchStatementLookup ... 579 // remainingArgs is a list of lists of actual instantiations. For example, if 580 // we had: 581 // execgen:instantiate<red, potato> 582 // execgen:instantiate<red, orange> 583 // execgen:instantiate<yellow, orange> 584 // 585 // remainingArgs would be {{red, potato}, {red, orange}, {yellow orange}} 586 func generateSwitchStatementLookup( 587 info *funcInfo, curArgs []dst.Expr, remainingTemplateParams []string, remainingArgs [][]string, 588 ) *dst.SwitchStmt { 589 ret := &dst.SwitchStmt{ 590 Tag: dst.NewIdent(remainingTemplateParams[0]), 591 Body: &dst.BlockStmt{}, 592 } 593 defaultCase := &dst.CaseClause{ 594 Body: []dst.Stmt{ 595 mustParseStmt(`panic(fmt.Sprint("unknown value", ` + remainingTemplateParams[0] + `))`), 596 }, 597 } 598 if len(remainingArgs[0]) == 1 { 599 // Base case. We finished switching on all template params, time to actually 600 // invoke the fully specialized function. 601 stmtList := make([]dst.Stmt, len(remainingArgs)+1) 602 for i := range remainingArgs { 603 argList := append(curArgs, dst.NewIdent(remainingArgs[i][0])) 604 call := &dst.CallExpr{ 605 Fun: dst.NewIdent(info.decl.Name.Name), 606 Args: argList, 607 } 608 _, call, _ = replaceTemplateVars(info, call) 609 var stmt dst.Stmt 610 if info.decl.Type.Results != nil { 611 stmt = &dst.ReturnStmt{Results: []dst.Expr{call}} 612 } else { 613 stmt = &dst.ExprStmt{X: call} 614 } 615 stmtList[i] = &dst.CaseClause{ 616 List: []dst.Expr{dst.NewIdent(remainingArgs[i][0])}, 617 Body: []dst.Stmt{stmt}, 618 } 619 } 620 stmtList[len(stmtList)-1] = defaultCase 621 ret.Body.List = stmtList 622 return ret 623 } 624 625 // Recursive case: we have more args to deal with 626 groupedArgs := make(map[string][][]string) 627 for _, argList := range remainingArgs { 628 firstArg := argList[0] 629 groupedArgs[firstArg] = append(groupedArgs[firstArg], argList[1:]) 630 } 631 stmtList := make([]dst.Stmt, len(groupedArgs)+1) 632 // Sort firstArgs lexicographically, so we have a consistent output order. 633 firstArgs := make([]string, 0, len(groupedArgs)) 634 for firstArg := range groupedArgs { 635 firstArgs = append(firstArgs, firstArg) 636 } 637 sort.Strings(firstArgs) 638 639 for i, firstArg := range firstArgs { 640 restArgs := groupedArgs[firstArg] 641 argList := append(curArgs, dst.NewIdent(firstArg)) 642 stmtList[i] = &dst.CaseClause{ 643 List: []dst.Expr{dst.NewIdent(firstArg)}, 644 Body: []dst.Stmt{generateSwitchStatementLookup( 645 info, 646 argList, 647 remainingTemplateParams[1:], 648 restArgs, 649 )}, 650 } 651 } 652 stmtList[len(stmtList)-1] = defaultCase 653 ret.Body.List = stmtList 654 return ret 655 } 656 657 var nameMangler = strings.NewReplacer(".", "DOT", "*", "STAR") 658 659 func getTemplateVariantName(info *funcInfo, args []dst.Expr) *dst.Ident { 660 var newName strings.Builder 661 newName.WriteString(info.decl.Name.Name) 662 for j := range args { 663 newName.WriteByte('_') 664 newName.WriteString(prettyPrintExprs(args[j])) 665 } 666 s := newName.String() 667 s = nameMangler.Replace(s) 668 return dst.NewIdent(s) 669 } 670 671 func trimStartDecs(n *dst.FuncDecl) []string { 672 // The function declaration node can accidentally capture extra comments that 673 // we want to leave in their original position, and not duplicate. So, remove 674 // any decorations that are separated from the function declaration by one or 675 // more newlines. 676 startDecs := n.Decs.Start.All() 677 for i := len(startDecs) - 1; i >= 0; i-- { 678 if strings.TrimSpace(startDecs[i]) == "" { 679 return startDecs[i+1:] 680 } 681 } 682 return startDecs 683 } 684 685 func trimLeadingNewLines(decs []string) []string { 686 var i int 687 for ; i < len(decs); i++ { 688 if strings.TrimSpace(decs[i]) != "" { 689 break 690 } 691 } 692 return decs[i:] 693 } 694 695 // replaceAndExpandTemplates finds all CallExprs in the input AST that are calling 696 // the functions that had been annotated with // execgen:template that are 697 // passed in via the templateFuncInfos map. It recursively replaces the 698 // CallExprs with their expanded, mangled template function names, and creates 699 // the requisite monomorphized FuncDecls on demand. 700 // 701 // For example, given a template function: 702 // 703 // // execgen:template<b> 704 // 705 // func foo (a int, b bool) { 706 // if b { 707 // return a 708 // } else { 709 // return a + 1 710 // } 711 // } 712 // 713 // And callsites: 714 // 715 // foo(a, true) 716 // foo(a, false) 717 // 718 // This function will add 2 new func decls to the AST: 719 // 720 // func foo_true(a int) { 721 // return a 722 // } 723 // 724 // func foo_false(a int) { 725 // return a + 1 726 // } 727 func replaceAndExpandTemplates(f *dst.File, templateFuncInfos map[string]*funcInfo) dst.Node { 728 // First, create the DAG of template functions. This DAG points from template 729 // function to any other template functions that are called from within its 730 // body that propagate template arguments. 731 // First, find all "roots": template CallExprs that only have concrete 732 // arguments. 733 var q []*dst.CallExpr 734 dstutil.Apply(f, func(cursor *dstutil.Cursor) bool { 735 n := cursor.Node() 736 switch n := n.(type) { 737 case *dst.FuncDecl: 738 q = append(q, findConcreteTemplateCallSites(n, templateFuncInfos)...) 739 } 740 return true 741 }, nil) 742 743 // For every remaining concrete call site, replace it with its mangled template 744 // function call, and generate the requisite monomorphized template function 745 // if we haven't already. 746 // 747 // Then, process the new monomorphized template function and add any newly 748 // created concrete template call sites to the queue. Do this until we have no 749 // more concrete template call sites. 750 seenCallsites := make(map[string]struct{}) 751 for len(q) > 0 { 752 q = q[:0] 753 dstutil.Apply(f, func(cursor *dstutil.Cursor) bool { 754 n := cursor.Node() 755 switch n := n.(type) { 756 case *dst.CallExpr: 757 ident, ok := n.Fun.(*dst.Ident) 758 if !ok { 759 return true 760 } 761 info, ok := templateFuncInfos[ident.Name] 762 if !ok { 763 // Nothing to do, it's not a templated function. 764 return true 765 } 766 // Critical moment: We need to know whether to replace with concrete 767 // input args or to replace the call with the lookup version. 768 if info.instantiateArgs != nil { 769 n.Fun = dst.NewIdent(fmt.Sprintf("%s%s", info.decl.Name.Name, runtimeToTemplateSuffix)) 770 cursor.Replace(n) 771 return true 772 } 773 templateArgs, newCall, newName := replaceTemplateVars(info, n) 774 cursor.Replace(newCall) 775 // Have we already replaced this template function with these args? 776 funcInstance := newName + prettyPrintExprs(templateArgs...) 777 if _, ok := seenCallsites[funcInstance]; !ok { 778 seenCallsites[funcInstance] = struct{}{} 779 newFuncVariant := createTemplateFuncVariant(f, info, templateArgs) 780 q = append(q, findConcreteTemplateCallSites(newFuncVariant, templateFuncInfos)...) 781 } 782 } 783 return true 784 }, nil) 785 } 786 return nil 787 } 788 789 // findConcreteTemplateCallSites finds all CallExprs within the input funcDecl 790 // that do not contain template arguments and thus can be immediately replaced. 791 func findConcreteTemplateCallSites( 792 funcDecl *dst.FuncDecl, templateFuncInfos map[string]*funcInfo, 793 ) []*dst.CallExpr { 794 info, calledFromTemplate := templateFuncInfos[funcDecl.Name.Name] 795 var ret []*dst.CallExpr 796 dstutil.Apply(funcDecl, func(cursor *dstutil.Cursor) bool { 797 n := cursor.Node() 798 switch callExpr := n.(type) { 799 case *dst.CallExpr: 800 ident, ok := callExpr.Fun.(*dst.Ident) 801 if !ok { 802 return true 803 } 804 _, ok = templateFuncInfos[ident.Name] 805 if !ok { 806 // Nothing to do, it's not a templated function. 807 return true 808 } 809 if !calledFromTemplate { 810 // All arguments are concrete since the callsite isn't within another 811 // templated function decl. 812 ret = append(ret, callExpr) 813 return true 814 } 815 for i := range callExpr.Args { 816 switch a := callExpr.Args[i].(type) { 817 case *dst.Ident: 818 for _, param := range info.templateParams { 819 if param.field.Names[0].Name == a.Name { 820 // Found a propagated template parameter, so we don't return 821 // this CallExpr (it's not concrete). 822 // NOTE: This is broken in the presence of shadowing. 823 // Let's assume nobody shadows template vars for now. 824 return true 825 } 826 } 827 } 828 } 829 ret = append(ret, callExpr) 830 } 831 return true 832 }, nil) 833 return ret 834 } 835 836 // expandTemplates is the main entry point to the templater. Given a dst.File, 837 // it modifies the dst.File to include all expanded template functions, and 838 // edits call sites to call the newly expanded functions. 839 func expandTemplates(f *dst.File) { 840 templateInfo := findTemplateDecls(f) 841 replaceAndExpandTemplates(f, templateInfo.funcInfos) 842 } 843 844 // createTemplateFuncVariant creates a variant of the input funcInfo given the 845 // template arguments passed in args, and adds the variant to the end of the 846 // input file. 847 func createTemplateFuncVariant(f *dst.File, info *funcInfo, args []dst.Expr) *dst.FuncDecl { 848 n := info.decl 849 directives := n.Decs.NodeDecs.Start 850 newBody := monomorphizeTemplate(dst.Clone(n.Body).(*dst.BlockStmt), info, args).(*dst.BlockStmt) 851 newName := getTemplateVariantName(info, args) 852 ret := &dst.FuncDecl{ 853 Name: newName, 854 Type: dst.Clone(info.decl.Type).(*dst.FuncType), 855 Body: newBody, 856 Decs: dst.FuncDeclDecorations{ 857 NodeDecs: dst.NodeDecs{ 858 Before: dst.EmptyLine, 859 Start: directives, 860 }, 861 }, 862 } 863 f.Decls = append(f.Decls, ret) 864 return ret 865 }