
     1  // Copyright 2020 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    11  package execgen
    13  import (
    14  	"fmt"
    15  	"go/token"
    16  	"regexp"
    17  	"sort"
    18  	"strings"
    20  	""
    21  	""
    22  )
    24  type templateInfo struct {
    25  	funcInfos map[string]*funcInfo
    27  	letInfos map[string]*letInfo
    28  }
    30  type templateParamInfo struct {
    31  	fieldOrdinal int
    32  	field        *dst.Field
    33  }
    35  type funcInfo struct {
    36  	decl           *dst.FuncDecl
    37  	templateParams []templateParamInfo
    39  	// instantiateArgs is a list of lists of arguments that were passed explicitly
    40  	// as execgen:instantiate declarations.
    41  	instantiateArgs [][]string
    42  }
    44  // letInfo contains a list of all of the values in an execgen:let declaration.
    45  type letInfo struct {
    46  	// typ is a type literal.
    47  	typ  *dst.ArrayType
    48  	vals []string
    49  }
    51  // Match // execgen:template<foo, bar>
    52  var templateRe = regexp.MustCompile(`\/\/ execgen:template<((?:(?:\w+),?\W*)+)>`)
    54  // Match // execgen:instantiate<foo, bar>
    55  var instantiateRe = regexp.MustCompile(`\/\/ execgen:instantiate<((?:(?:\w+),?\W*)+)>`)
    57  // replaceTemplateVars removes the template arguments from a callsite of a
    58  // templated function. It returns the template arguments that were used, and a
    59  // new CallExpr that doesn't have the template arguments.
    60  func replaceTemplateVars(
    61  	info *funcInfo, call *dst.CallExpr,
    62  ) (templateArgs []dst.Expr, newCall *dst.CallExpr, mangledName string) {
    63  	if len(info.templateParams) == 0 {
    64  		return nil, call, ""
    65  	}
    66  	templateArgs = make([]dst.Expr, len(info.templateParams))
    67  	// Collect template arguments.
    68  	for i, param := range info.templateParams {
    69  		templateArgs[i] = dst.Clone(call.Args[param.fieldOrdinal]).(dst.Expr)
    70  		// Clear the decorations so that argument comments are not used in
    71  		// template function names.
    72  		templateArgs[i].Decorations().Start.Clear()
    73  		templateArgs[i].Decorations().End.Clear()
    74  	}
    75  	// Remove template vars from callsite.
    76  	newArgs := make([]dst.Expr, 0, len(call.Args)-len(info.templateParams))
    77  	for i := range call.Args {
    78  		skip := false
    79  		for _, p := range info.templateParams {
    80  			if p.fieldOrdinal == i {
    81  				skip = true
    82  				break
    83  			}
    84  		}
    85  		if !skip {
    86  			newArgs = append(newArgs, dst.Clone(call.Args[i]).(dst.Expr))
    87  		}
    88  	}
    89  	ret := dst.Clone(call).(*dst.CallExpr)
    90  	newName := getTemplateVariantName(info, templateArgs)
    91  	ret.Fun = newName
    92  	ret.Args = newArgs
    93  	return templateArgs, ret, newName.Name
    94  }
    96  // monomorphizeTemplate produces a variant of the input function body, given the
    97  // definition of the function in funcInfo, and the concrete, template-time values
    98  // that the function is being invoked with. It will try to find conditional
    99  // statements that use the template variables and output only the branches that
   100  // match.
   101  //
   102  // For example, given the function:
   103  // // execgen:inline
   104  // // execgen:template<t, i>
   105  //
   106  //	func b(t bool, i int) int {
   107  //	  if t {
   108  //	    x = 3
   109  //	  } else {
   110  //	    x = 4
   111  //	  }
   112  //	  switch i {
   113  //	    case 5: fmt.Println("5")
   114  //	    case 6: fmt.Println("6")
   115  //	  }
   116  //	  return x
   117  //	}
   118  //
   119  // and a caller
   120  //
   121  //	b(true, 5)
   122  //
   123  // this function will generate
   124  //
   125  //	if true {
   126  //	  x = 3
   127  //	} else {
   128  //	  x = 4
   129  //	}
   130  //	switch 5 {
   131  //	  case 5: fmt.Println("5")
   132  //	  case 6: fmt.Println("6")
   133  //	}
   134  //	return x
   135  //
   136  // in its first pass. However, because the if's condition (true, in this case)
   137  // is a logical expression containing boolean literals, and the switch statement
   138  // is a switch on a template variable alone, a second pass "folds"
   139  // the conditionals and replaces them like so:
   140  //
   141  //	x = 3
   142  //	fmt.Println(5)
   143  //	return x
   144  //
   145  // Note that this method lexically replaces all formal parameters, so together
   146  // with createTemplateFuncVariant, it enables templates to call other templates
   147  // with template variables.
   148  func monomorphizeTemplate(n dst.Node, info *funcInfo, args []dst.Expr) dst.Node {
   149  	// Create map from formal param name to arg.
   150  	paramMap := make(map[string]dst.Expr)
   151  	for i, p := range info.templateParams {
   152  		paramMap[p.field.Names[0].Name] = args[i]
   153  	}
   154  	templateSwitches := make(map[*dst.SwitchStmt]struct{})
   155  	n = dstutil.Apply(n, func(cursor *dstutil.Cursor) bool {
   156  		// Replace all usages of the formal parameter with the template arg.
   157  		c := cursor.Node()
   158  		switch t := c.(type) {
   159  		case *dst.Ident:
   160  			if arg := paramMap[t.Name]; arg != nil {
   161  				p := cursor.Parent()
   162  				if s, ok := p.(*dst.SwitchStmt); ok {
   163  					if s.Tag.(*dst.Ident) == t {
   164  						// Write down the switch statements we see that are of the form:
   165  						// switch <templateParam> {
   166  						// ...
   167  						// }
   168  						// We'll replace these later.
   169  						templateSwitches[s] = struct{}{}
   170  					}
   171  				}
   172  				cursor.Replace(dst.Clone(arg))
   173  			}
   174  		}
   175  		return true
   176  	}, nil)
   178  	return foldConditionals(n, info, templateSwitches)
   179  }
   181  // foldConditionals edits conditional statements to try to remove branches that
   182  // are statically falsifiable. It works with two cases:
   183  //
   184  // if <bool> { } else { } and if !<bool> { } else { }
   185  //
   186  // execgen:switch
   187  //
   188  //	switch <ident> {
   189  //	  case <otherIdent>:
   190  //	  case <ident>:
   191  //	  ...
   192  //	}
   193  func foldConditionals(
   194  	n dst.Node, info *funcInfo, templateSwitches map[*dst.SwitchStmt]struct{},
   195  ) dst.Node {
   196  	return dstutil.Apply(n, func(cursor *dstutil.Cursor) bool {
   197  		n := cursor.Node()
   198  		switch n := n.(type) {
   199  		case *dst.SwitchStmt:
   200  			if _, ok := templateSwitches[n]; !ok {
   201  				// Not a template switch.
   202  				return true
   203  			}
   204  			t := prettyPrintExprs(n.Tag)
   205  			for _, item := range n.Body.List {
   206  				c := item.(*dst.CaseClause)
   207  				for _, e := range c.List {
   208  					if prettyPrintExprs(e) == t {
   209  						body := &dst.BlockStmt{
   210  							List: c.Body,
   211  							Decs: dst.BlockStmtDecorations{
   212  								NodeDecs: c.Decs.NodeDecs,
   213  								Lbrace:   c.Decs.Colon,
   214  							},
   215  						}
   216  						newBody := foldConditionals(body, info, templateSwitches).(*dst.BlockStmt)
   217  						insertBlockStmt(cursor, newBody)
   218  						cursor.Delete()
   219  						return true
   220  					}
   221  				}
   222  			}
   223  		case *dst.IfStmt:
   224  			ret, ok := tryEvalBool(n.Cond)
   225  			if !ok {
   226  				return true
   227  			}
   228  			// Since we're replacing the node, make sure we preserve any comments.
   229  			if len(n.Decs.NodeDecs.Start) > 0 {
   230  				cursor.InsertBefore(&dst.AssignStmt{
   231  					Tok: token.ASSIGN,
   232  					Lhs: []dst.Expr{dst.NewIdent("_")},
   233  					Rhs: []dst.Expr{
   234  						&dst.BasicLit{
   235  							Kind:  token.STRING,
   236  							Value: "true",
   237  						},
   238  					},
   239  					Decs: dst.AssignStmtDecorations{
   240  						NodeDecs: n.Decs.NodeDecs,
   241  					},
   242  				})
   243  			}
   244  			if ret {
   245  				// Replace with the if side.
   246  				newBody := foldConditionals(n.Body, info, templateSwitches).(*dst.BlockStmt)
   247  				insertBlockStmt(cursor, newBody)
   248  				cursor.Delete()
   249  				return true
   250  			}
   251  			// Replace with the else side, if it exists.
   252  			if n.Else != nil {
   253  				newElse := foldConditionals(n.Else, info, templateSwitches)
   254  				switch e := newElse.(type) {
   255  				case *dst.BlockStmt:
   256  					insertBlockStmt(cursor, e)
   257  					cursor.Delete()
   258  				default:
   259  					cursor.Replace(newElse)
   260  				}
   261  			} else {
   262  				cursor.Delete()
   263  			}
   264  		}
   265  		return true
   266  	}, nil)
   267  }
   269  // tryEvalBool attempts to statically evaluate the input expr as a logical
   270  // combination of boolean literals (like false || true). It returns the result
   271  // of the evaluation and whether or not the expression was actually evaluable
   272  // as such.
   273  func tryEvalBool(n dst.Expr) (ret bool, ok bool) {
   274  	switch n := n.(type) {
   275  	case *dst.UnaryExpr:
   276  		// !<expr>
   277  		if n.Op == token.NOT {
   278  			ret, ok = tryEvalBool(n.X)
   279  			ret = !ret
   280  			return ret, ok
   281  		}
   282  		return false, false
   283  	case *dst.BinaryExpr:
   284  		// expr && expr or expr || expr
   285  		if n.Op != token.LAND && n.Op != token.LOR {
   286  			return false, false
   287  		}
   288  		l, ok := tryEvalBool(n.X)
   289  		if !ok {
   290  			return false, false
   291  		}
   292  		r, ok := tryEvalBool(n.Y)
   293  		if !ok {
   294  			return false, false
   295  		}
   296  		switch n.Op {
   297  		case token.LAND:
   298  			return l && r, true
   299  		case token.LOR:
   300  			return l || r, true
   301  		default:
   302  			panic("unreachable")
   303  		}
   304  	case *dst.Ident:
   305  		switch n.Name {
   306  		case "true":
   307  			return true, true
   308  		case "false":
   309  			return false, true
   310  		}
   311  		return false, false
   312  	}
   313  	return false, false
   314  }
   316  func insertBlockStmt(cursor *dstutil.Cursor, block *dst.BlockStmt) {
   317  	// Make sure to preserve comments.
   318  	cursor.InsertBefore(&dst.EmptyStmt{
   319  		Implicit: true,
   320  		Decs: dst.EmptyStmtDecorations{
   321  			NodeDecs: dst.NodeDecs{
   322  				Before: dst.NewLine,
   323  				Start:  trimLeadingNewLines(append(block.Decs.Lbrace, block.Decs.NodeDecs.Start...)),
   324  				End:    block.Decs.End,
   325  				After:  dst.NewLine,
   326  			}},
   327  	})
   328  	for _, stmt := range block.List {
   329  		cursor.InsertBefore(stmt)
   330  	}
   331  }
   333  // trimTemplateDeclMatches takes a list of matches from an execgen:blah<a,b,c>
   334  // regexp match and returns the trimmed list of a, b, and c.
   335  func trimTemplateDeclMatches(matches []string) []string {
   336  	match := matches[1]
   338  	templateVars := strings.Split(match, ",")
   339  	for i, v := range templateVars {
   340  		templateVars[i] = strings.TrimSpace(v)
   341  	}
   342  	return templateVars
   343  }
   345  const runtimeToTemplateSuffix = "_runtime_to_template"
   347  // findTemplateDecls, given an AST, finds all functions annotated with
   348  // execgen:template<foo,bar>, and returns a funcInfo for each of them, and
   349  // finds all var decls annotated with execgen:let, returning a letInfo for
   350  // each of them.
   351  func findTemplateDecls(f *dst.File) templateInfo {
   352  	ret := templateInfo{
   353  		funcInfos: make(map[string]*funcInfo),
   354  		letInfos:  make(map[string]*letInfo),
   355  	}
   357  	dstutil.Apply(f, func(cursor *dstutil.Cursor) bool {
   358  		n := cursor.Node()
   359  		switch n := n.(type) {
   360  		case *dst.FuncDecl:
   361  			var templateVars []string
   362  			var instantiateArgs [][]string
   363  			i := 0
   364  			for _, dec := range n.Decs.Start {
   365  				if matches := templateRe.FindStringSubmatch(dec); matches != nil {
   366  					templateVars = trimTemplateDeclMatches(matches)
   367  					continue
   368  				}
   370  				if matches := instantiateRe.FindStringSubmatch(dec); matches != nil {
   371  					instantiateMatches := trimTemplateDeclMatches(matches)
   372  					newInstantiateArgs := expandInstantiateArgs(instantiateMatches, ret.letInfos)
   373  					instantiateArgs = append(instantiateArgs, newInstantiateArgs...)
   374  					// Eventually let's delete the instantiate comments as well.
   375  					continue
   376  				}
   377  				// Filter decorations in place.
   378  				n.Decs.Start[i] = dec
   379  				i++
   380  			}
   381  			n.Decs.Start = n.Decs.Start[:i]
   382  			if templateVars == nil {
   383  				return false
   384  			}
   386  			// Process template funcs: find template params from runtime definition
   387  			// and save in funcInfo.
   388  			info := &funcInfo{
   389  				instantiateArgs: instantiateArgs,
   390  			}
   391  			for _, v := range templateVars {
   392  				var found bool
   393  				for i, f := range n.Type.Params.List {
   394  					// We can safely 0-index here because fields always have at least
   395  					// one name, and we've already banned the case where they have more
   396  					// than one. (e.g. func a (a int, b int, c, d int))
   397  					if f.Names[0].Name == v {
   398  						info.templateParams = append(info.templateParams, templateParamInfo{
   399  							fieldOrdinal: i,
   400  							field:        dst.Clone(f).(*dst.Field),
   401  						})
   402  						found = true
   403  						break
   404  					}
   405  				}
   406  				if !found {
   407  					panic(fmt.Errorf("template var %s not found", v))
   408  				}
   409  			}
   410  			// Delete template params from runtime definition.
   411  			newParamList := make([]*dst.Field, 0, len(n.Type.Params.List)-len(info.templateParams))
   412  			for i, field := range n.Type.Params.List {
   413  				var skip bool
   414  				for _, p := range info.templateParams {
   415  					if i == p.fieldOrdinal {
   416  						skip = true
   417  						break
   418  					}
   419  				}
   420  				if !skip {
   421  					newParamList = append(newParamList, field)
   422  				}
   423  			}
   424  			funcDecs := n.Decs
   425  			// Replace the template function with a const marker, just so we can keep
   426  			// the comments above the template function available.
   427  			cursor.InsertBefore(&dst.GenDecl{
   428  				Tok: token.CONST,
   429  				Specs: []dst.Spec{
   430  					&dst.ValueSpec{
   431  						Names: []*dst.Ident{dst.NewIdent("_")},
   432  						Values: []dst.Expr{
   433  							&dst.BasicLit{
   434  								Kind:  token.STRING,
   435  								Value: fmt.Sprintf(`"template_%s"`, n.Name.Name),
   436  							},
   437  						},
   438  					},
   439  				},
   440  				Decs: dst.GenDeclDecorations{
   441  					NodeDecs: funcDecs.NodeDecs,
   442  				},
   443  			})
   444  			oldParamList := n.Type.Params.List
   445  			n.Type.Params.List = newParamList
   446  			n.Decs.Start = trimStartDecs(n)
   447  			info.decl = n
   448  			ret.funcInfos[info.decl.Name.Name] = info
   450  			for _, args := range info.instantiateArgs {
   451  				exprList := make([]dst.Expr, len(args))
   452  				for j := range args {
   453  					exprList[j] = dst.NewIdent(args[j])
   454  				}
   455  				createTemplateFuncVariant(f, info, exprList)
   456  			}
   458  			// Now, we need to generate the "look up table" that allows us to convert
   459  			// runtime values into template values for the template args.
   461  			// We only do this if there were execgen:instantiate statements, since we
   462  			// assume that if there were no such statements, the concrete callsites
   463  			// were already present.
   465  			if info.instantiateArgs != nil {
   466  				runtimeArgs := make([]dst.Expr, len(n.Type.Params.List))
   467  				for i, p := range n.Type.Params.List {
   468  					runtimeArgs[i] = dst.NewIdent(p.Names[0].Name)
   469  				}
   470  				decl := &dst.FuncDecl{
   471  					Name: dst.NewIdent(fmt.Sprintf("%s%s", info.decl.Name.Name, runtimeToTemplateSuffix)),
   472  					Type: dst.Clone(info.decl.Type).(*dst.FuncType),
   473  					Body: &dst.BlockStmt{
   474  						List: []dst.Stmt{
   475  							generateSwitchStatementLookup(info, runtimeArgs, templateVars, info.instantiateArgs),
   476  						},
   477  					},
   478  				}
   479  				decl.Type.Params.List = oldParamList
   480  				cursor.InsertAfter(decl)
   481  			}
   482  			cursor.Delete()
   484  		case *dst.GenDecl:
   485  			// Search for execgen:let declarations.
   486  			isLet := false
   487  			for _, dec := range n.Decs.Start {
   488  				if dec == "// execgen:let" {
   489  					isLet = true
   490  					break
   491  				}
   492  			}
   493  			if !isLet {
   494  				return true
   495  			}
   496  			if n.Tok != token.VAR {
   497  				panic("execgen:let only allowed on vars")
   498  			}
   499  			for _, spec := range n.Specs {
   500  				n := spec.(*dst.ValueSpec)
   501  				if len(n.Names) != 1 || len(n.Values) != 1 {
   502  					panic("execgen:let must have 1 name and one value per var")
   503  				}
   504  				info := &letInfo{}
   505  				name := n.Names[0].Name
   506  				c, ok := n.Values[0].(*dst.CompositeLit)
   507  				if !ok {
   508  					panic("execgen:let must use a composite literal value")
   509  				}
   510  				typ, ok := c.Type.(*dst.ArrayType)
   511  				if !ok {
   512  					panic("execgen:let must be on an array type literal")
   513  				}
   514  				info.vals = make([]string, len(c.Elts))
   515  				info.typ = typ
   516  				for i := range c.Elts {
   517  					info.vals[i] = prettyPrintExprs(c.Elts[i])
   518  				}
   519  				ret.letInfos[name] = info
   520  			}
   522  			cursor.Delete()
   523  		}
   524  		return true
   525  	}, nil)
   527  	return ret
   528  }
   530  // expandInstantiateArgs takes a list of strings, the arguments to an
   531  // execgen:instantiate annotation, and returns a list of list of strings, after
   532  // combinatorially expanding any execgen:let lists in the instantiate arguments.
   533  // For example, given the instantiateArgs:
   534  // ["Bools", "Bools", 3]
   535  // and an execgen:let that maps "Bools" to ["true", "false"], we'd return the
   536  // list of lists:
   537  // [true, true, 3]
   538  // [true, false, 3]
   539  // [false, true, 3]
   540  // [false, false, 3]
   541  func expandInstantiateArgs(instantiateArgs []string, letInfos map[string]*letInfo) [][]string {
   542  	expandedArgs := make([][]string, len(instantiateArgs))
   543  	for i, arg := range instantiateArgs {
   544  		if info := letInfos[arg]; info != nil {
   545  			expandedArgs[i] = info.vals
   546  		} else {
   547  			expandedArgs[i] = []string{arg}
   548  		}
   549  	}
   550  	return generateInstantiateCombinations(expandedArgs)
   551  }
   553  func generateInstantiateCombinations(args [][]string) [][]string {
   554  	if len(args) == 1 {
   555  		// Base case: transform the final options list into an arguments list of
   556  		// lists where each arguments list is a single element containing one of
   557  		// the final options.
   558  		// For example, given [[true, false]], we'll return:
   559  		// [[true], [false]]
   560  		ret := make([][]string, len(args[0]))
   561  		for i, arg := range args[0] {
   562  			ret[i] = []string{arg}
   563  		}
   564  		return ret
   565  	}
   566  	rest := generateInstantiateCombinations(args[1:])
   567  	ret := make([][]string, 0, len(rest)*len(args[0]))
   568  	for _, argOption := range args[0] {
   569  		// For every option of argument, prepend it to every args list from
   570  		// the recursive step.
   571  		for _, args := range rest {
   572  			ret = append(ret, append([]string{argOption}, args...))
   573  		}
   574  	}
   575  	return ret
   576  }
   578  // generateSwitchStatementLookup ...
   579  // remainingArgs is a list of lists of actual instantiations. For example, if
   580  // we had:
   581  // execgen:instantiate<red, potato>
   582  // execgen:instantiate<red, orange>
   583  // execgen:instantiate<yellow, orange>
   584  //
   585  // remainingArgs would be {{red, potato}, {red, orange}, {yellow orange}}
   586  func generateSwitchStatementLookup(
   587  	info *funcInfo, curArgs []dst.Expr, remainingTemplateParams []string, remainingArgs [][]string,
   588  ) *dst.SwitchStmt {
   589  	ret := &dst.SwitchStmt{
   590  		Tag:  dst.NewIdent(remainingTemplateParams[0]),
   591  		Body: &dst.BlockStmt{},
   592  	}
   593  	defaultCase := &dst.CaseClause{
   594  		Body: []dst.Stmt{
   595  			mustParseStmt(`panic(fmt.Sprint("unknown value", ` + remainingTemplateParams[0] + `))`),
   596  		},
   597  	}
   598  	if len(remainingArgs[0]) == 1 {
   599  		// Base case. We finished switching on all template params, time to actually
   600  		// invoke the fully specialized function.
   601  		stmtList := make([]dst.Stmt, len(remainingArgs)+1)
   602  		for i := range remainingArgs {
   603  			argList := append(curArgs, dst.NewIdent(remainingArgs[i][0]))
   604  			call := &dst.CallExpr{
   605  				Fun:  dst.NewIdent(info.decl.Name.Name),
   606  				Args: argList,
   607  			}
   608  			_, call, _ = replaceTemplateVars(info, call)
   609  			var stmt dst.Stmt
   610  			if info.decl.Type.Results != nil {
   611  				stmt = &dst.ReturnStmt{Results: []dst.Expr{call}}
   612  			} else {
   613  				stmt = &dst.ExprStmt{X: call}
   614  			}
   615  			stmtList[i] = &dst.CaseClause{
   616  				List: []dst.Expr{dst.NewIdent(remainingArgs[i][0])},
   617  				Body: []dst.Stmt{stmt},
   618  			}
   619  		}
   620  		stmtList[len(stmtList)-1] = defaultCase
   621  		ret.Body.List = stmtList
   622  		return ret
   623  	}
   625  	// Recursive case: we have more args to deal with
   626  	groupedArgs := make(map[string][][]string)
   627  	for _, argList := range remainingArgs {
   628  		firstArg := argList[0]
   629  		groupedArgs[firstArg] = append(groupedArgs[firstArg], argList[1:])
   630  	}
   631  	stmtList := make([]dst.Stmt, len(groupedArgs)+1)
   632  	// Sort firstArgs lexicographically, so we have a consistent output order.
   633  	firstArgs := make([]string, 0, len(groupedArgs))
   634  	for firstArg := range groupedArgs {
   635  		firstArgs = append(firstArgs, firstArg)
   636  	}
   637  	sort.Strings(firstArgs)
   639  	for i, firstArg := range firstArgs {
   640  		restArgs := groupedArgs[firstArg]
   641  		argList := append(curArgs, dst.NewIdent(firstArg))
   642  		stmtList[i] = &dst.CaseClause{
   643  			List: []dst.Expr{dst.NewIdent(firstArg)},
   644  			Body: []dst.Stmt{generateSwitchStatementLookup(
   645  				info,
   646  				argList,
   647  				remainingTemplateParams[1:],
   648  				restArgs,
   649  			)},
   650  		}
   651  	}
   652  	stmtList[len(stmtList)-1] = defaultCase
   653  	ret.Body.List = stmtList
   654  	return ret
   655  }
   657  var nameMangler = strings.NewReplacer(".", "DOT", "*", "STAR")
   659  func getTemplateVariantName(info *funcInfo, args []dst.Expr) *dst.Ident {
   660  	var newName strings.Builder
   661  	newName.WriteString(info.decl.Name.Name)
   662  	for j := range args {
   663  		newName.WriteByte('_')
   664  		newName.WriteString(prettyPrintExprs(args[j]))
   665  	}
   666  	s := newName.String()
   667  	s = nameMangler.Replace(s)
   668  	return dst.NewIdent(s)
   669  }
   671  func trimStartDecs(n *dst.FuncDecl) []string {
   672  	// The function declaration node can accidentally capture extra comments that
   673  	// we want to leave in their original position, and not duplicate. So, remove
   674  	// any decorations that are separated from the function declaration by one or
   675  	// more newlines.
   676  	startDecs := n.Decs.Start.All()
   677  	for i := len(startDecs) - 1; i >= 0; i-- {
   678  		if strings.TrimSpace(startDecs[i]) == "" {
   679  			return startDecs[i+1:]
   680  		}
   681  	}
   682  	return startDecs
   683  }
   685  func trimLeadingNewLines(decs []string) []string {
   686  	var i int
   687  	for ; i < len(decs); i++ {
   688  		if strings.TrimSpace(decs[i]) != "" {
   689  			break
   690  		}
   691  	}
   692  	return decs[i:]
   693  }
   695  // replaceAndExpandTemplates finds all CallExprs in the input AST that are calling
   696  // the functions that had been annotated with // execgen:template that are
   697  // passed in via the templateFuncInfos map. It recursively replaces the
   698  // CallExprs with their expanded, mangled template function names, and creates
   699  // the requisite monomorphized FuncDecls on demand.
   700  //
   701  // For example, given a template function:
   702  //
   703  // // execgen:template<b>
   704  //
   705  //	func foo (a int, b bool) {
   706  //	  if b {
   707  //	    return a
   708  //	  } else {
   709  //	    return a + 1
   710  //	  }
   711  //	}
   712  //
   713  // And callsites:
   714  //
   715  // foo(a, true)
   716  // foo(a, false)
   717  //
   718  // This function will add 2 new func decls to the AST:
   719  //
   720  //	func foo_true(a int) {
   721  //	  return a
   722  //	}
   723  //
   724  //	func foo_false(a int) {
   725  //	  return a + 1
   726  //	}
   727  func replaceAndExpandTemplates(f *dst.File, templateFuncInfos map[string]*funcInfo) dst.Node {
   728  	// First, create the DAG of template functions. This DAG points from template
   729  	// function to any other template functions that are called from within its
   730  	// body that propagate template arguments.
   731  	// First, find all "roots": template CallExprs that only have concrete
   732  	// arguments.
   733  	var q []*dst.CallExpr
   734  	dstutil.Apply(f, func(cursor *dstutil.Cursor) bool {
   735  		n := cursor.Node()
   736  		switch n := n.(type) {
   737  		case *dst.FuncDecl:
   738  			q = append(q, findConcreteTemplateCallSites(n, templateFuncInfos)...)
   739  		}
   740  		return true
   741  	}, nil)
   743  	// For every remaining concrete call site, replace it with its mangled template
   744  	// function call, and generate the requisite monomorphized template function
   745  	// if we haven't already.
   746  	//
   747  	// Then, process the new monomorphized template function and add any newly
   748  	// created concrete template call sites to the queue. Do this until we have no
   749  	// more concrete template call sites.
   750  	seenCallsites := make(map[string]struct{})
   751  	for len(q) > 0 {
   752  		q = q[:0]
   753  		dstutil.Apply(f, func(cursor *dstutil.Cursor) bool {
   754  			n := cursor.Node()
   755  			switch n := n.(type) {
   756  			case *dst.CallExpr:
   757  				ident, ok := n.Fun.(*dst.Ident)
   758  				if !ok {
   759  					return true
   760  				}
   761  				info, ok := templateFuncInfos[ident.Name]
   762  				if !ok {
   763  					// Nothing to do, it's not a templated function.
   764  					return true
   765  				}
   766  				// Critical moment: We need to know whether to replace with concrete
   767  				// input args or to replace the call with the lookup version.
   768  				if info.instantiateArgs != nil {
   769  					n.Fun = dst.NewIdent(fmt.Sprintf("%s%s", info.decl.Name.Name, runtimeToTemplateSuffix))
   770  					cursor.Replace(n)
   771  					return true
   772  				}
   773  				templateArgs, newCall, newName := replaceTemplateVars(info, n)
   774  				cursor.Replace(newCall)
   775  				// Have we already replaced this template function with these args?
   776  				funcInstance := newName + prettyPrintExprs(templateArgs...)
   777  				if _, ok := seenCallsites[funcInstance]; !ok {
   778  					seenCallsites[funcInstance] = struct{}{}
   779  					newFuncVariant := createTemplateFuncVariant(f, info, templateArgs)
   780  					q = append(q, findConcreteTemplateCallSites(newFuncVariant, templateFuncInfos)...)
   781  				}
   782  			}
   783  			return true
   784  		}, nil)
   785  	}
   786  	return nil
   787  }
   789  // findConcreteTemplateCallSites finds all CallExprs within the input funcDecl
   790  // that do not contain template arguments and thus can be immediately replaced.
   791  func findConcreteTemplateCallSites(
   792  	funcDecl *dst.FuncDecl, templateFuncInfos map[string]*funcInfo,
   793  ) []*dst.CallExpr {
   794  	info, calledFromTemplate := templateFuncInfos[funcDecl.Name.Name]
   795  	var ret []*dst.CallExpr
   796  	dstutil.Apply(funcDecl, func(cursor *dstutil.Cursor) bool {
   797  		n := cursor.Node()
   798  		switch callExpr := n.(type) {
   799  		case *dst.CallExpr:
   800  			ident, ok := callExpr.Fun.(*dst.Ident)
   801  			if !ok {
   802  				return true
   803  			}
   804  			_, ok = templateFuncInfos[ident.Name]
   805  			if !ok {
   806  				// Nothing to do, it's not a templated function.
   807  				return true
   808  			}
   809  			if !calledFromTemplate {
   810  				// All arguments are concrete since the callsite isn't within another
   811  				// templated function decl.
   812  				ret = append(ret, callExpr)
   813  				return true
   814  			}
   815  			for i := range callExpr.Args {
   816  				switch a := callExpr.Args[i].(type) {
   817  				case *dst.Ident:
   818  					for _, param := range info.templateParams {
   819  						if param.field.Names[0].Name == a.Name {
   820  							// Found a propagated template parameter, so we don't return
   821  							// this CallExpr (it's not concrete).
   822  							// NOTE: This is broken in the presence of shadowing.
   823  							// Let's assume nobody shadows template vars for now.
   824  							return true
   825  						}
   826  					}
   827  				}
   828  			}
   829  			ret = append(ret, callExpr)
   830  		}
   831  		return true
   832  	}, nil)
   833  	return ret
   834  }
   836  // expandTemplates is the main entry point to the templater. Given a dst.File,
   837  // it modifies the dst.File to include all expanded template functions, and
   838  // edits call sites to call the newly expanded functions.
   839  func expandTemplates(f *dst.File) {
   840  	templateInfo := findTemplateDecls(f)
   841  	replaceAndExpandTemplates(f, templateInfo.funcInfos)
   842  }
   844  // createTemplateFuncVariant creates a variant of the input funcInfo given the
   845  // template arguments passed in args, and adds the variant to the end of the
   846  // input file.
   847  func createTemplateFuncVariant(f *dst.File, info *funcInfo, args []dst.Expr) *dst.FuncDecl {
   848  	n := info.decl
   849  	directives := n.Decs.NodeDecs.Start
   850  	newBody := monomorphizeTemplate(dst.Clone(n.Body).(*dst.BlockStmt), info, args).(*dst.BlockStmt)
   851  	newName := getTemplateVariantName(info, args)
   852  	ret := &dst.FuncDecl{
   853  		Name: newName,
   854  		Type: dst.Clone(info.decl.Type).(*dst.FuncType),
   855  		Body: newBody,
   856  		Decs: dst.FuncDeclDecorations{
   857  			NodeDecs: dst.NodeDecs{
   858  				Before: dst.EmptyLine,
   859  				Start:  directives,
   860  			},
   861  		},
   862  	}
   863  	f.Decls = append(f.Decls, ret)
   864  	return ret
   865  }