github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/colexec/execgen/inline.go (about) 1 // Copyright 2020 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package execgen 12 13 import ( 14 "fmt" 15 "go/token" 16 17 "github.com/cockroachdb/errors" 18 "github.com/dave/dst" 19 "github.com/dave/dst/dstutil" 20 ) 21 22 // inlineFuncs takes an input file's contents and inlines all functions 23 // annotated with // execgen:inline into their callsites via AST manipulation. 24 func inlineFuncs(f *dst.File) { 25 // First, run over the input file, searching for functions that are annotated 26 // with execgen:inline. 27 inlineFuncMap := extractInlineFuncDecls(f) 28 29 // Do a second pass over the AST, this time replacing calls to the inlined 30 // functions with the inlined function itself. 31 inlineFunc(inlineFuncMap, f) 32 } 33 34 func inlineFunc(inlineFuncMap map[string]funcInfo, n dst.Node) dst.Node { 35 var funcIdx int 36 return dstutil.Apply(n, func(cursor *dstutil.Cursor) bool { 37 cursor.Index() 38 n := cursor.Node() 39 // There are two cases. AssignStmt, which are like: 40 // a = foo() 41 // and ExprStmt, which are simply: 42 // foo() 43 // AssignStmts need to do extra work for inlining, because we have to 44 // simulate producing return values. 45 switch n := n.(type) { 46 case *dst.AssignStmt: 47 // Search for assignment function call: 48 // a = foo() 49 callExpr, ok := n.Rhs[0].(*dst.CallExpr) 50 if !ok { 51 return true 52 } 53 funcInfo := getInlinedFunc(inlineFuncMap, callExpr) 54 if funcInfo == nil { 55 return true 56 } 57 // We want to recurse here because funcInfo itself might have calls to 58 // inlined functions. 59 funcInfo.decl = inlineFunc(inlineFuncMap, funcInfo.decl).(*dst.FuncDecl) 60 61 if len(n.Rhs) > 1 { 62 panic("can't do template replacement with more than a single RHS to a CallExpr") 63 } 64 65 if n.Tok == token.DEFINE { 66 // We need to put a variable declaration for the new defined variables 67 // in the parent scope. 68 newDefinitions := &dst.GenDecl{ 69 Tok: token.VAR, 70 Specs: make([]dst.Spec, len(n.Lhs)), 71 } 72 73 for i, e := range n.Lhs { 74 // If we had foo, bar := thingToInline(), we'd get 75 // var ( 76 // foo int 77 // bar int 78 // ) 79 newDefinitions.Specs[i] = &dst.ValueSpec{ 80 Names: []*dst.Ident{dst.NewIdent(e.(*dst.Ident).Name)}, 81 Type: dst.Clone(funcInfo.decl.Type.Results.List[i].Type).(dst.Expr), 82 } 83 } 84 85 cursor.InsertBefore(&dst.DeclStmt{Decl: newDefinitions}) 86 } 87 88 // Now we've got a callExpr. We need to inline the function call, and 89 // convert the result into the assignment variable. 90 91 decl := funcInfo.decl 92 // Produce declarations for each return value of the function to inline. 93 retValDeclStmt, retValNames := extractReturnValues(decl) 94 // inlinedStatements is a BlockStmt (a set of statements within curly 95 // braces) that contains the entirety of the statements that result from 96 // inlining the call. We make this a BlockStmt to avoid issues with 97 // variable shadowing. 98 // The first thing that goes in the BlockStmt is the ret val declarations. 99 // When we're done, the BlockStmt for a statement 100 // a, b = foo(x, y) 101 // where foo was defined as 102 // func foo(b string, c string) { ... } 103 // will look like: 104 // { 105 // var ( 106 // __retval_0 bool 107 // __retval_1 int 108 // ) 109 // ... 110 // { 111 // b := x 112 // c := y 113 // ... the contents of func foo() except its return ... 114 // { 115 // // If foo() had `return true, j`, we'll generate the code: 116 // __retval_0 = true 117 // __retval_1 = j 118 // } 119 // } 120 // a = __retval_0 121 // b = __retval_1 122 // } 123 inlinedStatements := &dst.BlockStmt{ 124 List: []dst.Stmt{retValDeclStmt}, 125 Decs: dst.BlockStmtDecorations{ 126 NodeDecs: n.Decs.NodeDecs, 127 }, 128 } 129 body := dst.Clone(decl.Body).(*dst.BlockStmt) 130 131 // Replace return statements with assignments to the return values. 132 // Make a copy of the function to inline, and walk through it, replacing 133 // return statements at the end of the body with assignments to the return 134 // value declarations we made first. 135 body = replaceReturnStatements(decl.Name.Name, funcIdx, body, func(stmt *dst.ReturnStmt) dst.Stmt { 136 returnAssignmentSpecs := make([]dst.Stmt, len(retValNames)) 137 for i := range retValNames { 138 returnAssignmentSpecs[i] = &dst.AssignStmt{ 139 Lhs: []dst.Expr{dst.NewIdent(retValNames[i])}, 140 Tok: token.ASSIGN, 141 Rhs: []dst.Expr{stmt.Results[i]}, 142 } 143 } 144 // Replace the return with the new assignments. 145 return &dst.BlockStmt{List: returnAssignmentSpecs} 146 }) 147 // Reassign input parameters to formal parameters. 148 reassignmentStmt := getFormalParamReassignments(decl, callExpr) 149 inlinedStatements.List = append(inlinedStatements.List, &dst.BlockStmt{ 150 List: append([]dst.Stmt{reassignmentStmt}, body.List...), 151 }) 152 // Assign mangled return values to the original assignment variables. 153 newAssignment := dst.Clone(n).(*dst.AssignStmt) 154 newAssignment.Tok = token.ASSIGN 155 newAssignment.Rhs = make([]dst.Expr, len(retValNames)) 156 for i := range retValNames { 157 newAssignment.Rhs[i] = dst.NewIdent(retValNames[i]) 158 } 159 inlinedStatements.List = append(inlinedStatements.List, newAssignment) 160 cursor.Replace(inlinedStatements) 161 162 case *dst.ExprStmt: 163 // Search for raw function call: 164 // foo() 165 callExpr, ok := n.X.(*dst.CallExpr) 166 if !ok { 167 return true 168 } 169 funcInfo := getInlinedFunc(inlineFuncMap, callExpr) 170 if funcInfo == nil { 171 return true 172 } 173 // We want to recurse here because funcInfo itself might have calls to 174 // inlined functions. 175 funcInfo.decl = inlineFunc(inlineFuncMap, funcInfo.decl).(*dst.FuncDecl) 176 decl := funcInfo.decl 177 178 reassignments := getFormalParamReassignments(decl, callExpr) 179 180 // This case is simpler than the AssignStmt case. It's identical, except 181 // there is no mangled return value name block, nor re-assignment to 182 // the mangled returns after the inlined function. 183 funcBlock := &dst.BlockStmt{ 184 List: []dst.Stmt{reassignments}, 185 Decs: dst.BlockStmtDecorations{ 186 NodeDecs: n.Decs.NodeDecs, 187 }, 188 } 189 body := dst.Clone(decl.Body).(*dst.BlockStmt) 190 191 // Remove return values if there are any, since we're ignoring returns 192 // as a raw function call. 193 body = replaceReturnStatements(decl.Name.Name, funcIdx, body, nil) 194 // Add the inlined function body to the block. 195 funcBlock.List = append(funcBlock.List, body.List...) 196 197 cursor.Replace(funcBlock) 198 default: 199 return true 200 } 201 funcIdx++ 202 return true 203 }, nil) 204 } 205 206 // extractInlineFuncDecls searches the input file for functions that are 207 // annotated with execgen:inline, extracts them into templateFuncMap, and 208 // deletes them from the AST. 209 func extractInlineFuncDecls(f *dst.File) map[string]funcInfo { 210 ret := make(map[string]funcInfo) 211 dstutil.Apply(f, func(cursor *dstutil.Cursor) bool { 212 n := cursor.Node() 213 switch n := n.(type) { 214 case *dst.FuncDecl: 215 var mustInline bool 216 for _, dec := range n.Decorations().Start.All() { 217 if dec == "// execgen:inline" { 218 mustInline = true 219 } 220 } 221 if !mustInline { 222 // Nothing to do, but recurse further. 223 return true 224 } 225 for _, p := range n.Type.Params.List { 226 if len(p.Names) > 1 { 227 // If we have a definition like this: 228 // func a (a, b int) int 229 // We're just giving up for now out of complete laziness. 230 panic("can't currently deal with multiple names per type in decls") 231 } 232 } 233 234 var info funcInfo 235 info.decl = dst.Clone(n).(*dst.FuncDecl) 236 // Store the function in a map. 237 ret[n.Name.Name] = info 238 239 // Replace the function textually with a fake constant, such as: 240 // `const _ = "inlined_blahFunc"`. We do this instead 241 // of completely deleting it to prevent "important comments" above the 242 // function to be deleted, such as template comments like {{end}}. This 243 // is kind of a quirk of the way the comments are parsed, but nonetheless 244 // this is an easy fix so we'll leave it for now. 245 cursor.Replace(&dst.GenDecl{ 246 Tok: token.CONST, 247 Specs: []dst.Spec{ 248 &dst.ValueSpec{ 249 Names: []*dst.Ident{dst.NewIdent("_")}, 250 Values: []dst.Expr{ 251 &dst.BasicLit{ 252 Kind: token.STRING, 253 Value: fmt.Sprintf(`"inlined_%s"`, n.Name.Name), 254 }, 255 }, 256 }, 257 }, 258 Decs: dst.GenDeclDecorations{ 259 NodeDecs: n.Decs.NodeDecs, 260 }, 261 }) 262 return false 263 } 264 return true 265 }, nil) 266 return ret 267 } 268 269 // extractReturnValues generates return value variables. It will produce one 270 // statement per return value of the input FuncDecl. For example, for 271 // a FuncDecl that returns two boolean arguments, lastVal and lastValNull, 272 // two statements will be returned: 273 // 274 // var __retval_lastVal bool 275 // var __retval_lastValNull bool 276 // 277 // The second return is a slice of the names of each of the mangled return 278 // declarations, in this example, __retval_lastVal and __retval_lastValNull. 279 func extractReturnValues(decl *dst.FuncDecl) (retValDeclStmt dst.Stmt, retValNames []string) { 280 if decl.Type.Results == nil { 281 return &dst.EmptyStmt{}, nil 282 } 283 results := decl.Type.Results.List 284 retValNames = make([]string, len(results)) 285 specs := make([]dst.Spec, len(results)) 286 for i, result := range results { 287 var retvalName string 288 // Make a mangled name. 289 if len(result.Names) == 0 { 290 retvalName = fmt.Sprintf("__retval_%d", i) 291 } else { 292 retvalName = fmt.Sprintf("__retval_%s", result.Names[0]) 293 } 294 retValNames[i] = retvalName 295 specs[i] = &dst.ValueSpec{ 296 Names: []*dst.Ident{dst.NewIdent(retvalName)}, 297 Type: dst.Clone(result.Type).(dst.Expr), 298 } 299 } 300 return &dst.DeclStmt{ 301 Decl: &dst.GenDecl{ 302 Tok: token.VAR, 303 Specs: specs, 304 }, 305 }, retValNames 306 } 307 308 // getFormalParamReassignments creates a new DEFINE (:=) statement per parameter 309 // to a FuncDecl, which makes a fresh variable with the same name as the formal 310 // parameter name and assigns it to the corresponding name in the CallExpr. 311 // 312 // For example, given a FuncDecl: 313 // 314 // func foo(a int, b string) { ... } 315 // 316 // and a CallExpr 317 // 318 // foo(x, y) 319 // 320 // we'll return the statement: 321 // 322 // var ( 323 // 324 // a int = x 325 // b string = y 326 // 327 // ) 328 // 329 // In the case where the formal parameter name is the same as the input 330 // parameter name, no extra assignment is created. 331 func getFormalParamReassignments(decl *dst.FuncDecl, callExpr *dst.CallExpr) dst.Stmt { 332 formalParams := decl.Type.Params.List 333 reassignmentSpecs := make([]dst.Spec, 0, len(formalParams)) 334 for i, formalParam := range formalParams { 335 if inputIdent, ok := callExpr.Args[i].(*dst.Ident); ok && inputIdent.Name == formalParam.Names[0].Name { 336 continue 337 } 338 reassignmentSpecs = append(reassignmentSpecs, &dst.ValueSpec{ 339 Names: []*dst.Ident{dst.NewIdent(formalParam.Names[0].Name)}, 340 Type: dst.Clone(formalParam.Type).(dst.Expr), 341 Values: []dst.Expr{callExpr.Args[i]}, 342 }) 343 } 344 if len(reassignmentSpecs) == 0 { 345 return &dst.EmptyStmt{} 346 } 347 return &dst.DeclStmt{ 348 Decl: &dst.GenDecl{ 349 Tok: token.VAR, 350 Specs: reassignmentSpecs, 351 }, 352 } 353 } 354 355 // replaceReturnStatements edits the input BlockStmt, from the function funcName, 356 // replacing ReturnStmts at the end of the BlockStmts with the results of 357 // applying returnEditor on the ReturnStmt or deleting them if the modifier is 358 // nil. 359 // It will panic if any return statements are not in the final position of the 360 // input block. 361 func replaceReturnStatements( 362 funcName string, funcIdx int, stmt *dst.BlockStmt, returnModifier func(*dst.ReturnStmt) dst.Stmt, 363 ) *dst.BlockStmt { 364 if len(stmt.List) == 0 { 365 return stmt 366 } 367 // Insert an explicit return at the end if there isn't one. 368 // We'll need to edit this later to make early returns work properly. 369 lastStmt := stmt.List[len(stmt.List)-1] 370 if _, ok := lastStmt.(*dst.ReturnStmt); !ok { 371 ret := &dst.ReturnStmt{} 372 stmt.List = append(stmt.List, ret) 373 lastStmt = ret 374 } 375 retStmt := lastStmt.(*dst.ReturnStmt) 376 if returnModifier == nil { 377 stmt.List[len(stmt.List)-1] = &dst.EmptyStmt{} 378 } else { 379 stmt.List[len(stmt.List)-1] = returnModifier(retStmt) 380 } 381 382 label := dst.NewIdent(fmt.Sprintf("%s_return_%d", funcName, funcIdx)) 383 384 // Find returns that weren't at the end of the function and replace them with 385 // labeled gotos. 386 var foundInlineReturn bool 387 stmt = dstutil.Apply(stmt, func(cursor *dstutil.Cursor) bool { 388 n := cursor.Node() 389 switch n := n.(type) { 390 case *dst.FuncLit: 391 // A FuncLit is a function literal, like: 392 // x := func() int { return 3 } 393 // We don't recurse into function literals since the return statements 394 // they contain aren't relevant to the inliner. 395 return false 396 case *dst.ReturnStmt: 397 foundInlineReturn = true 398 gotoStmt := &dst.BranchStmt{ 399 Tok: token.GOTO, 400 Label: dst.Clone(label).(*dst.Ident), 401 } 402 if returnModifier != nil { 403 cursor.Replace(returnModifier(n)) 404 cursor.InsertAfter(gotoStmt) 405 } else { 406 cursor.Replace(gotoStmt) 407 } 408 return false 409 } 410 return true 411 }, nil).(*dst.BlockStmt) 412 413 if foundInlineReturn { 414 // Add the label at the end. 415 stmt.List = append(stmt.List, 416 &dst.LabeledStmt{ 417 Label: label, 418 Stmt: &dst.EmptyStmt{Implicit: true}, 419 }) 420 } 421 return stmt 422 } 423 424 // getInlinedFunc returns the corresponding FuncDecl for a CallExpr from the 425 // map, using the CallExpr's name to look up the FuncDecl from templateFuncs. 426 func getInlinedFunc(templateFuncs map[string]funcInfo, n *dst.CallExpr) *funcInfo { 427 ident, ok := n.Fun.(*dst.Ident) 428 if !ok { 429 return nil 430 } 431 432 info, ok := templateFuncs[ident.Name] 433 if !ok { 434 return nil 435 } 436 decl := info.decl 437 if decl.Type.Params.NumFields()+len(info.templateParams) != len(n.Args) { 438 panic(errors.Newf( 439 "%s expected %d arguments, found %d", 440 decl.Name, decl.Type.Params.NumFields(), len(n.Args)), 441 ) 442 } 443 return &info 444 }