github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/colexec/execgen/inline.go (about)

     1  // Copyright 2020 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package execgen
    12  
    13  import (
    14  	"fmt"
    15  	"go/token"
    16  
    17  	"github.com/cockroachdb/errors"
    18  	"github.com/dave/dst"
    19  	"github.com/dave/dst/dstutil"
    20  )
    21  
    22  // inlineFuncs takes an input file's contents and inlines all functions
    23  // annotated with // execgen:inline into their callsites via AST manipulation.
    24  func inlineFuncs(f *dst.File) {
    25  	// First, run over the input file, searching for functions that are annotated
    26  	// with execgen:inline.
    27  	inlineFuncMap := extractInlineFuncDecls(f)
    28  
    29  	// Do a second pass over the AST, this time replacing calls to the inlined
    30  	// functions with the inlined function itself.
    31  	inlineFunc(inlineFuncMap, f)
    32  }
    33  
    34  func inlineFunc(inlineFuncMap map[string]funcInfo, n dst.Node) dst.Node {
    35  	var funcIdx int
    36  	return dstutil.Apply(n, func(cursor *dstutil.Cursor) bool {
    37  		cursor.Index()
    38  		n := cursor.Node()
    39  		// There are two cases. AssignStmt, which are like:
    40  		// a = foo()
    41  		// and ExprStmt, which are simply:
    42  		// foo()
    43  		// AssignStmts need to do extra work for inlining, because we have to
    44  		// simulate producing return values.
    45  		switch n := n.(type) {
    46  		case *dst.AssignStmt:
    47  			// Search for assignment function call:
    48  			// a = foo()
    49  			callExpr, ok := n.Rhs[0].(*dst.CallExpr)
    50  			if !ok {
    51  				return true
    52  			}
    53  			funcInfo := getInlinedFunc(inlineFuncMap, callExpr)
    54  			if funcInfo == nil {
    55  				return true
    56  			}
    57  			// We want to recurse here because funcInfo itself might have calls to
    58  			// inlined functions.
    59  			funcInfo.decl = inlineFunc(inlineFuncMap, funcInfo.decl).(*dst.FuncDecl)
    60  
    61  			if len(n.Rhs) > 1 {
    62  				panic("can't do template replacement with more than a single RHS to a CallExpr")
    63  			}
    64  
    65  			if n.Tok == token.DEFINE {
    66  				// We need to put a variable declaration for the new defined variables
    67  				// in the parent scope.
    68  				newDefinitions := &dst.GenDecl{
    69  					Tok:   token.VAR,
    70  					Specs: make([]dst.Spec, len(n.Lhs)),
    71  				}
    72  
    73  				for i, e := range n.Lhs {
    74  					// If we had foo, bar := thingToInline(), we'd get
    75  					// var (
    76  					//   foo int
    77  					//   bar int
    78  					// )
    79  					newDefinitions.Specs[i] = &dst.ValueSpec{
    80  						Names: []*dst.Ident{dst.NewIdent(e.(*dst.Ident).Name)},
    81  						Type:  dst.Clone(funcInfo.decl.Type.Results.List[i].Type).(dst.Expr),
    82  					}
    83  				}
    84  
    85  				cursor.InsertBefore(&dst.DeclStmt{Decl: newDefinitions})
    86  			}
    87  
    88  			// Now we've got a callExpr. We need to inline the function call, and
    89  			// convert the result into the assignment variable.
    90  
    91  			decl := funcInfo.decl
    92  			// Produce declarations for each return value of the function to inline.
    93  			retValDeclStmt, retValNames := extractReturnValues(decl)
    94  			// inlinedStatements is a BlockStmt (a set of statements within curly
    95  			// braces) that contains the entirety of the statements that result from
    96  			// inlining the call. We make this a BlockStmt to avoid issues with
    97  			// variable shadowing.
    98  			// The first thing that goes in the BlockStmt is the ret val declarations.
    99  			// When we're done, the BlockStmt for a statement
   100  			//    a, b = foo(x, y)
   101  			// where foo was defined as
   102  			//    func foo(b string, c string) { ... }
   103  			// will look like:
   104  			// {
   105  			//    var (
   106  			//     __retval_0 bool
   107  			//     __retval_1 int
   108  			//    )
   109  			//    ...
   110  			//    {
   111  			//       b := x
   112  			//       c := y
   113  			//       ... the contents of func foo() except its return ...
   114  			//       {
   115  			//          // If foo() had `return true, j`, we'll generate the code:
   116  			//          __retval_0 = true
   117  			//          __retval_1 = j
   118  			//       }
   119  			//    }
   120  			//    a   = __retval_0
   121  			//    b   = __retval_1
   122  			// }
   123  			inlinedStatements := &dst.BlockStmt{
   124  				List: []dst.Stmt{retValDeclStmt},
   125  				Decs: dst.BlockStmtDecorations{
   126  					NodeDecs: n.Decs.NodeDecs,
   127  				},
   128  			}
   129  			body := dst.Clone(decl.Body).(*dst.BlockStmt)
   130  
   131  			// Replace return statements with assignments to the return values.
   132  			// Make a copy of the function to inline, and walk through it, replacing
   133  			// return statements at the end of the body with assignments to the return
   134  			// value declarations we made first.
   135  			body = replaceReturnStatements(decl.Name.Name, funcIdx, body, func(stmt *dst.ReturnStmt) dst.Stmt {
   136  				returnAssignmentSpecs := make([]dst.Stmt, len(retValNames))
   137  				for i := range retValNames {
   138  					returnAssignmentSpecs[i] = &dst.AssignStmt{
   139  						Lhs: []dst.Expr{dst.NewIdent(retValNames[i])},
   140  						Tok: token.ASSIGN,
   141  						Rhs: []dst.Expr{stmt.Results[i]},
   142  					}
   143  				}
   144  				// Replace the return with the new assignments.
   145  				return &dst.BlockStmt{List: returnAssignmentSpecs}
   146  			})
   147  			// Reassign input parameters to formal parameters.
   148  			reassignmentStmt := getFormalParamReassignments(decl, callExpr)
   149  			inlinedStatements.List = append(inlinedStatements.List, &dst.BlockStmt{
   150  				List: append([]dst.Stmt{reassignmentStmt}, body.List...),
   151  			})
   152  			// Assign mangled return values to the original assignment variables.
   153  			newAssignment := dst.Clone(n).(*dst.AssignStmt)
   154  			newAssignment.Tok = token.ASSIGN
   155  			newAssignment.Rhs = make([]dst.Expr, len(retValNames))
   156  			for i := range retValNames {
   157  				newAssignment.Rhs[i] = dst.NewIdent(retValNames[i])
   158  			}
   159  			inlinedStatements.List = append(inlinedStatements.List, newAssignment)
   160  			cursor.Replace(inlinedStatements)
   161  
   162  		case *dst.ExprStmt:
   163  			// Search for raw function call:
   164  			// foo()
   165  			callExpr, ok := n.X.(*dst.CallExpr)
   166  			if !ok {
   167  				return true
   168  			}
   169  			funcInfo := getInlinedFunc(inlineFuncMap, callExpr)
   170  			if funcInfo == nil {
   171  				return true
   172  			}
   173  			// We want to recurse here because funcInfo itself might have calls to
   174  			// inlined functions.
   175  			funcInfo.decl = inlineFunc(inlineFuncMap, funcInfo.decl).(*dst.FuncDecl)
   176  			decl := funcInfo.decl
   177  
   178  			reassignments := getFormalParamReassignments(decl, callExpr)
   179  
   180  			// This case is simpler than the AssignStmt case. It's identical, except
   181  			// there is no mangled return value name block, nor re-assignment to
   182  			// the mangled returns after the inlined function.
   183  			funcBlock := &dst.BlockStmt{
   184  				List: []dst.Stmt{reassignments},
   185  				Decs: dst.BlockStmtDecorations{
   186  					NodeDecs: n.Decs.NodeDecs,
   187  				},
   188  			}
   189  			body := dst.Clone(decl.Body).(*dst.BlockStmt)
   190  
   191  			// Remove return values if there are any, since we're ignoring returns
   192  			// as a raw function call.
   193  			body = replaceReturnStatements(decl.Name.Name, funcIdx, body, nil)
   194  			// Add the inlined function body to the block.
   195  			funcBlock.List = append(funcBlock.List, body.List...)
   196  
   197  			cursor.Replace(funcBlock)
   198  		default:
   199  			return true
   200  		}
   201  		funcIdx++
   202  		return true
   203  	}, nil)
   204  }
   205  
   206  // extractInlineFuncDecls searches the input file for functions that are
   207  // annotated with execgen:inline, extracts them into templateFuncMap, and
   208  // deletes them from the AST.
   209  func extractInlineFuncDecls(f *dst.File) map[string]funcInfo {
   210  	ret := make(map[string]funcInfo)
   211  	dstutil.Apply(f, func(cursor *dstutil.Cursor) bool {
   212  		n := cursor.Node()
   213  		switch n := n.(type) {
   214  		case *dst.FuncDecl:
   215  			var mustInline bool
   216  			for _, dec := range n.Decorations().Start.All() {
   217  				if dec == "// execgen:inline" {
   218  					mustInline = true
   219  				}
   220  			}
   221  			if !mustInline {
   222  				// Nothing to do, but recurse further.
   223  				return true
   224  			}
   225  			for _, p := range n.Type.Params.List {
   226  				if len(p.Names) > 1 {
   227  					// If we have a definition like this:
   228  					// func a (a, b int) int
   229  					// We're just giving up for now out of complete laziness.
   230  					panic("can't currently deal with multiple names per type in decls")
   231  				}
   232  			}
   233  
   234  			var info funcInfo
   235  			info.decl = dst.Clone(n).(*dst.FuncDecl)
   236  			// Store the function in a map.
   237  			ret[n.Name.Name] = info
   238  
   239  			// Replace the function textually with a fake constant, such as:
   240  			// `const _ = "inlined_blahFunc"`. We do this instead
   241  			// of completely deleting it to prevent "important comments" above the
   242  			// function to be deleted, such as template comments like {{end}}. This
   243  			// is kind of a quirk of the way the comments are parsed, but nonetheless
   244  			// this is an easy fix so we'll leave it for now.
   245  			cursor.Replace(&dst.GenDecl{
   246  				Tok: token.CONST,
   247  				Specs: []dst.Spec{
   248  					&dst.ValueSpec{
   249  						Names: []*dst.Ident{dst.NewIdent("_")},
   250  						Values: []dst.Expr{
   251  							&dst.BasicLit{
   252  								Kind:  token.STRING,
   253  								Value: fmt.Sprintf(`"inlined_%s"`, n.Name.Name),
   254  							},
   255  						},
   256  					},
   257  				},
   258  				Decs: dst.GenDeclDecorations{
   259  					NodeDecs: n.Decs.NodeDecs,
   260  				},
   261  			})
   262  			return false
   263  		}
   264  		return true
   265  	}, nil)
   266  	return ret
   267  }
   268  
   269  // extractReturnValues generates return value variables. It will produce one
   270  // statement per return value of the input FuncDecl. For example, for
   271  // a FuncDecl that returns two boolean arguments, lastVal and lastValNull,
   272  // two statements will be returned:
   273  //
   274  //	var __retval_lastVal bool
   275  //	var __retval_lastValNull bool
   276  //
   277  // The second return is a slice of the names of each of the mangled return
   278  // declarations, in this example, __retval_lastVal and __retval_lastValNull.
   279  func extractReturnValues(decl *dst.FuncDecl) (retValDeclStmt dst.Stmt, retValNames []string) {
   280  	if decl.Type.Results == nil {
   281  		return &dst.EmptyStmt{}, nil
   282  	}
   283  	results := decl.Type.Results.List
   284  	retValNames = make([]string, len(results))
   285  	specs := make([]dst.Spec, len(results))
   286  	for i, result := range results {
   287  		var retvalName string
   288  		// Make a mangled name.
   289  		if len(result.Names) == 0 {
   290  			retvalName = fmt.Sprintf("__retval_%d", i)
   291  		} else {
   292  			retvalName = fmt.Sprintf("__retval_%s", result.Names[0])
   293  		}
   294  		retValNames[i] = retvalName
   295  		specs[i] = &dst.ValueSpec{
   296  			Names: []*dst.Ident{dst.NewIdent(retvalName)},
   297  			Type:  dst.Clone(result.Type).(dst.Expr),
   298  		}
   299  	}
   300  	return &dst.DeclStmt{
   301  		Decl: &dst.GenDecl{
   302  			Tok:   token.VAR,
   303  			Specs: specs,
   304  		},
   305  	}, retValNames
   306  }
   307  
   308  // getFormalParamReassignments creates a new DEFINE (:=) statement per parameter
   309  // to a FuncDecl, which makes a fresh variable with the same name as the formal
   310  // parameter name and assigns it to the corresponding name in the CallExpr.
   311  //
   312  // For example, given a FuncDecl:
   313  //
   314  // func foo(a int, b string) { ... }
   315  //
   316  // and a CallExpr
   317  //
   318  // foo(x, y)
   319  //
   320  // we'll return the statement:
   321  //
   322  // var (
   323  //
   324  //	a int = x
   325  //	b string = y
   326  //
   327  // )
   328  //
   329  // In the case where the formal parameter name is the same as the input
   330  // parameter name, no extra assignment is created.
   331  func getFormalParamReassignments(decl *dst.FuncDecl, callExpr *dst.CallExpr) dst.Stmt {
   332  	formalParams := decl.Type.Params.List
   333  	reassignmentSpecs := make([]dst.Spec, 0, len(formalParams))
   334  	for i, formalParam := range formalParams {
   335  		if inputIdent, ok := callExpr.Args[i].(*dst.Ident); ok && inputIdent.Name == formalParam.Names[0].Name {
   336  			continue
   337  		}
   338  		reassignmentSpecs = append(reassignmentSpecs, &dst.ValueSpec{
   339  			Names:  []*dst.Ident{dst.NewIdent(formalParam.Names[0].Name)},
   340  			Type:   dst.Clone(formalParam.Type).(dst.Expr),
   341  			Values: []dst.Expr{callExpr.Args[i]},
   342  		})
   343  	}
   344  	if len(reassignmentSpecs) == 0 {
   345  		return &dst.EmptyStmt{}
   346  	}
   347  	return &dst.DeclStmt{
   348  		Decl: &dst.GenDecl{
   349  			Tok:   token.VAR,
   350  			Specs: reassignmentSpecs,
   351  		},
   352  	}
   353  }
   354  
   355  // replaceReturnStatements edits the input BlockStmt, from the function funcName,
   356  // replacing ReturnStmts at the end of the BlockStmts with the results of
   357  // applying returnEditor on the ReturnStmt or deleting them if the modifier is
   358  // nil.
   359  // It will panic if any return statements are not in the final position of the
   360  // input block.
   361  func replaceReturnStatements(
   362  	funcName string, funcIdx int, stmt *dst.BlockStmt, returnModifier func(*dst.ReturnStmt) dst.Stmt,
   363  ) *dst.BlockStmt {
   364  	if len(stmt.List) == 0 {
   365  		return stmt
   366  	}
   367  	// Insert an explicit return at the end if there isn't one.
   368  	// We'll need to edit this later to make early returns work properly.
   369  	lastStmt := stmt.List[len(stmt.List)-1]
   370  	if _, ok := lastStmt.(*dst.ReturnStmt); !ok {
   371  		ret := &dst.ReturnStmt{}
   372  		stmt.List = append(stmt.List, ret)
   373  		lastStmt = ret
   374  	}
   375  	retStmt := lastStmt.(*dst.ReturnStmt)
   376  	if returnModifier == nil {
   377  		stmt.List[len(stmt.List)-1] = &dst.EmptyStmt{}
   378  	} else {
   379  		stmt.List[len(stmt.List)-1] = returnModifier(retStmt)
   380  	}
   381  
   382  	label := dst.NewIdent(fmt.Sprintf("%s_return_%d", funcName, funcIdx))
   383  
   384  	// Find returns that weren't at the end of the function and replace them with
   385  	// labeled gotos.
   386  	var foundInlineReturn bool
   387  	stmt = dstutil.Apply(stmt, func(cursor *dstutil.Cursor) bool {
   388  		n := cursor.Node()
   389  		switch n := n.(type) {
   390  		case *dst.FuncLit:
   391  			// A FuncLit is a function literal, like:
   392  			// x := func() int { return 3 }
   393  			// We don't recurse into function literals since the return statements
   394  			// they contain aren't relevant to the inliner.
   395  			return false
   396  		case *dst.ReturnStmt:
   397  			foundInlineReturn = true
   398  			gotoStmt := &dst.BranchStmt{
   399  				Tok:   token.GOTO,
   400  				Label: dst.Clone(label).(*dst.Ident),
   401  			}
   402  			if returnModifier != nil {
   403  				cursor.Replace(returnModifier(n))
   404  				cursor.InsertAfter(gotoStmt)
   405  			} else {
   406  				cursor.Replace(gotoStmt)
   407  			}
   408  			return false
   409  		}
   410  		return true
   411  	}, nil).(*dst.BlockStmt)
   412  
   413  	if foundInlineReturn {
   414  		// Add the label at the end.
   415  		stmt.List = append(stmt.List,
   416  			&dst.LabeledStmt{
   417  				Label: label,
   418  				Stmt:  &dst.EmptyStmt{Implicit: true},
   419  			})
   420  	}
   421  	return stmt
   422  }
   423  
   424  // getInlinedFunc returns the corresponding FuncDecl for a CallExpr from the
   425  // map, using the CallExpr's name to look up the FuncDecl from templateFuncs.
   426  func getInlinedFunc(templateFuncs map[string]funcInfo, n *dst.CallExpr) *funcInfo {
   427  	ident, ok := n.Fun.(*dst.Ident)
   428  	if !ok {
   429  		return nil
   430  	}
   431  
   432  	info, ok := templateFuncs[ident.Name]
   433  	if !ok {
   434  		return nil
   435  	}
   436  	decl := info.decl
   437  	if decl.Type.Params.NumFields()+len(info.templateParams) != len(n.Args) {
   438  		panic(errors.Newf(
   439  			"%s expected %d arguments, found %d",
   440  			decl.Name, decl.Type.Params.NumFields(), len(n.Args)),
   441  		)
   442  	}
   443  	return &info
   444  }