github.com/abhinav/git-pr@v0.6.1-0.20171029234004-54218d68c11b/pr/rebase_test.go (about)

     1  package pr
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/abhinav/git-pr/gateway"
    11  	"github.com/abhinav/git-pr/gateway/gatewaytest"
    12  	"github.com/abhinav/git-pr/git"
    13  	"github.com/abhinav/git-pr/git/gittest"
    14  	"github.com/abhinav/git-pr/service"
    15  
    16  	"github.com/golang/mock/gomock"
    17  	"github.com/google/go-github/github"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  )
    21  
    22  func TestServiceRebase(t *testing.T) {
    23  	type testCase struct {
    24  		Desc string
    25  
    26  		Request service.RebaseRequest
    27  
    28  		// Values to return from rebasePullRequests
    29  		RebasePRsResult []rebasedPullRequest
    30  		RebasePRsError  error
    31  
    32  		// Skip common mock setup if true.
    33  		SkipCommon bool
    34  
    35  		// SHA1 hashes for branches that may be queried. May be empty or
    36  		// partial if SetupGit is handling this.
    37  		SHA1Hashes map[string]string // branch name -> sha1 hash
    38  
    39  		// Branches which don't have a local version. SHA lookup for these
    40  		// will fail.
    41  		SHA1Failures []string
    42  
    43  		// If present, these may be used for more complicated setup on the
    44  		// mocks.
    45  		SetupGit    func(*gatewaytest.MockGit)
    46  		SetupGitHub func(*gatewaytest.MockGitHub)
    47  
    48  		// Expected Git.ResetBranch calls. May be empty or partial if SetupGit
    49  		// is handling this.
    50  		WantBranchResets []string // branch name -> ref
    51  
    52  		// Expected items in Push(). May be empty if SetupGitHub is handling
    53  		// this.
    54  		WantPushes map[string]string // local ref -> branch name
    55  
    56  		// List of pull requests for which we expect the PR base to change to
    57  		// Request.Base. May be empty or partial if SetupGitHub is handling
    58  		// this.
    59  		WantBaseChanges []int
    60  
    61  		WantResponse service.RebaseResponse
    62  		WantErrors   []string
    63  	}
    64  
    65  	tests := []testCase{
    66  		{
    67  			Desc:         "empty",
    68  			Request:      service.RebaseRequest{Base: "foo"},
    69  			WantResponse: service.RebaseResponse{},
    70  			SkipCommon:   true,
    71  		},
    72  		func() (tt testCase) {
    73  			tt.Desc = "single"
    74  
    75  			pr := &github.PullRequest{
    76  				Number:  github.Int(1),
    77  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
    78  				Base: &github.PullRequestBranch{
    79  					Ref: github.String("master"),
    80  				},
    81  				Head: &github.PullRequestBranch{
    82  					SHA: github.String("headsha"),
    83  					Ref: github.String("myfeature"),
    84  				},
    85  			}
    86  
    87  			tt.Request = service.RebaseRequest{
    88  				Base:         "master",
    89  				PullRequests: []*github.PullRequest{pr},
    90  			}
    91  			tt.RebasePRsResult = []rebasedPullRequest{
    92  				{PR: pr, LocalRef: "git-pr/rebase/headsha"},
    93  			}
    94  			tt.SHA1Hashes = map[string]string{"myfeature": "headsha"}
    95  
    96  			tt.WantPushes = map[string]string{"git-pr/rebase/headsha": "myfeature"}
    97  			tt.WantBranchResets = []string{"myfeature"}
    98  
    99  			return
   100  		}(),
   101  		{
   102  			Desc: "no rebases",
   103  			Request: service.RebaseRequest{
   104  				Base: "master",
   105  				PullRequests: []*github.PullRequest{
   106  					&github.PullRequest{
   107  						Number:  github.Int(1),
   108  						HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   109  						Base: &github.PullRequestBranch{
   110  							Ref: github.String("master"),
   111  						},
   112  						Head: &github.PullRequestBranch{
   113  							SHA: github.String("headsha"),
   114  							Ref: github.String("myfeature"),
   115  						},
   116  					},
   117  				},
   118  			},
   119  			RebasePRsResult: []rebasedPullRequest{},
   120  		},
   121  		func() (tt testCase) {
   122  			tt.Desc = "single base change"
   123  
   124  			pr := &github.PullRequest{
   125  				Number:  github.Int(1),
   126  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   127  				Base: &github.PullRequestBranch{
   128  					Ref: github.String("dev"),
   129  				},
   130  				Head: &github.PullRequestBranch{
   131  					SHA: github.String("somesha"),
   132  					Ref: github.String("myfeature"),
   133  				},
   134  			}
   135  
   136  			tt.Request = service.RebaseRequest{
   137  				Base:         "master",
   138  				PullRequests: []*github.PullRequest{pr},
   139  			}
   140  			tt.RebasePRsResult = []rebasedPullRequest{
   141  				{PR: pr, LocalRef: "git-pr/rebase/somesha"},
   142  			}
   143  			tt.SHA1Hashes = map[string]string{"myfeature": "differentsha"}
   144  
   145  			tt.WantPushes = map[string]string{"git-pr/rebase/somesha": "myfeature"}
   146  			tt.WantBaseChanges = []int{1}
   147  			tt.WantResponse = service.RebaseResponse{
   148  				BranchesNotUpdated: []string{"myfeature"},
   149  			}
   150  
   151  			return
   152  		}(),
   153  		func() (tt testCase) {
   154  			tt.Desc = "multiple"
   155  
   156  			pr1 := &github.PullRequest{
   157  				Number:  github.Int(1),
   158  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   159  				Base: &github.PullRequestBranch{
   160  					Ref: github.String("dev"),
   161  				},
   162  				Head: &github.PullRequestBranch{
   163  					SHA: github.String("sha1"),
   164  					Ref: github.String("feature-1"),
   165  				},
   166  			}
   167  
   168  			pr2 := &github.PullRequest{
   169  				Number:  github.Int(2),
   170  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/2"),
   171  				Base: &github.PullRequestBranch{
   172  					Ref: github.String("master"),
   173  				},
   174  				Head: &github.PullRequestBranch{
   175  					SHA: github.String("sha2"),
   176  					Ref: github.String("feature-2"),
   177  				},
   178  			}
   179  
   180  			pr3 := &github.PullRequest{
   181  				Number:  github.Int(3),
   182  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/3"),
   183  				Base: &github.PullRequestBranch{
   184  					Ref: github.String("master"),
   185  				},
   186  				Head: &github.PullRequestBranch{
   187  					SHA: github.String("sha3"),
   188  					Ref: github.String("feature-3"),
   189  				},
   190  			}
   191  
   192  			tt.Request = service.RebaseRequest{
   193  				Base:         "dev",
   194  				PullRequests: []*github.PullRequest{pr1, pr2, pr3},
   195  			}
   196  			tt.RebasePRsResult = []rebasedPullRequest{
   197  				{PR: pr1, LocalRef: "git-pr/rebase/sha1"},
   198  				{PR: pr2, LocalRef: "git-pr/rebase/sha2"},
   199  				{PR: pr3, LocalRef: "git-pr/rebase/sha3"},
   200  			}
   201  			tt.SHA1Hashes = map[string]string{
   202  				"feature-1": "sha1",
   203  				"feature-2": "sha2",
   204  				"feature-3": "not-sha3",
   205  			}
   206  			tt.WantPushes = map[string]string{
   207  				"git-pr/rebase/sha1": "feature-1",
   208  				"git-pr/rebase/sha2": "feature-2",
   209  				"git-pr/rebase/sha3": "feature-3",
   210  			}
   211  			tt.WantBaseChanges = []int{2, 3}
   212  			tt.WantBranchResets = []string{"feature-1", "feature-2"}
   213  			tt.WantResponse = service.RebaseResponse{
   214  				BranchesNotUpdated: []string{"feature-3"},
   215  			}
   216  
   217  			return
   218  		}(),
   219  		func() (tt testCase) {
   220  			tt.Desc = "simple stack"
   221  
   222  			// dev -> feature-1 -> feature-2 -> feature-3
   223  
   224  			pr := &github.PullRequest{
   225  				Number:  github.Int(1),
   226  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   227  				Base: &github.PullRequestBranch{
   228  					Ref: github.String("dev"),
   229  				},
   230  				Head: &github.PullRequestBranch{
   231  					SHA: github.String("sha1"),
   232  					Ref: github.String("feature-1"),
   233  				},
   234  			}
   235  
   236  			tt.Request = service.RebaseRequest{
   237  				Base:         "dev",
   238  				PullRequests: []*github.PullRequest{pr},
   239  			}
   240  			tt.RebasePRsResult = []rebasedPullRequest{
   241  				{PR: pr, LocalRef: "git-pr/rebase/sha1"},
   242  				{
   243  					LocalRef: "git-pr/rebase/sha2",
   244  					PR: &github.PullRequest{
   245  						Number:  github.Int(2),
   246  						HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/2"),
   247  						Base: &github.PullRequestBranch{
   248  							Ref: github.String("feature-1"),
   249  						},
   250  						Head: &github.PullRequestBranch{
   251  							SHA: github.String("sha2"),
   252  							Ref: github.String("feature-2"),
   253  						},
   254  					},
   255  				},
   256  				{
   257  					LocalRef: "git-pr/rebase/sha3",
   258  					PR: &github.PullRequest{
   259  						Number:  github.Int(3),
   260  						HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/3"),
   261  						Base: &github.PullRequestBranch{
   262  							Ref: github.String("feature-2"),
   263  						},
   264  						Head: &github.PullRequestBranch{
   265  							SHA: github.String("sha3"),
   266  							Ref: github.String("feature-3"),
   267  						},
   268  					},
   269  				},
   270  			}
   271  
   272  			tt.SHA1Hashes = map[string]string{
   273  				"feature-1": "not-sha1",
   274  				"feature-3": "sha3",
   275  			}
   276  			tt.SHA1Failures = []string{"feature-2"}
   277  
   278  			tt.WantBranchResets = []string{"feature-3"}
   279  			tt.WantPushes = map[string]string{
   280  				"git-pr/rebase/sha1": "feature-1",
   281  				"git-pr/rebase/sha2": "feature-2",
   282  				"git-pr/rebase/sha3": "feature-3",
   283  			}
   284  			tt.WantResponse = service.RebaseResponse{
   285  				BranchesNotUpdated: []string{"feature-1"},
   286  			}
   287  
   288  			return
   289  		}(),
   290  		func() (tt testCase) {
   291  			tt.Desc = "graph"
   292  
   293  			// dev-----------.
   294  			//  |             \
   295  			//  +-> feature-1  +-> feature-2 -> feature-3
   296  			//  |                    |
   297  			//  +-> feature-4        +-> feature-5
   298  			//                                |
   299  			//                                +-> feature-6
   300  
   301  			pr1 := &github.PullRequest{
   302  				Number:  github.Int(1),
   303  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   304  				Base: &github.PullRequestBranch{
   305  					Ref: github.String("dev"),
   306  				},
   307  				Head: &github.PullRequestBranch{
   308  					SHA: github.String("sha1"),
   309  					Ref: github.String("feature-1"),
   310  				},
   311  			}
   312  
   313  			pr2 := &github.PullRequest{
   314  				Number:  github.Int(2),
   315  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/2"),
   316  				Base: &github.PullRequestBranch{
   317  					Ref: github.String("master"),
   318  				},
   319  				Head: &github.PullRequestBranch{
   320  					SHA: github.String("sha2"),
   321  					Ref: github.String("feature-2"),
   322  				},
   323  			}
   324  
   325  			tt.Request = service.RebaseRequest{
   326  				Base:         "dev",
   327  				PullRequests: []*github.PullRequest{pr1, pr2},
   328  			}
   329  			tt.RebasePRsResult = []rebasedPullRequest{
   330  				{PR: pr1, LocalRef: "git-pr/rebase/sha1"},
   331  				{PR: pr2, LocalRef: "git-pr/rebase/sha2"},
   332  				{
   333  					LocalRef: "git-pr/rebase/sha3",
   334  					PR: &github.PullRequest{
   335  						Number:  github.Int(3),
   336  						HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/3"),
   337  						Base: &github.PullRequestBranch{
   338  							Ref: github.String("feature-2"),
   339  						},
   340  						Head: &github.PullRequestBranch{
   341  							SHA: github.String("sha3"),
   342  							Ref: github.String("feature-3"),
   343  						},
   344  					},
   345  				},
   346  				{
   347  					LocalRef: "git-pr/rebase/sha4",
   348  					PR: &github.PullRequest{
   349  						Number:  github.Int(4),
   350  						HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/4"),
   351  						Base: &github.PullRequestBranch{
   352  							Ref: github.String("feature-1"),
   353  						},
   354  						Head: &github.PullRequestBranch{
   355  							SHA: github.String("sha4"),
   356  							Ref: github.String("feature-4"),
   357  						},
   358  					},
   359  				},
   360  				{
   361  					LocalRef: "git-pr/rebase/sha5",
   362  					PR: &github.PullRequest{
   363  						Number:  github.Int(5),
   364  						HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/5"),
   365  						Base: &github.PullRequestBranch{
   366  							Ref: github.String("feature-2"),
   367  						},
   368  						Head: &github.PullRequestBranch{
   369  							SHA: github.String("sha5"),
   370  							Ref: github.String("feature-5"),
   371  						},
   372  					},
   373  				},
   374  				{
   375  					LocalRef: "git-pr/rebase/sha6",
   376  					PR: &github.PullRequest{
   377  						Number:  github.Int(6),
   378  						HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/6"),
   379  						Base: &github.PullRequestBranch{
   380  							Ref: github.String("feature-5"),
   381  						},
   382  						Head: &github.PullRequestBranch{
   383  							SHA: github.String("sha6"),
   384  							Ref: github.String("feature-6"),
   385  						},
   386  					},
   387  				},
   388  			}
   389  
   390  			tt.SHA1Hashes = map[string]string{
   391  				"feature-1": "sha1",
   392  				"feature-2": "not-sha2",
   393  				"feature-4": "not-sha4",
   394  				"feature-5": "sha5",
   395  			}
   396  			tt.SHA1Failures = []string{"feature-3", "feature-6"}
   397  
   398  			tt.WantBranchResets = []string{"feature-1", "feature-5"}
   399  			tt.WantPushes = map[string]string{
   400  				"git-pr/rebase/sha1": "feature-1",
   401  				"git-pr/rebase/sha2": "feature-2",
   402  				"git-pr/rebase/sha3": "feature-3",
   403  				"git-pr/rebase/sha4": "feature-4",
   404  				"git-pr/rebase/sha5": "feature-5",
   405  				"git-pr/rebase/sha6": "feature-6",
   406  			}
   407  			tt.WantBaseChanges = []int{2}
   408  			tt.WantResponse = service.RebaseResponse{
   409  				BranchesNotUpdated: []string{"feature-2", "feature-4"},
   410  			}
   411  
   412  			return
   413  		}(),
   414  		{
   415  			Desc: "current branch error",
   416  			Request: service.RebaseRequest{
   417  				Base: "derp",
   418  				PullRequests: []*github.PullRequest{
   419  					{}, // doesn't matter
   420  				},
   421  			},
   422  			SkipCommon: true,
   423  			SetupGit: func(git *gatewaytest.MockGit) {
   424  				git.EXPECT().CurrentBranch().
   425  					Return("", errors.New("not a git repository"))
   426  			},
   427  			WantErrors: []string{"not a git repository"},
   428  		},
   429  		{
   430  			Desc: "fetch error",
   431  			Request: service.RebaseRequest{
   432  				Base: "derp",
   433  				PullRequests: []*github.PullRequest{
   434  					{}, // doesn't matter
   435  				},
   436  			},
   437  			SkipCommon: true,
   438  			SetupGit: func(git *gatewaytest.MockGit) {
   439  				git.EXPECT().CurrentBranch().Return("master", nil)
   440  				git.EXPECT().Fetch(&gateway.FetchRequest{
   441  					Remote: "origin",
   442  				}).Return(errors.New("remote origin doesn't exist"))
   443  				git.EXPECT().Checkout("master").Return(nil)
   444  			},
   445  			WantErrors: []string{"remote origin doesn't exist"},
   446  		},
   447  		{
   448  			Desc: "fetch error",
   449  			Request: service.RebaseRequest{
   450  				Base: "derp",
   451  				PullRequests: []*github.PullRequest{
   452  					{}, // doesn't matter
   453  				},
   454  			},
   455  			SkipCommon: true,
   456  			SetupGit: func(git *gatewaytest.MockGit) {
   457  				git.EXPECT().CurrentBranch().Return("master", nil)
   458  				git.EXPECT().Fetch(&gateway.FetchRequest{Remote: "origin"}).Return(nil)
   459  
   460  				git.EXPECT().SHA1("origin/derp").
   461  					Return("", errors.New("could not find ref origin/derp"))
   462  
   463  				git.EXPECT().Checkout("master").Return(nil)
   464  			},
   465  			WantErrors: []string{"could not find ref origin/derp"},
   466  		},
   467  		{
   468  			Desc: "rebase error",
   469  			Request: service.RebaseRequest{
   470  				Base: "derp",
   471  				PullRequests: []*github.PullRequest{
   472  					{}, // doesn't matter
   473  				},
   474  			},
   475  			RebasePRsError: errors.New("could not rebase stuff"),
   476  			WantErrors:     []string{"could not rebase stuff"},
   477  		},
   478  		func() (tt testCase) {
   479  			tt.Desc = "push error"
   480  
   481  			pr := &github.PullRequest{
   482  				Number:  github.Int(1),
   483  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   484  				Base: &github.PullRequestBranch{
   485  					Ref: github.String("master"),
   486  				},
   487  				Head: &github.PullRequestBranch{
   488  					SHA: github.String("headsha"),
   489  					Ref: github.String("myfeature"),
   490  				},
   491  			}
   492  
   493  			tt.Request = service.RebaseRequest{
   494  				Base:         "master",
   495  				PullRequests: []*github.PullRequest{pr},
   496  			}
   497  			tt.RebasePRsResult = []rebasedPullRequest{
   498  				{PR: pr, LocalRef: "git-pr/rebase/headsha"},
   499  			}
   500  			tt.SHA1Hashes = map[string]string{"myfeature": "headsha"}
   501  
   502  			tt.SetupGit = func(git *gatewaytest.MockGit) {
   503  				git.EXPECT().Push(gomock.Any()).
   504  					Return(errors.New("remote timed out"))
   505  			}
   506  			tt.WantErrors = []string{"remote timed out"}
   507  
   508  			return
   509  		}(),
   510  		func() (tt testCase) {
   511  			tt.Desc = "update base error"
   512  
   513  			pr := &github.PullRequest{
   514  				Number:  github.Int(1),
   515  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   516  				Base: &github.PullRequestBranch{
   517  					Ref: github.String("master"),
   518  				},
   519  				Head: &github.PullRequestBranch{
   520  					SHA: github.String("headsha"),
   521  					Ref: github.String("myfeature"),
   522  				},
   523  			}
   524  
   525  			tt.Request = service.RebaseRequest{
   526  				Base:         "dev",
   527  				PullRequests: []*github.PullRequest{pr},
   528  			}
   529  			tt.RebasePRsResult = []rebasedPullRequest{
   530  				{PR: pr, LocalRef: "git-pr/rebase/headsha"},
   531  			}
   532  			tt.SHA1Hashes = map[string]string{"myfeature": "headsha"}
   533  
   534  			tt.WantPushes = map[string]string{"git-pr/rebase/headsha": "myfeature"}
   535  			tt.WantBranchResets = []string{"myfeature"}
   536  
   537  			tt.SetupGitHub = func(gh *gatewaytest.MockGitHub) {
   538  				gh.EXPECT().SetPullRequestBase(gomock.Any(), 1, "dev").
   539  					Return(errors.New("unauthorized operation"))
   540  			}
   541  			tt.WantErrors = []string{"unauthorized operation"}
   542  
   543  			return
   544  		}(),
   545  	}
   546  
   547  	for _, tt := range tests {
   548  		t.Run(tt.Desc, func(t *testing.T) {
   549  			mockCtrl := gomock.NewController(t)
   550  			defer mockCtrl.Finish()
   551  
   552  			git := gatewaytest.NewMockGit(mockCtrl)
   553  			gh := gatewaytest.NewMockGitHub(mockCtrl)
   554  
   555  			if !tt.SkipCommon {
   556  				git.EXPECT().CurrentBranch().Return("oldbranch", nil)
   557  				git.EXPECT().Checkout("oldbranch").Return(nil)
   558  				git.EXPECT().Fetch(&gateway.FetchRequest{Remote: "origin"}).Return(nil)
   559  				git.EXPECT().SHA1("origin/"+tt.Request.Base).Return("originbasesha", nil)
   560  			}
   561  
   562  			for _, branch := range tt.WantBranchResets {
   563  				git.EXPECT().ResetBranch(branch, "origin/"+branch).Return(nil)
   564  			}
   565  
   566  			if len(tt.WantPushes) > 0 {
   567  				git.EXPECT().Push(&gateway.PushRequest{
   568  					Remote: "origin",
   569  					Force:  true,
   570  					Refs:   tt.WantPushes,
   571  				}).Return(nil)
   572  			}
   573  
   574  			for branch, sha := range tt.SHA1Hashes {
   575  				git.EXPECT().SHA1(branch).Return(sha, nil)
   576  			}
   577  			for _, branch := range tt.SHA1Failures {
   578  				git.EXPECT().SHA1(branch).
   579  					Return("", fmt.Errorf("unknown branch %q", branch))
   580  			}
   581  
   582  			for _, prNum := range tt.WantBaseChanges {
   583  				gh.EXPECT().
   584  					SetPullRequestBase(gomock.Any(), prNum, tt.Request.Base).
   585  					Return(nil)
   586  			}
   587  
   588  			if tt.SetupGit != nil {
   589  				tt.SetupGit(git)
   590  			}
   591  			if tt.SetupGitHub != nil {
   592  				tt.SetupGitHub(gh)
   593  			}
   594  
   595  			service := NewService(ServiceConfig{
   596  				Git:    git,
   597  				GitHub: gh,
   598  			})
   599  			service.rebasePullRequests = fakeRebasePullRequests(
   600  				tt.RebasePRsResult, tt.RebasePRsError)
   601  
   602  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   603  			defer cancel()
   604  
   605  			res, err := service.Rebase(ctx, &tt.Request)
   606  			if len(tt.WantErrors) > 0 {
   607  				require.Error(t, err, "expected failure")
   608  				for _, msg := range tt.WantErrors {
   609  					assert.Contains(t, err.Error(), msg)
   610  				}
   611  				return
   612  			}
   613  
   614  			require.NoError(t, err, "expected success")
   615  
   616  			wantBranchesNotUpdated := make(map[string]struct{}, len(tt.WantResponse.BranchesNotUpdated))
   617  			for _, br := range tt.WantResponse.BranchesNotUpdated {
   618  				wantBranchesNotUpdated[br] = struct{}{}
   619  			}
   620  
   621  			gotBranchesNotUpdated := make(map[string]struct{}, len(res.BranchesNotUpdated))
   622  			for _, br := range res.BranchesNotUpdated {
   623  				gotBranchesNotUpdated[br] = struct{}{}
   624  			}
   625  
   626  			assert.Equal(t, wantBranchesNotUpdated, gotBranchesNotUpdated,
   627  				"BranchesNotUpdated must match")
   628  		})
   629  	}
   630  }
   631  
   632  func fakeRebasePullRequests(
   633  	results []rebasedPullRequest, err error,
   634  ) func(rebasePRConfig) (map[int]rebasedPullRequest, error) {
   635  	// For convenience, build the map from the list rather than manually in
   636  	// the test case.
   637  	resultMap := make(map[int]rebasedPullRequest)
   638  	for _, r := range results {
   639  		resultMap[r.PR.GetNumber()] = r
   640  	}
   641  
   642  	return func(rebasePRConfig) (map[int]rebasedPullRequest, error) {
   643  		return resultMap, err
   644  	}
   645  }
   646  
   647  type fakeRebase struct {
   648  	// Range of commits to rebase
   649  	FromRef string
   650  	ToRef   string
   651  
   652  	// ToRef after the rebase. This will be the Base of the returned by
   653  	// the RebaseHandle.
   654  	GiveRef string
   655  
   656  	// Rebases expected on the returned handle
   657  	WantRebases []fakeRebase
   658  }
   659  
   660  func setupFakeRebases(ctrl *gomock.Controller, h *gittest.MockRebaseHandle, rebases []fakeRebase) {
   661  	for _, r := range rebases {
   662  		newH := gittest.NewMockRebaseHandle(ctrl)
   663  		newH.EXPECT().Base().Return(r.GiveRef)
   664  		setupFakeRebases(ctrl, newH, r.WantRebases)
   665  
   666  		h.EXPECT().Rebase(r.FromRef, r.ToRef).Return(newH)
   667  	}
   668  }
   669  
   670  func TestRebasePullRequests(t *testing.T) {
   671  	type testCase struct {
   672  		Desc string
   673  
   674  		Author       string
   675  		Base         string
   676  		PullRequests []*github.PullRequest
   677  
   678  		// Dependents of different pull requests. May be partial or empty if
   679  		// part of the work is done by SetupGitHub.
   680  		Dependents map[string][]*github.PullRequest // base branch -> PRs
   681  
   682  		// Whether the given branches are owned by the current repo or not.
   683  		// May be partial or empty if part of the work is done by SetupGitHub.
   684  		BranchOwnership map[string]bool
   685  
   686  		// Customize the GitHub gateway mock.
   687  		SetupGitHub func(*gatewaytest.MockGitHub)
   688  
   689  		// Whether the bulkRebaser fails
   690  		RebaserError error
   691  
   692  		// Rebases expected on the base branch.
   693  		WantRebases []fakeRebase
   694  
   695  		WantResults []rebasedPullRequest
   696  		WantErrors  []string
   697  	}
   698  
   699  	tests := []testCase{
   700  		{Desc: "empty", Base: "shrug"},
   701  		func() (tt testCase) {
   702  			tt.Desc = "single"
   703  
   704  			pr := &github.PullRequest{
   705  				Number:  github.Int(1),
   706  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   707  				Base: &github.PullRequestBranch{
   708  					SHA: github.String("basesha"),
   709  				},
   710  				Head: &github.PullRequestBranch{
   711  					SHA: github.String("headsha"),
   712  					Ref: github.String("feature-1"),
   713  				},
   714  			}
   715  
   716  			tt.Base = "origin/master"
   717  			tt.PullRequests = []*github.PullRequest{pr}
   718  			tt.BranchOwnership = map[string]bool{
   719  				"feature-1": true,
   720  			}
   721  			tt.Dependents = map[string][]*github.PullRequest{"feature-1": {}}
   722  			tt.WantRebases = []fakeRebase{
   723  				{
   724  					FromRef: "basesha",
   725  					ToRef:   "headsha",
   726  					GiveRef: "newsha",
   727  				},
   728  			}
   729  			tt.WantResults = []rebasedPullRequest{
   730  				{LocalRef: "newsha", PR: pr},
   731  			}
   732  
   733  			return
   734  		}(),
   735  		func() (tt testCase) {
   736  			tt.Desc = "single wrong author"
   737  
   738  			pr := &github.PullRequest{
   739  				Number:  github.Int(1),
   740  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   741  				Base: &github.PullRequestBranch{
   742  					SHA: github.String("basesha"),
   743  				},
   744  				User: &github.User{Login: github.String("probablynotarealusername")},
   745  				Head: &github.PullRequestBranch{
   746  					SHA: github.String("headsha"),
   747  					Ref: github.String("feature-1"),
   748  				},
   749  			}
   750  
   751  			tt.Author = "abhinav"
   752  			tt.Base = "origin/master"
   753  			tt.PullRequests = []*github.PullRequest{pr}
   754  			tt.BranchOwnership = map[string]bool{"feature-1": true}
   755  
   756  			return
   757  		}(),
   758  		func() (tt testCase) {
   759  			tt.Desc = "github dependents failure"
   760  
   761  			pr := &github.PullRequest{
   762  				Number:  github.Int(1),
   763  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   764  				Base: &github.PullRequestBranch{
   765  					SHA: github.String("basesha"),
   766  				},
   767  				Head: &github.PullRequestBranch{
   768  					SHA: github.String("headsha"),
   769  					Ref: github.String("feature-1"),
   770  				},
   771  			}
   772  
   773  			tt.Base = "origin/master"
   774  			tt.PullRequests = []*github.PullRequest{pr}
   775  			tt.BranchOwnership = map[string]bool{"feature-1": true}
   776  
   777  			tt.SetupGitHub = func(gh *gatewaytest.MockGitHub) {
   778  				gh.EXPECT().ListPullRequestsByBase(gomock.Any(), "feature-1").
   779  					Return(nil, errors.New("great sadness"))
   780  			}
   781  
   782  			tt.RebaserError = errors.New("great sadness")
   783  			tt.WantRebases = []fakeRebase{
   784  				{
   785  					FromRef: "basesha",
   786  					ToRef:   "headsha",
   787  					GiveRef: "newsha",
   788  				},
   789  			}
   790  			tt.WantErrors = []string{"great sadness"}
   791  
   792  			return
   793  		}(),
   794  		func() (tt testCase) {
   795  			tt.Desc = "rebase failure"
   796  
   797  			pr := &github.PullRequest{
   798  				Number:  github.Int(1),
   799  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   800  				Base: &github.PullRequestBranch{
   801  					SHA: github.String("basesha"),
   802  				},
   803  				Head: &github.PullRequestBranch{
   804  					SHA: github.String("headsha"),
   805  					Ref: github.String("feature-1"),
   806  				},
   807  			}
   808  
   809  			tt.Base = "origin/master"
   810  			tt.PullRequests = []*github.PullRequest{pr}
   811  			tt.BranchOwnership = map[string]bool{
   812  				"feature-1": true,
   813  			}
   814  			tt.Dependents = map[string][]*github.PullRequest{"feature-1": {}}
   815  			tt.RebaserError = errors.New("great sadness")
   816  			tt.WantRebases = []fakeRebase{
   817  				{
   818  					FromRef: "basesha",
   819  					ToRef:   "headsha",
   820  					GiveRef: "newsha",
   821  				},
   822  			}
   823  			tt.WantErrors = []string{"great sadness"}
   824  
   825  			return
   826  		}(),
   827  		func() (tt testCase) {
   828  			tt.Desc = "single not owned"
   829  
   830  			pr := &github.PullRequest{
   831  				Number:  github.Int(1),
   832  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   833  				Base: &github.PullRequestBranch{
   834  					SHA: github.String("basesha"),
   835  				},
   836  				Head: &github.PullRequestBranch{
   837  					SHA: github.String("headsha"),
   838  					Ref: github.String("feature-1"),
   839  				},
   840  			}
   841  
   842  			tt.Base = "origin/master"
   843  			tt.PullRequests = []*github.PullRequest{pr}
   844  			tt.BranchOwnership = map[string]bool{
   845  				"feature-1": false,
   846  			}
   847  			tt.WantRebases = []fakeRebase{}
   848  			tt.WantResults = []rebasedPullRequest{}
   849  
   850  			return
   851  		}(),
   852  		func() (tt testCase) {
   853  			tt.Desc = "stack"
   854  
   855  			pr1 := &github.PullRequest{
   856  				Number:  github.Int(1),
   857  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   858  				Base: &github.PullRequestBranch{
   859  					SHA: github.String("mastersha"),
   860  				},
   861  				Head: &github.PullRequestBranch{
   862  					SHA: github.String("sha1"),
   863  					Ref: github.String("feature-1"),
   864  				},
   865  			}
   866  
   867  			pr2 := &github.PullRequest{
   868  				Number:  github.Int(2),
   869  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/2"),
   870  				Base: &github.PullRequestBranch{
   871  					SHA: github.String("sha1"),
   872  				},
   873  				Head: &github.PullRequestBranch{
   874  					SHA: github.String("sha2"),
   875  					Ref: github.String("feature-2"),
   876  				},
   877  			}
   878  
   879  			pr3 := &github.PullRequest{
   880  				Number:  github.Int(3),
   881  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/3"),
   882  				Base: &github.PullRequestBranch{
   883  					SHA: github.String("sha2"),
   884  				},
   885  				Head: &github.PullRequestBranch{
   886  					SHA: github.String("sha3"),
   887  					Ref: github.String("feature-3"),
   888  				},
   889  			}
   890  
   891  			tt.Base = "origin/master"
   892  			tt.PullRequests = []*github.PullRequest{pr1}
   893  			tt.Dependents = map[string][]*github.PullRequest{
   894  				"feature-1": []*github.PullRequest{pr2},
   895  				"feature-2": []*github.PullRequest{pr3},
   896  				"feature-3": []*github.PullRequest{},
   897  			}
   898  			tt.BranchOwnership = map[string]bool{
   899  				"feature-1": true,
   900  				"feature-2": true,
   901  				"feature-3": true,
   902  			}
   903  
   904  			tt.WantRebases = []fakeRebase{
   905  				{
   906  					FromRef: "mastersha",
   907  					ToRef:   "sha1",
   908  					GiveRef: "newsha1",
   909  					WantRebases: []fakeRebase{
   910  						{
   911  							FromRef: "sha1",
   912  							ToRef:   "sha2",
   913  							GiveRef: "newsha2",
   914  							WantRebases: []fakeRebase{
   915  								{FromRef: "sha2", ToRef: "sha3", GiveRef: "newsha3"},
   916  							},
   917  						},
   918  					},
   919  				},
   920  			}
   921  			tt.WantResults = []rebasedPullRequest{
   922  				{LocalRef: "newsha1", PR: pr1},
   923  				{LocalRef: "newsha2", PR: pr2},
   924  				{LocalRef: "newsha3", PR: pr3},
   925  			}
   926  
   927  			return
   928  		}(),
   929  		func() (tt testCase) {
   930  			tt.Desc = "stack partly not owned"
   931  
   932  			pr1 := &github.PullRequest{
   933  				Number:  github.Int(1),
   934  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   935  				Base: &github.PullRequestBranch{
   936  					SHA: github.String("mastersha"),
   937  				},
   938  				Head: &github.PullRequestBranch{
   939  					SHA: github.String("sha1"),
   940  					Ref: github.String("feature-1"),
   941  				},
   942  			}
   943  
   944  			pr2 := &github.PullRequest{
   945  				Number:  github.Int(2),
   946  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/2"),
   947  				Base: &github.PullRequestBranch{
   948  					SHA: github.String("sha1"),
   949  				},
   950  				Head: &github.PullRequestBranch{
   951  					SHA: github.String("sha2"),
   952  					Ref: github.String("feature-2"),
   953  				},
   954  			}
   955  
   956  			tt.Base = "origin/master"
   957  			tt.PullRequests = []*github.PullRequest{pr1}
   958  			tt.Dependents = map[string][]*github.PullRequest{
   959  				"feature-1": []*github.PullRequest{pr2},
   960  			}
   961  			tt.BranchOwnership = map[string]bool{
   962  				"feature-1": true,
   963  				"feature-2": false,
   964  			}
   965  
   966  			tt.WantRebases = []fakeRebase{
   967  				{
   968  					FromRef: "mastersha",
   969  					ToRef:   "sha1",
   970  					GiveRef: "newsha1",
   971  				},
   972  			}
   973  			tt.WantResults = []rebasedPullRequest{
   974  				{LocalRef: "newsha1", PR: pr1},
   975  			}
   976  
   977  			return
   978  		}(),
   979  		func() (tt testCase) {
   980  			tt.Desc = "stack partly wrong user"
   981  
   982  			pr1 := &github.PullRequest{
   983  				Number:  github.Int(1),
   984  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/1"),
   985  				User:    &github.User{Login: github.String("abhinav")},
   986  				Base: &github.PullRequestBranch{
   987  					SHA: github.String("mastersha"),
   988  				},
   989  				Head: &github.PullRequestBranch{
   990  					SHA: github.String("sha1"),
   991  					Ref: github.String("feature-1"),
   992  				},
   993  			}
   994  
   995  			pr2 := &github.PullRequest{
   996  				Number:  github.Int(2),
   997  				HTMLURL: github.String("http://github.com/abhinav/git-pr/pulls/2"),
   998  				User:    &github.User{Login: github.String("probablynotarealusername")},
   999  				Base: &github.PullRequestBranch{
  1000  					SHA: github.String("sha1"),
  1001  				},
  1002  				Head: &github.PullRequestBranch{
  1003  					SHA: github.String("sha2"),
  1004  					Ref: github.String("feature-2"),
  1005  				},
  1006  			}
  1007  
  1008  			tt.Author = "abhinav"
  1009  			tt.Base = "origin/master"
  1010  			tt.PullRequests = []*github.PullRequest{pr1}
  1011  			tt.Dependents = map[string][]*github.PullRequest{
  1012  				"feature-1": []*github.PullRequest{pr2},
  1013  			}
  1014  			tt.BranchOwnership = map[string]bool{
  1015  				"feature-1": true,
  1016  				"feature-2": true,
  1017  			}
  1018  
  1019  			tt.WantRebases = []fakeRebase{
  1020  				{
  1021  					FromRef: "mastersha",
  1022  					ToRef:   "sha1",
  1023  					GiveRef: "newsha1",
  1024  				},
  1025  			}
  1026  			tt.WantResults = []rebasedPullRequest{
  1027  				{LocalRef: "newsha1", PR: pr1},
  1028  			}
  1029  
  1030  			return
  1031  		}(),
  1032  	}
  1033  
  1034  	for _, tt := range tests {
  1035  		t.Run(tt.Desc, func(t *testing.T) {
  1036  			mockCtrl := gomock.NewController(t)
  1037  			defer mockCtrl.Finish()
  1038  
  1039  			rebaser := newMockBulkRebaser(mockCtrl)
  1040  			gh := gatewaytest.NewMockGitHub(mockCtrl)
  1041  
  1042  			for branch, deps := range tt.Dependents {
  1043  				gh.EXPECT().
  1044  					ListPullRequestsByBase(gomock.Any(), branch).
  1045  					Return(deps, nil)
  1046  			}
  1047  
  1048  			for br, owned := range tt.BranchOwnership {
  1049  				gh.EXPECT().
  1050  					IsOwned(gomock.Any(), prBranchMatcher(br)).
  1051  					Return(owned)
  1052  			}
  1053  
  1054  			if tt.SetupGitHub != nil {
  1055  				tt.SetupGitHub(gh)
  1056  			}
  1057  
  1058  			mockHandle := gittest.NewMockRebaseHandle(mockCtrl)
  1059  			setupFakeRebases(mockCtrl, mockHandle, tt.WantRebases)
  1060  			rebaser.EXPECT().Onto(tt.Base).Return(mockHandle)
  1061  			rebaser.EXPECT().Err().Return(tt.RebaserError).AnyTimes()
  1062  
  1063  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  1064  			defer cancel()
  1065  
  1066  			results, err := rebasePullRequests(rebasePRConfig{
  1067  				Context:      ctx,
  1068  				GitRebaser:   rebaser,
  1069  				GitHub:       gh,
  1070  				Author:       tt.Author,
  1071  				Base:         tt.Base,
  1072  				PullRequests: tt.PullRequests,
  1073  			})
  1074  
  1075  			if len(tt.WantErrors) > 0 {
  1076  				require.Error(t, err, "expected failure")
  1077  				for _, msg := range tt.WantErrors {
  1078  					assert.Contains(t, err.Error(), msg)
  1079  				}
  1080  				return
  1081  			}
  1082  
  1083  			require.NoError(t, err, "expected success")
  1084  
  1085  			wantResults := make(map[int]rebasedPullRequest)
  1086  			for _, r := range tt.WantResults {
  1087  				wantResults[r.PR.GetNumber()] = r
  1088  			}
  1089  			assert.Equal(t, wantResults, results)
  1090  		})
  1091  	}
  1092  }
  1093  
  1094  type mockBulkRebaser struct {
  1095  	ctrl *gomock.Controller
  1096  }
  1097  
  1098  var _ bulkRebaser = (*mockBulkRebaser)(nil)
  1099  
  1100  func newMockBulkRebaser(ctrl *gomock.Controller) *mockBulkRebaser {
  1101  	return &mockBulkRebaser{ctrl: ctrl}
  1102  }
  1103  
  1104  func (m *mockBulkRebaser) Err() error {
  1105  	results := m.ctrl.Call(m, "Err")
  1106  	err, _ := results[0].(error)
  1107  	return err
  1108  }
  1109  
  1110  func (m *mockBulkRebaser) Onto(name string) git.RebaseHandle {
  1111  	results := m.ctrl.Call(m, "Onto", name)
  1112  	h, _ := results[0].(git.RebaseHandle)
  1113  	return h
  1114  }
  1115  
  1116  func (m *mockBulkRebaser) EXPECT() _mockBulkRebaserRecorder {
  1117  	return _mockBulkRebaserRecorder{m: m, ctrl: m.ctrl}
  1118  }
  1119  
  1120  type _mockBulkRebaserRecorder struct {
  1121  	m    *mockBulkRebaser
  1122  	ctrl *gomock.Controller
  1123  }
  1124  
  1125  func (r _mockBulkRebaserRecorder) Err() *gomock.Call {
  1126  	return r.ctrl.RecordCall(r.m, "Err")
  1127  }
  1128  
  1129  func (r _mockBulkRebaserRecorder) Onto(name interface{}) *gomock.Call {
  1130  	return r.ctrl.RecordCall(r.m, "Onto", name)
  1131  }
  1132  
  1133  // Matches *github.PullRequestBranch objects with the given branch.
  1134  type prBranchMatcher string
  1135  
  1136  var _ gomock.Matcher = prBranchMatcher("")
  1137  
  1138  func (m prBranchMatcher) String() string {
  1139  	return fmt.Sprintf("pull request branch %q", string(m))
  1140  }
  1141  
  1142  func (m prBranchMatcher) Matches(x interface{}) bool {
  1143  	b, ok := x.(*github.PullRequestBranch)
  1144  	if !ok {
  1145  		return false
  1146  	}
  1147  
  1148  	return b.GetRef() == string(m)
  1149  }