github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/cmd/pr/status/http.go (about)

     1  package status
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"strings"
     7  
     8  	"github.com/ungtb10d/cli/v2/api"
     9  	"github.com/ungtb10d/cli/v2/internal/ghinstance"
    10  	"github.com/ungtb10d/cli/v2/internal/ghrepo"
    11  	"github.com/ungtb10d/cli/v2/pkg/set"
    12  )
    13  
    14  type requestOptions struct {
    15  	CurrentPR      int
    16  	HeadRef        string
    17  	Username       string
    18  	Fields         []string
    19  	ConflictStatus bool
    20  }
    21  
    22  type pullRequestsPayload struct {
    23  	ViewerCreated   api.PullRequestAndTotalCount
    24  	ReviewRequested api.PullRequestAndTotalCount
    25  	CurrentPR       *api.PullRequest
    26  	DefaultBranch   string
    27  }
    28  
    29  func pullRequestStatus(httpClient *http.Client, repo ghrepo.Interface, options requestOptions) (*pullRequestsPayload, error) {
    30  	apiClient := api.NewClientFromHTTP(httpClient)
    31  	type edges struct {
    32  		TotalCount int
    33  		Edges      []struct {
    34  			Node api.PullRequest
    35  		}
    36  	}
    37  
    38  	type response struct {
    39  		Repository struct {
    40  			DefaultBranchRef struct {
    41  				Name string
    42  			}
    43  			PullRequests edges
    44  			PullRequest  *api.PullRequest
    45  		}
    46  		ViewerCreated   edges
    47  		ReviewRequested edges
    48  	}
    49  
    50  	var fragments string
    51  	if len(options.Fields) > 0 {
    52  		fields := set.NewStringSet()
    53  		fields.AddValues(options.Fields)
    54  		// these are always necessary to find the PR for the current branch
    55  		fields.AddValues([]string{"isCrossRepository", "headRepositoryOwner", "headRefName"})
    56  		gr := api.PullRequestGraphQL(fields.ToSlice())
    57  		fragments = fmt.Sprintf("fragment pr on PullRequest{%s}fragment prWithReviews on PullRequest{...pr}", gr)
    58  	} else {
    59  		var err error
    60  		fragments, err = pullRequestFragment(repo.RepoHost(), options.ConflictStatus)
    61  		if err != nil {
    62  			return nil, err
    63  		}
    64  	}
    65  
    66  	queryPrefix := `
    67  	query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
    68  		repository(owner: $owner, name: $repo) {
    69  			defaultBranchRef {
    70  				name
    71  			}
    72  			pullRequests(headRefName: $headRefName, first: $per_page, orderBy: { field: CREATED_AT, direction: DESC }) {
    73  				totalCount
    74  				edges {
    75  					node {
    76  						...prWithReviews
    77  					}
    78  				}
    79  			}
    80  		}
    81  	`
    82  	if options.CurrentPR > 0 {
    83  		queryPrefix = `
    84  		query PullRequestStatus($owner: String!, $repo: String!, $number: Int!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
    85  			repository(owner: $owner, name: $repo) {
    86  				defaultBranchRef {
    87  					name
    88  				}
    89  				pullRequest(number: $number) {
    90  					...prWithReviews
    91  					baseRef {
    92  						branchProtectionRule {
    93  							requiredApprovingReviewCount
    94  						}
    95  					}
    96  				}
    97  			}
    98  		`
    99  	}
   100  
   101  	query := fragments + queryPrefix + `
   102        viewerCreated: search(query: $viewerQuery, type: ISSUE, first: $per_page) {
   103         totalCount: issueCount
   104          edges {
   105            node {
   106              ...prWithReviews
   107            }
   108          }
   109        }
   110        reviewRequested: search(query: $reviewerQuery, type: ISSUE, first: $per_page) {
   111          totalCount: issueCount
   112          edges {
   113            node {
   114              ...pr
   115            }
   116          }
   117        }
   118      }
   119  	`
   120  
   121  	currentUsername := options.Username
   122  	if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) {
   123  		var err error
   124  		currentUsername, err = api.CurrentLoginName(apiClient, repo.RepoHost())
   125  		if err != nil {
   126  			return nil, err
   127  		}
   128  	}
   129  
   130  	viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername)
   131  	reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername)
   132  
   133  	currentPRHeadRef := options.HeadRef
   134  	branchWithoutOwner := currentPRHeadRef
   135  	if idx := strings.Index(currentPRHeadRef, ":"); idx >= 0 {
   136  		branchWithoutOwner = currentPRHeadRef[idx+1:]
   137  	}
   138  
   139  	variables := map[string]interface{}{
   140  		"viewerQuery":   viewerQuery,
   141  		"reviewerQuery": reviewerQuery,
   142  		"owner":         repo.RepoOwner(),
   143  		"repo":          repo.RepoName(),
   144  		"headRefName":   branchWithoutOwner,
   145  		"number":        options.CurrentPR,
   146  	}
   147  
   148  	var resp response
   149  	err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	var viewerCreated []api.PullRequest
   155  	for _, edge := range resp.ViewerCreated.Edges {
   156  		viewerCreated = append(viewerCreated, edge.Node)
   157  	}
   158  
   159  	var reviewRequested []api.PullRequest
   160  	for _, edge := range resp.ReviewRequested.Edges {
   161  		reviewRequested = append(reviewRequested, edge.Node)
   162  	}
   163  
   164  	var currentPR = resp.Repository.PullRequest
   165  	if currentPR == nil {
   166  		for _, edge := range resp.Repository.PullRequests.Edges {
   167  			if edge.Node.HeadLabel() == currentPRHeadRef {
   168  				currentPR = &edge.Node
   169  				break // Take the most recent PR for the current branch
   170  			}
   171  		}
   172  	}
   173  
   174  	payload := pullRequestsPayload{
   175  		ViewerCreated: api.PullRequestAndTotalCount{
   176  			PullRequests: viewerCreated,
   177  			TotalCount:   resp.ViewerCreated.TotalCount,
   178  		},
   179  		ReviewRequested: api.PullRequestAndTotalCount{
   180  			PullRequests: reviewRequested,
   181  			TotalCount:   resp.ReviewRequested.TotalCount,
   182  		},
   183  		CurrentPR:     currentPR,
   184  		DefaultBranch: resp.Repository.DefaultBranchRef.Name,
   185  	}
   186  
   187  	return &payload, nil
   188  }
   189  
   190  func pullRequestFragment(hostname string, conflictStatus bool) (string, error) {
   191  	fields := []string{
   192  		"number", "title", "state", "url", "isDraft", "isCrossRepository",
   193  		"headRefName", "headRepositoryOwner", "mergeStateStatus",
   194  		"statusCheckRollup", "requiresStrictStatusChecks",
   195  	}
   196  
   197  	if conflictStatus {
   198  		fields = append(fields, "mergeable")
   199  	}
   200  	reviewFields := []string{"reviewDecision", "latestReviews"}
   201  	fragments := fmt.Sprintf(`
   202  	fragment pr on PullRequest {%s}
   203  	fragment prWithReviews on PullRequest {...pr,%s}
   204  	`, api.PullRequestGraphQL(fields), api.PullRequestGraphQL(reviewFields))
   205  	return fragments, nil
   206  }