github.com/cli/cli@v1.14.1-0.20210902173923-1af6a669e342/pkg/cmd/pr/shared/finder.go (about) 1 package shared 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net/http" 8 "net/url" 9 "regexp" 10 "sort" 11 "strconv" 12 "strings" 13 14 "github.com/cli/cli/api" 15 remotes "github.com/cli/cli/context" 16 "github.com/cli/cli/git" 17 "github.com/cli/cli/internal/ghinstance" 18 "github.com/cli/cli/internal/ghrepo" 19 "github.com/cli/cli/pkg/cmdutil" 20 "github.com/cli/cli/pkg/set" 21 "github.com/shurcooL/githubv4" 22 "github.com/shurcooL/graphql" 23 "golang.org/x/sync/errgroup" 24 ) 25 26 type PRFinder interface { 27 Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) 28 } 29 30 type progressIndicator interface { 31 StartProgressIndicator() 32 StopProgressIndicator() 33 } 34 35 type finder struct { 36 baseRepoFn func() (ghrepo.Interface, error) 37 branchFn func() (string, error) 38 remotesFn func() (remotes.Remotes, error) 39 httpClient func() (*http.Client, error) 40 branchConfig func(string) git.BranchConfig 41 progress progressIndicator 42 43 repo ghrepo.Interface 44 prNumber int 45 branchName string 46 } 47 48 func NewFinder(factory *cmdutil.Factory) PRFinder { 49 if runCommandFinder != nil { 50 f := runCommandFinder 51 runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")} 52 return f 53 } 54 55 return &finder{ 56 baseRepoFn: factory.BaseRepo, 57 branchFn: factory.Branch, 58 remotesFn: factory.Remotes, 59 httpClient: factory.HttpClient, 60 progress: factory.IOStreams, 61 branchConfig: git.ReadBranchConfig, 62 } 63 } 64 65 var runCommandFinder PRFinder 66 67 // RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests. 68 func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { 69 finder := NewMockFinder(selector, pr, repo) 70 runCommandFinder = finder 71 return finder 72 } 73 74 type FindOptions struct { 75 // Selector can be a number with optional `#` prefix, a branch name with optional `<owner>:` prefix, or 76 // a PR URL. 77 Selector string 78 // Fields lists the GraphQL fields to fetch for the PullRequest. 79 Fields []string 80 // BaseBranch is the name of the base branch to scope the PR-for-branch lookup to. 81 BaseBranch string 82 // States lists the possible PR states to scope the PR-for-branch lookup to. 83 States []string 84 } 85 86 func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { 87 if len(opts.Fields) == 0 { 88 return nil, nil, errors.New("Find error: no fields specified") 89 } 90 91 if repo, prNumber, err := f.parseURL(opts.Selector); err == nil { 92 f.prNumber = prNumber 93 f.repo = repo 94 } 95 96 if f.repo == nil { 97 repo, err := f.baseRepoFn() 98 if err != nil { 99 return nil, nil, fmt.Errorf("could not determine base repo: %w", err) 100 } 101 f.repo = repo 102 } 103 104 if opts.Selector == "" { 105 if branch, prNumber, err := f.parseCurrentBranch(); err != nil { 106 return nil, nil, err 107 } else if prNumber > 0 { 108 f.prNumber = prNumber 109 } else { 110 f.branchName = branch 111 } 112 } else if f.prNumber == 0 { 113 if prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#")); err == nil { 114 f.prNumber = prNumber 115 } else { 116 f.branchName = opts.Selector 117 } 118 } 119 120 httpClient, err := f.httpClient() 121 if err != nil { 122 return nil, nil, err 123 } 124 125 if f.progress != nil { 126 f.progress.StartProgressIndicator() 127 defer f.progress.StopProgressIndicator() 128 } 129 130 fields := set.NewStringSet() 131 fields.AddValues(opts.Fields) 132 numberFieldOnly := fields.Len() == 1 && fields.Contains("number") 133 fields.Add("id") // for additional preload queries below 134 135 var pr *api.PullRequest 136 if f.prNumber > 0 { 137 if numberFieldOnly { 138 // avoid hitting the API if we already have all the information 139 return &api.PullRequest{Number: f.prNumber}, f.repo, nil 140 } 141 pr, err = findByNumber(httpClient, f.repo, f.prNumber, fields.ToSlice()) 142 } else { 143 pr, err = findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, fields.ToSlice()) 144 } 145 if err != nil { 146 return pr, f.repo, err 147 } 148 149 g, _ := errgroup.WithContext(context.Background()) 150 if fields.Contains("reviews") { 151 g.Go(func() error { 152 return preloadPrReviews(httpClient, f.repo, pr) 153 }) 154 } 155 if fields.Contains("comments") { 156 g.Go(func() error { 157 return preloadPrComments(httpClient, f.repo, pr) 158 }) 159 } 160 if fields.Contains("statusCheckRollup") { 161 g.Go(func() error { 162 return preloadPrChecks(httpClient, f.repo, pr) 163 }) 164 } 165 166 return pr, f.repo, g.Wait() 167 } 168 169 var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`) 170 171 func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) { 172 if prURL == "" { 173 return nil, 0, fmt.Errorf("invalid URL: %q", prURL) 174 } 175 176 u, err := url.Parse(prURL) 177 if err != nil { 178 return nil, 0, err 179 } 180 181 if u.Scheme != "https" && u.Scheme != "http" { 182 return nil, 0, fmt.Errorf("invalid scheme: %s", u.Scheme) 183 } 184 185 m := pullURLRE.FindStringSubmatch(u.Path) 186 if m == nil { 187 return nil, 0, fmt.Errorf("not a pull request URL: %s", prURL) 188 } 189 190 repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname()) 191 prNumber, _ := strconv.Atoi(m[3]) 192 return repo, prNumber, nil 193 } 194 195 var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`) 196 197 func (f *finder) parseCurrentBranch() (string, int, error) { 198 prHeadRef, err := f.branchFn() 199 if err != nil { 200 return "", 0, err 201 } 202 203 branchConfig := f.branchConfig(prHeadRef) 204 205 // the branch is configured to merge a special PR head ref 206 if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { 207 prNumber, _ := strconv.Atoi(m[1]) 208 return "", prNumber, nil 209 } 210 211 var branchOwner string 212 if branchConfig.RemoteURL != nil { 213 // the branch merges from a remote specified by URL 214 if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil { 215 branchOwner = r.RepoOwner() 216 } 217 } else if branchConfig.RemoteName != "" { 218 // the branch merges from a remote specified by name 219 rem, _ := f.remotesFn() 220 if r, err := rem.FindByName(branchConfig.RemoteName); err == nil { 221 branchOwner = r.RepoOwner() 222 } 223 } 224 225 if branchOwner != "" { 226 if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") { 227 prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") 228 } 229 // prepend `OWNER:` if this branch is pushed to a fork 230 if !strings.EqualFold(branchOwner, f.repo.RepoOwner()) { 231 prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef) 232 } 233 } 234 235 return prHeadRef, 0, nil 236 } 237 238 func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) { 239 type response struct { 240 Repository struct { 241 PullRequest api.PullRequest 242 } 243 } 244 245 query := fmt.Sprintf(` 246 query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) { 247 repository(owner: $owner, name: $repo) { 248 pullRequest(number: $pr_number) {%s} 249 } 250 }`, api.PullRequestGraphQL(fields)) 251 252 variables := map[string]interface{}{ 253 "owner": repo.RepoOwner(), 254 "repo": repo.RepoName(), 255 "pr_number": number, 256 } 257 258 var resp response 259 client := api.NewClientFromHTTP(httpClient) 260 err := client.GraphQL(repo.RepoHost(), query, variables, &resp) 261 if err != nil { 262 return nil, err 263 } 264 265 return &resp.Repository.PullRequest, nil 266 } 267 268 func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) { 269 type response struct { 270 Repository struct { 271 PullRequests struct { 272 Nodes []api.PullRequest 273 } 274 } 275 } 276 277 fieldSet := set.NewStringSet() 278 fieldSet.AddValues(fields) 279 // these fields are required for filtering below 280 fieldSet.AddValues([]string{"state", "baseRefName", "headRefName", "isCrossRepository", "headRepositoryOwner"}) 281 282 query := fmt.Sprintf(` 283 query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) { 284 repository(owner: $owner, name: $repo) { 285 pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) { 286 nodes {%s} 287 } 288 } 289 }`, api.PullRequestGraphQL(fieldSet.ToSlice())) 290 291 branchWithoutOwner := headBranch 292 if idx := strings.Index(headBranch, ":"); idx >= 0 { 293 branchWithoutOwner = headBranch[idx+1:] 294 } 295 296 variables := map[string]interface{}{ 297 "owner": repo.RepoOwner(), 298 "repo": repo.RepoName(), 299 "headRefName": branchWithoutOwner, 300 "states": stateFilters, 301 } 302 303 var resp response 304 client := api.NewClientFromHTTP(httpClient) 305 err := client.GraphQL(repo.RepoHost(), query, variables, &resp) 306 if err != nil { 307 return nil, err 308 } 309 310 prs := resp.Repository.PullRequests.Nodes 311 sort.SliceStable(prs, func(a, b int) bool { 312 return prs[a].State == "OPEN" && prs[b].State != "OPEN" 313 }) 314 315 for _, pr := range prs { 316 if pr.HeadLabel() == headBranch && (baseBranch == "" || pr.BaseRefName == baseBranch) { 317 return &pr, nil 318 } 319 } 320 321 return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)} 322 } 323 324 func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { 325 if !pr.Reviews.PageInfo.HasNextPage { 326 return nil 327 } 328 329 type response struct { 330 Node struct { 331 PullRequest struct { 332 Reviews api.PullRequestReviews `graphql:"reviews(first: 100, after: $endCursor)"` 333 } `graphql:"...on PullRequest"` 334 } `graphql:"node(id: $id)"` 335 } 336 337 variables := map[string]interface{}{ 338 "id": githubv4.ID(pr.ID), 339 "endCursor": githubv4.String(pr.Reviews.PageInfo.EndCursor), 340 } 341 342 gql := graphql.NewClient(ghinstance.GraphQLEndpoint(repo.RepoHost()), httpClient) 343 344 for { 345 var query response 346 err := gql.QueryNamed(context.Background(), "ReviewsForPullRequest", &query, variables) 347 if err != nil { 348 return err 349 } 350 351 pr.Reviews.Nodes = append(pr.Reviews.Nodes, query.Node.PullRequest.Reviews.Nodes...) 352 pr.Reviews.TotalCount = len(pr.Reviews.Nodes) 353 354 if !query.Node.PullRequest.Reviews.PageInfo.HasNextPage { 355 break 356 } 357 variables["endCursor"] = githubv4.String(query.Node.PullRequest.Reviews.PageInfo.EndCursor) 358 } 359 360 pr.Reviews.PageInfo.HasNextPage = false 361 return nil 362 } 363 364 func preloadPrComments(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { 365 if !pr.Comments.PageInfo.HasNextPage { 366 return nil 367 } 368 369 type response struct { 370 Node struct { 371 PullRequest struct { 372 Comments api.Comments `graphql:"comments(first: 100, after: $endCursor)"` 373 } `graphql:"...on PullRequest"` 374 } `graphql:"node(id: $id)"` 375 } 376 377 variables := map[string]interface{}{ 378 "id": githubv4.ID(pr.ID), 379 "endCursor": githubv4.String(pr.Comments.PageInfo.EndCursor), 380 } 381 382 gql := graphql.NewClient(ghinstance.GraphQLEndpoint(repo.RepoHost()), client) 383 384 for { 385 var query response 386 err := gql.QueryNamed(context.Background(), "CommentsForPullRequest", &query, variables) 387 if err != nil { 388 return err 389 } 390 391 pr.Comments.Nodes = append(pr.Comments.Nodes, query.Node.PullRequest.Comments.Nodes...) 392 pr.Comments.TotalCount = len(pr.Comments.Nodes) 393 394 if !query.Node.PullRequest.Comments.PageInfo.HasNextPage { 395 break 396 } 397 variables["endCursor"] = githubv4.String(query.Node.PullRequest.Comments.PageInfo.EndCursor) 398 } 399 400 pr.Comments.PageInfo.HasNextPage = false 401 return nil 402 } 403 404 func preloadPrChecks(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { 405 if len(pr.StatusCheckRollup.Nodes) == 0 { 406 return nil 407 } 408 statusCheckRollup := &pr.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts 409 if !statusCheckRollup.PageInfo.HasNextPage { 410 return nil 411 } 412 413 endCursor := statusCheckRollup.PageInfo.EndCursor 414 415 type response struct { 416 Node *api.PullRequest 417 } 418 419 query := fmt.Sprintf(` 420 query PullRequestStatusChecks($id: ID!, $endCursor: String!) { 421 node(id: $id) { 422 ...on PullRequest { 423 %s 424 } 425 } 426 }`, api.StatusCheckRollupGraphQL("$endCursor")) 427 428 variables := map[string]interface{}{ 429 "id": pr.ID, 430 } 431 432 apiClient := api.NewClientFromHTTP(client) 433 for { 434 variables["endCursor"] = endCursor 435 var resp response 436 err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp) 437 if err != nil { 438 return err 439 } 440 441 result := resp.Node.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts 442 statusCheckRollup.Nodes = append( 443 statusCheckRollup.Nodes, 444 result.Nodes..., 445 ) 446 447 if !result.PageInfo.HasNextPage { 448 break 449 } 450 endCursor = result.PageInfo.EndCursor 451 } 452 453 statusCheckRollup.PageInfo.HasNextPage = false 454 return nil 455 } 456 457 type NotFoundError struct { 458 error 459 } 460 461 func (err *NotFoundError) Unwrap() error { 462 return err.error 463 } 464 465 func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { 466 var err error 467 if pr == nil { 468 err = &NotFoundError{errors.New("no pull requests found")} 469 } 470 return &mockFinder{ 471 expectSelector: selector, 472 pr: pr, 473 repo: repo, 474 err: err, 475 } 476 } 477 478 type mockFinder struct { 479 called bool 480 expectSelector string 481 expectFields []string 482 pr *api.PullRequest 483 repo ghrepo.Interface 484 err error 485 } 486 487 func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { 488 if m.err != nil { 489 return nil, nil, m.err 490 } 491 if m.expectSelector != opts.Selector { 492 return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector) 493 } 494 if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) { 495 return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields) 496 } 497 if m.called { 498 return nil, nil, errors.New("mockFinder used more than once") 499 } 500 m.called = true 501 502 if m.pr.HeadRepositoryOwner.Login == "" { 503 // pose as same-repo PR by default 504 m.pr.HeadRepositoryOwner.Login = m.repo.RepoOwner() 505 } 506 507 return m.pr, m.repo, nil 508 } 509 510 func (m *mockFinder) ExpectFields(fields []string) { 511 m.expectFields = fields 512 } 513 514 func isEqualSet(a, b []string) bool { 515 if len(a) != len(b) { 516 return false 517 } 518 519 aCopy := make([]string, len(a)) 520 copy(aCopy, a) 521 bCopy := make([]string, len(b)) 522 copy(bCopy, b) 523 sort.Strings(aCopy) 524 sort.Strings(bCopy) 525 526 for i := range aCopy { 527 if aCopy[i] != bCopy[i] { 528 return false 529 } 530 } 531 return true 532 }