github.com/pingcap/failpoint@v0.0.0-20240412033321-fd0796e60f86/code/rewriter.go (about) 1 // Copyright 2019 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package code 16 17 import ( 18 "fmt" 19 "go/ast" 20 "go/format" 21 "go/parser" 22 "go/token" 23 "io" 24 "os" 25 "path/filepath" 26 "runtime/debug" 27 "strings" 28 ) 29 30 const ( 31 packagePath = "github.com/pingcap/failpoint" 32 packageName = "failpoint" 33 evalFunction = "Eval" 34 evalCtxFunction = "EvalContext" 35 ExtendPkgName = "_curpkg_" 36 // It is an indicator to indicate the label is converted from `failpoint.Label("...")` 37 // We use an illegal suffix to avoid conflict with the user's code 38 // So `failpoint.Label("label1")` will be converted to `label1-tmp-marker:` in expression 39 // rewrite and be converted to the legal form in label statement organization. 40 labelSuffix = "-tmp-marker" 41 ) 42 43 // Rewriter represents a rewriting tool for converting the failpoint marker functions to 44 // corresponding statements in Golang. It will traverse the specified path and filter 45 // out files which do not have failpoint injection sites, and rewrite the remain files. 46 type Rewriter struct { 47 rewriteDir string 48 currentPath string 49 currentFile *ast.File 50 currsetFset *token.FileSet 51 failpointName string 52 allowNotChecked bool 53 rewritten bool 54 55 output io.Writer 56 } 57 58 // NewRewriter returns a non-nil rewriter which is used to rewrite the specified path 59 func NewRewriter(path string) *Rewriter { 60 return &Rewriter{ 61 rewriteDir: path, 62 } 63 } 64 65 // SetOutput sets a writer and the rewrite results will write to the writer instead of generate a stash file 66 func (r *Rewriter) SetOutput(out io.Writer) { 67 r.output = out 68 } 69 70 // SetAllowNotChecked sets whether the rewriter allows the file which does not import failpoint package. 71 func (r *Rewriter) SetAllowNotChecked(b bool) { 72 r.allowNotChecked = b 73 } 74 75 // GetRewritten returns whether the rewriter has rewritten the file in a RewriteFile call. 76 func (r *Rewriter) GetRewritten() bool { 77 return r.rewritten 78 } 79 80 // GetCurrentFile returns the current file which is being rewritten 81 func (r *Rewriter) GetCurrentFile() *ast.File { 82 return r.currentFile 83 } 84 85 func (r *Rewriter) pos(pos token.Pos) string { 86 p := r.currsetFset.Position(pos) 87 return fmt.Sprintf("%s:%d", p.Filename, p.Line) 88 } 89 90 func (r *Rewriter) rewriteFuncLit(fn *ast.FuncLit) error { 91 return r.rewriteStmts(fn.Body.List) 92 } 93 94 func (r *Rewriter) rewriteAssign(v *ast.AssignStmt) error { 95 // fn1, fn2, fn3, ... := func(){...}, func(){...}, func(){...}, ... 96 // x, fn := 100, func() { 97 // failpoint.Marker(fpname, func() { 98 // ... 99 // }) 100 // } 101 // ch := <-func() chan interface{} { 102 // failpoint.Marker(fpname, func() { 103 // ... 104 // }) 105 // } 106 for _, v := range v.Rhs { 107 err := r.rewriteExpr(v) 108 if err != nil { 109 return err 110 } 111 } 112 return nil 113 } 114 115 // rewriteInitStmt rewrites non-nil initialization statement 116 func (r *Rewriter) rewriteInitStmt(v ast.Stmt) error { 117 var err error 118 switch stmt := v.(type) { 119 case *ast.ExprStmt: 120 err = r.rewriteExpr(stmt.X) 121 case *ast.AssignStmt: 122 err = r.rewriteAssign(stmt) 123 } 124 return err 125 } 126 127 func (r *Rewriter) rewriteIfStmt(v *ast.IfStmt) error { 128 // if a, b := func() {...}, func() int {...}(); cond {...} 129 // if func() {...}(); cond {...} 130 if v.Init != nil { 131 err := r.rewriteInitStmt(v.Init) 132 if err != nil { 133 return err 134 } 135 } 136 137 if err := r.rewriteExpr(v.Cond); err != nil { 138 return err 139 } 140 141 err := r.rewriteStmts(v.Body.List) 142 if err != nil { 143 return err 144 } 145 if v.Else != nil { 146 if elseIf, ok := v.Else.(*ast.IfStmt); ok { 147 return r.rewriteIfStmt(elseIf) 148 } 149 if els, ok := v.Else.(*ast.BlockStmt); ok { 150 return r.rewriteStmts(els.List) 151 } 152 } 153 return nil 154 } 155 156 func (r *Rewriter) rewriteExpr(expr ast.Expr) error { 157 if expr == nil { 158 return nil 159 } 160 161 switch ex := expr.(type) { 162 case *ast.BadExpr, 163 *ast.Ident, 164 *ast.Ellipsis, 165 *ast.BasicLit, 166 *ast.ArrayType, 167 *ast.StructType, 168 *ast.FuncType, 169 *ast.InterfaceType, 170 *ast.MapType, 171 *ast.ChanType: 172 // expressions that can not inject failpoint 173 case *ast.SelectorExpr: 174 return r.rewriteExpr(ex.X) 175 176 case *ast.IndexExpr: 177 // func()[]int {}()[func()int{}()] 178 if err := r.rewriteExpr(ex.X); err != nil { 179 return err 180 } 181 return r.rewriteExpr(ex.Index) 182 183 case *ast.SliceExpr: 184 // array[low:high:max] 185 // => func()[]int {}()[func()int{}():func()int{}():func()int{}()] 186 if err := r.rewriteExpr(ex.Low); err != nil { 187 return err 188 } 189 if err := r.rewriteExpr(ex.High); err != nil { 190 return err 191 } 192 if err := r.rewriteExpr(ex.Max); err != nil { 193 return err 194 } 195 return r.rewriteExpr(ex.X) 196 197 case *ast.FuncLit: 198 // return func(){...}, 199 return r.rewriteFuncLit(ex) 200 201 case *ast.CompositeLit: 202 // []int{func() int {...}()} 203 for _, elt := range ex.Elts { 204 if err := r.rewriteExpr(elt); err != nil { 205 return err 206 } 207 } 208 209 case *ast.CallExpr: 210 // return func() int {...}() 211 if fn, ok := ex.Fun.(*ast.FuncLit); ok { 212 err := r.rewriteFuncLit(fn) 213 if err != nil { 214 return err 215 } 216 } 217 218 // return fn(func() int{...}) 219 for _, arg := range ex.Args { 220 if fn, ok := arg.(*ast.FuncLit); ok { 221 err := r.rewriteFuncLit(fn) 222 if err != nil { 223 return err 224 } 225 } 226 } 227 228 case *ast.StarExpr: 229 // *func() *T{}() 230 return r.rewriteExpr(ex.X) 231 232 case *ast.UnaryExpr: 233 // !func() {...}() 234 return r.rewriteExpr(ex.X) 235 236 case *ast.BinaryExpr: 237 // a && func() bool {...} () 238 // func() bool {...} () && a 239 // func() bool {...} () && func() bool {...} () && a 240 // func() bool {...} () && a && func() bool {...} () && a 241 err := r.rewriteExpr(ex.X) 242 if err != nil { 243 return err 244 } 245 return r.rewriteExpr(ex.Y) 246 247 case *ast.ParenExpr: 248 // (func() {...}()) 249 return r.rewriteExpr(ex.X) 250 251 case *ast.TypeAssertExpr: 252 // (func() {...}()).(type) 253 return r.rewriteExpr(ex.X) 254 255 case *ast.KeyValueExpr: 256 // Key: (func() {...}()) 257 return r.rewriteExpr(ex.Value) 258 259 default: 260 fmt.Printf("unspport expression: %T\n", expr) 261 } 262 return nil 263 } 264 265 func (r *Rewriter) rewriteExprs(exprs []ast.Expr) error { 266 for _, expr := range exprs { 267 err := r.rewriteExpr(expr) 268 if err != nil { 269 return err 270 } 271 } 272 return nil 273 } 274 275 func (r *Rewriter) rewriteStmts(stmts []ast.Stmt) error { 276 for i, block := range stmts { 277 switch v := block.(type) { 278 case *ast.DeclStmt: 279 // var fn1, fn2, fn3, ... = func(){...}, func(){...}, func(){...}, ... 280 // var x, fn = 100, func() { 281 // failpoint.Marker(fpname, func() { 282 // ... 283 // }) 284 // } 285 specs := v.Decl.(*ast.GenDecl).Specs 286 for _, spec := range specs { 287 vs, ok := spec.(*ast.ValueSpec) 288 if !ok { 289 continue 290 } 291 for _, v := range vs.Values { 292 fn, ok := v.(*ast.FuncLit) 293 if !ok { 294 continue 295 } 296 err := r.rewriteStmts(fn.Body.List) 297 if err != nil { 298 return err 299 } 300 } 301 } 302 303 case *ast.ExprStmt: 304 // failpoint.Marker("failpoint.name", func(context.Context, *failpoint.Arg)) {...} 305 // failpoint.Break() 306 // failpoint.Break("label") 307 // failpoint.Continue() 308 // failpoint.Fallthrough() 309 // failpoint.Continue("label") 310 // failpoint.Goto("label") 311 // failpoint.Label("label") 312 call, ok := v.X.(*ast.CallExpr) 313 if !ok { 314 break 315 } 316 for _, arg := range call.Args { 317 err := r.rewriteExpr(arg) 318 if err != nil { 319 return err 320 } 321 } 322 323 switch expr := call.Fun.(type) { 324 case *ast.FuncLit: 325 err := r.rewriteFuncLit(expr) 326 if err != nil { 327 return err 328 } 329 case *ast.SelectorExpr: 330 packageName, ok := expr.X.(*ast.Ident) 331 if !ok || packageName.Name != r.failpointName { 332 break 333 } 334 exprRewriter, found := exprRewriters[expr.Sel.Name] 335 if !found { 336 break 337 } 338 rewritten, stmt, err := exprRewriter(r, call) 339 if err != nil { 340 return err 341 } 342 if !rewritten { 343 continue 344 } 345 346 if ifStmt, ok := stmt.(*ast.IfStmt); ok { 347 err := r.rewriteIfStmt(ifStmt) 348 if err != nil { 349 return err 350 } 351 } 352 353 stmts[i] = stmt 354 r.rewritten = true 355 } 356 357 case *ast.AssignStmt: 358 // x := (func() {...} ()) 359 err := r.rewriteAssign(v) 360 if err != nil { 361 return err 362 } 363 364 case *ast.GoStmt: 365 // go func() {...}() 366 // go func(fn) {...}(func(){...}) 367 err := r.rewriteExpr(v.Call) 368 if err != nil { 369 return err 370 } 371 372 case *ast.DeferStmt: 373 // defer func() {...}() 374 // defer func(fn) {...}(func(){...}) 375 err := r.rewriteExpr(v.Call) 376 if err != nil { 377 return err 378 } 379 380 case *ast.ReturnStmt: 381 // return func() {...}() 382 // return func(fn) {...}(func(){...}) 383 err := r.rewriteExprs(v.Results) 384 if err != nil { 385 return err 386 } 387 388 case *ast.BlockStmt: 389 // { 390 // func() {...}() 391 // } 392 err := r.rewriteStmts(v.List) 393 if err != nil { 394 return err 395 } 396 397 case *ast.IfStmt: 398 // if func() {...}() {...} 399 err := r.rewriteIfStmt(v) 400 if err != nil { 401 return err 402 } 403 404 case *ast.CaseClause: 405 // case func() int {...}() > 100 && func () bool {...}() 406 if len(v.List) > 0 { 407 err := r.rewriteExprs(v.List) 408 if err != nil { 409 return err 410 } 411 } 412 // case func() int {...}() > 100 && func () bool {...}(): 413 // fn := func(){...} 414 // fn() 415 if len(v.Body) > 0 { 416 err := r.rewriteStmts(v.Body) 417 if err != nil { 418 return err 419 } 420 } 421 422 case *ast.SwitchStmt: 423 // switch x := func() {...}(); {...} 424 if v.Init != nil { 425 err := r.rewriteAssign(v.Init.(*ast.AssignStmt)) 426 if err != nil { 427 return err 428 } 429 } 430 431 // switch (func() {...}()) {...} 432 if err := r.rewriteExpr(v.Tag); err != nil { 433 return err 434 } 435 436 // switch x { 437 // case 1: 438 // func() {...}() 439 // } 440 err := r.rewriteStmts(v.Body.List) 441 if err != nil { 442 return err 443 } 444 445 case *ast.CommClause: 446 // select { 447 // case ch := <-func() chan bool {...}(): 448 // case <- fromCh: 449 // case toCh <- x: 450 // case <- func() chan bool {...}(): 451 // default: 452 // } 453 if v.Comm != nil { 454 if assign, ok := v.Comm.(*ast.AssignStmt); ok { 455 err := r.rewriteAssign(assign) 456 if err != nil { 457 return err 458 } 459 } 460 if expr, ok := v.Comm.(*ast.ExprStmt); ok { 461 err := r.rewriteExpr(expr.X) 462 if err != nil { 463 return err 464 } 465 } 466 } 467 err := r.rewriteStmts(v.Body) 468 if err != nil { 469 return err 470 } 471 472 case *ast.SelectStmt: 473 if len(v.Body.List) < 1 { 474 continue 475 } 476 err := r.rewriteStmts(v.Body.List) 477 if err != nil { 478 return err 479 } 480 481 case *ast.ForStmt: 482 // for i := func() int {...}(); i < func() int {...}(); i += func() int {...}() {...} 483 // for iter.Begin(); !iter.End(); iter.Next() {...} 484 if v.Init != nil { 485 err := r.rewriteInitStmt(v.Init) 486 if err != nil { 487 return err 488 } 489 } 490 if v.Cond != nil { 491 err := r.rewriteExpr(v.Cond) 492 if err != nil { 493 return err 494 } 495 } 496 if v.Post != nil { 497 assign, ok := v.Post.(*ast.AssignStmt) 498 if ok { 499 err := r.rewriteAssign(assign) 500 if err != nil { 501 return err 502 } 503 } 504 } 505 err := r.rewriteStmts(v.Body.List) 506 if err != nil { 507 return err 508 } 509 510 case *ast.RangeStmt: 511 // for i := range func() {...}() {...} 512 if err := r.rewriteExpr(v.X); err != nil { 513 return err 514 } 515 err := r.rewriteStmts(v.Body.List) 516 if err != nil { 517 return err 518 } 519 520 case *ast.TypeSwitchStmt: 521 if v.Assign != nil { 522 // switch x := (func () {...}()).(type) {...} 523 if assign, ok := v.Assign.(*ast.AssignStmt); ok { 524 err := r.rewriteAssign(assign) 525 if err != nil { 526 return err 527 } 528 } 529 // switch (func () {...}()).(type) {...} 530 if expr, ok := v.Assign.(*ast.ExprStmt); ok { 531 err := r.rewriteExpr(expr.X) 532 if err != nil { 533 return err 534 } 535 } 536 } 537 err := r.rewriteStmts(v.Body.List) 538 if err != nil { 539 return err 540 } 541 542 case *ast.SendStmt: 543 // ch <- func () {...}() 544 err := r.rewriteExprs([]ast.Expr{v.Chan, v.Value}) 545 if err != nil { 546 return err 547 } 548 549 case *ast.LabeledStmt: 550 // Label: 551 // func () {...}() 552 stmts := []ast.Stmt{v.Stmt} 553 err := r.rewriteStmts(stmts) 554 if err != nil { 555 return err 556 } 557 v.Stmt = stmts[0] 558 559 case *ast.IncDecStmt: 560 // func() *FooType {...}().Field++ 561 // func() *FooType {...}().Field-- 562 err := r.rewriteExpr(v.X) 563 if err != nil { 564 return err 565 } 566 567 case *ast.BranchStmt: 568 // ignore keyword token (BREAK, CONTINUE, GOTO, FALLTHROUGH) 569 570 default: 571 fmt.Printf("unsupported statement: %T in %s\n", v, r.pos(v.Pos())) 572 } 573 } 574 575 // Label statement must ahead of for loop 576 for i := 0; i < len(stmts); i++ { 577 stmt := stmts[i] 578 if label, ok := stmt.(*ast.LabeledStmt); ok && strings.HasSuffix(label.Label.Name, labelSuffix) { 579 label.Label.Name = label.Label.Name[:len(label.Label.Name)-len(labelSuffix)] 580 label.Stmt = stmts[i+1] 581 stmts[i+1] = &ast.EmptyStmt{} 582 } 583 } 584 return nil 585 } 586 587 func (r *Rewriter) rewriteFuncDecl(fn *ast.FuncDecl) error { 588 if fn.Body == nil { 589 return nil 590 } 591 return r.rewriteStmts(fn.Body.List) 592 } 593 594 // RewriteFile rewrites a single file 595 func (r *Rewriter) RewriteFile(path string) (err error) { 596 defer func() { 597 if e := recover(); e != nil { 598 err = fmt.Errorf("%s %v\n%s", r.currentPath, e, debug.Stack()) 599 } 600 }() 601 fset := token.NewFileSet() 602 file, err := parser.ParseFile(fset, path, nil, parser.ParseComments) 603 if err != nil { 604 return err 605 } 606 if len(file.Decls) < 1 { 607 return nil 608 } 609 r.currentPath = path 610 r.currentFile = file 611 r.currsetFset = fset 612 r.rewritten = false 613 614 var failpointImport *ast.ImportSpec 615 for _, imp := range file.Imports { 616 if strings.Trim(imp.Path.Value, "`\"") == packagePath { 617 failpointImport = imp 618 break 619 } 620 } 621 if failpointImport == nil { 622 if r.allowNotChecked { 623 return nil 624 } 625 panic("import path should be check before rewrite") 626 } 627 if failpointImport.Name != nil { 628 r.failpointName = failpointImport.Name.Name 629 } else { 630 r.failpointName = packageName 631 } 632 633 for _, decl := range file.Decls { 634 fn, ok := decl.(*ast.FuncDecl) 635 if !ok { 636 continue 637 } 638 if err := r.rewriteFuncDecl(fn); err != nil { 639 return err 640 } 641 } 642 643 if !r.rewritten { 644 return nil 645 } 646 647 if r.output != nil { 648 return format.Node(r.output, fset, file) 649 } 650 651 // Generate binding code 652 found, err := isBindingFileExists(path) 653 if err != nil { 654 return err 655 } 656 if !found { 657 err := writeBindingFile(path, file.Name.Name) 658 if err != nil { 659 return err 660 } 661 } 662 663 // Backup origin file and replace content 664 targetPath := path + failpointStashFileSuffix 665 if err := os.Rename(path, targetPath); err != nil { 666 return err 667 } 668 669 newFile, err := os.OpenFile(path, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644) 670 if err != nil { 671 return err 672 } 673 defer newFile.Close() 674 return format.Node(newFile, fset, file) 675 } 676 677 // Rewrite does the rewrite action for specified path. It contains the main steps: 678 // 679 // 1. Filter out failpoint binding files and files that have no suffix `.go` 680 // 2. Filter out files which have not imported failpoint package (implying no failpoints) 681 // 3. Parse file to `ast.File` and rewrite the AST 682 // 4. Create failpoint binding file (which contains `_curpkg_` function) if it does not exist 683 // 5. Rename original file to `original-file-name + __failpoint_stash__` 684 // 6. Replace original file content base on the new AST 685 func (r *Rewriter) Rewrite() error { 686 var files []string 687 err := filepath.Walk(r.rewriteDir, func(path string, info os.FileInfo, err error) error { 688 if err != nil { 689 return err 690 } 691 if info.IsDir() { 692 return nil 693 } 694 if !strings.HasSuffix(path, ".go") { 695 return nil 696 } 697 if strings.HasSuffix(path, failpointBindingFileName) { 698 return nil 699 } 700 // Will rewrite a file only if the file has imported "github.com/pingcap/failpoint" 701 fset := token.NewFileSet() 702 file, err := parser.ParseFile(fset, path, nil, parser.ImportsOnly) 703 if err != nil { 704 return err 705 } 706 if len(file.Imports) < 1 { 707 return nil 708 } 709 for _, imp := range file.Imports { 710 // import path maybe in the form of: 711 // 712 // 1. normal import 713 // - "github.com/pingcap/failpoint" 714 // - `github.com/pingcap/failpoint` 715 // 2. ignore import 716 // - _ "github.com/pingcap/failpoint" 717 // - _ `github.com/pingcap/failpoint` 718 // 3. alias import 719 // - alias "github.com/pingcap/failpoint" 720 // - alias `github.com/pingcap/failpoint` 721 // we should trim '"' or '`' before compare it. 722 if strings.Trim(imp.Path.Value, "`\"") == packagePath { 723 files = append(files, path) 724 break 725 } 726 } 727 return nil 728 }) 729 if err != nil { 730 return err 731 } 732 733 for _, file := range files { 734 err := r.RewriteFile(file) 735 if err != nil { 736 return err 737 } 738 } 739 return nil 740 }