github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/citogo/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"flag"
     8  	"fmt"
     9  	"os"
    10  	"os/exec"
    11  	"path/filepath"
    12  	"strings"
    13  	"time"
    14  
    15  	"golang.org/x/sync/errgroup"
    16  
    17  	"github.com/keybase/client/go/citogo/types"
    18  )
    19  
    20  type opts struct {
    21  	Flakes               int
    22  	Fails                int
    23  	Prefix               string
    24  	S3Bucket             string
    25  	ReportLambdaFunction string
    26  	DirBasename          string
    27  	BuildID              string
    28  	Branch               string
    29  	Parallel             int
    30  	Preserve             bool
    31  	BuildURL             string
    32  	NoCompile            bool
    33  	TestBinary           string
    34  	Timeout              string
    35  	Pause                time.Duration
    36  }
    37  
    38  func logError(f string, args ...interface{}) {
    39  	s := fmt.Sprintf(f, args...)
    40  	if s[len(s)-1] != '\n' {
    41  		s += "\n"
    42  	}
    43  	fmt.Fprintf(os.Stderr, "%s", s)
    44  }
    45  
    46  type runner struct {
    47  	opts   opts
    48  	flakes []string
    49  	fails  []string
    50  	tests  []string
    51  }
    52  
    53  func convertBreakingChars(s string) string {
    54  	// replace either the unix or the DOS directory marker
    55  	// with an underscore, so as not to break the directory
    56  	// structure of where we are storing the log
    57  	s = strings.ReplaceAll(s, "/", "_")
    58  	s = strings.ReplaceAll(s, "\\", "_")
    59  	s = strings.ReplaceAll(s, "-", "_")
    60  	return s
    61  }
    62  
    63  func (r *runner) parseArgs() (err error) {
    64  	flag.IntVar(&r.opts.Flakes, "flakes", 3, "number of allowed flakes")
    65  	flag.IntVar(&r.opts.Fails, "fails", -1, "number of fails allowed before quitting")
    66  	flag.IntVar(&r.opts.Parallel, "parallel", 1, "number of tests to run in parallel")
    67  	flag.StringVar(&r.opts.Prefix, "prefix", "", "test set prefix")
    68  	flag.StringVar(&r.opts.S3Bucket, "s3bucket", "", "AWS S3 bucket to write failures to")
    69  	flag.StringVar(&r.opts.ReportLambdaFunction, "report-lambda-function", "", "lambda function to report results to")
    70  	flag.StringVar(&r.opts.BuildID, "build-id", "", "build ID of the current build")
    71  	flag.StringVar(&r.opts.Branch, "branch", "", "the branch of the current build")
    72  	flag.BoolVar(&r.opts.Preserve, "preserve", false, "preserve test binary after done")
    73  	flag.StringVar(&r.opts.BuildURL, "build-url", "", "URL for this build (in CI mainly)")
    74  	flag.BoolVar(&r.opts.NoCompile, "no-compile", false, "specify flag if you've pre-compiled the test")
    75  	flag.StringVar(&r.opts.TestBinary, "test-binary", "", "specify the test binary to run")
    76  	flag.StringVar(&r.opts.Timeout, "timeout", "60s", "timeout (in seconds) for any one individual test")
    77  	flag.DurationVar(&r.opts.Pause, "pause", 0, "pause duration between each test (default 0)")
    78  	flag.Parse()
    79  	var d string
    80  	d, err = os.Getwd()
    81  	if err != nil {
    82  		return err
    83  	}
    84  	r.opts.DirBasename = filepath.Base(d)
    85  	return nil
    86  }
    87  
    88  func (r *runner) compile() error {
    89  	if r.opts.NoCompile {
    90  		return nil
    91  	}
    92  	fmt.Printf("CMPL: %s\n", r.testerName())
    93  	cmd := exec.Command("go", "test", "-c")
    94  	cmd.Stdout = os.Stdout
    95  	cmd.Stderr = os.Stderr
    96  	return cmd.Run()
    97  }
    98  
    99  func filter(v []string) []string {
   100  	var ret []string
   101  	for _, s := range v {
   102  		if s != "" {
   103  			ret = append(ret, s)
   104  		}
   105  	}
   106  	return ret
   107  }
   108  
   109  func (r *runner) testerName() string {
   110  	if r.opts.TestBinary != "" {
   111  		return r.opts.TestBinary
   112  	}
   113  	return fmt.Sprintf(".%c%s.test", os.PathSeparator, r.opts.DirBasename)
   114  }
   115  
   116  func (r *runner) listTests() error {
   117  	cmd := exec.Command(r.testerName(), "-test.list", ".")
   118  	var out bytes.Buffer
   119  	cmd.Stdout = &out
   120  	err := cmd.Run()
   121  	if err != nil {
   122  		return err
   123  	}
   124  	r.tests = filter(strings.Split(out.String(), "\n"))
   125  	return nil
   126  }
   127  
   128  func (r *runner) flushTestLogs(test string, log bytes.Buffer) (string, error) {
   129  	logName := fmt.Sprintf("citogo-%s-%s-%s-%s", convertBreakingChars(r.opts.Branch),
   130  		convertBreakingChars(r.opts.BuildID), convertBreakingChars(r.opts.Prefix), test)
   131  	if r.opts.S3Bucket != "" {
   132  		return r.flushLogsToS3(logName, log)
   133  	}
   134  	return r.flushTestLogsToTemp(logName, log)
   135  }
   136  
   137  func (r *runner) flushLogsToS3(logName string, log bytes.Buffer) (string, error) {
   138  	return s3put(&log, r.opts.S3Bucket, logName)
   139  }
   140  
   141  func (r *runner) flushTestLogsToTemp(logName string, log bytes.Buffer) (string, error) {
   142  	tmpfile, err := os.CreateTemp("", fmt.Sprintf("%s-", logName))
   143  	if err != nil {
   144  		return "", err
   145  	}
   146  	_, err = tmpfile.Write(log.Bytes())
   147  	if err != nil {
   148  		return "", err
   149  	}
   150  	err = tmpfile.Close()
   151  	if err != nil {
   152  		return "", err
   153  	}
   154  	return fmt.Sprintf("see log: %s", tmpfile.Name()), nil
   155  }
   156  
   157  func (r *runner) report(result types.TestResult) {
   158  	if r.opts.ReportLambdaFunction == "" {
   159  		return
   160  	}
   161  
   162  	b, err := json.Marshal(result)
   163  	if err != nil {
   164  		logError("error marshalling result: %s", err.Error())
   165  		return
   166  	}
   167  
   168  	err = lambdaInvoke(r.opts.ReportLambdaFunction, b)
   169  	if err != nil {
   170  		logError("error reporting flake: %s", err.Error())
   171  	}
   172  }
   173  
   174  func (r *runner) newTestResult(outcome types.Outcome, testName string, where string) types.TestResult {
   175  	return types.TestResult{
   176  		Outcome:  outcome,
   177  		TestName: testName,
   178  		Where:    where,
   179  		Branch:   r.opts.Branch,
   180  		BuildID:  r.opts.BuildID,
   181  		Prefix:   r.opts.Prefix,
   182  		BuildURL: r.opts.BuildURL,
   183  	}
   184  }
   185  
   186  func (r *runner) runTest(test string) error {
   187  	canRerun := len(r.flakes) < r.opts.Flakes
   188  	outcome, where, err := r.runTestOnce(test, false /* isRerun */, canRerun)
   189  	if err != nil {
   190  		return err
   191  	}
   192  	if outcome == types.OutcomeSuccess {
   193  		return nil
   194  	}
   195  	if len(r.flakes) >= r.opts.Flakes {
   196  		return errTestFailed
   197  	}
   198  	outcome2, _, err2 := r.runTestOnce(test, true /* isRerun */, false /* canRerun */)
   199  	if err2 != nil {
   200  		return err2
   201  	}
   202  	switch outcome2 {
   203  	case types.OutcomeFail:
   204  		return errTestFailed
   205  	case types.OutcomeSuccess:
   206  		r.report(r.newTestResult(types.OutcomeFlake, test, where))
   207  		r.flakes = append(r.flakes, test)
   208  	}
   209  	return nil
   210  }
   211  
   212  var errTestFailed = errors.New("test failed")
   213  
   214  // runTestOnce only returns an error if there was a problem with the test
   215  // harness code; it does not return an error if the test failed.
   216  func (r *runner) runTestOnce(test string, isRerun bool, canRerun bool) (outcome types.Outcome, where string, err error) {
   217  	defer func() {
   218  		logOutcome := outcome
   219  		if outcome == types.OutcomeFail && canRerun {
   220  			logOutcome = types.OutcomeFlake
   221  		}
   222  		fmt.Printf("%s: %s %s\n", logOutcome.Abbrv(), test, where)
   223  		if logOutcome != types.OutcomeFlake && r.opts.Branch == "master" && err == nil {
   224  			r.report(r.newTestResult(logOutcome, test, where))
   225  		}
   226  	}()
   227  
   228  	cmd := exec.Command(r.testerName(), "-test.run", "^"+test+"$", "-test.timeout", r.opts.Timeout)
   229  	if isRerun {
   230  		cmd.Env = append(os.Environ(), "CITOGO_FLAKE_RERUN=1")
   231  	}
   232  	var combined bytes.Buffer
   233  	cmd.Stdout = &combined
   234  	cmd.Stderr = &combined
   235  	testErr := cmd.Run()
   236  	if testErr != nil {
   237  		err = errTestFailed
   238  
   239  		var flushErr error
   240  		where, flushErr := r.flushTestLogs(test, combined)
   241  		if flushErr != nil {
   242  			return types.OutcomeFail, "", flushErr
   243  		}
   244  		return types.OutcomeFail, where, nil
   245  	}
   246  	return types.OutcomeSuccess, "", nil
   247  }
   248  
   249  func (r *runner) runTestFixError(t string) error {
   250  	err := r.runTest(t)
   251  	if err == nil {
   252  		return nil
   253  	}
   254  	if err != errTestFailed {
   255  		return err
   256  	}
   257  	r.fails = append(r.fails, t)
   258  	if r.opts.Fails < 0 {
   259  		// We have an infinite fail budget, so keep plowing through
   260  		// failed tests. This test run is still going to fail.
   261  		return nil
   262  	}
   263  	if r.opts.Fails >= len(r.fails) {
   264  		// We've failed less than our budget, so we can still keep going.
   265  		// This test run is still going to fail.
   266  		return nil
   267  	}
   268  	// We ate up our fail budget.
   269  	return err
   270  }
   271  
   272  func (r *runner) runTests() error {
   273  	var eg errgroup.Group
   274  	q := make(chan string, len(r.tests))
   275  	for i := 0; i < r.opts.Parallel; i++ {
   276  		eg.Go(func() error {
   277  			for f := range q {
   278  				err := r.runTestFixError(f)
   279  				if err != nil {
   280  					return err
   281  				}
   282  				if r.opts.Pause > 0 {
   283  					time.Sleep(r.opts.Pause)
   284  				}
   285  			}
   286  			return nil
   287  		})
   288  	}
   289  	for _, f := range r.tests {
   290  		q <- f
   291  	}
   292  	close(q)
   293  	return eg.Wait()
   294  }
   295  
   296  func (r *runner) cleanup() {
   297  	if r.opts.Preserve || r.opts.NoCompile {
   298  		return
   299  	}
   300  	n := r.testerName()
   301  	err := os.Remove(n)
   302  	fmt.Printf("RMOV: %s\n", n)
   303  	if err != nil {
   304  		logError("could not remove %s: %s", n, err.Error())
   305  	}
   306  }
   307  
   308  func (r *runner) debugStartup() {
   309  	dir, _ := os.Getwd()
   310  	fmt.Printf("WDIR: %s\n", dir)
   311  }
   312  
   313  func (r *runner) testExists() (bool, error) {
   314  	f := r.testerName()
   315  	info, err := os.Stat(f)
   316  	if os.IsNotExist(err) {
   317  		return false, nil
   318  	}
   319  	if err != nil {
   320  		return false, err
   321  	}
   322  	if info.Mode().IsRegular() {
   323  		return true, nil
   324  	}
   325  	return false, fmt.Errorf("%s: file of wrong type", f)
   326  
   327  }
   328  
   329  func (r *runner) run() error {
   330  	start := time.Now()
   331  	err := r.parseArgs()
   332  	if err != nil {
   333  		return err
   334  	}
   335  
   336  	r.debugStartup()
   337  	err = r.compile()
   338  	if err != nil {
   339  		return err
   340  	}
   341  	exists, err := r.testExists()
   342  	if exists {
   343  		err = r.listTests()
   344  		if err != nil {
   345  			return err
   346  		}
   347  		err = r.runTests()
   348  		r.cleanup()
   349  	}
   350  	end := time.Now()
   351  	diff := end.Sub(start)
   352  	fmt.Printf("DONE: in %s\n", diff)
   353  	if err != nil {
   354  		return err
   355  	}
   356  	if len(r.fails) > 0 {
   357  		// If we have more than 15 tests, repeat at the end which tests failed,
   358  		// so we don't have to scroll all the way up.
   359  		if len(r.tests) > 15 {
   360  			for _, t := range r.fails {
   361  				fmt.Printf("FAIL: %s\n", t)
   362  			}
   363  		}
   364  		return fmt.Errorf("RED!: %d total tests failed", len(r.fails))
   365  	}
   366  	return nil
   367  }
   368  
   369  func main2() error {
   370  	runner := runner{}
   371  	return runner.run()
   372  }
   373  
   374  func main() {
   375  	err := main2()
   376  	if err != nil {
   377  		logError(err.Error())
   378  		fmt.Printf("EXIT: 2\n")
   379  		os.Exit(2)
   380  	}
   381  	fmt.Printf("EXIT: 0\n")
   382  	os.Exit(0)
   383  }