github.com/cli/cli@v1.14.1-0.20210902173923-1af6a669e342/pkg/cmd/pr/shared/finder.go (about)

     1  package shared
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"net/url"
     9  	"regexp"
    10  	"sort"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/cli/cli/api"
    15  	remotes "github.com/cli/cli/context"
    16  	"github.com/cli/cli/git"
    17  	"github.com/cli/cli/internal/ghinstance"
    18  	"github.com/cli/cli/internal/ghrepo"
    19  	"github.com/cli/cli/pkg/cmdutil"
    20  	"github.com/cli/cli/pkg/set"
    21  	"github.com/shurcooL/githubv4"
    22  	"github.com/shurcooL/graphql"
    23  	"golang.org/x/sync/errgroup"
    24  )
    25  
    26  type PRFinder interface {
    27  	Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error)
    28  }
    29  
    30  type progressIndicator interface {
    31  	StartProgressIndicator()
    32  	StopProgressIndicator()
    33  }
    34  
    35  type finder struct {
    36  	baseRepoFn   func() (ghrepo.Interface, error)
    37  	branchFn     func() (string, error)
    38  	remotesFn    func() (remotes.Remotes, error)
    39  	httpClient   func() (*http.Client, error)
    40  	branchConfig func(string) git.BranchConfig
    41  	progress     progressIndicator
    42  
    43  	repo       ghrepo.Interface
    44  	prNumber   int
    45  	branchName string
    46  }
    47  
    48  func NewFinder(factory *cmdutil.Factory) PRFinder {
    49  	if runCommandFinder != nil {
    50  		f := runCommandFinder
    51  		runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")}
    52  		return f
    53  	}
    54  
    55  	return &finder{
    56  		baseRepoFn:   factory.BaseRepo,
    57  		branchFn:     factory.Branch,
    58  		remotesFn:    factory.Remotes,
    59  		httpClient:   factory.HttpClient,
    60  		progress:     factory.IOStreams,
    61  		branchConfig: git.ReadBranchConfig,
    62  	}
    63  }
    64  
    65  var runCommandFinder PRFinder
    66  
    67  // RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests.
    68  func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
    69  	finder := NewMockFinder(selector, pr, repo)
    70  	runCommandFinder = finder
    71  	return finder
    72  }
    73  
    74  type FindOptions struct {
    75  	// Selector can be a number with optional `#` prefix, a branch name with optional `<owner>:` prefix, or
    76  	// a PR URL.
    77  	Selector string
    78  	// Fields lists the GraphQL fields to fetch for the PullRequest.
    79  	Fields []string
    80  	// BaseBranch is the name of the base branch to scope the PR-for-branch lookup to.
    81  	BaseBranch string
    82  	// States lists the possible PR states to scope the PR-for-branch lookup to.
    83  	States []string
    84  }
    85  
    86  func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
    87  	if len(opts.Fields) == 0 {
    88  		return nil, nil, errors.New("Find error: no fields specified")
    89  	}
    90  
    91  	if repo, prNumber, err := f.parseURL(opts.Selector); err == nil {
    92  		f.prNumber = prNumber
    93  		f.repo = repo
    94  	}
    95  
    96  	if f.repo == nil {
    97  		repo, err := f.baseRepoFn()
    98  		if err != nil {
    99  			return nil, nil, fmt.Errorf("could not determine base repo: %w", err)
   100  		}
   101  		f.repo = repo
   102  	}
   103  
   104  	if opts.Selector == "" {
   105  		if branch, prNumber, err := f.parseCurrentBranch(); err != nil {
   106  			return nil, nil, err
   107  		} else if prNumber > 0 {
   108  			f.prNumber = prNumber
   109  		} else {
   110  			f.branchName = branch
   111  		}
   112  	} else if f.prNumber == 0 {
   113  		if prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#")); err == nil {
   114  			f.prNumber = prNumber
   115  		} else {
   116  			f.branchName = opts.Selector
   117  		}
   118  	}
   119  
   120  	httpClient, err := f.httpClient()
   121  	if err != nil {
   122  		return nil, nil, err
   123  	}
   124  
   125  	if f.progress != nil {
   126  		f.progress.StartProgressIndicator()
   127  		defer f.progress.StopProgressIndicator()
   128  	}
   129  
   130  	fields := set.NewStringSet()
   131  	fields.AddValues(opts.Fields)
   132  	numberFieldOnly := fields.Len() == 1 && fields.Contains("number")
   133  	fields.Add("id") // for additional preload queries below
   134  
   135  	var pr *api.PullRequest
   136  	if f.prNumber > 0 {
   137  		if numberFieldOnly {
   138  			// avoid hitting the API if we already have all the information
   139  			return &api.PullRequest{Number: f.prNumber}, f.repo, nil
   140  		}
   141  		pr, err = findByNumber(httpClient, f.repo, f.prNumber, fields.ToSlice())
   142  	} else {
   143  		pr, err = findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, fields.ToSlice())
   144  	}
   145  	if err != nil {
   146  		return pr, f.repo, err
   147  	}
   148  
   149  	g, _ := errgroup.WithContext(context.Background())
   150  	if fields.Contains("reviews") {
   151  		g.Go(func() error {
   152  			return preloadPrReviews(httpClient, f.repo, pr)
   153  		})
   154  	}
   155  	if fields.Contains("comments") {
   156  		g.Go(func() error {
   157  			return preloadPrComments(httpClient, f.repo, pr)
   158  		})
   159  	}
   160  	if fields.Contains("statusCheckRollup") {
   161  		g.Go(func() error {
   162  			return preloadPrChecks(httpClient, f.repo, pr)
   163  		})
   164  	}
   165  
   166  	return pr, f.repo, g.Wait()
   167  }
   168  
   169  var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`)
   170  
   171  func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) {
   172  	if prURL == "" {
   173  		return nil, 0, fmt.Errorf("invalid URL: %q", prURL)
   174  	}
   175  
   176  	u, err := url.Parse(prURL)
   177  	if err != nil {
   178  		return nil, 0, err
   179  	}
   180  
   181  	if u.Scheme != "https" && u.Scheme != "http" {
   182  		return nil, 0, fmt.Errorf("invalid scheme: %s", u.Scheme)
   183  	}
   184  
   185  	m := pullURLRE.FindStringSubmatch(u.Path)
   186  	if m == nil {
   187  		return nil, 0, fmt.Errorf("not a pull request URL: %s", prURL)
   188  	}
   189  
   190  	repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname())
   191  	prNumber, _ := strconv.Atoi(m[3])
   192  	return repo, prNumber, nil
   193  }
   194  
   195  var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`)
   196  
   197  func (f *finder) parseCurrentBranch() (string, int, error) {
   198  	prHeadRef, err := f.branchFn()
   199  	if err != nil {
   200  		return "", 0, err
   201  	}
   202  
   203  	branchConfig := f.branchConfig(prHeadRef)
   204  
   205  	// the branch is configured to merge a special PR head ref
   206  	if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
   207  		prNumber, _ := strconv.Atoi(m[1])
   208  		return "", prNumber, nil
   209  	}
   210  
   211  	var branchOwner string
   212  	if branchConfig.RemoteURL != nil {
   213  		// the branch merges from a remote specified by URL
   214  		if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil {
   215  			branchOwner = r.RepoOwner()
   216  		}
   217  	} else if branchConfig.RemoteName != "" {
   218  		// the branch merges from a remote specified by name
   219  		rem, _ := f.remotesFn()
   220  		if r, err := rem.FindByName(branchConfig.RemoteName); err == nil {
   221  			branchOwner = r.RepoOwner()
   222  		}
   223  	}
   224  
   225  	if branchOwner != "" {
   226  		if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") {
   227  			prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
   228  		}
   229  		// prepend `OWNER:` if this branch is pushed to a fork
   230  		if !strings.EqualFold(branchOwner, f.repo.RepoOwner()) {
   231  			prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef)
   232  		}
   233  	}
   234  
   235  	return prHeadRef, 0, nil
   236  }
   237  
   238  func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) {
   239  	type response struct {
   240  		Repository struct {
   241  			PullRequest api.PullRequest
   242  		}
   243  	}
   244  
   245  	query := fmt.Sprintf(`
   246  	query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) {
   247  		repository(owner: $owner, name: $repo) {
   248  			pullRequest(number: $pr_number) {%s}
   249  		}
   250  	}`, api.PullRequestGraphQL(fields))
   251  
   252  	variables := map[string]interface{}{
   253  		"owner":     repo.RepoOwner(),
   254  		"repo":      repo.RepoName(),
   255  		"pr_number": number,
   256  	}
   257  
   258  	var resp response
   259  	client := api.NewClientFromHTTP(httpClient)
   260  	err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
   261  	if err != nil {
   262  		return nil, err
   263  	}
   264  
   265  	return &resp.Repository.PullRequest, nil
   266  }
   267  
   268  func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) {
   269  	type response struct {
   270  		Repository struct {
   271  			PullRequests struct {
   272  				Nodes []api.PullRequest
   273  			}
   274  		}
   275  	}
   276  
   277  	fieldSet := set.NewStringSet()
   278  	fieldSet.AddValues(fields)
   279  	// these fields are required for filtering below
   280  	fieldSet.AddValues([]string{"state", "baseRefName", "headRefName", "isCrossRepository", "headRepositoryOwner"})
   281  
   282  	query := fmt.Sprintf(`
   283  	query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) {
   284  		repository(owner: $owner, name: $repo) {
   285  			pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) {
   286  				nodes {%s}
   287  			}
   288  		}
   289  	}`, api.PullRequestGraphQL(fieldSet.ToSlice()))
   290  
   291  	branchWithoutOwner := headBranch
   292  	if idx := strings.Index(headBranch, ":"); idx >= 0 {
   293  		branchWithoutOwner = headBranch[idx+1:]
   294  	}
   295  
   296  	variables := map[string]interface{}{
   297  		"owner":       repo.RepoOwner(),
   298  		"repo":        repo.RepoName(),
   299  		"headRefName": branchWithoutOwner,
   300  		"states":      stateFilters,
   301  	}
   302  
   303  	var resp response
   304  	client := api.NewClientFromHTTP(httpClient)
   305  	err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
   306  	if err != nil {
   307  		return nil, err
   308  	}
   309  
   310  	prs := resp.Repository.PullRequests.Nodes
   311  	sort.SliceStable(prs, func(a, b int) bool {
   312  		return prs[a].State == "OPEN" && prs[b].State != "OPEN"
   313  	})
   314  
   315  	for _, pr := range prs {
   316  		if pr.HeadLabel() == headBranch && (baseBranch == "" || pr.BaseRefName == baseBranch) {
   317  			return &pr, nil
   318  		}
   319  	}
   320  
   321  	return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)}
   322  }
   323  
   324  func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
   325  	if !pr.Reviews.PageInfo.HasNextPage {
   326  		return nil
   327  	}
   328  
   329  	type response struct {
   330  		Node struct {
   331  			PullRequest struct {
   332  				Reviews api.PullRequestReviews `graphql:"reviews(first: 100, after: $endCursor)"`
   333  			} `graphql:"...on PullRequest"`
   334  		} `graphql:"node(id: $id)"`
   335  	}
   336  
   337  	variables := map[string]interface{}{
   338  		"id":        githubv4.ID(pr.ID),
   339  		"endCursor": githubv4.String(pr.Reviews.PageInfo.EndCursor),
   340  	}
   341  
   342  	gql := graphql.NewClient(ghinstance.GraphQLEndpoint(repo.RepoHost()), httpClient)
   343  
   344  	for {
   345  		var query response
   346  		err := gql.QueryNamed(context.Background(), "ReviewsForPullRequest", &query, variables)
   347  		if err != nil {
   348  			return err
   349  		}
   350  
   351  		pr.Reviews.Nodes = append(pr.Reviews.Nodes, query.Node.PullRequest.Reviews.Nodes...)
   352  		pr.Reviews.TotalCount = len(pr.Reviews.Nodes)
   353  
   354  		if !query.Node.PullRequest.Reviews.PageInfo.HasNextPage {
   355  			break
   356  		}
   357  		variables["endCursor"] = githubv4.String(query.Node.PullRequest.Reviews.PageInfo.EndCursor)
   358  	}
   359  
   360  	pr.Reviews.PageInfo.HasNextPage = false
   361  	return nil
   362  }
   363  
   364  func preloadPrComments(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
   365  	if !pr.Comments.PageInfo.HasNextPage {
   366  		return nil
   367  	}
   368  
   369  	type response struct {
   370  		Node struct {
   371  			PullRequest struct {
   372  				Comments api.Comments `graphql:"comments(first: 100, after: $endCursor)"`
   373  			} `graphql:"...on PullRequest"`
   374  		} `graphql:"node(id: $id)"`
   375  	}
   376  
   377  	variables := map[string]interface{}{
   378  		"id":        githubv4.ID(pr.ID),
   379  		"endCursor": githubv4.String(pr.Comments.PageInfo.EndCursor),
   380  	}
   381  
   382  	gql := graphql.NewClient(ghinstance.GraphQLEndpoint(repo.RepoHost()), client)
   383  
   384  	for {
   385  		var query response
   386  		err := gql.QueryNamed(context.Background(), "CommentsForPullRequest", &query, variables)
   387  		if err != nil {
   388  			return err
   389  		}
   390  
   391  		pr.Comments.Nodes = append(pr.Comments.Nodes, query.Node.PullRequest.Comments.Nodes...)
   392  		pr.Comments.TotalCount = len(pr.Comments.Nodes)
   393  
   394  		if !query.Node.PullRequest.Comments.PageInfo.HasNextPage {
   395  			break
   396  		}
   397  		variables["endCursor"] = githubv4.String(query.Node.PullRequest.Comments.PageInfo.EndCursor)
   398  	}
   399  
   400  	pr.Comments.PageInfo.HasNextPage = false
   401  	return nil
   402  }
   403  
   404  func preloadPrChecks(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
   405  	if len(pr.StatusCheckRollup.Nodes) == 0 {
   406  		return nil
   407  	}
   408  	statusCheckRollup := &pr.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts
   409  	if !statusCheckRollup.PageInfo.HasNextPage {
   410  		return nil
   411  	}
   412  
   413  	endCursor := statusCheckRollup.PageInfo.EndCursor
   414  
   415  	type response struct {
   416  		Node *api.PullRequest
   417  	}
   418  
   419  	query := fmt.Sprintf(`
   420  	query PullRequestStatusChecks($id: ID!, $endCursor: String!) {
   421  		node(id: $id) {
   422  			...on PullRequest {
   423  				%s
   424  			}
   425  		}
   426  	}`, api.StatusCheckRollupGraphQL("$endCursor"))
   427  
   428  	variables := map[string]interface{}{
   429  		"id": pr.ID,
   430  	}
   431  
   432  	apiClient := api.NewClientFromHTTP(client)
   433  	for {
   434  		variables["endCursor"] = endCursor
   435  		var resp response
   436  		err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp)
   437  		if err != nil {
   438  			return err
   439  		}
   440  
   441  		result := resp.Node.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts
   442  		statusCheckRollup.Nodes = append(
   443  			statusCheckRollup.Nodes,
   444  			result.Nodes...,
   445  		)
   446  
   447  		if !result.PageInfo.HasNextPage {
   448  			break
   449  		}
   450  		endCursor = result.PageInfo.EndCursor
   451  	}
   452  
   453  	statusCheckRollup.PageInfo.HasNextPage = false
   454  	return nil
   455  }
   456  
   457  type NotFoundError struct {
   458  	error
   459  }
   460  
   461  func (err *NotFoundError) Unwrap() error {
   462  	return err.error
   463  }
   464  
   465  func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
   466  	var err error
   467  	if pr == nil {
   468  		err = &NotFoundError{errors.New("no pull requests found")}
   469  	}
   470  	return &mockFinder{
   471  		expectSelector: selector,
   472  		pr:             pr,
   473  		repo:           repo,
   474  		err:            err,
   475  	}
   476  }
   477  
   478  type mockFinder struct {
   479  	called         bool
   480  	expectSelector string
   481  	expectFields   []string
   482  	pr             *api.PullRequest
   483  	repo           ghrepo.Interface
   484  	err            error
   485  }
   486  
   487  func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
   488  	if m.err != nil {
   489  		return nil, nil, m.err
   490  	}
   491  	if m.expectSelector != opts.Selector {
   492  		return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector)
   493  	}
   494  	if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) {
   495  		return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields)
   496  	}
   497  	if m.called {
   498  		return nil, nil, errors.New("mockFinder used more than once")
   499  	}
   500  	m.called = true
   501  
   502  	if m.pr.HeadRepositoryOwner.Login == "" {
   503  		// pose as same-repo PR by default
   504  		m.pr.HeadRepositoryOwner.Login = m.repo.RepoOwner()
   505  	}
   506  
   507  	return m.pr, m.repo, nil
   508  }
   509  
   510  func (m *mockFinder) ExpectFields(fields []string) {
   511  	m.expectFields = fields
   512  }
   513  
   514  func isEqualSet(a, b []string) bool {
   515  	if len(a) != len(b) {
   516  		return false
   517  	}
   518  
   519  	aCopy := make([]string, len(a))
   520  	copy(aCopy, a)
   521  	bCopy := make([]string, len(b))
   522  	copy(bCopy, b)
   523  	sort.Strings(aCopy)
   524  	sort.Strings(bCopy)
   525  
   526  	for i := range aCopy {
   527  		if aCopy[i] != bCopy[i] {
   528  			return false
   529  		}
   530  	}
   531  	return true
   532  }