github.com/pingcap/failpoint@v0.0.0-20240412033321-fd0796e60f86/failpoint-toolexec/main.go (about) 1 // Copyright 2024 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package main 16 17 import ( 18 "fmt" 19 "log" 20 "os" 21 "os/exec" 22 "path/filepath" 23 "runtime" 24 "strings" 25 26 "github.com/pingcap/errors" 27 "github.com/pingcap/failpoint/code" 28 "golang.org/x/mod/modfile" 29 ) 30 31 var logger = log.New(os.Stderr, "[failpoint-toolexec]", log.LstdFlags) 32 33 func main() { 34 if len(os.Args) < 2 { 35 return 36 } 37 goCmd, buildArgs := os.Args[1], os.Args[2:] 38 goCmdBase := filepath.Base(goCmd) 39 if runtime.GOOS == "windows" { 40 goCmdBase = strings.TrimSuffix(goCmd, ".exe") 41 } 42 43 if strings.ToLower(goCmdBase) == "compile" { 44 if err := injectFailpoint(&buildArgs); err != nil { 45 logger.Println("failed to inject failpoint", err) 46 } 47 } 48 49 cmd := exec.Command(goCmd, buildArgs...) 50 cmd.Stdout = os.Stdout 51 cmd.Stderr = os.Stderr 52 53 if err := cmd.Run(); err != nil { 54 logger.Println("failed to run command", err) 55 } 56 } 57 58 func injectFailpoint(argsP *[]string) error { 59 callersModule, err := findCallersModule() 60 if err != nil { 61 return err 62 } 63 64 // ref https://pkg.go.dev/cmd/compile#hdr-Command_Line 65 var module string 66 args := *argsP 67 for i, arg := range args { 68 if arg == "-p" { 69 if i+1 < len(args) { 70 module = args[i+1] 71 } 72 break 73 } 74 } 75 if !strings.HasPrefix(module, callersModule) && module != "main" { 76 return nil 77 } 78 79 fileIndices := make([]int, 0, len(args)) 80 for i, arg := range args { 81 // find the golang source files of the caller's package 82 if strings.HasSuffix(arg, ".go") && !inSDKOrMod(arg) { 83 fileIndices = append(fileIndices, i) 84 } 85 } 86 87 needExtraFile := false 88 writer := &code.Rewriter{} 89 writer.SetAllowNotChecked(true) 90 for _, idx := range fileIndices { 91 needExtraFile = injectFailpointForFile(writer, &args[idx], module) || needExtraFile 92 } 93 if needExtraFile { 94 newFile := filepath.Join(tmpFolder, module, "failpoint_toolexec_extra.go") 95 if err := writeExtraFile(newFile, writer.GetCurrentFile().Name.Name, module); err != nil { 96 return err 97 } 98 *argsP = append(args, newFile) 99 } 100 return nil 101 } 102 103 // ref https://github.com/golang/go/blob/bdd27c4debfb51fe42df0c0532c1c747777b7a32/src/cmd/go/internal/modload/init.go#L1511 104 func findCallersModule() (string, error) { 105 cwd, err := os.Getwd() 106 if err != nil { 107 return "", err 108 } 109 dir := filepath.Clean(cwd) 110 111 // Look for enclosing go.mod. 112 for { 113 goModPath := filepath.Join(dir, "go.mod") 114 if fi, err := os.Stat(goModPath); err == nil && !fi.IsDir() { 115 data, err := os.ReadFile(goModPath) 116 if err != nil { 117 return "", err 118 } 119 f, err := modfile.ParseLax(goModPath, data, nil) 120 if err != nil { 121 return "", err 122 } 123 return f.Module.Mod.Path, err 124 } 125 d := filepath.Dir(dir) 126 if d == dir { 127 break 128 } 129 dir = d 130 } 131 return "", errors.New("go.mod file not found") 132 } 133 134 var goModCache = os.Getenv("GOMODCACHE") 135 var goRoot = runtime.GOROOT() 136 137 func inSDKOrMod(path string) bool { 138 absPath, err := filepath.Abs(path) 139 if err != nil { 140 logger.Println("failed to get absolute path", err) 141 return false 142 } 143 144 if goModCache != "" && strings.HasPrefix(absPath, goModCache) { 145 return true 146 } 147 if strings.HasPrefix(absPath, goRoot) { 148 return true 149 } 150 return false 151 } 152 153 var tmpFolder = filepath.Join(os.TempDir(), "failpoint-toolexec") 154 155 func injectFailpointForFile(w *code.Rewriter, file *string, module string) bool { 156 newFile := filepath.Join(tmpFolder, module, filepath.Base(*file)) 157 newFileDir := filepath.Dir(newFile) 158 if err := os.MkdirAll(newFileDir, 0700); err != nil { 159 logger.Println("failed to create temp folder", err) 160 return false 161 } 162 f, err := os.OpenFile(newFile, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) 163 if err != nil { 164 logger.Println("failed to open temp file", err) 165 return false 166 } 167 defer f.Close() 168 w.SetOutput(f) 169 170 if err := w.RewriteFile(*file); err != nil { 171 logger.Println("failed to rewrite file", err) 172 return false 173 } 174 if !w.GetRewritten() { 175 return false 176 } 177 *file = newFile 178 return true 179 } 180 181 func writeExtraFile(filePath, packageName, module string) error { 182 bindingContent := fmt.Sprintf(` 183 package %s 184 185 func %s(name string) string { 186 return "%s/" + name 187 } 188 `, packageName, code.ExtendPkgName, module) 189 return os.WriteFile(filePath, []byte(bindingContent), 0644) 190 }