github.com/abhinav/git-pr@v0.6.1-0.20171029234004-54218d68c11b/pr/rebase.go (about) 1 package pr 2 3 import ( 4 "container/list" 5 "context" 6 "fmt" 7 "sync" 8 9 "github.com/abhinav/git-pr/gateway" 10 "github.com/abhinav/git-pr/git" 11 "github.com/abhinav/git-pr/service" 12 13 "github.com/google/go-github/github" 14 "go.uber.org/multierr" 15 ) 16 17 // Rebase a pull request and its dependencies. 18 func (s *Service) Rebase(ctx context.Context, req *service.RebaseRequest) (_ *service.RebaseResponse, err error) { 19 if len(req.PullRequests) == 0 { 20 return &service.RebaseResponse{}, nil 21 } 22 23 // Go back to the original branch after everything is done. 24 oldBranch, err := s.git.CurrentBranch() 25 if err != nil { 26 return nil, err 27 } 28 defer func(oldBranch string) { 29 err = multierr.Append(err, s.git.Checkout(oldBranch)) 30 }(oldBranch) 31 32 if err := s.git.Fetch(&gateway.FetchRequest{Remote: "origin"}); err != nil { 33 return nil, err 34 } 35 36 // TODO: support remotes besides origin 37 baseRef, err := s.git.SHA1("origin/" + req.Base) 38 if err != nil { 39 return nil, err 40 } 41 42 rebaser := git.NewBulkRebaser(s.git) 43 defer func() { 44 err = multierr.Append(err, rebaser.Cleanup()) 45 }() 46 47 results, err := s.rebasePullRequests(rebasePRConfig{ 48 Context: ctx, 49 GitRebaser: rebaser, 50 GitHub: s.gh, 51 Base: baseRef, 52 PullRequests: req.PullRequests, 53 Author: req.Author, 54 }) 55 if err != nil { 56 return nil, err 57 } 58 59 // Nothing to do 60 if len(results) == 0 { 61 return &service.RebaseResponse{}, nil 62 } 63 64 var ( 65 // Branches to reset to new positions for their remotes after rebasing. 66 branchesToReset []string 67 68 // Branches not updated because their heads were out of date 69 branchesNotUpdated []string 70 71 // Pushes to perform. local ref -> remote branch 72 pushes = make(map[string]string) 73 ) 74 75 for _, r := range results { 76 prBranch := r.PR.Head.GetRef() 77 if sha, err := s.git.SHA1(prBranch); err == nil { 78 if sha == r.PR.Head.GetSHA() { 79 branchesToReset = append(branchesToReset, prBranch) 80 } else { 81 branchesNotUpdated = append(branchesNotUpdated, prBranch) 82 } 83 } 84 pushes[r.LocalRef] = prBranch 85 } 86 87 if err := s.git.Push(&gateway.PushRequest{ 88 Remote: "origin", 89 Force: true, 90 Refs: pushes, 91 }); err != nil { 92 return nil, err 93 } 94 95 for _, br := range branchesToReset { 96 err = multierr.Append(err, s.git.ResetBranch(br, "origin/"+br)) 97 } 98 99 var ( 100 mu sync.Mutex 101 wg sync.WaitGroup 102 ) 103 for _, pr := range req.PullRequests { 104 // TODO: --only-mine should apply 105 if pr.Base.GetRef() != req.Base { 106 wg.Add(1) 107 // TODO: fix unbounded goroutine count 108 go func(pr *github.PullRequest) { 109 defer wg.Done() 110 e := s.gh.SetPullRequestBase(ctx, *pr.Number, req.Base) 111 if e == nil { 112 return 113 } 114 115 mu.Lock() 116 err = multierr.Append(err, fmt.Errorf( 117 "failed to set base for %v to %q: %v", *pr.HTMLURL, req.Base, e)) 118 mu.Unlock() 119 }(pr) 120 } 121 } 122 wg.Wait() 123 124 return &service.RebaseResponse{ 125 BranchesNotUpdated: branchesNotUpdated, 126 }, err 127 } 128 129 type rebasedPullRequest struct { 130 PR *github.PullRequest 131 132 // We should do, 133 // 134 // git push origin $LocalRef:$Branch 135 // 136 // Where $Branch is pr.Head.GetRef() 137 LocalRef string 138 } 139 140 // Part of the interface of git.BulkRebaser that we need here. 141 type bulkRebaser interface { 142 Err() error 143 Onto(string) git.RebaseHandle 144 } 145 146 type rebasePRConfig struct { 147 // If non-empty, only PRs authored by this user will be considered. 148 Author string 149 150 Context context.Context 151 GitRebaser bulkRebaser 152 GitHub gateway.GitHub 153 Base string 154 PullRequests []*github.PullRequest 155 } 156 157 func rebasePullRequests(cfg rebasePRConfig) (map[int]rebasedPullRequest, error) { 158 v := rebaseVisitor{ 159 rebasePRConfig: &cfg, 160 handle: cfg.GitRebaser.Onto(cfg.Base), 161 mu: new(sync.Mutex), 162 results: list.New(), 163 } 164 165 walkCfg := WalkConfig{Children: getDependentPRs(cfg.Context, cfg.GitHub)} 166 if err := Walk(walkCfg, cfg.PullRequests, v); err != nil { 167 return nil, err 168 } 169 170 if err := cfg.GitRebaser.Err(); err != nil { 171 return nil, err 172 } 173 174 results := make(map[int]rebasedPullRequest, v.results.Len()) 175 for e := v.results.Front(); e != nil; e = e.Next() { 176 p := e.Value.(rebasedPullRequest) 177 results[p.PR.GetNumber()] = p 178 } 179 180 return results, nil 181 } 182 183 func getDependentPRs( 184 ctx context.Context, gh gateway.GitHub, 185 ) func(*github.PullRequest) ([]*github.PullRequest, error) { 186 return func(pr *github.PullRequest) ([]*github.PullRequest, error) { 187 return gh.ListPullRequestsByBase(ctx, pr.Head.GetRef()) 188 } 189 } 190 191 type rebaseVisitor struct { 192 *rebasePRConfig 193 194 mu *sync.Mutex 195 results *list.List // list<rebasedPullRequest> 196 197 handle git.RebaseHandle 198 } 199 200 func (v rebaseVisitor) Visit(pr *github.PullRequest) (Visitor, error) { 201 // Don't rebase if we don't own the PR. 202 if !v.GitHub.IsOwned(v.Context, pr.Head) { 203 // TODO: There is more nuance to this. We should check if we have 204 // write access instead. 205 // TODO: Log if we skip 206 return nil, nil 207 } 208 209 if v.Author != "" && pr.User.GetLogin() != v.Author { 210 // TODO: log skipped PR 211 return nil, nil 212 } 213 214 h := v.handle.Rebase(pr.Base.GetSHA(), pr.Head.GetSHA()) 215 v.mu.Lock() 216 v.results.PushBack(rebasedPullRequest{PR: pr, LocalRef: h.Base()}) 217 v.mu.Unlock() 218 219 // We are operating on a shallow copy of v so we can just modify and 220 // return it. 221 v.handle = h 222 return v, nil 223 }