github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/cmd/pr/shared/finder.go (about)

     1  package shared
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"net/url"
     9  	"regexp"
    10  	"sort"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/ungtb10d/cli/v2/api"
    16  	remotes "github.com/ungtb10d/cli/v2/context"
    17  	"github.com/ungtb10d/cli/v2/git"
    18  	fd "github.com/ungtb10d/cli/v2/internal/featuredetection"
    19  	"github.com/ungtb10d/cli/v2/internal/ghrepo"
    20  	"github.com/ungtb10d/cli/v2/pkg/cmdutil"
    21  	"github.com/ungtb10d/cli/v2/pkg/set"
    22  	"github.com/shurcooL/githubv4"
    23  	"golang.org/x/sync/errgroup"
    24  )
    25  
    26  type PRFinder interface {
    27  	Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error)
    28  }
    29  
    30  type progressIndicator interface {
    31  	StartProgressIndicator()
    32  	StopProgressIndicator()
    33  }
    34  
    35  type finder struct {
    36  	baseRepoFn   func() (ghrepo.Interface, error)
    37  	branchFn     func() (string, error)
    38  	remotesFn    func() (remotes.Remotes, error)
    39  	httpClient   func() (*http.Client, error)
    40  	branchConfig func(string) git.BranchConfig
    41  	progress     progressIndicator
    42  
    43  	repo       ghrepo.Interface
    44  	prNumber   int
    45  	branchName string
    46  }
    47  
    48  func NewFinder(factory *cmdutil.Factory) PRFinder {
    49  	if runCommandFinder != nil {
    50  		f := runCommandFinder
    51  		runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")}
    52  		return f
    53  	}
    54  
    55  	return &finder{
    56  		baseRepoFn: factory.BaseRepo,
    57  		branchFn:   factory.Branch,
    58  		remotesFn:  factory.Remotes,
    59  		httpClient: factory.HttpClient,
    60  		progress:   factory.IOStreams,
    61  		branchConfig: func(s string) git.BranchConfig {
    62  			return factory.GitClient.ReadBranchConfig(context.Background(), s)
    63  		},
    64  	}
    65  }
    66  
    67  var runCommandFinder PRFinder
    68  
    69  // RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests.
    70  func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
    71  	finder := NewMockFinder(selector, pr, repo)
    72  	runCommandFinder = finder
    73  	return finder
    74  }
    75  
    76  type FindOptions struct {
    77  	// Selector can be a number with optional `#` prefix, a branch name with optional `<owner>:` prefix, or
    78  	// a PR URL.
    79  	Selector string
    80  	// Fields lists the GraphQL fields to fetch for the PullRequest.
    81  	Fields []string
    82  	// BaseBranch is the name of the base branch to scope the PR-for-branch lookup to.
    83  	BaseBranch string
    84  	// States lists the possible PR states to scope the PR-for-branch lookup to.
    85  	States []string
    86  }
    87  
    88  func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
    89  	if len(opts.Fields) == 0 {
    90  		return nil, nil, errors.New("Find error: no fields specified")
    91  	}
    92  
    93  	if repo, prNumber, err := f.parseURL(opts.Selector); err == nil {
    94  		f.prNumber = prNumber
    95  		f.repo = repo
    96  	}
    97  
    98  	if f.repo == nil {
    99  		repo, err := f.baseRepoFn()
   100  		if err != nil {
   101  			return nil, nil, fmt.Errorf("could not determine base repo: %w", err)
   102  		}
   103  		f.repo = repo
   104  	}
   105  
   106  	if opts.Selector == "" {
   107  		if branch, prNumber, err := f.parseCurrentBranch(); err != nil {
   108  			return nil, nil, err
   109  		} else if prNumber > 0 {
   110  			f.prNumber = prNumber
   111  		} else {
   112  			f.branchName = branch
   113  		}
   114  	} else if f.prNumber == 0 {
   115  		// If opts.Selector is a valid number then assume it is the
   116  		// PR number unless opts.BaseBranch is specified. This is a
   117  		// special case for PR create command which will always want
   118  		// to assume that a numerical selector is a branch name rather
   119  		// than PR number.
   120  		prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#"))
   121  		if opts.BaseBranch == "" && err == nil {
   122  			f.prNumber = prNumber
   123  		} else {
   124  			f.branchName = opts.Selector
   125  		}
   126  	}
   127  
   128  	httpClient, err := f.httpClient()
   129  	if err != nil {
   130  		return nil, nil, err
   131  	}
   132  
   133  	// TODO(josebalius): Should we be guarding here?
   134  	if f.progress != nil {
   135  		f.progress.StartProgressIndicator()
   136  		defer f.progress.StopProgressIndicator()
   137  	}
   138  
   139  	fields := set.NewStringSet()
   140  	fields.AddValues(opts.Fields)
   141  	numberFieldOnly := fields.Len() == 1 && fields.Contains("number")
   142  	fields.Add("id") // for additional preload queries below
   143  
   144  	if fields.Contains("isInMergeQueue") || fields.Contains("isMergeQueueEnabled") {
   145  		cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24)
   146  		detector := fd.NewDetector(cachedClient, f.repo.RepoHost())
   147  		prFeatures, err := detector.PullRequestFeatures()
   148  		if err != nil {
   149  			return nil, nil, err
   150  		}
   151  		if !prFeatures.MergeQueue {
   152  			fields.Remove("isInMergeQueue")
   153  			fields.Remove("isMergeQueueEnabled")
   154  		}
   155  	}
   156  
   157  	var pr *api.PullRequest
   158  	if f.prNumber > 0 {
   159  		if numberFieldOnly {
   160  			// avoid hitting the API if we already have all the information
   161  			return &api.PullRequest{Number: f.prNumber}, f.repo, nil
   162  		}
   163  		pr, err = findByNumber(httpClient, f.repo, f.prNumber, fields.ToSlice())
   164  	} else {
   165  		pr, err = findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, fields.ToSlice())
   166  	}
   167  	if err != nil {
   168  		return pr, f.repo, err
   169  	}
   170  
   171  	g, _ := errgroup.WithContext(context.Background())
   172  	if fields.Contains("reviews") {
   173  		g.Go(func() error {
   174  			return preloadPrReviews(httpClient, f.repo, pr)
   175  		})
   176  	}
   177  	if fields.Contains("comments") {
   178  		g.Go(func() error {
   179  			return preloadPrComments(httpClient, f.repo, pr)
   180  		})
   181  	}
   182  	if fields.Contains("statusCheckRollup") {
   183  		g.Go(func() error {
   184  			return preloadPrChecks(httpClient, f.repo, pr)
   185  		})
   186  	}
   187  
   188  	return pr, f.repo, g.Wait()
   189  }
   190  
   191  var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`)
   192  
   193  func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) {
   194  	if prURL == "" {
   195  		return nil, 0, fmt.Errorf("invalid URL: %q", prURL)
   196  	}
   197  
   198  	u, err := url.Parse(prURL)
   199  	if err != nil {
   200  		return nil, 0, err
   201  	}
   202  
   203  	if u.Scheme != "https" && u.Scheme != "http" {
   204  		return nil, 0, fmt.Errorf("invalid scheme: %s", u.Scheme)
   205  	}
   206  
   207  	m := pullURLRE.FindStringSubmatch(u.Path)
   208  	if m == nil {
   209  		return nil, 0, fmt.Errorf("not a pull request URL: %s", prURL)
   210  	}
   211  
   212  	repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname())
   213  	prNumber, _ := strconv.Atoi(m[3])
   214  	return repo, prNumber, nil
   215  }
   216  
   217  var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`)
   218  
   219  func (f *finder) parseCurrentBranch() (string, int, error) {
   220  	prHeadRef, err := f.branchFn()
   221  	if err != nil {
   222  		return "", 0, err
   223  	}
   224  
   225  	branchConfig := f.branchConfig(prHeadRef)
   226  
   227  	// the branch is configured to merge a special PR head ref
   228  	if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
   229  		prNumber, _ := strconv.Atoi(m[1])
   230  		return "", prNumber, nil
   231  	}
   232  
   233  	var branchOwner string
   234  	if branchConfig.RemoteURL != nil {
   235  		// the branch merges from a remote specified by URL
   236  		if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil {
   237  			branchOwner = r.RepoOwner()
   238  		}
   239  	} else if branchConfig.RemoteName != "" {
   240  		// the branch merges from a remote specified by name
   241  		rem, _ := f.remotesFn()
   242  		if r, err := rem.FindByName(branchConfig.RemoteName); err == nil {
   243  			branchOwner = r.RepoOwner()
   244  		}
   245  	}
   246  
   247  	if branchOwner != "" {
   248  		if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") {
   249  			prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
   250  		}
   251  		// prepend `OWNER:` if this branch is pushed to a fork
   252  		if !strings.EqualFold(branchOwner, f.repo.RepoOwner()) {
   253  			prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef)
   254  		}
   255  	}
   256  
   257  	return prHeadRef, 0, nil
   258  }
   259  
   260  func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) {
   261  	type response struct {
   262  		Repository struct {
   263  			PullRequest api.PullRequest
   264  		}
   265  	}
   266  
   267  	query := fmt.Sprintf(`
   268  	query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) {
   269  		repository(owner: $owner, name: $repo) {
   270  			pullRequest(number: $pr_number) {%s}
   271  		}
   272  	}`, api.PullRequestGraphQL(fields))
   273  
   274  	variables := map[string]interface{}{
   275  		"owner":     repo.RepoOwner(),
   276  		"repo":      repo.RepoName(),
   277  		"pr_number": number,
   278  	}
   279  
   280  	var resp response
   281  	client := api.NewClientFromHTTP(httpClient)
   282  	err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  
   287  	return &resp.Repository.PullRequest, nil
   288  }
   289  
   290  func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) {
   291  	type response struct {
   292  		Repository struct {
   293  			PullRequests struct {
   294  				Nodes []api.PullRequest
   295  			}
   296  			DefaultBranchRef struct {
   297  				Name string
   298  			}
   299  		}
   300  	}
   301  
   302  	fieldSet := set.NewStringSet()
   303  	fieldSet.AddValues(fields)
   304  	// these fields are required for filtering below
   305  	fieldSet.AddValues([]string{"state", "baseRefName", "headRefName", "isCrossRepository", "headRepositoryOwner"})
   306  
   307  	query := fmt.Sprintf(`
   308  	query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) {
   309  		repository(owner: $owner, name: $repo) {
   310  			pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) {
   311  				nodes {%s}
   312  			}
   313  			defaultBranchRef { name }
   314  		}
   315  	}`, api.PullRequestGraphQL(fieldSet.ToSlice()))
   316  
   317  	branchWithoutOwner := headBranch
   318  	if idx := strings.Index(headBranch, ":"); idx >= 0 {
   319  		branchWithoutOwner = headBranch[idx+1:]
   320  	}
   321  
   322  	variables := map[string]interface{}{
   323  		"owner":       repo.RepoOwner(),
   324  		"repo":        repo.RepoName(),
   325  		"headRefName": branchWithoutOwner,
   326  		"states":      stateFilters,
   327  	}
   328  
   329  	var resp response
   330  	client := api.NewClientFromHTTP(httpClient)
   331  	err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
   332  	if err != nil {
   333  		return nil, err
   334  	}
   335  
   336  	prs := resp.Repository.PullRequests.Nodes
   337  	sort.SliceStable(prs, func(a, b int) bool {
   338  		return prs[a].State == "OPEN" && prs[b].State != "OPEN"
   339  	})
   340  
   341  	for _, pr := range prs {
   342  		if pr.HeadLabel() == headBranch && (baseBranch == "" || pr.BaseRefName == baseBranch) && (pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranch) {
   343  			return &pr, nil
   344  		}
   345  	}
   346  
   347  	return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)}
   348  }
   349  
   350  func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
   351  	if !pr.Reviews.PageInfo.HasNextPage {
   352  		return nil
   353  	}
   354  
   355  	type response struct {
   356  		Node struct {
   357  			PullRequest struct {
   358  				Reviews api.PullRequestReviews `graphql:"reviews(first: 100, after: $endCursor)"`
   359  			} `graphql:"...on PullRequest"`
   360  		} `graphql:"node(id: $id)"`
   361  	}
   362  
   363  	variables := map[string]interface{}{
   364  		"id":        githubv4.ID(pr.ID),
   365  		"endCursor": githubv4.String(pr.Reviews.PageInfo.EndCursor),
   366  	}
   367  
   368  	gql := api.NewClientFromHTTP(httpClient)
   369  
   370  	for {
   371  		var query response
   372  		err := gql.Query(repo.RepoHost(), "ReviewsForPullRequest", &query, variables)
   373  		if err != nil {
   374  			return err
   375  		}
   376  
   377  		pr.Reviews.Nodes = append(pr.Reviews.Nodes, query.Node.PullRequest.Reviews.Nodes...)
   378  		pr.Reviews.TotalCount = len(pr.Reviews.Nodes)
   379  
   380  		if !query.Node.PullRequest.Reviews.PageInfo.HasNextPage {
   381  			break
   382  		}
   383  		variables["endCursor"] = githubv4.String(query.Node.PullRequest.Reviews.PageInfo.EndCursor)
   384  	}
   385  
   386  	pr.Reviews.PageInfo.HasNextPage = false
   387  	return nil
   388  }
   389  
   390  func preloadPrComments(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
   391  	if !pr.Comments.PageInfo.HasNextPage {
   392  		return nil
   393  	}
   394  
   395  	type response struct {
   396  		Node struct {
   397  			PullRequest struct {
   398  				Comments api.Comments `graphql:"comments(first: 100, after: $endCursor)"`
   399  			} `graphql:"...on PullRequest"`
   400  		} `graphql:"node(id: $id)"`
   401  	}
   402  
   403  	variables := map[string]interface{}{
   404  		"id":        githubv4.ID(pr.ID),
   405  		"endCursor": githubv4.String(pr.Comments.PageInfo.EndCursor),
   406  	}
   407  
   408  	gql := api.NewClientFromHTTP(client)
   409  
   410  	for {
   411  		var query response
   412  		err := gql.Query(repo.RepoHost(), "CommentsForPullRequest", &query, variables)
   413  		if err != nil {
   414  			return err
   415  		}
   416  
   417  		pr.Comments.Nodes = append(pr.Comments.Nodes, query.Node.PullRequest.Comments.Nodes...)
   418  		pr.Comments.TotalCount = len(pr.Comments.Nodes)
   419  
   420  		if !query.Node.PullRequest.Comments.PageInfo.HasNextPage {
   421  			break
   422  		}
   423  		variables["endCursor"] = githubv4.String(query.Node.PullRequest.Comments.PageInfo.EndCursor)
   424  	}
   425  
   426  	pr.Comments.PageInfo.HasNextPage = false
   427  	return nil
   428  }
   429  
   430  func preloadPrChecks(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
   431  	if len(pr.StatusCheckRollup.Nodes) == 0 {
   432  		return nil
   433  	}
   434  	statusCheckRollup := &pr.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts
   435  	if !statusCheckRollup.PageInfo.HasNextPage {
   436  		return nil
   437  	}
   438  
   439  	endCursor := statusCheckRollup.PageInfo.EndCursor
   440  
   441  	type response struct {
   442  		Node *api.PullRequest
   443  	}
   444  
   445  	query := fmt.Sprintf(`
   446  	query PullRequestStatusChecks($id: ID!, $endCursor: String!) {
   447  		node(id: $id) {
   448  			...on PullRequest {
   449  				%s
   450  			}
   451  		}
   452  	}`, api.StatusCheckRollupGraphQL("$endCursor"))
   453  
   454  	variables := map[string]interface{}{
   455  		"id": pr.ID,
   456  	}
   457  
   458  	apiClient := api.NewClientFromHTTP(client)
   459  	for {
   460  		variables["endCursor"] = endCursor
   461  		var resp response
   462  		err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp)
   463  		if err != nil {
   464  			return err
   465  		}
   466  
   467  		result := resp.Node.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts
   468  		statusCheckRollup.Nodes = append(
   469  			statusCheckRollup.Nodes,
   470  			result.Nodes...,
   471  		)
   472  
   473  		if !result.PageInfo.HasNextPage {
   474  			break
   475  		}
   476  		endCursor = result.PageInfo.EndCursor
   477  	}
   478  
   479  	statusCheckRollup.PageInfo.HasNextPage = false
   480  	return nil
   481  }
   482  
   483  type NotFoundError struct {
   484  	error
   485  }
   486  
   487  func (err *NotFoundError) Unwrap() error {
   488  	return err.error
   489  }
   490  
   491  func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
   492  	var err error
   493  	if pr == nil {
   494  		err = &NotFoundError{errors.New("no pull requests found")}
   495  	}
   496  	return &mockFinder{
   497  		expectSelector: selector,
   498  		pr:             pr,
   499  		repo:           repo,
   500  		err:            err,
   501  	}
   502  }
   503  
   504  type mockFinder struct {
   505  	called         bool
   506  	expectSelector string
   507  	expectFields   []string
   508  	pr             *api.PullRequest
   509  	repo           ghrepo.Interface
   510  	err            error
   511  }
   512  
   513  func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
   514  	if m.err != nil {
   515  		return nil, nil, m.err
   516  	}
   517  	if m.expectSelector != opts.Selector {
   518  		return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector)
   519  	}
   520  	if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) {
   521  		return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields)
   522  	}
   523  	if m.called {
   524  		return nil, nil, errors.New("mockFinder used more than once")
   525  	}
   526  	m.called = true
   527  
   528  	if m.pr.HeadRepositoryOwner.Login == "" {
   529  		// pose as same-repo PR by default
   530  		m.pr.HeadRepositoryOwner.Login = m.repo.RepoOwner()
   531  	}
   532  
   533  	return m.pr, m.repo, nil
   534  }
   535  
   536  func (m *mockFinder) ExpectFields(fields []string) {
   537  	m.expectFields = fields
   538  }
   539  
   540  func isEqualSet(a, b []string) bool {
   541  	if len(a) != len(b) {
   542  		return false
   543  	}
   544  
   545  	aCopy := make([]string, len(a))
   546  	copy(aCopy, a)
   547  	bCopy := make([]string, len(b))
   548  	copy(bCopy, b)
   549  	sort.Strings(aCopy)
   550  	sort.Strings(bCopy)
   551  
   552  	for i := range aCopy {
   553  		if aCopy[i] != bCopy[i] {
   554  			return false
   555  		}
   556  	}
   557  	return true
   558  }