github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/cmd/pr/shared/finder.go (about) 1 package shared 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net/http" 8 "net/url" 9 "regexp" 10 "sort" 11 "strconv" 12 "strings" 13 "time" 14 15 "github.com/ungtb10d/cli/v2/api" 16 remotes "github.com/ungtb10d/cli/v2/context" 17 "github.com/ungtb10d/cli/v2/git" 18 fd "github.com/ungtb10d/cli/v2/internal/featuredetection" 19 "github.com/ungtb10d/cli/v2/internal/ghrepo" 20 "github.com/ungtb10d/cli/v2/pkg/cmdutil" 21 "github.com/ungtb10d/cli/v2/pkg/set" 22 "github.com/shurcooL/githubv4" 23 "golang.org/x/sync/errgroup" 24 ) 25 26 type PRFinder interface { 27 Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) 28 } 29 30 type progressIndicator interface { 31 StartProgressIndicator() 32 StopProgressIndicator() 33 } 34 35 type finder struct { 36 baseRepoFn func() (ghrepo.Interface, error) 37 branchFn func() (string, error) 38 remotesFn func() (remotes.Remotes, error) 39 httpClient func() (*http.Client, error) 40 branchConfig func(string) git.BranchConfig 41 progress progressIndicator 42 43 repo ghrepo.Interface 44 prNumber int 45 branchName string 46 } 47 48 func NewFinder(factory *cmdutil.Factory) PRFinder { 49 if runCommandFinder != nil { 50 f := runCommandFinder 51 runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")} 52 return f 53 } 54 55 return &finder{ 56 baseRepoFn: factory.BaseRepo, 57 branchFn: factory.Branch, 58 remotesFn: factory.Remotes, 59 httpClient: factory.HttpClient, 60 progress: factory.IOStreams, 61 branchConfig: func(s string) git.BranchConfig { 62 return factory.GitClient.ReadBranchConfig(context.Background(), s) 63 }, 64 } 65 } 66 67 var runCommandFinder PRFinder 68 69 // RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests. 70 func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { 71 finder := NewMockFinder(selector, pr, repo) 72 runCommandFinder = finder 73 return finder 74 } 75 76 type FindOptions struct { 77 // Selector can be a number with optional `#` prefix, a branch name with optional `<owner>:` prefix, or 78 // a PR URL. 79 Selector string 80 // Fields lists the GraphQL fields to fetch for the PullRequest. 81 Fields []string 82 // BaseBranch is the name of the base branch to scope the PR-for-branch lookup to. 83 BaseBranch string 84 // States lists the possible PR states to scope the PR-for-branch lookup to. 85 States []string 86 } 87 88 func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { 89 if len(opts.Fields) == 0 { 90 return nil, nil, errors.New("Find error: no fields specified") 91 } 92 93 if repo, prNumber, err := f.parseURL(opts.Selector); err == nil { 94 f.prNumber = prNumber 95 f.repo = repo 96 } 97 98 if f.repo == nil { 99 repo, err := f.baseRepoFn() 100 if err != nil { 101 return nil, nil, fmt.Errorf("could not determine base repo: %w", err) 102 } 103 f.repo = repo 104 } 105 106 if opts.Selector == "" { 107 if branch, prNumber, err := f.parseCurrentBranch(); err != nil { 108 return nil, nil, err 109 } else if prNumber > 0 { 110 f.prNumber = prNumber 111 } else { 112 f.branchName = branch 113 } 114 } else if f.prNumber == 0 { 115 // If opts.Selector is a valid number then assume it is the 116 // PR number unless opts.BaseBranch is specified. This is a 117 // special case for PR create command which will always want 118 // to assume that a numerical selector is a branch name rather 119 // than PR number. 120 prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#")) 121 if opts.BaseBranch == "" && err == nil { 122 f.prNumber = prNumber 123 } else { 124 f.branchName = opts.Selector 125 } 126 } 127 128 httpClient, err := f.httpClient() 129 if err != nil { 130 return nil, nil, err 131 } 132 133 // TODO(josebalius): Should we be guarding here? 134 if f.progress != nil { 135 f.progress.StartProgressIndicator() 136 defer f.progress.StopProgressIndicator() 137 } 138 139 fields := set.NewStringSet() 140 fields.AddValues(opts.Fields) 141 numberFieldOnly := fields.Len() == 1 && fields.Contains("number") 142 fields.Add("id") // for additional preload queries below 143 144 if fields.Contains("isInMergeQueue") || fields.Contains("isMergeQueueEnabled") { 145 cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) 146 detector := fd.NewDetector(cachedClient, f.repo.RepoHost()) 147 prFeatures, err := detector.PullRequestFeatures() 148 if err != nil { 149 return nil, nil, err 150 } 151 if !prFeatures.MergeQueue { 152 fields.Remove("isInMergeQueue") 153 fields.Remove("isMergeQueueEnabled") 154 } 155 } 156 157 var pr *api.PullRequest 158 if f.prNumber > 0 { 159 if numberFieldOnly { 160 // avoid hitting the API if we already have all the information 161 return &api.PullRequest{Number: f.prNumber}, f.repo, nil 162 } 163 pr, err = findByNumber(httpClient, f.repo, f.prNumber, fields.ToSlice()) 164 } else { 165 pr, err = findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, fields.ToSlice()) 166 } 167 if err != nil { 168 return pr, f.repo, err 169 } 170 171 g, _ := errgroup.WithContext(context.Background()) 172 if fields.Contains("reviews") { 173 g.Go(func() error { 174 return preloadPrReviews(httpClient, f.repo, pr) 175 }) 176 } 177 if fields.Contains("comments") { 178 g.Go(func() error { 179 return preloadPrComments(httpClient, f.repo, pr) 180 }) 181 } 182 if fields.Contains("statusCheckRollup") { 183 g.Go(func() error { 184 return preloadPrChecks(httpClient, f.repo, pr) 185 }) 186 } 187 188 return pr, f.repo, g.Wait() 189 } 190 191 var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`) 192 193 func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) { 194 if prURL == "" { 195 return nil, 0, fmt.Errorf("invalid URL: %q", prURL) 196 } 197 198 u, err := url.Parse(prURL) 199 if err != nil { 200 return nil, 0, err 201 } 202 203 if u.Scheme != "https" && u.Scheme != "http" { 204 return nil, 0, fmt.Errorf("invalid scheme: %s", u.Scheme) 205 } 206 207 m := pullURLRE.FindStringSubmatch(u.Path) 208 if m == nil { 209 return nil, 0, fmt.Errorf("not a pull request URL: %s", prURL) 210 } 211 212 repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname()) 213 prNumber, _ := strconv.Atoi(m[3]) 214 return repo, prNumber, nil 215 } 216 217 var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`) 218 219 func (f *finder) parseCurrentBranch() (string, int, error) { 220 prHeadRef, err := f.branchFn() 221 if err != nil { 222 return "", 0, err 223 } 224 225 branchConfig := f.branchConfig(prHeadRef) 226 227 // the branch is configured to merge a special PR head ref 228 if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { 229 prNumber, _ := strconv.Atoi(m[1]) 230 return "", prNumber, nil 231 } 232 233 var branchOwner string 234 if branchConfig.RemoteURL != nil { 235 // the branch merges from a remote specified by URL 236 if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil { 237 branchOwner = r.RepoOwner() 238 } 239 } else if branchConfig.RemoteName != "" { 240 // the branch merges from a remote specified by name 241 rem, _ := f.remotesFn() 242 if r, err := rem.FindByName(branchConfig.RemoteName); err == nil { 243 branchOwner = r.RepoOwner() 244 } 245 } 246 247 if branchOwner != "" { 248 if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") { 249 prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") 250 } 251 // prepend `OWNER:` if this branch is pushed to a fork 252 if !strings.EqualFold(branchOwner, f.repo.RepoOwner()) { 253 prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef) 254 } 255 } 256 257 return prHeadRef, 0, nil 258 } 259 260 func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) { 261 type response struct { 262 Repository struct { 263 PullRequest api.PullRequest 264 } 265 } 266 267 query := fmt.Sprintf(` 268 query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) { 269 repository(owner: $owner, name: $repo) { 270 pullRequest(number: $pr_number) {%s} 271 } 272 }`, api.PullRequestGraphQL(fields)) 273 274 variables := map[string]interface{}{ 275 "owner": repo.RepoOwner(), 276 "repo": repo.RepoName(), 277 "pr_number": number, 278 } 279 280 var resp response 281 client := api.NewClientFromHTTP(httpClient) 282 err := client.GraphQL(repo.RepoHost(), query, variables, &resp) 283 if err != nil { 284 return nil, err 285 } 286 287 return &resp.Repository.PullRequest, nil 288 } 289 290 func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) { 291 type response struct { 292 Repository struct { 293 PullRequests struct { 294 Nodes []api.PullRequest 295 } 296 DefaultBranchRef struct { 297 Name string 298 } 299 } 300 } 301 302 fieldSet := set.NewStringSet() 303 fieldSet.AddValues(fields) 304 // these fields are required for filtering below 305 fieldSet.AddValues([]string{"state", "baseRefName", "headRefName", "isCrossRepository", "headRepositoryOwner"}) 306 307 query := fmt.Sprintf(` 308 query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) { 309 repository(owner: $owner, name: $repo) { 310 pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) { 311 nodes {%s} 312 } 313 defaultBranchRef { name } 314 } 315 }`, api.PullRequestGraphQL(fieldSet.ToSlice())) 316 317 branchWithoutOwner := headBranch 318 if idx := strings.Index(headBranch, ":"); idx >= 0 { 319 branchWithoutOwner = headBranch[idx+1:] 320 } 321 322 variables := map[string]interface{}{ 323 "owner": repo.RepoOwner(), 324 "repo": repo.RepoName(), 325 "headRefName": branchWithoutOwner, 326 "states": stateFilters, 327 } 328 329 var resp response 330 client := api.NewClientFromHTTP(httpClient) 331 err := client.GraphQL(repo.RepoHost(), query, variables, &resp) 332 if err != nil { 333 return nil, err 334 } 335 336 prs := resp.Repository.PullRequests.Nodes 337 sort.SliceStable(prs, func(a, b int) bool { 338 return prs[a].State == "OPEN" && prs[b].State != "OPEN" 339 }) 340 341 for _, pr := range prs { 342 if pr.HeadLabel() == headBranch && (baseBranch == "" || pr.BaseRefName == baseBranch) && (pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranch) { 343 return &pr, nil 344 } 345 } 346 347 return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)} 348 } 349 350 func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { 351 if !pr.Reviews.PageInfo.HasNextPage { 352 return nil 353 } 354 355 type response struct { 356 Node struct { 357 PullRequest struct { 358 Reviews api.PullRequestReviews `graphql:"reviews(first: 100, after: $endCursor)"` 359 } `graphql:"...on PullRequest"` 360 } `graphql:"node(id: $id)"` 361 } 362 363 variables := map[string]interface{}{ 364 "id": githubv4.ID(pr.ID), 365 "endCursor": githubv4.String(pr.Reviews.PageInfo.EndCursor), 366 } 367 368 gql := api.NewClientFromHTTP(httpClient) 369 370 for { 371 var query response 372 err := gql.Query(repo.RepoHost(), "ReviewsForPullRequest", &query, variables) 373 if err != nil { 374 return err 375 } 376 377 pr.Reviews.Nodes = append(pr.Reviews.Nodes, query.Node.PullRequest.Reviews.Nodes...) 378 pr.Reviews.TotalCount = len(pr.Reviews.Nodes) 379 380 if !query.Node.PullRequest.Reviews.PageInfo.HasNextPage { 381 break 382 } 383 variables["endCursor"] = githubv4.String(query.Node.PullRequest.Reviews.PageInfo.EndCursor) 384 } 385 386 pr.Reviews.PageInfo.HasNextPage = false 387 return nil 388 } 389 390 func preloadPrComments(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { 391 if !pr.Comments.PageInfo.HasNextPage { 392 return nil 393 } 394 395 type response struct { 396 Node struct { 397 PullRequest struct { 398 Comments api.Comments `graphql:"comments(first: 100, after: $endCursor)"` 399 } `graphql:"...on PullRequest"` 400 } `graphql:"node(id: $id)"` 401 } 402 403 variables := map[string]interface{}{ 404 "id": githubv4.ID(pr.ID), 405 "endCursor": githubv4.String(pr.Comments.PageInfo.EndCursor), 406 } 407 408 gql := api.NewClientFromHTTP(client) 409 410 for { 411 var query response 412 err := gql.Query(repo.RepoHost(), "CommentsForPullRequest", &query, variables) 413 if err != nil { 414 return err 415 } 416 417 pr.Comments.Nodes = append(pr.Comments.Nodes, query.Node.PullRequest.Comments.Nodes...) 418 pr.Comments.TotalCount = len(pr.Comments.Nodes) 419 420 if !query.Node.PullRequest.Comments.PageInfo.HasNextPage { 421 break 422 } 423 variables["endCursor"] = githubv4.String(query.Node.PullRequest.Comments.PageInfo.EndCursor) 424 } 425 426 pr.Comments.PageInfo.HasNextPage = false 427 return nil 428 } 429 430 func preloadPrChecks(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { 431 if len(pr.StatusCheckRollup.Nodes) == 0 { 432 return nil 433 } 434 statusCheckRollup := &pr.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts 435 if !statusCheckRollup.PageInfo.HasNextPage { 436 return nil 437 } 438 439 endCursor := statusCheckRollup.PageInfo.EndCursor 440 441 type response struct { 442 Node *api.PullRequest 443 } 444 445 query := fmt.Sprintf(` 446 query PullRequestStatusChecks($id: ID!, $endCursor: String!) { 447 node(id: $id) { 448 ...on PullRequest { 449 %s 450 } 451 } 452 }`, api.StatusCheckRollupGraphQL("$endCursor")) 453 454 variables := map[string]interface{}{ 455 "id": pr.ID, 456 } 457 458 apiClient := api.NewClientFromHTTP(client) 459 for { 460 variables["endCursor"] = endCursor 461 var resp response 462 err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp) 463 if err != nil { 464 return err 465 } 466 467 result := resp.Node.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts 468 statusCheckRollup.Nodes = append( 469 statusCheckRollup.Nodes, 470 result.Nodes..., 471 ) 472 473 if !result.PageInfo.HasNextPage { 474 break 475 } 476 endCursor = result.PageInfo.EndCursor 477 } 478 479 statusCheckRollup.PageInfo.HasNextPage = false 480 return nil 481 } 482 483 type NotFoundError struct { 484 error 485 } 486 487 func (err *NotFoundError) Unwrap() error { 488 return err.error 489 } 490 491 func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { 492 var err error 493 if pr == nil { 494 err = &NotFoundError{errors.New("no pull requests found")} 495 } 496 return &mockFinder{ 497 expectSelector: selector, 498 pr: pr, 499 repo: repo, 500 err: err, 501 } 502 } 503 504 type mockFinder struct { 505 called bool 506 expectSelector string 507 expectFields []string 508 pr *api.PullRequest 509 repo ghrepo.Interface 510 err error 511 } 512 513 func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { 514 if m.err != nil { 515 return nil, nil, m.err 516 } 517 if m.expectSelector != opts.Selector { 518 return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector) 519 } 520 if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) { 521 return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields) 522 } 523 if m.called { 524 return nil, nil, errors.New("mockFinder used more than once") 525 } 526 m.called = true 527 528 if m.pr.HeadRepositoryOwner.Login == "" { 529 // pose as same-repo PR by default 530 m.pr.HeadRepositoryOwner.Login = m.repo.RepoOwner() 531 } 532 533 return m.pr, m.repo, nil 534 } 535 536 func (m *mockFinder) ExpectFields(fields []string) { 537 m.expectFields = fields 538 } 539 540 func isEqualSet(a, b []string) bool { 541 if len(a) != len(b) { 542 return false 543 } 544 545 aCopy := make([]string, len(a)) 546 copy(aCopy, a) 547 bCopy := make([]string, len(b)) 548 copy(bCopy, b) 549 sort.Strings(aCopy) 550 sort.Strings(bCopy) 551 552 for i := range aCopy { 553 if aCopy[i] != bCopy[i] { 554 return false 555 } 556 } 557 return true 558 }