github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/cmd/pr/checks/checks.go (about)

     1  package checks
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/http"
     7  	"time"
     8  
     9  	"github.com/MakeNowJust/heredoc"
    10  	"github.com/ungtb10d/cli/v2/api"
    11  	"github.com/ungtb10d/cli/v2/internal/browser"
    12  	"github.com/ungtb10d/cli/v2/internal/ghrepo"
    13  	"github.com/ungtb10d/cli/v2/internal/text"
    14  	"github.com/ungtb10d/cli/v2/pkg/cmd/pr/shared"
    15  	"github.com/ungtb10d/cli/v2/pkg/cmdutil"
    16  	"github.com/ungtb10d/cli/v2/pkg/iostreams"
    17  	"github.com/spf13/cobra"
    18  )
    19  
    20  const defaultInterval time.Duration = 10 * time.Second
    21  
    22  type ChecksOptions struct {
    23  	HttpClient func() (*http.Client, error)
    24  	IO         *iostreams.IOStreams
    25  	Browser    browser.Browser
    26  
    27  	Finder shared.PRFinder
    28  
    29  	SelectorArg string
    30  	WebMode     bool
    31  	Interval    time.Duration
    32  	Watch       bool
    33  	Required    bool
    34  }
    35  
    36  func NewCmdChecks(f *cmdutil.Factory, runF func(*ChecksOptions) error) *cobra.Command {
    37  	var interval int
    38  	opts := &ChecksOptions{
    39  		HttpClient: f.HttpClient,
    40  		IO:         f.IOStreams,
    41  		Browser:    f.Browser,
    42  		Interval:   defaultInterval,
    43  	}
    44  
    45  	cmd := &cobra.Command{
    46  		Use:   "checks [<number> | <url> | <branch>]",
    47  		Short: "Show CI status for a single pull request",
    48  		Long: heredoc.Doc(`
    49  			Show CI status for a single pull request.
    50  
    51  			Without an argument, the pull request that belongs to the current branch
    52  			is selected.
    53  		`),
    54  		Args: cobra.MaximumNArgs(1),
    55  		RunE: func(cmd *cobra.Command, args []string) error {
    56  			opts.Finder = shared.NewFinder(f)
    57  
    58  			if repoOverride, _ := cmd.Flags().GetString("repo"); repoOverride != "" && len(args) == 0 {
    59  				return cmdutil.FlagErrorf("argument required when using the `--repo` flag")
    60  			}
    61  
    62  			intervalChanged := cmd.Flags().Changed("interval")
    63  			if !opts.Watch && intervalChanged {
    64  				return cmdutil.FlagErrorf("cannot use `--interval` flag without `--watch` flag")
    65  			}
    66  
    67  			if intervalChanged {
    68  				var err error
    69  				opts.Interval, err = time.ParseDuration(fmt.Sprintf("%ds", interval))
    70  				if err != nil {
    71  					return cmdutil.FlagErrorf("could not parse `--interval` flag: %w", err)
    72  				}
    73  			}
    74  
    75  			if len(args) > 0 {
    76  				opts.SelectorArg = args[0]
    77  			}
    78  
    79  			if runF != nil {
    80  				return runF(opts)
    81  			}
    82  
    83  			return checksRun(opts)
    84  		},
    85  	}
    86  
    87  	cmd.Flags().BoolVarP(&opts.WebMode, "web", "w", false, "Open the web browser to show details about checks")
    88  	cmd.Flags().BoolVarP(&opts.Watch, "watch", "", false, "Watch checks until they finish")
    89  	cmd.Flags().IntVarP(&interval, "interval", "i", 10, "Refresh interval in seconds when using `--watch` flag")
    90  	cmd.Flags().BoolVar(&opts.Required, "required", false, "Only show checks that are required")
    91  
    92  	return cmd
    93  }
    94  
    95  func checksRunWebMode(opts *ChecksOptions) error {
    96  	findOptions := shared.FindOptions{
    97  		Selector: opts.SelectorArg,
    98  		Fields:   []string{"number"},
    99  	}
   100  	pr, baseRepo, err := opts.Finder.Find(findOptions)
   101  	if err != nil {
   102  		return err
   103  	}
   104  
   105  	isTerminal := opts.IO.IsStdoutTTY()
   106  	openURL := ghrepo.GenerateRepoURL(baseRepo, "pull/%d/checks", pr.Number)
   107  
   108  	if isTerminal {
   109  		fmt.Fprintf(opts.IO.ErrOut, "Opening %s in your browser.\n", text.DisplayURL(openURL))
   110  	}
   111  
   112  	return opts.Browser.Browse(openURL)
   113  }
   114  
   115  func checksRun(opts *ChecksOptions) error {
   116  	if opts.WebMode {
   117  		return checksRunWebMode(opts)
   118  	}
   119  
   120  	findOptions := shared.FindOptions{
   121  		Selector: opts.SelectorArg,
   122  		Fields:   []string{"number", "headRefName"},
   123  	}
   124  
   125  	var pr *api.PullRequest
   126  	pr, repo, findErr := opts.Finder.Find(findOptions)
   127  	if findErr != nil {
   128  		return findErr
   129  	}
   130  
   131  	client, clientErr := opts.HttpClient()
   132  	if clientErr != nil {
   133  		return clientErr
   134  	}
   135  
   136  	var checks []check
   137  	var counts checkCounts
   138  	var err error
   139  
   140  	checks, counts, err = populateStatusChecks(client, repo, pr, opts.Required)
   141  	if err != nil {
   142  		return err
   143  	}
   144  
   145  	if opts.Watch {
   146  		opts.IO.StartAlternateScreenBuffer()
   147  	} else {
   148  		// Only start pager in non-watch mode
   149  		if err := opts.IO.StartPager(); err == nil {
   150  			defer opts.IO.StopPager()
   151  		} else {
   152  			fmt.Fprintf(opts.IO.ErrOut, "failed to start pager: %v\n", err)
   153  		}
   154  	}
   155  
   156  	// Do not return err until we can StopAlternateScreenBuffer()
   157  	for {
   158  		if counts.Pending != 0 && opts.Watch {
   159  			opts.IO.RefreshScreen()
   160  			cs := opts.IO.ColorScheme()
   161  			fmt.Fprintln(opts.IO.Out, cs.Boldf("Refreshing checks status every %v seconds. Press Ctrl+C to quit.\n", opts.Interval.Seconds()))
   162  		}
   163  
   164  		printSummary(opts.IO, counts)
   165  		err = printTable(opts.IO, checks)
   166  		if err != nil {
   167  			break
   168  		}
   169  
   170  		if counts.Pending == 0 || !opts.Watch {
   171  			break
   172  		}
   173  
   174  		time.Sleep(opts.Interval)
   175  
   176  		checks, counts, err = populateStatusChecks(client, repo, pr, opts.Required)
   177  		if err != nil {
   178  			break
   179  		}
   180  	}
   181  
   182  	opts.IO.StopAlternateScreenBuffer()
   183  	if err != nil {
   184  		return err
   185  	}
   186  
   187  	if opts.Watch {
   188  		// Print final summary to original screen buffer
   189  		printSummary(opts.IO, counts)
   190  		err = printTable(opts.IO, checks)
   191  		if err != nil {
   192  			return err
   193  		}
   194  	}
   195  
   196  	if counts.Failed+counts.Pending > 0 {
   197  		return cmdutil.SilentError
   198  	}
   199  
   200  	return nil
   201  }
   202  
   203  func populateStatusChecks(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest, requiredChecks bool) ([]check, checkCounts, error) {
   204  	apiClient := api.NewClientFromHTTP(client)
   205  
   206  	type response struct {
   207  		Node *api.PullRequest
   208  	}
   209  
   210  	query := fmt.Sprintf(`
   211  	query PullRequestStatusChecks($id: ID!, $endCursor: String) {
   212  		node(id: $id) {
   213  			...on PullRequest {
   214  				%s
   215  			}
   216  		}
   217  	}`, api.RequiredStatusCheckRollupGraphQL("$id", "$endCursor"))
   218  
   219  	variables := map[string]interface{}{
   220  		"id": pr.ID,
   221  	}
   222  
   223  	statusCheckRollup := api.CheckContexts{}
   224  
   225  	for {
   226  		var resp response
   227  		err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp)
   228  		if err != nil {
   229  			return nil, checkCounts{}, err
   230  		}
   231  
   232  		if len(resp.Node.StatusCheckRollup.Nodes) == 0 {
   233  			return nil, checkCounts{}, errors.New("no commit found on the pull request")
   234  		}
   235  
   236  		result := resp.Node.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts
   237  		statusCheckRollup.Nodes = append(
   238  			statusCheckRollup.Nodes,
   239  			result.Nodes...,
   240  		)
   241  
   242  		if !result.PageInfo.HasNextPage {
   243  			break
   244  		}
   245  		variables["endCursor"] = result.PageInfo.EndCursor
   246  	}
   247  
   248  	if len(statusCheckRollup.Nodes) == 0 {
   249  		return nil, checkCounts{}, fmt.Errorf("no checks reported on the '%s' branch", pr.HeadRefName)
   250  	}
   251  
   252  	checks, counts := aggregateChecks(statusCheckRollup.Nodes, requiredChecks)
   253  	if len(checks) == 0 && requiredChecks {
   254  		return checks, counts, fmt.Errorf("no required checks reported on the '%s' branch", pr.HeadRefName)
   255  	}
   256  	return checks, counts, nil
   257  }