sigs.k8s.io/prow@v0.0.0-20240503223140-c5e374dc7eb1/pkg/tide/github_test.go (about)

     1  /*
     2  Copyright 2019 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package tide
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"reflect"
    23  	"testing"
    24  	"time"
    25  
    26  	githubql "github.com/shurcooL/githubv4"
    27  	"github.com/sirupsen/logrus"
    28  	"k8s.io/apimachinery/pkg/api/equality"
    29  	"k8s.io/apimachinery/pkg/util/diff"
    30  	"sigs.k8s.io/prow/pkg/config"
    31  	"sigs.k8s.io/prow/pkg/git/types"
    32  	"sigs.k8s.io/prow/pkg/github"
    33  )
    34  
    35  func TestSearch(t *testing.T) {
    36  	const q = "random search string"
    37  	now := time.Now()
    38  	earlier := now.Add(-5 * time.Hour)
    39  	makePRs := func(numbers ...int) []PullRequest {
    40  		var prs []PullRequest
    41  		for _, n := range numbers {
    42  			prs = append(prs, PullRequest{Number: githubql.Int(n)})
    43  		}
    44  		return prs
    45  	}
    46  	makeQuery := func(more bool, cursor string, numbers ...int) searchQuery {
    47  		var sq searchQuery
    48  		sq.Search.PageInfo.HasNextPage = githubql.Boolean(more)
    49  		sq.Search.PageInfo.EndCursor = githubql.String(cursor)
    50  		for _, pr := range makePRs(numbers...) {
    51  			sq.Search.Nodes = append(sq.Search.Nodes, PRNode{pr})
    52  		}
    53  		return sq
    54  	}
    55  
    56  	cases := []struct {
    57  		name     string
    58  		start    time.Time
    59  		end      time.Time
    60  		q        string
    61  		cursors  []*githubql.String
    62  		sqs      []searchQuery
    63  		errs     []error
    64  		expected []PullRequest
    65  		err      bool
    66  	}{
    67  		{
    68  			name:    "single page works",
    69  			start:   earlier,
    70  			end:     now,
    71  			q:       datedQuery(q, earlier, now),
    72  			cursors: []*githubql.String{nil},
    73  			sqs: []searchQuery{
    74  				makeQuery(false, "", 1, 2),
    75  			},
    76  			errs:     []error{nil},
    77  			expected: makePRs(1, 2),
    78  		},
    79  		{
    80  			name:    "fail on first page",
    81  			start:   earlier,
    82  			end:     now,
    83  			q:       datedQuery(q, earlier, now),
    84  			cursors: []*githubql.String{nil},
    85  			sqs: []searchQuery{
    86  				{},
    87  			},
    88  			errs: []error{errors.New("injected error")},
    89  			err:  true,
    90  		},
    91  		{
    92  			name:    "set minimum start time",
    93  			start:   time.Time{},
    94  			end:     now,
    95  			q:       datedQuery(q, floor(time.Time{}), now),
    96  			cursors: []*githubql.String{nil},
    97  			sqs: []searchQuery{
    98  				makeQuery(false, "", 1, 2),
    99  			},
   100  			errs:     []error{nil},
   101  			expected: makePRs(1, 2),
   102  		},
   103  		{
   104  			name:  "can handle multiple pages of results",
   105  			start: earlier,
   106  			end:   now,
   107  			q:     datedQuery(q, earlier, now),
   108  			cursors: []*githubql.String{
   109  				nil,
   110  				githubql.NewString("first"),
   111  				githubql.NewString("second"),
   112  			},
   113  			sqs: []searchQuery{
   114  				makeQuery(true, "first", 1, 2),
   115  				makeQuery(true, "second", 3, 4),
   116  				makeQuery(false, "", 5, 6),
   117  			},
   118  			errs:     []error{nil, nil, nil},
   119  			expected: makePRs(1, 2, 3, 4, 5, 6),
   120  		},
   121  		{
   122  			name:  "return partial results on later page failure",
   123  			start: earlier,
   124  			end:   now,
   125  			q:     datedQuery(q, earlier, now),
   126  			cursors: []*githubql.String{
   127  				nil,
   128  				githubql.NewString("first"),
   129  			},
   130  			sqs: []searchQuery{
   131  				makeQuery(true, "first", 1, 2),
   132  				{},
   133  			},
   134  			errs:     []error{nil, errors.New("second page error")},
   135  			expected: makePRs(1, 2),
   136  			err:      true,
   137  		},
   138  	}
   139  
   140  	for _, tc := range cases {
   141  		t.Run(tc.name, func(t *testing.T) {
   142  			client := &GitHubProvider{}
   143  			var i int
   144  			querier := func(_ context.Context, result interface{}, actual map[string]interface{}, _ string) error {
   145  				expected := map[string]interface{}{
   146  					"query":        githubql.String(tc.q),
   147  					"searchCursor": tc.cursors[i],
   148  				}
   149  				if !equality.Semantic.DeepEqual(expected, actual) {
   150  					t.Errorf("call %d vars do not match:\n%s", i, diff.ObjectReflectDiff(expected, actual))
   151  				}
   152  				ret := result.(*searchQuery)
   153  				err := tc.errs[i]
   154  				sq := tc.sqs[i]
   155  				i++
   156  				if err != nil {
   157  					return err
   158  				}
   159  				*ret = sq
   160  				return nil
   161  			}
   162  			prs, err := client.search(querier, logrus.WithField("test", tc.name), q, tc.start, tc.end, "")
   163  			switch {
   164  			case err != nil:
   165  				if !tc.err {
   166  					t.Errorf("unexpected error: %v", err)
   167  				}
   168  			case tc.err:
   169  				t.Errorf("failed to receive expected error")
   170  			}
   171  
   172  			if !reflect.DeepEqual(tc.expected, prs) {
   173  				t.Errorf("prs do not match:\n%s", diff.ObjectReflectDiff(tc.expected, prs))
   174  			}
   175  		})
   176  	}
   177  }
   178  
   179  func TestPrepareMergeDetails(t *testing.T) {
   180  	pr := PullRequest{
   181  		Number:     githubql.Int(1),
   182  		Mergeable:  githubql.MergeableStateMergeable,
   183  		HeadRefOID: githubql.String("SHA"),
   184  		Title:      "my commit title",
   185  		Body:       "my commit body",
   186  	}
   187  
   188  	testCases := []struct {
   189  		name        string
   190  		tpl         config.TideMergeCommitTemplate
   191  		pr          PullRequest
   192  		mergeMethod types.PullRequestMergeType
   193  		expected    github.MergeDetails
   194  	}{{
   195  		name:        "No commit template",
   196  		tpl:         config.TideMergeCommitTemplate{},
   197  		pr:          pr,
   198  		mergeMethod: "merge",
   199  		expected: github.MergeDetails{
   200  			SHA:         "SHA",
   201  			MergeMethod: "merge",
   202  		},
   203  	}, {
   204  		name: "No commit template fields",
   205  		tpl: config.TideMergeCommitTemplate{
   206  			Title: nil,
   207  			Body:  nil,
   208  		},
   209  		pr:          pr,
   210  		mergeMethod: "merge",
   211  		expected: github.MergeDetails{
   212  			SHA:         "SHA",
   213  			MergeMethod: "merge",
   214  		},
   215  	}, {
   216  		name: "Static commit template",
   217  		tpl: config.TideMergeCommitTemplate{
   218  			Title: getTemplate("CommitTitle", "static title"),
   219  			Body:  getTemplate("CommitBody", "static body"),
   220  		},
   221  		pr:          pr,
   222  		mergeMethod: "merge",
   223  		expected: github.MergeDetails{
   224  			SHA:           "SHA",
   225  			MergeMethod:   "merge",
   226  			CommitTitle:   "static title",
   227  			CommitMessage: "static body",
   228  		},
   229  	}, {
   230  		name: "Commit template uses PullRequest fields",
   231  		tpl: config.TideMergeCommitTemplate{
   232  			Title: getTemplate("CommitTitle", "{{ .Number }}: {{ .Title }}"),
   233  			Body:  getTemplate("CommitBody", "{{ .HeadRefOID }} - {{ .Body }}"),
   234  		},
   235  		pr:          pr,
   236  		mergeMethod: "merge",
   237  		expected: github.MergeDetails{
   238  			SHA:           "SHA",
   239  			MergeMethod:   "merge",
   240  			CommitTitle:   "1: my commit title",
   241  			CommitMessage: "SHA - my commit body",
   242  		},
   243  	}, {
   244  		name: "Commit template uses nonexistent fields",
   245  		tpl: config.TideMergeCommitTemplate{
   246  			Title: getTemplate("CommitTitle", "{{ .Hello }}"),
   247  			Body:  getTemplate("CommitBody", "{{ .World }}"),
   248  		},
   249  		pr:          pr,
   250  		mergeMethod: "merge",
   251  		expected: github.MergeDetails{
   252  			SHA:         "SHA",
   253  			MergeMethod: "merge",
   254  		},
   255  	}}
   256  
   257  	for _, test := range testCases {
   258  		t.Run(test.name, func(t *testing.T) {
   259  			cfg := &config.Config{}
   260  			cfgAgent := &config.Agent{}
   261  			cfgAgent.Set(cfg)
   262  			provider := &GitHubProvider{
   263  				cfg:    cfgAgent.Config,
   264  				ghc:    &fgc{},
   265  				logger: logrus.WithContext(context.Background()),
   266  			}
   267  
   268  			actual := provider.prepareMergeDetails(test.tpl, *CodeReviewCommonFromPullRequest(&test.pr), test.mergeMethod)
   269  
   270  			if !reflect.DeepEqual(actual, test.expected) {
   271  				t.Errorf("Case %s failed: expected %+v, got %+v", test.name, test.expected, actual)
   272  			}
   273  		})
   274  	}
   275  }
   276  
   277  func TestHeadContexts(t *testing.T) {
   278  	type commitContext struct {
   279  		// one context per commit for testing
   280  		context string
   281  		sha     string
   282  	}
   283  
   284  	win := "win"
   285  	lose := "lose"
   286  	headSHA := "head"
   287  	testCases := []struct {
   288  		name                string
   289  		commitContexts      []commitContext
   290  		expectAPICall       bool
   291  		expectChecksAPICall bool
   292  	}{
   293  		{
   294  			name: "first commit is head",
   295  			commitContexts: []commitContext{
   296  				{context: win, sha: headSHA},
   297  				{context: lose, sha: "other"},
   298  				{context: lose, sha: "sha"},
   299  			},
   300  		},
   301  		{
   302  			name: "last commit is head",
   303  			commitContexts: []commitContext{
   304  				{context: lose, sha: "shaaa"},
   305  				{context: lose, sha: "other"},
   306  				{context: win, sha: headSHA},
   307  			},
   308  		},
   309  		{
   310  			name: "no commit is head, falling back to v3 api and getting context via status api",
   311  			commitContexts: []commitContext{
   312  				{context: lose, sha: "shaaa"},
   313  				{context: lose, sha: "other"},
   314  				{context: lose, sha: "sha"},
   315  			},
   316  			expectAPICall: true,
   317  		},
   318  		{
   319  			name: "no commit is head, falling back to v3 api and getting context via checks api",
   320  			commitContexts: []commitContext{
   321  				{context: lose, sha: "shaaa"},
   322  				{context: lose, sha: "other"},
   323  				{context: lose, sha: "sha"},
   324  			},
   325  			expectAPICall:       true,
   326  			expectChecksAPICall: true,
   327  		},
   328  	}
   329  
   330  	for _, tc := range testCases {
   331  		t.Run(tc.name, func(t *testing.T) {
   332  			t.Logf("Running test case %q", tc.name)
   333  			fgc := &fgc{}
   334  			if !tc.expectChecksAPICall {
   335  				fgc.combinedStatus = map[string]string{win: string(githubql.StatusStateSuccess)}
   336  			} else {
   337  				fgc.checkRuns = &github.CheckRunList{CheckRuns: []github.CheckRun{
   338  					{Name: win, Status: "completed", Conclusion: "neutral"},
   339  				}}
   340  			}
   341  			if tc.expectAPICall {
   342  				fgc.expectedSHA = headSHA
   343  			}
   344  			provider := &GitHubProvider{
   345  				ghc:    fgc,
   346  				logger: logrus.WithField("component", "tide"),
   347  			}
   348  			pr := &PullRequest{HeadRefOID: githubql.String(headSHA)}
   349  			for _, ctx := range tc.commitContexts {
   350  				commit := Commit{
   351  					Status: struct{ Contexts []Context }{
   352  						Contexts: []Context{
   353  							{
   354  								Context: githubql.String(ctx.context),
   355  							},
   356  						},
   357  					},
   358  					OID: githubql.String(ctx.sha),
   359  				}
   360  				pr.Commits.Nodes = append(pr.Commits.Nodes, struct{ Commit Commit }{commit})
   361  			}
   362  
   363  			contexts, err := provider.headContexts(CodeReviewCommonFromPullRequest(pr))
   364  			if err != nil {
   365  				t.Fatalf("Unexpected error from headContexts: %v", err)
   366  			}
   367  			if len(contexts) != 1 || string(contexts[0].Context) != win {
   368  				t.Errorf("Expected exactly 1 %q context, but got: %#v", win, contexts)
   369  			}
   370  		})
   371  	}
   372  }