sigs.k8s.io/prow@v0.0.0-20240503223140-c5e374dc7eb1/pkg/tide/github_test.go (about) 1 /* 2 Copyright 2019 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package tide 18 19 import ( 20 "context" 21 "errors" 22 "reflect" 23 "testing" 24 "time" 25 26 githubql "github.com/shurcooL/githubv4" 27 "github.com/sirupsen/logrus" 28 "k8s.io/apimachinery/pkg/api/equality" 29 "k8s.io/apimachinery/pkg/util/diff" 30 "sigs.k8s.io/prow/pkg/config" 31 "sigs.k8s.io/prow/pkg/git/types" 32 "sigs.k8s.io/prow/pkg/github" 33 ) 34 35 func TestSearch(t *testing.T) { 36 const q = "random search string" 37 now := time.Now() 38 earlier := now.Add(-5 * time.Hour) 39 makePRs := func(numbers ...int) []PullRequest { 40 var prs []PullRequest 41 for _, n := range numbers { 42 prs = append(prs, PullRequest{Number: githubql.Int(n)}) 43 } 44 return prs 45 } 46 makeQuery := func(more bool, cursor string, numbers ...int) searchQuery { 47 var sq searchQuery 48 sq.Search.PageInfo.HasNextPage = githubql.Boolean(more) 49 sq.Search.PageInfo.EndCursor = githubql.String(cursor) 50 for _, pr := range makePRs(numbers...) { 51 sq.Search.Nodes = append(sq.Search.Nodes, PRNode{pr}) 52 } 53 return sq 54 } 55 56 cases := []struct { 57 name string 58 start time.Time 59 end time.Time 60 q string 61 cursors []*githubql.String 62 sqs []searchQuery 63 errs []error 64 expected []PullRequest 65 err bool 66 }{ 67 { 68 name: "single page works", 69 start: earlier, 70 end: now, 71 q: datedQuery(q, earlier, now), 72 cursors: []*githubql.String{nil}, 73 sqs: []searchQuery{ 74 makeQuery(false, "", 1, 2), 75 }, 76 errs: []error{nil}, 77 expected: makePRs(1, 2), 78 }, 79 { 80 name: "fail on first page", 81 start: earlier, 82 end: now, 83 q: datedQuery(q, earlier, now), 84 cursors: []*githubql.String{nil}, 85 sqs: []searchQuery{ 86 {}, 87 }, 88 errs: []error{errors.New("injected error")}, 89 err: true, 90 }, 91 { 92 name: "set minimum start time", 93 start: time.Time{}, 94 end: now, 95 q: datedQuery(q, floor(time.Time{}), now), 96 cursors: []*githubql.String{nil}, 97 sqs: []searchQuery{ 98 makeQuery(false, "", 1, 2), 99 }, 100 errs: []error{nil}, 101 expected: makePRs(1, 2), 102 }, 103 { 104 name: "can handle multiple pages of results", 105 start: earlier, 106 end: now, 107 q: datedQuery(q, earlier, now), 108 cursors: []*githubql.String{ 109 nil, 110 githubql.NewString("first"), 111 githubql.NewString("second"), 112 }, 113 sqs: []searchQuery{ 114 makeQuery(true, "first", 1, 2), 115 makeQuery(true, "second", 3, 4), 116 makeQuery(false, "", 5, 6), 117 }, 118 errs: []error{nil, nil, nil}, 119 expected: makePRs(1, 2, 3, 4, 5, 6), 120 }, 121 { 122 name: "return partial results on later page failure", 123 start: earlier, 124 end: now, 125 q: datedQuery(q, earlier, now), 126 cursors: []*githubql.String{ 127 nil, 128 githubql.NewString("first"), 129 }, 130 sqs: []searchQuery{ 131 makeQuery(true, "first", 1, 2), 132 {}, 133 }, 134 errs: []error{nil, errors.New("second page error")}, 135 expected: makePRs(1, 2), 136 err: true, 137 }, 138 } 139 140 for _, tc := range cases { 141 t.Run(tc.name, func(t *testing.T) { 142 client := &GitHubProvider{} 143 var i int 144 querier := func(_ context.Context, result interface{}, actual map[string]interface{}, _ string) error { 145 expected := map[string]interface{}{ 146 "query": githubql.String(tc.q), 147 "searchCursor": tc.cursors[i], 148 } 149 if !equality.Semantic.DeepEqual(expected, actual) { 150 t.Errorf("call %d vars do not match:\n%s", i, diff.ObjectReflectDiff(expected, actual)) 151 } 152 ret := result.(*searchQuery) 153 err := tc.errs[i] 154 sq := tc.sqs[i] 155 i++ 156 if err != nil { 157 return err 158 } 159 *ret = sq 160 return nil 161 } 162 prs, err := client.search(querier, logrus.WithField("test", tc.name), q, tc.start, tc.end, "") 163 switch { 164 case err != nil: 165 if !tc.err { 166 t.Errorf("unexpected error: %v", err) 167 } 168 case tc.err: 169 t.Errorf("failed to receive expected error") 170 } 171 172 if !reflect.DeepEqual(tc.expected, prs) { 173 t.Errorf("prs do not match:\n%s", diff.ObjectReflectDiff(tc.expected, prs)) 174 } 175 }) 176 } 177 } 178 179 func TestPrepareMergeDetails(t *testing.T) { 180 pr := PullRequest{ 181 Number: githubql.Int(1), 182 Mergeable: githubql.MergeableStateMergeable, 183 HeadRefOID: githubql.String("SHA"), 184 Title: "my commit title", 185 Body: "my commit body", 186 } 187 188 testCases := []struct { 189 name string 190 tpl config.TideMergeCommitTemplate 191 pr PullRequest 192 mergeMethod types.PullRequestMergeType 193 expected github.MergeDetails 194 }{{ 195 name: "No commit template", 196 tpl: config.TideMergeCommitTemplate{}, 197 pr: pr, 198 mergeMethod: "merge", 199 expected: github.MergeDetails{ 200 SHA: "SHA", 201 MergeMethod: "merge", 202 }, 203 }, { 204 name: "No commit template fields", 205 tpl: config.TideMergeCommitTemplate{ 206 Title: nil, 207 Body: nil, 208 }, 209 pr: pr, 210 mergeMethod: "merge", 211 expected: github.MergeDetails{ 212 SHA: "SHA", 213 MergeMethod: "merge", 214 }, 215 }, { 216 name: "Static commit template", 217 tpl: config.TideMergeCommitTemplate{ 218 Title: getTemplate("CommitTitle", "static title"), 219 Body: getTemplate("CommitBody", "static body"), 220 }, 221 pr: pr, 222 mergeMethod: "merge", 223 expected: github.MergeDetails{ 224 SHA: "SHA", 225 MergeMethod: "merge", 226 CommitTitle: "static title", 227 CommitMessage: "static body", 228 }, 229 }, { 230 name: "Commit template uses PullRequest fields", 231 tpl: config.TideMergeCommitTemplate{ 232 Title: getTemplate("CommitTitle", "{{ .Number }}: {{ .Title }}"), 233 Body: getTemplate("CommitBody", "{{ .HeadRefOID }} - {{ .Body }}"), 234 }, 235 pr: pr, 236 mergeMethod: "merge", 237 expected: github.MergeDetails{ 238 SHA: "SHA", 239 MergeMethod: "merge", 240 CommitTitle: "1: my commit title", 241 CommitMessage: "SHA - my commit body", 242 }, 243 }, { 244 name: "Commit template uses nonexistent fields", 245 tpl: config.TideMergeCommitTemplate{ 246 Title: getTemplate("CommitTitle", "{{ .Hello }}"), 247 Body: getTemplate("CommitBody", "{{ .World }}"), 248 }, 249 pr: pr, 250 mergeMethod: "merge", 251 expected: github.MergeDetails{ 252 SHA: "SHA", 253 MergeMethod: "merge", 254 }, 255 }} 256 257 for _, test := range testCases { 258 t.Run(test.name, func(t *testing.T) { 259 cfg := &config.Config{} 260 cfgAgent := &config.Agent{} 261 cfgAgent.Set(cfg) 262 provider := &GitHubProvider{ 263 cfg: cfgAgent.Config, 264 ghc: &fgc{}, 265 logger: logrus.WithContext(context.Background()), 266 } 267 268 actual := provider.prepareMergeDetails(test.tpl, *CodeReviewCommonFromPullRequest(&test.pr), test.mergeMethod) 269 270 if !reflect.DeepEqual(actual, test.expected) { 271 t.Errorf("Case %s failed: expected %+v, got %+v", test.name, test.expected, actual) 272 } 273 }) 274 } 275 } 276 277 func TestHeadContexts(t *testing.T) { 278 type commitContext struct { 279 // one context per commit for testing 280 context string 281 sha string 282 } 283 284 win := "win" 285 lose := "lose" 286 headSHA := "head" 287 testCases := []struct { 288 name string 289 commitContexts []commitContext 290 expectAPICall bool 291 expectChecksAPICall bool 292 }{ 293 { 294 name: "first commit is head", 295 commitContexts: []commitContext{ 296 {context: win, sha: headSHA}, 297 {context: lose, sha: "other"}, 298 {context: lose, sha: "sha"}, 299 }, 300 }, 301 { 302 name: "last commit is head", 303 commitContexts: []commitContext{ 304 {context: lose, sha: "shaaa"}, 305 {context: lose, sha: "other"}, 306 {context: win, sha: headSHA}, 307 }, 308 }, 309 { 310 name: "no commit is head, falling back to v3 api and getting context via status api", 311 commitContexts: []commitContext{ 312 {context: lose, sha: "shaaa"}, 313 {context: lose, sha: "other"}, 314 {context: lose, sha: "sha"}, 315 }, 316 expectAPICall: true, 317 }, 318 { 319 name: "no commit is head, falling back to v3 api and getting context via checks api", 320 commitContexts: []commitContext{ 321 {context: lose, sha: "shaaa"}, 322 {context: lose, sha: "other"}, 323 {context: lose, sha: "sha"}, 324 }, 325 expectAPICall: true, 326 expectChecksAPICall: true, 327 }, 328 } 329 330 for _, tc := range testCases { 331 t.Run(tc.name, func(t *testing.T) { 332 t.Logf("Running test case %q", tc.name) 333 fgc := &fgc{} 334 if !tc.expectChecksAPICall { 335 fgc.combinedStatus = map[string]string{win: string(githubql.StatusStateSuccess)} 336 } else { 337 fgc.checkRuns = &github.CheckRunList{CheckRuns: []github.CheckRun{ 338 {Name: win, Status: "completed", Conclusion: "neutral"}, 339 }} 340 } 341 if tc.expectAPICall { 342 fgc.expectedSHA = headSHA 343 } 344 provider := &GitHubProvider{ 345 ghc: fgc, 346 logger: logrus.WithField("component", "tide"), 347 } 348 pr := &PullRequest{HeadRefOID: githubql.String(headSHA)} 349 for _, ctx := range tc.commitContexts { 350 commit := Commit{ 351 Status: struct{ Contexts []Context }{ 352 Contexts: []Context{ 353 { 354 Context: githubql.String(ctx.context), 355 }, 356 }, 357 }, 358 OID: githubql.String(ctx.sha), 359 } 360 pr.Commits.Nodes = append(pr.Commits.Nodes, struct{ Commit Commit }{commit}) 361 } 362 363 contexts, err := provider.headContexts(CodeReviewCommonFromPullRequest(pr)) 364 if err != nil { 365 t.Fatalf("Unexpected error from headContexts: %v", err) 366 } 367 if len(contexts) != 1 || string(contexts[0].Context) != win { 368 t.Errorf("Expected exactly 1 %q context, but got: %#v", win, contexts) 369 } 370 }) 371 } 372 }