github.com/cli/cli@v1.14.1-0.20210902173923-1af6a669e342/pkg/cmd/repo/sync/sync.go (about)

     1  package sync
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/http"
     7  	"regexp"
     8  
     9  	"github.com/MakeNowJust/heredoc"
    10  	"github.com/cli/cli/api"
    11  	"github.com/cli/cli/context"
    12  	gitpkg "github.com/cli/cli/git"
    13  	"github.com/cli/cli/internal/ghrepo"
    14  	"github.com/cli/cli/pkg/cmdutil"
    15  	"github.com/cli/cli/pkg/iostreams"
    16  	"github.com/spf13/cobra"
    17  )
    18  
    19  type SyncOptions struct {
    20  	HttpClient func() (*http.Client, error)
    21  	IO         *iostreams.IOStreams
    22  	BaseRepo   func() (ghrepo.Interface, error)
    23  	Remotes    func() (context.Remotes, error)
    24  	Git        gitClient
    25  	DestArg    string
    26  	SrcArg     string
    27  	Branch     string
    28  	Force      bool
    29  }
    30  
    31  func NewCmdSync(f *cmdutil.Factory, runF func(*SyncOptions) error) *cobra.Command {
    32  	opts := SyncOptions{
    33  		HttpClient: f.HttpClient,
    34  		IO:         f.IOStreams,
    35  		BaseRepo:   f.BaseRepo,
    36  		Remotes:    f.Remotes,
    37  		Git:        &gitExecuter{},
    38  	}
    39  
    40  	cmd := &cobra.Command{
    41  		Use:   "sync [<destination-repository>]",
    42  		Short: "Sync a repository",
    43  		Long: heredoc.Docf(`
    44  			Sync destination repository from source repository. Syncing uses the main branch
    45  			of the source repository to update the matching branch on the destination
    46  			repository so they are equal. A fast forward update will be used execept when the
    47  			%[1]s--force%[1]s flag is specified, then the two branches will
    48  			by synced using a hard reset.
    49  
    50  			Without an argument, the local repository is selected as the destination repository.
    51  
    52  			The source repository is the parent of the destination repository by default.
    53  			This can be overridden with the %[1]s--source%[1]s flag.
    54  		`, "`"),
    55  		Example: heredoc.Doc(`
    56  			# Sync local repository from remote parent
    57  			$ gh repo sync
    58  
    59  			# Sync local repository from remote parent on specific branch
    60  			$ gh repo sync --branch v1
    61  
    62  			# Sync remote fork from its parent
    63  			$ gh repo sync owner/cli-fork
    64  
    65  			# Sync remote repository from another remote repository
    66  			$ gh repo sync owner/repo --source owner2/repo2
    67  		`),
    68  		Args: cobra.MaximumNArgs(1),
    69  		RunE: func(c *cobra.Command, args []string) error {
    70  			if len(args) > 0 {
    71  				opts.DestArg = args[0]
    72  			}
    73  			if runF != nil {
    74  				return runF(&opts)
    75  			}
    76  			return syncRun(&opts)
    77  		},
    78  	}
    79  
    80  	cmd.Flags().StringVarP(&opts.SrcArg, "source", "s", "", "Source repository")
    81  	cmd.Flags().StringVarP(&opts.Branch, "branch", "b", "", "Branch to sync (default: main branch)")
    82  	cmd.Flags().BoolVarP(&opts.Force, "force", "", false, "Hard reset the branch of the destination repository to match the source repository")
    83  	return cmd
    84  }
    85  
    86  func syncRun(opts *SyncOptions) error {
    87  	if opts.DestArg == "" {
    88  		return syncLocalRepo(opts)
    89  	} else {
    90  		return syncRemoteRepo(opts)
    91  	}
    92  }
    93  
    94  func syncLocalRepo(opts *SyncOptions) error {
    95  	var srcRepo ghrepo.Interface
    96  
    97  	if opts.SrcArg != "" {
    98  		var err error
    99  		srcRepo, err = ghrepo.FromFullName(opts.SrcArg)
   100  		if err != nil {
   101  			return err
   102  		}
   103  	} else {
   104  		var err error
   105  		srcRepo, err = opts.BaseRepo()
   106  		if err != nil {
   107  			return err
   108  		}
   109  	}
   110  
   111  	// Find remote that matches the srcRepo
   112  	var remote string
   113  	remotes, err := opts.Remotes()
   114  	if err != nil {
   115  		return err
   116  	}
   117  	if r, err := remotes.FindByRepo(srcRepo.RepoOwner(), srcRepo.RepoName()); err == nil {
   118  		remote = r.Name
   119  	} else {
   120  		return fmt.Errorf("can't find corresponding remote for %s", ghrepo.FullName(srcRepo))
   121  	}
   122  
   123  	if opts.Branch == "" {
   124  		httpClient, err := opts.HttpClient()
   125  		if err != nil {
   126  			return err
   127  		}
   128  		apiClient := api.NewClientFromHTTP(httpClient)
   129  		opts.IO.StartProgressIndicator()
   130  		opts.Branch, err = api.RepoDefaultBranch(apiClient, srcRepo)
   131  		opts.IO.StopProgressIndicator()
   132  		if err != nil {
   133  			return err
   134  		}
   135  	}
   136  
   137  	opts.IO.StartProgressIndicator()
   138  	err = executeLocalRepoSync(srcRepo, remote, opts)
   139  	opts.IO.StopProgressIndicator()
   140  	if err != nil {
   141  		if errors.Is(err, divergingError) {
   142  			return fmt.Errorf("can't sync because there are diverging changes; use `--force` to overwrite the destination branch")
   143  		}
   144  		if errors.Is(err, mismatchRemotesError) {
   145  			return fmt.Errorf("can't sync because %s is not tracking %s", opts.Branch, ghrepo.FullName(srcRepo))
   146  		}
   147  		return err
   148  	}
   149  
   150  	if opts.IO.IsStdoutTTY() {
   151  		cs := opts.IO.ColorScheme()
   152  		fmt.Fprintf(opts.IO.Out, "%s Synced the \"%s\" branch from %s to local repository\n",
   153  			cs.SuccessIcon(),
   154  			opts.Branch,
   155  			ghrepo.FullName(srcRepo))
   156  	}
   157  
   158  	return nil
   159  }
   160  
   161  func syncRemoteRepo(opts *SyncOptions) error {
   162  	httpClient, err := opts.HttpClient()
   163  	if err != nil {
   164  		return err
   165  	}
   166  	apiClient := api.NewClientFromHTTP(httpClient)
   167  
   168  	var destRepo, srcRepo ghrepo.Interface
   169  
   170  	destRepo, err = ghrepo.FromFullName(opts.DestArg)
   171  	if err != nil {
   172  		return err
   173  	}
   174  
   175  	if opts.SrcArg == "" {
   176  		opts.IO.StartProgressIndicator()
   177  		srcRepo, err = api.RepoParent(apiClient, destRepo)
   178  		opts.IO.StopProgressIndicator()
   179  		if err != nil {
   180  			return err
   181  		}
   182  		if srcRepo == nil {
   183  			return fmt.Errorf("can't determine source repository for %s because repository is not fork", ghrepo.FullName(destRepo))
   184  		}
   185  	} else {
   186  		srcRepo, err = ghrepo.FromFullName(opts.SrcArg)
   187  		if err != nil {
   188  			return err
   189  		}
   190  	}
   191  
   192  	if destRepo.RepoHost() != srcRepo.RepoHost() {
   193  		return fmt.Errorf("can't sync repositories from different hosts")
   194  	}
   195  
   196  	if opts.Branch == "" {
   197  		opts.IO.StartProgressIndicator()
   198  		opts.Branch, err = api.RepoDefaultBranch(apiClient, srcRepo)
   199  		opts.IO.StopProgressIndicator()
   200  		if err != nil {
   201  			return err
   202  		}
   203  	}
   204  
   205  	opts.IO.StartProgressIndicator()
   206  	err = executeRemoteRepoSync(apiClient, destRepo, srcRepo, opts)
   207  	opts.IO.StopProgressIndicator()
   208  	if err != nil {
   209  		if errors.Is(err, divergingError) {
   210  			return fmt.Errorf("can't sync because there are diverging changes; use `--force` to overwrite the destination branch")
   211  		}
   212  		return err
   213  	}
   214  
   215  	if opts.IO.IsStdoutTTY() {
   216  		cs := opts.IO.ColorScheme()
   217  		fmt.Fprintf(opts.IO.Out, "%s Synced the \"%s\" branch from %s to %s\n",
   218  			cs.SuccessIcon(),
   219  			opts.Branch,
   220  			ghrepo.FullName(srcRepo),
   221  			ghrepo.FullName(destRepo))
   222  	}
   223  
   224  	return nil
   225  }
   226  
   227  var divergingError = errors.New("diverging changes")
   228  var mismatchRemotesError = errors.New("branch remote does not match specified source")
   229  
   230  func executeLocalRepoSync(srcRepo ghrepo.Interface, remote string, opts *SyncOptions) error {
   231  	git := opts.Git
   232  	branch := opts.Branch
   233  	useForce := opts.Force
   234  
   235  	if err := git.Fetch(remote, fmt.Sprintf("refs/heads/%s", branch)); err != nil {
   236  		return err
   237  	}
   238  
   239  	hasLocalBranch := git.HasLocalBranch(branch)
   240  	if hasLocalBranch {
   241  		branchRemote, err := git.BranchRemote(branch)
   242  		if err != nil {
   243  			return err
   244  		}
   245  		if branchRemote != remote {
   246  			return mismatchRemotesError
   247  		}
   248  
   249  		fastForward, err := git.IsAncestor(branch, "FETCH_HEAD")
   250  		if err != nil {
   251  			return err
   252  		}
   253  
   254  		if !fastForward && !useForce {
   255  			return divergingError
   256  		}
   257  		if fastForward && useForce {
   258  			useForce = false
   259  		}
   260  	}
   261  
   262  	currentBranch, err := git.CurrentBranch()
   263  	if err != nil && !errors.Is(err, gitpkg.ErrNotOnAnyBranch) {
   264  		return err
   265  	}
   266  	if currentBranch == branch {
   267  		if isDirty, err := git.IsDirty(); err == nil && isDirty {
   268  			return fmt.Errorf("can't sync because there are local changes; please stash them before trying again")
   269  		} else if err != nil {
   270  			return err
   271  		}
   272  		if useForce {
   273  			if err := git.ResetHard("FETCH_HEAD"); err != nil {
   274  				return err
   275  			}
   276  		} else {
   277  			if err := git.MergeFastForward("FETCH_HEAD"); err != nil {
   278  				return err
   279  			}
   280  		}
   281  	} else {
   282  		if hasLocalBranch {
   283  			if err := git.UpdateBranch(branch, "FETCH_HEAD"); err != nil {
   284  				return err
   285  			}
   286  		} else {
   287  			if err := git.CreateBranch(branch, "FETCH_HEAD", fmt.Sprintf("%s/%s", remote, branch)); err != nil {
   288  				return err
   289  			}
   290  		}
   291  	}
   292  
   293  	return nil
   294  }
   295  
   296  func executeRemoteRepoSync(client *api.Client, destRepo, srcRepo ghrepo.Interface, opts *SyncOptions) error {
   297  	commit, err := latestCommit(client, srcRepo, opts.Branch)
   298  	if err != nil {
   299  		return err
   300  	}
   301  
   302  	// This is not a great way to detect the error returned by the API
   303  	// Unfortunately API returns 422 for multiple reasons
   304  	notFastForwardErrorMessage := regexp.MustCompile(`^Update is not a fast forward$`)
   305  	err = syncFork(client, destRepo, opts.Branch, commit.Object.SHA, opts.Force)
   306  	var httpErr api.HTTPError
   307  	if err != nil && errors.As(err, &httpErr) && notFastForwardErrorMessage.MatchString(httpErr.Message) {
   308  		return divergingError
   309  	}
   310  
   311  	return err
   312  }