github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/cmd/pr/checkout/checkout.go (about) 1 package checkout 2 3 import ( 4 "context" 5 "fmt" 6 "net/http" 7 "strings" 8 9 "github.com/ungtb10d/cli/v2/api" 10 cliContext "github.com/ungtb10d/cli/v2/context" 11 "github.com/ungtb10d/cli/v2/git" 12 "github.com/ungtb10d/cli/v2/internal/config" 13 "github.com/ungtb10d/cli/v2/internal/ghrepo" 14 "github.com/ungtb10d/cli/v2/pkg/cmd/pr/shared" 15 "github.com/ungtb10d/cli/v2/pkg/cmdutil" 16 "github.com/ungtb10d/cli/v2/pkg/iostreams" 17 "github.com/spf13/cobra" 18 ) 19 20 type CheckoutOptions struct { 21 HttpClient func() (*http.Client, error) 22 GitClient *git.Client 23 Config func() (config.Config, error) 24 IO *iostreams.IOStreams 25 Remotes func() (cliContext.Remotes, error) 26 Branch func() (string, error) 27 28 Finder shared.PRFinder 29 30 SelectorArg string 31 RecurseSubmodules bool 32 Force bool 33 Detach bool 34 BranchName string 35 } 36 37 func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobra.Command { 38 opts := &CheckoutOptions{ 39 IO: f.IOStreams, 40 HttpClient: f.HttpClient, 41 GitClient: f.GitClient, 42 Config: f.Config, 43 Remotes: f.Remotes, 44 Branch: f.Branch, 45 } 46 47 cmd := &cobra.Command{ 48 Use: "checkout {<number> | <url> | <branch>}", 49 Short: "Check out a pull request in git", 50 Args: cmdutil.ExactArgs(1, "argument required"), 51 RunE: func(cmd *cobra.Command, args []string) error { 52 opts.Finder = shared.NewFinder(f) 53 54 if len(args) > 0 { 55 opts.SelectorArg = args[0] 56 } 57 58 if runF != nil { 59 return runF(opts) 60 } 61 return checkoutRun(opts) 62 }, 63 } 64 65 cmd.Flags().BoolVarP(&opts.RecurseSubmodules, "recurse-submodules", "", false, "Update all submodules after checkout") 66 cmd.Flags().BoolVarP(&opts.Force, "force", "f", false, "Reset the existing local branch to the latest state of the pull request") 67 cmd.Flags().BoolVarP(&opts.Detach, "detach", "", false, "Checkout PR with a detached HEAD") 68 cmd.Flags().StringVarP(&opts.BranchName, "branch", "b", "", "Local branch name to use (default: the name of the head branch)") 69 70 return cmd 71 } 72 73 func checkoutRun(opts *CheckoutOptions) error { 74 findOptions := shared.FindOptions{ 75 Selector: opts.SelectorArg, 76 Fields: []string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}, 77 } 78 pr, baseRepo, err := opts.Finder.Find(findOptions) 79 if err != nil { 80 return err 81 } 82 83 cfg, err := opts.Config() 84 if err != nil { 85 return err 86 } 87 protocol, _ := cfg.GetOrDefault(baseRepo.RepoHost(), "git_protocol") 88 89 remotes, err := opts.Remotes() 90 if err != nil { 91 return err 92 } 93 baseRemote, _ := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName()) 94 baseURLOrName := ghrepo.FormatRemoteURL(baseRepo, protocol) 95 if baseRemote != nil { 96 baseURLOrName = baseRemote.Name 97 } 98 99 headRemote := baseRemote 100 if pr.HeadRepository == nil { 101 headRemote = nil 102 } else if pr.IsCrossRepository { 103 headRemote, _ = remotes.FindByRepo(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name) 104 } 105 106 if strings.HasPrefix(pr.HeadRefName, "-") { 107 return fmt.Errorf("invalid branch name: %q", pr.HeadRefName) 108 } 109 110 var cmdQueue [][]string 111 112 if headRemote != nil { 113 cmdQueue = append(cmdQueue, cmdsForExistingRemote(headRemote, pr, opts)...) 114 } else { 115 httpClient, err := opts.HttpClient() 116 if err != nil { 117 return err 118 } 119 apiClient := api.NewClientFromHTTP(httpClient) 120 121 defaultBranch, err := api.RepoDefaultBranch(apiClient, baseRepo) 122 if err != nil { 123 return err 124 } 125 cmdQueue = append(cmdQueue, cmdsForMissingRemote(pr, baseURLOrName, baseRepo.RepoHost(), defaultBranch, protocol, opts)...) 126 } 127 128 if opts.RecurseSubmodules { 129 cmdQueue = append(cmdQueue, []string{"submodule", "sync", "--recursive"}) 130 cmdQueue = append(cmdQueue, []string{"submodule", "update", "--init", "--recursive"}) 131 } 132 133 err = executeCmds(opts.GitClient, cmdQueue) 134 if err != nil { 135 return err 136 } 137 138 return nil 139 } 140 141 func cmdsForExistingRemote(remote *cliContext.Remote, pr *api.PullRequest, opts *CheckoutOptions) [][]string { 142 var cmds [][]string 143 remoteBranch := fmt.Sprintf("%s/%s", remote.Name, pr.HeadRefName) 144 145 refSpec := fmt.Sprintf("+refs/heads/%s", pr.HeadRefName) 146 if !opts.Detach { 147 refSpec += fmt.Sprintf(":refs/remotes/%s", remoteBranch) 148 } 149 150 cmds = append(cmds, []string{"fetch", remote.Name, refSpec}) 151 152 localBranch := pr.HeadRefName 153 if opts.BranchName != "" { 154 localBranch = opts.BranchName 155 } 156 157 switch { 158 case opts.Detach: 159 cmds = append(cmds, []string{"checkout", "--detach", "FETCH_HEAD"}) 160 case localBranchExists(opts.GitClient, localBranch): 161 cmds = append(cmds, []string{"checkout", localBranch}) 162 if opts.Force { 163 cmds = append(cmds, []string{"reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) 164 } else { 165 // TODO: check if non-fast-forward and suggest to use `--force` 166 cmds = append(cmds, []string{"merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) 167 } 168 default: 169 cmds = append(cmds, []string{"checkout", "-b", localBranch, "--track", remoteBranch}) 170 } 171 172 return cmds 173 } 174 175 func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultBranch, protocol string, opts *CheckoutOptions) [][]string { 176 var cmds [][]string 177 ref := fmt.Sprintf("refs/pull/%d/head", pr.Number) 178 179 if opts.Detach { 180 cmds = append(cmds, []string{"fetch", baseURLOrName, ref}) 181 cmds = append(cmds, []string{"checkout", "--detach", "FETCH_HEAD"}) 182 return cmds 183 } 184 185 localBranch := pr.HeadRefName 186 if opts.BranchName != "" { 187 localBranch = opts.BranchName 188 } else if pr.HeadRefName == defaultBranch { 189 // avoid naming the new branch the same as the default branch 190 localBranch = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, localBranch) 191 } 192 193 currentBranch, _ := opts.Branch() 194 if localBranch == currentBranch { 195 // PR head matches currently checked out branch 196 cmds = append(cmds, []string{"fetch", baseURLOrName, ref}) 197 if opts.Force { 198 cmds = append(cmds, []string{"reset", "--hard", "FETCH_HEAD"}) 199 } else { 200 // TODO: check if non-fast-forward and suggest to use `--force` 201 cmds = append(cmds, []string{"merge", "--ff-only", "FETCH_HEAD"}) 202 } 203 } else { 204 if opts.Force { 205 cmds = append(cmds, []string{"fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch), "--force"}) 206 } else { 207 // TODO: check if non-fast-forward and suggest to use `--force` 208 cmds = append(cmds, []string{"fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch)}) 209 } 210 211 cmds = append(cmds, []string{"checkout", localBranch}) 212 } 213 214 remote := baseURLOrName 215 mergeRef := ref 216 if pr.MaintainerCanModify && pr.HeadRepository != nil { 217 headRepo := ghrepo.NewWithHost(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name, repoHost) 218 remote = ghrepo.FormatRemoteURL(headRepo, protocol) 219 mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName) 220 } 221 if missingMergeConfigForBranch(opts.GitClient, localBranch) { 222 // .remote is needed for `git pull` to work 223 // .pushRemote is needed for `git push` to work, if user has set `remote.pushDefault`. 224 // see https://git-scm.com/docs/git-config#Documentation/git-config.txt-branchltnamegtremote 225 cmds = append(cmds, []string{"config", fmt.Sprintf("branch.%s.remote", localBranch), remote}) 226 cmds = append(cmds, []string{"config", fmt.Sprintf("branch.%s.pushRemote", localBranch), remote}) 227 cmds = append(cmds, []string{"config", fmt.Sprintf("branch.%s.merge", localBranch), mergeRef}) 228 } 229 230 return cmds 231 } 232 233 func missingMergeConfigForBranch(client *git.Client, b string) bool { 234 mc, err := client.Config(context.Background(), fmt.Sprintf("branch.%s.merge", b)) 235 return err != nil || mc == "" 236 } 237 238 func localBranchExists(client *git.Client, b string) bool { 239 _, err := client.ShowRefs(context.Background(), []string{"refs/heads/" + b}) 240 return err == nil 241 } 242 243 func executeCmds(client *git.Client, cmdQueue [][]string) error { 244 for _, args := range cmdQueue { 245 // TODO: Use AuthenticatedCommand 246 cmd, err := client.Command(context.Background(), args...) 247 if err != nil { 248 return err 249 } 250 if err := cmd.Run(); err != nil { 251 return err 252 } 253 } 254 return nil 255 }