github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/cmd/repo/sync/sync.go (about)

     1  package sync
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/http"
     7  	"strings"
     8  
     9  	"github.com/MakeNowJust/heredoc"
    10  	"github.com/ungtb10d/cli/v2/api"
    11  	"github.com/ungtb10d/cli/v2/context"
    12  	gitpkg "github.com/ungtb10d/cli/v2/git"
    13  	"github.com/ungtb10d/cli/v2/internal/ghrepo"
    14  	"github.com/ungtb10d/cli/v2/pkg/cmdutil"
    15  	"github.com/ungtb10d/cli/v2/pkg/iostreams"
    16  	"github.com/spf13/cobra"
    17  )
    18  
    19  const (
    20  	notFastForwardErrorMessage     = "Update is not a fast forward"
    21  	branchDoesNotExistErrorMessage = "Reference does not exist"
    22  )
    23  
    24  type SyncOptions struct {
    25  	HttpClient func() (*http.Client, error)
    26  	IO         *iostreams.IOStreams
    27  	BaseRepo   func() (ghrepo.Interface, error)
    28  	Remotes    func() (context.Remotes, error)
    29  	Git        gitClient
    30  	DestArg    string
    31  	SrcArg     string
    32  	Branch     string
    33  	Force      bool
    34  }
    35  
    36  func NewCmdSync(f *cmdutil.Factory, runF func(*SyncOptions) error) *cobra.Command {
    37  	opts := SyncOptions{
    38  		HttpClient: f.HttpClient,
    39  		IO:         f.IOStreams,
    40  		BaseRepo:   f.BaseRepo,
    41  		Remotes:    f.Remotes,
    42  		Git:        &gitExecuter{client: f.GitClient},
    43  	}
    44  
    45  	cmd := &cobra.Command{
    46  		Use:   "sync [<destination-repository>]",
    47  		Short: "Sync a repository",
    48  		Long: heredoc.Docf(`
    49  			Sync destination repository from source repository. Syncing uses the main branch
    50  			of the source repository to update the matching branch on the destination
    51  			repository so they are equal. A fast forward update will be used except when the
    52  			%[1]s--force%[1]s flag is specified, then the two branches will
    53  			by synced using a hard reset.
    54  
    55  			Without an argument, the local repository is selected as the destination repository.
    56  
    57  			The source repository is the parent of the destination repository by default.
    58  			This can be overridden with the %[1]s--source%[1]s flag.
    59  		`, "`"),
    60  		Example: heredoc.Doc(`
    61  			# Sync local repository from remote parent
    62  			$ gh repo sync
    63  
    64  			# Sync local repository from remote parent on specific branch
    65  			$ gh repo sync --branch v1
    66  
    67  			# Sync remote fork from its parent
    68  			$ gh repo sync owner/cli-fork
    69  
    70  			# Sync remote repository from another remote repository
    71  			$ gh repo sync owner/repo --source owner2/repo2
    72  		`),
    73  		Args: cobra.MaximumNArgs(1),
    74  		RunE: func(c *cobra.Command, args []string) error {
    75  			if len(args) > 0 {
    76  				opts.DestArg = args[0]
    77  			}
    78  			if runF != nil {
    79  				return runF(&opts)
    80  			}
    81  			return syncRun(&opts)
    82  		},
    83  	}
    84  
    85  	cmd.Flags().StringVarP(&opts.SrcArg, "source", "s", "", "Source repository")
    86  	cmd.Flags().StringVarP(&opts.Branch, "branch", "b", "", "Branch to sync (default: main branch)")
    87  	cmd.Flags().BoolVarP(&opts.Force, "force", "", false, "Hard reset the branch of the destination repository to match the source repository")
    88  	return cmd
    89  }
    90  
    91  func syncRun(opts *SyncOptions) error {
    92  	if opts.DestArg == "" {
    93  		return syncLocalRepo(opts)
    94  	} else {
    95  		return syncRemoteRepo(opts)
    96  	}
    97  }
    98  
    99  func syncLocalRepo(opts *SyncOptions) error {
   100  	var srcRepo ghrepo.Interface
   101  
   102  	if opts.SrcArg != "" {
   103  		var err error
   104  		srcRepo, err = ghrepo.FromFullName(opts.SrcArg)
   105  		if err != nil {
   106  			return err
   107  		}
   108  	} else {
   109  		var err error
   110  		srcRepo, err = opts.BaseRepo()
   111  		if err != nil {
   112  			return err
   113  		}
   114  	}
   115  
   116  	// Find remote that matches the srcRepo
   117  	var remote string
   118  	remotes, err := opts.Remotes()
   119  	if err != nil {
   120  		return err
   121  	}
   122  	if r, err := remotes.FindByRepo(srcRepo.RepoOwner(), srcRepo.RepoName()); err == nil {
   123  		remote = r.Name
   124  	} else {
   125  		return fmt.Errorf("can't find corresponding remote for %s", ghrepo.FullName(srcRepo))
   126  	}
   127  
   128  	if opts.Branch == "" {
   129  		httpClient, err := opts.HttpClient()
   130  		if err != nil {
   131  			return err
   132  		}
   133  		apiClient := api.NewClientFromHTTP(httpClient)
   134  		opts.IO.StartProgressIndicator()
   135  		opts.Branch, err = api.RepoDefaultBranch(apiClient, srcRepo)
   136  		opts.IO.StopProgressIndicator()
   137  		if err != nil {
   138  			return err
   139  		}
   140  	}
   141  
   142  	// Git fetch might require input from user, so do it before starting progress indicator.
   143  	if err := opts.Git.Fetch(remote, fmt.Sprintf("refs/heads/%s", opts.Branch)); err != nil {
   144  		return err
   145  	}
   146  
   147  	opts.IO.StartProgressIndicator()
   148  	err = executeLocalRepoSync(srcRepo, remote, opts)
   149  	opts.IO.StopProgressIndicator()
   150  	if err != nil {
   151  		if errors.Is(err, divergingError) {
   152  			return fmt.Errorf("can't sync because there are diverging changes; use `--force` to overwrite the destination branch")
   153  		}
   154  		if errors.Is(err, mismatchRemotesError) {
   155  			return fmt.Errorf("can't sync because %s is not tracking %s", opts.Branch, ghrepo.FullName(srcRepo))
   156  		}
   157  		return err
   158  	}
   159  
   160  	if opts.IO.IsStdoutTTY() {
   161  		cs := opts.IO.ColorScheme()
   162  		fmt.Fprintf(opts.IO.Out, "%s Synced the \"%s\" branch from %s to local repository\n",
   163  			cs.SuccessIcon(),
   164  			opts.Branch,
   165  			ghrepo.FullName(srcRepo))
   166  	}
   167  
   168  	return nil
   169  }
   170  
   171  func syncRemoteRepo(opts *SyncOptions) error {
   172  	httpClient, err := opts.HttpClient()
   173  	if err != nil {
   174  		return err
   175  	}
   176  	apiClient := api.NewClientFromHTTP(httpClient)
   177  
   178  	var destRepo, srcRepo ghrepo.Interface
   179  
   180  	destRepo, err = ghrepo.FromFullName(opts.DestArg)
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	if opts.SrcArg != "" {
   186  		srcRepo, err = ghrepo.FromFullName(opts.SrcArg)
   187  		if err != nil {
   188  			return err
   189  		}
   190  	}
   191  
   192  	if srcRepo != nil && destRepo.RepoHost() != srcRepo.RepoHost() {
   193  		return fmt.Errorf("can't sync repositories from different hosts")
   194  	}
   195  
   196  	opts.IO.StartProgressIndicator()
   197  	baseBranchLabel, err := executeRemoteRepoSync(apiClient, destRepo, srcRepo, opts)
   198  	opts.IO.StopProgressIndicator()
   199  	if err != nil {
   200  		if errors.Is(err, divergingError) {
   201  			return fmt.Errorf("can't sync because there are diverging changes; use `--force` to overwrite the destination branch")
   202  		}
   203  		return err
   204  	}
   205  
   206  	if opts.IO.IsStdoutTTY() {
   207  		cs := opts.IO.ColorScheme()
   208  		branchName := opts.Branch
   209  		if idx := strings.Index(baseBranchLabel, ":"); idx >= 0 {
   210  			branchName = baseBranchLabel[idx+1:]
   211  		}
   212  		fmt.Fprintf(opts.IO.Out, "%s Synced the \"%s:%s\" branch from \"%s\"\n",
   213  			cs.SuccessIcon(),
   214  			destRepo.RepoOwner(),
   215  			branchName,
   216  			baseBranchLabel)
   217  	}
   218  
   219  	return nil
   220  }
   221  
   222  var divergingError = errors.New("diverging changes")
   223  var mismatchRemotesError = errors.New("branch remote does not match specified source")
   224  
   225  func executeLocalRepoSync(srcRepo ghrepo.Interface, remote string, opts *SyncOptions) error {
   226  	git := opts.Git
   227  	branch := opts.Branch
   228  	useForce := opts.Force
   229  
   230  	hasLocalBranch := git.HasLocalBranch(branch)
   231  	if hasLocalBranch {
   232  		branchRemote, err := git.BranchRemote(branch)
   233  		if err != nil {
   234  			return err
   235  		}
   236  		if branchRemote != remote {
   237  			return mismatchRemotesError
   238  		}
   239  
   240  		fastForward, err := git.IsAncestor(branch, "FETCH_HEAD")
   241  		if err != nil {
   242  			return err
   243  		}
   244  
   245  		if !fastForward && !useForce {
   246  			return divergingError
   247  		}
   248  		if fastForward && useForce {
   249  			useForce = false
   250  		}
   251  	}
   252  
   253  	currentBranch, err := git.CurrentBranch()
   254  	if err != nil && !errors.Is(err, gitpkg.ErrNotOnAnyBranch) {
   255  		return err
   256  	}
   257  	if currentBranch == branch {
   258  		if isDirty, err := git.IsDirty(); err == nil && isDirty {
   259  			return fmt.Errorf("can't sync because there are local changes; please stash them before trying again")
   260  		} else if err != nil {
   261  			return err
   262  		}
   263  		if useForce {
   264  			if err := git.ResetHard("FETCH_HEAD"); err != nil {
   265  				return err
   266  			}
   267  		} else {
   268  			if err := git.MergeFastForward("FETCH_HEAD"); err != nil {
   269  				return err
   270  			}
   271  		}
   272  	} else {
   273  		if hasLocalBranch {
   274  			if err := git.UpdateBranch(branch, "FETCH_HEAD"); err != nil {
   275  				return err
   276  			}
   277  		} else {
   278  			if err := git.CreateBranch(branch, "FETCH_HEAD", fmt.Sprintf("%s/%s", remote, branch)); err != nil {
   279  				return err
   280  			}
   281  		}
   282  	}
   283  
   284  	return nil
   285  }
   286  
   287  func executeRemoteRepoSync(client *api.Client, destRepo, srcRepo ghrepo.Interface, opts *SyncOptions) (string, error) {
   288  	branchName := opts.Branch
   289  	if branchName == "" {
   290  		var err error
   291  		branchName, err = api.RepoDefaultBranch(client, destRepo)
   292  		if err != nil {
   293  			return "", err
   294  		}
   295  	}
   296  
   297  	var apiErr upstreamMergeErr
   298  	if baseBranch, err := triggerUpstreamMerge(client, destRepo, branchName); err == nil {
   299  		return baseBranch, nil
   300  	} else if !errors.As(err, &apiErr) {
   301  		return "", err
   302  	}
   303  
   304  	if srcRepo == nil {
   305  		var err error
   306  		srcRepo, err = api.RepoParent(client, destRepo)
   307  		if err != nil {
   308  			return "", err
   309  		}
   310  		if srcRepo == nil {
   311  			return "", fmt.Errorf("can't determine source repository for %s because repository is not fork", ghrepo.FullName(destRepo))
   312  		}
   313  	}
   314  
   315  	commit, err := latestCommit(client, srcRepo, branchName)
   316  	if err != nil {
   317  		return "", err
   318  	}
   319  
   320  	// This is not a great way to detect the error returned by the API
   321  	// Unfortunately API returns 422 for multiple reasons
   322  	err = syncFork(client, destRepo, branchName, commit.Object.SHA, opts.Force)
   323  	var httpErr api.HTTPError
   324  	if err != nil {
   325  		if errors.As(err, &httpErr) {
   326  			switch httpErr.Message {
   327  			case notFastForwardErrorMessage:
   328  				return "", divergingError
   329  			case branchDoesNotExistErrorMessage:
   330  				return "", fmt.Errorf("%s branch does not exist on %s repository", branchName, ghrepo.FullName(destRepo))
   331  			}
   332  		}
   333  		return "", err
   334  	}
   335  
   336  	return fmt.Sprintf("%s:%s", srcRepo.RepoOwner(), branchName), nil
   337  }