github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/utils/batsee/main.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build !windows
    16  // +build !windows
    17  
    18  package main
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"io"
    24  	"os"
    25  	"os/exec"
    26  	"os/signal"
    27  	"path/filepath"
    28  	"strings"
    29  	"sync"
    30  	"syscall"
    31  	"time"
    32  
    33  	"github.com/fatih/color"
    34  
    35  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    36  	"github.com/dolthub/dolt/go/libraries/utils/argparser"
    37  )
    38  
    39  var batseeDoc = cli.CommandDocumentationContent{
    40  	ShortDesc: `Run the Bats Tests concurrently`,
    41  	LongDesc: `From within the integration-test/bats directory, run the bats tests concurrently.
    42  Output for each test is written to a file in the batsee_output directory.
    43  Example:  batsee -t 42 --max-time 1h15m -r 2 --only types.bats,foreign-keys.bats`,
    44  	Synopsis: []string{
    45  		`[-t threads] [-o dir] [--skip-slow] [--max-time time] [--only test1,test2,...]`,
    46  	},
    47  }
    48  
    49  const (
    50  	threadsFlag  = "threads"
    51  	outputDir    = "output"
    52  	skipSlowFlag = "skip-slow"
    53  	maxTimeFlag  = "max-time"
    54  	onlyFLag     = "only"
    55  	retriesFLag  = "retries"
    56  )
    57  
    58  func buildArgParser() *argparser.ArgParser {
    59  	ap := argparser.NewArgParserWithMaxArgs("batsee", 0)
    60  	ap.SupportsInt(threadsFlag, "t", "threads", "Number of tests to execute in parallel. Defaults to 12")
    61  	ap.SupportsString(outputDir, "o", "directory", "Directory to write output to. Defaults to 'batsee_results'")
    62  	ap.SupportsFlag(skipSlowFlag, "s", "Skip slow tests. This is a static list of test we know are slow, may grow stale.")
    63  	ap.SupportsString(maxTimeFlag, "", "duration", "Maximum time to run tests. Defaults to 30m")
    64  	ap.SupportsString(onlyFLag, "", "", "Only run the specified test, or tests (comma separated)")
    65  	ap.SupportsInt(retriesFLag, "r", "retries", "Number of times to retry a failed test. Defaults to 1")
    66  	return ap
    67  }
    68  
    69  type batsResult struct {
    70  	runtime time.Duration
    71  	path    string
    72  	err     error
    73  	skipped bool
    74  	aborted bool
    75  }
    76  
    77  // list of slow commands. These tend to run more than 5-7 min, so we want to run them first.
    78  var slowCommands = map[string]bool{
    79  	"types.bats":                 true,
    80  	"keyless.bats":               true,
    81  	"index-on-writes.bats":       true,
    82  	"constraint-violations.bats": true,
    83  	"foreign-keys.bats":          true,
    84  	"index.bats":                 true,
    85  	"sql-server.bats":            true,
    86  	"index-on-writes-2.bats":     true,
    87  	"sql.bats":                   true,
    88  	"remotes.bats":               true,
    89  }
    90  
    91  type config struct {
    92  	threads  int
    93  	output   string
    94  	duration time.Duration
    95  	skipSlow bool
    96  	limitTo  map[string]bool
    97  	retries  int
    98  }
    99  
   100  func buildConfig(apr *argparser.ArgParseResults) config {
   101  	threads, hasThreads := apr.GetInt(threadsFlag)
   102  	if !hasThreads {
   103  		threads = 12
   104  	}
   105  
   106  	output, hasOutput := apr.GetValue(outputDir)
   107  	if !hasOutput {
   108  		output = "batsee_results"
   109  	}
   110  
   111  	durationInput, hasDuration := apr.GetValue(maxTimeFlag)
   112  	if !hasDuration {
   113  		durationInput = "30m"
   114  	}
   115  	duration, err := time.ParseDuration(durationInput)
   116  	if err != nil {
   117  		cli.Println("Error parsing duration:", err)
   118  		os.Exit(1)
   119  	}
   120  
   121  	skipSlow := apr.Contains(skipSlowFlag)
   122  
   123  	limitTo := map[string]bool{}
   124  	runOnlyStr, hasRunOnly := apr.GetValue(onlyFLag)
   125  	if hasRunOnly {
   126  		for _, test := range strings.Split(runOnlyStr, ",") {
   127  			test = strings.TrimSpace(test)
   128  			limitTo[test] = true
   129  		}
   130  	}
   131  
   132  	retries, hasRetries := apr.GetInt(retriesFLag)
   133  	if !hasRetries {
   134  		retries = 1
   135  	}
   136  
   137  	return config{
   138  		threads:  threads,
   139  		output:   output,
   140  		duration: duration,
   141  		skipSlow: skipSlow,
   142  		limitTo:  limitTo,
   143  		retries:  retries,
   144  	}
   145  }
   146  
   147  func main() {
   148  	ap := buildArgParser()
   149  	help, _ := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString("batsee", batseeDoc, ap))
   150  	args := os.Args[1:]
   151  	apr := cli.ParseArgsOrDie(ap, args, help)
   152  
   153  	config := buildConfig(apr)
   154  
   155  	startTime := time.Now()
   156  
   157  	cwd, err := os.Getwd()
   158  	if err != nil {
   159  		cli.Println("Error getting current working directory:", err)
   160  		os.Exit(1)
   161  	}
   162  	// This is pretty restrictive. Loosen this up. TODO
   163  	if filepath.Base(cwd) != "bats" || filepath.Base(filepath.Dir(cwd)) != "integration-tests" {
   164  		cli.Println("Current working directory is not integration-tests/bats")
   165  		os.Exit(1)
   166  	}
   167  
   168  	// Get a list of all files in this directory which end in ".bats"
   169  	files, err := os.ReadDir(cwd)
   170  	if err != nil {
   171  		cli.Println("Error reading directory:", err)
   172  		os.Exit(1)
   173  	}
   174  
   175  	workQueue := []string{}
   176  	// Insert the slow tests first
   177  	for key, _ := range slowCommands {
   178  		if !config.skipSlow {
   179  			workQueue = append(workQueue, key)
   180  		}
   181  	}
   182  	// Then insert the rest of the tests
   183  	for _, file := range files {
   184  		if !file.IsDir() && filepath.Ext(file.Name()) == ".bats" {
   185  			if _, ok := slowCommands[file.Name()]; !ok {
   186  				workQueue = append(workQueue, file.Name())
   187  			}
   188  		}
   189  	}
   190  
   191  	jobs := make(chan string, len(workQueue))
   192  	results := make(chan batsResult, len(workQueue))
   193  
   194  	ctx := context.Background()
   195  	ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
   196  	defer stop()
   197  	ctx, cancel := context.WithTimeout(ctx, config.duration)
   198  	defer cancel()
   199  
   200  	var wg sync.WaitGroup
   201  	for i := 0; i < config.threads; i++ {
   202  		go func() {
   203  			wg.Add(1)
   204  			defer wg.Done()
   205  			worker(jobs, results, ctx, config)
   206  		}()
   207  	}
   208  
   209  	for _, job := range workQueue {
   210  		jobs <- job
   211  	}
   212  	close(jobs)
   213  
   214  	cli.Println(fmt.Sprintf("Waiting for workers (%d) to finish", config.threads))
   215  	comprehensiveWait(ctx, &wg)
   216  
   217  	close(results)
   218  
   219  	exitStatus := printResults(results)
   220  	cli.Println(fmt.Sprintf("BATS Executor Exemplar completed in: %s with a status of %d", durationStr(time.Since(startTime)), exitStatus))
   221  	os.Exit(exitStatus)
   222  }
   223  
   224  func comprehensiveWait(ctx context.Context, wg *sync.WaitGroup) {
   225  	wgChan := make(chan struct{})
   226  	go func() {
   227  		wg.Wait()
   228  		close(wgChan)
   229  	}()
   230  
   231  	prematureExit := false
   232  	select {
   233  	case <-ctx.Done():
   234  		prematureExit = true
   235  		break
   236  	case <-wgChan:
   237  	}
   238  
   239  	if prematureExit {
   240  		// Still need to wait for workers to finish. They got the signal, but we will panic if we don't let them finish.
   241  		<-wgChan
   242  	}
   243  }
   244  
   245  func printResults(results <-chan batsResult) int {
   246  	// Note that color control characters batch formatting, so we build these status strings all to be the same length
   247  	// so they will produce the right results when included below.
   248  	passStr := color.GreenString(fmt.Sprintf("%20s", "PASSED"))
   249  	failStr := color.RedString(fmt.Sprintf("%20s", "FAILED"))
   250  	skippedStr := color.YellowString(fmt.Sprintf("%20s", "SKIPPED"))
   251  	skippedNoTimeStr := color.YellowString(fmt.Sprintf("%20s", "SKIPPED (no time)"))
   252  	terminatedStr := color.RedString(fmt.Sprintf("%20s", "TERMINATED"))
   253  
   254  	failedQ := []batsResult{}
   255  	skippedQ := []batsResult{}
   256  	for result := range results {
   257  		if result.skipped {
   258  			skippedQ = append(skippedQ, result)
   259  			continue
   260  		}
   261  
   262  		if result.err != nil {
   263  			failedQ = append(failedQ, result)
   264  		} else {
   265  			cli.Println(fmt.Sprintf("%s %-40s (time: %s)", passStr, result.path, durationStr(result.runtime)))
   266  		}
   267  	}
   268  	for _, result := range skippedQ {
   269  		reason := skippedStr
   270  		if result.aborted {
   271  			reason = skippedNoTimeStr
   272  		}
   273  		cli.Println(fmt.Sprintf("%s %-40s (time:NA)", reason, result.path))
   274  	}
   275  
   276  	exitStatus := 0
   277  	for _, result := range failedQ {
   278  		reason := failStr
   279  		if result.aborted {
   280  			reason = terminatedStr
   281  		}
   282  		cli.Println(fmt.Sprintf("%s %-40s (time:%s)", reason, result.path, durationStr(result.runtime)))
   283  		exitStatus = 1
   284  	}
   285  	return exitStatus
   286  }
   287  
   288  func durationStr(duration time.Duration) string {
   289  	return fmt.Sprintf("%02d:%02d", int(duration.Minutes()), int(duration.Seconds())%60)
   290  }
   291  
   292  func worker(jobs <-chan string, results chan<- batsResult, ctx context.Context, config config) {
   293  	for job := range jobs {
   294  		runBats(job, results, ctx, config)
   295  	}
   296  }
   297  
   298  // runBats runs a single bats test and sends the result to the results channel. Stdout and stderr are written to files
   299  // in the batsee_results directory in the CWD, and the error is written to the result.err field.
   300  func runBats(path string, resultChan chan<- batsResult, ctx context.Context, cfg config) {
   301  	cmd := exec.CommandContext(ctx, "bats", path)
   302  	// Set the process group ID so that we can kill the entire process tree if it runs too long. We need to differenciate
   303  	// process group of the sub process from this one, because kill the primary process if we don't.
   304  	cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
   305  	cmd.Env = append(os.Environ(), fmt.Sprintf("DOLT_TEST_RETRIES=%d", cfg.retries))
   306  
   307  	result := batsResult{path: path}
   308  
   309  	if cfg.limitTo != nil && len(cfg.limitTo) != 0 && !cfg.limitTo[path] {
   310  		result.skipped = true
   311  		resultChan <- result
   312  		return
   313  	}
   314  
   315  	// ensure cfg.output exists, and create it if it doesn't
   316  	if _, err := os.Stat(cfg.output); os.IsNotExist(err) {
   317  		err = os.Mkdir(cfg.output, 0755)
   318  		if err != nil {
   319  			cli.Println("Error creating output directory:", err.Error())
   320  			result.err = err
   321  			resultChan <- result
   322  			return
   323  		}
   324  	}
   325  
   326  	startTime := time.Now()
   327  
   328  	outPath := fmt.Sprintf("%s/%s.stdout.log", cfg.output, path)
   329  	output, err := os.Create(outPath)
   330  	if err != nil {
   331  		cli.Println("Error creating stdout log:", err.Error())
   332  		result.err = err
   333  	}
   334  	defer output.Close()
   335  	stdout, err := cmd.StdoutPipe()
   336  	if err != nil {
   337  		cli.Println("Error creating stdout pipe:", err.Error())
   338  		result.err = err
   339  	}
   340  
   341  	errPath := fmt.Sprintf("%s/%s.stderr.log", cfg.output, path)
   342  	errput, err := os.Create(errPath)
   343  	if err != nil {
   344  		cli.Println("Error creating stderr log:", err.Error())
   345  		result.err = err
   346  	}
   347  	defer errput.Close()
   348  
   349  	stderr, err := cmd.StderrPipe()
   350  	if err != nil {
   351  		cli.Println("Error creating stderr pipe:", err.Error())
   352  		result.err = err
   353  	}
   354  
   355  	if result.err == nil {
   356  		// All systems go!
   357  		err = cmd.Start()
   358  		if err != nil {
   359  			if ctx.Err() == context.DeadlineExceeded || ctx.Err() == context.Canceled {
   360  				result.aborted = true
   361  				result.skipped = true
   362  			} else {
   363  				cli.Println("Error starting command:", err.Error())
   364  			}
   365  			result.err = err
   366  		}
   367  	}
   368  
   369  	if cmd.Process != nil {
   370  		// Process started. Now we may have things to clean up if things go sideways.
   371  		// do this as a goroutines so that we can tail the output files while tests are running.
   372  		go io.Copy(output, stdout)
   373  		go io.Copy(errput, stderr)
   374  		pgroup := -1 * cmd.Process.Pid
   375  
   376  		err = cmd.Wait()
   377  		if err != nil {
   378  			if ctx.Err() == context.DeadlineExceeded || ctx.Err() == context.Canceled {
   379  				// Kill entire process group with fire
   380  				syscall.Kill(pgroup, syscall.SIGKILL)
   381  				result.aborted = true
   382  			}
   383  			// command completed with a non-0 exit code. This is "normal", so not writing to output. It will be captured
   384  			// as part of the summary.
   385  			result.err = err
   386  		}
   387  	}
   388  	result.runtime = time.Since(startTime)
   389  	resultChan <- result
   390  	return
   391  }