github.com/google/go-github/v69@v69.2.0/github/github_test.go (about) 1 // Copyright 2013 The go-github AUTHORS. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 package github 7 8 import ( 9 "context" 10 "encoding/json" 11 "errors" 12 "fmt" 13 "io" 14 "net/http" 15 "net/http/httptest" 16 "net/url" 17 "os" 18 "path/filepath" 19 "reflect" 20 "strconv" 21 "strings" 22 "testing" 23 "time" 24 25 "github.com/google/go-cmp/cmp" 26 ) 27 28 const ( 29 // baseURLPath is a non-empty Client.BaseURL path to use during tests, 30 // to ensure relative URLs are used for all endpoints. See issue #752. 31 baseURLPath = "/api-v3" 32 ) 33 34 // setup sets up a test HTTP server along with a github.Client that is 35 // configured to talk to that test server. Tests should register handlers on 36 // mux which provide mock responses for the API method being tested. 37 func setup(t *testing.T) (client *Client, mux *http.ServeMux, serverURL string) { 38 t.Helper() 39 // mux is the HTTP request multiplexer used with the test server. 40 mux = http.NewServeMux() 41 42 // We want to ensure that tests catch mistakes where the endpoint URL is 43 // specified as absolute rather than relative. It only makes a difference 44 // when there's a non-empty base URL path. So, use that. See issue #752. 45 apiHandler := http.NewServeMux() 46 apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux)) 47 apiHandler.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { 48 fmt.Fprintln(os.Stderr, "FAIL: Client.BaseURL path prefix is not preserved in the request URL:") 49 fmt.Fprintln(os.Stderr) 50 fmt.Fprintln(os.Stderr, "\t"+req.URL.String()) 51 fmt.Fprintln(os.Stderr) 52 fmt.Fprintln(os.Stderr, "\tDid you accidentally use an absolute endpoint URL rather than relative?") 53 fmt.Fprintln(os.Stderr, "\tSee https://github.com/google/go-github/issues/752 for information.") 54 http.Error(w, "Client.BaseURL path prefix is not preserved in the request URL.", http.StatusInternalServerError) 55 }) 56 57 // server is a test HTTP server used to provide mock API responses. 58 server := httptest.NewServer(apiHandler) 59 60 // client is the GitHub client being tested and is 61 // configured to use test server. 62 client = NewClient(nil) 63 url, _ := url.Parse(server.URL + baseURLPath + "/") 64 client.BaseURL = url 65 client.UploadURL = url 66 67 t.Cleanup(server.Close) 68 69 return client, mux, server.URL 70 } 71 72 // openTestFile creates a new file with the given name and content for testing. 73 // In order to ensure the exact file name, this function will create a new temp 74 // directory, and create the file in that directory. The file is automatically 75 // cleaned up after the test. 76 func openTestFile(t *testing.T, name, content string) *os.File { 77 t.Helper() 78 fname := filepath.Join(t.TempDir(), name) 79 err := os.WriteFile(fname, []byte(content), 0600) 80 if err != nil { 81 t.Fatal(err) 82 } 83 file, err := os.Open(fname) 84 if err != nil { 85 t.Fatal(err) 86 } 87 88 t.Cleanup(func() { file.Close() }) 89 90 return file 91 } 92 93 func testMethod(t *testing.T, r *http.Request, want string) { 94 t.Helper() 95 if got := r.Method; got != want { 96 t.Errorf("Request method: %v, want %v", got, want) 97 } 98 } 99 100 type values map[string]string 101 102 func testFormValues(t *testing.T, r *http.Request, values values) { 103 t.Helper() 104 want := url.Values{} 105 for k, v := range values { 106 want.Set(k, v) 107 } 108 109 assertNilError(t, r.ParseForm()) 110 if got := r.Form; !cmp.Equal(got, want) { 111 t.Errorf("Request parameters: %v, want %v", got, want) 112 } 113 } 114 115 func testHeader(t *testing.T, r *http.Request, header string, want string) { 116 t.Helper() 117 if got := r.Header.Get(header); got != want { 118 t.Errorf("Header.Get(%q) returned %q, want %q", header, got, want) 119 } 120 } 121 122 func testURLParseError(t *testing.T, err error) { 123 t.Helper() 124 if err == nil { 125 t.Errorf("Expected error to be returned") 126 } 127 if err, ok := err.(*url.Error); !ok || err.Op != "parse" { 128 t.Errorf("Expected URL parse error, got %+v", err) 129 } 130 } 131 132 func testBody(t *testing.T, r *http.Request, want string) { 133 t.Helper() 134 b, err := io.ReadAll(r.Body) 135 if err != nil { 136 t.Errorf("Error reading request body: %v", err) 137 } 138 if got := string(b); got != want { 139 t.Errorf("request Body is %s, want %s", got, want) 140 } 141 } 142 143 // Test whether the marshaling of v produces JSON that corresponds 144 // to the want string. 145 func testJSONMarshal(t *testing.T, v interface{}, want string) { 146 t.Helper() 147 // Unmarshal the wanted JSON, to verify its correctness, and marshal it back 148 // to sort the keys. 149 u := reflect.New(reflect.TypeOf(v)).Interface() 150 if err := json.Unmarshal([]byte(want), &u); err != nil { 151 t.Errorf("Unable to unmarshal JSON for %v: %v", want, err) 152 } 153 w, err := json.MarshalIndent(u, "", " ") 154 if err != nil { 155 t.Errorf("Unable to marshal JSON for %#v", u) 156 } 157 158 // Marshal the target value. 159 got, err := json.MarshalIndent(v, "", " ") 160 if err != nil { 161 t.Errorf("Unable to marshal JSON for %#v", v) 162 } 163 164 if diff := cmp.Diff(string(w), string(got)); diff != "" { 165 t.Errorf("json.Marshal returned:\n%s\nwant:\n%s\ndiff:\n%v", got, w, diff) 166 } 167 } 168 169 // Test whether the v fields have the url tag and the parsing of v 170 // produces query parameters that corresponds to the want string. 171 func testAddURLOptions(t *testing.T, url string, v interface{}, want string) { 172 t.Helper() 173 174 vt := reflect.Indirect(reflect.ValueOf(v)).Type() 175 for i := 0; i < vt.NumField(); i++ { 176 field := vt.Field(i) 177 if alias, ok := field.Tag.Lookup("url"); ok { 178 if alias == "" { 179 t.Errorf("The field %+v has a blank url tag", field) 180 } 181 } else { 182 t.Errorf("The field %+v has no url tag specified", field) 183 } 184 } 185 186 got, err := addOptions(url, v) 187 if err != nil { 188 t.Errorf("Unable to add %#v as query parameters", v) 189 } 190 191 if got != want { 192 t.Errorf("addOptions(%q, %#v) returned %v, want %v", url, v, got, want) 193 } 194 } 195 196 // Test how bad options are handled. Method f under test should 197 // return an error. 198 func testBadOptions(t *testing.T, methodName string, f func() error) { 199 t.Helper() 200 if methodName == "" { 201 t.Error("testBadOptions: must supply method methodName") 202 } 203 if err := f(); err == nil { 204 t.Errorf("bad options %v err = nil, want error", methodName) 205 } 206 } 207 208 // Test function under NewRequest failure and then s.client.Do failure. 209 // Method f should be a regular call that would normally succeed, but 210 // should return an error when NewRequest or s.client.Do fails. 211 func testNewRequestAndDoFailure(t *testing.T, methodName string, client *Client, f func() (*Response, error)) { 212 testNewRequestAndDoFailureCategory(t, methodName, client, CoreCategory, f) 213 } 214 215 // testNewRequestAndDoFailureCategory works Like testNewRequestAndDoFailure, but allows setting the category. 216 func testNewRequestAndDoFailureCategory(t *testing.T, methodName string, client *Client, category RateLimitCategory, f func() (*Response, error)) { 217 t.Helper() 218 if methodName == "" { 219 t.Error("testNewRequestAndDoFailure: must supply method methodName") 220 } 221 222 client.BaseURL.Path = "" 223 resp, err := f() 224 if resp != nil { 225 t.Errorf("client.BaseURL.Path='' %v resp = %#v, want nil", methodName, resp) 226 } 227 if err == nil { 228 t.Errorf("client.BaseURL.Path='' %v err = nil, want error", methodName) 229 } 230 231 client.BaseURL.Path = "/api-v3/" 232 client.rateLimits[category].Reset.Time = time.Now().Add(10 * time.Minute) 233 resp, err = f() 234 if bypass := resp.Request.Context().Value(BypassRateLimitCheck); bypass != nil { 235 return 236 } 237 if want := http.StatusForbidden; resp == nil || resp.Response.StatusCode != want { 238 if resp != nil { 239 t.Errorf("rate.Reset.Time > now %v resp = %#v, want StatusCode=%v", methodName, resp.Response, want) 240 } else { 241 t.Errorf("rate.Reset.Time > now %v resp = nil, want StatusCode=%v", methodName, want) 242 } 243 } 244 if err == nil { 245 t.Errorf("rate.Reset.Time > now %v err = nil, want error", methodName) 246 } 247 } 248 249 // Test that all error response types contain the status code. 250 func testErrorResponseForStatusCode(t *testing.T, code int) { 251 t.Helper() 252 client, mux, _ := setup(t) 253 254 mux.HandleFunc("/repos/o/r/hooks", func(w http.ResponseWriter, r *http.Request) { 255 testMethod(t, r, "GET") 256 w.WriteHeader(code) 257 }) 258 259 ctx := context.Background() 260 _, _, err := client.Repositories.ListHooks(ctx, "o", "r", nil) 261 262 switch e := err.(type) { 263 case *ErrorResponse: 264 case *RateLimitError: 265 case *AbuseRateLimitError: 266 if code != e.Response.StatusCode { 267 t.Error("Error response does not contain status code") 268 } 269 default: 270 t.Error("Unknown error response type") 271 } 272 } 273 274 func assertNoDiff(t *testing.T, want, got interface{}) { 275 t.Helper() 276 if diff := cmp.Diff(want, got); diff != "" { 277 t.Errorf("diff mismatch (-want +got):\n%v", diff) 278 } 279 } 280 281 func assertNilError(t *testing.T, err error) { 282 t.Helper() 283 if err != nil { 284 t.Errorf("unexpected error: %v", err) 285 } 286 } 287 288 func assertWrite(t *testing.T, w io.Writer, data []byte) { 289 t.Helper() 290 _, err := w.Write(data) 291 assertNilError(t, err) 292 } 293 294 func TestNewClient(t *testing.T) { 295 t.Parallel() 296 c := NewClient(nil) 297 298 if got, want := c.BaseURL.String(), defaultBaseURL; got != want { 299 t.Errorf("NewClient BaseURL is %v, want %v", got, want) 300 } 301 if got, want := c.UserAgent, defaultUserAgent; got != want { 302 t.Errorf("NewClient UserAgent is %v, want %v", got, want) 303 } 304 305 c2 := NewClient(nil) 306 if c.client == c2.client { 307 t.Error("NewClient returned same http.Clients, but they should differ") 308 } 309 } 310 311 func TestNewClientWithEnvProxy(t *testing.T) { 312 t.Parallel() 313 client := NewClientWithEnvProxy() 314 if got, want := client.BaseURL.String(), defaultBaseURL; got != want { 315 t.Errorf("NewClient BaseURL is %v, want %v", got, want) 316 } 317 } 318 319 func TestClient(t *testing.T) { 320 t.Parallel() 321 c := NewClient(nil) 322 c2 := c.Client() 323 if c.client == c2 { 324 t.Error("Client returned same http.Client, but should be different") 325 } 326 } 327 328 func TestWithAuthToken(t *testing.T) { 329 t.Parallel() 330 token := "gh_test_token" 331 332 validate := func(t *testing.T, c *http.Client, token string) { 333 t.Helper() 334 want := token 335 if want != "" { 336 want = "Bearer " + want 337 } 338 gotReq := false 339 headerVal := "" 340 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 341 gotReq = true 342 headerVal = r.Header.Get("Authorization") 343 })) 344 _, err := c.Get(srv.URL) 345 assertNilError(t, err) 346 if !gotReq { 347 t.Error("request not sent") 348 } 349 if headerVal != want { 350 t.Errorf("Authorization header is %v, want %v", headerVal, want) 351 } 352 } 353 354 t.Run("zero-value Client", func(t *testing.T) { 355 t.Parallel() 356 c := new(Client).WithAuthToken(token) 357 validate(t, c.Client(), token) 358 }) 359 360 t.Run("NewClient", func(t *testing.T) { 361 t.Parallel() 362 httpClient := &http.Client{} 363 client := NewClient(httpClient).WithAuthToken(token) 364 validate(t, client.Client(), token) 365 // make sure the original client isn't setting auth headers now 366 validate(t, httpClient, "") 367 }) 368 369 t.Run("NewTokenClient", func(t *testing.T) { 370 t.Parallel() 371 validate(t, NewTokenClient(context.Background(), token).Client(), token) 372 }) 373 } 374 375 func TestWithEnterpriseURLs(t *testing.T) { 376 t.Parallel() 377 for _, test := range []struct { 378 name string 379 baseURL string 380 wantBaseURL string 381 uploadURL string 382 wantUploadURL string 383 wantErr string 384 }{ 385 { 386 name: "does not modify properly formed URLs", 387 baseURL: "https://custom-url/api/v3/", 388 wantBaseURL: "https://custom-url/api/v3/", 389 uploadURL: "https://custom-upload-url/api/uploads/", 390 wantUploadURL: "https://custom-upload-url/api/uploads/", 391 }, 392 { 393 name: "adds trailing slash", 394 baseURL: "https://custom-url/api/v3", 395 wantBaseURL: "https://custom-url/api/v3/", 396 uploadURL: "https://custom-upload-url/api/uploads", 397 wantUploadURL: "https://custom-upload-url/api/uploads/", 398 }, 399 { 400 name: "adds enterprise suffix", 401 baseURL: "https://custom-url/", 402 wantBaseURL: "https://custom-url/api/v3/", 403 uploadURL: "https://custom-upload-url/", 404 wantUploadURL: "https://custom-upload-url/api/uploads/", 405 }, 406 { 407 name: "adds enterprise suffix and trailing slash", 408 baseURL: "https://custom-url", 409 wantBaseURL: "https://custom-url/api/v3/", 410 uploadURL: "https://custom-upload-url", 411 wantUploadURL: "https://custom-upload-url/api/uploads/", 412 }, 413 { 414 name: "bad base URL", 415 baseURL: "bogus\nbase\nURL", 416 uploadURL: "https://custom-upload-url/api/uploads/", 417 wantErr: `invalid control character in URL`, 418 }, 419 { 420 name: "bad upload URL", 421 baseURL: "https://custom-url/api/v3/", 422 uploadURL: "bogus\nupload\nURL", 423 wantErr: `invalid control character in URL`, 424 }, 425 { 426 name: "URL has existing API prefix, adds trailing slash", 427 baseURL: "https://api.custom-url", 428 wantBaseURL: "https://api.custom-url/", 429 uploadURL: "https://api.custom-upload-url", 430 wantUploadURL: "https://api.custom-upload-url/", 431 }, 432 { 433 name: "URL has existing API prefix and trailing slash", 434 baseURL: "https://api.custom-url/", 435 wantBaseURL: "https://api.custom-url/", 436 uploadURL: "https://api.custom-upload-url/", 437 wantUploadURL: "https://api.custom-upload-url/", 438 }, 439 { 440 name: "URL has API subdomain, adds trailing slash", 441 baseURL: "https://catalog.api.custom-url", 442 wantBaseURL: "https://catalog.api.custom-url/", 443 uploadURL: "https://catalog.api.custom-upload-url", 444 wantUploadURL: "https://catalog.api.custom-upload-url/", 445 }, 446 { 447 name: "URL has API subdomain and trailing slash", 448 baseURL: "https://catalog.api.custom-url/", 449 wantBaseURL: "https://catalog.api.custom-url/", 450 uploadURL: "https://catalog.api.custom-upload-url/", 451 wantUploadURL: "https://catalog.api.custom-upload-url/", 452 }, 453 { 454 name: "URL is not a proper API subdomain, adds enterprise suffix and slash", 455 baseURL: "https://cloud-api.custom-url", 456 wantBaseURL: "https://cloud-api.custom-url/api/v3/", 457 uploadURL: "https://cloud-api.custom-upload-url", 458 wantUploadURL: "https://cloud-api.custom-upload-url/api/uploads/", 459 }, 460 { 461 name: "URL is not a proper API subdomain, adds enterprise suffix", 462 baseURL: "https://cloud-api.custom-url/", 463 wantBaseURL: "https://cloud-api.custom-url/api/v3/", 464 uploadURL: "https://cloud-api.custom-upload-url/", 465 wantUploadURL: "https://cloud-api.custom-upload-url/api/uploads/", 466 }, 467 } { 468 test := test 469 t.Run(test.name, func(t *testing.T) { 470 t.Parallel() 471 validate := func(c *Client, err error) { 472 t.Helper() 473 if test.wantErr != "" { 474 if err == nil || !strings.Contains(err.Error(), test.wantErr) { 475 t.Fatalf("error does not contain expected string %q: %v", test.wantErr, err) 476 } 477 return 478 } 479 if err != nil { 480 t.Fatalf("got unexpected error: %v", err) 481 } 482 if c.BaseURL.String() != test.wantBaseURL { 483 t.Errorf("BaseURL is %v, want %v", c.BaseURL.String(), test.wantBaseURL) 484 } 485 if c.UploadURL.String() != test.wantUploadURL { 486 t.Errorf("UploadURL is %v, want %v", c.UploadURL.String(), test.wantUploadURL) 487 } 488 } 489 validate(NewClient(nil).WithEnterpriseURLs(test.baseURL, test.uploadURL)) 490 validate(new(Client).WithEnterpriseURLs(test.baseURL, test.uploadURL)) 491 validate(NewEnterpriseClient(test.baseURL, test.uploadURL, nil)) 492 }) 493 } 494 } 495 496 // Ensure that length of Client.rateLimits is the same as number of fields in RateLimits struct. 497 func TestClient_rateLimits(t *testing.T) { 498 t.Parallel() 499 if got, want := len(Client{}.rateLimits), reflect.TypeOf(RateLimits{}).NumField(); got != want { 500 t.Errorf("len(Client{}.rateLimits) is %v, want %v", got, want) 501 } 502 } 503 504 func TestNewRequest(t *testing.T) { 505 t.Parallel() 506 c := NewClient(nil) 507 508 inURL, outURL := "/foo", defaultBaseURL+"foo" 509 inBody, outBody := &User{Login: Ptr("l")}, `{"login":"l"}`+"\n" 510 req, _ := c.NewRequest("GET", inURL, inBody) 511 512 // test that relative URL was expanded 513 if got, want := req.URL.String(), outURL; got != want { 514 t.Errorf("NewRequest(%q) URL is %v, want %v", inURL, got, want) 515 } 516 517 // test that body was JSON encoded 518 body, _ := io.ReadAll(req.Body) 519 if got, want := string(body), outBody; got != want { 520 t.Errorf("NewRequest(%q) Body is %v, want %v", inBody, got, want) 521 } 522 523 userAgent := req.Header.Get("User-Agent") 524 525 // test that default user-agent is attached to the request 526 if got, want := userAgent, c.UserAgent; got != want { 527 t.Errorf("NewRequest() User-Agent is %v, want %v", got, want) 528 } 529 530 if !strings.Contains(userAgent, Version) { 531 t.Errorf("NewRequest() User-Agent should contain %v, found %v", Version, userAgent) 532 } 533 534 apiVersion := req.Header.Get(headerAPIVersion) 535 if got, want := apiVersion, defaultAPIVersion; got != want { 536 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 537 } 538 539 req, _ = c.NewRequest("GET", inURL, inBody, WithVersion("2022-11-29")) 540 apiVersion = req.Header.Get(headerAPIVersion) 541 if got, want := apiVersion, "2022-11-29"; got != want { 542 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 543 } 544 } 545 546 func TestNewRequest_invalidJSON(t *testing.T) { 547 t.Parallel() 548 c := NewClient(nil) 549 550 type T struct { 551 A map[interface{}]interface{} 552 } 553 _, err := c.NewRequest("GET", ".", &T{}) 554 555 if err == nil { 556 t.Error("Expected error to be returned.") 557 } 558 if err, ok := err.(*json.UnsupportedTypeError); !ok { 559 t.Errorf("Expected a JSON error; got %#v.", err) 560 } 561 } 562 563 func TestNewRequest_badURL(t *testing.T) { 564 t.Parallel() 565 c := NewClient(nil) 566 _, err := c.NewRequest("GET", ":", nil) 567 testURLParseError(t, err) 568 } 569 570 func TestNewRequest_badMethod(t *testing.T) { 571 t.Parallel() 572 c := NewClient(nil) 573 if _, err := c.NewRequest("BOGUS\nMETHOD", ".", nil); err == nil { 574 t.Fatal("NewRequest returned nil; expected error") 575 } 576 } 577 578 // ensure that no User-Agent header is set if the client's UserAgent is empty. 579 // This caused a problem with Google's internal http client. 580 func TestNewRequest_emptyUserAgent(t *testing.T) { 581 t.Parallel() 582 c := NewClient(nil) 583 c.UserAgent = "" 584 req, err := c.NewRequest("GET", ".", nil) 585 if err != nil { 586 t.Fatalf("NewRequest returned unexpected error: %v", err) 587 } 588 if _, ok := req.Header["User-Agent"]; ok { 589 t.Fatal("constructed request contains unexpected User-Agent header") 590 } 591 } 592 593 // If a nil body is passed to github.NewRequest, make sure that nil is also 594 // passed to http.NewRequest. In most cases, passing an io.Reader that returns 595 // no content is fine, since there is no difference between an HTTP request 596 // body that is an empty string versus one that is not set at all. However in 597 // certain cases, intermediate systems may treat these differently resulting in 598 // subtle errors. 599 func TestNewRequest_emptyBody(t *testing.T) { 600 t.Parallel() 601 c := NewClient(nil) 602 req, err := c.NewRequest("GET", ".", nil) 603 if err != nil { 604 t.Fatalf("NewRequest returned unexpected error: %v", err) 605 } 606 if req.Body != nil { 607 t.Fatalf("constructed request contains a non-nil Body") 608 } 609 } 610 611 func TestNewRequest_errorForNoTrailingSlash(t *testing.T) { 612 t.Parallel() 613 tests := []struct { 614 rawurl string 615 wantError bool 616 }{ 617 {rawurl: "https://example.com/api/v3", wantError: true}, 618 {rawurl: "https://example.com/api/v3/", wantError: false}, 619 } 620 c := NewClient(nil) 621 for _, test := range tests { 622 u, err := url.Parse(test.rawurl) 623 if err != nil { 624 t.Fatalf("url.Parse returned unexpected error: %v.", err) 625 } 626 c.BaseURL = u 627 if _, err := c.NewRequest(http.MethodGet, "test", nil); test.wantError && err == nil { 628 t.Fatalf("Expected error to be returned.") 629 } else if !test.wantError && err != nil { 630 t.Fatalf("NewRequest returned unexpected error: %v.", err) 631 } 632 } 633 } 634 635 func TestNewFormRequest(t *testing.T) { 636 t.Parallel() 637 c := NewClient(nil) 638 639 inURL, outURL := "/foo", defaultBaseURL+"foo" 640 form := url.Values{} 641 form.Add("login", "l") 642 inBody, outBody := strings.NewReader(form.Encode()), "login=l" 643 req, _ := c.NewFormRequest(inURL, inBody) 644 645 // test that relative URL was expanded 646 if got, want := req.URL.String(), outURL; got != want { 647 t.Errorf("NewFormRequest(%q) URL is %v, want %v", inURL, got, want) 648 } 649 650 // test that body was form encoded 651 body, _ := io.ReadAll(req.Body) 652 if got, want := string(body), outBody; got != want { 653 t.Errorf("NewFormRequest(%q) Body is %v, want %v", inBody, got, want) 654 } 655 656 // test that default user-agent is attached to the request 657 if got, want := req.Header.Get("User-Agent"), c.UserAgent; got != want { 658 t.Errorf("NewFormRequest() User-Agent is %v, want %v", got, want) 659 } 660 661 apiVersion := req.Header.Get(headerAPIVersion) 662 if got, want := apiVersion, defaultAPIVersion; got != want { 663 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 664 } 665 666 req, _ = c.NewFormRequest(inURL, inBody, WithVersion("2022-11-29")) 667 apiVersion = req.Header.Get(headerAPIVersion) 668 if got, want := apiVersion, "2022-11-29"; got != want { 669 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 670 } 671 } 672 673 func TestNewFormRequest_badURL(t *testing.T) { 674 t.Parallel() 675 c := NewClient(nil) 676 _, err := c.NewFormRequest(":", nil) 677 testURLParseError(t, err) 678 } 679 680 func TestNewFormRequest_emptyUserAgent(t *testing.T) { 681 t.Parallel() 682 c := NewClient(nil) 683 c.UserAgent = "" 684 req, err := c.NewFormRequest(".", nil) 685 if err != nil { 686 t.Fatalf("NewFormRequest returned unexpected error: %v", err) 687 } 688 if _, ok := req.Header["User-Agent"]; ok { 689 t.Fatal("constructed request contains unexpected User-Agent header") 690 } 691 } 692 693 func TestNewFormRequest_emptyBody(t *testing.T) { 694 t.Parallel() 695 c := NewClient(nil) 696 req, err := c.NewFormRequest(".", nil) 697 if err != nil { 698 t.Fatalf("NewFormRequest returned unexpected error: %v", err) 699 } 700 if req.Body != nil { 701 t.Fatalf("constructed request contains a non-nil Body") 702 } 703 } 704 705 func TestNewFormRequest_errorForNoTrailingSlash(t *testing.T) { 706 t.Parallel() 707 tests := []struct { 708 rawURL string 709 wantError bool 710 }{ 711 {rawURL: "https://example.com/api/v3", wantError: true}, 712 {rawURL: "https://example.com/api/v3/", wantError: false}, 713 } 714 c := NewClient(nil) 715 for _, test := range tests { 716 u, err := url.Parse(test.rawURL) 717 if err != nil { 718 t.Fatalf("url.Parse returned unexpected error: %v.", err) 719 } 720 c.BaseURL = u 721 if _, err := c.NewFormRequest("test", nil); test.wantError && err == nil { 722 t.Fatalf("Expected error to be returned.") 723 } else if !test.wantError && err != nil { 724 t.Fatalf("NewFormRequest returned unexpected error: %v.", err) 725 } 726 } 727 } 728 729 func TestNewUploadRequest_WithVersion(t *testing.T) { 730 t.Parallel() 731 c := NewClient(nil) 732 req, _ := c.NewUploadRequest("https://example.com/", nil, 0, "") 733 734 apiVersion := req.Header.Get(headerAPIVersion) 735 if got, want := apiVersion, defaultAPIVersion; got != want { 736 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 737 } 738 739 req, _ = c.NewUploadRequest("https://example.com/", nil, 0, "", WithVersion("2022-11-29")) 740 apiVersion = req.Header.Get(headerAPIVersion) 741 if got, want := apiVersion, "2022-11-29"; got != want { 742 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 743 } 744 } 745 746 func TestNewUploadRequest_badURL(t *testing.T) { 747 t.Parallel() 748 c := NewClient(nil) 749 _, err := c.NewUploadRequest(":", nil, 0, "") 750 testURLParseError(t, err) 751 752 const methodName = "NewUploadRequest" 753 testBadOptions(t, methodName, func() (err error) { 754 _, err = c.NewUploadRequest("\n", nil, -1, "\n") 755 return err 756 }) 757 } 758 759 func TestNewUploadRequest_errorForNoTrailingSlash(t *testing.T) { 760 t.Parallel() 761 tests := []struct { 762 rawurl string 763 wantError bool 764 }{ 765 {rawurl: "https://example.com/api/uploads", wantError: true}, 766 {rawurl: "https://example.com/api/uploads/", wantError: false}, 767 } 768 c := NewClient(nil) 769 for _, test := range tests { 770 u, err := url.Parse(test.rawurl) 771 if err != nil { 772 t.Fatalf("url.Parse returned unexpected error: %v.", err) 773 } 774 c.UploadURL = u 775 if _, err = c.NewUploadRequest("test", nil, 0, ""); test.wantError && err == nil { 776 t.Fatalf("Expected error to be returned.") 777 } else if !test.wantError && err != nil { 778 t.Fatalf("NewUploadRequest returned unexpected error: %v.", err) 779 } 780 } 781 } 782 783 func TestResponse_populatePageValues(t *testing.T) { 784 t.Parallel() 785 r := http.Response{ 786 Header: http.Header{ 787 "Link": {`<https://api.github.com/?page=1>; rel="first",` + 788 ` <https://api.github.com/?page=2>; rel="prev",` + 789 ` <https://api.github.com/?page=4>; rel="next",` + 790 ` <https://api.github.com/?page=5>; rel="last"`, 791 }, 792 }, 793 } 794 795 response := newResponse(&r) 796 if got, want := response.FirstPage, 1; got != want { 797 t.Errorf("response.FirstPage: %v, want %v", got, want) 798 } 799 if got, want := response.PrevPage, 2; want != got { 800 t.Errorf("response.PrevPage: %v, want %v", got, want) 801 } 802 if got, want := response.NextPage, 4; want != got { 803 t.Errorf("response.NextPage: %v, want %v", got, want) 804 } 805 if got, want := response.LastPage, 5; want != got { 806 t.Errorf("response.LastPage: %v, want %v", got, want) 807 } 808 if got, want := response.NextPageToken, ""; want != got { 809 t.Errorf("response.NextPageToken: %v, want %v", got, want) 810 } 811 } 812 813 func TestResponse_populateSinceValues(t *testing.T) { 814 t.Parallel() 815 r := http.Response{ 816 Header: http.Header{ 817 "Link": {`<https://api.github.com/?since=1>; rel="first",` + 818 ` <https://api.github.com/?since=2>; rel="prev",` + 819 ` <https://api.github.com/?since=4>; rel="next",` + 820 ` <https://api.github.com/?since=5>; rel="last"`, 821 }, 822 }, 823 } 824 825 response := newResponse(&r) 826 if got, want := response.FirstPage, 1; got != want { 827 t.Errorf("response.FirstPage: %v, want %v", got, want) 828 } 829 if got, want := response.PrevPage, 2; want != got { 830 t.Errorf("response.PrevPage: %v, want %v", got, want) 831 } 832 if got, want := response.NextPage, 4; want != got { 833 t.Errorf("response.NextPage: %v, want %v", got, want) 834 } 835 if got, want := response.LastPage, 5; want != got { 836 t.Errorf("response.LastPage: %v, want %v", got, want) 837 } 838 if got, want := response.NextPageToken, ""; want != got { 839 t.Errorf("response.NextPageToken: %v, want %v", got, want) 840 } 841 } 842 843 func TestResponse_SinceWithPage(t *testing.T) { 844 t.Parallel() 845 r := http.Response{ 846 Header: http.Header{ 847 "Link": {`<https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=1>; rel="first",` + 848 ` <https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=2>; rel="prev",` + 849 ` <https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=4>; rel="next",` + 850 ` <https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=5>; rel="last"`, 851 }, 852 }, 853 } 854 855 response := newResponse(&r) 856 if got, want := response.FirstPage, 1; got != want { 857 t.Errorf("response.FirstPage: %v, want %v", got, want) 858 } 859 if got, want := response.PrevPage, 2; want != got { 860 t.Errorf("response.PrevPage: %v, want %v", got, want) 861 } 862 if got, want := response.NextPage, 4; want != got { 863 t.Errorf("response.NextPage: %v, want %v", got, want) 864 } 865 if got, want := response.LastPage, 5; want != got { 866 t.Errorf("response.LastPage: %v, want %v", got, want) 867 } 868 if got, want := response.NextPageToken, ""; want != got { 869 t.Errorf("response.NextPageToken: %v, want %v", got, want) 870 } 871 } 872 873 func TestResponse_cursorPagination(t *testing.T) { 874 t.Parallel() 875 r := http.Response{ 876 Header: http.Header{ 877 "Status": {"200 OK"}, 878 "Link": {`<https://api.github.com/resource?per_page=2&page=url-encoded-next-page-token>; rel="next"`}, 879 }, 880 } 881 882 response := newResponse(&r) 883 if got, want := response.FirstPage, 0; got != want { 884 t.Errorf("response.FirstPage: %v, want %v", got, want) 885 } 886 if got, want := response.PrevPage, 0; want != got { 887 t.Errorf("response.PrevPage: %v, want %v", got, want) 888 } 889 if got, want := response.NextPage, 0; want != got { 890 t.Errorf("response.NextPage: %v, want %v", got, want) 891 } 892 if got, want := response.LastPage, 0; want != got { 893 t.Errorf("response.LastPage: %v, want %v", got, want) 894 } 895 if got, want := response.NextPageToken, "url-encoded-next-page-token"; want != got { 896 t.Errorf("response.NextPageToken: %v, want %v", got, want) 897 } 898 899 // cursor-based pagination with "cursor" param 900 r = http.Response{ 901 Header: http.Header{ 902 "Link": { 903 `<https://api.github.com/?cursor=v1_12345678>; rel="next"`, 904 }, 905 }, 906 } 907 908 response = newResponse(&r) 909 if got, want := response.Cursor, "v1_12345678"; got != want { 910 t.Errorf("response.Cursor: %v, want %v", got, want) 911 } 912 } 913 914 func TestResponse_beforeAfterPagination(t *testing.T) { 915 t.Parallel() 916 r := http.Response{ 917 Header: http.Header{ 918 "Link": {`<https://api.github.com/?after=a1b2c3&before=>; rel="next",` + 919 ` <https://api.github.com/?after=&before=>; rel="first",` + 920 ` <https://api.github.com/?after=&before=d4e5f6>; rel="prev",`, 921 }, 922 }, 923 } 924 925 response := newResponse(&r) 926 if got, want := response.Before, "d4e5f6"; got != want { 927 t.Errorf("response.Before: %v, want %v", got, want) 928 } 929 if got, want := response.After, "a1b2c3"; got != want { 930 t.Errorf("response.After: %v, want %v", got, want) 931 } 932 if got, want := response.FirstPage, 0; got != want { 933 t.Errorf("response.FirstPage: %v, want %v", got, want) 934 } 935 if got, want := response.PrevPage, 0; want != got { 936 t.Errorf("response.PrevPage: %v, want %v", got, want) 937 } 938 if got, want := response.NextPage, 0; want != got { 939 t.Errorf("response.NextPage: %v, want %v", got, want) 940 } 941 if got, want := response.LastPage, 0; want != got { 942 t.Errorf("response.LastPage: %v, want %v", got, want) 943 } 944 if got, want := response.NextPageToken, ""; want != got { 945 t.Errorf("response.NextPageToken: %v, want %v", got, want) 946 } 947 } 948 949 func TestResponse_populatePageValues_invalid(t *testing.T) { 950 t.Parallel() 951 r := http.Response{ 952 Header: http.Header{ 953 "Link": {`<https://api.github.com/?page=1>,` + 954 `<https://api.github.com/?page=abc>; rel="first",` + 955 `https://api.github.com/?page=2; rel="prev",` + 956 `<https://api.github.com/>; rel="next",` + 957 `<https://api.github.com/?page=>; rel="last"`, 958 }, 959 }, 960 } 961 962 response := newResponse(&r) 963 if got, want := response.FirstPage, 0; got != want { 964 t.Errorf("response.FirstPage: %v, want %v", got, want) 965 } 966 if got, want := response.PrevPage, 0; got != want { 967 t.Errorf("response.PrevPage: %v, want %v", got, want) 968 } 969 if got, want := response.NextPage, 0; got != want { 970 t.Errorf("response.NextPage: %v, want %v", got, want) 971 } 972 if got, want := response.LastPage, 0; got != want { 973 t.Errorf("response.LastPage: %v, want %v", got, want) 974 } 975 976 // more invalid URLs 977 r = http.Response{ 978 Header: http.Header{ 979 "Link": {`<https://api.github.com/%?page=2>; rel="first"`}, 980 }, 981 } 982 983 response = newResponse(&r) 984 if got, want := response.FirstPage, 0; got != want { 985 t.Errorf("response.FirstPage: %v, want %v", got, want) 986 } 987 } 988 989 func TestResponse_populateSinceValues_invalid(t *testing.T) { 990 t.Parallel() 991 r := http.Response{ 992 Header: http.Header{ 993 "Link": {`<https://api.github.com/?since=1>,` + 994 `<https://api.github.com/?since=abc>; rel="first",` + 995 `https://api.github.com/?since=2; rel="prev",` + 996 `<https://api.github.com/>; rel="next",` + 997 `<https://api.github.com/?since=>; rel="last"`, 998 }, 999 }, 1000 } 1001 1002 response := newResponse(&r) 1003 if got, want := response.FirstPage, 0; got != want { 1004 t.Errorf("response.FirstPage: %v, want %v", got, want) 1005 } 1006 if got, want := response.PrevPage, 0; got != want { 1007 t.Errorf("response.PrevPage: %v, want %v", got, want) 1008 } 1009 if got, want := response.NextPage, 0; got != want { 1010 t.Errorf("response.NextPage: %v, want %v", got, want) 1011 } 1012 if got, want := response.LastPage, 0; got != want { 1013 t.Errorf("response.LastPage: %v, want %v", got, want) 1014 } 1015 1016 // more invalid URLs 1017 r = http.Response{ 1018 Header: http.Header{ 1019 "Link": {`<https://api.github.com/%?since=2>; rel="first"`}, 1020 }, 1021 } 1022 1023 response = newResponse(&r) 1024 if got, want := response.FirstPage, 0; got != want { 1025 t.Errorf("response.FirstPage: %v, want %v", got, want) 1026 } 1027 } 1028 1029 func TestDo(t *testing.T) { 1030 t.Parallel() 1031 client, mux, _ := setup(t) 1032 1033 type foo struct { 1034 A string 1035 } 1036 1037 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1038 testMethod(t, r, "GET") 1039 fmt.Fprint(w, `{"A":"a"}`) 1040 }) 1041 1042 req, _ := client.NewRequest("GET", ".", nil) 1043 body := new(foo) 1044 ctx := context.Background() 1045 _, err := client.Do(ctx, req, body) 1046 assertNilError(t, err) 1047 1048 want := &foo{"a"} 1049 if !cmp.Equal(body, want) { 1050 t.Errorf("Response body = %v, want %v", body, want) 1051 } 1052 } 1053 1054 func TestDo_nilContext(t *testing.T) { 1055 t.Parallel() 1056 client, _, _ := setup(t) 1057 1058 req, _ := client.NewRequest("GET", ".", nil) 1059 _, err := client.Do(nil, req, nil) 1060 1061 if !errors.Is(err, errNonNilContext) { 1062 t.Errorf("Expected context must be non-nil error") 1063 } 1064 } 1065 1066 func TestDo_httpError(t *testing.T) { 1067 t.Parallel() 1068 client, mux, _ := setup(t) 1069 1070 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1071 http.Error(w, "Bad Request", 400) 1072 }) 1073 1074 req, _ := client.NewRequest("GET", ".", nil) 1075 ctx := context.Background() 1076 resp, err := client.Do(ctx, req, nil) 1077 1078 if err == nil { 1079 t.Fatal("Expected HTTP 400 error, got no error.") 1080 } 1081 if resp.StatusCode != 400 { 1082 t.Errorf("Expected HTTP 400 error, got %d status code.", resp.StatusCode) 1083 } 1084 } 1085 1086 // Test handling of an error caused by the internal http client's Do() 1087 // function. A redirect loop is pretty unlikely to occur within the GitHub 1088 // API, but does allow us to exercise the right code path. 1089 func TestDo_redirectLoop(t *testing.T) { 1090 t.Parallel() 1091 client, mux, _ := setup(t) 1092 1093 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1094 http.Redirect(w, r, baseURLPath, http.StatusFound) 1095 }) 1096 1097 req, _ := client.NewRequest("GET", ".", nil) 1098 ctx := context.Background() 1099 _, err := client.Do(ctx, req, nil) 1100 1101 if err == nil { 1102 t.Error("Expected error to be returned.") 1103 } 1104 if err, ok := err.(*url.Error); !ok { 1105 t.Errorf("Expected a URL error; got %#v.", err) 1106 } 1107 } 1108 1109 func TestDo_preservesResponseInHTTPError(t *testing.T) { 1110 t.Parallel() 1111 client, mux, _ := setup(t) 1112 1113 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1114 w.Header().Set("Content-Type", "application/json") 1115 w.WriteHeader(http.StatusNotFound) 1116 fmt.Fprintf(w, `{ 1117 "message": "Resource not found", 1118 "documentation_url": "https://docs.github.com/rest/reference/repos#get-a-repository" 1119 }`) 1120 }) 1121 1122 req, _ := client.NewRequest("GET", ".", nil) 1123 var resp *Response 1124 var data interface{} 1125 resp, err := client.Do(context.Background(), req, &data) 1126 1127 if err == nil { 1128 t.Fatal("Expected error response") 1129 } 1130 1131 // Verify error type and access to status code 1132 errResp, ok := err.(*ErrorResponse) 1133 if !ok { 1134 t.Fatalf("Expected *ErrorResponse error, got %T", err) 1135 } 1136 1137 // Verify status code is accessible from both Response and ErrorResponse 1138 if resp == nil { 1139 t.Fatal("Expected response to be returned even with error") 1140 } 1141 if got, want := resp.StatusCode, http.StatusNotFound; got != want { 1142 t.Errorf("Response status = %d, want %d", got, want) 1143 } 1144 if got, want := errResp.Response.StatusCode, http.StatusNotFound; got != want { 1145 t.Errorf("Error response status = %d, want %d", got, want) 1146 } 1147 1148 // Verify error contains proper message 1149 if !strings.Contains(errResp.Message, "Resource not found") { 1150 t.Errorf("Error message = %q, want to contain 'Resource not found'", errResp.Message) 1151 } 1152 } 1153 1154 // Test that an error caused by the internal http client's Do() function 1155 // does not leak the client secret. 1156 func TestDo_sanitizeURL(t *testing.T) { 1157 t.Parallel() 1158 tp := &UnauthenticatedRateLimitedTransport{ 1159 ClientID: "id", 1160 ClientSecret: "secret", 1161 } 1162 unauthedClient := NewClient(tp.Client()) 1163 unauthedClient.BaseURL = &url.URL{Scheme: "http", Host: "127.0.0.1:0", Path: "/"} // Use port 0 on purpose to trigger a dial TCP error, expect to get "dial tcp 127.0.0.1:0: connect: can't assign requested address". 1164 req, err := unauthedClient.NewRequest("GET", ".", nil) 1165 if err != nil { 1166 t.Fatalf("NewRequest returned unexpected error: %v", err) 1167 } 1168 ctx := context.Background() 1169 _, err = unauthedClient.Do(ctx, req, nil) 1170 if err == nil { 1171 t.Fatal("Expected error to be returned.") 1172 } 1173 if strings.Contains(err.Error(), "client_secret=secret") { 1174 t.Errorf("Do error contains secret, should be redacted:\n%q", err) 1175 } 1176 } 1177 1178 func TestDo_rateLimit(t *testing.T) { 1179 t.Parallel() 1180 client, mux, _ := setup(t) 1181 1182 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1183 w.Header().Set(headerRateLimit, "60") 1184 w.Header().Set(headerRateRemaining, "59") 1185 w.Header().Set(headerRateUsed, "1") 1186 w.Header().Set(headerRateReset, "1372700873") 1187 w.Header().Set(headerRateResource, "core") 1188 }) 1189 1190 req, _ := client.NewRequest("GET", ".", nil) 1191 ctx := context.Background() 1192 resp, err := client.Do(ctx, req, nil) 1193 if err != nil { 1194 t.Errorf("Do returned unexpected error: %v", err) 1195 } 1196 if got, want := resp.Rate.Limit, 60; got != want { 1197 t.Errorf("Client rate limit = %v, want %v", got, want) 1198 } 1199 if got, want := resp.Rate.Remaining, 59; got != want { 1200 t.Errorf("Client rate remaining = %v, want %v", got, want) 1201 } 1202 if got, want := resp.Rate.Used, 1; got != want { 1203 t.Errorf("Client rate used = %v, want %v", got, want) 1204 } 1205 reset := time.Date(2013, time.July, 1, 17, 47, 53, 0, time.UTC) 1206 if !resp.Rate.Reset.UTC().Equal(reset) { 1207 t.Errorf("Client rate reset = %v, want %v", resp.Rate.Reset.UTC(), reset) 1208 } 1209 if got, want := resp.Rate.Resource, "core"; got != want { 1210 t.Errorf("Client rate resource = %v, want %v", got, want) 1211 } 1212 } 1213 1214 func TestDo_rateLimitCategory(t *testing.T) { 1215 t.Parallel() 1216 tests := []struct { 1217 method string 1218 url string 1219 category RateLimitCategory 1220 }{ 1221 { 1222 method: http.MethodGet, 1223 url: "/", 1224 category: CoreCategory, 1225 }, 1226 { 1227 method: http.MethodGet, 1228 url: "/search/issues?q=rate", 1229 category: SearchCategory, 1230 }, 1231 { 1232 method: http.MethodGet, 1233 url: "/graphql", 1234 category: GraphqlCategory, 1235 }, 1236 { 1237 method: http.MethodPost, 1238 url: "/app-manifests/code/conversions", 1239 category: IntegrationManifestCategory, 1240 }, 1241 { 1242 method: http.MethodGet, 1243 url: "/app-manifests/code/conversions", 1244 category: CoreCategory, // only POST requests are in the integration manifest category 1245 }, 1246 { 1247 method: http.MethodPut, 1248 url: "/repos/google/go-github/import", 1249 category: SourceImportCategory, 1250 }, 1251 { 1252 method: http.MethodGet, 1253 url: "/repos/google/go-github/import", 1254 category: CoreCategory, // only PUT requests are in the source import category 1255 }, 1256 { 1257 method: http.MethodPost, 1258 url: "/repos/google/go-github/code-scanning/sarifs", 1259 category: CodeScanningUploadCategory, 1260 }, 1261 { 1262 method: http.MethodGet, 1263 url: "/scim/v2/organizations/ORG/Users", 1264 category: ScimCategory, 1265 }, 1266 { 1267 method: http.MethodPost, 1268 url: "/repos/google/go-github/dependency-graph/snapshots", 1269 category: DependencySnapshotsCategory, 1270 }, 1271 { 1272 method: http.MethodGet, 1273 url: "/search/code?q=rate", 1274 category: CodeSearchCategory, 1275 }, 1276 { 1277 method: http.MethodGet, 1278 url: "/orgs/google/audit-log", 1279 category: AuditLogCategory, 1280 }, 1281 // missing a check for actionsRunnerRegistrationCategory: API not found 1282 } 1283 1284 for _, tt := range tests { 1285 if got, want := GetRateLimitCategory(tt.method, tt.url), tt.category; got != want { 1286 t.Errorf("expecting category %v, found %v", got, want) 1287 } 1288 } 1289 } 1290 1291 // Ensure rate limit is still parsed, even for error responses. 1292 func TestDo_rateLimit_errorResponse(t *testing.T) { 1293 t.Parallel() 1294 client, mux, _ := setup(t) 1295 1296 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1297 w.Header().Set(headerRateLimit, "60") 1298 w.Header().Set(headerRateRemaining, "59") 1299 w.Header().Set(headerRateUsed, "1") 1300 w.Header().Set(headerRateReset, "1372700873") 1301 w.Header().Set(headerRateResource, "core") 1302 http.Error(w, "Bad Request", 400) 1303 }) 1304 1305 req, _ := client.NewRequest("GET", ".", nil) 1306 ctx := context.Background() 1307 resp, err := client.Do(ctx, req, nil) 1308 if err == nil { 1309 t.Error("Expected error to be returned.") 1310 } 1311 if _, ok := err.(*RateLimitError); ok { 1312 t.Errorf("Did not expect a *RateLimitError error; got %#v.", err) 1313 } 1314 if got, want := resp.Rate.Limit, 60; got != want { 1315 t.Errorf("Client rate limit = %v, want %v", got, want) 1316 } 1317 if got, want := resp.Rate.Remaining, 59; got != want { 1318 t.Errorf("Client rate remaining = %v, want %v", got, want) 1319 } 1320 if got, want := resp.Rate.Used, 1; got != want { 1321 t.Errorf("Client rate used = %v, want %v", got, want) 1322 } 1323 reset := time.Date(2013, time.July, 1, 17, 47, 53, 0, time.UTC) 1324 if !resp.Rate.Reset.UTC().Equal(reset) { 1325 t.Errorf("Client rate reset = %v, want %v", resp.Rate.Reset, reset) 1326 } 1327 if got, want := resp.Rate.Resource, "core"; got != want { 1328 t.Errorf("Client rate resource = %v, want %v", got, want) 1329 } 1330 } 1331 1332 // Ensure *RateLimitError is returned when API rate limit is exceeded. 1333 func TestDo_rateLimit_rateLimitError(t *testing.T) { 1334 t.Parallel() 1335 client, mux, _ := setup(t) 1336 1337 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1338 w.Header().Set(headerRateLimit, "60") 1339 w.Header().Set(headerRateRemaining, "0") 1340 w.Header().Set(headerRateUsed, "60") 1341 w.Header().Set(headerRateReset, "1372700873") 1342 w.Header().Set(headerRateResource, "core") 1343 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1344 w.WriteHeader(http.StatusForbidden) 1345 fmt.Fprintln(w, `{ 1346 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1347 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1348 }`) 1349 }) 1350 1351 req, _ := client.NewRequest("GET", ".", nil) 1352 ctx := context.Background() 1353 _, err := client.Do(ctx, req, nil) 1354 1355 if err == nil { 1356 t.Error("Expected error to be returned.") 1357 } 1358 rateLimitErr, ok := err.(*RateLimitError) 1359 if !ok { 1360 t.Fatalf("Expected a *RateLimitError error; got %#v.", err) 1361 } 1362 if got, want := rateLimitErr.Rate.Limit, 60; got != want { 1363 t.Errorf("rateLimitErr rate limit = %v, want %v", got, want) 1364 } 1365 if got, want := rateLimitErr.Rate.Remaining, 0; got != want { 1366 t.Errorf("rateLimitErr rate remaining = %v, want %v", got, want) 1367 } 1368 if got, want := rateLimitErr.Rate.Used, 60; got != want { 1369 t.Errorf("rateLimitErr rate used = %v, want %v", got, want) 1370 } 1371 reset := time.Date(2013, time.July, 1, 17, 47, 53, 0, time.UTC) 1372 if !rateLimitErr.Rate.Reset.UTC().Equal(reset) { 1373 t.Errorf("rateLimitErr rate reset = %v, want %v", rateLimitErr.Rate.Reset.UTC(), reset) 1374 } 1375 if got, want := rateLimitErr.Rate.Resource, "core"; got != want { 1376 t.Errorf("rateLimitErr rate resource = %v, want %v", got, want) 1377 } 1378 } 1379 1380 // Ensure a network call is not made when it's known that API rate limit is still exceeded. 1381 func TestDo_rateLimit_noNetworkCall(t *testing.T) { 1382 t.Parallel() 1383 client, mux, _ := setup(t) 1384 1385 reset := time.Now().UTC().Add(time.Minute).Round(time.Second) // Rate reset is a minute from now, with 1 second precision. 1386 1387 mux.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) { 1388 w.Header().Set(headerRateLimit, "60") 1389 w.Header().Set(headerRateRemaining, "0") 1390 w.Header().Set(headerRateUsed, "60") 1391 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1392 w.Header().Set(headerRateResource, "core") 1393 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1394 w.WriteHeader(http.StatusForbidden) 1395 fmt.Fprintln(w, `{ 1396 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1397 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1398 }`) 1399 }) 1400 1401 madeNetworkCall := false 1402 mux.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) { 1403 madeNetworkCall = true 1404 }) 1405 1406 // First request is made, and it makes the client aware of rate reset time being in the future. 1407 req, _ := client.NewRequest("GET", "first", nil) 1408 ctx := context.Background() 1409 _, err := client.Do(ctx, req, nil) 1410 if err == nil { 1411 t.Error("Expected error to be returned.") 1412 } 1413 1414 // Second request should not cause a network call to be made, since client can predict a rate limit error. 1415 req, _ = client.NewRequest("GET", "second", nil) 1416 _, err = client.Do(ctx, req, nil) 1417 1418 if madeNetworkCall { 1419 t.Fatal("Network call was made, even though rate limit is known to still be exceeded.") 1420 } 1421 1422 if err == nil { 1423 t.Error("Expected error to be returned.") 1424 } 1425 rateLimitErr, ok := err.(*RateLimitError) 1426 if !ok { 1427 t.Fatalf("Expected a *RateLimitError error; got %#v.", err) 1428 } 1429 if got, want := rateLimitErr.Rate.Limit, 60; got != want { 1430 t.Errorf("rateLimitErr rate limit = %v, want %v", got, want) 1431 } 1432 if got, want := rateLimitErr.Rate.Remaining, 0; got != want { 1433 t.Errorf("rateLimitErr rate remaining = %v, want %v", got, want) 1434 } 1435 if got, want := rateLimitErr.Rate.Used, 60; got != want { 1436 t.Errorf("rateLimitErr rate used = %v, want %v", got, want) 1437 } 1438 if !rateLimitErr.Rate.Reset.UTC().Equal(reset) { 1439 t.Errorf("rateLimitErr rate reset = %v, want %v", rateLimitErr.Rate.Reset.UTC(), reset) 1440 } 1441 if got, want := rateLimitErr.Rate.Resource, "core"; got != want { 1442 t.Errorf("rateLimitErr rate resource = %v, want %v", got, want) 1443 } 1444 } 1445 1446 // Ignore rate limit headers if the response was served from cache. 1447 func TestDo_rateLimit_ignoredFromCache(t *testing.T) { 1448 t.Parallel() 1449 client, mux, _ := setup(t) 1450 1451 reset := time.Now().UTC().Add(time.Minute).Round(time.Second) // Rate reset is a minute from now, with 1 second precision. 1452 1453 // By adding the X-From-Cache header we pretend this is served from a cache. 1454 mux.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) { 1455 w.Header().Set("X-From-Cache", "1") 1456 w.Header().Set(headerRateLimit, "60") 1457 w.Header().Set(headerRateRemaining, "0") 1458 w.Header().Set(headerRateUsed, "60") 1459 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1460 w.Header().Set(headerRateResource, "core") 1461 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1462 w.WriteHeader(http.StatusForbidden) 1463 fmt.Fprintln(w, `{ 1464 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1465 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1466 }`) 1467 }) 1468 1469 madeNetworkCall := false 1470 mux.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) { 1471 madeNetworkCall = true 1472 }) 1473 1474 // First request is made so afterwards we can check the returned rate limit headers were ignored. 1475 req, _ := client.NewRequest("GET", "first", nil) 1476 ctx := context.Background() 1477 _, err := client.Do(ctx, req, nil) 1478 if err == nil { 1479 t.Error("Expected error to be returned.") 1480 } 1481 1482 // Second request should not by hindered by rate limits. 1483 req, _ = client.NewRequest("GET", "second", nil) 1484 _, err = client.Do(ctx, req, nil) 1485 1486 if err != nil { 1487 t.Fatalf("Second request failed, even though the rate limits from the cache should've been ignored: %v", err) 1488 } 1489 if !madeNetworkCall { 1490 t.Fatal("Network call was not made, even though the rate limits from the cache should've been ignored") 1491 } 1492 } 1493 1494 // Ensure sleeps until the rate limit is reset when the client is rate limited. 1495 func TestDo_rateLimit_sleepUntilResponseResetLimit(t *testing.T) { 1496 t.Parallel() 1497 client, mux, _ := setup(t) 1498 1499 reset := time.Now().UTC().Add(time.Second) 1500 1501 var firstRequest = true 1502 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1503 if firstRequest { 1504 firstRequest = false 1505 w.Header().Set(headerRateLimit, "60") 1506 w.Header().Set(headerRateRemaining, "0") 1507 w.Header().Set(headerRateUsed, "60") 1508 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1509 w.Header().Set(headerRateResource, "core") 1510 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1511 w.WriteHeader(http.StatusForbidden) 1512 fmt.Fprintln(w, `{ 1513 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1514 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1515 }`) 1516 return 1517 } 1518 w.Header().Set(headerRateLimit, "5000") 1519 w.Header().Set(headerRateRemaining, "5000") 1520 w.Header().Set(headerRateUsed, "0") 1521 w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) 1522 w.Header().Set(headerRateResource, "core") 1523 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1524 w.WriteHeader(http.StatusOK) 1525 fmt.Fprintln(w, `{}`) 1526 }) 1527 1528 req, _ := client.NewRequest("GET", ".", nil) 1529 ctx := context.Background() 1530 resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1531 if err != nil { 1532 t.Errorf("Do returned unexpected error: %v", err) 1533 } 1534 if got, want := resp.StatusCode, http.StatusOK; got != want { 1535 t.Errorf("Response status code = %v, want %v", got, want) 1536 } 1537 } 1538 1539 // Ensure tries to sleep until the rate limit is reset when the client is rate limited, but only once. 1540 func TestDo_rateLimit_sleepUntilResponseResetLimitRetryOnce(t *testing.T) { 1541 t.Parallel() 1542 client, mux, _ := setup(t) 1543 1544 reset := time.Now().UTC().Add(time.Second) 1545 1546 requestCount := 0 1547 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1548 requestCount++ 1549 w.Header().Set(headerRateLimit, "60") 1550 w.Header().Set(headerRateRemaining, "0") 1551 w.Header().Set(headerRateUsed, "60") 1552 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1553 w.Header().Set(headerRateResource, "core") 1554 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1555 w.WriteHeader(http.StatusForbidden) 1556 fmt.Fprintln(w, `{ 1557 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1558 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1559 }`) 1560 }) 1561 1562 req, _ := client.NewRequest("GET", ".", nil) 1563 ctx := context.Background() 1564 _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1565 if err == nil { 1566 t.Error("Expected error to be returned.") 1567 } 1568 if got, want := requestCount, 2; got != want { 1569 t.Errorf("Expected 2 requests, got %d", got) 1570 } 1571 } 1572 1573 // Ensure a network call is not made when it's known that API rate limit is still exceeded. 1574 func TestDo_rateLimit_sleepUntilClientResetLimit(t *testing.T) { 1575 t.Parallel() 1576 client, mux, _ := setup(t) 1577 1578 reset := time.Now().UTC().Add(time.Second) 1579 client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}} 1580 requestCount := 0 1581 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1582 requestCount++ 1583 w.Header().Set(headerRateLimit, "5000") 1584 w.Header().Set(headerRateRemaining, "5000") 1585 w.Header().Set(headerRateUsed, "0") 1586 w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) 1587 w.Header().Set(headerRateResource, "core") 1588 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1589 w.WriteHeader(http.StatusOK) 1590 fmt.Fprintln(w, `{}`) 1591 }) 1592 req, _ := client.NewRequest("GET", ".", nil) 1593 ctx := context.Background() 1594 resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1595 if err != nil { 1596 t.Errorf("Do returned unexpected error: %v", err) 1597 } 1598 if got, want := resp.StatusCode, http.StatusOK; got != want { 1599 t.Errorf("Response status code = %v, want %v", got, want) 1600 } 1601 if got, want := requestCount, 1; got != want { 1602 t.Errorf("Expected 1 request, got %d", got) 1603 } 1604 } 1605 1606 // Ensure sleep is aborted when the context is cancelled. 1607 func TestDo_rateLimit_abortSleepContextCancelled(t *testing.T) { 1608 t.Parallel() 1609 client, mux, _ := setup(t) 1610 1611 // We use a 1 minute reset time to ensure the sleep is not completed. 1612 reset := time.Now().UTC().Add(time.Minute) 1613 requestCount := 0 1614 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1615 requestCount++ 1616 w.Header().Set(headerRateLimit, "60") 1617 w.Header().Set(headerRateRemaining, "0") 1618 w.Header().Set(headerRateUsed, "60") 1619 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1620 w.Header().Set(headerRateResource, "core") 1621 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1622 w.WriteHeader(http.StatusForbidden) 1623 fmt.Fprintln(w, `{ 1624 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1625 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1626 }`) 1627 }) 1628 1629 req, _ := client.NewRequest("GET", ".", nil) 1630 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 1631 defer cancel() 1632 _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1633 if !errors.Is(err, context.DeadlineExceeded) { 1634 t.Error("Expected context deadline exceeded error.") 1635 } 1636 if got, want := requestCount, 1; got != want { 1637 t.Errorf("Expected 1 requests, got %d", got) 1638 } 1639 } 1640 1641 // Ensure sleep is aborted when the context is cancelled on initial request. 1642 func TestDo_rateLimit_abortSleepContextCancelledClientLimit(t *testing.T) { 1643 t.Parallel() 1644 client, mux, _ := setup(t) 1645 1646 reset := time.Now().UTC().Add(time.Minute) 1647 client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}} 1648 requestCount := 0 1649 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1650 requestCount++ 1651 w.Header().Set(headerRateLimit, "5000") 1652 w.Header().Set(headerRateRemaining, "5000") 1653 w.Header().Set(headerRateUsed, "0") 1654 w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) 1655 w.Header().Set(headerRateResource, "core") 1656 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1657 w.WriteHeader(http.StatusOK) 1658 fmt.Fprintln(w, `{}`) 1659 }) 1660 req, _ := client.NewRequest("GET", ".", nil) 1661 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 1662 defer cancel() 1663 _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1664 rateLimitError, ok := err.(*RateLimitError) 1665 if !ok { 1666 t.Fatalf("Expected a *rateLimitError error; got %#v.", err) 1667 } 1668 if got, wantSuffix := rateLimitError.Message, "Context cancelled while waiting for rate limit to reset until"; !strings.HasPrefix(got, wantSuffix) { 1669 t.Errorf("Expected request to be prevented because context cancellation, got: %v.", got) 1670 } 1671 if got, want := requestCount, 0; got != want { 1672 t.Errorf("Expected 1 requests, got %d", got) 1673 } 1674 } 1675 1676 // Ensure *AbuseRateLimitError is returned when the response indicates that 1677 // the client has triggered an abuse detection mechanism. 1678 func TestDo_rateLimit_abuseRateLimitError(t *testing.T) { 1679 t.Parallel() 1680 client, mux, _ := setup(t) 1681 1682 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1683 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1684 w.WriteHeader(http.StatusForbidden) 1685 // When the abuse rate limit error is of the "temporarily blocked from content creation" type, 1686 // there is no "Retry-After" header. 1687 fmt.Fprintln(w, `{ 1688 "message": "You have triggered an abuse detection mechanism and have been temporarily blocked from content creation. Please retry your request again later.", 1689 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1690 }`) 1691 }) 1692 1693 req, _ := client.NewRequest("GET", ".", nil) 1694 ctx := context.Background() 1695 _, err := client.Do(ctx, req, nil) 1696 1697 if err == nil { 1698 t.Error("Expected error to be returned.") 1699 } 1700 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1701 if !ok { 1702 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1703 } 1704 if got, want := abuseRateLimitErr.RetryAfter, (*time.Duration)(nil); got != want { 1705 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1706 } 1707 } 1708 1709 // Ensure *AbuseRateLimitError is returned when the response indicates that 1710 // the client has triggered an abuse detection mechanism on GitHub Enterprise. 1711 func TestDo_rateLimit_abuseRateLimitErrorEnterprise(t *testing.T) { 1712 t.Parallel() 1713 client, mux, _ := setup(t) 1714 1715 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1716 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1717 w.WriteHeader(http.StatusForbidden) 1718 // When the abuse rate limit error is of the "temporarily blocked from content creation" type, 1719 // there is no "Retry-After" header. 1720 // This response returns a documentation url like the one returned for GitHub Enterprise, this 1721 // url changes between versions but follows roughly the same format. 1722 fmt.Fprintln(w, `{ 1723 "message": "You have triggered an abuse detection mechanism and have been temporarily blocked from content creation. Please retry your request again later.", 1724 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1725 }`) 1726 }) 1727 1728 req, _ := client.NewRequest("GET", ".", nil) 1729 ctx := context.Background() 1730 _, err := client.Do(ctx, req, nil) 1731 1732 if err == nil { 1733 t.Error("Expected error to be returned.") 1734 } 1735 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1736 if !ok { 1737 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1738 } 1739 if got, want := abuseRateLimitErr.RetryAfter, (*time.Duration)(nil); got != want { 1740 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1741 } 1742 } 1743 1744 // Ensure *AbuseRateLimitError.RetryAfter is parsed correctly for the Retry-After header. 1745 func TestDo_rateLimit_abuseRateLimitError_retryAfter(t *testing.T) { 1746 t.Parallel() 1747 client, mux, _ := setup(t) 1748 1749 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1750 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1751 w.Header().Set(headerRetryAfter, "123") // Retry after value of 123 seconds. 1752 w.WriteHeader(http.StatusForbidden) 1753 fmt.Fprintln(w, `{ 1754 "message": "You have triggered an abuse detection mechanism ...", 1755 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1756 }`) 1757 }) 1758 1759 req, _ := client.NewRequest("GET", ".", nil) 1760 ctx := context.Background() 1761 _, err := client.Do(ctx, req, nil) 1762 1763 if err == nil { 1764 t.Error("Expected error to be returned.") 1765 } 1766 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1767 if !ok { 1768 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1769 } 1770 if abuseRateLimitErr.RetryAfter == nil { 1771 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1772 } 1773 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; got != want { 1774 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1775 } 1776 1777 // expect prevention of a following request 1778 if _, err = client.Do(ctx, req, nil); err == nil { 1779 t.Error("Expected error to be returned.") 1780 } 1781 abuseRateLimitErr, ok = err.(*AbuseRateLimitError) 1782 if !ok { 1783 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1784 } 1785 if abuseRateLimitErr.RetryAfter == nil { 1786 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1787 } 1788 // the saved duration might be a bit smaller than Retry-After because the duration is calculated from the expected end-of-cooldown time 1789 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; want-got > 1*time.Second { 1790 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1791 } 1792 if got, wantSuffix := abuseRateLimitErr.Message, "not making remote request."; !strings.HasSuffix(got, wantSuffix) { 1793 t.Errorf("Expected request to be prevented because of secondary rate limit, got: %v.", got) 1794 } 1795 } 1796 1797 // Ensure *AbuseRateLimitError.RetryAfter is parsed correctly for the x-ratelimit-reset header. 1798 func TestDo_rateLimit_abuseRateLimitError_xRateLimitReset(t *testing.T) { 1799 t.Parallel() 1800 client, mux, _ := setup(t) 1801 1802 // x-ratelimit-reset value of 123 seconds into the future. 1803 blockUntil := time.Now().Add(time.Duration(123) * time.Second).Unix() 1804 1805 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1806 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1807 w.Header().Set(headerRateReset, strconv.Itoa(int(blockUntil))) 1808 w.Header().Set(headerRateRemaining, "1") // set remaining to a value > 0 to distinct from a primary rate limit 1809 w.WriteHeader(http.StatusForbidden) 1810 fmt.Fprintln(w, `{ 1811 "message": "You have triggered an abuse detection mechanism ...", 1812 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1813 }`) 1814 }) 1815 1816 req, _ := client.NewRequest("GET", ".", nil) 1817 ctx := context.Background() 1818 _, err := client.Do(ctx, req, nil) 1819 1820 if err == nil { 1821 t.Error("Expected error to be returned.") 1822 } 1823 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1824 if !ok { 1825 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1826 } 1827 if abuseRateLimitErr.RetryAfter == nil { 1828 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1829 } 1830 // the retry after value might be a bit smaller than the original duration because the duration is calculated from the expected end-of-cooldown time 1831 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; want-got > 1*time.Second { 1832 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1833 } 1834 1835 // expect prevention of a following request 1836 if _, err = client.Do(ctx, req, nil); err == nil { 1837 t.Error("Expected error to be returned.") 1838 } 1839 abuseRateLimitErr, ok = err.(*AbuseRateLimitError) 1840 if !ok { 1841 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1842 } 1843 if abuseRateLimitErr.RetryAfter == nil { 1844 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1845 } 1846 // the saved duration might be a bit smaller than Retry-After because the duration is calculated from the expected end-of-cooldown time 1847 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; want-got > 1*time.Second { 1848 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1849 } 1850 if got, wantSuffix := abuseRateLimitErr.Message, "not making remote request."; !strings.HasSuffix(got, wantSuffix) { 1851 t.Errorf("Expected request to be prevented because of secondary rate limit, got: %v.", got) 1852 } 1853 } 1854 1855 // Ensure *AbuseRateLimitError.RetryAfter respect a max duration if specified. 1856 func TestDo_rateLimit_abuseRateLimitError_maxDuration(t *testing.T) { 1857 t.Parallel() 1858 client, mux, _ := setup(t) 1859 // specify a max retry after duration of 1 min 1860 client.MaxSecondaryRateLimitRetryAfterDuration = 60 * time.Second 1861 1862 // x-ratelimit-reset value of 1h into the future, to make sure we are way over the max wait time duration. 1863 blockUntil := time.Now().Add(1 * time.Hour).Unix() 1864 1865 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1866 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1867 w.Header().Set(headerRateReset, strconv.Itoa(int(blockUntil))) 1868 w.Header().Set(headerRateRemaining, "1") // set remaining to a value > 0 to distinct from a primary rate limit 1869 w.WriteHeader(http.StatusForbidden) 1870 fmt.Fprintln(w, `{ 1871 "message": "You have triggered an abuse detection mechanism ...", 1872 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1873 }`) 1874 }) 1875 1876 req, _ := client.NewRequest("GET", ".", nil) 1877 ctx := context.Background() 1878 _, err := client.Do(ctx, req, nil) 1879 1880 if err == nil { 1881 t.Error("Expected error to be returned.") 1882 } 1883 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1884 if !ok { 1885 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1886 } 1887 if abuseRateLimitErr.RetryAfter == nil { 1888 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1889 } 1890 // check that the retry after is set to be the max allowed duration 1891 if got, want := *abuseRateLimitErr.RetryAfter, client.MaxSecondaryRateLimitRetryAfterDuration; got != want { 1892 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1893 } 1894 } 1895 1896 func TestDo_noContent(t *testing.T) { 1897 t.Parallel() 1898 client, mux, _ := setup(t) 1899 1900 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1901 w.WriteHeader(http.StatusNoContent) 1902 }) 1903 1904 var body json.RawMessage 1905 1906 req, _ := client.NewRequest("GET", ".", nil) 1907 ctx := context.Background() 1908 _, err := client.Do(ctx, req, &body) 1909 if err != nil { 1910 t.Fatalf("Do returned unexpected error: %v", err) 1911 } 1912 } 1913 1914 func TestBareDoUntilFound_redirectLoop(t *testing.T) { 1915 t.Parallel() 1916 client, mux, _ := setup(t) 1917 1918 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1919 http.Redirect(w, r, baseURLPath, http.StatusMovedPermanently) 1920 }) 1921 1922 req, _ := client.NewRequest("GET", ".", nil) 1923 ctx := context.Background() 1924 _, _, err := client.bareDoUntilFound(ctx, req, 1) 1925 1926 if err == nil { 1927 t.Error("Expected error to be returned.") 1928 } 1929 var rerr *RedirectionError 1930 if !errors.As(err, &rerr) { 1931 t.Errorf("Expected a Redirection error; got %#v.", err) 1932 } 1933 } 1934 1935 func TestBareDoUntilFound_UnexpectedRedirection(t *testing.T) { 1936 t.Parallel() 1937 client, mux, _ := setup(t) 1938 1939 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1940 http.Redirect(w, r, baseURLPath, http.StatusSeeOther) 1941 }) 1942 1943 req, _ := client.NewRequest("GET", ".", nil) 1944 ctx := context.Background() 1945 _, _, err := client.bareDoUntilFound(ctx, req, 1) 1946 1947 if err == nil { 1948 t.Error("Expected error to be returned.") 1949 } 1950 var rerr *RedirectionError 1951 if !errors.As(err, &rerr) { 1952 t.Errorf("Expected a Redirection error; got %#v.", err) 1953 } 1954 } 1955 1956 func TestSanitizeURL(t *testing.T) { 1957 t.Parallel() 1958 tests := []struct { 1959 in, want string 1960 }{ 1961 {"/?a=b", "/?a=b"}, 1962 {"/?a=b&client_secret=secret", "/?a=b&client_secret=REDACTED"}, 1963 {"/?a=b&client_id=id&client_secret=secret", "/?a=b&client_id=id&client_secret=REDACTED"}, 1964 } 1965 1966 for _, tt := range tests { 1967 inURL, _ := url.Parse(tt.in) 1968 want, _ := url.Parse(tt.want) 1969 1970 if got := sanitizeURL(inURL); !cmp.Equal(got, want) { 1971 t.Errorf("sanitizeURL(%v) returned %v, want %v", tt.in, got, want) 1972 } 1973 } 1974 } 1975 1976 func TestCheckResponse(t *testing.T) { 1977 t.Parallel() 1978 res := &http.Response{ 1979 Request: &http.Request{}, 1980 StatusCode: http.StatusBadRequest, 1981 Body: io.NopCloser(strings.NewReader(`{"message":"m", 1982 "errors": [{"resource": "r", "field": "f", "code": "c"}], 1983 "block": {"reason": "dmca", "created_at": "2016-03-17T15:39:46Z"}}`)), 1984 } 1985 err := CheckResponse(res).(*ErrorResponse) 1986 1987 if err == nil { 1988 t.Errorf("Expected error response.") 1989 } 1990 1991 want := &ErrorResponse{ 1992 Response: res, 1993 Message: "m", 1994 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 1995 Block: &ErrorBlock{ 1996 Reason: "dmca", 1997 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 1998 }, 1999 } 2000 if !errors.Is(err, want) { 2001 t.Errorf("Error = %#v, want %#v", err, want) 2002 } 2003 } 2004 2005 func TestCheckResponse_RateLimit(t *testing.T) { 2006 t.Parallel() 2007 res := &http.Response{ 2008 Request: &http.Request{}, 2009 StatusCode: http.StatusForbidden, 2010 Header: http.Header{}, 2011 Body: io.NopCloser(strings.NewReader(`{"message":"m", 2012 "documentation_url": "url"}`)), 2013 } 2014 res.Header.Set(headerRateLimit, "60") 2015 res.Header.Set(headerRateRemaining, "0") 2016 res.Header.Set(headerRateUsed, "1") 2017 res.Header.Set(headerRateReset, "243424") 2018 res.Header.Set(headerRateResource, "core") 2019 2020 err := CheckResponse(res).(*RateLimitError) 2021 2022 if err == nil { 2023 t.Errorf("Expected error response.") 2024 } 2025 2026 want := &RateLimitError{ 2027 Rate: parseRate(res), 2028 Response: res, 2029 Message: "m", 2030 } 2031 if !errors.Is(err, want) { 2032 t.Errorf("Error = %#v, want %#v", err, want) 2033 } 2034 } 2035 2036 func TestCheckResponse_AbuseRateLimit(t *testing.T) { 2037 t.Parallel() 2038 res := &http.Response{ 2039 Request: &http.Request{}, 2040 StatusCode: http.StatusForbidden, 2041 Body: io.NopCloser(strings.NewReader(`{"message":"m", 2042 "documentation_url": "docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits"}`)), 2043 } 2044 err := CheckResponse(res).(*AbuseRateLimitError) 2045 2046 if err == nil { 2047 t.Errorf("Expected error response.") 2048 } 2049 2050 want := &AbuseRateLimitError{ 2051 Response: res, 2052 Message: "m", 2053 } 2054 if !errors.Is(err, want) { 2055 t.Errorf("Error = %#v, want %#v", err, want) 2056 } 2057 } 2058 2059 func TestCheckResponse_RedirectionError(t *testing.T) { 2060 t.Parallel() 2061 urlStr := "/foo/bar" 2062 2063 res := &http.Response{ 2064 Request: &http.Request{}, 2065 StatusCode: http.StatusFound, 2066 Header: http.Header{}, 2067 Body: io.NopCloser(strings.NewReader(``)), 2068 } 2069 res.Header.Set("Location", urlStr) 2070 err := CheckResponse(res).(*RedirectionError) 2071 2072 if err == nil { 2073 t.Errorf("Expected error response.") 2074 } 2075 2076 wantedURL, parseErr := url.Parse(urlStr) 2077 if parseErr != nil { 2078 t.Errorf("Error parsing fixture url: %v", parseErr) 2079 } 2080 2081 want := &RedirectionError{ 2082 Response: res, 2083 StatusCode: http.StatusFound, 2084 Location: wantedURL, 2085 } 2086 if !errors.Is(err, want) { 2087 t.Errorf("Error = %#v, want %#v", err, want) 2088 } 2089 } 2090 2091 func TestCompareHttpResponse(t *testing.T) { 2092 t.Parallel() 2093 testcases := map[string]struct { 2094 h1 *http.Response 2095 h2 *http.Response 2096 expected bool 2097 }{ 2098 "both are nil": { 2099 expected: true, 2100 }, 2101 "both are non nil - same StatusCode": { 2102 expected: true, 2103 h1: &http.Response{StatusCode: 200}, 2104 h2: &http.Response{StatusCode: 200}, 2105 }, 2106 "both are non nil - different StatusCode": { 2107 expected: false, 2108 h1: &http.Response{StatusCode: 200}, 2109 h2: &http.Response{StatusCode: 404}, 2110 }, 2111 "one is nil, other is not": { 2112 expected: false, 2113 h2: &http.Response{}, 2114 }, 2115 } 2116 2117 for name, tc := range testcases { 2118 tc := tc 2119 t.Run(name, func(t *testing.T) { 2120 t.Parallel() 2121 v := compareHTTPResponse(tc.h1, tc.h2) 2122 if tc.expected != v { 2123 t.Errorf("Expected %t, got %t for (%#v, %#v)", tc.expected, v, tc.h1, tc.h2) 2124 } 2125 }) 2126 } 2127 } 2128 2129 func TestErrorResponse_Is(t *testing.T) { 2130 t.Parallel() 2131 err := &ErrorResponse{ 2132 Response: &http.Response{}, 2133 Message: "m", 2134 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2135 Block: &ErrorBlock{ 2136 Reason: "r", 2137 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2138 }, 2139 DocumentationURL: "https://github.com", 2140 } 2141 testcases := map[string]struct { 2142 wantSame bool 2143 otherError error 2144 }{ 2145 "errors are same": { 2146 wantSame: true, 2147 otherError: &ErrorResponse{ 2148 Response: &http.Response{}, 2149 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2150 Message: "m", 2151 Block: &ErrorBlock{ 2152 Reason: "r", 2153 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2154 }, 2155 DocumentationURL: "https://github.com", 2156 }, 2157 }, 2158 "errors have different values - Message": { 2159 wantSame: false, 2160 otherError: &ErrorResponse{ 2161 Response: &http.Response{}, 2162 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2163 Message: "m1", 2164 Block: &ErrorBlock{ 2165 Reason: "r", 2166 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2167 }, 2168 DocumentationURL: "https://github.com", 2169 }, 2170 }, 2171 "errors have different values - DocumentationURL": { 2172 wantSame: false, 2173 otherError: &ErrorResponse{ 2174 Response: &http.Response{}, 2175 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2176 Message: "m", 2177 Block: &ErrorBlock{ 2178 Reason: "r", 2179 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2180 }, 2181 DocumentationURL: "https://google.com", 2182 }, 2183 }, 2184 "errors have different values - Response is nil": { 2185 wantSame: false, 2186 otherError: &ErrorResponse{ 2187 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2188 Message: "m", 2189 Block: &ErrorBlock{ 2190 Reason: "r", 2191 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2192 }, 2193 DocumentationURL: "https://github.com", 2194 }, 2195 }, 2196 "errors have different values - Errors": { 2197 wantSame: false, 2198 otherError: &ErrorResponse{ 2199 Response: &http.Response{}, 2200 Errors: []Error{{Resource: "r1", Field: "f1", Code: "c1"}}, 2201 Message: "m", 2202 Block: &ErrorBlock{ 2203 Reason: "r", 2204 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2205 }, 2206 DocumentationURL: "https://github.com", 2207 }, 2208 }, 2209 "errors have different values - Errors have different length": { 2210 wantSame: false, 2211 otherError: &ErrorResponse{ 2212 Response: &http.Response{}, 2213 Errors: []Error{}, 2214 Message: "m", 2215 Block: &ErrorBlock{ 2216 Reason: "r", 2217 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2218 }, 2219 DocumentationURL: "https://github.com", 2220 }, 2221 }, 2222 "errors have different values - Block - one is nil, other is not": { 2223 wantSame: false, 2224 otherError: &ErrorResponse{ 2225 Response: &http.Response{}, 2226 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2227 Message: "m", 2228 DocumentationURL: "https://github.com", 2229 }, 2230 }, 2231 "errors have different values - Block - different Reason": { 2232 wantSame: false, 2233 otherError: &ErrorResponse{ 2234 Response: &http.Response{}, 2235 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2236 Message: "m", 2237 Block: &ErrorBlock{ 2238 Reason: "r1", 2239 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2240 }, 2241 DocumentationURL: "https://github.com", 2242 }, 2243 }, 2244 "errors have different values - Block - different CreatedAt #1": { 2245 wantSame: false, 2246 otherError: &ErrorResponse{ 2247 Response: &http.Response{}, 2248 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2249 Message: "m", 2250 Block: &ErrorBlock{ 2251 Reason: "r", 2252 CreatedAt: nil, 2253 }, 2254 DocumentationURL: "https://github.com", 2255 }, 2256 }, 2257 "errors have different values - Block - different CreatedAt #2": { 2258 wantSame: false, 2259 otherError: &ErrorResponse{ 2260 Response: &http.Response{}, 2261 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2262 Message: "m", 2263 Block: &ErrorBlock{ 2264 Reason: "r", 2265 CreatedAt: &Timestamp{time.Date(2017, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2266 }, 2267 DocumentationURL: "https://github.com", 2268 }, 2269 }, 2270 "errors have different types": { 2271 wantSame: false, 2272 otherError: errors.New("github"), 2273 }, 2274 } 2275 2276 for name, tc := range testcases { 2277 tc := tc 2278 t.Run(name, func(t *testing.T) { 2279 t.Parallel() 2280 if tc.wantSame != err.Is(tc.otherError) { 2281 t.Errorf("Error = %#v, want %#v", err, tc.otherError) 2282 } 2283 }) 2284 } 2285 } 2286 2287 func TestRateLimitError_Is(t *testing.T) { 2288 t.Parallel() 2289 err := &RateLimitError{ 2290 Response: &http.Response{}, 2291 Message: "Github", 2292 } 2293 testcases := map[string]struct { 2294 wantSame bool 2295 err *RateLimitError 2296 otherError error 2297 }{ 2298 "errors are same": { 2299 wantSame: true, 2300 err: err, 2301 otherError: &RateLimitError{ 2302 Response: &http.Response{}, 2303 Message: "Github", 2304 }, 2305 }, 2306 "errors are same - Response is nil": { 2307 wantSame: true, 2308 err: &RateLimitError{ 2309 Message: "Github", 2310 }, 2311 otherError: &RateLimitError{ 2312 Message: "Github", 2313 }, 2314 }, 2315 "errors have different values - Rate": { 2316 wantSame: false, 2317 err: err, 2318 otherError: &RateLimitError{ 2319 Rate: Rate{Limit: 10}, 2320 Response: &http.Response{}, 2321 Message: "Gitlab", 2322 }, 2323 }, 2324 "errors have different values - Response is nil": { 2325 wantSame: false, 2326 err: err, 2327 otherError: &RateLimitError{ 2328 Message: "Github", 2329 }, 2330 }, 2331 "errors have different values - StatusCode": { 2332 wantSame: false, 2333 err: err, 2334 otherError: &RateLimitError{ 2335 Response: &http.Response{StatusCode: 200}, 2336 Message: "Github", 2337 }, 2338 }, 2339 "errors have different types": { 2340 wantSame: false, 2341 err: err, 2342 otherError: errors.New("github"), 2343 }, 2344 } 2345 2346 for name, tc := range testcases { 2347 tc := tc 2348 t.Run(name, func(t *testing.T) { 2349 t.Parallel() 2350 if tc.wantSame != tc.err.Is(tc.otherError) { 2351 t.Errorf("Error = %#v, want %#v", tc.err, tc.otherError) 2352 } 2353 }) 2354 } 2355 } 2356 2357 func TestAbuseRateLimitError_Is(t *testing.T) { 2358 t.Parallel() 2359 t1 := 1 * time.Second 2360 t2 := 2 * time.Second 2361 err := &AbuseRateLimitError{ 2362 Response: &http.Response{}, 2363 Message: "Github", 2364 RetryAfter: &t1, 2365 } 2366 testcases := map[string]struct { 2367 wantSame bool 2368 err *AbuseRateLimitError 2369 otherError error 2370 }{ 2371 "errors are same": { 2372 wantSame: true, 2373 err: err, 2374 otherError: &AbuseRateLimitError{ 2375 Response: &http.Response{}, 2376 Message: "Github", 2377 RetryAfter: &t1, 2378 }, 2379 }, 2380 "errors are same - Response is nil": { 2381 wantSame: true, 2382 err: &AbuseRateLimitError{ 2383 Message: "Github", 2384 RetryAfter: &t1, 2385 }, 2386 otherError: &AbuseRateLimitError{ 2387 Message: "Github", 2388 RetryAfter: &t1, 2389 }, 2390 }, 2391 "errors have different values - Message": { 2392 wantSame: false, 2393 err: err, 2394 otherError: &AbuseRateLimitError{ 2395 Response: &http.Response{}, 2396 Message: "Gitlab", 2397 RetryAfter: nil, 2398 }, 2399 }, 2400 "errors have different values - RetryAfter": { 2401 wantSame: false, 2402 err: err, 2403 otherError: &AbuseRateLimitError{ 2404 Response: &http.Response{}, 2405 Message: "Github", 2406 RetryAfter: &t2, 2407 }, 2408 }, 2409 "errors have different values - Response is nil": { 2410 wantSame: false, 2411 err: err, 2412 otherError: &AbuseRateLimitError{ 2413 Message: "Github", 2414 RetryAfter: &t1, 2415 }, 2416 }, 2417 "errors have different values - StatusCode": { 2418 wantSame: false, 2419 err: err, 2420 otherError: &AbuseRateLimitError{ 2421 Response: &http.Response{StatusCode: 200}, 2422 Message: "Github", 2423 RetryAfter: &t1, 2424 }, 2425 }, 2426 "errors have different types": { 2427 wantSame: false, 2428 err: err, 2429 otherError: errors.New("github"), 2430 }, 2431 } 2432 2433 for name, tc := range testcases { 2434 tc := tc 2435 t.Run(name, func(t *testing.T) { 2436 t.Parallel() 2437 if tc.wantSame != tc.err.Is(tc.otherError) { 2438 t.Errorf("Error = %#v, want %#v", tc.err, tc.otherError) 2439 } 2440 }) 2441 } 2442 } 2443 2444 func TestAcceptedError_Is(t *testing.T) { 2445 t.Parallel() 2446 err := &AcceptedError{Raw: []byte("Github")} 2447 testcases := map[string]struct { 2448 wantSame bool 2449 otherError error 2450 }{ 2451 "errors are same": { 2452 wantSame: true, 2453 otherError: &AcceptedError{Raw: []byte("Github")}, 2454 }, 2455 "errors have different values": { 2456 wantSame: false, 2457 otherError: &AcceptedError{Raw: []byte("Gitlab")}, 2458 }, 2459 "errors have different types": { 2460 wantSame: false, 2461 otherError: errors.New("github"), 2462 }, 2463 } 2464 2465 for name, tc := range testcases { 2466 tc := tc 2467 t.Run(name, func(t *testing.T) { 2468 t.Parallel() 2469 if tc.wantSame != err.Is(tc.otherError) { 2470 t.Errorf("Error = %#v, want %#v", err, tc.otherError) 2471 } 2472 }) 2473 } 2474 } 2475 2476 // Ensure that we properly handle API errors that do not contain a response body. 2477 func TestCheckResponse_noBody(t *testing.T) { 2478 t.Parallel() 2479 res := &http.Response{ 2480 Request: &http.Request{}, 2481 StatusCode: http.StatusBadRequest, 2482 Body: io.NopCloser(strings.NewReader("")), 2483 } 2484 err := CheckResponse(res).(*ErrorResponse) 2485 2486 if err == nil { 2487 t.Errorf("Expected error response.") 2488 } 2489 2490 want := &ErrorResponse{ 2491 Response: res, 2492 } 2493 if !errors.Is(err, want) { 2494 t.Errorf("Error = %#v, want %#v", err, want) 2495 } 2496 } 2497 2498 func TestCheckResponse_unexpectedErrorStructure(t *testing.T) { 2499 t.Parallel() 2500 httpBody := `{"message":"m", "errors": ["error 1"]}` 2501 res := &http.Response{ 2502 Request: &http.Request{}, 2503 StatusCode: http.StatusBadRequest, 2504 Body: io.NopCloser(strings.NewReader(httpBody)), 2505 } 2506 err := CheckResponse(res).(*ErrorResponse) 2507 2508 if err == nil { 2509 t.Errorf("Expected error response.") 2510 } 2511 2512 want := &ErrorResponse{ 2513 Response: res, 2514 Message: "m", 2515 Errors: []Error{{Message: "error 1"}}, 2516 } 2517 if !errors.Is(err, want) { 2518 t.Errorf("Error = %#v, want %#v", err, want) 2519 } 2520 data, err2 := io.ReadAll(err.Response.Body) 2521 if err2 != nil { 2522 t.Fatalf("failed to read response body: %v", err) 2523 } 2524 if got := string(data); got != httpBody { 2525 t.Errorf("ErrorResponse.Response.Body = %q, want %q", got, httpBody) 2526 } 2527 } 2528 2529 func TestParseBooleanResponse_true(t *testing.T) { 2530 t.Parallel() 2531 result, err := parseBoolResponse(nil) 2532 if err != nil { 2533 t.Errorf("parseBoolResponse returned error: %+v", err) 2534 } 2535 2536 if want := true; result != want { 2537 t.Errorf("parseBoolResponse returned %+v, want: %+v", result, want) 2538 } 2539 } 2540 2541 func TestParseBooleanResponse_false(t *testing.T) { 2542 t.Parallel() 2543 v := &ErrorResponse{Response: &http.Response{StatusCode: http.StatusNotFound}} 2544 result, err := parseBoolResponse(v) 2545 if err != nil { 2546 t.Errorf("parseBoolResponse returned error: %+v", err) 2547 } 2548 2549 if want := false; result != want { 2550 t.Errorf("parseBoolResponse returned %+v, want: %+v", result, want) 2551 } 2552 } 2553 2554 func TestParseBooleanResponse_error(t *testing.T) { 2555 t.Parallel() 2556 v := &ErrorResponse{Response: &http.Response{StatusCode: http.StatusBadRequest}} 2557 result, err := parseBoolResponse(v) 2558 2559 if err == nil { 2560 t.Errorf("Expected error to be returned.") 2561 } 2562 2563 if want := false; result != want { 2564 t.Errorf("parseBoolResponse returned %+v, want: %+v", result, want) 2565 } 2566 } 2567 2568 func TestErrorResponse_Error(t *testing.T) { 2569 t.Parallel() 2570 res := &http.Response{Request: &http.Request{}} 2571 err := ErrorResponse{Message: "m", Response: res} 2572 if err.Error() == "" { 2573 t.Errorf("Expected non-empty ErrorResponse.Error()") 2574 } 2575 2576 // dont panic if request is nil 2577 res = &http.Response{} 2578 err = ErrorResponse{Message: "m", Response: res} 2579 if err.Error() == "" { 2580 t.Errorf("Expected non-empty ErrorResponse.Error()") 2581 } 2582 2583 // dont panic if response is nil 2584 err = ErrorResponse{Message: "m"} 2585 if err.Error() == "" { 2586 t.Errorf("Expected non-empty ErrorResponse.Error()") 2587 } 2588 } 2589 2590 func TestError_Error(t *testing.T) { 2591 t.Parallel() 2592 err := Error{} 2593 if err.Error() == "" { 2594 t.Errorf("Expected non-empty Error.Error()") 2595 } 2596 } 2597 2598 func TestSetCredentialsAsHeaders(t *testing.T) { 2599 t.Parallel() 2600 req := new(http.Request) 2601 id, secret := "id", "secret" 2602 modifiedRequest := setCredentialsAsHeaders(req, id, secret) 2603 2604 actualID, actualSecret, ok := modifiedRequest.BasicAuth() 2605 if !ok { 2606 t.Errorf("request does not contain basic credentials") 2607 } 2608 2609 if actualID != id { 2610 t.Errorf("id is %s, want %s", actualID, id) 2611 } 2612 2613 if actualSecret != secret { 2614 t.Errorf("secret is %s, want %s", actualSecret, secret) 2615 } 2616 } 2617 2618 func TestUnauthenticatedRateLimitedTransport(t *testing.T) { 2619 t.Parallel() 2620 client, mux, _ := setup(t) 2621 2622 clientID, clientSecret := "id", "secret" 2623 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 2624 id, secret, ok := r.BasicAuth() 2625 if !ok { 2626 t.Errorf("request does not contain basic auth credentials") 2627 } 2628 if id != clientID { 2629 t.Errorf("request contained basic auth username %q, want %q", id, clientID) 2630 } 2631 if secret != clientSecret { 2632 t.Errorf("request contained basic auth password %q, want %q", secret, clientSecret) 2633 } 2634 }) 2635 2636 tp := &UnauthenticatedRateLimitedTransport{ 2637 ClientID: clientID, 2638 ClientSecret: clientSecret, 2639 } 2640 unauthedClient := NewClient(tp.Client()) 2641 unauthedClient.BaseURL = client.BaseURL 2642 req, _ := unauthedClient.NewRequest("GET", ".", nil) 2643 ctx := context.Background() 2644 _, err := unauthedClient.Do(ctx, req, nil) 2645 assertNilError(t, err) 2646 } 2647 2648 func TestUnauthenticatedRateLimitedTransport_missingFields(t *testing.T) { 2649 t.Parallel() 2650 // missing ClientID 2651 tp := &UnauthenticatedRateLimitedTransport{ 2652 ClientSecret: "secret", 2653 } 2654 _, err := tp.RoundTrip(nil) 2655 if err == nil { 2656 t.Errorf("Expected error to be returned") 2657 } 2658 2659 // missing ClientSecret 2660 tp = &UnauthenticatedRateLimitedTransport{ 2661 ClientID: "id", 2662 } 2663 _, err = tp.RoundTrip(nil) 2664 if err == nil { 2665 t.Errorf("Expected error to be returned") 2666 } 2667 } 2668 2669 func TestUnauthenticatedRateLimitedTransport_transport(t *testing.T) { 2670 t.Parallel() 2671 // default transport 2672 tp := &UnauthenticatedRateLimitedTransport{ 2673 ClientID: "id", 2674 ClientSecret: "secret", 2675 } 2676 if tp.transport() != http.DefaultTransport { 2677 t.Errorf("Expected http.DefaultTransport to be used.") 2678 } 2679 2680 // custom transport 2681 tp = &UnauthenticatedRateLimitedTransport{ 2682 ClientID: "id", 2683 ClientSecret: "secret", 2684 Transport: &http.Transport{}, 2685 } 2686 if tp.transport() == http.DefaultTransport { 2687 t.Errorf("Expected custom transport to be used.") 2688 } 2689 } 2690 2691 func TestBasicAuthTransport(t *testing.T) { 2692 t.Parallel() 2693 client, mux, _ := setup(t) 2694 2695 username, password, otp := "u", "p", "123456" 2696 2697 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 2698 u, p, ok := r.BasicAuth() 2699 if !ok { 2700 t.Errorf("request does not contain basic auth credentials") 2701 } 2702 if u != username { 2703 t.Errorf("request contained basic auth username %q, want %q", u, username) 2704 } 2705 if p != password { 2706 t.Errorf("request contained basic auth password %q, want %q", p, password) 2707 } 2708 if got, want := r.Header.Get(headerOTP), otp; got != want { 2709 t.Errorf("request contained OTP %q, want %q", got, want) 2710 } 2711 }) 2712 2713 tp := &BasicAuthTransport{ 2714 Username: username, 2715 Password: password, 2716 OTP: otp, 2717 } 2718 basicAuthClient := NewClient(tp.Client()) 2719 basicAuthClient.BaseURL = client.BaseURL 2720 req, _ := basicAuthClient.NewRequest("GET", ".", nil) 2721 ctx := context.Background() 2722 _, err := basicAuthClient.Do(ctx, req, nil) 2723 assertNilError(t, err) 2724 } 2725 2726 func TestBasicAuthTransport_transport(t *testing.T) { 2727 t.Parallel() 2728 // default transport 2729 tp := &BasicAuthTransport{} 2730 if tp.transport() != http.DefaultTransport { 2731 t.Errorf("Expected http.DefaultTransport to be used.") 2732 } 2733 2734 // custom transport 2735 tp = &BasicAuthTransport{ 2736 Transport: &http.Transport{}, 2737 } 2738 if tp.transport() == http.DefaultTransport { 2739 t.Errorf("Expected custom transport to be used.") 2740 } 2741 } 2742 2743 func TestFormatRateReset(t *testing.T) { 2744 t.Parallel() 2745 d := 120*time.Minute + 12*time.Second 2746 got := formatRateReset(d) 2747 want := "[rate reset in 120m12s]" 2748 if got != want { 2749 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2750 } 2751 2752 d = 14*time.Minute + 2*time.Second 2753 got = formatRateReset(d) 2754 want = "[rate reset in 14m02s]" 2755 if got != want { 2756 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2757 } 2758 2759 d = 2*time.Minute + 2*time.Second 2760 got = formatRateReset(d) 2761 want = "[rate reset in 2m02s]" 2762 if got != want { 2763 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2764 } 2765 2766 d = 12 * time.Second 2767 got = formatRateReset(d) 2768 want = "[rate reset in 12s]" 2769 if got != want { 2770 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2771 } 2772 2773 d = -1 * (2*time.Hour + 2*time.Second) 2774 got = formatRateReset(d) 2775 want = "[rate limit was reset 120m02s ago]" 2776 if got != want { 2777 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2778 } 2779 } 2780 2781 func TestNestedStructAccessorNoPanic(t *testing.T) { 2782 t.Parallel() 2783 issue := &Issue{User: nil} 2784 got := issue.GetUser().GetPlan().GetName() 2785 want := "" 2786 if got != want { 2787 t.Errorf("Issues.Get.GetUser().GetPlan().GetName() returned %+v, want %+v", got, want) 2788 } 2789 } 2790 2791 func TestTwoFactorAuthError(t *testing.T) { 2792 t.Parallel() 2793 u, err := url.Parse("https://example.com") 2794 if err != nil { 2795 t.Fatal(err) 2796 } 2797 2798 e := &TwoFactorAuthError{ 2799 Response: &http.Response{ 2800 Request: &http.Request{Method: "PUT", URL: u}, 2801 StatusCode: http.StatusTooManyRequests, 2802 }, 2803 Message: "<msg>", 2804 } 2805 if got, want := e.Error(), "PUT https://example.com: 429 <msg> []"; got != want { 2806 t.Errorf("TwoFactorAuthError = %q, want %q", got, want) 2807 } 2808 } 2809 2810 func TestRateLimitError(t *testing.T) { 2811 t.Parallel() 2812 u, err := url.Parse("https://example.com") 2813 if err != nil { 2814 t.Fatal(err) 2815 } 2816 2817 r := &RateLimitError{ 2818 Response: &http.Response{ 2819 Request: &http.Request{Method: "PUT", URL: u}, 2820 StatusCode: http.StatusTooManyRequests, 2821 }, 2822 Message: "<msg>", 2823 } 2824 if got, want := r.Error(), "PUT https://example.com: 429 <msg> [rate limit was reset"; !strings.Contains(got, want) { 2825 t.Errorf("RateLimitError = %q, want %q", got, want) 2826 } 2827 } 2828 2829 func TestAcceptedError(t *testing.T) { 2830 t.Parallel() 2831 a := &AcceptedError{} 2832 if got, want := a.Error(), "try again later"; !strings.Contains(got, want) { 2833 t.Errorf("AcceptedError = %q, want %q", got, want) 2834 } 2835 } 2836 2837 func TestAbuseRateLimitError(t *testing.T) { 2838 t.Parallel() 2839 u, err := url.Parse("https://example.com") 2840 if err != nil { 2841 t.Fatal(err) 2842 } 2843 2844 r := &AbuseRateLimitError{ 2845 Response: &http.Response{ 2846 Request: &http.Request{Method: "PUT", URL: u}, 2847 StatusCode: http.StatusTooManyRequests, 2848 }, 2849 Message: "<msg>", 2850 } 2851 if got, want := r.Error(), "PUT https://example.com: 429 <msg>"; got != want { 2852 t.Errorf("AbuseRateLimitError = %q, want %q", got, want) 2853 } 2854 } 2855 2856 func TestAddOptions_QueryValues(t *testing.T) { 2857 t.Parallel() 2858 if _, err := addOptions("yo", ""); err == nil { 2859 t.Error("addOptions err = nil, want error") 2860 } 2861 } 2862 2863 func TestBareDo_returnsOpenBody(t *testing.T) { 2864 t.Parallel() 2865 client, mux, _ := setup(t) 2866 2867 expectedBody := "Hello from the other side !" 2868 2869 mux.HandleFunc("/test-url", func(w http.ResponseWriter, r *http.Request) { 2870 testMethod(t, r, "GET") 2871 fmt.Fprint(w, expectedBody) 2872 }) 2873 2874 ctx := context.Background() 2875 req, err := client.NewRequest("GET", "test-url", nil) 2876 if err != nil { 2877 t.Fatalf("client.NewRequest returned error: %v", err) 2878 } 2879 2880 resp, err := client.BareDo(ctx, req) 2881 if err != nil { 2882 t.Fatalf("client.BareDo returned error: %v", err) 2883 } 2884 2885 got, err := io.ReadAll(resp.Body) 2886 if err != nil { 2887 t.Fatalf("io.ReadAll returned error: %v", err) 2888 } 2889 if string(got) != expectedBody { 2890 t.Fatalf("Expected %q, got %q", expectedBody, string(got)) 2891 } 2892 if err := resp.Body.Close(); err != nil { 2893 t.Fatalf("resp.Body.Close() returned error: %v", err) 2894 } 2895 } 2896 2897 func TestErrorResponse_Marshal(t *testing.T) { 2898 t.Parallel() 2899 testJSONMarshal(t, &ErrorResponse{}, "{}") 2900 2901 u := &ErrorResponse{ 2902 Message: "msg", 2903 Errors: []Error{ 2904 { 2905 Resource: "res", 2906 Field: "f", 2907 Code: "c", 2908 Message: "msg", 2909 }, 2910 }, 2911 Block: &ErrorBlock{ 2912 Reason: "reason", 2913 CreatedAt: &Timestamp{referenceTime}, 2914 }, 2915 DocumentationURL: "doc", 2916 } 2917 2918 want := `{ 2919 "message": "msg", 2920 "errors": [ 2921 { 2922 "resource": "res", 2923 "field": "f", 2924 "code": "c", 2925 "message": "msg" 2926 } 2927 ], 2928 "block": { 2929 "reason": "reason", 2930 "created_at": ` + referenceTimeStr + ` 2931 }, 2932 "documentation_url": "doc" 2933 }` 2934 2935 testJSONMarshal(t, u, want) 2936 } 2937 2938 func TestErrorBlock_Marshal(t *testing.T) { 2939 t.Parallel() 2940 testJSONMarshal(t, &ErrorBlock{}, "{}") 2941 2942 u := &ErrorBlock{ 2943 Reason: "reason", 2944 CreatedAt: &Timestamp{referenceTime}, 2945 } 2946 2947 want := `{ 2948 "reason": "reason", 2949 "created_at": ` + referenceTimeStr + ` 2950 }` 2951 2952 testJSONMarshal(t, u, want) 2953 } 2954 2955 func TestRateLimitError_Marshal(t *testing.T) { 2956 t.Parallel() 2957 testJSONMarshal(t, &RateLimitError{}, "{}") 2958 2959 u := &RateLimitError{ 2960 Rate: Rate{ 2961 Limit: 1, 2962 Remaining: 1, 2963 Reset: Timestamp{referenceTime}, 2964 }, 2965 Message: "msg", 2966 } 2967 2968 want := `{ 2969 "Rate": { 2970 "limit": 1, 2971 "remaining": 1, 2972 "reset": ` + referenceTimeStr + ` 2973 }, 2974 "message": "msg" 2975 }` 2976 2977 testJSONMarshal(t, u, want) 2978 } 2979 2980 func TestAbuseRateLimitError_Marshal(t *testing.T) { 2981 t.Parallel() 2982 testJSONMarshal(t, &AbuseRateLimitError{}, "{}") 2983 2984 u := &AbuseRateLimitError{ 2985 Message: "msg", 2986 } 2987 2988 want := `{ 2989 "message": "msg" 2990 }` 2991 2992 testJSONMarshal(t, u, want) 2993 } 2994 2995 func TestError_Marshal(t *testing.T) { 2996 t.Parallel() 2997 testJSONMarshal(t, &Error{}, "{}") 2998 2999 u := &Error{ 3000 Resource: "res", 3001 Field: "field", 3002 Code: "code", 3003 Message: "msg", 3004 } 3005 3006 want := `{ 3007 "resource": "res", 3008 "field": "field", 3009 "code": "code", 3010 "message": "msg" 3011 }` 3012 3013 testJSONMarshal(t, u, want) 3014 } 3015 3016 func TestParseTokenExpiration(t *testing.T) { 3017 t.Parallel() 3018 tests := []struct { 3019 header string 3020 want Timestamp 3021 }{ 3022 { 3023 header: "", 3024 want: Timestamp{}, 3025 }, 3026 { 3027 header: "this is a garbage", 3028 want: Timestamp{}, 3029 }, 3030 { 3031 header: "2021-09-03 02:34:04 UTC", 3032 want: Timestamp{time.Date(2021, time.September, 3, 2, 34, 4, 0, time.UTC)}, 3033 }, 3034 { 3035 header: "2021-09-03 14:34:04 UTC", 3036 want: Timestamp{time.Date(2021, time.September, 3, 14, 34, 4, 0, time.UTC)}, 3037 }, 3038 // Some tokens include the timezone offset instead of the timezone. 3039 // https://github.com/google/go-github/issues/2649 3040 { 3041 header: "2023-04-26 20:23:26 +0200", 3042 want: Timestamp{time.Date(2023, time.April, 26, 18, 23, 26, 0, time.UTC)}, 3043 }, 3044 } 3045 3046 for _, tt := range tests { 3047 res := &http.Response{ 3048 Request: &http.Request{}, 3049 Header: http.Header{}, 3050 } 3051 3052 res.Header.Set(headerTokenExpiration, tt.header) 3053 exp := parseTokenExpiration(res) 3054 if !exp.Equal(tt.want) { 3055 t.Errorf("parseTokenExpiration of %q\nreturned %#v\n want %#v", tt.header, exp, tt.want) 3056 } 3057 } 3058 } 3059 3060 func TestClientCopy_leak_transport(t *testing.T) { 3061 t.Parallel() 3062 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 3063 w.Header().Set("Content-Type", "application/json") 3064 accessToken := r.Header.Get("Authorization") 3065 _, _ = fmt.Fprintf(w, `{"login": "%s"}`, accessToken) 3066 })) 3067 clientPreconfiguredWithURLs, err := NewClient(nil).WithEnterpriseURLs(srv.URL, srv.URL) 3068 if err != nil { 3069 t.Fatal(err) 3070 } 3071 3072 aliceClient := clientPreconfiguredWithURLs.WithAuthToken("alice") 3073 bobClient := clientPreconfiguredWithURLs.WithAuthToken("bob") 3074 3075 alice, _, err := aliceClient.Users.Get(context.Background(), "") 3076 if err != nil { 3077 t.Fatal(err) 3078 } 3079 3080 assertNoDiff(t, "Bearer alice", alice.GetLogin()) 3081 3082 bob, _, err := bobClient.Users.Get(context.Background(), "") 3083 if err != nil { 3084 t.Fatal(err) 3085 } 3086 3087 assertNoDiff(t, "Bearer bob", bob.GetLogin()) 3088 } 3089 3090 func TestPtr(t *testing.T) { 3091 t.Parallel() 3092 equal := func(t *testing.T, want, got any) { 3093 t.Helper() 3094 if !reflect.DeepEqual(want, got) { 3095 t.Errorf("want %#v, got %#v", want, got) 3096 } 3097 } 3098 3099 equal(t, true, *Ptr(true)) 3100 equal(t, int(10), *Ptr(int(10))) 3101 equal(t, int64(-10), *Ptr(int64(-10))) 3102 equal(t, "str", *Ptr("str")) 3103 } 3104 3105 func TestDeploymentProtectionRuleEvent_GetRunID(t *testing.T) { 3106 t.Parallel() 3107 3108 var want int64 = 123456789 3109 url := "repos/dummy-org/dummy-repo/actions/runs/123456789/deployment_protection_rule" 3110 3111 e := DeploymentProtectionRuleEvent{ 3112 DeploymentCallbackURL: &url, 3113 } 3114 3115 got, _ := e.GetRunID() 3116 if got != want { 3117 t.Errorf("want %#v, got %#v", want, got) 3118 } 3119 3120 want = -1 3121 url = "repos/dummy-org/dummy-repo/actions/runs/abc123/deployment_protection_rule" 3122 3123 got, err := e.GetRunID() 3124 if err == nil { 3125 t.Errorf("Expected error to be returned") 3126 } 3127 3128 if got != want { 3129 t.Errorf("want %#v, got %#v", want, got) 3130 } 3131 }