github.com/google/go-github/v71@v71.0.0/github/github_test.go (about) 1 // Copyright 2013 The go-github AUTHORS. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 package github 7 8 import ( 9 "context" 10 "encoding/json" 11 "errors" 12 "fmt" 13 "io" 14 "net/http" 15 "net/http/httptest" 16 "net/url" 17 "os" 18 "path/filepath" 19 "reflect" 20 "strconv" 21 "strings" 22 "testing" 23 "time" 24 25 "github.com/google/go-cmp/cmp" 26 ) 27 28 const ( 29 // baseURLPath is a non-empty Client.BaseURL path to use during tests, 30 // to ensure relative URLs are used for all endpoints. See issue #752. 31 baseURLPath = "/api-v3" 32 ) 33 34 // setup sets up a test HTTP server along with a github.Client that is 35 // configured to talk to that test server. Tests should register handlers on 36 // mux which provide mock responses for the API method being tested. 37 func setup(t *testing.T) (client *Client, mux *http.ServeMux, serverURL string) { 38 t.Helper() 39 // mux is the HTTP request multiplexer used with the test server. 40 mux = http.NewServeMux() 41 42 // We want to ensure that tests catch mistakes where the endpoint URL is 43 // specified as absolute rather than relative. It only makes a difference 44 // when there's a non-empty base URL path. So, use that. See issue #752. 45 apiHandler := http.NewServeMux() 46 apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux)) 47 apiHandler.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { 48 fmt.Fprintln(os.Stderr, "FAIL: Client.BaseURL path prefix is not preserved in the request URL:") 49 fmt.Fprintln(os.Stderr) 50 fmt.Fprintln(os.Stderr, "\t"+req.URL.String()) 51 fmt.Fprintln(os.Stderr) 52 fmt.Fprintln(os.Stderr, "\tDid you accidentally use an absolute endpoint URL rather than relative?") 53 fmt.Fprintln(os.Stderr, "\tSee https://github.com/google/go-github/issues/752 for information.") 54 http.Error(w, "Client.BaseURL path prefix is not preserved in the request URL.", http.StatusInternalServerError) 55 }) 56 57 // server is a test HTTP server used to provide mock API responses. 58 server := httptest.NewServer(apiHandler) 59 60 // Create a custom transport with isolated connection pool 61 transport := &http.Transport{ 62 // Controls connection reuse - false allows reuse, true forces new connections for each request 63 DisableKeepAlives: false, 64 // Maximum concurrent connections per host (active + idle) 65 MaxConnsPerHost: 10, 66 // Maximum idle connections maintained per host for reuse 67 MaxIdleConnsPerHost: 5, 68 // Maximum total idle connections across all hosts 69 MaxIdleConns: 20, 70 // How long an idle connection remains in the pool before being closed 71 IdleConnTimeout: 20 * time.Second, 72 } 73 74 // Create HTTP client with the isolated transport 75 httpClient := &http.Client{ 76 Transport: transport, 77 Timeout: 30 * time.Second, 78 } 79 // client is the GitHub client being tested and is 80 // configured to use test server. 81 client = NewClient(httpClient) 82 83 url, _ := url.Parse(server.URL + baseURLPath + "/") 84 client.BaseURL = url 85 client.UploadURL = url 86 87 t.Cleanup(server.Close) 88 89 return client, mux, server.URL 90 } 91 92 // openTestFile creates a new file with the given name and content for testing. 93 // In order to ensure the exact file name, this function will create a new temp 94 // directory, and create the file in that directory. The file is automatically 95 // cleaned up after the test. 96 func openTestFile(t *testing.T, name, content string) *os.File { 97 t.Helper() 98 fname := filepath.Join(t.TempDir(), name) 99 err := os.WriteFile(fname, []byte(content), 0600) 100 if err != nil { 101 t.Fatal(err) 102 } 103 file, err := os.Open(fname) 104 if err != nil { 105 t.Fatal(err) 106 } 107 108 t.Cleanup(func() { file.Close() }) 109 110 return file 111 } 112 113 func testMethod(t *testing.T, r *http.Request, want string) { 114 t.Helper() 115 if got := r.Method; got != want { 116 t.Errorf("Request method: %v, want %v", got, want) 117 } 118 } 119 120 type values map[string]string 121 122 func testFormValues(t *testing.T, r *http.Request, values values) { 123 t.Helper() 124 want := url.Values{} 125 for k, v := range values { 126 want.Set(k, v) 127 } 128 129 assertNilError(t, r.ParseForm()) 130 if got := r.Form; !cmp.Equal(got, want) { 131 t.Errorf("Request parameters: %v, want %v", got, want) 132 } 133 } 134 135 func testHeader(t *testing.T, r *http.Request, header string, want string) { 136 t.Helper() 137 if got := r.Header.Get(header); got != want { 138 t.Errorf("Header.Get(%q) returned %q, want %q", header, got, want) 139 } 140 } 141 142 func testURLParseError(t *testing.T, err error) { 143 t.Helper() 144 if err == nil { 145 t.Errorf("Expected error to be returned") 146 } 147 if err, ok := err.(*url.Error); !ok || err.Op != "parse" { 148 t.Errorf("Expected URL parse error, got %+v", err) 149 } 150 } 151 152 func testBody(t *testing.T, r *http.Request, want string) { 153 t.Helper() 154 b, err := io.ReadAll(r.Body) 155 if err != nil { 156 t.Errorf("Error reading request body: %v", err) 157 } 158 if got := string(b); got != want { 159 t.Errorf("request Body is %s, want %s", got, want) 160 } 161 } 162 163 // Test whether the marshaling of v produces JSON that corresponds 164 // to the want string. 165 func testJSONMarshal(t *testing.T, v interface{}, want string) { 166 t.Helper() 167 // Unmarshal the wanted JSON, to verify its correctness, and marshal it back 168 // to sort the keys. 169 u := reflect.New(reflect.TypeOf(v)).Interface() 170 if err := json.Unmarshal([]byte(want), &u); err != nil { 171 t.Errorf("Unable to unmarshal JSON for %v: %v", want, err) 172 } 173 w, err := json.MarshalIndent(u, "", " ") 174 if err != nil { 175 t.Errorf("Unable to marshal JSON for %#v", u) 176 } 177 178 // Marshal the target value. 179 got, err := json.MarshalIndent(v, "", " ") 180 if err != nil { 181 t.Errorf("Unable to marshal JSON for %#v", v) 182 } 183 184 if diff := cmp.Diff(string(w), string(got)); diff != "" { 185 t.Errorf("json.Marshal returned:\n%s\nwant:\n%s\ndiff:\n%v", got, w, diff) 186 } 187 } 188 189 // Test whether the v fields have the url tag and the parsing of v 190 // produces query parameters that corresponds to the want string. 191 func testAddURLOptions(t *testing.T, url string, v interface{}, want string) { 192 t.Helper() 193 194 vt := reflect.Indirect(reflect.ValueOf(v)).Type() 195 for i := 0; i < vt.NumField(); i++ { 196 field := vt.Field(i) 197 if alias, ok := field.Tag.Lookup("url"); ok { 198 if alias == "" { 199 t.Errorf("The field %+v has a blank url tag", field) 200 } 201 } else { 202 t.Errorf("The field %+v has no url tag specified", field) 203 } 204 } 205 206 got, err := addOptions(url, v) 207 if err != nil { 208 t.Errorf("Unable to add %#v as query parameters", v) 209 } 210 211 if got != want { 212 t.Errorf("addOptions(%q, %#v) returned %v, want %v", url, v, got, want) 213 } 214 } 215 216 // Test how bad options are handled. Method f under test should 217 // return an error. 218 func testBadOptions(t *testing.T, methodName string, f func() error) { 219 t.Helper() 220 if methodName == "" { 221 t.Error("testBadOptions: must supply method methodName") 222 } 223 if err := f(); err == nil { 224 t.Errorf("bad options %v err = nil, want error", methodName) 225 } 226 } 227 228 // Test function under NewRequest failure and then s.client.Do failure. 229 // Method f should be a regular call that would normally succeed, but 230 // should return an error when NewRequest or s.client.Do fails. 231 func testNewRequestAndDoFailure(t *testing.T, methodName string, client *Client, f func() (*Response, error)) { 232 testNewRequestAndDoFailureCategory(t, methodName, client, CoreCategory, f) 233 } 234 235 // testNewRequestAndDoFailureCategory works Like testNewRequestAndDoFailure, but allows setting the category. 236 func testNewRequestAndDoFailureCategory(t *testing.T, methodName string, client *Client, category RateLimitCategory, f func() (*Response, error)) { 237 t.Helper() 238 if methodName == "" { 239 t.Error("testNewRequestAndDoFailure: must supply method methodName") 240 } 241 242 client.BaseURL.Path = "" 243 resp, err := f() 244 if resp != nil { 245 t.Errorf("client.BaseURL.Path='' %v resp = %#v, want nil", methodName, resp) 246 } 247 if err == nil { 248 t.Errorf("client.BaseURL.Path='' %v err = nil, want error", methodName) 249 } 250 251 client.BaseURL.Path = "/api-v3/" 252 client.rateLimits[category].Reset.Time = time.Now().Add(10 * time.Minute) 253 resp, err = f() 254 if bypass := resp.Request.Context().Value(BypassRateLimitCheck); bypass != nil { 255 return 256 } 257 if want := http.StatusForbidden; resp == nil || resp.Response.StatusCode != want { 258 if resp != nil { 259 t.Errorf("rate.Reset.Time > now %v resp = %#v, want StatusCode=%v", methodName, resp.Response, want) 260 } else { 261 t.Errorf("rate.Reset.Time > now %v resp = nil, want StatusCode=%v", methodName, want) 262 } 263 } 264 if err == nil { 265 t.Errorf("rate.Reset.Time > now %v err = nil, want error", methodName) 266 } 267 } 268 269 // Test that all error response types contain the status code. 270 func testErrorResponseForStatusCode(t *testing.T, code int) { 271 t.Helper() 272 client, mux, _ := setup(t) 273 274 mux.HandleFunc("/repos/o/r/hooks", func(w http.ResponseWriter, r *http.Request) { 275 testMethod(t, r, "GET") 276 w.WriteHeader(code) 277 }) 278 279 ctx := context.Background() 280 _, _, err := client.Repositories.ListHooks(ctx, "o", "r", nil) 281 282 switch e := err.(type) { 283 case *ErrorResponse: 284 case *RateLimitError: 285 case *AbuseRateLimitError: 286 if code != e.Response.StatusCode { 287 t.Error("Error response does not contain status code") 288 } 289 default: 290 t.Error("Unknown error response type") 291 } 292 } 293 294 func assertNoDiff(t *testing.T, want, got interface{}) { 295 t.Helper() 296 if diff := cmp.Diff(want, got); diff != "" { 297 t.Errorf("diff mismatch (-want +got):\n%v", diff) 298 } 299 } 300 301 func assertNilError(t *testing.T, err error) { 302 t.Helper() 303 if err != nil { 304 t.Errorf("unexpected error: %v", err) 305 } 306 } 307 308 func assertWrite(t *testing.T, w io.Writer, data []byte) { 309 t.Helper() 310 _, err := w.Write(data) 311 assertNilError(t, err) 312 } 313 314 func TestNewClient(t *testing.T) { 315 t.Parallel() 316 c := NewClient(nil) 317 318 if got, want := c.BaseURL.String(), defaultBaseURL; got != want { 319 t.Errorf("NewClient BaseURL is %v, want %v", got, want) 320 } 321 if got, want := c.UserAgent, defaultUserAgent; got != want { 322 t.Errorf("NewClient UserAgent is %v, want %v", got, want) 323 } 324 325 c2 := NewClient(nil) 326 if c.client == c2.client { 327 t.Error("NewClient returned same http.Clients, but they should differ") 328 } 329 } 330 331 func TestNewClientWithEnvProxy(t *testing.T) { 332 t.Parallel() 333 client := NewClientWithEnvProxy() 334 if got, want := client.BaseURL.String(), defaultBaseURL; got != want { 335 t.Errorf("NewClient BaseURL is %v, want %v", got, want) 336 } 337 } 338 339 func TestClient(t *testing.T) { 340 t.Parallel() 341 c := NewClient(nil) 342 c2 := c.Client() 343 if c.client == c2 { 344 t.Error("Client returned same http.Client, but should be different") 345 } 346 } 347 348 func TestWithAuthToken(t *testing.T) { 349 t.Parallel() 350 token := "gh_test_token" 351 352 validate := func(t *testing.T, c *http.Client, token string) { 353 t.Helper() 354 want := token 355 if want != "" { 356 want = "Bearer " + want 357 } 358 gotReq := false 359 headerVal := "" 360 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 361 gotReq = true 362 headerVal = r.Header.Get("Authorization") 363 })) 364 _, err := c.Get(srv.URL) 365 assertNilError(t, err) 366 if !gotReq { 367 t.Error("request not sent") 368 } 369 if headerVal != want { 370 t.Errorf("Authorization header is %v, want %v", headerVal, want) 371 } 372 } 373 374 t.Run("zero-value Client", func(t *testing.T) { 375 t.Parallel() 376 c := new(Client).WithAuthToken(token) 377 validate(t, c.Client(), token) 378 }) 379 380 t.Run("NewClient", func(t *testing.T) { 381 t.Parallel() 382 httpClient := &http.Client{} 383 client := NewClient(httpClient).WithAuthToken(token) 384 validate(t, client.Client(), token) 385 // make sure the original client isn't setting auth headers now 386 validate(t, httpClient, "") 387 }) 388 389 t.Run("NewTokenClient", func(t *testing.T) { 390 t.Parallel() 391 validate(t, NewTokenClient(context.Background(), token).Client(), token) 392 }) 393 } 394 395 func TestWithEnterpriseURLs(t *testing.T) { 396 t.Parallel() 397 for _, test := range []struct { 398 name string 399 baseURL string 400 wantBaseURL string 401 uploadURL string 402 wantUploadURL string 403 wantErr string 404 }{ 405 { 406 name: "does not modify properly formed URLs", 407 baseURL: "https://custom-url/api/v3/", 408 wantBaseURL: "https://custom-url/api/v3/", 409 uploadURL: "https://custom-upload-url/api/uploads/", 410 wantUploadURL: "https://custom-upload-url/api/uploads/", 411 }, 412 { 413 name: "adds trailing slash", 414 baseURL: "https://custom-url/api/v3", 415 wantBaseURL: "https://custom-url/api/v3/", 416 uploadURL: "https://custom-upload-url/api/uploads", 417 wantUploadURL: "https://custom-upload-url/api/uploads/", 418 }, 419 { 420 name: "adds enterprise suffix", 421 baseURL: "https://custom-url/", 422 wantBaseURL: "https://custom-url/api/v3/", 423 uploadURL: "https://custom-upload-url/", 424 wantUploadURL: "https://custom-upload-url/api/uploads/", 425 }, 426 { 427 name: "adds enterprise suffix and trailing slash", 428 baseURL: "https://custom-url", 429 wantBaseURL: "https://custom-url/api/v3/", 430 uploadURL: "https://custom-upload-url", 431 wantUploadURL: "https://custom-upload-url/api/uploads/", 432 }, 433 { 434 name: "bad base URL", 435 baseURL: "bogus\nbase\nURL", 436 uploadURL: "https://custom-upload-url/api/uploads/", 437 wantErr: `invalid control character in URL`, 438 }, 439 { 440 name: "bad upload URL", 441 baseURL: "https://custom-url/api/v3/", 442 uploadURL: "bogus\nupload\nURL", 443 wantErr: `invalid control character in URL`, 444 }, 445 { 446 name: "URL has existing API prefix, adds trailing slash", 447 baseURL: "https://api.custom-url", 448 wantBaseURL: "https://api.custom-url/", 449 uploadURL: "https://api.custom-upload-url", 450 wantUploadURL: "https://api.custom-upload-url/", 451 }, 452 { 453 name: "URL has existing API prefix and trailing slash", 454 baseURL: "https://api.custom-url/", 455 wantBaseURL: "https://api.custom-url/", 456 uploadURL: "https://api.custom-upload-url/", 457 wantUploadURL: "https://api.custom-upload-url/", 458 }, 459 { 460 name: "URL has API subdomain, adds trailing slash", 461 baseURL: "https://catalog.api.custom-url", 462 wantBaseURL: "https://catalog.api.custom-url/", 463 uploadURL: "https://catalog.api.custom-upload-url", 464 wantUploadURL: "https://catalog.api.custom-upload-url/", 465 }, 466 { 467 name: "URL has API subdomain and trailing slash", 468 baseURL: "https://catalog.api.custom-url/", 469 wantBaseURL: "https://catalog.api.custom-url/", 470 uploadURL: "https://catalog.api.custom-upload-url/", 471 wantUploadURL: "https://catalog.api.custom-upload-url/", 472 }, 473 { 474 name: "URL is not a proper API subdomain, adds enterprise suffix and slash", 475 baseURL: "https://cloud-api.custom-url", 476 wantBaseURL: "https://cloud-api.custom-url/api/v3/", 477 uploadURL: "https://cloud-api.custom-upload-url", 478 wantUploadURL: "https://cloud-api.custom-upload-url/api/uploads/", 479 }, 480 { 481 name: "URL is not a proper API subdomain, adds enterprise suffix", 482 baseURL: "https://cloud-api.custom-url/", 483 wantBaseURL: "https://cloud-api.custom-url/api/v3/", 484 uploadURL: "https://cloud-api.custom-upload-url/", 485 wantUploadURL: "https://cloud-api.custom-upload-url/api/uploads/", 486 }, 487 } { 488 t.Run(test.name, func(t *testing.T) { 489 t.Parallel() 490 validate := func(c *Client, err error) { 491 t.Helper() 492 if test.wantErr != "" { 493 if err == nil || !strings.Contains(err.Error(), test.wantErr) { 494 t.Fatalf("error does not contain expected string %q: %v", test.wantErr, err) 495 } 496 return 497 } 498 if err != nil { 499 t.Fatalf("got unexpected error: %v", err) 500 } 501 if c.BaseURL.String() != test.wantBaseURL { 502 t.Errorf("BaseURL is %v, want %v", c.BaseURL.String(), test.wantBaseURL) 503 } 504 if c.UploadURL.String() != test.wantUploadURL { 505 t.Errorf("UploadURL is %v, want %v", c.UploadURL.String(), test.wantUploadURL) 506 } 507 } 508 validate(NewClient(nil).WithEnterpriseURLs(test.baseURL, test.uploadURL)) 509 validate(new(Client).WithEnterpriseURLs(test.baseURL, test.uploadURL)) 510 validate(NewEnterpriseClient(test.baseURL, test.uploadURL, nil)) 511 }) 512 } 513 } 514 515 // Ensure that length of Client.rateLimits is the same as number of fields in RateLimits struct. 516 func TestClient_rateLimits(t *testing.T) { 517 t.Parallel() 518 if got, want := len(Client{}.rateLimits), reflect.TypeOf(RateLimits{}).NumField(); got != want { 519 t.Errorf("len(Client{}.rateLimits) is %v, want %v", got, want) 520 } 521 } 522 523 func TestNewRequest(t *testing.T) { 524 t.Parallel() 525 c := NewClient(nil) 526 527 inURL, outURL := "/foo", defaultBaseURL+"foo" 528 inBody, outBody := &User{Login: Ptr("l")}, `{"login":"l"}`+"\n" 529 req, _ := c.NewRequest("GET", inURL, inBody) 530 531 // test that relative URL was expanded 532 if got, want := req.URL.String(), outURL; got != want { 533 t.Errorf("NewRequest(%q) URL is %v, want %v", inURL, got, want) 534 } 535 536 // test that body was JSON encoded 537 body, _ := io.ReadAll(req.Body) 538 if got, want := string(body), outBody; got != want { 539 t.Errorf("NewRequest(%q) Body is %v, want %v", inBody, got, want) 540 } 541 542 userAgent := req.Header.Get("User-Agent") 543 544 // test that default user-agent is attached to the request 545 if got, want := userAgent, c.UserAgent; got != want { 546 t.Errorf("NewRequest() User-Agent is %v, want %v", got, want) 547 } 548 549 if !strings.Contains(userAgent, Version) { 550 t.Errorf("NewRequest() User-Agent should contain %v, found %v", Version, userAgent) 551 } 552 553 apiVersion := req.Header.Get(headerAPIVersion) 554 if got, want := apiVersion, defaultAPIVersion; got != want { 555 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 556 } 557 558 req, _ = c.NewRequest("GET", inURL, inBody, WithVersion("2022-11-29")) 559 apiVersion = req.Header.Get(headerAPIVersion) 560 if got, want := apiVersion, "2022-11-29"; got != want { 561 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 562 } 563 } 564 565 func TestNewRequest_invalidJSON(t *testing.T) { 566 t.Parallel() 567 c := NewClient(nil) 568 569 type T struct { 570 A map[interface{}]interface{} 571 } 572 _, err := c.NewRequest("GET", ".", &T{}) 573 574 if err == nil { 575 t.Error("Expected error to be returned.") 576 } 577 if err, ok := err.(*json.UnsupportedTypeError); !ok { 578 t.Errorf("Expected a JSON error; got %#v.", err) 579 } 580 } 581 582 func TestNewRequest_badURL(t *testing.T) { 583 t.Parallel() 584 c := NewClient(nil) 585 _, err := c.NewRequest("GET", ":", nil) 586 testURLParseError(t, err) 587 } 588 589 func TestNewRequest_badMethod(t *testing.T) { 590 t.Parallel() 591 c := NewClient(nil) 592 if _, err := c.NewRequest("BOGUS\nMETHOD", ".", nil); err == nil { 593 t.Fatal("NewRequest returned nil; expected error") 594 } 595 } 596 597 // ensure that no User-Agent header is set if the client's UserAgent is empty. 598 // This caused a problem with Google's internal http client. 599 func TestNewRequest_emptyUserAgent(t *testing.T) { 600 t.Parallel() 601 c := NewClient(nil) 602 c.UserAgent = "" 603 req, err := c.NewRequest("GET", ".", nil) 604 if err != nil { 605 t.Fatalf("NewRequest returned unexpected error: %v", err) 606 } 607 if _, ok := req.Header["User-Agent"]; ok { 608 t.Fatal("constructed request contains unexpected User-Agent header") 609 } 610 } 611 612 // If a nil body is passed to github.NewRequest, make sure that nil is also 613 // passed to http.NewRequest. In most cases, passing an io.Reader that returns 614 // no content is fine, since there is no difference between an HTTP request 615 // body that is an empty string versus one that is not set at all. However in 616 // certain cases, intermediate systems may treat these differently resulting in 617 // subtle errors. 618 func TestNewRequest_emptyBody(t *testing.T) { 619 t.Parallel() 620 c := NewClient(nil) 621 req, err := c.NewRequest("GET", ".", nil) 622 if err != nil { 623 t.Fatalf("NewRequest returned unexpected error: %v", err) 624 } 625 if req.Body != nil { 626 t.Fatalf("constructed request contains a non-nil Body") 627 } 628 } 629 630 func TestNewRequest_errorForNoTrailingSlash(t *testing.T) { 631 t.Parallel() 632 tests := []struct { 633 rawurl string 634 wantError bool 635 }{ 636 {rawurl: "https://example.com/api/v3", wantError: true}, 637 {rawurl: "https://example.com/api/v3/", wantError: false}, 638 } 639 c := NewClient(nil) 640 for _, test := range tests { 641 u, err := url.Parse(test.rawurl) 642 if err != nil { 643 t.Fatalf("url.Parse returned unexpected error: %v.", err) 644 } 645 c.BaseURL = u 646 if _, err := c.NewRequest(http.MethodGet, "test", nil); test.wantError && err == nil { 647 t.Fatalf("Expected error to be returned.") 648 } else if !test.wantError && err != nil { 649 t.Fatalf("NewRequest returned unexpected error: %v.", err) 650 } 651 } 652 } 653 654 func TestNewFormRequest(t *testing.T) { 655 t.Parallel() 656 c := NewClient(nil) 657 658 inURL, outURL := "/foo", defaultBaseURL+"foo" 659 form := url.Values{} 660 form.Add("login", "l") 661 inBody, outBody := strings.NewReader(form.Encode()), "login=l" 662 req, _ := c.NewFormRequest(inURL, inBody) 663 664 // test that relative URL was expanded 665 if got, want := req.URL.String(), outURL; got != want { 666 t.Errorf("NewFormRequest(%q) URL is %v, want %v", inURL, got, want) 667 } 668 669 // test that body was form encoded 670 body, _ := io.ReadAll(req.Body) 671 if got, want := string(body), outBody; got != want { 672 t.Errorf("NewFormRequest(%q) Body is %v, want %v", inBody, got, want) 673 } 674 675 // test that default user-agent is attached to the request 676 if got, want := req.Header.Get("User-Agent"), c.UserAgent; got != want { 677 t.Errorf("NewFormRequest() User-Agent is %v, want %v", got, want) 678 } 679 680 apiVersion := req.Header.Get(headerAPIVersion) 681 if got, want := apiVersion, defaultAPIVersion; got != want { 682 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 683 } 684 685 req, _ = c.NewFormRequest(inURL, inBody, WithVersion("2022-11-29")) 686 apiVersion = req.Header.Get(headerAPIVersion) 687 if got, want := apiVersion, "2022-11-29"; got != want { 688 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 689 } 690 } 691 692 func TestNewFormRequest_badURL(t *testing.T) { 693 t.Parallel() 694 c := NewClient(nil) 695 _, err := c.NewFormRequest(":", nil) 696 testURLParseError(t, err) 697 } 698 699 func TestNewFormRequest_emptyUserAgent(t *testing.T) { 700 t.Parallel() 701 c := NewClient(nil) 702 c.UserAgent = "" 703 req, err := c.NewFormRequest(".", nil) 704 if err != nil { 705 t.Fatalf("NewFormRequest returned unexpected error: %v", err) 706 } 707 if _, ok := req.Header["User-Agent"]; ok { 708 t.Fatal("constructed request contains unexpected User-Agent header") 709 } 710 } 711 712 func TestNewFormRequest_emptyBody(t *testing.T) { 713 t.Parallel() 714 c := NewClient(nil) 715 req, err := c.NewFormRequest(".", nil) 716 if err != nil { 717 t.Fatalf("NewFormRequest returned unexpected error: %v", err) 718 } 719 if req.Body != nil { 720 t.Fatalf("constructed request contains a non-nil Body") 721 } 722 } 723 724 func TestNewFormRequest_errorForNoTrailingSlash(t *testing.T) { 725 t.Parallel() 726 tests := []struct { 727 rawURL string 728 wantError bool 729 }{ 730 {rawURL: "https://example.com/api/v3", wantError: true}, 731 {rawURL: "https://example.com/api/v3/", wantError: false}, 732 } 733 c := NewClient(nil) 734 for _, test := range tests { 735 u, err := url.Parse(test.rawURL) 736 if err != nil { 737 t.Fatalf("url.Parse returned unexpected error: %v.", err) 738 } 739 c.BaseURL = u 740 if _, err := c.NewFormRequest("test", nil); test.wantError && err == nil { 741 t.Fatalf("Expected error to be returned.") 742 } else if !test.wantError && err != nil { 743 t.Fatalf("NewFormRequest returned unexpected error: %v.", err) 744 } 745 } 746 } 747 748 func TestNewUploadRequest_WithVersion(t *testing.T) { 749 t.Parallel() 750 c := NewClient(nil) 751 req, _ := c.NewUploadRequest("https://example.com/", nil, 0, "") 752 753 apiVersion := req.Header.Get(headerAPIVersion) 754 if got, want := apiVersion, defaultAPIVersion; got != want { 755 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 756 } 757 758 req, _ = c.NewUploadRequest("https://example.com/", nil, 0, "", WithVersion("2022-11-29")) 759 apiVersion = req.Header.Get(headerAPIVersion) 760 if got, want := apiVersion, "2022-11-29"; got != want { 761 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 762 } 763 } 764 765 func TestNewUploadRequest_badURL(t *testing.T) { 766 t.Parallel() 767 c := NewClient(nil) 768 _, err := c.NewUploadRequest(":", nil, 0, "") 769 testURLParseError(t, err) 770 771 const methodName = "NewUploadRequest" 772 testBadOptions(t, methodName, func() (err error) { 773 _, err = c.NewUploadRequest("\n", nil, -1, "\n") 774 return err 775 }) 776 } 777 778 func TestNewUploadRequest_errorForNoTrailingSlash(t *testing.T) { 779 t.Parallel() 780 tests := []struct { 781 rawurl string 782 wantError bool 783 }{ 784 {rawurl: "https://example.com/api/uploads", wantError: true}, 785 {rawurl: "https://example.com/api/uploads/", wantError: false}, 786 } 787 c := NewClient(nil) 788 for _, test := range tests { 789 u, err := url.Parse(test.rawurl) 790 if err != nil { 791 t.Fatalf("url.Parse returned unexpected error: %v.", err) 792 } 793 c.UploadURL = u 794 if _, err = c.NewUploadRequest("test", nil, 0, ""); test.wantError && err == nil { 795 t.Fatalf("Expected error to be returned.") 796 } else if !test.wantError && err != nil { 797 t.Fatalf("NewUploadRequest returned unexpected error: %v.", err) 798 } 799 } 800 } 801 802 func TestResponse_populatePageValues(t *testing.T) { 803 t.Parallel() 804 r := http.Response{ 805 Header: http.Header{ 806 "Link": {`<https://api.github.com/?page=1>; rel="first",` + 807 ` <https://api.github.com/?page=2>; rel="prev",` + 808 ` <https://api.github.com/?page=4>; rel="next",` + 809 ` <https://api.github.com/?page=5>; rel="last"`, 810 }, 811 }, 812 } 813 814 response := newResponse(&r) 815 if got, want := response.FirstPage, 1; got != want { 816 t.Errorf("response.FirstPage: %v, want %v", got, want) 817 } 818 if got, want := response.PrevPage, 2; want != got { 819 t.Errorf("response.PrevPage: %v, want %v", got, want) 820 } 821 if got, want := response.NextPage, 4; want != got { 822 t.Errorf("response.NextPage: %v, want %v", got, want) 823 } 824 if got, want := response.LastPage, 5; want != got { 825 t.Errorf("response.LastPage: %v, want %v", got, want) 826 } 827 if got, want := response.NextPageToken, ""; want != got { 828 t.Errorf("response.NextPageToken: %v, want %v", got, want) 829 } 830 } 831 832 func TestResponse_populateSinceValues(t *testing.T) { 833 t.Parallel() 834 r := http.Response{ 835 Header: http.Header{ 836 "Link": {`<https://api.github.com/?since=1>; rel="first",` + 837 ` <https://api.github.com/?since=2>; rel="prev",` + 838 ` <https://api.github.com/?since=4>; rel="next",` + 839 ` <https://api.github.com/?since=5>; rel="last"`, 840 }, 841 }, 842 } 843 844 response := newResponse(&r) 845 if got, want := response.FirstPage, 1; got != want { 846 t.Errorf("response.FirstPage: %v, want %v", got, want) 847 } 848 if got, want := response.PrevPage, 2; want != got { 849 t.Errorf("response.PrevPage: %v, want %v", got, want) 850 } 851 if got, want := response.NextPage, 4; want != got { 852 t.Errorf("response.NextPage: %v, want %v", got, want) 853 } 854 if got, want := response.LastPage, 5; want != got { 855 t.Errorf("response.LastPage: %v, want %v", got, want) 856 } 857 if got, want := response.NextPageToken, ""; want != got { 858 t.Errorf("response.NextPageToken: %v, want %v", got, want) 859 } 860 } 861 862 func TestResponse_SinceWithPage(t *testing.T) { 863 t.Parallel() 864 r := http.Response{ 865 Header: http.Header{ 866 "Link": {`<https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=1>; rel="first",` + 867 ` <https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=2>; rel="prev",` + 868 ` <https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=4>; rel="next",` + 869 ` <https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=5>; rel="last"`, 870 }, 871 }, 872 } 873 874 response := newResponse(&r) 875 if got, want := response.FirstPage, 1; got != want { 876 t.Errorf("response.FirstPage: %v, want %v", got, want) 877 } 878 if got, want := response.PrevPage, 2; want != got { 879 t.Errorf("response.PrevPage: %v, want %v", got, want) 880 } 881 if got, want := response.NextPage, 4; want != got { 882 t.Errorf("response.NextPage: %v, want %v", got, want) 883 } 884 if got, want := response.LastPage, 5; want != got { 885 t.Errorf("response.LastPage: %v, want %v", got, want) 886 } 887 if got, want := response.NextPageToken, ""; want != got { 888 t.Errorf("response.NextPageToken: %v, want %v", got, want) 889 } 890 } 891 892 func TestResponse_cursorPagination(t *testing.T) { 893 t.Parallel() 894 r := http.Response{ 895 Header: http.Header{ 896 "Status": {"200 OK"}, 897 "Link": {`<https://api.github.com/resource?per_page=2&page=url-encoded-next-page-token>; rel="next"`}, 898 }, 899 } 900 901 response := newResponse(&r) 902 if got, want := response.FirstPage, 0; got != want { 903 t.Errorf("response.FirstPage: %v, want %v", got, want) 904 } 905 if got, want := response.PrevPage, 0; want != got { 906 t.Errorf("response.PrevPage: %v, want %v", got, want) 907 } 908 if got, want := response.NextPage, 0; want != got { 909 t.Errorf("response.NextPage: %v, want %v", got, want) 910 } 911 if got, want := response.LastPage, 0; want != got { 912 t.Errorf("response.LastPage: %v, want %v", got, want) 913 } 914 if got, want := response.NextPageToken, "url-encoded-next-page-token"; want != got { 915 t.Errorf("response.NextPageToken: %v, want %v", got, want) 916 } 917 918 // cursor-based pagination with "cursor" param 919 r = http.Response{ 920 Header: http.Header{ 921 "Link": { 922 `<https://api.github.com/?cursor=v1_12345678>; rel="next"`, 923 }, 924 }, 925 } 926 927 response = newResponse(&r) 928 if got, want := response.Cursor, "v1_12345678"; got != want { 929 t.Errorf("response.Cursor: %v, want %v", got, want) 930 } 931 } 932 933 func TestResponse_beforeAfterPagination(t *testing.T) { 934 t.Parallel() 935 r := http.Response{ 936 Header: http.Header{ 937 "Link": {`<https://api.github.com/?after=a1b2c3&before=>; rel="next",` + 938 ` <https://api.github.com/?after=&before=>; rel="first",` + 939 ` <https://api.github.com/?after=&before=d4e5f6>; rel="prev",`, 940 }, 941 }, 942 } 943 944 response := newResponse(&r) 945 if got, want := response.Before, "d4e5f6"; got != want { 946 t.Errorf("response.Before: %v, want %v", got, want) 947 } 948 if got, want := response.After, "a1b2c3"; got != want { 949 t.Errorf("response.After: %v, want %v", got, want) 950 } 951 if got, want := response.FirstPage, 0; got != want { 952 t.Errorf("response.FirstPage: %v, want %v", got, want) 953 } 954 if got, want := response.PrevPage, 0; want != got { 955 t.Errorf("response.PrevPage: %v, want %v", got, want) 956 } 957 if got, want := response.NextPage, 0; want != got { 958 t.Errorf("response.NextPage: %v, want %v", got, want) 959 } 960 if got, want := response.LastPage, 0; want != got { 961 t.Errorf("response.LastPage: %v, want %v", got, want) 962 } 963 if got, want := response.NextPageToken, ""; want != got { 964 t.Errorf("response.NextPageToken: %v, want %v", got, want) 965 } 966 } 967 968 func TestResponse_populatePageValues_invalid(t *testing.T) { 969 t.Parallel() 970 r := http.Response{ 971 Header: http.Header{ 972 "Link": {`<https://api.github.com/?page=1>,` + 973 `<https://api.github.com/?page=abc>; rel="first",` + 974 `https://api.github.com/?page=2; rel="prev",` + 975 `<https://api.github.com/>; rel="next",` + 976 `<https://api.github.com/?page=>; rel="last"`, 977 }, 978 }, 979 } 980 981 response := newResponse(&r) 982 if got, want := response.FirstPage, 0; got != want { 983 t.Errorf("response.FirstPage: %v, want %v", got, want) 984 } 985 if got, want := response.PrevPage, 0; got != want { 986 t.Errorf("response.PrevPage: %v, want %v", got, want) 987 } 988 if got, want := response.NextPage, 0; got != want { 989 t.Errorf("response.NextPage: %v, want %v", got, want) 990 } 991 if got, want := response.LastPage, 0; got != want { 992 t.Errorf("response.LastPage: %v, want %v", got, want) 993 } 994 995 // more invalid URLs 996 r = http.Response{ 997 Header: http.Header{ 998 "Link": {`<https://api.github.com/%?page=2>; rel="first"`}, 999 }, 1000 } 1001 1002 response = newResponse(&r) 1003 if got, want := response.FirstPage, 0; got != want { 1004 t.Errorf("response.FirstPage: %v, want %v", got, want) 1005 } 1006 } 1007 1008 func TestResponse_populateSinceValues_invalid(t *testing.T) { 1009 t.Parallel() 1010 r := http.Response{ 1011 Header: http.Header{ 1012 "Link": {`<https://api.github.com/?since=1>,` + 1013 `<https://api.github.com/?since=abc>; rel="first",` + 1014 `https://api.github.com/?since=2; rel="prev",` + 1015 `<https://api.github.com/>; rel="next",` + 1016 `<https://api.github.com/?since=>; rel="last"`, 1017 }, 1018 }, 1019 } 1020 1021 response := newResponse(&r) 1022 if got, want := response.FirstPage, 0; got != want { 1023 t.Errorf("response.FirstPage: %v, want %v", got, want) 1024 } 1025 if got, want := response.PrevPage, 0; got != want { 1026 t.Errorf("response.PrevPage: %v, want %v", got, want) 1027 } 1028 if got, want := response.NextPage, 0; got != want { 1029 t.Errorf("response.NextPage: %v, want %v", got, want) 1030 } 1031 if got, want := response.LastPage, 0; got != want { 1032 t.Errorf("response.LastPage: %v, want %v", got, want) 1033 } 1034 1035 // more invalid URLs 1036 r = http.Response{ 1037 Header: http.Header{ 1038 "Link": {`<https://api.github.com/%?since=2>; rel="first"`}, 1039 }, 1040 } 1041 1042 response = newResponse(&r) 1043 if got, want := response.FirstPage, 0; got != want { 1044 t.Errorf("response.FirstPage: %v, want %v", got, want) 1045 } 1046 } 1047 1048 func TestDo(t *testing.T) { 1049 t.Parallel() 1050 client, mux, _ := setup(t) 1051 1052 type foo struct { 1053 A string 1054 } 1055 1056 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1057 testMethod(t, r, "GET") 1058 fmt.Fprint(w, `{"A":"a"}`) 1059 }) 1060 1061 req, _ := client.NewRequest("GET", ".", nil) 1062 body := new(foo) 1063 ctx := context.Background() 1064 _, err := client.Do(ctx, req, body) 1065 assertNilError(t, err) 1066 1067 want := &foo{"a"} 1068 if !cmp.Equal(body, want) { 1069 t.Errorf("Response body = %v, want %v", body, want) 1070 } 1071 } 1072 1073 func TestDo_nilContext(t *testing.T) { 1074 t.Parallel() 1075 client, _, _ := setup(t) 1076 1077 req, _ := client.NewRequest("GET", ".", nil) 1078 _, err := client.Do(nil, req, nil) 1079 1080 if !errors.Is(err, errNonNilContext) { 1081 t.Errorf("Expected context must be non-nil error") 1082 } 1083 } 1084 1085 func TestDo_httpError(t *testing.T) { 1086 t.Parallel() 1087 client, mux, _ := setup(t) 1088 1089 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1090 http.Error(w, "Bad Request", 400) 1091 }) 1092 1093 req, _ := client.NewRequest("GET", ".", nil) 1094 ctx := context.Background() 1095 resp, err := client.Do(ctx, req, nil) 1096 1097 if err == nil { 1098 t.Fatal("Expected HTTP 400 error, got no error.") 1099 } 1100 if resp.StatusCode != 400 { 1101 t.Errorf("Expected HTTP 400 error, got %d status code.", resp.StatusCode) 1102 } 1103 } 1104 1105 // Test handling of an error caused by the internal http client's Do() 1106 // function. A redirect loop is pretty unlikely to occur within the GitHub 1107 // API, but does allow us to exercise the right code path. 1108 func TestDo_redirectLoop(t *testing.T) { 1109 t.Parallel() 1110 client, mux, _ := setup(t) 1111 1112 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1113 http.Redirect(w, r, baseURLPath, http.StatusFound) 1114 }) 1115 1116 req, _ := client.NewRequest("GET", ".", nil) 1117 ctx := context.Background() 1118 _, err := client.Do(ctx, req, nil) 1119 1120 if err == nil { 1121 t.Error("Expected error to be returned.") 1122 } 1123 if err, ok := err.(*url.Error); !ok { 1124 t.Errorf("Expected a URL error; got %#v.", err) 1125 } 1126 } 1127 1128 func TestDo_preservesResponseInHTTPError(t *testing.T) { 1129 t.Parallel() 1130 client, mux, _ := setup(t) 1131 1132 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1133 w.Header().Set("Content-Type", "application/json") 1134 w.WriteHeader(http.StatusNotFound) 1135 fmt.Fprintf(w, `{ 1136 "message": "Resource not found", 1137 "documentation_url": "https://docs.github.com/rest/reference/repos#get-a-repository" 1138 }`) 1139 }) 1140 1141 req, _ := client.NewRequest("GET", ".", nil) 1142 var resp *Response 1143 var data interface{} 1144 resp, err := client.Do(context.Background(), req, &data) 1145 1146 if err == nil { 1147 t.Fatal("Expected error response") 1148 } 1149 1150 // Verify error type and access to status code 1151 errResp, ok := err.(*ErrorResponse) 1152 if !ok { 1153 t.Fatalf("Expected *ErrorResponse error, got %T", err) 1154 } 1155 1156 // Verify status code is accessible from both Response and ErrorResponse 1157 if resp == nil { 1158 t.Fatal("Expected response to be returned even with error") 1159 } 1160 if got, want := resp.StatusCode, http.StatusNotFound; got != want { 1161 t.Errorf("Response status = %d, want %d", got, want) 1162 } 1163 if got, want := errResp.Response.StatusCode, http.StatusNotFound; got != want { 1164 t.Errorf("Error response status = %d, want %d", got, want) 1165 } 1166 1167 // Verify error contains proper message 1168 if !strings.Contains(errResp.Message, "Resource not found") { 1169 t.Errorf("Error message = %q, want to contain 'Resource not found'", errResp.Message) 1170 } 1171 } 1172 1173 // Test that an error caused by the internal http client's Do() function 1174 // does not leak the client secret. 1175 func TestDo_sanitizeURL(t *testing.T) { 1176 t.Parallel() 1177 tp := &UnauthenticatedRateLimitedTransport{ 1178 ClientID: "id", 1179 ClientSecret: "secret", 1180 } 1181 unauthedClient := NewClient(tp.Client()) 1182 unauthedClient.BaseURL = &url.URL{Scheme: "http", Host: "127.0.0.1:0", Path: "/"} // Use port 0 on purpose to trigger a dial TCP error, expect to get "dial tcp 127.0.0.1:0: connect: can't assign requested address". 1183 req, err := unauthedClient.NewRequest("GET", ".", nil) 1184 if err != nil { 1185 t.Fatalf("NewRequest returned unexpected error: %v", err) 1186 } 1187 ctx := context.Background() 1188 _, err = unauthedClient.Do(ctx, req, nil) 1189 if err == nil { 1190 t.Fatal("Expected error to be returned.") 1191 } 1192 if strings.Contains(err.Error(), "client_secret=secret") { 1193 t.Errorf("Do error contains secret, should be redacted:\n%q", err) 1194 } 1195 } 1196 1197 func TestDo_rateLimit(t *testing.T) { 1198 t.Parallel() 1199 client, mux, _ := setup(t) 1200 1201 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1202 w.Header().Set(headerRateLimit, "60") 1203 w.Header().Set(headerRateRemaining, "59") 1204 w.Header().Set(headerRateUsed, "1") 1205 w.Header().Set(headerRateReset, "1372700873") 1206 w.Header().Set(headerRateResource, "core") 1207 }) 1208 1209 req, _ := client.NewRequest("GET", ".", nil) 1210 ctx := context.Background() 1211 resp, err := client.Do(ctx, req, nil) 1212 if err != nil { 1213 t.Errorf("Do returned unexpected error: %v", err) 1214 } 1215 if got, want := resp.Rate.Limit, 60; got != want { 1216 t.Errorf("Client rate limit = %v, want %v", got, want) 1217 } 1218 if got, want := resp.Rate.Remaining, 59; got != want { 1219 t.Errorf("Client rate remaining = %v, want %v", got, want) 1220 } 1221 if got, want := resp.Rate.Used, 1; got != want { 1222 t.Errorf("Client rate used = %v, want %v", got, want) 1223 } 1224 reset := time.Date(2013, time.July, 1, 17, 47, 53, 0, time.UTC) 1225 if !resp.Rate.Reset.UTC().Equal(reset) { 1226 t.Errorf("Client rate reset = %v, want %v", resp.Rate.Reset.UTC(), reset) 1227 } 1228 if got, want := resp.Rate.Resource, "core"; got != want { 1229 t.Errorf("Client rate resource = %v, want %v", got, want) 1230 } 1231 } 1232 1233 func TestDo_rateLimitCategory(t *testing.T) { 1234 t.Parallel() 1235 tests := []struct { 1236 method string 1237 url string 1238 category RateLimitCategory 1239 }{ 1240 { 1241 method: http.MethodGet, 1242 url: "/", 1243 category: CoreCategory, 1244 }, 1245 { 1246 method: http.MethodGet, 1247 url: "/search/issues?q=rate", 1248 category: SearchCategory, 1249 }, 1250 { 1251 method: http.MethodGet, 1252 url: "/graphql", 1253 category: GraphqlCategory, 1254 }, 1255 { 1256 method: http.MethodPost, 1257 url: "/app-manifests/code/conversions", 1258 category: IntegrationManifestCategory, 1259 }, 1260 { 1261 method: http.MethodGet, 1262 url: "/app-manifests/code/conversions", 1263 category: CoreCategory, // only POST requests are in the integration manifest category 1264 }, 1265 { 1266 method: http.MethodPut, 1267 url: "/repos/google/go-github/import", 1268 category: SourceImportCategory, 1269 }, 1270 { 1271 method: http.MethodGet, 1272 url: "/repos/google/go-github/import", 1273 category: CoreCategory, // only PUT requests are in the source import category 1274 }, 1275 { 1276 method: http.MethodPost, 1277 url: "/repos/google/go-github/code-scanning/sarifs", 1278 category: CodeScanningUploadCategory, 1279 }, 1280 { 1281 method: http.MethodGet, 1282 url: "/scim/v2/organizations/ORG/Users", 1283 category: ScimCategory, 1284 }, 1285 { 1286 method: http.MethodPost, 1287 url: "/repos/google/go-github/dependency-graph/snapshots", 1288 category: DependencySnapshotsCategory, 1289 }, 1290 { 1291 method: http.MethodGet, 1292 url: "/search/code?q=rate", 1293 category: CodeSearchCategory, 1294 }, 1295 { 1296 method: http.MethodGet, 1297 url: "/orgs/google/audit-log", 1298 category: AuditLogCategory, 1299 }, 1300 // missing a check for actionsRunnerRegistrationCategory: API not found 1301 } 1302 1303 for _, tt := range tests { 1304 if got, want := GetRateLimitCategory(tt.method, tt.url), tt.category; got != want { 1305 t.Errorf("expecting category %v, found %v", got, want) 1306 } 1307 } 1308 } 1309 1310 // Ensure rate limit is still parsed, even for error responses. 1311 func TestDo_rateLimit_errorResponse(t *testing.T) { 1312 t.Parallel() 1313 client, mux, _ := setup(t) 1314 1315 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1316 w.Header().Set(headerRateLimit, "60") 1317 w.Header().Set(headerRateRemaining, "59") 1318 w.Header().Set(headerRateUsed, "1") 1319 w.Header().Set(headerRateReset, "1372700873") 1320 w.Header().Set(headerRateResource, "core") 1321 http.Error(w, "Bad Request", 400) 1322 }) 1323 1324 req, _ := client.NewRequest("GET", ".", nil) 1325 ctx := context.Background() 1326 resp, err := client.Do(ctx, req, nil) 1327 if err == nil { 1328 t.Error("Expected error to be returned.") 1329 } 1330 if _, ok := err.(*RateLimitError); ok { 1331 t.Errorf("Did not expect a *RateLimitError error; got %#v.", err) 1332 } 1333 if got, want := resp.Rate.Limit, 60; got != want { 1334 t.Errorf("Client rate limit = %v, want %v", got, want) 1335 } 1336 if got, want := resp.Rate.Remaining, 59; got != want { 1337 t.Errorf("Client rate remaining = %v, want %v", got, want) 1338 } 1339 if got, want := resp.Rate.Used, 1; got != want { 1340 t.Errorf("Client rate used = %v, want %v", got, want) 1341 } 1342 reset := time.Date(2013, time.July, 1, 17, 47, 53, 0, time.UTC) 1343 if !resp.Rate.Reset.UTC().Equal(reset) { 1344 t.Errorf("Client rate reset = %v, want %v", resp.Rate.Reset, reset) 1345 } 1346 if got, want := resp.Rate.Resource, "core"; got != want { 1347 t.Errorf("Client rate resource = %v, want %v", got, want) 1348 } 1349 } 1350 1351 // Ensure *RateLimitError is returned when API rate limit is exceeded. 1352 func TestDo_rateLimit_rateLimitError(t *testing.T) { 1353 t.Parallel() 1354 client, mux, _ := setup(t) 1355 1356 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1357 w.Header().Set(headerRateLimit, "60") 1358 w.Header().Set(headerRateRemaining, "0") 1359 w.Header().Set(headerRateUsed, "60") 1360 w.Header().Set(headerRateReset, "1372700873") 1361 w.Header().Set(headerRateResource, "core") 1362 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1363 w.WriteHeader(http.StatusForbidden) 1364 fmt.Fprintln(w, `{ 1365 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1366 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1367 }`) 1368 }) 1369 1370 req, _ := client.NewRequest("GET", ".", nil) 1371 ctx := context.Background() 1372 _, err := client.Do(ctx, req, nil) 1373 1374 if err == nil { 1375 t.Error("Expected error to be returned.") 1376 } 1377 rateLimitErr, ok := err.(*RateLimitError) 1378 if !ok { 1379 t.Fatalf("Expected a *RateLimitError error; got %#v.", err) 1380 } 1381 if got, want := rateLimitErr.Rate.Limit, 60; got != want { 1382 t.Errorf("rateLimitErr rate limit = %v, want %v", got, want) 1383 } 1384 if got, want := rateLimitErr.Rate.Remaining, 0; got != want { 1385 t.Errorf("rateLimitErr rate remaining = %v, want %v", got, want) 1386 } 1387 if got, want := rateLimitErr.Rate.Used, 60; got != want { 1388 t.Errorf("rateLimitErr rate used = %v, want %v", got, want) 1389 } 1390 reset := time.Date(2013, time.July, 1, 17, 47, 53, 0, time.UTC) 1391 if !rateLimitErr.Rate.Reset.UTC().Equal(reset) { 1392 t.Errorf("rateLimitErr rate reset = %v, want %v", rateLimitErr.Rate.Reset.UTC(), reset) 1393 } 1394 if got, want := rateLimitErr.Rate.Resource, "core"; got != want { 1395 t.Errorf("rateLimitErr rate resource = %v, want %v", got, want) 1396 } 1397 } 1398 1399 // Ensure a network call is not made when it's known that API rate limit is still exceeded. 1400 func TestDo_rateLimit_noNetworkCall(t *testing.T) { 1401 t.Parallel() 1402 client, mux, _ := setup(t) 1403 1404 reset := time.Now().UTC().Add(time.Minute).Round(time.Second) // Rate reset is a minute from now, with 1 second precision. 1405 1406 mux.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) { 1407 w.Header().Set(headerRateLimit, "60") 1408 w.Header().Set(headerRateRemaining, "0") 1409 w.Header().Set(headerRateUsed, "60") 1410 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1411 w.Header().Set(headerRateResource, "core") 1412 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1413 w.WriteHeader(http.StatusForbidden) 1414 fmt.Fprintln(w, `{ 1415 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1416 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1417 }`) 1418 }) 1419 1420 madeNetworkCall := false 1421 mux.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) { 1422 madeNetworkCall = true 1423 }) 1424 1425 // First request is made, and it makes the client aware of rate reset time being in the future. 1426 req, _ := client.NewRequest("GET", "first", nil) 1427 ctx := context.Background() 1428 _, err := client.Do(ctx, req, nil) 1429 if err == nil { 1430 t.Error("Expected error to be returned.") 1431 } 1432 1433 // Second request should not cause a network call to be made, since client can predict a rate limit error. 1434 req, _ = client.NewRequest("GET", "second", nil) 1435 _, err = client.Do(ctx, req, nil) 1436 1437 if madeNetworkCall { 1438 t.Fatal("Network call was made, even though rate limit is known to still be exceeded.") 1439 } 1440 1441 if err == nil { 1442 t.Error("Expected error to be returned.") 1443 } 1444 rateLimitErr, ok := err.(*RateLimitError) 1445 if !ok { 1446 t.Fatalf("Expected a *RateLimitError error; got %#v.", err) 1447 } 1448 if got, want := rateLimitErr.Rate.Limit, 60; got != want { 1449 t.Errorf("rateLimitErr rate limit = %v, want %v", got, want) 1450 } 1451 if got, want := rateLimitErr.Rate.Remaining, 0; got != want { 1452 t.Errorf("rateLimitErr rate remaining = %v, want %v", got, want) 1453 } 1454 if got, want := rateLimitErr.Rate.Used, 60; got != want { 1455 t.Errorf("rateLimitErr rate used = %v, want %v", got, want) 1456 } 1457 if !rateLimitErr.Rate.Reset.UTC().Equal(reset) { 1458 t.Errorf("rateLimitErr rate reset = %v, want %v", rateLimitErr.Rate.Reset.UTC(), reset) 1459 } 1460 if got, want := rateLimitErr.Rate.Resource, "core"; got != want { 1461 t.Errorf("rateLimitErr rate resource = %v, want %v", got, want) 1462 } 1463 } 1464 1465 // Ignore rate limit headers if the response was served from cache. 1466 func TestDo_rateLimit_ignoredFromCache(t *testing.T) { 1467 t.Parallel() 1468 client, mux, _ := setup(t) 1469 1470 reset := time.Now().UTC().Add(time.Minute).Round(time.Second) // Rate reset is a minute from now, with 1 second precision. 1471 1472 // By adding the X-From-Cache header we pretend this is served from a cache. 1473 mux.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) { 1474 w.Header().Set("X-From-Cache", "1") 1475 w.Header().Set(headerRateLimit, "60") 1476 w.Header().Set(headerRateRemaining, "0") 1477 w.Header().Set(headerRateUsed, "60") 1478 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1479 w.Header().Set(headerRateResource, "core") 1480 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1481 w.WriteHeader(http.StatusForbidden) 1482 fmt.Fprintln(w, `{ 1483 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1484 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1485 }`) 1486 }) 1487 1488 madeNetworkCall := false 1489 mux.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) { 1490 madeNetworkCall = true 1491 }) 1492 1493 // First request is made so afterwards we can check the returned rate limit headers were ignored. 1494 req, _ := client.NewRequest("GET", "first", nil) 1495 ctx := context.Background() 1496 _, err := client.Do(ctx, req, nil) 1497 if err == nil { 1498 t.Error("Expected error to be returned.") 1499 } 1500 1501 // Second request should not by hindered by rate limits. 1502 req, _ = client.NewRequest("GET", "second", nil) 1503 _, err = client.Do(ctx, req, nil) 1504 1505 if err != nil { 1506 t.Fatalf("Second request failed, even though the rate limits from the cache should've been ignored: %v", err) 1507 } 1508 if !madeNetworkCall { 1509 t.Fatal("Network call was not made, even though the rate limits from the cache should've been ignored") 1510 } 1511 } 1512 1513 // Ensure sleeps until the rate limit is reset when the client is rate limited. 1514 func TestDo_rateLimit_sleepUntilResponseResetLimit(t *testing.T) { 1515 t.Parallel() 1516 client, mux, _ := setup(t) 1517 1518 reset := time.Now().UTC().Add(time.Second) 1519 1520 var firstRequest = true 1521 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1522 if firstRequest { 1523 firstRequest = false 1524 w.Header().Set(headerRateLimit, "60") 1525 w.Header().Set(headerRateRemaining, "0") 1526 w.Header().Set(headerRateUsed, "60") 1527 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1528 w.Header().Set(headerRateResource, "core") 1529 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1530 w.WriteHeader(http.StatusForbidden) 1531 fmt.Fprintln(w, `{ 1532 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1533 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1534 }`) 1535 return 1536 } 1537 w.Header().Set(headerRateLimit, "5000") 1538 w.Header().Set(headerRateRemaining, "5000") 1539 w.Header().Set(headerRateUsed, "0") 1540 w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) 1541 w.Header().Set(headerRateResource, "core") 1542 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1543 w.WriteHeader(http.StatusOK) 1544 fmt.Fprintln(w, `{}`) 1545 }) 1546 1547 req, _ := client.NewRequest("GET", ".", nil) 1548 ctx := context.Background() 1549 resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1550 if err != nil { 1551 t.Errorf("Do returned unexpected error: %v", err) 1552 } 1553 if got, want := resp.StatusCode, http.StatusOK; got != want { 1554 t.Errorf("Response status code = %v, want %v", got, want) 1555 } 1556 } 1557 1558 // Ensure tries to sleep until the rate limit is reset when the client is rate limited, but only once. 1559 func TestDo_rateLimit_sleepUntilResponseResetLimitRetryOnce(t *testing.T) { 1560 t.Parallel() 1561 client, mux, _ := setup(t) 1562 1563 reset := time.Now().UTC().Add(time.Second) 1564 1565 requestCount := 0 1566 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1567 requestCount++ 1568 w.Header().Set(headerRateLimit, "60") 1569 w.Header().Set(headerRateRemaining, "0") 1570 w.Header().Set(headerRateUsed, "60") 1571 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1572 w.Header().Set(headerRateResource, "core") 1573 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1574 w.WriteHeader(http.StatusForbidden) 1575 fmt.Fprintln(w, `{ 1576 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1577 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1578 }`) 1579 }) 1580 1581 req, _ := client.NewRequest("GET", ".", nil) 1582 ctx := context.Background() 1583 _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1584 if err == nil { 1585 t.Error("Expected error to be returned.") 1586 } 1587 if got, want := requestCount, 2; got != want { 1588 t.Errorf("Expected 2 requests, got %d", got) 1589 } 1590 } 1591 1592 // Ensure a network call is not made when it's known that API rate limit is still exceeded. 1593 func TestDo_rateLimit_sleepUntilClientResetLimit(t *testing.T) { 1594 t.Parallel() 1595 client, mux, _ := setup(t) 1596 1597 reset := time.Now().UTC().Add(time.Second) 1598 client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}} 1599 requestCount := 0 1600 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1601 requestCount++ 1602 w.Header().Set(headerRateLimit, "5000") 1603 w.Header().Set(headerRateRemaining, "5000") 1604 w.Header().Set(headerRateUsed, "0") 1605 w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) 1606 w.Header().Set(headerRateResource, "core") 1607 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1608 w.WriteHeader(http.StatusOK) 1609 fmt.Fprintln(w, `{}`) 1610 }) 1611 req, _ := client.NewRequest("GET", ".", nil) 1612 ctx := context.Background() 1613 resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1614 if err != nil { 1615 t.Errorf("Do returned unexpected error: %v", err) 1616 } 1617 if got, want := resp.StatusCode, http.StatusOK; got != want { 1618 t.Errorf("Response status code = %v, want %v", got, want) 1619 } 1620 if got, want := requestCount, 1; got != want { 1621 t.Errorf("Expected 1 request, got %d", got) 1622 } 1623 } 1624 1625 // Ensure sleep is aborted when the context is cancelled. 1626 func TestDo_rateLimit_abortSleepContextCancelled(t *testing.T) { 1627 t.Parallel() 1628 client, mux, _ := setup(t) 1629 1630 // We use a 1 minute reset time to ensure the sleep is not completed. 1631 reset := time.Now().UTC().Add(time.Minute) 1632 requestCount := 0 1633 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1634 requestCount++ 1635 w.Header().Set(headerRateLimit, "60") 1636 w.Header().Set(headerRateRemaining, "0") 1637 w.Header().Set(headerRateUsed, "60") 1638 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1639 w.Header().Set(headerRateResource, "core") 1640 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1641 w.WriteHeader(http.StatusForbidden) 1642 fmt.Fprintln(w, `{ 1643 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1644 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1645 }`) 1646 }) 1647 1648 req, _ := client.NewRequest("GET", ".", nil) 1649 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 1650 defer cancel() 1651 _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1652 if !errors.Is(err, context.DeadlineExceeded) { 1653 t.Error("Expected context deadline exceeded error.") 1654 } 1655 if got, want := requestCount, 1; got != want { 1656 t.Errorf("Expected 1 requests, got %d", got) 1657 } 1658 } 1659 1660 // Ensure sleep is aborted when the context is cancelled on initial request. 1661 func TestDo_rateLimit_abortSleepContextCancelledClientLimit(t *testing.T) { 1662 t.Parallel() 1663 client, mux, _ := setup(t) 1664 1665 reset := time.Now().UTC().Add(time.Minute) 1666 client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}} 1667 requestCount := 0 1668 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1669 requestCount++ 1670 w.Header().Set(headerRateLimit, "5000") 1671 w.Header().Set(headerRateRemaining, "5000") 1672 w.Header().Set(headerRateUsed, "0") 1673 w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) 1674 w.Header().Set(headerRateResource, "core") 1675 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1676 w.WriteHeader(http.StatusOK) 1677 fmt.Fprintln(w, `{}`) 1678 }) 1679 req, _ := client.NewRequest("GET", ".", nil) 1680 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 1681 defer cancel() 1682 _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1683 rateLimitError, ok := err.(*RateLimitError) 1684 if !ok { 1685 t.Fatalf("Expected a *rateLimitError error; got %#v.", err) 1686 } 1687 if got, wantSuffix := rateLimitError.Message, "Context cancelled while waiting for rate limit to reset until"; !strings.HasPrefix(got, wantSuffix) { 1688 t.Errorf("Expected request to be prevented because context cancellation, got: %v.", got) 1689 } 1690 if got, want := requestCount, 0; got != want { 1691 t.Errorf("Expected 1 requests, got %d", got) 1692 } 1693 } 1694 1695 // Ensure *AbuseRateLimitError is returned when the response indicates that 1696 // the client has triggered an abuse detection mechanism. 1697 func TestDo_rateLimit_abuseRateLimitError(t *testing.T) { 1698 t.Parallel() 1699 client, mux, _ := setup(t) 1700 1701 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1702 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1703 w.WriteHeader(http.StatusForbidden) 1704 // When the abuse rate limit error is of the "temporarily blocked from content creation" type, 1705 // there is no "Retry-After" header. 1706 fmt.Fprintln(w, `{ 1707 "message": "You have triggered an abuse detection mechanism and have been temporarily blocked from content creation. Please retry your request again later.", 1708 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1709 }`) 1710 }) 1711 1712 req, _ := client.NewRequest("GET", ".", nil) 1713 ctx := context.Background() 1714 _, err := client.Do(ctx, req, nil) 1715 1716 if err == nil { 1717 t.Error("Expected error to be returned.") 1718 } 1719 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1720 if !ok { 1721 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1722 } 1723 if got, want := abuseRateLimitErr.RetryAfter, (*time.Duration)(nil); got != want { 1724 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1725 } 1726 } 1727 1728 // Ensure *AbuseRateLimitError is returned when the response indicates that 1729 // the client has triggered an abuse detection mechanism on GitHub Enterprise. 1730 func TestDo_rateLimit_abuseRateLimitErrorEnterprise(t *testing.T) { 1731 t.Parallel() 1732 client, mux, _ := setup(t) 1733 1734 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1735 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1736 w.WriteHeader(http.StatusForbidden) 1737 // When the abuse rate limit error is of the "temporarily blocked from content creation" type, 1738 // there is no "Retry-After" header. 1739 // This response returns a documentation url like the one returned for GitHub Enterprise, this 1740 // url changes between versions but follows roughly the same format. 1741 fmt.Fprintln(w, `{ 1742 "message": "You have triggered an abuse detection mechanism and have been temporarily blocked from content creation. Please retry your request again later.", 1743 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1744 }`) 1745 }) 1746 1747 req, _ := client.NewRequest("GET", ".", nil) 1748 ctx := context.Background() 1749 _, err := client.Do(ctx, req, nil) 1750 1751 if err == nil { 1752 t.Error("Expected error to be returned.") 1753 } 1754 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1755 if !ok { 1756 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1757 } 1758 if got, want := abuseRateLimitErr.RetryAfter, (*time.Duration)(nil); got != want { 1759 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1760 } 1761 } 1762 1763 // Ensure *AbuseRateLimitError.RetryAfter is parsed correctly for the Retry-After header. 1764 func TestDo_rateLimit_abuseRateLimitError_retryAfter(t *testing.T) { 1765 t.Parallel() 1766 client, mux, _ := setup(t) 1767 1768 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1769 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1770 w.Header().Set(headerRetryAfter, "123") // Retry after value of 123 seconds. 1771 w.WriteHeader(http.StatusForbidden) 1772 fmt.Fprintln(w, `{ 1773 "message": "You have triggered an abuse detection mechanism ...", 1774 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1775 }`) 1776 }) 1777 1778 req, _ := client.NewRequest("GET", ".", nil) 1779 ctx := context.Background() 1780 _, err := client.Do(ctx, req, nil) 1781 1782 if err == nil { 1783 t.Error("Expected error to be returned.") 1784 } 1785 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1786 if !ok { 1787 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1788 } 1789 if abuseRateLimitErr.RetryAfter == nil { 1790 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1791 } 1792 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; got != want { 1793 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1794 } 1795 1796 // expect prevention of a following request 1797 if _, err = client.Do(ctx, req, nil); err == nil { 1798 t.Error("Expected error to be returned.") 1799 } 1800 abuseRateLimitErr, ok = err.(*AbuseRateLimitError) 1801 if !ok { 1802 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1803 } 1804 if abuseRateLimitErr.RetryAfter == nil { 1805 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1806 } 1807 // the saved duration might be a bit smaller than Retry-After because the duration is calculated from the expected end-of-cooldown time 1808 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; want-got > 1*time.Second { 1809 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1810 } 1811 if got, wantSuffix := abuseRateLimitErr.Message, "not making remote request."; !strings.HasSuffix(got, wantSuffix) { 1812 t.Errorf("Expected request to be prevented because of secondary rate limit, got: %v.", got) 1813 } 1814 } 1815 1816 // Ensure *AbuseRateLimitError.RetryAfter is parsed correctly for the x-ratelimit-reset header. 1817 func TestDo_rateLimit_abuseRateLimitError_xRateLimitReset(t *testing.T) { 1818 t.Parallel() 1819 client, mux, _ := setup(t) 1820 1821 // x-ratelimit-reset value of 123 seconds into the future. 1822 blockUntil := time.Now().Add(time.Duration(123) * time.Second).Unix() 1823 1824 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1825 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1826 w.Header().Set(headerRateReset, strconv.Itoa(int(blockUntil))) 1827 w.Header().Set(headerRateRemaining, "1") // set remaining to a value > 0 to distinct from a primary rate limit 1828 w.WriteHeader(http.StatusForbidden) 1829 fmt.Fprintln(w, `{ 1830 "message": "You have triggered an abuse detection mechanism ...", 1831 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1832 }`) 1833 }) 1834 1835 req, _ := client.NewRequest("GET", ".", nil) 1836 ctx := context.Background() 1837 _, err := client.Do(ctx, req, nil) 1838 1839 if err == nil { 1840 t.Error("Expected error to be returned.") 1841 } 1842 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1843 if !ok { 1844 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1845 } 1846 if abuseRateLimitErr.RetryAfter == nil { 1847 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1848 } 1849 // the retry after value might be a bit smaller than the original duration because the duration is calculated from the expected end-of-cooldown time 1850 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; want-got > 1*time.Second { 1851 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1852 } 1853 1854 // expect prevention of a following request 1855 if _, err = client.Do(ctx, req, nil); err == nil { 1856 t.Error("Expected error to be returned.") 1857 } 1858 abuseRateLimitErr, ok = err.(*AbuseRateLimitError) 1859 if !ok { 1860 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1861 } 1862 if abuseRateLimitErr.RetryAfter == nil { 1863 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1864 } 1865 // the saved duration might be a bit smaller than Retry-After because the duration is calculated from the expected end-of-cooldown time 1866 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; want-got > 1*time.Second { 1867 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1868 } 1869 if got, wantSuffix := abuseRateLimitErr.Message, "not making remote request."; !strings.HasSuffix(got, wantSuffix) { 1870 t.Errorf("Expected request to be prevented because of secondary rate limit, got: %v.", got) 1871 } 1872 } 1873 1874 // Ensure *AbuseRateLimitError.RetryAfter respect a max duration if specified. 1875 func TestDo_rateLimit_abuseRateLimitError_maxDuration(t *testing.T) { 1876 t.Parallel() 1877 client, mux, _ := setup(t) 1878 // specify a max retry after duration of 1 min 1879 client.MaxSecondaryRateLimitRetryAfterDuration = 60 * time.Second 1880 1881 // x-ratelimit-reset value of 1h into the future, to make sure we are way over the max wait time duration. 1882 blockUntil := time.Now().Add(1 * time.Hour).Unix() 1883 1884 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1885 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1886 w.Header().Set(headerRateReset, strconv.Itoa(int(blockUntil))) 1887 w.Header().Set(headerRateRemaining, "1") // set remaining to a value > 0 to distinct from a primary rate limit 1888 w.WriteHeader(http.StatusForbidden) 1889 fmt.Fprintln(w, `{ 1890 "message": "You have triggered an abuse detection mechanism ...", 1891 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1892 }`) 1893 }) 1894 1895 req, _ := client.NewRequest("GET", ".", nil) 1896 ctx := context.Background() 1897 _, err := client.Do(ctx, req, nil) 1898 1899 if err == nil { 1900 t.Error("Expected error to be returned.") 1901 } 1902 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1903 if !ok { 1904 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1905 } 1906 if abuseRateLimitErr.RetryAfter == nil { 1907 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1908 } 1909 // check that the retry after is set to be the max allowed duration 1910 if got, want := *abuseRateLimitErr.RetryAfter, client.MaxSecondaryRateLimitRetryAfterDuration; got != want { 1911 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1912 } 1913 } 1914 1915 func TestDo_noContent(t *testing.T) { 1916 t.Parallel() 1917 client, mux, _ := setup(t) 1918 1919 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1920 w.WriteHeader(http.StatusNoContent) 1921 }) 1922 1923 var body json.RawMessage 1924 1925 req, _ := client.NewRequest("GET", ".", nil) 1926 ctx := context.Background() 1927 _, err := client.Do(ctx, req, &body) 1928 if err != nil { 1929 t.Fatalf("Do returned unexpected error: %v", err) 1930 } 1931 } 1932 1933 func TestBareDoUntilFound_redirectLoop(t *testing.T) { 1934 t.Parallel() 1935 client, mux, _ := setup(t) 1936 1937 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1938 http.Redirect(w, r, baseURLPath, http.StatusMovedPermanently) 1939 }) 1940 1941 req, _ := client.NewRequest("GET", ".", nil) 1942 ctx := context.Background() 1943 _, _, err := client.bareDoUntilFound(ctx, req, 1) 1944 1945 if err == nil { 1946 t.Error("Expected error to be returned.") 1947 } 1948 var rerr *RedirectionError 1949 if !errors.As(err, &rerr) { 1950 t.Errorf("Expected a Redirection error; got %#v.", err) 1951 } 1952 } 1953 1954 func TestBareDoUntilFound_UnexpectedRedirection(t *testing.T) { 1955 t.Parallel() 1956 client, mux, _ := setup(t) 1957 1958 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1959 http.Redirect(w, r, baseURLPath, http.StatusSeeOther) 1960 }) 1961 1962 req, _ := client.NewRequest("GET", ".", nil) 1963 ctx := context.Background() 1964 _, _, err := client.bareDoUntilFound(ctx, req, 1) 1965 1966 if err == nil { 1967 t.Error("Expected error to be returned.") 1968 } 1969 var rerr *RedirectionError 1970 if !errors.As(err, &rerr) { 1971 t.Errorf("Expected a Redirection error; got %#v.", err) 1972 } 1973 } 1974 1975 func TestSanitizeURL(t *testing.T) { 1976 t.Parallel() 1977 tests := []struct { 1978 in, want string 1979 }{ 1980 {"/?a=b", "/?a=b"}, 1981 {"/?a=b&client_secret=secret", "/?a=b&client_secret=REDACTED"}, 1982 {"/?a=b&client_id=id&client_secret=secret", "/?a=b&client_id=id&client_secret=REDACTED"}, 1983 } 1984 1985 for _, tt := range tests { 1986 inURL, _ := url.Parse(tt.in) 1987 want, _ := url.Parse(tt.want) 1988 1989 if got := sanitizeURL(inURL); !cmp.Equal(got, want) { 1990 t.Errorf("sanitizeURL(%v) returned %v, want %v", tt.in, got, want) 1991 } 1992 } 1993 } 1994 1995 func TestCheckResponse(t *testing.T) { 1996 t.Parallel() 1997 res := &http.Response{ 1998 Request: &http.Request{}, 1999 StatusCode: http.StatusBadRequest, 2000 Body: io.NopCloser(strings.NewReader(`{"message":"m", 2001 "errors": [{"resource": "r", "field": "f", "code": "c"}], 2002 "block": {"reason": "dmca", "created_at": "2016-03-17T15:39:46Z"}}`)), 2003 } 2004 err := CheckResponse(res).(*ErrorResponse) 2005 2006 if err == nil { 2007 t.Errorf("Expected error response.") 2008 } 2009 2010 want := &ErrorResponse{ 2011 Response: res, 2012 Message: "m", 2013 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2014 Block: &ErrorBlock{ 2015 Reason: "dmca", 2016 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2017 }, 2018 } 2019 if !errors.Is(err, want) { 2020 t.Errorf("Error = %#v, want %#v", err, want) 2021 } 2022 } 2023 2024 func TestCheckResponse_RateLimit(t *testing.T) { 2025 t.Parallel() 2026 res := &http.Response{ 2027 Request: &http.Request{}, 2028 StatusCode: http.StatusForbidden, 2029 Header: http.Header{}, 2030 Body: io.NopCloser(strings.NewReader(`{"message":"m", 2031 "documentation_url": "url"}`)), 2032 } 2033 res.Header.Set(headerRateLimit, "60") 2034 res.Header.Set(headerRateRemaining, "0") 2035 res.Header.Set(headerRateUsed, "1") 2036 res.Header.Set(headerRateReset, "243424") 2037 res.Header.Set(headerRateResource, "core") 2038 2039 err := CheckResponse(res).(*RateLimitError) 2040 2041 if err == nil { 2042 t.Errorf("Expected error response.") 2043 } 2044 2045 want := &RateLimitError{ 2046 Rate: parseRate(res), 2047 Response: res, 2048 Message: "m", 2049 } 2050 if !errors.Is(err, want) { 2051 t.Errorf("Error = %#v, want %#v", err, want) 2052 } 2053 } 2054 2055 func TestCheckResponse_AbuseRateLimit(t *testing.T) { 2056 t.Parallel() 2057 res := &http.Response{ 2058 Request: &http.Request{}, 2059 StatusCode: http.StatusForbidden, 2060 Body: io.NopCloser(strings.NewReader(`{"message":"m", 2061 "documentation_url": "docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits"}`)), 2062 } 2063 err := CheckResponse(res).(*AbuseRateLimitError) 2064 2065 if err == nil { 2066 t.Errorf("Expected error response.") 2067 } 2068 2069 want := &AbuseRateLimitError{ 2070 Response: res, 2071 Message: "m", 2072 } 2073 if !errors.Is(err, want) { 2074 t.Errorf("Error = %#v, want %#v", err, want) 2075 } 2076 } 2077 2078 func TestCheckResponse_RedirectionError(t *testing.T) { 2079 t.Parallel() 2080 urlStr := "/foo/bar" 2081 2082 res := &http.Response{ 2083 Request: &http.Request{}, 2084 StatusCode: http.StatusFound, 2085 Header: http.Header{}, 2086 Body: io.NopCloser(strings.NewReader(``)), 2087 } 2088 res.Header.Set("Location", urlStr) 2089 err := CheckResponse(res).(*RedirectionError) 2090 2091 if err == nil { 2092 t.Errorf("Expected error response.") 2093 } 2094 2095 wantedURL, parseErr := url.Parse(urlStr) 2096 if parseErr != nil { 2097 t.Errorf("Error parsing fixture url: %v", parseErr) 2098 } 2099 2100 want := &RedirectionError{ 2101 Response: res, 2102 StatusCode: http.StatusFound, 2103 Location: wantedURL, 2104 } 2105 if !errors.Is(err, want) { 2106 t.Errorf("Error = %#v, want %#v", err, want) 2107 } 2108 } 2109 2110 func TestCompareHttpResponse(t *testing.T) { 2111 t.Parallel() 2112 testcases := map[string]struct { 2113 h1 *http.Response 2114 h2 *http.Response 2115 expected bool 2116 }{ 2117 "both are nil": { 2118 expected: true, 2119 }, 2120 "both are non nil - same StatusCode": { 2121 expected: true, 2122 h1: &http.Response{StatusCode: 200}, 2123 h2: &http.Response{StatusCode: 200}, 2124 }, 2125 "both are non nil - different StatusCode": { 2126 expected: false, 2127 h1: &http.Response{StatusCode: 200}, 2128 h2: &http.Response{StatusCode: 404}, 2129 }, 2130 "one is nil, other is not": { 2131 expected: false, 2132 h2: &http.Response{}, 2133 }, 2134 } 2135 2136 for name, tc := range testcases { 2137 t.Run(name, func(t *testing.T) { 2138 t.Parallel() 2139 v := compareHTTPResponse(tc.h1, tc.h2) 2140 if tc.expected != v { 2141 t.Errorf("Expected %t, got %t for (%#v, %#v)", tc.expected, v, tc.h1, tc.h2) 2142 } 2143 }) 2144 } 2145 } 2146 2147 func TestErrorResponse_Is(t *testing.T) { 2148 t.Parallel() 2149 err := &ErrorResponse{ 2150 Response: &http.Response{}, 2151 Message: "m", 2152 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2153 Block: &ErrorBlock{ 2154 Reason: "r", 2155 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2156 }, 2157 DocumentationURL: "https://github.com", 2158 } 2159 testcases := map[string]struct { 2160 wantSame bool 2161 otherError error 2162 }{ 2163 "errors are same": { 2164 wantSame: true, 2165 otherError: &ErrorResponse{ 2166 Response: &http.Response{}, 2167 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2168 Message: "m", 2169 Block: &ErrorBlock{ 2170 Reason: "r", 2171 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2172 }, 2173 DocumentationURL: "https://github.com", 2174 }, 2175 }, 2176 "errors have different values - Message": { 2177 wantSame: false, 2178 otherError: &ErrorResponse{ 2179 Response: &http.Response{}, 2180 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2181 Message: "m1", 2182 Block: &ErrorBlock{ 2183 Reason: "r", 2184 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2185 }, 2186 DocumentationURL: "https://github.com", 2187 }, 2188 }, 2189 "errors have different values - DocumentationURL": { 2190 wantSame: false, 2191 otherError: &ErrorResponse{ 2192 Response: &http.Response{}, 2193 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2194 Message: "m", 2195 Block: &ErrorBlock{ 2196 Reason: "r", 2197 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2198 }, 2199 DocumentationURL: "https://google.com", 2200 }, 2201 }, 2202 "errors have different values - Response is nil": { 2203 wantSame: false, 2204 otherError: &ErrorResponse{ 2205 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2206 Message: "m", 2207 Block: &ErrorBlock{ 2208 Reason: "r", 2209 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2210 }, 2211 DocumentationURL: "https://github.com", 2212 }, 2213 }, 2214 "errors have different values - Errors": { 2215 wantSame: false, 2216 otherError: &ErrorResponse{ 2217 Response: &http.Response{}, 2218 Errors: []Error{{Resource: "r1", Field: "f1", Code: "c1"}}, 2219 Message: "m", 2220 Block: &ErrorBlock{ 2221 Reason: "r", 2222 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2223 }, 2224 DocumentationURL: "https://github.com", 2225 }, 2226 }, 2227 "errors have different values - Errors have different length": { 2228 wantSame: false, 2229 otherError: &ErrorResponse{ 2230 Response: &http.Response{}, 2231 Errors: []Error{}, 2232 Message: "m", 2233 Block: &ErrorBlock{ 2234 Reason: "r", 2235 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2236 }, 2237 DocumentationURL: "https://github.com", 2238 }, 2239 }, 2240 "errors have different values - Block - one is nil, other is not": { 2241 wantSame: false, 2242 otherError: &ErrorResponse{ 2243 Response: &http.Response{}, 2244 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2245 Message: "m", 2246 DocumentationURL: "https://github.com", 2247 }, 2248 }, 2249 "errors have different values - Block - different Reason": { 2250 wantSame: false, 2251 otherError: &ErrorResponse{ 2252 Response: &http.Response{}, 2253 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2254 Message: "m", 2255 Block: &ErrorBlock{ 2256 Reason: "r1", 2257 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2258 }, 2259 DocumentationURL: "https://github.com", 2260 }, 2261 }, 2262 "errors have different values - Block - different CreatedAt #1": { 2263 wantSame: false, 2264 otherError: &ErrorResponse{ 2265 Response: &http.Response{}, 2266 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2267 Message: "m", 2268 Block: &ErrorBlock{ 2269 Reason: "r", 2270 CreatedAt: nil, 2271 }, 2272 DocumentationURL: "https://github.com", 2273 }, 2274 }, 2275 "errors have different values - Block - different CreatedAt #2": { 2276 wantSame: false, 2277 otherError: &ErrorResponse{ 2278 Response: &http.Response{}, 2279 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2280 Message: "m", 2281 Block: &ErrorBlock{ 2282 Reason: "r", 2283 CreatedAt: &Timestamp{time.Date(2017, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2284 }, 2285 DocumentationURL: "https://github.com", 2286 }, 2287 }, 2288 "errors have different types": { 2289 wantSame: false, 2290 otherError: errors.New("github"), 2291 }, 2292 } 2293 2294 for name, tc := range testcases { 2295 t.Run(name, func(t *testing.T) { 2296 t.Parallel() 2297 if tc.wantSame != err.Is(tc.otherError) { 2298 t.Errorf("Error = %#v, want %#v", err, tc.otherError) 2299 } 2300 }) 2301 } 2302 } 2303 2304 func TestRateLimitError_Is(t *testing.T) { 2305 t.Parallel() 2306 err := &RateLimitError{ 2307 Response: &http.Response{}, 2308 Message: "Github", 2309 } 2310 testcases := map[string]struct { 2311 wantSame bool 2312 err *RateLimitError 2313 otherError error 2314 }{ 2315 "errors are same": { 2316 wantSame: true, 2317 err: err, 2318 otherError: &RateLimitError{ 2319 Response: &http.Response{}, 2320 Message: "Github", 2321 }, 2322 }, 2323 "errors are same - Response is nil": { 2324 wantSame: true, 2325 err: &RateLimitError{ 2326 Message: "Github", 2327 }, 2328 otherError: &RateLimitError{ 2329 Message: "Github", 2330 }, 2331 }, 2332 "errors have different values - Rate": { 2333 wantSame: false, 2334 err: err, 2335 otherError: &RateLimitError{ 2336 Rate: Rate{Limit: 10}, 2337 Response: &http.Response{}, 2338 Message: "Gitlab", 2339 }, 2340 }, 2341 "errors have different values - Response is nil": { 2342 wantSame: false, 2343 err: err, 2344 otherError: &RateLimitError{ 2345 Message: "Github", 2346 }, 2347 }, 2348 "errors have different values - StatusCode": { 2349 wantSame: false, 2350 err: err, 2351 otherError: &RateLimitError{ 2352 Response: &http.Response{StatusCode: 200}, 2353 Message: "Github", 2354 }, 2355 }, 2356 "errors have different types": { 2357 wantSame: false, 2358 err: err, 2359 otherError: errors.New("github"), 2360 }, 2361 } 2362 2363 for name, tc := range testcases { 2364 t.Run(name, func(t *testing.T) { 2365 t.Parallel() 2366 if tc.wantSame != tc.err.Is(tc.otherError) { 2367 t.Errorf("Error = %#v, want %#v", tc.err, tc.otherError) 2368 } 2369 }) 2370 } 2371 } 2372 2373 func TestAbuseRateLimitError_Is(t *testing.T) { 2374 t.Parallel() 2375 t1 := 1 * time.Second 2376 t2 := 2 * time.Second 2377 err := &AbuseRateLimitError{ 2378 Response: &http.Response{}, 2379 Message: "Github", 2380 RetryAfter: &t1, 2381 } 2382 testcases := map[string]struct { 2383 wantSame bool 2384 err *AbuseRateLimitError 2385 otherError error 2386 }{ 2387 "errors are same": { 2388 wantSame: true, 2389 err: err, 2390 otherError: &AbuseRateLimitError{ 2391 Response: &http.Response{}, 2392 Message: "Github", 2393 RetryAfter: &t1, 2394 }, 2395 }, 2396 "errors are same - Response is nil": { 2397 wantSame: true, 2398 err: &AbuseRateLimitError{ 2399 Message: "Github", 2400 RetryAfter: &t1, 2401 }, 2402 otherError: &AbuseRateLimitError{ 2403 Message: "Github", 2404 RetryAfter: &t1, 2405 }, 2406 }, 2407 "errors have different values - Message": { 2408 wantSame: false, 2409 err: err, 2410 otherError: &AbuseRateLimitError{ 2411 Response: &http.Response{}, 2412 Message: "Gitlab", 2413 RetryAfter: nil, 2414 }, 2415 }, 2416 "errors have different values - RetryAfter": { 2417 wantSame: false, 2418 err: err, 2419 otherError: &AbuseRateLimitError{ 2420 Response: &http.Response{}, 2421 Message: "Github", 2422 RetryAfter: &t2, 2423 }, 2424 }, 2425 "errors have different values - Response is nil": { 2426 wantSame: false, 2427 err: err, 2428 otherError: &AbuseRateLimitError{ 2429 Message: "Github", 2430 RetryAfter: &t1, 2431 }, 2432 }, 2433 "errors have different values - StatusCode": { 2434 wantSame: false, 2435 err: err, 2436 otherError: &AbuseRateLimitError{ 2437 Response: &http.Response{StatusCode: 200}, 2438 Message: "Github", 2439 RetryAfter: &t1, 2440 }, 2441 }, 2442 "errors have different types": { 2443 wantSame: false, 2444 err: err, 2445 otherError: errors.New("github"), 2446 }, 2447 } 2448 2449 for name, tc := range testcases { 2450 t.Run(name, func(t *testing.T) { 2451 t.Parallel() 2452 if tc.wantSame != tc.err.Is(tc.otherError) { 2453 t.Errorf("Error = %#v, want %#v", tc.err, tc.otherError) 2454 } 2455 }) 2456 } 2457 } 2458 2459 func TestAcceptedError_Is(t *testing.T) { 2460 t.Parallel() 2461 err := &AcceptedError{Raw: []byte("Github")} 2462 testcases := map[string]struct { 2463 wantSame bool 2464 otherError error 2465 }{ 2466 "errors are same": { 2467 wantSame: true, 2468 otherError: &AcceptedError{Raw: []byte("Github")}, 2469 }, 2470 "errors have different values": { 2471 wantSame: false, 2472 otherError: &AcceptedError{Raw: []byte("Gitlab")}, 2473 }, 2474 "errors have different types": { 2475 wantSame: false, 2476 otherError: errors.New("github"), 2477 }, 2478 } 2479 2480 for name, tc := range testcases { 2481 t.Run(name, func(t *testing.T) { 2482 t.Parallel() 2483 if tc.wantSame != err.Is(tc.otherError) { 2484 t.Errorf("Error = %#v, want %#v", err, tc.otherError) 2485 } 2486 }) 2487 } 2488 } 2489 2490 // Ensure that we properly handle API errors that do not contain a response body. 2491 func TestCheckResponse_noBody(t *testing.T) { 2492 t.Parallel() 2493 res := &http.Response{ 2494 Request: &http.Request{}, 2495 StatusCode: http.StatusBadRequest, 2496 Body: io.NopCloser(strings.NewReader("")), 2497 } 2498 err := CheckResponse(res).(*ErrorResponse) 2499 2500 if err == nil { 2501 t.Errorf("Expected error response.") 2502 } 2503 2504 want := &ErrorResponse{ 2505 Response: res, 2506 } 2507 if !errors.Is(err, want) { 2508 t.Errorf("Error = %#v, want %#v", err, want) 2509 } 2510 } 2511 2512 func TestCheckResponse_unexpectedErrorStructure(t *testing.T) { 2513 t.Parallel() 2514 httpBody := `{"message":"m", "errors": ["error 1"]}` 2515 res := &http.Response{ 2516 Request: &http.Request{}, 2517 StatusCode: http.StatusBadRequest, 2518 Body: io.NopCloser(strings.NewReader(httpBody)), 2519 } 2520 err := CheckResponse(res).(*ErrorResponse) 2521 2522 if err == nil { 2523 t.Errorf("Expected error response.") 2524 } 2525 2526 want := &ErrorResponse{ 2527 Response: res, 2528 Message: "m", 2529 Errors: []Error{{Message: "error 1"}}, 2530 } 2531 if !errors.Is(err, want) { 2532 t.Errorf("Error = %#v, want %#v", err, want) 2533 } 2534 data, err2 := io.ReadAll(err.Response.Body) 2535 if err2 != nil { 2536 t.Fatalf("failed to read response body: %v", err) 2537 } 2538 if got := string(data); got != httpBody { 2539 t.Errorf("ErrorResponse.Response.Body = %q, want %q", got, httpBody) 2540 } 2541 } 2542 2543 func TestParseBooleanResponse_true(t *testing.T) { 2544 t.Parallel() 2545 result, err := parseBoolResponse(nil) 2546 if err != nil { 2547 t.Errorf("parseBoolResponse returned error: %+v", err) 2548 } 2549 2550 if want := true; result != want { 2551 t.Errorf("parseBoolResponse returned %+v, want: %+v", result, want) 2552 } 2553 } 2554 2555 func TestParseBooleanResponse_false(t *testing.T) { 2556 t.Parallel() 2557 v := &ErrorResponse{Response: &http.Response{StatusCode: http.StatusNotFound}} 2558 result, err := parseBoolResponse(v) 2559 if err != nil { 2560 t.Errorf("parseBoolResponse returned error: %+v", err) 2561 } 2562 2563 if want := false; result != want { 2564 t.Errorf("parseBoolResponse returned %+v, want: %+v", result, want) 2565 } 2566 } 2567 2568 func TestParseBooleanResponse_error(t *testing.T) { 2569 t.Parallel() 2570 v := &ErrorResponse{Response: &http.Response{StatusCode: http.StatusBadRequest}} 2571 result, err := parseBoolResponse(v) 2572 2573 if err == nil { 2574 t.Errorf("Expected error to be returned.") 2575 } 2576 2577 if want := false; result != want { 2578 t.Errorf("parseBoolResponse returned %+v, want: %+v", result, want) 2579 } 2580 } 2581 2582 func TestErrorResponse_Error(t *testing.T) { 2583 t.Parallel() 2584 res := &http.Response{Request: &http.Request{}} 2585 err := ErrorResponse{Message: "m", Response: res} 2586 if err.Error() == "" { 2587 t.Errorf("Expected non-empty ErrorResponse.Error()") 2588 } 2589 2590 // dont panic if request is nil 2591 res = &http.Response{} 2592 err = ErrorResponse{Message: "m", Response: res} 2593 if err.Error() == "" { 2594 t.Errorf("Expected non-empty ErrorResponse.Error()") 2595 } 2596 2597 // dont panic if response is nil 2598 err = ErrorResponse{Message: "m"} 2599 if err.Error() == "" { 2600 t.Errorf("Expected non-empty ErrorResponse.Error()") 2601 } 2602 } 2603 2604 func TestError_Error(t *testing.T) { 2605 t.Parallel() 2606 err := Error{} 2607 if err.Error() == "" { 2608 t.Errorf("Expected non-empty Error.Error()") 2609 } 2610 } 2611 2612 func TestSetCredentialsAsHeaders(t *testing.T) { 2613 t.Parallel() 2614 req := new(http.Request) 2615 id, secret := "id", "secret" 2616 modifiedRequest := setCredentialsAsHeaders(req, id, secret) 2617 2618 actualID, actualSecret, ok := modifiedRequest.BasicAuth() 2619 if !ok { 2620 t.Errorf("request does not contain basic credentials") 2621 } 2622 2623 if actualID != id { 2624 t.Errorf("id is %s, want %s", actualID, id) 2625 } 2626 2627 if actualSecret != secret { 2628 t.Errorf("secret is %s, want %s", actualSecret, secret) 2629 } 2630 } 2631 2632 func TestUnauthenticatedRateLimitedTransport(t *testing.T) { 2633 t.Parallel() 2634 client, mux, _ := setup(t) 2635 2636 clientID, clientSecret := "id", "secret" 2637 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 2638 id, secret, ok := r.BasicAuth() 2639 if !ok { 2640 t.Errorf("request does not contain basic auth credentials") 2641 } 2642 if id != clientID { 2643 t.Errorf("request contained basic auth username %q, want %q", id, clientID) 2644 } 2645 if secret != clientSecret { 2646 t.Errorf("request contained basic auth password %q, want %q", secret, clientSecret) 2647 } 2648 }) 2649 2650 tp := &UnauthenticatedRateLimitedTransport{ 2651 ClientID: clientID, 2652 ClientSecret: clientSecret, 2653 } 2654 unauthedClient := NewClient(tp.Client()) 2655 unauthedClient.BaseURL = client.BaseURL 2656 req, _ := unauthedClient.NewRequest("GET", ".", nil) 2657 ctx := context.Background() 2658 _, err := unauthedClient.Do(ctx, req, nil) 2659 assertNilError(t, err) 2660 } 2661 2662 func TestUnauthenticatedRateLimitedTransport_missingFields(t *testing.T) { 2663 t.Parallel() 2664 // missing ClientID 2665 tp := &UnauthenticatedRateLimitedTransport{ 2666 ClientSecret: "secret", 2667 } 2668 _, err := tp.RoundTrip(nil) 2669 if err == nil { 2670 t.Errorf("Expected error to be returned") 2671 } 2672 2673 // missing ClientSecret 2674 tp = &UnauthenticatedRateLimitedTransport{ 2675 ClientID: "id", 2676 } 2677 _, err = tp.RoundTrip(nil) 2678 if err == nil { 2679 t.Errorf("Expected error to be returned") 2680 } 2681 } 2682 2683 func TestUnauthenticatedRateLimitedTransport_transport(t *testing.T) { 2684 t.Parallel() 2685 // default transport 2686 tp := &UnauthenticatedRateLimitedTransport{ 2687 ClientID: "id", 2688 ClientSecret: "secret", 2689 } 2690 if tp.transport() != http.DefaultTransport { 2691 t.Errorf("Expected http.DefaultTransport to be used.") 2692 } 2693 2694 // custom transport 2695 tp = &UnauthenticatedRateLimitedTransport{ 2696 ClientID: "id", 2697 ClientSecret: "secret", 2698 Transport: &http.Transport{}, 2699 } 2700 if tp.transport() == http.DefaultTransport { 2701 t.Errorf("Expected custom transport to be used.") 2702 } 2703 } 2704 2705 func TestBasicAuthTransport(t *testing.T) { 2706 t.Parallel() 2707 client, mux, _ := setup(t) 2708 2709 username, password, otp := "u", "p", "123456" 2710 2711 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 2712 u, p, ok := r.BasicAuth() 2713 if !ok { 2714 t.Errorf("request does not contain basic auth credentials") 2715 } 2716 if u != username { 2717 t.Errorf("request contained basic auth username %q, want %q", u, username) 2718 } 2719 if p != password { 2720 t.Errorf("request contained basic auth password %q, want %q", p, password) 2721 } 2722 if got, want := r.Header.Get(headerOTP), otp; got != want { 2723 t.Errorf("request contained OTP %q, want %q", got, want) 2724 } 2725 }) 2726 2727 tp := &BasicAuthTransport{ 2728 Username: username, 2729 Password: password, 2730 OTP: otp, 2731 } 2732 basicAuthClient := NewClient(tp.Client()) 2733 basicAuthClient.BaseURL = client.BaseURL 2734 req, _ := basicAuthClient.NewRequest("GET", ".", nil) 2735 ctx := context.Background() 2736 _, err := basicAuthClient.Do(ctx, req, nil) 2737 assertNilError(t, err) 2738 } 2739 2740 func TestBasicAuthTransport_transport(t *testing.T) { 2741 t.Parallel() 2742 // default transport 2743 tp := &BasicAuthTransport{} 2744 if tp.transport() != http.DefaultTransport { 2745 t.Errorf("Expected http.DefaultTransport to be used.") 2746 } 2747 2748 // custom transport 2749 tp = &BasicAuthTransport{ 2750 Transport: &http.Transport{}, 2751 } 2752 if tp.transport() == http.DefaultTransport { 2753 t.Errorf("Expected custom transport to be used.") 2754 } 2755 } 2756 2757 func TestFormatRateReset(t *testing.T) { 2758 t.Parallel() 2759 d := 120*time.Minute + 12*time.Second 2760 got := formatRateReset(d) 2761 want := "[rate reset in 120m12s]" 2762 if got != want { 2763 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2764 } 2765 2766 d = 14*time.Minute + 2*time.Second 2767 got = formatRateReset(d) 2768 want = "[rate reset in 14m02s]" 2769 if got != want { 2770 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2771 } 2772 2773 d = 2*time.Minute + 2*time.Second 2774 got = formatRateReset(d) 2775 want = "[rate reset in 2m02s]" 2776 if got != want { 2777 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2778 } 2779 2780 d = 12 * time.Second 2781 got = formatRateReset(d) 2782 want = "[rate reset in 12s]" 2783 if got != want { 2784 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2785 } 2786 2787 d = -1 * (2*time.Hour + 2*time.Second) 2788 got = formatRateReset(d) 2789 want = "[rate limit was reset 120m02s ago]" 2790 if got != want { 2791 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2792 } 2793 } 2794 2795 func TestNestedStructAccessorNoPanic(t *testing.T) { 2796 t.Parallel() 2797 issue := &Issue{User: nil} 2798 got := issue.GetUser().GetPlan().GetName() 2799 want := "" 2800 if got != want { 2801 t.Errorf("Issues.Get.GetUser().GetPlan().GetName() returned %+v, want %+v", got, want) 2802 } 2803 } 2804 2805 func TestTwoFactorAuthError(t *testing.T) { 2806 t.Parallel() 2807 u, err := url.Parse("https://example.com") 2808 if err != nil { 2809 t.Fatal(err) 2810 } 2811 2812 e := &TwoFactorAuthError{ 2813 Response: &http.Response{ 2814 Request: &http.Request{Method: "PUT", URL: u}, 2815 StatusCode: http.StatusTooManyRequests, 2816 }, 2817 Message: "<msg>", 2818 } 2819 if got, want := e.Error(), "PUT https://example.com: 429 <msg> []"; got != want { 2820 t.Errorf("TwoFactorAuthError = %q, want %q", got, want) 2821 } 2822 } 2823 2824 func TestRateLimitError(t *testing.T) { 2825 t.Parallel() 2826 u, err := url.Parse("https://example.com") 2827 if err != nil { 2828 t.Fatal(err) 2829 } 2830 2831 r := &RateLimitError{ 2832 Response: &http.Response{ 2833 Request: &http.Request{Method: "PUT", URL: u}, 2834 StatusCode: http.StatusTooManyRequests, 2835 }, 2836 Message: "<msg>", 2837 } 2838 if got, want := r.Error(), "PUT https://example.com: 429 <msg> [rate limit was reset"; !strings.Contains(got, want) { 2839 t.Errorf("RateLimitError = %q, want %q", got, want) 2840 } 2841 } 2842 2843 func TestAcceptedError(t *testing.T) { 2844 t.Parallel() 2845 a := &AcceptedError{} 2846 if got, want := a.Error(), "try again later"; !strings.Contains(got, want) { 2847 t.Errorf("AcceptedError = %q, want %q", got, want) 2848 } 2849 } 2850 2851 func TestAbuseRateLimitError(t *testing.T) { 2852 t.Parallel() 2853 u, err := url.Parse("https://example.com") 2854 if err != nil { 2855 t.Fatal(err) 2856 } 2857 2858 r := &AbuseRateLimitError{ 2859 Response: &http.Response{ 2860 Request: &http.Request{Method: "PUT", URL: u}, 2861 StatusCode: http.StatusTooManyRequests, 2862 }, 2863 Message: "<msg>", 2864 } 2865 if got, want := r.Error(), "PUT https://example.com: 429 <msg>"; got != want { 2866 t.Errorf("AbuseRateLimitError = %q, want %q", got, want) 2867 } 2868 } 2869 2870 func TestAddOptions_QueryValues(t *testing.T) { 2871 t.Parallel() 2872 if _, err := addOptions("yo", ""); err == nil { 2873 t.Error("addOptions err = nil, want error") 2874 } 2875 } 2876 2877 func TestBareDo_returnsOpenBody(t *testing.T) { 2878 t.Parallel() 2879 client, mux, _ := setup(t) 2880 2881 expectedBody := "Hello from the other side !" 2882 2883 mux.HandleFunc("/test-url", func(w http.ResponseWriter, r *http.Request) { 2884 testMethod(t, r, "GET") 2885 fmt.Fprint(w, expectedBody) 2886 }) 2887 2888 ctx := context.Background() 2889 req, err := client.NewRequest("GET", "test-url", nil) 2890 if err != nil { 2891 t.Fatalf("client.NewRequest returned error: %v", err) 2892 } 2893 2894 resp, err := client.BareDo(ctx, req) 2895 if err != nil { 2896 t.Fatalf("client.BareDo returned error: %v", err) 2897 } 2898 2899 got, err := io.ReadAll(resp.Body) 2900 if err != nil { 2901 t.Fatalf("io.ReadAll returned error: %v", err) 2902 } 2903 if string(got) != expectedBody { 2904 t.Fatalf("Expected %q, got %q", expectedBody, string(got)) 2905 } 2906 if err := resp.Body.Close(); err != nil { 2907 t.Fatalf("resp.Body.Close() returned error: %v", err) 2908 } 2909 } 2910 2911 func TestErrorResponse_Marshal(t *testing.T) { 2912 t.Parallel() 2913 testJSONMarshal(t, &ErrorResponse{}, "{}") 2914 2915 u := &ErrorResponse{ 2916 Message: "msg", 2917 Errors: []Error{ 2918 { 2919 Resource: "res", 2920 Field: "f", 2921 Code: "c", 2922 Message: "msg", 2923 }, 2924 }, 2925 Block: &ErrorBlock{ 2926 Reason: "reason", 2927 CreatedAt: &Timestamp{referenceTime}, 2928 }, 2929 DocumentationURL: "doc", 2930 } 2931 2932 want := `{ 2933 "message": "msg", 2934 "errors": [ 2935 { 2936 "resource": "res", 2937 "field": "f", 2938 "code": "c", 2939 "message": "msg" 2940 } 2941 ], 2942 "block": { 2943 "reason": "reason", 2944 "created_at": ` + referenceTimeStr + ` 2945 }, 2946 "documentation_url": "doc" 2947 }` 2948 2949 testJSONMarshal(t, u, want) 2950 } 2951 2952 func TestErrorBlock_Marshal(t *testing.T) { 2953 t.Parallel() 2954 testJSONMarshal(t, &ErrorBlock{}, "{}") 2955 2956 u := &ErrorBlock{ 2957 Reason: "reason", 2958 CreatedAt: &Timestamp{referenceTime}, 2959 } 2960 2961 want := `{ 2962 "reason": "reason", 2963 "created_at": ` + referenceTimeStr + ` 2964 }` 2965 2966 testJSONMarshal(t, u, want) 2967 } 2968 2969 func TestRateLimitError_Marshal(t *testing.T) { 2970 t.Parallel() 2971 testJSONMarshal(t, &RateLimitError{}, "{}") 2972 2973 u := &RateLimitError{ 2974 Rate: Rate{ 2975 Limit: 1, 2976 Remaining: 1, 2977 Reset: Timestamp{referenceTime}, 2978 }, 2979 Message: "msg", 2980 } 2981 2982 want := `{ 2983 "Rate": { 2984 "limit": 1, 2985 "remaining": 1, 2986 "reset": ` + referenceTimeStr + ` 2987 }, 2988 "message": "msg" 2989 }` 2990 2991 testJSONMarshal(t, u, want) 2992 } 2993 2994 func TestAbuseRateLimitError_Marshal(t *testing.T) { 2995 t.Parallel() 2996 testJSONMarshal(t, &AbuseRateLimitError{}, "{}") 2997 2998 u := &AbuseRateLimitError{ 2999 Message: "msg", 3000 } 3001 3002 want := `{ 3003 "message": "msg" 3004 }` 3005 3006 testJSONMarshal(t, u, want) 3007 } 3008 3009 func TestError_Marshal(t *testing.T) { 3010 t.Parallel() 3011 testJSONMarshal(t, &Error{}, "{}") 3012 3013 u := &Error{ 3014 Resource: "res", 3015 Field: "field", 3016 Code: "code", 3017 Message: "msg", 3018 } 3019 3020 want := `{ 3021 "resource": "res", 3022 "field": "field", 3023 "code": "code", 3024 "message": "msg" 3025 }` 3026 3027 testJSONMarshal(t, u, want) 3028 } 3029 3030 func TestParseTokenExpiration(t *testing.T) { 3031 t.Parallel() 3032 tests := []struct { 3033 header string 3034 want Timestamp 3035 }{ 3036 { 3037 header: "", 3038 want: Timestamp{}, 3039 }, 3040 { 3041 header: "this is a garbage", 3042 want: Timestamp{}, 3043 }, 3044 { 3045 header: "2021-09-03 02:34:04 UTC", 3046 want: Timestamp{time.Date(2021, time.September, 3, 2, 34, 4, 0, time.UTC)}, 3047 }, 3048 { 3049 header: "2021-09-03 14:34:04 UTC", 3050 want: Timestamp{time.Date(2021, time.September, 3, 14, 34, 4, 0, time.UTC)}, 3051 }, 3052 // Some tokens include the timezone offset instead of the timezone. 3053 // https://github.com/google/go-github/issues/2649 3054 { 3055 header: "2023-04-26 20:23:26 +0200", 3056 want: Timestamp{time.Date(2023, time.April, 26, 18, 23, 26, 0, time.UTC)}, 3057 }, 3058 } 3059 3060 for _, tt := range tests { 3061 res := &http.Response{ 3062 Request: &http.Request{}, 3063 Header: http.Header{}, 3064 } 3065 3066 res.Header.Set(headerTokenExpiration, tt.header) 3067 exp := parseTokenExpiration(res) 3068 if !exp.Equal(tt.want) { 3069 t.Errorf("parseTokenExpiration of %q\nreturned %#v\n want %#v", tt.header, exp, tt.want) 3070 } 3071 } 3072 } 3073 3074 func TestClientCopy_leak_transport(t *testing.T) { 3075 t.Parallel() 3076 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 3077 w.Header().Set("Content-Type", "application/json") 3078 accessToken := r.Header.Get("Authorization") 3079 _, _ = fmt.Fprintf(w, `{"login": "%s"}`, accessToken) 3080 })) 3081 clientPreconfiguredWithURLs, err := NewClient(nil).WithEnterpriseURLs(srv.URL, srv.URL) 3082 if err != nil { 3083 t.Fatal(err) 3084 } 3085 3086 aliceClient := clientPreconfiguredWithURLs.WithAuthToken("alice") 3087 bobClient := clientPreconfiguredWithURLs.WithAuthToken("bob") 3088 3089 alice, _, err := aliceClient.Users.Get(context.Background(), "") 3090 if err != nil { 3091 t.Fatal(err) 3092 } 3093 3094 assertNoDiff(t, "Bearer alice", alice.GetLogin()) 3095 3096 bob, _, err := bobClient.Users.Get(context.Background(), "") 3097 if err != nil { 3098 t.Fatal(err) 3099 } 3100 3101 assertNoDiff(t, "Bearer bob", bob.GetLogin()) 3102 } 3103 3104 func TestPtr(t *testing.T) { 3105 t.Parallel() 3106 equal := func(t *testing.T, want, got any) { 3107 t.Helper() 3108 if !reflect.DeepEqual(want, got) { 3109 t.Errorf("want %#v, got %#v", want, got) 3110 } 3111 } 3112 3113 equal(t, true, *Ptr(true)) 3114 equal(t, int(10), *Ptr(int(10))) 3115 equal(t, int64(-10), *Ptr(int64(-10))) 3116 equal(t, "str", *Ptr("str")) 3117 } 3118 3119 func TestDeploymentProtectionRuleEvent_GetRunID(t *testing.T) { 3120 t.Parallel() 3121 3122 var want int64 = 123456789 3123 url := "https://api.github.com/repos/dummy-org/dummy-repo/actions/runs/123456789/deployment_protection_rule" 3124 3125 e := DeploymentProtectionRuleEvent{ 3126 DeploymentCallbackURL: &url, 3127 } 3128 3129 got, _ := e.GetRunID() 3130 if got != want { 3131 t.Errorf("want %#v, got %#v", want, got) 3132 } 3133 3134 want = 123456789 3135 url = "repos/dummy-org/dummy-repo/actions/runs/123456789/deployment_protection_rule" 3136 3137 e = DeploymentProtectionRuleEvent{ 3138 DeploymentCallbackURL: &url, 3139 } 3140 3141 got, _ = e.GetRunID() 3142 if got != want { 3143 t.Errorf("want %#v, got %#v", want, got) 3144 } 3145 3146 want = -1 3147 url = "https://api.github.com/repos/dummy-org/dummy-repo/actions/runs/abc123/deployment_protection_rule" 3148 got, err := e.GetRunID() 3149 if err == nil { 3150 t.Errorf("Expected error to be returned") 3151 } 3152 3153 if got != want { 3154 t.Errorf("want %#v, got %#v", want, got) 3155 } 3156 }