github.com/pingcap/failpoint@v0.0.0-20240412033321-fd0796e60f86/failpoint-toolexec/main.go (about)

     1  // Copyright 2024 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"fmt"
    19  	"log"
    20  	"os"
    21  	"os/exec"
    22  	"path/filepath"
    23  	"runtime"
    24  	"strings"
    25  
    26  	"github.com/pingcap/errors"
    27  	"github.com/pingcap/failpoint/code"
    28  	"golang.org/x/mod/modfile"
    29  )
    30  
    31  var logger = log.New(os.Stderr, "[failpoint-toolexec]", log.LstdFlags)
    32  
    33  func main() {
    34  	if len(os.Args) < 2 {
    35  		return
    36  	}
    37  	goCmd, buildArgs := os.Args[1], os.Args[2:]
    38  	goCmdBase := filepath.Base(goCmd)
    39  	if runtime.GOOS == "windows" {
    40  		goCmdBase = strings.TrimSuffix(goCmd, ".exe")
    41  	}
    42  
    43  	if strings.ToLower(goCmdBase) == "compile" {
    44  		if err := injectFailpoint(&buildArgs); err != nil {
    45  			logger.Println("failed to inject failpoint", err)
    46  		}
    47  	}
    48  
    49  	cmd := exec.Command(goCmd, buildArgs...)
    50  	cmd.Stdout = os.Stdout
    51  	cmd.Stderr = os.Stderr
    52  
    53  	if err := cmd.Run(); err != nil {
    54  		logger.Println("failed to run command", err)
    55  	}
    56  }
    57  
    58  func injectFailpoint(argsP *[]string) error {
    59  	callersModule, err := findCallersModule()
    60  	if err != nil {
    61  		return err
    62  	}
    63  
    64  	// ref https://pkg.go.dev/cmd/compile#hdr-Command_Line
    65  	var module string
    66  	args := *argsP
    67  	for i, arg := range args {
    68  		if arg == "-p" {
    69  			if i+1 < len(args) {
    70  				module = args[i+1]
    71  			}
    72  			break
    73  		}
    74  	}
    75  	if !strings.HasPrefix(module, callersModule) && module != "main" {
    76  		return nil
    77  	}
    78  
    79  	fileIndices := make([]int, 0, len(args))
    80  	for i, arg := range args {
    81  		// find the golang source files of the caller's package
    82  		if strings.HasSuffix(arg, ".go") && !inSDKOrMod(arg) {
    83  			fileIndices = append(fileIndices, i)
    84  		}
    85  	}
    86  
    87  	needExtraFile := false
    88  	writer := &code.Rewriter{}
    89  	writer.SetAllowNotChecked(true)
    90  	for _, idx := range fileIndices {
    91  		needExtraFile = injectFailpointForFile(writer, &args[idx], module) || needExtraFile
    92  	}
    93  	if needExtraFile {
    94  		newFile := filepath.Join(tmpFolder, module, "failpoint_toolexec_extra.go")
    95  		if err := writeExtraFile(newFile, writer.GetCurrentFile().Name.Name, module); err != nil {
    96  			return err
    97  		}
    98  		*argsP = append(args, newFile)
    99  	}
   100  	return nil
   101  }
   102  
   103  // ref https://github.com/golang/go/blob/bdd27c4debfb51fe42df0c0532c1c747777b7a32/src/cmd/go/internal/modload/init.go#L1511
   104  func findCallersModule() (string, error) {
   105  	cwd, err := os.Getwd()
   106  	if err != nil {
   107  		return "", err
   108  	}
   109  	dir := filepath.Clean(cwd)
   110  
   111  	// Look for enclosing go.mod.
   112  	for {
   113  		goModPath := filepath.Join(dir, "go.mod")
   114  		if fi, err := os.Stat(goModPath); err == nil && !fi.IsDir() {
   115  			data, err := os.ReadFile(goModPath)
   116  			if err != nil {
   117  				return "", err
   118  			}
   119  			f, err := modfile.ParseLax(goModPath, data, nil)
   120  			if err != nil {
   121  				return "", err
   122  			}
   123  			return f.Module.Mod.Path, err
   124  		}
   125  		d := filepath.Dir(dir)
   126  		if d == dir {
   127  			break
   128  		}
   129  		dir = d
   130  	}
   131  	return "", errors.New("go.mod file not found")
   132  }
   133  
   134  var goModCache = os.Getenv("GOMODCACHE")
   135  var goRoot = runtime.GOROOT()
   136  
   137  func inSDKOrMod(path string) bool {
   138  	absPath, err := filepath.Abs(path)
   139  	if err != nil {
   140  		logger.Println("failed to get absolute path", err)
   141  		return false
   142  	}
   143  
   144  	if goModCache != "" && strings.HasPrefix(absPath, goModCache) {
   145  		return true
   146  	}
   147  	if strings.HasPrefix(absPath, goRoot) {
   148  		return true
   149  	}
   150  	return false
   151  }
   152  
   153  var tmpFolder = filepath.Join(os.TempDir(), "failpoint-toolexec")
   154  
   155  func injectFailpointForFile(w *code.Rewriter, file *string, module string) bool {
   156  	newFile := filepath.Join(tmpFolder, module, filepath.Base(*file))
   157  	newFileDir := filepath.Dir(newFile)
   158  	if err := os.MkdirAll(newFileDir, 0700); err != nil {
   159  		logger.Println("failed to create temp folder", err)
   160  		return false
   161  	}
   162  	f, err := os.OpenFile(newFile, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
   163  	if err != nil {
   164  		logger.Println("failed to open temp file", err)
   165  		return false
   166  	}
   167  	defer f.Close()
   168  	w.SetOutput(f)
   169  
   170  	if err := w.RewriteFile(*file); err != nil {
   171  		logger.Println("failed to rewrite file", err)
   172  		return false
   173  	}
   174  	if !w.GetRewritten() {
   175  		return false
   176  	}
   177  	*file = newFile
   178  	return true
   179  }
   180  
   181  func writeExtraFile(filePath, packageName, module string) error {
   182  	bindingContent := fmt.Sprintf(`
   183  package %s
   184  
   185  func %s(name string) string {
   186  	return "%s/" + name
   187  }
   188  `, packageName, code.ExtendPkgName, module)
   189  	return os.WriteFile(filePath, []byte(bindingContent), 0644)
   190  }