github.com/cli/cli@v1.14.1-0.20210902173923-1af6a669e342/pkg/cmd/pr/checkout/checkout.go (about)

     1  package checkout
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"os"
     7  	"os/exec"
     8  	"strings"
     9  
    10  	"github.com/cli/cli/api"
    11  	"github.com/cli/cli/context"
    12  	"github.com/cli/cli/git"
    13  	"github.com/cli/cli/internal/config"
    14  	"github.com/cli/cli/internal/ghrepo"
    15  	"github.com/cli/cli/internal/run"
    16  	"github.com/cli/cli/pkg/cmd/pr/shared"
    17  	"github.com/cli/cli/pkg/cmdutil"
    18  	"github.com/cli/cli/pkg/iostreams"
    19  	"github.com/cli/safeexec"
    20  	"github.com/spf13/cobra"
    21  )
    22  
    23  type CheckoutOptions struct {
    24  	HttpClient func() (*http.Client, error)
    25  	Config     func() (config.Config, error)
    26  	IO         *iostreams.IOStreams
    27  	Remotes    func() (context.Remotes, error)
    28  	Branch     func() (string, error)
    29  
    30  	Finder shared.PRFinder
    31  
    32  	SelectorArg       string
    33  	RecurseSubmodules bool
    34  	Force             bool
    35  	Detach            bool
    36  	BranchName        string
    37  }
    38  
    39  func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobra.Command {
    40  	opts := &CheckoutOptions{
    41  		IO:         f.IOStreams,
    42  		HttpClient: f.HttpClient,
    43  		Config:     f.Config,
    44  		Remotes:    f.Remotes,
    45  		Branch:     f.Branch,
    46  	}
    47  
    48  	cmd := &cobra.Command{
    49  		Use:   "checkout {<number> | <url> | <branch>}",
    50  		Short: "Check out a pull request in git",
    51  		Args:  cmdutil.ExactArgs(1, "argument required"),
    52  		RunE: func(cmd *cobra.Command, args []string) error {
    53  			opts.Finder = shared.NewFinder(f)
    54  
    55  			if len(args) > 0 {
    56  				opts.SelectorArg = args[0]
    57  			}
    58  
    59  			if runF != nil {
    60  				return runF(opts)
    61  			}
    62  			return checkoutRun(opts)
    63  		},
    64  	}
    65  
    66  	cmd.Flags().BoolVarP(&opts.RecurseSubmodules, "recurse-submodules", "", false, "Update all submodules after checkout")
    67  	cmd.Flags().BoolVarP(&opts.Force, "force", "f", false, "Reset the existing local branch to the latest state of the pull request")
    68  	cmd.Flags().BoolVarP(&opts.Detach, "detach", "", false, "Checkout PR with a detached HEAD")
    69  	cmd.Flags().StringVarP(&opts.BranchName, "branch", "b", "", "Local branch name to use (default: the name of the head branch)")
    70  
    71  	return cmd
    72  }
    73  
    74  func checkoutRun(opts *CheckoutOptions) error {
    75  	findOptions := shared.FindOptions{
    76  		Selector: opts.SelectorArg,
    77  		Fields:   []string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"},
    78  	}
    79  	pr, baseRepo, err := opts.Finder.Find(findOptions)
    80  	if err != nil {
    81  		return err
    82  	}
    83  
    84  	cfg, err := opts.Config()
    85  	if err != nil {
    86  		return err
    87  	}
    88  	protocol, _ := cfg.Get(baseRepo.RepoHost(), "git_protocol")
    89  
    90  	remotes, err := opts.Remotes()
    91  	if err != nil {
    92  		return err
    93  	}
    94  	baseRemote, _ := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName())
    95  	baseURLOrName := ghrepo.FormatRemoteURL(baseRepo, protocol)
    96  	if baseRemote != nil {
    97  		baseURLOrName = baseRemote.Name
    98  	}
    99  
   100  	headRemote := baseRemote
   101  	if pr.HeadRepository == nil {
   102  		headRemote = nil
   103  	} else if pr.IsCrossRepository {
   104  		headRemote, _ = remotes.FindByRepo(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name)
   105  	}
   106  
   107  	if strings.HasPrefix(pr.HeadRefName, "-") {
   108  		return fmt.Errorf("invalid branch name: %q", pr.HeadRefName)
   109  	}
   110  
   111  	var cmdQueue [][]string
   112  
   113  	if headRemote != nil {
   114  		cmdQueue = append(cmdQueue, cmdsForExistingRemote(headRemote, pr, opts)...)
   115  	} else {
   116  		httpClient, err := opts.HttpClient()
   117  		if err != nil {
   118  			return err
   119  		}
   120  		apiClient := api.NewClientFromHTTP(httpClient)
   121  
   122  		defaultBranch, err := api.RepoDefaultBranch(apiClient, baseRepo)
   123  		if err != nil {
   124  			return err
   125  		}
   126  		cmdQueue = append(cmdQueue, cmdsForMissingRemote(pr, baseURLOrName, baseRepo.RepoHost(), defaultBranch, protocol, opts)...)
   127  	}
   128  
   129  	if opts.RecurseSubmodules {
   130  		cmdQueue = append(cmdQueue, []string{"git", "submodule", "sync", "--recursive"})
   131  		cmdQueue = append(cmdQueue, []string{"git", "submodule", "update", "--init", "--recursive"})
   132  	}
   133  
   134  	err = executeCmds(cmdQueue)
   135  	if err != nil {
   136  		return err
   137  	}
   138  
   139  	return nil
   140  }
   141  
   142  func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *CheckoutOptions) [][]string {
   143  	var cmds [][]string
   144  	remoteBranch := fmt.Sprintf("%s/%s", remote.Name, pr.HeadRefName)
   145  
   146  	refSpec := fmt.Sprintf("+refs/heads/%s", pr.HeadRefName)
   147  	if !opts.Detach {
   148  		refSpec += fmt.Sprintf(":refs/remotes/%s", remoteBranch)
   149  	}
   150  
   151  	cmds = append(cmds, []string{"git", "fetch", remote.Name, refSpec})
   152  
   153  	localBranch := pr.HeadRefName
   154  	if opts.BranchName != "" {
   155  		localBranch = opts.BranchName
   156  	}
   157  
   158  	switch {
   159  	case opts.Detach:
   160  		cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"})
   161  	case localBranchExists(localBranch):
   162  		cmds = append(cmds, []string{"git", "checkout", localBranch})
   163  		if opts.Force {
   164  			cmds = append(cmds, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
   165  		} else {
   166  			// TODO: check if non-fast-forward and suggest to use `--force`
   167  			cmds = append(cmds, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
   168  		}
   169  	default:
   170  		cmds = append(cmds, []string{"git", "checkout", "-b", localBranch, "--track", remoteBranch})
   171  	}
   172  
   173  	return cmds
   174  }
   175  
   176  func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultBranch, protocol string, opts *CheckoutOptions) [][]string {
   177  	var cmds [][]string
   178  	ref := fmt.Sprintf("refs/pull/%d/head", pr.Number)
   179  
   180  	if opts.Detach {
   181  		cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref})
   182  		cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"})
   183  		return cmds
   184  	}
   185  
   186  	localBranch := pr.HeadRefName
   187  	if opts.BranchName != "" {
   188  		localBranch = opts.BranchName
   189  	} else if pr.HeadRefName == defaultBranch {
   190  		// avoid naming the new branch the same as the default branch
   191  		localBranch = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, localBranch)
   192  	}
   193  
   194  	currentBranch, _ := opts.Branch()
   195  	if localBranch == currentBranch {
   196  		// PR head matches currently checked out branch
   197  		cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref})
   198  		if opts.Force {
   199  			cmds = append(cmds, []string{"git", "reset", "--hard", "FETCH_HEAD"})
   200  		} else {
   201  			// TODO: check if non-fast-forward and suggest to use `--force`
   202  			cmds = append(cmds, []string{"git", "merge", "--ff-only", "FETCH_HEAD"})
   203  		}
   204  	} else {
   205  		if opts.Force {
   206  			cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch), "--force"})
   207  		} else {
   208  			// TODO: check if non-fast-forward and suggest to use `--force`
   209  			cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch)})
   210  		}
   211  
   212  		cmds = append(cmds, []string{"git", "checkout", localBranch})
   213  	}
   214  
   215  	remote := baseURLOrName
   216  	mergeRef := ref
   217  	if pr.MaintainerCanModify && pr.HeadRepository != nil {
   218  		headRepo := ghrepo.NewWithHost(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name, repoHost)
   219  		remote = ghrepo.FormatRemoteURL(headRepo, protocol)
   220  		mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName)
   221  	}
   222  	if missingMergeConfigForBranch(localBranch) {
   223  		cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", localBranch), remote})
   224  		cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", localBranch), mergeRef})
   225  	}
   226  
   227  	return cmds
   228  }
   229  
   230  func missingMergeConfigForBranch(b string) bool {
   231  	mc, err := git.Config(fmt.Sprintf("branch.%s.merge", b))
   232  	return err != nil || mc == ""
   233  }
   234  
   235  func localBranchExists(b string) bool {
   236  	_, err := git.ShowRefs("refs/heads/" + b)
   237  	return err == nil
   238  }
   239  
   240  func executeCmds(cmdQueue [][]string) error {
   241  	for _, args := range cmdQueue {
   242  		// TODO: reuse the result of this lookup across loop iteration
   243  		exe, err := safeexec.LookPath(args[0])
   244  		if err != nil {
   245  			return err
   246  		}
   247  		cmd := exec.Command(exe, args[1:]...)
   248  		cmd.Stdout = os.Stdout
   249  		cmd.Stderr = os.Stderr
   250  		if err := run.PrepareCmd(cmd).Run(); err != nil {
   251  			return err
   252  		}
   253  	}
   254  	return nil
   255  }