github.com/abhinav/git-pr@v0.6.1-0.20171029234004-54218d68c11b/pr/rebase.go (about)

     1  package pr
     2  
     3  import (
     4  	"container/list"
     5  	"context"
     6  	"fmt"
     7  	"sync"
     8  
     9  	"github.com/abhinav/git-pr/gateway"
    10  	"github.com/abhinav/git-pr/git"
    11  	"github.com/abhinav/git-pr/service"
    12  
    13  	"github.com/google/go-github/github"
    14  	"go.uber.org/multierr"
    15  )
    16  
    17  // Rebase a pull request and its dependencies.
    18  func (s *Service) Rebase(ctx context.Context, req *service.RebaseRequest) (_ *service.RebaseResponse, err error) {
    19  	if len(req.PullRequests) == 0 {
    20  		return &service.RebaseResponse{}, nil
    21  	}
    22  
    23  	// Go back to the original branch after everything is done.
    24  	oldBranch, err := s.git.CurrentBranch()
    25  	if err != nil {
    26  		return nil, err
    27  	}
    28  	defer func(oldBranch string) {
    29  		err = multierr.Append(err, s.git.Checkout(oldBranch))
    30  	}(oldBranch)
    31  
    32  	if err := s.git.Fetch(&gateway.FetchRequest{Remote: "origin"}); err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	// TODO: support remotes besides origin
    37  	baseRef, err := s.git.SHA1("origin/" + req.Base)
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  
    42  	rebaser := git.NewBulkRebaser(s.git)
    43  	defer func() {
    44  		err = multierr.Append(err, rebaser.Cleanup())
    45  	}()
    46  
    47  	results, err := s.rebasePullRequests(rebasePRConfig{
    48  		Context:      ctx,
    49  		GitRebaser:   rebaser,
    50  		GitHub:       s.gh,
    51  		Base:         baseRef,
    52  		PullRequests: req.PullRequests,
    53  		Author:       req.Author,
    54  	})
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  
    59  	// Nothing to do
    60  	if len(results) == 0 {
    61  		return &service.RebaseResponse{}, nil
    62  	}
    63  
    64  	var (
    65  		// Branches to reset to new positions for their remotes after rebasing.
    66  		branchesToReset []string
    67  
    68  		// Branches not updated because their heads were out of date
    69  		branchesNotUpdated []string
    70  
    71  		// Pushes to perform. local ref -> remote branch
    72  		pushes = make(map[string]string)
    73  	)
    74  
    75  	for _, r := range results {
    76  		prBranch := r.PR.Head.GetRef()
    77  		if sha, err := s.git.SHA1(prBranch); err == nil {
    78  			if sha == r.PR.Head.GetSHA() {
    79  				branchesToReset = append(branchesToReset, prBranch)
    80  			} else {
    81  				branchesNotUpdated = append(branchesNotUpdated, prBranch)
    82  			}
    83  		}
    84  		pushes[r.LocalRef] = prBranch
    85  	}
    86  
    87  	if err := s.git.Push(&gateway.PushRequest{
    88  		Remote: "origin",
    89  		Force:  true,
    90  		Refs:   pushes,
    91  	}); err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	for _, br := range branchesToReset {
    96  		err = multierr.Append(err, s.git.ResetBranch(br, "origin/"+br))
    97  	}
    98  
    99  	var (
   100  		mu sync.Mutex
   101  		wg sync.WaitGroup
   102  	)
   103  	for _, pr := range req.PullRequests {
   104  		// TODO: --only-mine should apply
   105  		if pr.Base.GetRef() != req.Base {
   106  			wg.Add(1)
   107  			// TODO: fix unbounded goroutine count
   108  			go func(pr *github.PullRequest) {
   109  				defer wg.Done()
   110  				e := s.gh.SetPullRequestBase(ctx, *pr.Number, req.Base)
   111  				if e == nil {
   112  					return
   113  				}
   114  
   115  				mu.Lock()
   116  				err = multierr.Append(err, fmt.Errorf(
   117  					"failed to set base for %v to %q: %v", *pr.HTMLURL, req.Base, e))
   118  				mu.Unlock()
   119  			}(pr)
   120  		}
   121  	}
   122  	wg.Wait()
   123  
   124  	return &service.RebaseResponse{
   125  		BranchesNotUpdated: branchesNotUpdated,
   126  	}, err
   127  }
   128  
   129  type rebasedPullRequest struct {
   130  	PR *github.PullRequest
   131  
   132  	// We should do,
   133  	//
   134  	// 	git push origin $LocalRef:$Branch
   135  	//
   136  	// Where $Branch is pr.Head.GetRef()
   137  	LocalRef string
   138  }
   139  
   140  // Part of the interface of git.BulkRebaser that we need here.
   141  type bulkRebaser interface {
   142  	Err() error
   143  	Onto(string) git.RebaseHandle
   144  }
   145  
   146  type rebasePRConfig struct {
   147  	// If non-empty, only PRs authored by this user will be considered.
   148  	Author string
   149  
   150  	Context      context.Context
   151  	GitRebaser   bulkRebaser
   152  	GitHub       gateway.GitHub
   153  	Base         string
   154  	PullRequests []*github.PullRequest
   155  }
   156  
   157  func rebasePullRequests(cfg rebasePRConfig) (map[int]rebasedPullRequest, error) {
   158  	v := rebaseVisitor{
   159  		rebasePRConfig: &cfg,
   160  		handle:         cfg.GitRebaser.Onto(cfg.Base),
   161  		mu:             new(sync.Mutex),
   162  		results:        list.New(),
   163  	}
   164  
   165  	walkCfg := WalkConfig{Children: getDependentPRs(cfg.Context, cfg.GitHub)}
   166  	if err := Walk(walkCfg, cfg.PullRequests, v); err != nil {
   167  		return nil, err
   168  	}
   169  
   170  	if err := cfg.GitRebaser.Err(); err != nil {
   171  		return nil, err
   172  	}
   173  
   174  	results := make(map[int]rebasedPullRequest, v.results.Len())
   175  	for e := v.results.Front(); e != nil; e = e.Next() {
   176  		p := e.Value.(rebasedPullRequest)
   177  		results[p.PR.GetNumber()] = p
   178  	}
   179  
   180  	return results, nil
   181  }
   182  
   183  func getDependentPRs(
   184  	ctx context.Context, gh gateway.GitHub,
   185  ) func(*github.PullRequest) ([]*github.PullRequest, error) {
   186  	return func(pr *github.PullRequest) ([]*github.PullRequest, error) {
   187  		return gh.ListPullRequestsByBase(ctx, pr.Head.GetRef())
   188  	}
   189  }
   190  
   191  type rebaseVisitor struct {
   192  	*rebasePRConfig
   193  
   194  	mu      *sync.Mutex
   195  	results *list.List // list<rebasedPullRequest>
   196  
   197  	handle git.RebaseHandle
   198  }
   199  
   200  func (v rebaseVisitor) Visit(pr *github.PullRequest) (Visitor, error) {
   201  	// Don't rebase if we don't own the PR.
   202  	if !v.GitHub.IsOwned(v.Context, pr.Head) {
   203  		// TODO: There is more nuance to this. We should check if we have
   204  		// write access instead.
   205  		// TODO: Log if we skip
   206  		return nil, nil
   207  	}
   208  
   209  	if v.Author != "" && pr.User.GetLogin() != v.Author {
   210  		// TODO: log skipped PR
   211  		return nil, nil
   212  	}
   213  
   214  	h := v.handle.Rebase(pr.Base.GetSHA(), pr.Head.GetSHA())
   215  	v.mu.Lock()
   216  	v.results.PushBack(rebasedPullRequest{PR: pr, LocalRef: h.Base()})
   217  	v.mu.Unlock()
   218  
   219  	// We are operating on a shallow copy of v so we can just modify and
   220  	// return it.
   221  	v.handle = h
   222  	return v, nil
   223  }