github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/github-pull-request-make/main.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  // This utility detects new tests added in a given pull request, and runs them
    12  // under stress in our CI infrastructure.
    13  //
    14  // Note that this program will directly exec `make`, so there is no need to
    15  // process its output. See build/teamcity-test{,race}.sh for usage examples.
    16  //
    17  // Note that our CI infrastructure has no notion of "pull requests", forcing
    18  // the approach taken here be quite brute-force with respect to its use of the
    19  // GitHub API.
    20  package main
    21  
    22  import (
    23  	"bufio"
    24  	"bytes"
    25  	"context"
    26  	"fmt"
    27  	"go/build"
    28  	"io"
    29  	"log"
    30  	"net/http"
    31  	"os"
    32  	"os/exec"
    33  	"path/filepath"
    34  	"regexp"
    35  	"strings"
    36  	"time"
    37  
    38  	_ "github.com/cockroachdb/cockroach/pkg/testutils/buildutil"
    39  	"github.com/google/go-github/github"
    40  	"golang.org/x/oauth2"
    41  )
    42  
    43  const githubAPITokenEnv = "GITHUB_API_TOKEN"
    44  const teamcityVCSNumberEnv = "BUILD_VCS_NUMBER"
    45  const makeTargetEnv = "TARGET"
    46  
    47  // https://github.com/golang/go/blob/go1.7.3/src/cmd/go/test.go#L1260:L1262
    48  //
    49  // It is a Test (say) if there is a character after Test that is not a lower-case letter.
    50  // We don't want TesticularCancer.
    51  const goTestStr = `func (Test[^a-z]\w*)\(.*\*testing\.TB?\) {$`
    52  const goBenchmarkStr = `func (Benchmark[^a-z]\w*)\(.*\*testing\.T?B\) {$`
    53  
    54  var currentGoTestRE = regexp.MustCompile(`.*` + goTestStr)
    55  var currentGoBenchmarkRE = regexp.MustCompile(`.*` + goBenchmarkStr)
    56  var newGoTestRE = regexp.MustCompile(`^\+\s*` + goTestStr)
    57  var newGoBenchmarkRE = regexp.MustCompile(`^\+\s*` + goBenchmarkStr)
    58  
    59  type pkg struct {
    60  	tests, benchmarks []string
    61  }
    62  
    63  // pkgsFromDiff parses a git-style diff and returns a mapping from directories
    64  // to tests and benchmarks added in those directories in the given diff.
    65  func pkgsFromDiff(r io.Reader) (map[string]pkg, error) {
    66  	const newFilePrefix = "+++ b/"
    67  	const replacement = "$1"
    68  
    69  	pkgs := make(map[string]pkg)
    70  
    71  	var curPkgName string
    72  	var curTestName string
    73  	var curBenchmarkName string
    74  	var inPrefix bool
    75  	for reader := bufio.NewReader(r); ; {
    76  		line, isPrefix, err := reader.ReadLine()
    77  		switch {
    78  		case err == nil:
    79  		case err == io.EOF:
    80  			return pkgs, nil
    81  		default:
    82  			return nil, err
    83  		}
    84  		// Ignore generated files a la embedded.go.
    85  		if isPrefix {
    86  			inPrefix = true
    87  			continue
    88  		} else if inPrefix {
    89  			inPrefix = false
    90  			continue
    91  		}
    92  
    93  		switch {
    94  		case bytes.HasPrefix(line, []byte(newFilePrefix)):
    95  			curPkgName = filepath.Dir(string(bytes.TrimPrefix(line, []byte(newFilePrefix))))
    96  		case newGoTestRE.Match(line):
    97  			curPkg := pkgs[curPkgName]
    98  			curPkg.tests = append(curPkg.tests, string(newGoTestRE.ReplaceAll(line, []byte(replacement))))
    99  			pkgs[curPkgName] = curPkg
   100  		case newGoBenchmarkRE.Match(line):
   101  			curPkg := pkgs[curPkgName]
   102  			curPkg.benchmarks = append(curPkg.benchmarks, string(newGoBenchmarkRE.ReplaceAll(line, []byte(replacement))))
   103  			pkgs[curPkgName] = curPkg
   104  		case currentGoTestRE.Match(line):
   105  			curTestName = ""
   106  			curBenchmarkName = ""
   107  			if !bytes.HasPrefix(line, []byte{'-'}) {
   108  				curTestName = string(currentGoTestRE.ReplaceAll(line, []byte(replacement)))
   109  			}
   110  		case currentGoBenchmarkRE.Match(line):
   111  			curTestName = ""
   112  			curBenchmarkName = ""
   113  			if !bytes.HasPrefix(line, []byte{'-'}) {
   114  				curBenchmarkName = string(currentGoBenchmarkRE.ReplaceAll(line, []byte(replacement)))
   115  			}
   116  		case bytes.HasPrefix(line, []byte{'-'}) && bytes.Contains(line, []byte(".Skip")):
   117  			if curPkgName != "" {
   118  				switch {
   119  				case len(curTestName) > 0:
   120  					if !(curPkgName == "build" && curTestName == "TestStyle") {
   121  						curPkg := pkgs[curPkgName]
   122  						curPkg.tests = append(curPkg.tests, curTestName)
   123  						pkgs[curPkgName] = curPkg
   124  					}
   125  				case len(curBenchmarkName) > 0:
   126  					curPkg := pkgs[curPkgName]
   127  					curPkg.benchmarks = append(curPkg.benchmarks, curBenchmarkName)
   128  					pkgs[curPkgName] = curPkg
   129  				}
   130  			}
   131  		}
   132  	}
   133  }
   134  
   135  func findPullRequest(
   136  	ctx context.Context, client *github.Client, org, repo, sha string,
   137  ) *github.PullRequest {
   138  	opts := &github.PullRequestListOptions{
   139  		ListOptions: github.ListOptions{PerPage: 100},
   140  	}
   141  	for {
   142  		pulls, resp, err := client.PullRequests.List(ctx, org, repo, opts)
   143  		if err != nil {
   144  			log.Fatal(err)
   145  		}
   146  
   147  		for _, pull := range pulls {
   148  			if *pull.Head.SHA == sha {
   149  				return pull
   150  			}
   151  		}
   152  
   153  		if resp.NextPage == 0 {
   154  			return nil
   155  		}
   156  		opts.Page = resp.NextPage
   157  	}
   158  }
   159  
   160  func ghClient(ctx context.Context) *github.Client {
   161  	var httpClient *http.Client
   162  	if token, ok := os.LookupEnv(githubAPITokenEnv); ok {
   163  		httpClient = oauth2.NewClient(ctx, oauth2.StaticTokenSource(
   164  			&oauth2.Token{AccessToken: token},
   165  		))
   166  	} else {
   167  		log.Printf("GitHub API token environment variable %s is not set", githubAPITokenEnv)
   168  	}
   169  	return github.NewClient(httpClient)
   170  }
   171  
   172  func getDiff(
   173  	ctx context.Context, client *github.Client, org, repo string, prNum int,
   174  ) (string, error) {
   175  	diff, _, err := client.PullRequests.GetRaw(
   176  		ctx,
   177  		org,
   178  		repo,
   179  		prNum,
   180  		github.RawOptions{Type: github.Diff},
   181  	)
   182  	return diff, err
   183  }
   184  
   185  func main() {
   186  	sha, ok := os.LookupEnv(teamcityVCSNumberEnv)
   187  	if !ok {
   188  		log.Fatalf("VCS number environment variable %s is not set", teamcityVCSNumberEnv)
   189  	}
   190  
   191  	target, ok := os.LookupEnv(makeTargetEnv)
   192  	if !ok {
   193  		log.Fatalf("make target variable %s is not set", makeTargetEnv)
   194  	}
   195  
   196  	const org = "cockroachdb"
   197  	const repo = "cockroach"
   198  
   199  	crdb, err := build.Import(fmt.Sprintf("github.com/%s/%s", org, repo), "", build.FindOnly)
   200  	if err != nil {
   201  		log.Fatal(err)
   202  	}
   203  
   204  	ctx := context.Background()
   205  	client := ghClient(ctx)
   206  
   207  	currentPull := findPullRequest(ctx, client, org, repo, sha)
   208  	if currentPull == nil {
   209  		log.Printf("SHA %s not found in open pull requests, skipping stress", sha)
   210  		return
   211  	}
   212  
   213  	diff, err := getDiff(ctx, client, org, repo, *currentPull.Number)
   214  	if err != nil {
   215  		log.Fatal(err)
   216  	}
   217  
   218  	if target == "checkdeps" {
   219  		var vendorChanged bool
   220  		for _, path := range []string{"Gopkg.lock", "vendor"} {
   221  			if strings.Contains(diff, fmt.Sprintf("\n--- a/%[1]s\n+++ b/%[1]s\n", path)) {
   222  				vendorChanged = true
   223  				break
   224  			}
   225  		}
   226  		if vendorChanged {
   227  			cmd := exec.Command("dep", "ensure", "-v")
   228  			cmd.Dir = crdb.Dir
   229  			cmd.Stdout = os.Stdout
   230  			cmd.Stderr = os.Stderr
   231  			log.Println(cmd.Args)
   232  			if err := cmd.Run(); err != nil {
   233  				log.Fatal(err)
   234  			}
   235  
   236  			// Check for diffs.
   237  			var foundDiff bool
   238  			for _, dir := range []string{filepath.Join(crdb.Dir, "vendor"), crdb.Dir} {
   239  				cmd := exec.Command("git", "diff")
   240  				cmd.Dir = dir
   241  				log.Println(cmd.Dir, cmd.Args)
   242  				if output, err := cmd.CombinedOutput(); err != nil {
   243  					log.Fatalf("%s: %s", err, string(output))
   244  				} else if len(output) > 0 {
   245  					foundDiff = true
   246  					log.Printf("unexpected diff:\n%s", output)
   247  				}
   248  			}
   249  			if foundDiff {
   250  				os.Exit(1)
   251  			}
   252  		}
   253  	} else {
   254  		pkgs, err := pkgsFromDiff(strings.NewReader(diff))
   255  		if err != nil {
   256  			log.Fatal(err)
   257  		}
   258  		if len(pkgs) > 0 {
   259  			for name, pkg := range pkgs {
   260  				// 20 minutes total seems OK, but at least 2 minutes per test.
   261  				// This should be reduced. See #46941.
   262  				duration := (20 * time.Minute) / time.Duration(len(pkgs))
   263  				minDuration := (2 * time.Minute) * time.Duration(len(pkg.tests))
   264  				if duration < minDuration {
   265  					duration = minDuration
   266  				}
   267  				// Use a timeout shorter than the duration so that hanging tests don't
   268  				// get a free pass.
   269  				timeout := (3 * duration) / 4
   270  
   271  				tests := "-"
   272  				if len(pkg.tests) > 0 {
   273  					tests = "(" + strings.Join(pkg.tests, "$$|") + "$$)"
   274  				}
   275  
   276  				cmd := exec.Command(
   277  					"make",
   278  					target,
   279  					fmt.Sprintf("PKG=./%s", name),
   280  					fmt.Sprintf("TESTS=%s", tests),
   281  					fmt.Sprintf("TESTTIMEOUT=%s", timeout),
   282  					fmt.Sprintf("GOTESTFLAGS=-json"), // allow TeamCity to parse failures
   283  					fmt.Sprintf("STRESSFLAGS=-stderr -maxfails 1 -maxtime %s", duration),
   284  				)
   285  				cmd.Env = append(os.Environ(), "COCKROACH_NIGHTLY_STRESS=true")
   286  				cmd.Dir = crdb.Dir
   287  				cmd.Stdout = os.Stdout
   288  				cmd.Stderr = os.Stderr
   289  				log.Println(cmd.Args)
   290  				if err := cmd.Run(); err != nil {
   291  					log.Fatal(err)
   292  				}
   293  			}
   294  		}
   295  	}
   296  }