github.com/cli/cli@v1.14.1-0.20210902173923-1af6a669e342/pkg/cmd/pr/checkout/checkout.go (about) 1 package checkout 2 3 import ( 4 "fmt" 5 "net/http" 6 "os" 7 "os/exec" 8 "strings" 9 10 "github.com/cli/cli/api" 11 "github.com/cli/cli/context" 12 "github.com/cli/cli/git" 13 "github.com/cli/cli/internal/config" 14 "github.com/cli/cli/internal/ghrepo" 15 "github.com/cli/cli/internal/run" 16 "github.com/cli/cli/pkg/cmd/pr/shared" 17 "github.com/cli/cli/pkg/cmdutil" 18 "github.com/cli/cli/pkg/iostreams" 19 "github.com/cli/safeexec" 20 "github.com/spf13/cobra" 21 ) 22 23 type CheckoutOptions struct { 24 HttpClient func() (*http.Client, error) 25 Config func() (config.Config, error) 26 IO *iostreams.IOStreams 27 Remotes func() (context.Remotes, error) 28 Branch func() (string, error) 29 30 Finder shared.PRFinder 31 32 SelectorArg string 33 RecurseSubmodules bool 34 Force bool 35 Detach bool 36 BranchName string 37 } 38 39 func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobra.Command { 40 opts := &CheckoutOptions{ 41 IO: f.IOStreams, 42 HttpClient: f.HttpClient, 43 Config: f.Config, 44 Remotes: f.Remotes, 45 Branch: f.Branch, 46 } 47 48 cmd := &cobra.Command{ 49 Use: "checkout {<number> | <url> | <branch>}", 50 Short: "Check out a pull request in git", 51 Args: cmdutil.ExactArgs(1, "argument required"), 52 RunE: func(cmd *cobra.Command, args []string) error { 53 opts.Finder = shared.NewFinder(f) 54 55 if len(args) > 0 { 56 opts.SelectorArg = args[0] 57 } 58 59 if runF != nil { 60 return runF(opts) 61 } 62 return checkoutRun(opts) 63 }, 64 } 65 66 cmd.Flags().BoolVarP(&opts.RecurseSubmodules, "recurse-submodules", "", false, "Update all submodules after checkout") 67 cmd.Flags().BoolVarP(&opts.Force, "force", "f", false, "Reset the existing local branch to the latest state of the pull request") 68 cmd.Flags().BoolVarP(&opts.Detach, "detach", "", false, "Checkout PR with a detached HEAD") 69 cmd.Flags().StringVarP(&opts.BranchName, "branch", "b", "", "Local branch name to use (default: the name of the head branch)") 70 71 return cmd 72 } 73 74 func checkoutRun(opts *CheckoutOptions) error { 75 findOptions := shared.FindOptions{ 76 Selector: opts.SelectorArg, 77 Fields: []string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}, 78 } 79 pr, baseRepo, err := opts.Finder.Find(findOptions) 80 if err != nil { 81 return err 82 } 83 84 cfg, err := opts.Config() 85 if err != nil { 86 return err 87 } 88 protocol, _ := cfg.Get(baseRepo.RepoHost(), "git_protocol") 89 90 remotes, err := opts.Remotes() 91 if err != nil { 92 return err 93 } 94 baseRemote, _ := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName()) 95 baseURLOrName := ghrepo.FormatRemoteURL(baseRepo, protocol) 96 if baseRemote != nil { 97 baseURLOrName = baseRemote.Name 98 } 99 100 headRemote := baseRemote 101 if pr.HeadRepository == nil { 102 headRemote = nil 103 } else if pr.IsCrossRepository { 104 headRemote, _ = remotes.FindByRepo(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name) 105 } 106 107 if strings.HasPrefix(pr.HeadRefName, "-") { 108 return fmt.Errorf("invalid branch name: %q", pr.HeadRefName) 109 } 110 111 var cmdQueue [][]string 112 113 if headRemote != nil { 114 cmdQueue = append(cmdQueue, cmdsForExistingRemote(headRemote, pr, opts)...) 115 } else { 116 httpClient, err := opts.HttpClient() 117 if err != nil { 118 return err 119 } 120 apiClient := api.NewClientFromHTTP(httpClient) 121 122 defaultBranch, err := api.RepoDefaultBranch(apiClient, baseRepo) 123 if err != nil { 124 return err 125 } 126 cmdQueue = append(cmdQueue, cmdsForMissingRemote(pr, baseURLOrName, baseRepo.RepoHost(), defaultBranch, protocol, opts)...) 127 } 128 129 if opts.RecurseSubmodules { 130 cmdQueue = append(cmdQueue, []string{"git", "submodule", "sync", "--recursive"}) 131 cmdQueue = append(cmdQueue, []string{"git", "submodule", "update", "--init", "--recursive"}) 132 } 133 134 err = executeCmds(cmdQueue) 135 if err != nil { 136 return err 137 } 138 139 return nil 140 } 141 142 func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *CheckoutOptions) [][]string { 143 var cmds [][]string 144 remoteBranch := fmt.Sprintf("%s/%s", remote.Name, pr.HeadRefName) 145 146 refSpec := fmt.Sprintf("+refs/heads/%s", pr.HeadRefName) 147 if !opts.Detach { 148 refSpec += fmt.Sprintf(":refs/remotes/%s", remoteBranch) 149 } 150 151 cmds = append(cmds, []string{"git", "fetch", remote.Name, refSpec}) 152 153 localBranch := pr.HeadRefName 154 if opts.BranchName != "" { 155 localBranch = opts.BranchName 156 } 157 158 switch { 159 case opts.Detach: 160 cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"}) 161 case localBranchExists(localBranch): 162 cmds = append(cmds, []string{"git", "checkout", localBranch}) 163 if opts.Force { 164 cmds = append(cmds, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) 165 } else { 166 // TODO: check if non-fast-forward and suggest to use `--force` 167 cmds = append(cmds, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) 168 } 169 default: 170 cmds = append(cmds, []string{"git", "checkout", "-b", localBranch, "--track", remoteBranch}) 171 } 172 173 return cmds 174 } 175 176 func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultBranch, protocol string, opts *CheckoutOptions) [][]string { 177 var cmds [][]string 178 ref := fmt.Sprintf("refs/pull/%d/head", pr.Number) 179 180 if opts.Detach { 181 cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref}) 182 cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"}) 183 return cmds 184 } 185 186 localBranch := pr.HeadRefName 187 if opts.BranchName != "" { 188 localBranch = opts.BranchName 189 } else if pr.HeadRefName == defaultBranch { 190 // avoid naming the new branch the same as the default branch 191 localBranch = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, localBranch) 192 } 193 194 currentBranch, _ := opts.Branch() 195 if localBranch == currentBranch { 196 // PR head matches currently checked out branch 197 cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref}) 198 if opts.Force { 199 cmds = append(cmds, []string{"git", "reset", "--hard", "FETCH_HEAD"}) 200 } else { 201 // TODO: check if non-fast-forward and suggest to use `--force` 202 cmds = append(cmds, []string{"git", "merge", "--ff-only", "FETCH_HEAD"}) 203 } 204 } else { 205 if opts.Force { 206 cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch), "--force"}) 207 } else { 208 // TODO: check if non-fast-forward and suggest to use `--force` 209 cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch)}) 210 } 211 212 cmds = append(cmds, []string{"git", "checkout", localBranch}) 213 } 214 215 remote := baseURLOrName 216 mergeRef := ref 217 if pr.MaintainerCanModify && pr.HeadRepository != nil { 218 headRepo := ghrepo.NewWithHost(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name, repoHost) 219 remote = ghrepo.FormatRemoteURL(headRepo, protocol) 220 mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName) 221 } 222 if missingMergeConfigForBranch(localBranch) { 223 cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", localBranch), remote}) 224 cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", localBranch), mergeRef}) 225 } 226 227 return cmds 228 } 229 230 func missingMergeConfigForBranch(b string) bool { 231 mc, err := git.Config(fmt.Sprintf("branch.%s.merge", b)) 232 return err != nil || mc == "" 233 } 234 235 func localBranchExists(b string) bool { 236 _, err := git.ShowRefs("refs/heads/" + b) 237 return err == nil 238 } 239 240 func executeCmds(cmdQueue [][]string) error { 241 for _, args := range cmdQueue { 242 // TODO: reuse the result of this lookup across loop iteration 243 exe, err := safeexec.LookPath(args[0]) 244 if err != nil { 245 return err 246 } 247 cmd := exec.Command(exe, args[1:]...) 248 cmd.Stdout = os.Stdout 249 cmd.Stderr = os.Stderr 250 if err := run.PrepareCmd(cmd).Run(); err != nil { 251 return err 252 } 253 } 254 return nil 255 }