github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/cmd/pr/checkout/checkout.go (about)

     1  package checkout
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"strings"
     8  
     9  	"github.com/ungtb10d/cli/v2/api"
    10  	cliContext "github.com/ungtb10d/cli/v2/context"
    11  	"github.com/ungtb10d/cli/v2/git"
    12  	"github.com/ungtb10d/cli/v2/internal/config"
    13  	"github.com/ungtb10d/cli/v2/internal/ghrepo"
    14  	"github.com/ungtb10d/cli/v2/pkg/cmd/pr/shared"
    15  	"github.com/ungtb10d/cli/v2/pkg/cmdutil"
    16  	"github.com/ungtb10d/cli/v2/pkg/iostreams"
    17  	"github.com/spf13/cobra"
    18  )
    19  
    20  type CheckoutOptions struct {
    21  	HttpClient func() (*http.Client, error)
    22  	GitClient  *git.Client
    23  	Config     func() (config.Config, error)
    24  	IO         *iostreams.IOStreams
    25  	Remotes    func() (cliContext.Remotes, error)
    26  	Branch     func() (string, error)
    27  
    28  	Finder shared.PRFinder
    29  
    30  	SelectorArg       string
    31  	RecurseSubmodules bool
    32  	Force             bool
    33  	Detach            bool
    34  	BranchName        string
    35  }
    36  
    37  func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobra.Command {
    38  	opts := &CheckoutOptions{
    39  		IO:         f.IOStreams,
    40  		HttpClient: f.HttpClient,
    41  		GitClient:  f.GitClient,
    42  		Config:     f.Config,
    43  		Remotes:    f.Remotes,
    44  		Branch:     f.Branch,
    45  	}
    46  
    47  	cmd := &cobra.Command{
    48  		Use:   "checkout {<number> | <url> | <branch>}",
    49  		Short: "Check out a pull request in git",
    50  		Args:  cmdutil.ExactArgs(1, "argument required"),
    51  		RunE: func(cmd *cobra.Command, args []string) error {
    52  			opts.Finder = shared.NewFinder(f)
    53  
    54  			if len(args) > 0 {
    55  				opts.SelectorArg = args[0]
    56  			}
    57  
    58  			if runF != nil {
    59  				return runF(opts)
    60  			}
    61  			return checkoutRun(opts)
    62  		},
    63  	}
    64  
    65  	cmd.Flags().BoolVarP(&opts.RecurseSubmodules, "recurse-submodules", "", false, "Update all submodules after checkout")
    66  	cmd.Flags().BoolVarP(&opts.Force, "force", "f", false, "Reset the existing local branch to the latest state of the pull request")
    67  	cmd.Flags().BoolVarP(&opts.Detach, "detach", "", false, "Checkout PR with a detached HEAD")
    68  	cmd.Flags().StringVarP(&opts.BranchName, "branch", "b", "", "Local branch name to use (default: the name of the head branch)")
    69  
    70  	return cmd
    71  }
    72  
    73  func checkoutRun(opts *CheckoutOptions) error {
    74  	findOptions := shared.FindOptions{
    75  		Selector: opts.SelectorArg,
    76  		Fields:   []string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"},
    77  	}
    78  	pr, baseRepo, err := opts.Finder.Find(findOptions)
    79  	if err != nil {
    80  		return err
    81  	}
    82  
    83  	cfg, err := opts.Config()
    84  	if err != nil {
    85  		return err
    86  	}
    87  	protocol, _ := cfg.GetOrDefault(baseRepo.RepoHost(), "git_protocol")
    88  
    89  	remotes, err := opts.Remotes()
    90  	if err != nil {
    91  		return err
    92  	}
    93  	baseRemote, _ := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName())
    94  	baseURLOrName := ghrepo.FormatRemoteURL(baseRepo, protocol)
    95  	if baseRemote != nil {
    96  		baseURLOrName = baseRemote.Name
    97  	}
    98  
    99  	headRemote := baseRemote
   100  	if pr.HeadRepository == nil {
   101  		headRemote = nil
   102  	} else if pr.IsCrossRepository {
   103  		headRemote, _ = remotes.FindByRepo(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name)
   104  	}
   105  
   106  	if strings.HasPrefix(pr.HeadRefName, "-") {
   107  		return fmt.Errorf("invalid branch name: %q", pr.HeadRefName)
   108  	}
   109  
   110  	var cmdQueue [][]string
   111  
   112  	if headRemote != nil {
   113  		cmdQueue = append(cmdQueue, cmdsForExistingRemote(headRemote, pr, opts)...)
   114  	} else {
   115  		httpClient, err := opts.HttpClient()
   116  		if err != nil {
   117  			return err
   118  		}
   119  		apiClient := api.NewClientFromHTTP(httpClient)
   120  
   121  		defaultBranch, err := api.RepoDefaultBranch(apiClient, baseRepo)
   122  		if err != nil {
   123  			return err
   124  		}
   125  		cmdQueue = append(cmdQueue, cmdsForMissingRemote(pr, baseURLOrName, baseRepo.RepoHost(), defaultBranch, protocol, opts)...)
   126  	}
   127  
   128  	if opts.RecurseSubmodules {
   129  		cmdQueue = append(cmdQueue, []string{"submodule", "sync", "--recursive"})
   130  		cmdQueue = append(cmdQueue, []string{"submodule", "update", "--init", "--recursive"})
   131  	}
   132  
   133  	err = executeCmds(opts.GitClient, cmdQueue)
   134  	if err != nil {
   135  		return err
   136  	}
   137  
   138  	return nil
   139  }
   140  
   141  func cmdsForExistingRemote(remote *cliContext.Remote, pr *api.PullRequest, opts *CheckoutOptions) [][]string {
   142  	var cmds [][]string
   143  	remoteBranch := fmt.Sprintf("%s/%s", remote.Name, pr.HeadRefName)
   144  
   145  	refSpec := fmt.Sprintf("+refs/heads/%s", pr.HeadRefName)
   146  	if !opts.Detach {
   147  		refSpec += fmt.Sprintf(":refs/remotes/%s", remoteBranch)
   148  	}
   149  
   150  	cmds = append(cmds, []string{"fetch", remote.Name, refSpec})
   151  
   152  	localBranch := pr.HeadRefName
   153  	if opts.BranchName != "" {
   154  		localBranch = opts.BranchName
   155  	}
   156  
   157  	switch {
   158  	case opts.Detach:
   159  		cmds = append(cmds, []string{"checkout", "--detach", "FETCH_HEAD"})
   160  	case localBranchExists(opts.GitClient, localBranch):
   161  		cmds = append(cmds, []string{"checkout", localBranch})
   162  		if opts.Force {
   163  			cmds = append(cmds, []string{"reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
   164  		} else {
   165  			// TODO: check if non-fast-forward and suggest to use `--force`
   166  			cmds = append(cmds, []string{"merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
   167  		}
   168  	default:
   169  		cmds = append(cmds, []string{"checkout", "-b", localBranch, "--track", remoteBranch})
   170  	}
   171  
   172  	return cmds
   173  }
   174  
   175  func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultBranch, protocol string, opts *CheckoutOptions) [][]string {
   176  	var cmds [][]string
   177  	ref := fmt.Sprintf("refs/pull/%d/head", pr.Number)
   178  
   179  	if opts.Detach {
   180  		cmds = append(cmds, []string{"fetch", baseURLOrName, ref})
   181  		cmds = append(cmds, []string{"checkout", "--detach", "FETCH_HEAD"})
   182  		return cmds
   183  	}
   184  
   185  	localBranch := pr.HeadRefName
   186  	if opts.BranchName != "" {
   187  		localBranch = opts.BranchName
   188  	} else if pr.HeadRefName == defaultBranch {
   189  		// avoid naming the new branch the same as the default branch
   190  		localBranch = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, localBranch)
   191  	}
   192  
   193  	currentBranch, _ := opts.Branch()
   194  	if localBranch == currentBranch {
   195  		// PR head matches currently checked out branch
   196  		cmds = append(cmds, []string{"fetch", baseURLOrName, ref})
   197  		if opts.Force {
   198  			cmds = append(cmds, []string{"reset", "--hard", "FETCH_HEAD"})
   199  		} else {
   200  			// TODO: check if non-fast-forward and suggest to use `--force`
   201  			cmds = append(cmds, []string{"merge", "--ff-only", "FETCH_HEAD"})
   202  		}
   203  	} else {
   204  		if opts.Force {
   205  			cmds = append(cmds, []string{"fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch), "--force"})
   206  		} else {
   207  			// TODO: check if non-fast-forward and suggest to use `--force`
   208  			cmds = append(cmds, []string{"fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch)})
   209  		}
   210  
   211  		cmds = append(cmds, []string{"checkout", localBranch})
   212  	}
   213  
   214  	remote := baseURLOrName
   215  	mergeRef := ref
   216  	if pr.MaintainerCanModify && pr.HeadRepository != nil {
   217  		headRepo := ghrepo.NewWithHost(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name, repoHost)
   218  		remote = ghrepo.FormatRemoteURL(headRepo, protocol)
   219  		mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName)
   220  	}
   221  	if missingMergeConfigForBranch(opts.GitClient, localBranch) {
   222  		// .remote is needed for `git pull` to work
   223  		// .pushRemote is needed for `git push` to work, if user has set `remote.pushDefault`.
   224  		// see https://git-scm.com/docs/git-config#Documentation/git-config.txt-branchltnamegtremote
   225  		cmds = append(cmds, []string{"config", fmt.Sprintf("branch.%s.remote", localBranch), remote})
   226  		cmds = append(cmds, []string{"config", fmt.Sprintf("branch.%s.pushRemote", localBranch), remote})
   227  		cmds = append(cmds, []string{"config", fmt.Sprintf("branch.%s.merge", localBranch), mergeRef})
   228  	}
   229  
   230  	return cmds
   231  }
   232  
   233  func missingMergeConfigForBranch(client *git.Client, b string) bool {
   234  	mc, err := client.Config(context.Background(), fmt.Sprintf("branch.%s.merge", b))
   235  	return err != nil || mc == ""
   236  }
   237  
   238  func localBranchExists(client *git.Client, b string) bool {
   239  	_, err := client.ShowRefs(context.Background(), []string{"refs/heads/" + b})
   240  	return err == nil
   241  }
   242  
   243  func executeCmds(client *git.Client, cmdQueue [][]string) error {
   244  	for _, args := range cmdQueue {
   245  		// TODO: Use AuthenticatedCommand
   246  		cmd, err := client.Command(context.Background(), args...)
   247  		if err != nil {
   248  			return err
   249  		}
   250  		if err := cmd.Run(); err != nil {
   251  			return err
   252  		}
   253  	}
   254  	return nil
   255  }