github.com/zppinho/prow@v0.0.0-20240510014325-1738badeb017/cmd/external-plugins/needs-rebase/plugin/plugin.go (about)

     1  /*
     2  Copyright 2017 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package plugin
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"strings"
    25  	"sync"
    26  	"time"
    27  
    28  	githubql "github.com/shurcooL/githubv4"
    29  	"github.com/sirupsen/logrus"
    30  	utilerrors "k8s.io/apimachinery/pkg/util/errors"
    31  	"sigs.k8s.io/prow/pkg/config"
    32  	"sigs.k8s.io/prow/pkg/github"
    33  	"sigs.k8s.io/prow/pkg/labels"
    34  	"sigs.k8s.io/prow/pkg/pluginhelp"
    35  	"sigs.k8s.io/prow/pkg/plugins"
    36  )
    37  
    38  const (
    39  	// PluginName is the name of this plugin
    40  	PluginName              = labels.NeedsRebase
    41  	needsRebaseMessage      = "PR needs rebase."
    42  	dependabotRebaseMessage = "rebase"
    43  	dependabotUser          = "dependabot[bot]"
    44  )
    45  
    46  var sleep = time.Sleep
    47  
    48  type githubClient interface {
    49  	GetIssueLabels(org, repo string, number int) ([]github.Label, error)
    50  	CreateCommentWithContext(ctx context.Context, org, repo string, number int, comment string) error
    51  	BotUserChecker() (func(candidate string) bool, error)
    52  	AddLabelWithContext(ctx context.Context, org, repo string, number int, label string) error
    53  	RemoveLabelWithContext(ctx context.Context, org, repo string, number int, label string) error
    54  	IsMergeable(org, repo string, number int, sha string) (bool, error)
    55  	DeleteStaleCommentsWithContext(ctx context.Context, org, repo string, number int, comments []github.IssueComment, isStale func(github.IssueComment) bool) error
    56  	QueryWithGitHubAppsSupport(ctx context.Context, q interface{}, vars map[string]interface{}, org string) error
    57  	GetPullRequest(org, repo string, number int) (*github.PullRequest, error)
    58  }
    59  
    60  // HelpProvider constructs the PluginHelp for this plugin that takes into account enabled repositories.
    61  // HelpProvider defines the type for function that construct the PluginHelp for plugins.
    62  func HelpProvider(_ []config.OrgRepo) (*pluginhelp.PluginHelp, error) {
    63  	return &pluginhelp.PluginHelp{
    64  			Description: `The needs-rebase plugin manages the '` + labels.NeedsRebase + `' label by removing it from Pull Requests that are mergeable and adding it to those which are not.
    65  The plugin reacts to commit changes on PRs in addition to periodically scanning all open PRs for any changes to mergeability that could have resulted from changes in other PRs.`,
    66  		},
    67  		nil
    68  }
    69  
    70  // HandlePullRequestEvent handles a GitHub pull request event and adds or removes a
    71  // "needs-rebase" label based on whether the GitHub api considers the PR mergeable
    72  func HandlePullRequestEvent(log *logrus.Entry, ghc githubClient, pre *github.PullRequestEvent) error {
    73  	if pre.Action != github.PullRequestActionOpened && pre.Action != github.PullRequestActionSynchronize && pre.Action != github.PullRequestActionReopened {
    74  		return nil
    75  	}
    76  	return handle(log, ghc, &pre.PullRequest)
    77  }
    78  
    79  // HandleIssueCommentEvent handles a GitHub issue comment event and adds or removes a
    80  // "needs-rebase" label if the issue is a PR based on whether the GitHub api considers
    81  // the PR mergeable
    82  func HandleIssueCommentEvent(log *logrus.Entry, ghc githubClient, ice *github.IssueCommentEvent, cache *Cache) error {
    83  	if !ice.Issue.IsPullRequest() {
    84  		return nil
    85  	}
    86  
    87  	if cache.validTime > 0 && cache.Get(ice.Issue.ID) {
    88  		return nil
    89  	}
    90  
    91  	pr, err := ghc.GetPullRequest(ice.Repo.Owner.Login, ice.Repo.Name, ice.Issue.Number)
    92  	if err != nil {
    93  		return err
    94  	}
    95  	err = handle(log, ghc, pr)
    96  
    97  	if cache.validTime > 0 && err == nil {
    98  		cache.Set(ice.Issue.ID)
    99  	}
   100  
   101  	return err
   102  }
   103  
   104  // handle handles a GitHub PR to determine if the "needs-rebase"
   105  // label needs to be added or removed. It depends on GitHub mergeability check
   106  // to decide the need for a rebase.
   107  func handle(log *logrus.Entry, ghc githubClient, pr *github.PullRequest) error {
   108  	if pr.State != github.PullRequestStateOpen {
   109  		return nil
   110  	}
   111  	// Before checking mergeability wait a few seconds to give github a chance to calculate it.
   112  	// This initial delay prevents us from always wasting the first API token.
   113  	sleep(time.Second * 5)
   114  
   115  	org := pr.Base.Repo.Owner.Login
   116  	repo := pr.Base.Repo.Name
   117  	number := pr.Number
   118  	sha := pr.Head.SHA
   119  	*log = *log.WithFields(logrus.Fields{
   120  		github.OrgLogField:  org,
   121  		github.RepoLogField: repo,
   122  		github.PrLogField:   number,
   123  		"head-sha":          sha,
   124  	})
   125  
   126  	mergeable, err := ghc.IsMergeable(org, repo, number, sha)
   127  	if err != nil {
   128  		return err
   129  	}
   130  	issueLabels, err := ghc.GetIssueLabels(org, repo, number)
   131  	if err != nil {
   132  		return err
   133  	}
   134  	hasLabel := github.HasLabel(labels.NeedsRebase, issueLabels)
   135  	return takeAction(ghc, org, repo, number, pr.User.Login, hasLabel, mergeable)
   136  }
   137  
   138  const searchQueryPrefix = "archived:false is:pr is:open"
   139  
   140  // HandleAll checks all orgs and repos that enabled this plugin for open PRs to
   141  // determine if the "needs-rebase" label needs to be added or removed. It
   142  // depends on GitHub's mergeability check to decide the need for a rebase.
   143  func HandleAll(log *logrus.Entry, ghc githubClient, config *plugins.Configuration, usesAppsAuth bool, issueCache *Cache) error {
   144  	if issueCache.validTime > 0 {
   145  		issueCache.Flush()
   146  	}
   147  
   148  	log.Info("Checking all PRs.")
   149  	orgs, repos := config.EnabledReposForExternalPlugin(PluginName)
   150  	if len(orgs) == 0 && len(repos) == 0 {
   151  		log.Warnf("No repos have been configured for the %s plugin", PluginName)
   152  		return nil
   153  	}
   154  
   155  	var prs []pullRequest
   156  	var errs []error
   157  	for org, queries := range constructQueries(log, time.Now(), orgs, repos, usesAppsAuth) {
   158  		// Do _not_ parallelize this. It will trigger GitHubs abuse detection and we don't really care anyways except
   159  		// when developing.
   160  		for _, query := range queries {
   161  			found, err := search(context.Background(), log, ghc, query, org)
   162  			prs = append(prs, found...)
   163  			errs = append(errs, err)
   164  		}
   165  	}
   166  	if err := utilerrors.NewAggregate(errs); err != nil {
   167  		if len(prs) == 0 {
   168  			return err
   169  		}
   170  		log.WithError(err).Error("Encountered errors when querying GitHub but will process received results anyways")
   171  	}
   172  	log.WithField("prs_found_count", len(prs)).Debug("Processing all found PRs")
   173  
   174  	for _, pr := range prs {
   175  		// Skip PRs that are calculating mergeability or are not open. They will be updated by event or next loop.
   176  		if pr.Mergeable == githubql.MergeableStateUnknown || pr.State != githubql.PullRequestStateOpen {
   177  			continue
   178  		}
   179  		org := string(pr.Repository.Owner.Login)
   180  		repo := string(pr.Repository.Name)
   181  		num := int(pr.Number)
   182  		var hasLabel bool
   183  		for _, label := range pr.Labels.Nodes {
   184  			if label.Name == labels.NeedsRebase {
   185  				hasLabel = true
   186  				break
   187  			}
   188  		}
   189  		l := log.WithFields(logrus.Fields{
   190  			"org":       org,
   191  			"repo":      repo,
   192  			"pr":        num,
   193  			"mergeable": pr.Mergeable,
   194  			"has_label": hasLabel,
   195  		})
   196  		l.Debug("Processing PR")
   197  		err := takeAction(
   198  			ghc,
   199  			org,
   200  			repo,
   201  			num,
   202  			string(pr.Author.Login),
   203  			hasLabel,
   204  			pr.Mergeable == githubql.MergeableStateMergeable,
   205  		)
   206  		if err != nil {
   207  			l.WithError(err).Error("Error handling PR.")
   208  		}
   209  	}
   210  	return nil
   211  }
   212  
   213  // takeAction adds or removes the "needs-rebase" label based on the current
   214  // state of the PR (hasLabel and mergeable). It also handles adding and
   215  // removing GitHub comments notifying the PR author that a rebase is needed.
   216  func takeAction(ghc githubClient, org, repo string, num int, author string, hasLabel, mergeable bool) error {
   217  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
   218  	defer cancel()
   219  
   220  	// Swallow context.DeadlineExceeded errors, they are expected to happen when we get throttled
   221  	if err := takeActionWithContext(ctx, ghc, org, repo, num, author, hasLabel, mergeable); err != nil && !errors.Is(err, context.DeadlineExceeded) {
   222  		return err
   223  	}
   224  	return nil
   225  }
   226  
   227  func takeActionWithContext(ctx context.Context, ghc githubClient, org, repo string, num int, author string, hasLabel, mergeable bool) error {
   228  	if !mergeable && !hasLabel {
   229  		if err := ghc.AddLabelWithContext(ctx, org, repo, num, labels.NeedsRebase); err != nil {
   230  			return fmt.Errorf("failed to add %q label: %w", labels.NeedsRebase, err)
   231  		}
   232  		var msg string
   233  		if author == dependabotUser {
   234  			msg = plugins.FormatSimpleResponse(dependabotRebaseMessage)
   235  		} else {
   236  			msg = plugins.FormatSimpleResponse(needsRebaseMessage)
   237  		}
   238  		return ghc.CreateCommentWithContext(ctx, org, repo, num, msg)
   239  	} else if mergeable && hasLabel {
   240  		// remove label and prune comment
   241  		if err := ghc.RemoveLabelWithContext(ctx, org, repo, num, labels.NeedsRebase); err != nil {
   242  			return fmt.Errorf("failed to remove %q label: %w", labels.NeedsRebase, err)
   243  		}
   244  		botUserChecker, err := ghc.BotUserChecker()
   245  		if err != nil {
   246  			return err
   247  		}
   248  		return ghc.DeleteStaleCommentsWithContext(ctx, org, repo, num, nil, shouldPrune(botUserChecker))
   249  	}
   250  
   251  	return nil
   252  }
   253  
   254  func shouldPrune(isBot func(string) bool) func(github.IssueComment) bool {
   255  	return func(ic github.IssueComment) bool {
   256  		return isBot(ic.User.Login) &&
   257  			strings.Contains(ic.Body, needsRebaseMessage)
   258  	}
   259  }
   260  
   261  func search(ctx context.Context, log *logrus.Entry, ghc githubClient, q, org string) ([]pullRequest, error) {
   262  	var ret []pullRequest
   263  	vars := map[string]interface{}{
   264  		"query":        githubql.String(q),
   265  		"searchCursor": (*githubql.String)(nil),
   266  	}
   267  	var totalCost int
   268  	var remaining int
   269  	requestStart := time.Now()
   270  	var pageCount int
   271  	for {
   272  		pageCount++
   273  		sq := searchQuery{}
   274  		if err := ghc.QueryWithGitHubAppsSupport(ctx, &sq, vars, org); err != nil {
   275  			return nil, err
   276  		}
   277  		totalCost += int(sq.RateLimit.Cost)
   278  		remaining = int(sq.RateLimit.Remaining)
   279  		for _, n := range sq.Search.Nodes {
   280  			ret = append(ret, n.PullRequest)
   281  		}
   282  		if !sq.Search.PageInfo.HasNextPage {
   283  			break
   284  		}
   285  		vars["searchCursor"] = githubql.NewString(sq.Search.PageInfo.EndCursor)
   286  	}
   287  	log = log.WithFields(logrus.Fields{
   288  		"query":          q,
   289  		"duration":       time.Since(requestStart).String(),
   290  		"pr_found_count": len(ret),
   291  		"search_pages":   pageCount,
   292  		"cost":           totalCost,
   293  		"remaining":      remaining,
   294  	})
   295  	log.Debug("Finished query")
   296  
   297  	// https://github.community/t/graphql-github-api-how-to-get-more-than-1000-pull-requests/13838/10
   298  	if len(ret) == 1000 {
   299  		log.Warning("Query returned 1k PRs, which is the max number of results per query allowed by GitHub. This indicates that we were not able to process all PRs.")
   300  	}
   301  	return ret, nil
   302  }
   303  
   304  // See: https://developer.github.com/v4/object/pullrequest/.
   305  type pullRequest struct {
   306  	Number githubql.Int
   307  	Author struct {
   308  		Login githubql.String
   309  	}
   310  	Repository struct {
   311  		Name  githubql.String
   312  		Owner struct {
   313  			Login githubql.String
   314  		}
   315  	}
   316  	Labels struct {
   317  		Nodes []struct {
   318  			Name githubql.String
   319  		}
   320  	} `graphql:"labels(first:100)"`
   321  	Mergeable githubql.MergeableState
   322  	State     githubql.PullRequestState
   323  }
   324  
   325  // See: https://developer.github.com/v4/query/.
   326  type searchQuery struct {
   327  	RateLimit struct {
   328  		Cost      githubql.Int
   329  		Remaining githubql.Int
   330  	}
   331  	Search struct {
   332  		PageInfo struct {
   333  			HasNextPage githubql.Boolean
   334  			EndCursor   githubql.String
   335  		}
   336  		Nodes []struct {
   337  			PullRequest pullRequest `graphql:"... on PullRequest"`
   338  		}
   339  	} `graphql:"search(type: ISSUE, first: 100, after: $searchCursor, query: $query)"`
   340  }
   341  
   342  // constructQueries constructs the v4 queries for the peridic scan.
   343  // It returns a map[org][]query.
   344  func constructQueries(log *logrus.Entry, now time.Time, orgs, repos []string, usesGitHubAppsAuth bool) map[string][]string {
   345  	result := map[string][]string{}
   346  
   347  	// GitHub hard caps queries at 1k results, so always do one query per org and one for
   348  	// all repos. Ref: https://github.community/t/graphql-github-api-how-to-get-more-than-1000-pull-requests/13838/11
   349  	for _, org := range orgs {
   350  		// https://img.17qq.com/images/crqhcuueqhx.jpeg
   351  		if org == "kubernetes" {
   352  			result[org] = append(result[org], searchQueryPrefix+` org:"kubernetes" -repo:"kubernetes/kubernetes"`)
   353  
   354  			// Sharding by creation time > 2 months ago gives us around 50% of PRs per query (585 for the newer ones, 538 for the older ones when testing)
   355  			twoMonthsAgoISO8601 := now.Add(-2 * 30 * 24 * time.Hour).Format("2006-01-02")
   356  			result[org] = append(result[org], searchQueryPrefix+` repo:"kubernetes/kubernetes" created:>=`+twoMonthsAgoISO8601)
   357  			result[org] = append(result[org], searchQueryPrefix+` repo:"kubernetes/kubernetes" created:<`+twoMonthsAgoISO8601)
   358  		} else {
   359  			result[org] = append(result[org], searchQueryPrefix+` org:"`+org+`"`)
   360  		}
   361  	}
   362  
   363  	reposQueries := map[string]*bytes.Buffer{}
   364  	for _, repo := range repos {
   365  		slashSplit := strings.Split(repo, "/")
   366  		if n := len(slashSplit); n != 2 {
   367  			log.WithField("repo", repo).Warn("Found repo that was not in org/repo format, ignoring...")
   368  			continue
   369  		}
   370  		org := slashSplit[0]
   371  		if _, hasOrgQuery := result[org]; hasOrgQuery {
   372  			log.WithField("repo", repo).Warn("Plugin was enabled for repo even though it is already enabled for the org, ignoring...")
   373  			continue
   374  		}
   375  		var b *bytes.Buffer
   376  		if usesGitHubAppsAuth {
   377  			if reposQueries[org] == nil {
   378  				reposQueries[org] = bytes.NewBufferString(searchQueryPrefix)
   379  			}
   380  			b = reposQueries[org]
   381  		} else {
   382  			if reposQueries[""] == nil {
   383  				reposQueries[""] = bytes.NewBufferString(searchQueryPrefix)
   384  			}
   385  			b = reposQueries[""]
   386  		}
   387  		fmt.Fprintf(b, " repo:\"%s\"", repo)
   388  	}
   389  	for org, repoQuery := range reposQueries {
   390  		result[org] = append(result[org], repoQuery.String())
   391  	}
   392  
   393  	return result
   394  }
   395  
   396  type timeNow func() time.Time
   397  
   398  type Cache struct {
   399  	cache       map[int]time.Time
   400  	validTime   time.Duration
   401  	currentTime timeNow
   402  	mutex       sync.Mutex
   403  }
   404  
   405  func (cache *Cache) Get(key int) bool {
   406  	cache.mutex.Lock()
   407  	defer cache.mutex.Unlock()
   408  
   409  	insertTime := cache.cache[key]
   410  	curTime := cache.currentTime()
   411  	age := curTime.Sub(insertTime)
   412  	if age > 0 && age < cache.validTime {
   413  		return true
   414  	}
   415  	delete(cache.cache, key)
   416  	return false
   417  }
   418  
   419  func (cache *Cache) Set(key int) {
   420  	if cache.validTime > 0 {
   421  		cache.mutex.Lock()
   422  		defer cache.mutex.Unlock()
   423  
   424  		cache.cache[key] = cache.currentTime()
   425  	}
   426  }
   427  
   428  func (cache *Cache) Flush() {
   429  	cache.mutex.Lock()
   430  	defer cache.mutex.Unlock()
   431  	cache.cache = make(map[int]time.Time)
   432  }
   433  
   434  func NewCache(validTime int) *Cache {
   435  	return &Cache{
   436  		cache:       make(map[int]time.Time),
   437  		validTime:   time.Second * time.Duration(validTime),
   438  		currentTime: func() time.Time { return time.Now() },
   439  	}
   440  }