github.com/google/go-github/v68@v68.0.0/github/github_test.go (about) 1 // Copyright 2013 The go-github AUTHORS. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 package github 7 8 import ( 9 "context" 10 "encoding/json" 11 "errors" 12 "fmt" 13 "io" 14 "net/http" 15 "net/http/httptest" 16 "net/url" 17 "os" 18 "path/filepath" 19 "reflect" 20 "strconv" 21 "strings" 22 "testing" 23 "time" 24 25 "github.com/google/go-cmp/cmp" 26 ) 27 28 const ( 29 // baseURLPath is a non-empty Client.BaseURL path to use during tests, 30 // to ensure relative URLs are used for all endpoints. See issue #752. 31 baseURLPath = "/api-v3" 32 ) 33 34 // setup sets up a test HTTP server along with a github.Client that is 35 // configured to talk to that test server. Tests should register handlers on 36 // mux which provide mock responses for the API method being tested. 37 func setup(t *testing.T) (client *Client, mux *http.ServeMux, serverURL string) { 38 t.Helper() 39 // mux is the HTTP request multiplexer used with the test server. 40 mux = http.NewServeMux() 41 42 // We want to ensure that tests catch mistakes where the endpoint URL is 43 // specified as absolute rather than relative. It only makes a difference 44 // when there's a non-empty base URL path. So, use that. See issue #752. 45 apiHandler := http.NewServeMux() 46 apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux)) 47 apiHandler.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { 48 fmt.Fprintln(os.Stderr, "FAIL: Client.BaseURL path prefix is not preserved in the request URL:") 49 fmt.Fprintln(os.Stderr) 50 fmt.Fprintln(os.Stderr, "\t"+req.URL.String()) 51 fmt.Fprintln(os.Stderr) 52 fmt.Fprintln(os.Stderr, "\tDid you accidentally use an absolute endpoint URL rather than relative?") 53 fmt.Fprintln(os.Stderr, "\tSee https://github.com/google/go-github/issues/752 for information.") 54 http.Error(w, "Client.BaseURL path prefix is not preserved in the request URL.", http.StatusInternalServerError) 55 }) 56 57 // server is a test HTTP server used to provide mock API responses. 58 server := httptest.NewServer(apiHandler) 59 60 // client is the GitHub client being tested and is 61 // configured to use test server. 62 client = NewClient(nil) 63 url, _ := url.Parse(server.URL + baseURLPath + "/") 64 client.BaseURL = url 65 client.UploadURL = url 66 67 t.Cleanup(server.Close) 68 69 return client, mux, server.URL 70 } 71 72 // openTestFile creates a new file with the given name and content for testing. 73 // In order to ensure the exact file name, this function will create a new temp 74 // directory, and create the file in that directory. The file is automatically 75 // cleaned up after the test. 76 func openTestFile(t *testing.T, name, content string) *os.File { 77 t.Helper() 78 fname := filepath.Join(t.TempDir(), name) 79 err := os.WriteFile(fname, []byte(content), 0600) 80 if err != nil { 81 t.Fatal(err) 82 } 83 file, err := os.Open(fname) 84 if err != nil { 85 t.Fatal(err) 86 } 87 88 t.Cleanup(func() { file.Close() }) 89 90 return file 91 } 92 93 func testMethod(t *testing.T, r *http.Request, want string) { 94 t.Helper() 95 if got := r.Method; got != want { 96 t.Errorf("Request method: %v, want %v", got, want) 97 } 98 } 99 100 type values map[string]string 101 102 func testFormValues(t *testing.T, r *http.Request, values values) { 103 t.Helper() 104 want := url.Values{} 105 for k, v := range values { 106 want.Set(k, v) 107 } 108 109 assertNilError(t, r.ParseForm()) 110 if got := r.Form; !cmp.Equal(got, want) { 111 t.Errorf("Request parameters: %v, want %v", got, want) 112 } 113 } 114 115 func testHeader(t *testing.T, r *http.Request, header string, want string) { 116 t.Helper() 117 if got := r.Header.Get(header); got != want { 118 t.Errorf("Header.Get(%q) returned %q, want %q", header, got, want) 119 } 120 } 121 122 func testURLParseError(t *testing.T, err error) { 123 t.Helper() 124 if err == nil { 125 t.Errorf("Expected error to be returned") 126 } 127 if err, ok := err.(*url.Error); !ok || err.Op != "parse" { 128 t.Errorf("Expected URL parse error, got %+v", err) 129 } 130 } 131 132 func testBody(t *testing.T, r *http.Request, want string) { 133 t.Helper() 134 b, err := io.ReadAll(r.Body) 135 if err != nil { 136 t.Errorf("Error reading request body: %v", err) 137 } 138 if got := string(b); got != want { 139 t.Errorf("request Body is %s, want %s", got, want) 140 } 141 } 142 143 // Test whether the marshaling of v produces JSON that corresponds 144 // to the want string. 145 func testJSONMarshal(t *testing.T, v interface{}, want string) { 146 t.Helper() 147 // Unmarshal the wanted JSON, to verify its correctness, and marshal it back 148 // to sort the keys. 149 u := reflect.New(reflect.TypeOf(v)).Interface() 150 if err := json.Unmarshal([]byte(want), &u); err != nil { 151 t.Errorf("Unable to unmarshal JSON for %v: %v", want, err) 152 } 153 w, err := json.MarshalIndent(u, "", " ") 154 if err != nil { 155 t.Errorf("Unable to marshal JSON for %#v", u) 156 } 157 158 // Marshal the target value. 159 got, err := json.MarshalIndent(v, "", " ") 160 if err != nil { 161 t.Errorf("Unable to marshal JSON for %#v", v) 162 } 163 164 if diff := cmp.Diff(string(w), string(got)); diff != "" { 165 t.Errorf("json.Marshal returned:\n%s\nwant:\n%s\ndiff:\n%v", got, w, diff) 166 } 167 } 168 169 // Test whether the v fields have the url tag and the parsing of v 170 // produces query parameters that corresponds to the want string. 171 func testAddURLOptions(t *testing.T, url string, v interface{}, want string) { 172 t.Helper() 173 174 vt := reflect.Indirect(reflect.ValueOf(v)).Type() 175 for i := 0; i < vt.NumField(); i++ { 176 field := vt.Field(i) 177 if alias, ok := field.Tag.Lookup("url"); ok { 178 if alias == "" { 179 t.Errorf("The field %+v has a blank url tag", field) 180 } 181 } else { 182 t.Errorf("The field %+v has no url tag specified", field) 183 } 184 } 185 186 got, err := addOptions(url, v) 187 if err != nil { 188 t.Errorf("Unable to add %#v as query parameters", v) 189 } 190 191 if got != want { 192 t.Errorf("addOptions(%q, %#v) returned %v, want %v", url, v, got, want) 193 } 194 } 195 196 // Test how bad options are handled. Method f under test should 197 // return an error. 198 func testBadOptions(t *testing.T, methodName string, f func() error) { 199 t.Helper() 200 if methodName == "" { 201 t.Error("testBadOptions: must supply method methodName") 202 } 203 if err := f(); err == nil { 204 t.Errorf("bad options %v err = nil, want error", methodName) 205 } 206 } 207 208 // Test function under NewRequest failure and then s.client.Do failure. 209 // Method f should be a regular call that would normally succeed, but 210 // should return an error when NewRequest or s.client.Do fails. 211 func testNewRequestAndDoFailure(t *testing.T, methodName string, client *Client, f func() (*Response, error)) { 212 testNewRequestAndDoFailureCategory(t, methodName, client, CoreCategory, f) 213 } 214 215 // testNewRequestAndDoFailureCategory works Like testNewRequestAndDoFailure, but allows setting the category. 216 func testNewRequestAndDoFailureCategory(t *testing.T, methodName string, client *Client, category RateLimitCategory, f func() (*Response, error)) { 217 t.Helper() 218 if methodName == "" { 219 t.Error("testNewRequestAndDoFailure: must supply method methodName") 220 } 221 222 client.BaseURL.Path = "" 223 resp, err := f() 224 if resp != nil { 225 t.Errorf("client.BaseURL.Path='' %v resp = %#v, want nil", methodName, resp) 226 } 227 if err == nil { 228 t.Errorf("client.BaseURL.Path='' %v err = nil, want error", methodName) 229 } 230 231 client.BaseURL.Path = "/api-v3/" 232 client.rateLimits[category].Reset.Time = time.Now().Add(10 * time.Minute) 233 resp, err = f() 234 if bypass := resp.Request.Context().Value(bypassRateLimitCheck); bypass != nil { 235 return 236 } 237 if want := http.StatusForbidden; resp == nil || resp.Response.StatusCode != want { 238 if resp != nil { 239 t.Errorf("rate.Reset.Time > now %v resp = %#v, want StatusCode=%v", methodName, resp.Response, want) 240 } else { 241 t.Errorf("rate.Reset.Time > now %v resp = nil, want StatusCode=%v", methodName, want) 242 } 243 } 244 if err == nil { 245 t.Errorf("rate.Reset.Time > now %v err = nil, want error", methodName) 246 } 247 } 248 249 // Test that all error response types contain the status code. 250 func testErrorResponseForStatusCode(t *testing.T, code int) { 251 t.Helper() 252 client, mux, _ := setup(t) 253 254 mux.HandleFunc("/repos/o/r/hooks", func(w http.ResponseWriter, r *http.Request) { 255 testMethod(t, r, "GET") 256 w.WriteHeader(code) 257 }) 258 259 ctx := context.Background() 260 _, _, err := client.Repositories.ListHooks(ctx, "o", "r", nil) 261 262 switch e := err.(type) { 263 case *ErrorResponse: 264 case *RateLimitError: 265 case *AbuseRateLimitError: 266 if code != e.Response.StatusCode { 267 t.Error("Error response does not contain status code") 268 } 269 default: 270 t.Error("Unknown error response type") 271 } 272 } 273 274 func assertNoDiff(t *testing.T, want, got interface{}) { 275 t.Helper() 276 if diff := cmp.Diff(want, got); diff != "" { 277 t.Errorf("diff mismatch (-want +got):\n%v", diff) 278 } 279 } 280 281 func assertNilError(t *testing.T, err error) { 282 t.Helper() 283 if err != nil { 284 t.Errorf("unexpected error: %v", err) 285 } 286 } 287 288 func assertWrite(t *testing.T, w io.Writer, data []byte) { 289 t.Helper() 290 _, err := w.Write(data) 291 assertNilError(t, err) 292 } 293 294 func TestNewClient(t *testing.T) { 295 t.Parallel() 296 c := NewClient(nil) 297 298 if got, want := c.BaseURL.String(), defaultBaseURL; got != want { 299 t.Errorf("NewClient BaseURL is %v, want %v", got, want) 300 } 301 if got, want := c.UserAgent, defaultUserAgent; got != want { 302 t.Errorf("NewClient UserAgent is %v, want %v", got, want) 303 } 304 305 c2 := NewClient(nil) 306 if c.client == c2.client { 307 t.Error("NewClient returned same http.Clients, but they should differ") 308 } 309 } 310 311 func TestNewClientWithEnvProxy(t *testing.T) { 312 t.Parallel() 313 client := NewClientWithEnvProxy() 314 if got, want := client.BaseURL.String(), defaultBaseURL; got != want { 315 t.Errorf("NewClient BaseURL is %v, want %v", got, want) 316 } 317 } 318 319 func TestClient(t *testing.T) { 320 t.Parallel() 321 c := NewClient(nil) 322 c2 := c.Client() 323 if c.client == c2 { 324 t.Error("Client returned same http.Client, but should be different") 325 } 326 } 327 328 func TestWithAuthToken(t *testing.T) { 329 t.Parallel() 330 token := "gh_test_token" 331 332 validate := func(t *testing.T, c *http.Client, token string) { 333 t.Helper() 334 want := token 335 if want != "" { 336 want = "Bearer " + want 337 } 338 gotReq := false 339 headerVal := "" 340 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 341 gotReq = true 342 headerVal = r.Header.Get("Authorization") 343 })) 344 _, err := c.Get(srv.URL) 345 assertNilError(t, err) 346 if !gotReq { 347 t.Error("request not sent") 348 } 349 if headerVal != want { 350 t.Errorf("Authorization header is %v, want %v", headerVal, want) 351 } 352 } 353 354 t.Run("zero-value Client", func(t *testing.T) { 355 t.Parallel() 356 c := new(Client).WithAuthToken(token) 357 validate(t, c.Client(), token) 358 }) 359 360 t.Run("NewClient", func(t *testing.T) { 361 t.Parallel() 362 httpClient := &http.Client{} 363 client := NewClient(httpClient).WithAuthToken(token) 364 validate(t, client.Client(), token) 365 // make sure the original client isn't setting auth headers now 366 validate(t, httpClient, "") 367 }) 368 369 t.Run("NewTokenClient", func(t *testing.T) { 370 t.Parallel() 371 validate(t, NewTokenClient(context.Background(), token).Client(), token) 372 }) 373 } 374 375 func TestWithEnterpriseURLs(t *testing.T) { 376 t.Parallel() 377 for _, test := range []struct { 378 name string 379 baseURL string 380 wantBaseURL string 381 uploadURL string 382 wantUploadURL string 383 wantErr string 384 }{ 385 { 386 name: "does not modify properly formed URLs", 387 baseURL: "https://custom-url/api/v3/", 388 wantBaseURL: "https://custom-url/api/v3/", 389 uploadURL: "https://custom-upload-url/api/uploads/", 390 wantUploadURL: "https://custom-upload-url/api/uploads/", 391 }, 392 { 393 name: "adds trailing slash", 394 baseURL: "https://custom-url/api/v3", 395 wantBaseURL: "https://custom-url/api/v3/", 396 uploadURL: "https://custom-upload-url/api/uploads", 397 wantUploadURL: "https://custom-upload-url/api/uploads/", 398 }, 399 { 400 name: "adds enterprise suffix", 401 baseURL: "https://custom-url/", 402 wantBaseURL: "https://custom-url/api/v3/", 403 uploadURL: "https://custom-upload-url/", 404 wantUploadURL: "https://custom-upload-url/api/uploads/", 405 }, 406 { 407 name: "adds enterprise suffix and trailing slash", 408 baseURL: "https://custom-url", 409 wantBaseURL: "https://custom-url/api/v3/", 410 uploadURL: "https://custom-upload-url", 411 wantUploadURL: "https://custom-upload-url/api/uploads/", 412 }, 413 { 414 name: "bad base URL", 415 baseURL: "bogus\nbase\nURL", 416 uploadURL: "https://custom-upload-url/api/uploads/", 417 wantErr: `invalid control character in URL`, 418 }, 419 { 420 name: "bad upload URL", 421 baseURL: "https://custom-url/api/v3/", 422 uploadURL: "bogus\nupload\nURL", 423 wantErr: `invalid control character in URL`, 424 }, 425 { 426 name: "URL has existing API prefix, adds trailing slash", 427 baseURL: "https://api.custom-url", 428 wantBaseURL: "https://api.custom-url/", 429 uploadURL: "https://api.custom-upload-url", 430 wantUploadURL: "https://api.custom-upload-url/", 431 }, 432 { 433 name: "URL has existing API prefix and trailing slash", 434 baseURL: "https://api.custom-url/", 435 wantBaseURL: "https://api.custom-url/", 436 uploadURL: "https://api.custom-upload-url/", 437 wantUploadURL: "https://api.custom-upload-url/", 438 }, 439 { 440 name: "URL has API subdomain, adds trailing slash", 441 baseURL: "https://catalog.api.custom-url", 442 wantBaseURL: "https://catalog.api.custom-url/", 443 uploadURL: "https://catalog.api.custom-upload-url", 444 wantUploadURL: "https://catalog.api.custom-upload-url/", 445 }, 446 { 447 name: "URL has API subdomain and trailing slash", 448 baseURL: "https://catalog.api.custom-url/", 449 wantBaseURL: "https://catalog.api.custom-url/", 450 uploadURL: "https://catalog.api.custom-upload-url/", 451 wantUploadURL: "https://catalog.api.custom-upload-url/", 452 }, 453 { 454 name: "URL is not a proper API subdomain, adds enterprise suffix and slash", 455 baseURL: "https://cloud-api.custom-url", 456 wantBaseURL: "https://cloud-api.custom-url/api/v3/", 457 uploadURL: "https://cloud-api.custom-upload-url", 458 wantUploadURL: "https://cloud-api.custom-upload-url/api/uploads/", 459 }, 460 { 461 name: "URL is not a proper API subdomain, adds enterprise suffix", 462 baseURL: "https://cloud-api.custom-url/", 463 wantBaseURL: "https://cloud-api.custom-url/api/v3/", 464 uploadURL: "https://cloud-api.custom-upload-url/", 465 wantUploadURL: "https://cloud-api.custom-upload-url/api/uploads/", 466 }, 467 } { 468 test := test 469 t.Run(test.name, func(t *testing.T) { 470 t.Parallel() 471 validate := func(c *Client, err error) { 472 t.Helper() 473 if test.wantErr != "" { 474 if err == nil || !strings.Contains(err.Error(), test.wantErr) { 475 t.Fatalf("error does not contain expected string %q: %v", test.wantErr, err) 476 } 477 return 478 } 479 if err != nil { 480 t.Fatalf("got unexpected error: %v", err) 481 } 482 if c.BaseURL.String() != test.wantBaseURL { 483 t.Errorf("BaseURL is %v, want %v", c.BaseURL.String(), test.wantBaseURL) 484 } 485 if c.UploadURL.String() != test.wantUploadURL { 486 t.Errorf("UploadURL is %v, want %v", c.UploadURL.String(), test.wantUploadURL) 487 } 488 } 489 validate(NewClient(nil).WithEnterpriseURLs(test.baseURL, test.uploadURL)) 490 validate(new(Client).WithEnterpriseURLs(test.baseURL, test.uploadURL)) 491 validate(NewEnterpriseClient(test.baseURL, test.uploadURL, nil)) 492 }) 493 } 494 } 495 496 // Ensure that length of Client.rateLimits is the same as number of fields in RateLimits struct. 497 func TestClient_rateLimits(t *testing.T) { 498 t.Parallel() 499 if got, want := len(Client{}.rateLimits), reflect.TypeOf(RateLimits{}).NumField(); got != want { 500 t.Errorf("len(Client{}.rateLimits) is %v, want %v", got, want) 501 } 502 } 503 504 func TestNewRequest(t *testing.T) { 505 t.Parallel() 506 c := NewClient(nil) 507 508 inURL, outURL := "/foo", defaultBaseURL+"foo" 509 inBody, outBody := &User{Login: Ptr("l")}, `{"login":"l"}`+"\n" 510 req, _ := c.NewRequest("GET", inURL, inBody) 511 512 // test that relative URL was expanded 513 if got, want := req.URL.String(), outURL; got != want { 514 t.Errorf("NewRequest(%q) URL is %v, want %v", inURL, got, want) 515 } 516 517 // test that body was JSON encoded 518 body, _ := io.ReadAll(req.Body) 519 if got, want := string(body), outBody; got != want { 520 t.Errorf("NewRequest(%q) Body is %v, want %v", inBody, got, want) 521 } 522 523 userAgent := req.Header.Get("User-Agent") 524 525 // test that default user-agent is attached to the request 526 if got, want := userAgent, c.UserAgent; got != want { 527 t.Errorf("NewRequest() User-Agent is %v, want %v", got, want) 528 } 529 530 if !strings.Contains(userAgent, Version) { 531 t.Errorf("NewRequest() User-Agent should contain %v, found %v", Version, userAgent) 532 } 533 534 apiVersion := req.Header.Get(headerAPIVersion) 535 if got, want := apiVersion, defaultAPIVersion; got != want { 536 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 537 } 538 539 req, _ = c.NewRequest("GET", inURL, inBody, WithVersion("2022-11-29")) 540 apiVersion = req.Header.Get(headerAPIVersion) 541 if got, want := apiVersion, "2022-11-29"; got != want { 542 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 543 } 544 } 545 546 func TestNewRequest_invalidJSON(t *testing.T) { 547 t.Parallel() 548 c := NewClient(nil) 549 550 type T struct { 551 A map[interface{}]interface{} 552 } 553 _, err := c.NewRequest("GET", ".", &T{}) 554 555 if err == nil { 556 t.Error("Expected error to be returned.") 557 } 558 if err, ok := err.(*json.UnsupportedTypeError); !ok { 559 t.Errorf("Expected a JSON error; got %#v.", err) 560 } 561 } 562 563 func TestNewRequest_badURL(t *testing.T) { 564 t.Parallel() 565 c := NewClient(nil) 566 _, err := c.NewRequest("GET", ":", nil) 567 testURLParseError(t, err) 568 } 569 570 func TestNewRequest_badMethod(t *testing.T) { 571 t.Parallel() 572 c := NewClient(nil) 573 if _, err := c.NewRequest("BOGUS\nMETHOD", ".", nil); err == nil { 574 t.Fatal("NewRequest returned nil; expected error") 575 } 576 } 577 578 // ensure that no User-Agent header is set if the client's UserAgent is empty. 579 // This caused a problem with Google's internal http client. 580 func TestNewRequest_emptyUserAgent(t *testing.T) { 581 t.Parallel() 582 c := NewClient(nil) 583 c.UserAgent = "" 584 req, err := c.NewRequest("GET", ".", nil) 585 if err != nil { 586 t.Fatalf("NewRequest returned unexpected error: %v", err) 587 } 588 if _, ok := req.Header["User-Agent"]; ok { 589 t.Fatal("constructed request contains unexpected User-Agent header") 590 } 591 } 592 593 // If a nil body is passed to github.NewRequest, make sure that nil is also 594 // passed to http.NewRequest. In most cases, passing an io.Reader that returns 595 // no content is fine, since there is no difference between an HTTP request 596 // body that is an empty string versus one that is not set at all. However in 597 // certain cases, intermediate systems may treat these differently resulting in 598 // subtle errors. 599 func TestNewRequest_emptyBody(t *testing.T) { 600 t.Parallel() 601 c := NewClient(nil) 602 req, err := c.NewRequest("GET", ".", nil) 603 if err != nil { 604 t.Fatalf("NewRequest returned unexpected error: %v", err) 605 } 606 if req.Body != nil { 607 t.Fatalf("constructed request contains a non-nil Body") 608 } 609 } 610 611 func TestNewRequest_errorForNoTrailingSlash(t *testing.T) { 612 t.Parallel() 613 tests := []struct { 614 rawurl string 615 wantError bool 616 }{ 617 {rawurl: "https://example.com/api/v3", wantError: true}, 618 {rawurl: "https://example.com/api/v3/", wantError: false}, 619 } 620 c := NewClient(nil) 621 for _, test := range tests { 622 u, err := url.Parse(test.rawurl) 623 if err != nil { 624 t.Fatalf("url.Parse returned unexpected error: %v.", err) 625 } 626 c.BaseURL = u 627 if _, err := c.NewRequest(http.MethodGet, "test", nil); test.wantError && err == nil { 628 t.Fatalf("Expected error to be returned.") 629 } else if !test.wantError && err != nil { 630 t.Fatalf("NewRequest returned unexpected error: %v.", err) 631 } 632 } 633 } 634 635 func TestNewFormRequest(t *testing.T) { 636 t.Parallel() 637 c := NewClient(nil) 638 639 inURL, outURL := "/foo", defaultBaseURL+"foo" 640 form := url.Values{} 641 form.Add("login", "l") 642 inBody, outBody := strings.NewReader(form.Encode()), "login=l" 643 req, _ := c.NewFormRequest(inURL, inBody) 644 645 // test that relative URL was expanded 646 if got, want := req.URL.String(), outURL; got != want { 647 t.Errorf("NewFormRequest(%q) URL is %v, want %v", inURL, got, want) 648 } 649 650 // test that body was form encoded 651 body, _ := io.ReadAll(req.Body) 652 if got, want := string(body), outBody; got != want { 653 t.Errorf("NewFormRequest(%q) Body is %v, want %v", inBody, got, want) 654 } 655 656 // test that default user-agent is attached to the request 657 if got, want := req.Header.Get("User-Agent"), c.UserAgent; got != want { 658 t.Errorf("NewFormRequest() User-Agent is %v, want %v", got, want) 659 } 660 661 apiVersion := req.Header.Get(headerAPIVersion) 662 if got, want := apiVersion, defaultAPIVersion; got != want { 663 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 664 } 665 666 req, _ = c.NewFormRequest(inURL, inBody, WithVersion("2022-11-29")) 667 apiVersion = req.Header.Get(headerAPIVersion) 668 if got, want := apiVersion, "2022-11-29"; got != want { 669 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 670 } 671 } 672 673 func TestNewFormRequest_badURL(t *testing.T) { 674 t.Parallel() 675 c := NewClient(nil) 676 _, err := c.NewFormRequest(":", nil) 677 testURLParseError(t, err) 678 } 679 680 func TestNewFormRequest_emptyUserAgent(t *testing.T) { 681 t.Parallel() 682 c := NewClient(nil) 683 c.UserAgent = "" 684 req, err := c.NewFormRequest(".", nil) 685 if err != nil { 686 t.Fatalf("NewFormRequest returned unexpected error: %v", err) 687 } 688 if _, ok := req.Header["User-Agent"]; ok { 689 t.Fatal("constructed request contains unexpected User-Agent header") 690 } 691 } 692 693 func TestNewFormRequest_emptyBody(t *testing.T) { 694 t.Parallel() 695 c := NewClient(nil) 696 req, err := c.NewFormRequest(".", nil) 697 if err != nil { 698 t.Fatalf("NewFormRequest returned unexpected error: %v", err) 699 } 700 if req.Body != nil { 701 t.Fatalf("constructed request contains a non-nil Body") 702 } 703 } 704 705 func TestNewFormRequest_errorForNoTrailingSlash(t *testing.T) { 706 t.Parallel() 707 tests := []struct { 708 rawURL string 709 wantError bool 710 }{ 711 {rawURL: "https://example.com/api/v3", wantError: true}, 712 {rawURL: "https://example.com/api/v3/", wantError: false}, 713 } 714 c := NewClient(nil) 715 for _, test := range tests { 716 u, err := url.Parse(test.rawURL) 717 if err != nil { 718 t.Fatalf("url.Parse returned unexpected error: %v.", err) 719 } 720 c.BaseURL = u 721 if _, err := c.NewFormRequest("test", nil); test.wantError && err == nil { 722 t.Fatalf("Expected error to be returned.") 723 } else if !test.wantError && err != nil { 724 t.Fatalf("NewFormRequest returned unexpected error: %v.", err) 725 } 726 } 727 } 728 729 func TestNewUploadRequest_WithVersion(t *testing.T) { 730 t.Parallel() 731 c := NewClient(nil) 732 req, _ := c.NewUploadRequest("https://example.com/", nil, 0, "") 733 734 apiVersion := req.Header.Get(headerAPIVersion) 735 if got, want := apiVersion, defaultAPIVersion; got != want { 736 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 737 } 738 739 req, _ = c.NewUploadRequest("https://example.com/", nil, 0, "", WithVersion("2022-11-29")) 740 apiVersion = req.Header.Get(headerAPIVersion) 741 if got, want := apiVersion, "2022-11-29"; got != want { 742 t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) 743 } 744 } 745 746 func TestNewUploadRequest_badURL(t *testing.T) { 747 t.Parallel() 748 c := NewClient(nil) 749 _, err := c.NewUploadRequest(":", nil, 0, "") 750 testURLParseError(t, err) 751 752 const methodName = "NewUploadRequest" 753 testBadOptions(t, methodName, func() (err error) { 754 _, err = c.NewUploadRequest("\n", nil, -1, "\n") 755 return err 756 }) 757 } 758 759 func TestNewUploadRequest_errorForNoTrailingSlash(t *testing.T) { 760 t.Parallel() 761 tests := []struct { 762 rawurl string 763 wantError bool 764 }{ 765 {rawurl: "https://example.com/api/uploads", wantError: true}, 766 {rawurl: "https://example.com/api/uploads/", wantError: false}, 767 } 768 c := NewClient(nil) 769 for _, test := range tests { 770 u, err := url.Parse(test.rawurl) 771 if err != nil { 772 t.Fatalf("url.Parse returned unexpected error: %v.", err) 773 } 774 c.UploadURL = u 775 if _, err = c.NewUploadRequest("test", nil, 0, ""); test.wantError && err == nil { 776 t.Fatalf("Expected error to be returned.") 777 } else if !test.wantError && err != nil { 778 t.Fatalf("NewUploadRequest returned unexpected error: %v.", err) 779 } 780 } 781 } 782 783 func TestResponse_populatePageValues(t *testing.T) { 784 t.Parallel() 785 r := http.Response{ 786 Header: http.Header{ 787 "Link": {`<https://api.github.com/?page=1>; rel="first",` + 788 ` <https://api.github.com/?page=2>; rel="prev",` + 789 ` <https://api.github.com/?page=4>; rel="next",` + 790 ` <https://api.github.com/?page=5>; rel="last"`, 791 }, 792 }, 793 } 794 795 response := newResponse(&r) 796 if got, want := response.FirstPage, 1; got != want { 797 t.Errorf("response.FirstPage: %v, want %v", got, want) 798 } 799 if got, want := response.PrevPage, 2; want != got { 800 t.Errorf("response.PrevPage: %v, want %v", got, want) 801 } 802 if got, want := response.NextPage, 4; want != got { 803 t.Errorf("response.NextPage: %v, want %v", got, want) 804 } 805 if got, want := response.LastPage, 5; want != got { 806 t.Errorf("response.LastPage: %v, want %v", got, want) 807 } 808 if got, want := response.NextPageToken, ""; want != got { 809 t.Errorf("response.NextPageToken: %v, want %v", got, want) 810 } 811 } 812 813 func TestResponse_populateSinceValues(t *testing.T) { 814 t.Parallel() 815 r := http.Response{ 816 Header: http.Header{ 817 "Link": {`<https://api.github.com/?since=1>; rel="first",` + 818 ` <https://api.github.com/?since=2>; rel="prev",` + 819 ` <https://api.github.com/?since=4>; rel="next",` + 820 ` <https://api.github.com/?since=5>; rel="last"`, 821 }, 822 }, 823 } 824 825 response := newResponse(&r) 826 if got, want := response.FirstPage, 1; got != want { 827 t.Errorf("response.FirstPage: %v, want %v", got, want) 828 } 829 if got, want := response.PrevPage, 2; want != got { 830 t.Errorf("response.PrevPage: %v, want %v", got, want) 831 } 832 if got, want := response.NextPage, 4; want != got { 833 t.Errorf("response.NextPage: %v, want %v", got, want) 834 } 835 if got, want := response.LastPage, 5; want != got { 836 t.Errorf("response.LastPage: %v, want %v", got, want) 837 } 838 if got, want := response.NextPageToken, ""; want != got { 839 t.Errorf("response.NextPageToken: %v, want %v", got, want) 840 } 841 } 842 843 func TestResponse_SinceWithPage(t *testing.T) { 844 t.Parallel() 845 r := http.Response{ 846 Header: http.Header{ 847 "Link": {`<https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=1>; rel="first",` + 848 ` <https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=2>; rel="prev",` + 849 ` <https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=4>; rel="next",` + 850 ` <https://api.github.com/?since=2021-12-04T10%3A43%3A42Z&page=5>; rel="last"`, 851 }, 852 }, 853 } 854 855 response := newResponse(&r) 856 if got, want := response.FirstPage, 1; got != want { 857 t.Errorf("response.FirstPage: %v, want %v", got, want) 858 } 859 if got, want := response.PrevPage, 2; want != got { 860 t.Errorf("response.PrevPage: %v, want %v", got, want) 861 } 862 if got, want := response.NextPage, 4; want != got { 863 t.Errorf("response.NextPage: %v, want %v", got, want) 864 } 865 if got, want := response.LastPage, 5; want != got { 866 t.Errorf("response.LastPage: %v, want %v", got, want) 867 } 868 if got, want := response.NextPageToken, ""; want != got { 869 t.Errorf("response.NextPageToken: %v, want %v", got, want) 870 } 871 } 872 873 func TestResponse_cursorPagination(t *testing.T) { 874 t.Parallel() 875 r := http.Response{ 876 Header: http.Header{ 877 "Status": {"200 OK"}, 878 "Link": {`<https://api.github.com/resource?per_page=2&page=url-encoded-next-page-token>; rel="next"`}, 879 }, 880 } 881 882 response := newResponse(&r) 883 if got, want := response.FirstPage, 0; got != want { 884 t.Errorf("response.FirstPage: %v, want %v", got, want) 885 } 886 if got, want := response.PrevPage, 0; want != got { 887 t.Errorf("response.PrevPage: %v, want %v", got, want) 888 } 889 if got, want := response.NextPage, 0; want != got { 890 t.Errorf("response.NextPage: %v, want %v", got, want) 891 } 892 if got, want := response.LastPage, 0; want != got { 893 t.Errorf("response.LastPage: %v, want %v", got, want) 894 } 895 if got, want := response.NextPageToken, "url-encoded-next-page-token"; want != got { 896 t.Errorf("response.NextPageToken: %v, want %v", got, want) 897 } 898 899 // cursor-based pagination with "cursor" param 900 r = http.Response{ 901 Header: http.Header{ 902 "Link": { 903 `<https://api.github.com/?cursor=v1_12345678>; rel="next"`, 904 }, 905 }, 906 } 907 908 response = newResponse(&r) 909 if got, want := response.Cursor, "v1_12345678"; got != want { 910 t.Errorf("response.Cursor: %v, want %v", got, want) 911 } 912 } 913 914 func TestResponse_beforeAfterPagination(t *testing.T) { 915 t.Parallel() 916 r := http.Response{ 917 Header: http.Header{ 918 "Link": {`<https://api.github.com/?after=a1b2c3&before=>; rel="next",` + 919 ` <https://api.github.com/?after=&before=>; rel="first",` + 920 ` <https://api.github.com/?after=&before=d4e5f6>; rel="prev",`, 921 }, 922 }, 923 } 924 925 response := newResponse(&r) 926 if got, want := response.Before, "d4e5f6"; got != want { 927 t.Errorf("response.Before: %v, want %v", got, want) 928 } 929 if got, want := response.After, "a1b2c3"; got != want { 930 t.Errorf("response.After: %v, want %v", got, want) 931 } 932 if got, want := response.FirstPage, 0; got != want { 933 t.Errorf("response.FirstPage: %v, want %v", got, want) 934 } 935 if got, want := response.PrevPage, 0; want != got { 936 t.Errorf("response.PrevPage: %v, want %v", got, want) 937 } 938 if got, want := response.NextPage, 0; want != got { 939 t.Errorf("response.NextPage: %v, want %v", got, want) 940 } 941 if got, want := response.LastPage, 0; want != got { 942 t.Errorf("response.LastPage: %v, want %v", got, want) 943 } 944 if got, want := response.NextPageToken, ""; want != got { 945 t.Errorf("response.NextPageToken: %v, want %v", got, want) 946 } 947 } 948 949 func TestResponse_populatePageValues_invalid(t *testing.T) { 950 t.Parallel() 951 r := http.Response{ 952 Header: http.Header{ 953 "Link": {`<https://api.github.com/?page=1>,` + 954 `<https://api.github.com/?page=abc>; rel="first",` + 955 `https://api.github.com/?page=2; rel="prev",` + 956 `<https://api.github.com/>; rel="next",` + 957 `<https://api.github.com/?page=>; rel="last"`, 958 }, 959 }, 960 } 961 962 response := newResponse(&r) 963 if got, want := response.FirstPage, 0; got != want { 964 t.Errorf("response.FirstPage: %v, want %v", got, want) 965 } 966 if got, want := response.PrevPage, 0; got != want { 967 t.Errorf("response.PrevPage: %v, want %v", got, want) 968 } 969 if got, want := response.NextPage, 0; got != want { 970 t.Errorf("response.NextPage: %v, want %v", got, want) 971 } 972 if got, want := response.LastPage, 0; got != want { 973 t.Errorf("response.LastPage: %v, want %v", got, want) 974 } 975 976 // more invalid URLs 977 r = http.Response{ 978 Header: http.Header{ 979 "Link": {`<https://api.github.com/%?page=2>; rel="first"`}, 980 }, 981 } 982 983 response = newResponse(&r) 984 if got, want := response.FirstPage, 0; got != want { 985 t.Errorf("response.FirstPage: %v, want %v", got, want) 986 } 987 } 988 989 func TestResponse_populateSinceValues_invalid(t *testing.T) { 990 t.Parallel() 991 r := http.Response{ 992 Header: http.Header{ 993 "Link": {`<https://api.github.com/?since=1>,` + 994 `<https://api.github.com/?since=abc>; rel="first",` + 995 `https://api.github.com/?since=2; rel="prev",` + 996 `<https://api.github.com/>; rel="next",` + 997 `<https://api.github.com/?since=>; rel="last"`, 998 }, 999 }, 1000 } 1001 1002 response := newResponse(&r) 1003 if got, want := response.FirstPage, 0; got != want { 1004 t.Errorf("response.FirstPage: %v, want %v", got, want) 1005 } 1006 if got, want := response.PrevPage, 0; got != want { 1007 t.Errorf("response.PrevPage: %v, want %v", got, want) 1008 } 1009 if got, want := response.NextPage, 0; got != want { 1010 t.Errorf("response.NextPage: %v, want %v", got, want) 1011 } 1012 if got, want := response.LastPage, 0; got != want { 1013 t.Errorf("response.LastPage: %v, want %v", got, want) 1014 } 1015 1016 // more invalid URLs 1017 r = http.Response{ 1018 Header: http.Header{ 1019 "Link": {`<https://api.github.com/%?since=2>; rel="first"`}, 1020 }, 1021 } 1022 1023 response = newResponse(&r) 1024 if got, want := response.FirstPage, 0; got != want { 1025 t.Errorf("response.FirstPage: %v, want %v", got, want) 1026 } 1027 } 1028 1029 func TestDo(t *testing.T) { 1030 t.Parallel() 1031 client, mux, _ := setup(t) 1032 1033 type foo struct { 1034 A string 1035 } 1036 1037 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1038 testMethod(t, r, "GET") 1039 fmt.Fprint(w, `{"A":"a"}`) 1040 }) 1041 1042 req, _ := client.NewRequest("GET", ".", nil) 1043 body := new(foo) 1044 ctx := context.Background() 1045 _, err := client.Do(ctx, req, body) 1046 assertNilError(t, err) 1047 1048 want := &foo{"a"} 1049 if !cmp.Equal(body, want) { 1050 t.Errorf("Response body = %v, want %v", body, want) 1051 } 1052 } 1053 1054 func TestDo_nilContext(t *testing.T) { 1055 t.Parallel() 1056 client, _, _ := setup(t) 1057 1058 req, _ := client.NewRequest("GET", ".", nil) 1059 _, err := client.Do(nil, req, nil) 1060 1061 if !errors.Is(err, errNonNilContext) { 1062 t.Errorf("Expected context must be non-nil error") 1063 } 1064 } 1065 1066 func TestDo_httpError(t *testing.T) { 1067 t.Parallel() 1068 client, mux, _ := setup(t) 1069 1070 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1071 http.Error(w, "Bad Request", 400) 1072 }) 1073 1074 req, _ := client.NewRequest("GET", ".", nil) 1075 ctx := context.Background() 1076 resp, err := client.Do(ctx, req, nil) 1077 1078 if err == nil { 1079 t.Fatal("Expected HTTP 400 error, got no error.") 1080 } 1081 if resp.StatusCode != 400 { 1082 t.Errorf("Expected HTTP 400 error, got %d status code.", resp.StatusCode) 1083 } 1084 } 1085 1086 // Test handling of an error caused by the internal http client's Do() 1087 // function. A redirect loop is pretty unlikely to occur within the GitHub 1088 // API, but does allow us to exercise the right code path. 1089 func TestDo_redirectLoop(t *testing.T) { 1090 t.Parallel() 1091 client, mux, _ := setup(t) 1092 1093 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1094 http.Redirect(w, r, baseURLPath, http.StatusFound) 1095 }) 1096 1097 req, _ := client.NewRequest("GET", ".", nil) 1098 ctx := context.Background() 1099 _, err := client.Do(ctx, req, nil) 1100 1101 if err == nil { 1102 t.Error("Expected error to be returned.") 1103 } 1104 if err, ok := err.(*url.Error); !ok { 1105 t.Errorf("Expected a URL error; got %#v.", err) 1106 } 1107 } 1108 1109 func TestDo_preservesResponseInHTTPError(t *testing.T) { 1110 t.Parallel() 1111 client, mux, _ := setup(t) 1112 1113 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1114 w.Header().Set("Content-Type", "application/json") 1115 w.WriteHeader(http.StatusNotFound) 1116 fmt.Fprintf(w, `{ 1117 "message": "Resource not found", 1118 "documentation_url": "https://docs.github.com/rest/reference/repos#get-a-repository" 1119 }`) 1120 }) 1121 1122 req, _ := client.NewRequest("GET", ".", nil) 1123 var resp *Response 1124 var data interface{} 1125 resp, err := client.Do(context.Background(), req, &data) 1126 1127 if err == nil { 1128 t.Fatal("Expected error response") 1129 } 1130 1131 // Verify error type and access to status code 1132 errResp, ok := err.(*ErrorResponse) 1133 if !ok { 1134 t.Fatalf("Expected *ErrorResponse error, got %T", err) 1135 } 1136 1137 // Verify status code is accessible from both Response and ErrorResponse 1138 if resp == nil { 1139 t.Fatal("Expected response to be returned even with error") 1140 } 1141 if got, want := resp.StatusCode, http.StatusNotFound; got != want { 1142 t.Errorf("Response status = %d, want %d", got, want) 1143 } 1144 if got, want := errResp.Response.StatusCode, http.StatusNotFound; got != want { 1145 t.Errorf("Error response status = %d, want %d", got, want) 1146 } 1147 1148 // Verify error contains proper message 1149 if !strings.Contains(errResp.Message, "Resource not found") { 1150 t.Errorf("Error message = %q, want to contain 'Resource not found'", errResp.Message) 1151 } 1152 } 1153 1154 // Test that an error caused by the internal http client's Do() function 1155 // does not leak the client secret. 1156 func TestDo_sanitizeURL(t *testing.T) { 1157 t.Parallel() 1158 tp := &UnauthenticatedRateLimitedTransport{ 1159 ClientID: "id", 1160 ClientSecret: "secret", 1161 } 1162 unauthedClient := NewClient(tp.Client()) 1163 unauthedClient.BaseURL = &url.URL{Scheme: "http", Host: "127.0.0.1:0", Path: "/"} // Use port 0 on purpose to trigger a dial TCP error, expect to get "dial tcp 127.0.0.1:0: connect: can't assign requested address". 1164 req, err := unauthedClient.NewRequest("GET", ".", nil) 1165 if err != nil { 1166 t.Fatalf("NewRequest returned unexpected error: %v", err) 1167 } 1168 ctx := context.Background() 1169 _, err = unauthedClient.Do(ctx, req, nil) 1170 if err == nil { 1171 t.Fatal("Expected error to be returned.") 1172 } 1173 if strings.Contains(err.Error(), "client_secret=secret") { 1174 t.Errorf("Do error contains secret, should be redacted:\n%q", err) 1175 } 1176 } 1177 1178 func TestDo_rateLimit(t *testing.T) { 1179 t.Parallel() 1180 client, mux, _ := setup(t) 1181 1182 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1183 w.Header().Set(headerRateLimit, "60") 1184 w.Header().Set(headerRateRemaining, "59") 1185 w.Header().Set(headerRateReset, "1372700873") 1186 }) 1187 1188 req, _ := client.NewRequest("GET", ".", nil) 1189 ctx := context.Background() 1190 resp, err := client.Do(ctx, req, nil) 1191 if err != nil { 1192 t.Errorf("Do returned unexpected error: %v", err) 1193 } 1194 if got, want := resp.Rate.Limit, 60; got != want { 1195 t.Errorf("Client rate limit = %v, want %v", got, want) 1196 } 1197 if got, want := resp.Rate.Remaining, 59; got != want { 1198 t.Errorf("Client rate remaining = %v, want %v", got, want) 1199 } 1200 reset := time.Date(2013, time.July, 1, 17, 47, 53, 0, time.UTC) 1201 if !resp.Rate.Reset.UTC().Equal(reset) { 1202 t.Errorf("Client rate reset = %v, want %v", resp.Rate.Reset.UTC(), reset) 1203 } 1204 } 1205 1206 func TestDo_rateLimitCategory(t *testing.T) { 1207 t.Parallel() 1208 tests := []struct { 1209 method string 1210 url string 1211 category RateLimitCategory 1212 }{ 1213 { 1214 method: http.MethodGet, 1215 url: "/", 1216 category: CoreCategory, 1217 }, 1218 { 1219 method: http.MethodGet, 1220 url: "/search/issues?q=rate", 1221 category: SearchCategory, 1222 }, 1223 { 1224 method: http.MethodGet, 1225 url: "/graphql", 1226 category: GraphqlCategory, 1227 }, 1228 { 1229 method: http.MethodPost, 1230 url: "/app-manifests/code/conversions", 1231 category: IntegrationManifestCategory, 1232 }, 1233 { 1234 method: http.MethodGet, 1235 url: "/app-manifests/code/conversions", 1236 category: CoreCategory, // only POST requests are in the integration manifest category 1237 }, 1238 { 1239 method: http.MethodPut, 1240 url: "/repos/google/go-github/import", 1241 category: SourceImportCategory, 1242 }, 1243 { 1244 method: http.MethodGet, 1245 url: "/repos/google/go-github/import", 1246 category: CoreCategory, // only PUT requests are in the source import category 1247 }, 1248 { 1249 method: http.MethodPost, 1250 url: "/repos/google/go-github/code-scanning/sarifs", 1251 category: CodeScanningUploadCategory, 1252 }, 1253 { 1254 method: http.MethodGet, 1255 url: "/scim/v2/organizations/ORG/Users", 1256 category: ScimCategory, 1257 }, 1258 { 1259 method: http.MethodPost, 1260 url: "/repos/google/go-github/dependency-graph/snapshots", 1261 category: DependencySnapshotsCategory, 1262 }, 1263 { 1264 method: http.MethodGet, 1265 url: "/search/code?q=rate", 1266 category: CodeSearchCategory, 1267 }, 1268 { 1269 method: http.MethodGet, 1270 url: "/orgs/google/audit-log", 1271 category: AuditLogCategory, 1272 }, 1273 // missing a check for actionsRunnerRegistrationCategory: API not found 1274 } 1275 1276 for _, tt := range tests { 1277 if got, want := GetRateLimitCategory(tt.method, tt.url), tt.category; got != want { 1278 t.Errorf("expecting category %v, found %v", got, want) 1279 } 1280 } 1281 } 1282 1283 // Ensure rate limit is still parsed, even for error responses. 1284 func TestDo_rateLimit_errorResponse(t *testing.T) { 1285 t.Parallel() 1286 client, mux, _ := setup(t) 1287 1288 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1289 w.Header().Set(headerRateLimit, "60") 1290 w.Header().Set(headerRateRemaining, "59") 1291 w.Header().Set(headerRateReset, "1372700873") 1292 http.Error(w, "Bad Request", 400) 1293 }) 1294 1295 req, _ := client.NewRequest("GET", ".", nil) 1296 ctx := context.Background() 1297 resp, err := client.Do(ctx, req, nil) 1298 if err == nil { 1299 t.Error("Expected error to be returned.") 1300 } 1301 if _, ok := err.(*RateLimitError); ok { 1302 t.Errorf("Did not expect a *RateLimitError error; got %#v.", err) 1303 } 1304 if got, want := resp.Rate.Limit, 60; got != want { 1305 t.Errorf("Client rate limit = %v, want %v", got, want) 1306 } 1307 if got, want := resp.Rate.Remaining, 59; got != want { 1308 t.Errorf("Client rate remaining = %v, want %v", got, want) 1309 } 1310 reset := time.Date(2013, time.July, 1, 17, 47, 53, 0, time.UTC) 1311 if !resp.Rate.Reset.UTC().Equal(reset) { 1312 t.Errorf("Client rate reset = %v, want %v", resp.Rate.Reset, reset) 1313 } 1314 } 1315 1316 // Ensure *RateLimitError is returned when API rate limit is exceeded. 1317 func TestDo_rateLimit_rateLimitError(t *testing.T) { 1318 t.Parallel() 1319 client, mux, _ := setup(t) 1320 1321 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1322 w.Header().Set(headerRateLimit, "60") 1323 w.Header().Set(headerRateRemaining, "0") 1324 w.Header().Set(headerRateReset, "1372700873") 1325 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1326 w.WriteHeader(http.StatusForbidden) 1327 fmt.Fprintln(w, `{ 1328 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1329 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1330 }`) 1331 }) 1332 1333 req, _ := client.NewRequest("GET", ".", nil) 1334 ctx := context.Background() 1335 _, err := client.Do(ctx, req, nil) 1336 1337 if err == nil { 1338 t.Error("Expected error to be returned.") 1339 } 1340 rateLimitErr, ok := err.(*RateLimitError) 1341 if !ok { 1342 t.Fatalf("Expected a *RateLimitError error; got %#v.", err) 1343 } 1344 if got, want := rateLimitErr.Rate.Limit, 60; got != want { 1345 t.Errorf("rateLimitErr rate limit = %v, want %v", got, want) 1346 } 1347 if got, want := rateLimitErr.Rate.Remaining, 0; got != want { 1348 t.Errorf("rateLimitErr rate remaining = %v, want %v", got, want) 1349 } 1350 reset := time.Date(2013, time.July, 1, 17, 47, 53, 0, time.UTC) 1351 if !rateLimitErr.Rate.Reset.UTC().Equal(reset) { 1352 t.Errorf("rateLimitErr rate reset = %v, want %v", rateLimitErr.Rate.Reset.UTC(), reset) 1353 } 1354 } 1355 1356 // Ensure a network call is not made when it's known that API rate limit is still exceeded. 1357 func TestDo_rateLimit_noNetworkCall(t *testing.T) { 1358 t.Parallel() 1359 client, mux, _ := setup(t) 1360 1361 reset := time.Now().UTC().Add(time.Minute).Round(time.Second) // Rate reset is a minute from now, with 1 second precision. 1362 1363 mux.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) { 1364 w.Header().Set(headerRateLimit, "60") 1365 w.Header().Set(headerRateRemaining, "0") 1366 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1367 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1368 w.WriteHeader(http.StatusForbidden) 1369 fmt.Fprintln(w, `{ 1370 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1371 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1372 }`) 1373 }) 1374 1375 madeNetworkCall := false 1376 mux.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) { 1377 madeNetworkCall = true 1378 }) 1379 1380 // First request is made, and it makes the client aware of rate reset time being in the future. 1381 req, _ := client.NewRequest("GET", "first", nil) 1382 ctx := context.Background() 1383 _, err := client.Do(ctx, req, nil) 1384 if err == nil { 1385 t.Error("Expected error to be returned.") 1386 } 1387 1388 // Second request should not cause a network call to be made, since client can predict a rate limit error. 1389 req, _ = client.NewRequest("GET", "second", nil) 1390 _, err = client.Do(ctx, req, nil) 1391 1392 if madeNetworkCall { 1393 t.Fatal("Network call was made, even though rate limit is known to still be exceeded.") 1394 } 1395 1396 if err == nil { 1397 t.Error("Expected error to be returned.") 1398 } 1399 rateLimitErr, ok := err.(*RateLimitError) 1400 if !ok { 1401 t.Fatalf("Expected a *RateLimitError error; got %#v.", err) 1402 } 1403 if got, want := rateLimitErr.Rate.Limit, 60; got != want { 1404 t.Errorf("rateLimitErr rate limit = %v, want %v", got, want) 1405 } 1406 if got, want := rateLimitErr.Rate.Remaining, 0; got != want { 1407 t.Errorf("rateLimitErr rate remaining = %v, want %v", got, want) 1408 } 1409 if !rateLimitErr.Rate.Reset.UTC().Equal(reset) { 1410 t.Errorf("rateLimitErr rate reset = %v, want %v", rateLimitErr.Rate.Reset.UTC(), reset) 1411 } 1412 } 1413 1414 // Ignore rate limit headers if the response was served from cache. 1415 func TestDo_rateLimit_ignoredFromCache(t *testing.T) { 1416 t.Parallel() 1417 client, mux, _ := setup(t) 1418 1419 reset := time.Now().UTC().Add(time.Minute).Round(time.Second) // Rate reset is a minute from now, with 1 second precision. 1420 1421 // By adding the X-From-Cache header we pretend this is served from a cache. 1422 mux.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) { 1423 w.Header().Set("X-From-Cache", "1") 1424 w.Header().Set(headerRateLimit, "60") 1425 w.Header().Set(headerRateRemaining, "0") 1426 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1427 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1428 w.WriteHeader(http.StatusForbidden) 1429 fmt.Fprintln(w, `{ 1430 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1431 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1432 }`) 1433 }) 1434 1435 madeNetworkCall := false 1436 mux.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) { 1437 madeNetworkCall = true 1438 }) 1439 1440 // First request is made so afterwards we can check the returned rate limit headers were ignored. 1441 req, _ := client.NewRequest("GET", "first", nil) 1442 ctx := context.Background() 1443 _, err := client.Do(ctx, req, nil) 1444 if err == nil { 1445 t.Error("Expected error to be returned.") 1446 } 1447 1448 // Second request should not by hindered by rate limits. 1449 req, _ = client.NewRequest("GET", "second", nil) 1450 _, err = client.Do(ctx, req, nil) 1451 1452 if err != nil { 1453 t.Fatalf("Second request failed, even though the rate limits from the cache should've been ignored: %v", err) 1454 } 1455 if !madeNetworkCall { 1456 t.Fatal("Network call was not made, even though the rate limits from the cache should've been ignored") 1457 } 1458 } 1459 1460 // Ensure sleeps until the rate limit is reset when the client is rate limited. 1461 func TestDo_rateLimit_sleepUntilResponseResetLimit(t *testing.T) { 1462 t.Parallel() 1463 client, mux, _ := setup(t) 1464 1465 reset := time.Now().UTC().Add(time.Second) 1466 1467 var firstRequest = true 1468 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1469 if firstRequest { 1470 firstRequest = false 1471 w.Header().Set(headerRateLimit, "60") 1472 w.Header().Set(headerRateRemaining, "0") 1473 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1474 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1475 w.WriteHeader(http.StatusForbidden) 1476 fmt.Fprintln(w, `{ 1477 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1478 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1479 }`) 1480 return 1481 } 1482 w.Header().Set(headerRateLimit, "5000") 1483 w.Header().Set(headerRateRemaining, "5000") 1484 w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) 1485 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1486 w.WriteHeader(http.StatusOK) 1487 fmt.Fprintln(w, `{}`) 1488 }) 1489 1490 req, _ := client.NewRequest("GET", ".", nil) 1491 ctx := context.Background() 1492 resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1493 if err != nil { 1494 t.Errorf("Do returned unexpected error: %v", err) 1495 } 1496 if got, want := resp.StatusCode, http.StatusOK; got != want { 1497 t.Errorf("Response status code = %v, want %v", got, want) 1498 } 1499 } 1500 1501 // Ensure tries to sleep until the rate limit is reset when the client is rate limited, but only once. 1502 func TestDo_rateLimit_sleepUntilResponseResetLimitRetryOnce(t *testing.T) { 1503 t.Parallel() 1504 client, mux, _ := setup(t) 1505 1506 reset := time.Now().UTC().Add(time.Second) 1507 1508 requestCount := 0 1509 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1510 requestCount++ 1511 w.Header().Set(headerRateLimit, "60") 1512 w.Header().Set(headerRateRemaining, "0") 1513 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1514 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1515 w.WriteHeader(http.StatusForbidden) 1516 fmt.Fprintln(w, `{ 1517 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1518 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1519 }`) 1520 }) 1521 1522 req, _ := client.NewRequest("GET", ".", nil) 1523 ctx := context.Background() 1524 _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1525 if err == nil { 1526 t.Error("Expected error to be returned.") 1527 } 1528 if got, want := requestCount, 2; got != want { 1529 t.Errorf("Expected 2 requests, got %d", got) 1530 } 1531 } 1532 1533 // Ensure a network call is not made when it's known that API rate limit is still exceeded. 1534 func TestDo_rateLimit_sleepUntilClientResetLimit(t *testing.T) { 1535 t.Parallel() 1536 client, mux, _ := setup(t) 1537 1538 reset := time.Now().UTC().Add(time.Second) 1539 client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}} 1540 requestCount := 0 1541 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1542 requestCount++ 1543 w.Header().Set(headerRateLimit, "5000") 1544 w.Header().Set(headerRateRemaining, "5000") 1545 w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) 1546 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1547 w.WriteHeader(http.StatusOK) 1548 fmt.Fprintln(w, `{}`) 1549 }) 1550 req, _ := client.NewRequest("GET", ".", nil) 1551 ctx := context.Background() 1552 resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1553 if err != nil { 1554 t.Errorf("Do returned unexpected error: %v", err) 1555 } 1556 if got, want := resp.StatusCode, http.StatusOK; got != want { 1557 t.Errorf("Response status code = %v, want %v", got, want) 1558 } 1559 if got, want := requestCount, 1; got != want { 1560 t.Errorf("Expected 1 request, got %d", got) 1561 } 1562 } 1563 1564 // Ensure sleep is aborted when the context is cancelled. 1565 func TestDo_rateLimit_abortSleepContextCancelled(t *testing.T) { 1566 t.Parallel() 1567 client, mux, _ := setup(t) 1568 1569 // We use a 1 minute reset time to ensure the sleep is not completed. 1570 reset := time.Now().UTC().Add(time.Minute) 1571 requestCount := 0 1572 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1573 requestCount++ 1574 w.Header().Set(headerRateLimit, "60") 1575 w.Header().Set(headerRateRemaining, "0") 1576 w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) 1577 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1578 w.WriteHeader(http.StatusForbidden) 1579 fmt.Fprintln(w, `{ 1580 "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", 1581 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1582 }`) 1583 }) 1584 1585 req, _ := client.NewRequest("GET", ".", nil) 1586 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 1587 defer cancel() 1588 _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1589 if !errors.Is(err, context.DeadlineExceeded) { 1590 t.Error("Expected context deadline exceeded error.") 1591 } 1592 if got, want := requestCount, 1; got != want { 1593 t.Errorf("Expected 1 requests, got %d", got) 1594 } 1595 } 1596 1597 // Ensure sleep is aborted when the context is cancelled on initial request. 1598 func TestDo_rateLimit_abortSleepContextCancelledClientLimit(t *testing.T) { 1599 t.Parallel() 1600 client, mux, _ := setup(t) 1601 1602 reset := time.Now().UTC().Add(time.Minute) 1603 client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}} 1604 requestCount := 0 1605 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1606 requestCount++ 1607 w.Header().Set(headerRateLimit, "5000") 1608 w.Header().Set(headerRateRemaining, "5000") 1609 w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) 1610 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1611 w.WriteHeader(http.StatusOK) 1612 fmt.Fprintln(w, `{}`) 1613 }) 1614 req, _ := client.NewRequest("GET", ".", nil) 1615 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 1616 defer cancel() 1617 _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) 1618 rateLimitError, ok := err.(*RateLimitError) 1619 if !ok { 1620 t.Fatalf("Expected a *rateLimitError error; got %#v.", err) 1621 } 1622 if got, wantSuffix := rateLimitError.Message, "Context cancelled while waiting for rate limit to reset until"; !strings.HasPrefix(got, wantSuffix) { 1623 t.Errorf("Expected request to be prevented because context cancellation, got: %v.", got) 1624 } 1625 if got, want := requestCount, 0; got != want { 1626 t.Errorf("Expected 1 requests, got %d", got) 1627 } 1628 } 1629 1630 // Ensure *AbuseRateLimitError is returned when the response indicates that 1631 // the client has triggered an abuse detection mechanism. 1632 func TestDo_rateLimit_abuseRateLimitError(t *testing.T) { 1633 t.Parallel() 1634 client, mux, _ := setup(t) 1635 1636 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1637 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1638 w.WriteHeader(http.StatusForbidden) 1639 // When the abuse rate limit error is of the "temporarily blocked from content creation" type, 1640 // there is no "Retry-After" header. 1641 fmt.Fprintln(w, `{ 1642 "message": "You have triggered an abuse detection mechanism and have been temporarily blocked from content creation. Please retry your request again later.", 1643 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1644 }`) 1645 }) 1646 1647 req, _ := client.NewRequest("GET", ".", nil) 1648 ctx := context.Background() 1649 _, err := client.Do(ctx, req, nil) 1650 1651 if err == nil { 1652 t.Error("Expected error to be returned.") 1653 } 1654 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1655 if !ok { 1656 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1657 } 1658 if got, want := abuseRateLimitErr.RetryAfter, (*time.Duration)(nil); got != want { 1659 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1660 } 1661 } 1662 1663 // Ensure *AbuseRateLimitError is returned when the response indicates that 1664 // the client has triggered an abuse detection mechanism on GitHub Enterprise. 1665 func TestDo_rateLimit_abuseRateLimitErrorEnterprise(t *testing.T) { 1666 t.Parallel() 1667 client, mux, _ := setup(t) 1668 1669 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1670 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1671 w.WriteHeader(http.StatusForbidden) 1672 // When the abuse rate limit error is of the "temporarily blocked from content creation" type, 1673 // there is no "Retry-After" header. 1674 // This response returns a documentation url like the one returned for GitHub Enterprise, this 1675 // url changes between versions but follows roughly the same format. 1676 fmt.Fprintln(w, `{ 1677 "message": "You have triggered an abuse detection mechanism and have been temporarily blocked from content creation. Please retry your request again later.", 1678 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1679 }`) 1680 }) 1681 1682 req, _ := client.NewRequest("GET", ".", nil) 1683 ctx := context.Background() 1684 _, err := client.Do(ctx, req, nil) 1685 1686 if err == nil { 1687 t.Error("Expected error to be returned.") 1688 } 1689 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1690 if !ok { 1691 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1692 } 1693 if got, want := abuseRateLimitErr.RetryAfter, (*time.Duration)(nil); got != want { 1694 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1695 } 1696 } 1697 1698 // Ensure *AbuseRateLimitError.RetryAfter is parsed correctly for the Retry-After header. 1699 func TestDo_rateLimit_abuseRateLimitError_retryAfter(t *testing.T) { 1700 t.Parallel() 1701 client, mux, _ := setup(t) 1702 1703 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1704 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1705 w.Header().Set(headerRetryAfter, "123") // Retry after value of 123 seconds. 1706 w.WriteHeader(http.StatusForbidden) 1707 fmt.Fprintln(w, `{ 1708 "message": "You have triggered an abuse detection mechanism ...", 1709 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1710 }`) 1711 }) 1712 1713 req, _ := client.NewRequest("GET", ".", nil) 1714 ctx := context.Background() 1715 _, err := client.Do(ctx, req, nil) 1716 1717 if err == nil { 1718 t.Error("Expected error to be returned.") 1719 } 1720 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1721 if !ok { 1722 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1723 } 1724 if abuseRateLimitErr.RetryAfter == nil { 1725 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1726 } 1727 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; got != want { 1728 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1729 } 1730 1731 // expect prevention of a following request 1732 if _, err = client.Do(ctx, req, nil); err == nil { 1733 t.Error("Expected error to be returned.") 1734 } 1735 abuseRateLimitErr, ok = err.(*AbuseRateLimitError) 1736 if !ok { 1737 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1738 } 1739 if abuseRateLimitErr.RetryAfter == nil { 1740 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1741 } 1742 // the saved duration might be a bit smaller than Retry-After because the duration is calculated from the expected end-of-cooldown time 1743 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; want-got > 1*time.Second { 1744 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1745 } 1746 if got, wantSuffix := abuseRateLimitErr.Message, "not making remote request."; !strings.HasSuffix(got, wantSuffix) { 1747 t.Errorf("Expected request to be prevented because of secondary rate limit, got: %v.", got) 1748 } 1749 } 1750 1751 // Ensure *AbuseRateLimitError.RetryAfter is parsed correctly for the x-ratelimit-reset header. 1752 func TestDo_rateLimit_abuseRateLimitError_xRateLimitReset(t *testing.T) { 1753 t.Parallel() 1754 client, mux, _ := setup(t) 1755 1756 // x-ratelimit-reset value of 123 seconds into the future. 1757 blockUntil := time.Now().Add(time.Duration(123) * time.Second).Unix() 1758 1759 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1760 w.Header().Set("Content-Type", "application/json; charset=utf-8") 1761 w.Header().Set(headerRateReset, strconv.Itoa(int(blockUntil))) 1762 w.Header().Set(headerRateRemaining, "1") // set remaining to a value > 0 to distinct from a primary rate limit 1763 w.WriteHeader(http.StatusForbidden) 1764 fmt.Fprintln(w, `{ 1765 "message": "You have triggered an abuse detection mechanism ...", 1766 "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" 1767 }`) 1768 }) 1769 1770 req, _ := client.NewRequest("GET", ".", nil) 1771 ctx := context.Background() 1772 _, err := client.Do(ctx, req, nil) 1773 1774 if err == nil { 1775 t.Error("Expected error to be returned.") 1776 } 1777 abuseRateLimitErr, ok := err.(*AbuseRateLimitError) 1778 if !ok { 1779 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1780 } 1781 if abuseRateLimitErr.RetryAfter == nil { 1782 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1783 } 1784 // the retry after value might be a bit smaller than the original duration because the duration is calculated from the expected end-of-cooldown time 1785 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; want-got > 1*time.Second { 1786 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1787 } 1788 1789 // expect prevention of a following request 1790 if _, err = client.Do(ctx, req, nil); err == nil { 1791 t.Error("Expected error to be returned.") 1792 } 1793 abuseRateLimitErr, ok = err.(*AbuseRateLimitError) 1794 if !ok { 1795 t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) 1796 } 1797 if abuseRateLimitErr.RetryAfter == nil { 1798 t.Fatalf("abuseRateLimitErr RetryAfter is nil, expected not-nil") 1799 } 1800 // the saved duration might be a bit smaller than Retry-After because the duration is calculated from the expected end-of-cooldown time 1801 if got, want := *abuseRateLimitErr.RetryAfter, 123*time.Second; want-got > 1*time.Second { 1802 t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) 1803 } 1804 if got, wantSuffix := abuseRateLimitErr.Message, "not making remote request."; !strings.HasSuffix(got, wantSuffix) { 1805 t.Errorf("Expected request to be prevented because of secondary rate limit, got: %v.", got) 1806 } 1807 } 1808 1809 func TestDo_noContent(t *testing.T) { 1810 t.Parallel() 1811 client, mux, _ := setup(t) 1812 1813 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1814 w.WriteHeader(http.StatusNoContent) 1815 }) 1816 1817 var body json.RawMessage 1818 1819 req, _ := client.NewRequest("GET", ".", nil) 1820 ctx := context.Background() 1821 _, err := client.Do(ctx, req, &body) 1822 if err != nil { 1823 t.Fatalf("Do returned unexpected error: %v", err) 1824 } 1825 } 1826 1827 func TestSanitizeURL(t *testing.T) { 1828 t.Parallel() 1829 tests := []struct { 1830 in, want string 1831 }{ 1832 {"/?a=b", "/?a=b"}, 1833 {"/?a=b&client_secret=secret", "/?a=b&client_secret=REDACTED"}, 1834 {"/?a=b&client_id=id&client_secret=secret", "/?a=b&client_id=id&client_secret=REDACTED"}, 1835 } 1836 1837 for _, tt := range tests { 1838 inURL, _ := url.Parse(tt.in) 1839 want, _ := url.Parse(tt.want) 1840 1841 if got := sanitizeURL(inURL); !cmp.Equal(got, want) { 1842 t.Errorf("sanitizeURL(%v) returned %v, want %v", tt.in, got, want) 1843 } 1844 } 1845 } 1846 1847 func TestCheckResponse(t *testing.T) { 1848 t.Parallel() 1849 res := &http.Response{ 1850 Request: &http.Request{}, 1851 StatusCode: http.StatusBadRequest, 1852 Body: io.NopCloser(strings.NewReader(`{"message":"m", 1853 "errors": [{"resource": "r", "field": "f", "code": "c"}], 1854 "block": {"reason": "dmca", "created_at": "2016-03-17T15:39:46Z"}}`)), 1855 } 1856 err := CheckResponse(res).(*ErrorResponse) 1857 1858 if err == nil { 1859 t.Errorf("Expected error response.") 1860 } 1861 1862 want := &ErrorResponse{ 1863 Response: res, 1864 Message: "m", 1865 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 1866 Block: &ErrorBlock{ 1867 Reason: "dmca", 1868 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 1869 }, 1870 } 1871 if !errors.Is(err, want) { 1872 t.Errorf("Error = %#v, want %#v", err, want) 1873 } 1874 } 1875 1876 func TestCheckResponse_RateLimit(t *testing.T) { 1877 t.Parallel() 1878 res := &http.Response{ 1879 Request: &http.Request{}, 1880 StatusCode: http.StatusForbidden, 1881 Header: http.Header{}, 1882 Body: io.NopCloser(strings.NewReader(`{"message":"m", 1883 "documentation_url": "url"}`)), 1884 } 1885 res.Header.Set(headerRateLimit, "60") 1886 res.Header.Set(headerRateRemaining, "0") 1887 res.Header.Set(headerRateReset, "243424") 1888 1889 err := CheckResponse(res).(*RateLimitError) 1890 1891 if err == nil { 1892 t.Errorf("Expected error response.") 1893 } 1894 1895 want := &RateLimitError{ 1896 Rate: parseRate(res), 1897 Response: res, 1898 Message: "m", 1899 } 1900 if !errors.Is(err, want) { 1901 t.Errorf("Error = %#v, want %#v", err, want) 1902 } 1903 } 1904 1905 func TestCheckResponse_AbuseRateLimit(t *testing.T) { 1906 t.Parallel() 1907 res := &http.Response{ 1908 Request: &http.Request{}, 1909 StatusCode: http.StatusForbidden, 1910 Body: io.NopCloser(strings.NewReader(`{"message":"m", 1911 "documentation_url": "docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits"}`)), 1912 } 1913 err := CheckResponse(res).(*AbuseRateLimitError) 1914 1915 if err == nil { 1916 t.Errorf("Expected error response.") 1917 } 1918 1919 want := &AbuseRateLimitError{ 1920 Response: res, 1921 Message: "m", 1922 } 1923 if !errors.Is(err, want) { 1924 t.Errorf("Error = %#v, want %#v", err, want) 1925 } 1926 } 1927 1928 func TestCompareHttpResponse(t *testing.T) { 1929 t.Parallel() 1930 testcases := map[string]struct { 1931 h1 *http.Response 1932 h2 *http.Response 1933 expected bool 1934 }{ 1935 "both are nil": { 1936 expected: true, 1937 }, 1938 "both are non nil - same StatusCode": { 1939 expected: true, 1940 h1: &http.Response{StatusCode: 200}, 1941 h2: &http.Response{StatusCode: 200}, 1942 }, 1943 "both are non nil - different StatusCode": { 1944 expected: false, 1945 h1: &http.Response{StatusCode: 200}, 1946 h2: &http.Response{StatusCode: 404}, 1947 }, 1948 "one is nil, other is not": { 1949 expected: false, 1950 h2: &http.Response{}, 1951 }, 1952 } 1953 1954 for name, tc := range testcases { 1955 tc := tc 1956 t.Run(name, func(t *testing.T) { 1957 t.Parallel() 1958 v := compareHTTPResponse(tc.h1, tc.h2) 1959 if tc.expected != v { 1960 t.Errorf("Expected %t, got %t for (%#v, %#v)", tc.expected, v, tc.h1, tc.h2) 1961 } 1962 }) 1963 } 1964 } 1965 1966 func TestErrorResponse_Is(t *testing.T) { 1967 t.Parallel() 1968 err := &ErrorResponse{ 1969 Response: &http.Response{}, 1970 Message: "m", 1971 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 1972 Block: &ErrorBlock{ 1973 Reason: "r", 1974 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 1975 }, 1976 DocumentationURL: "https://github.com", 1977 } 1978 testcases := map[string]struct { 1979 wantSame bool 1980 otherError error 1981 }{ 1982 "errors are same": { 1983 wantSame: true, 1984 otherError: &ErrorResponse{ 1985 Response: &http.Response{}, 1986 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 1987 Message: "m", 1988 Block: &ErrorBlock{ 1989 Reason: "r", 1990 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 1991 }, 1992 DocumentationURL: "https://github.com", 1993 }, 1994 }, 1995 "errors have different values - Message": { 1996 wantSame: false, 1997 otherError: &ErrorResponse{ 1998 Response: &http.Response{}, 1999 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2000 Message: "m1", 2001 Block: &ErrorBlock{ 2002 Reason: "r", 2003 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2004 }, 2005 DocumentationURL: "https://github.com", 2006 }, 2007 }, 2008 "errors have different values - DocumentationURL": { 2009 wantSame: false, 2010 otherError: &ErrorResponse{ 2011 Response: &http.Response{}, 2012 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2013 Message: "m", 2014 Block: &ErrorBlock{ 2015 Reason: "r", 2016 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2017 }, 2018 DocumentationURL: "https://google.com", 2019 }, 2020 }, 2021 "errors have different values - Response is nil": { 2022 wantSame: false, 2023 otherError: &ErrorResponse{ 2024 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2025 Message: "m", 2026 Block: &ErrorBlock{ 2027 Reason: "r", 2028 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2029 }, 2030 DocumentationURL: "https://github.com", 2031 }, 2032 }, 2033 "errors have different values - Errors": { 2034 wantSame: false, 2035 otherError: &ErrorResponse{ 2036 Response: &http.Response{}, 2037 Errors: []Error{{Resource: "r1", Field: "f1", Code: "c1"}}, 2038 Message: "m", 2039 Block: &ErrorBlock{ 2040 Reason: "r", 2041 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2042 }, 2043 DocumentationURL: "https://github.com", 2044 }, 2045 }, 2046 "errors have different values - Errors have different length": { 2047 wantSame: false, 2048 otherError: &ErrorResponse{ 2049 Response: &http.Response{}, 2050 Errors: []Error{}, 2051 Message: "m", 2052 Block: &ErrorBlock{ 2053 Reason: "r", 2054 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2055 }, 2056 DocumentationURL: "https://github.com", 2057 }, 2058 }, 2059 "errors have different values - Block - one is nil, other is not": { 2060 wantSame: false, 2061 otherError: &ErrorResponse{ 2062 Response: &http.Response{}, 2063 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2064 Message: "m", 2065 DocumentationURL: "https://github.com", 2066 }, 2067 }, 2068 "errors have different values - Block - different Reason": { 2069 wantSame: false, 2070 otherError: &ErrorResponse{ 2071 Response: &http.Response{}, 2072 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2073 Message: "m", 2074 Block: &ErrorBlock{ 2075 Reason: "r1", 2076 CreatedAt: &Timestamp{time.Date(2016, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2077 }, 2078 DocumentationURL: "https://github.com", 2079 }, 2080 }, 2081 "errors have different values - Block - different CreatedAt #1": { 2082 wantSame: false, 2083 otherError: &ErrorResponse{ 2084 Response: &http.Response{}, 2085 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2086 Message: "m", 2087 Block: &ErrorBlock{ 2088 Reason: "r", 2089 CreatedAt: nil, 2090 }, 2091 DocumentationURL: "https://github.com", 2092 }, 2093 }, 2094 "errors have different values - Block - different CreatedAt #2": { 2095 wantSame: false, 2096 otherError: &ErrorResponse{ 2097 Response: &http.Response{}, 2098 Errors: []Error{{Resource: "r", Field: "f", Code: "c"}}, 2099 Message: "m", 2100 Block: &ErrorBlock{ 2101 Reason: "r", 2102 CreatedAt: &Timestamp{time.Date(2017, time.March, 17, 15, 39, 46, 0, time.UTC)}, 2103 }, 2104 DocumentationURL: "https://github.com", 2105 }, 2106 }, 2107 "errors have different types": { 2108 wantSame: false, 2109 otherError: errors.New("Github"), 2110 }, 2111 } 2112 2113 for name, tc := range testcases { 2114 tc := tc 2115 t.Run(name, func(t *testing.T) { 2116 t.Parallel() 2117 if tc.wantSame != err.Is(tc.otherError) { 2118 t.Errorf("Error = %#v, want %#v", err, tc.otherError) 2119 } 2120 }) 2121 } 2122 } 2123 2124 func TestRateLimitError_Is(t *testing.T) { 2125 t.Parallel() 2126 err := &RateLimitError{ 2127 Response: &http.Response{}, 2128 Message: "Github", 2129 } 2130 testcases := map[string]struct { 2131 wantSame bool 2132 err *RateLimitError 2133 otherError error 2134 }{ 2135 "errors are same": { 2136 wantSame: true, 2137 err: err, 2138 otherError: &RateLimitError{ 2139 Response: &http.Response{}, 2140 Message: "Github", 2141 }, 2142 }, 2143 "errors are same - Response is nil": { 2144 wantSame: true, 2145 err: &RateLimitError{ 2146 Message: "Github", 2147 }, 2148 otherError: &RateLimitError{ 2149 Message: "Github", 2150 }, 2151 }, 2152 "errors have different values - Rate": { 2153 wantSame: false, 2154 err: err, 2155 otherError: &RateLimitError{ 2156 Rate: Rate{Limit: 10}, 2157 Response: &http.Response{}, 2158 Message: "Gitlab", 2159 }, 2160 }, 2161 "errors have different values - Response is nil": { 2162 wantSame: false, 2163 err: err, 2164 otherError: &RateLimitError{ 2165 Message: "Github", 2166 }, 2167 }, 2168 "errors have different values - StatusCode": { 2169 wantSame: false, 2170 err: err, 2171 otherError: &RateLimitError{ 2172 Response: &http.Response{StatusCode: 200}, 2173 Message: "Github", 2174 }, 2175 }, 2176 "errors have different types": { 2177 wantSame: false, 2178 err: err, 2179 otherError: errors.New("Github"), 2180 }, 2181 } 2182 2183 for name, tc := range testcases { 2184 tc := tc 2185 t.Run(name, func(t *testing.T) { 2186 t.Parallel() 2187 if tc.wantSame != tc.err.Is(tc.otherError) { 2188 t.Errorf("Error = %#v, want %#v", tc.err, tc.otherError) 2189 } 2190 }) 2191 } 2192 } 2193 2194 func TestAbuseRateLimitError_Is(t *testing.T) { 2195 t.Parallel() 2196 t1 := 1 * time.Second 2197 t2 := 2 * time.Second 2198 err := &AbuseRateLimitError{ 2199 Response: &http.Response{}, 2200 Message: "Github", 2201 RetryAfter: &t1, 2202 } 2203 testcases := map[string]struct { 2204 wantSame bool 2205 err *AbuseRateLimitError 2206 otherError error 2207 }{ 2208 "errors are same": { 2209 wantSame: true, 2210 err: err, 2211 otherError: &AbuseRateLimitError{ 2212 Response: &http.Response{}, 2213 Message: "Github", 2214 RetryAfter: &t1, 2215 }, 2216 }, 2217 "errors are same - Response is nil": { 2218 wantSame: true, 2219 err: &AbuseRateLimitError{ 2220 Message: "Github", 2221 RetryAfter: &t1, 2222 }, 2223 otherError: &AbuseRateLimitError{ 2224 Message: "Github", 2225 RetryAfter: &t1, 2226 }, 2227 }, 2228 "errors have different values - Message": { 2229 wantSame: false, 2230 err: err, 2231 otherError: &AbuseRateLimitError{ 2232 Response: &http.Response{}, 2233 Message: "Gitlab", 2234 RetryAfter: nil, 2235 }, 2236 }, 2237 "errors have different values - RetryAfter": { 2238 wantSame: false, 2239 err: err, 2240 otherError: &AbuseRateLimitError{ 2241 Response: &http.Response{}, 2242 Message: "Github", 2243 RetryAfter: &t2, 2244 }, 2245 }, 2246 "errors have different values - Response is nil": { 2247 wantSame: false, 2248 err: err, 2249 otherError: &AbuseRateLimitError{ 2250 Message: "Github", 2251 RetryAfter: &t1, 2252 }, 2253 }, 2254 "errors have different values - StatusCode": { 2255 wantSame: false, 2256 err: err, 2257 otherError: &AbuseRateLimitError{ 2258 Response: &http.Response{StatusCode: 200}, 2259 Message: "Github", 2260 RetryAfter: &t1, 2261 }, 2262 }, 2263 "errors have different types": { 2264 wantSame: false, 2265 err: err, 2266 otherError: errors.New("Github"), 2267 }, 2268 } 2269 2270 for name, tc := range testcases { 2271 tc := tc 2272 t.Run(name, func(t *testing.T) { 2273 t.Parallel() 2274 if tc.wantSame != tc.err.Is(tc.otherError) { 2275 t.Errorf("Error = %#v, want %#v", tc.err, tc.otherError) 2276 } 2277 }) 2278 } 2279 } 2280 2281 func TestAcceptedError_Is(t *testing.T) { 2282 t.Parallel() 2283 err := &AcceptedError{Raw: []byte("Github")} 2284 testcases := map[string]struct { 2285 wantSame bool 2286 otherError error 2287 }{ 2288 "errors are same": { 2289 wantSame: true, 2290 otherError: &AcceptedError{Raw: []byte("Github")}, 2291 }, 2292 "errors have different values": { 2293 wantSame: false, 2294 otherError: &AcceptedError{Raw: []byte("Gitlab")}, 2295 }, 2296 "errors have different types": { 2297 wantSame: false, 2298 otherError: errors.New("Github"), 2299 }, 2300 } 2301 2302 for name, tc := range testcases { 2303 tc := tc 2304 t.Run(name, func(t *testing.T) { 2305 t.Parallel() 2306 if tc.wantSame != err.Is(tc.otherError) { 2307 t.Errorf("Error = %#v, want %#v", err, tc.otherError) 2308 } 2309 }) 2310 } 2311 } 2312 2313 // Ensure that we properly handle API errors that do not contain a response body. 2314 func TestCheckResponse_noBody(t *testing.T) { 2315 t.Parallel() 2316 res := &http.Response{ 2317 Request: &http.Request{}, 2318 StatusCode: http.StatusBadRequest, 2319 Body: io.NopCloser(strings.NewReader("")), 2320 } 2321 err := CheckResponse(res).(*ErrorResponse) 2322 2323 if err == nil { 2324 t.Errorf("Expected error response.") 2325 } 2326 2327 want := &ErrorResponse{ 2328 Response: res, 2329 } 2330 if !errors.Is(err, want) { 2331 t.Errorf("Error = %#v, want %#v", err, want) 2332 } 2333 } 2334 2335 func TestCheckResponse_unexpectedErrorStructure(t *testing.T) { 2336 t.Parallel() 2337 httpBody := `{"message":"m", "errors": ["error 1"]}` 2338 res := &http.Response{ 2339 Request: &http.Request{}, 2340 StatusCode: http.StatusBadRequest, 2341 Body: io.NopCloser(strings.NewReader(httpBody)), 2342 } 2343 err := CheckResponse(res).(*ErrorResponse) 2344 2345 if err == nil { 2346 t.Errorf("Expected error response.") 2347 } 2348 2349 want := &ErrorResponse{ 2350 Response: res, 2351 Message: "m", 2352 Errors: []Error{{Message: "error 1"}}, 2353 } 2354 if !errors.Is(err, want) { 2355 t.Errorf("Error = %#v, want %#v", err, want) 2356 } 2357 data, err2 := io.ReadAll(err.Response.Body) 2358 if err2 != nil { 2359 t.Fatalf("failed to read response body: %v", err) 2360 } 2361 if got := string(data); got != httpBody { 2362 t.Errorf("ErrorResponse.Response.Body = %q, want %q", got, httpBody) 2363 } 2364 } 2365 2366 func TestParseBooleanResponse_true(t *testing.T) { 2367 t.Parallel() 2368 result, err := parseBoolResponse(nil) 2369 if err != nil { 2370 t.Errorf("parseBoolResponse returned error: %+v", err) 2371 } 2372 2373 if want := true; result != want { 2374 t.Errorf("parseBoolResponse returned %+v, want: %+v", result, want) 2375 } 2376 } 2377 2378 func TestParseBooleanResponse_false(t *testing.T) { 2379 t.Parallel() 2380 v := &ErrorResponse{Response: &http.Response{StatusCode: http.StatusNotFound}} 2381 result, err := parseBoolResponse(v) 2382 if err != nil { 2383 t.Errorf("parseBoolResponse returned error: %+v", err) 2384 } 2385 2386 if want := false; result != want { 2387 t.Errorf("parseBoolResponse returned %+v, want: %+v", result, want) 2388 } 2389 } 2390 2391 func TestParseBooleanResponse_error(t *testing.T) { 2392 t.Parallel() 2393 v := &ErrorResponse{Response: &http.Response{StatusCode: http.StatusBadRequest}} 2394 result, err := parseBoolResponse(v) 2395 2396 if err == nil { 2397 t.Errorf("Expected error to be returned.") 2398 } 2399 2400 if want := false; result != want { 2401 t.Errorf("parseBoolResponse returned %+v, want: %+v", result, want) 2402 } 2403 } 2404 2405 func TestErrorResponse_Error(t *testing.T) { 2406 t.Parallel() 2407 res := &http.Response{Request: &http.Request{}} 2408 err := ErrorResponse{Message: "m", Response: res} 2409 if err.Error() == "" { 2410 t.Errorf("Expected non-empty ErrorResponse.Error()") 2411 } 2412 2413 // dont panic if request is nil 2414 res = &http.Response{} 2415 err = ErrorResponse{Message: "m", Response: res} 2416 if err.Error() == "" { 2417 t.Errorf("Expected non-empty ErrorResponse.Error()") 2418 } 2419 2420 // dont panic if response is nil 2421 err = ErrorResponse{Message: "m"} 2422 if err.Error() == "" { 2423 t.Errorf("Expected non-empty ErrorResponse.Error()") 2424 } 2425 } 2426 2427 func TestError_Error(t *testing.T) { 2428 t.Parallel() 2429 err := Error{} 2430 if err.Error() == "" { 2431 t.Errorf("Expected non-empty Error.Error()") 2432 } 2433 } 2434 2435 func TestSetCredentialsAsHeaders(t *testing.T) { 2436 t.Parallel() 2437 req := new(http.Request) 2438 id, secret := "id", "secret" 2439 modifiedRequest := setCredentialsAsHeaders(req, id, secret) 2440 2441 actualID, actualSecret, ok := modifiedRequest.BasicAuth() 2442 if !ok { 2443 t.Errorf("request does not contain basic credentials") 2444 } 2445 2446 if actualID != id { 2447 t.Errorf("id is %s, want %s", actualID, id) 2448 } 2449 2450 if actualSecret != secret { 2451 t.Errorf("secret is %s, want %s", actualSecret, secret) 2452 } 2453 } 2454 2455 func TestUnauthenticatedRateLimitedTransport(t *testing.T) { 2456 t.Parallel() 2457 client, mux, _ := setup(t) 2458 2459 clientID, clientSecret := "id", "secret" 2460 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 2461 id, secret, ok := r.BasicAuth() 2462 if !ok { 2463 t.Errorf("request does not contain basic auth credentials") 2464 } 2465 if id != clientID { 2466 t.Errorf("request contained basic auth username %q, want %q", id, clientID) 2467 } 2468 if secret != clientSecret { 2469 t.Errorf("request contained basic auth password %q, want %q", secret, clientSecret) 2470 } 2471 }) 2472 2473 tp := &UnauthenticatedRateLimitedTransport{ 2474 ClientID: clientID, 2475 ClientSecret: clientSecret, 2476 } 2477 unauthedClient := NewClient(tp.Client()) 2478 unauthedClient.BaseURL = client.BaseURL 2479 req, _ := unauthedClient.NewRequest("GET", ".", nil) 2480 ctx := context.Background() 2481 _, err := unauthedClient.Do(ctx, req, nil) 2482 assertNilError(t, err) 2483 } 2484 2485 func TestUnauthenticatedRateLimitedTransport_missingFields(t *testing.T) { 2486 t.Parallel() 2487 // missing ClientID 2488 tp := &UnauthenticatedRateLimitedTransport{ 2489 ClientSecret: "secret", 2490 } 2491 _, err := tp.RoundTrip(nil) 2492 if err == nil { 2493 t.Errorf("Expected error to be returned") 2494 } 2495 2496 // missing ClientSecret 2497 tp = &UnauthenticatedRateLimitedTransport{ 2498 ClientID: "id", 2499 } 2500 _, err = tp.RoundTrip(nil) 2501 if err == nil { 2502 t.Errorf("Expected error to be returned") 2503 } 2504 } 2505 2506 func TestUnauthenticatedRateLimitedTransport_transport(t *testing.T) { 2507 t.Parallel() 2508 // default transport 2509 tp := &UnauthenticatedRateLimitedTransport{ 2510 ClientID: "id", 2511 ClientSecret: "secret", 2512 } 2513 if tp.transport() != http.DefaultTransport { 2514 t.Errorf("Expected http.DefaultTransport to be used.") 2515 } 2516 2517 // custom transport 2518 tp = &UnauthenticatedRateLimitedTransport{ 2519 ClientID: "id", 2520 ClientSecret: "secret", 2521 Transport: &http.Transport{}, 2522 } 2523 if tp.transport() == http.DefaultTransport { 2524 t.Errorf("Expected custom transport to be used.") 2525 } 2526 } 2527 2528 func TestBasicAuthTransport(t *testing.T) { 2529 t.Parallel() 2530 client, mux, _ := setup(t) 2531 2532 username, password, otp := "u", "p", "123456" 2533 2534 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 2535 u, p, ok := r.BasicAuth() 2536 if !ok { 2537 t.Errorf("request does not contain basic auth credentials") 2538 } 2539 if u != username { 2540 t.Errorf("request contained basic auth username %q, want %q", u, username) 2541 } 2542 if p != password { 2543 t.Errorf("request contained basic auth password %q, want %q", p, password) 2544 } 2545 if got, want := r.Header.Get(headerOTP), otp; got != want { 2546 t.Errorf("request contained OTP %q, want %q", got, want) 2547 } 2548 }) 2549 2550 tp := &BasicAuthTransport{ 2551 Username: username, 2552 Password: password, 2553 OTP: otp, 2554 } 2555 basicAuthClient := NewClient(tp.Client()) 2556 basicAuthClient.BaseURL = client.BaseURL 2557 req, _ := basicAuthClient.NewRequest("GET", ".", nil) 2558 ctx := context.Background() 2559 _, err := basicAuthClient.Do(ctx, req, nil) 2560 assertNilError(t, err) 2561 } 2562 2563 func TestBasicAuthTransport_transport(t *testing.T) { 2564 t.Parallel() 2565 // default transport 2566 tp := &BasicAuthTransport{} 2567 if tp.transport() != http.DefaultTransport { 2568 t.Errorf("Expected http.DefaultTransport to be used.") 2569 } 2570 2571 // custom transport 2572 tp = &BasicAuthTransport{ 2573 Transport: &http.Transport{}, 2574 } 2575 if tp.transport() == http.DefaultTransport { 2576 t.Errorf("Expected custom transport to be used.") 2577 } 2578 } 2579 2580 func TestFormatRateReset(t *testing.T) { 2581 t.Parallel() 2582 d := 120*time.Minute + 12*time.Second 2583 got := formatRateReset(d) 2584 want := "[rate reset in 120m12s]" 2585 if got != want { 2586 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2587 } 2588 2589 d = 14*time.Minute + 2*time.Second 2590 got = formatRateReset(d) 2591 want = "[rate reset in 14m02s]" 2592 if got != want { 2593 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2594 } 2595 2596 d = 2*time.Minute + 2*time.Second 2597 got = formatRateReset(d) 2598 want = "[rate reset in 2m02s]" 2599 if got != want { 2600 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2601 } 2602 2603 d = 12 * time.Second 2604 got = formatRateReset(d) 2605 want = "[rate reset in 12s]" 2606 if got != want { 2607 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2608 } 2609 2610 d = -1 * (2*time.Hour + 2*time.Second) 2611 got = formatRateReset(d) 2612 want = "[rate limit was reset 120m02s ago]" 2613 if got != want { 2614 t.Errorf("Format is wrong. got: %v, want: %v", got, want) 2615 } 2616 } 2617 2618 func TestNestedStructAccessorNoPanic(t *testing.T) { 2619 t.Parallel() 2620 issue := &Issue{User: nil} 2621 got := issue.GetUser().GetPlan().GetName() 2622 want := "" 2623 if got != want { 2624 t.Errorf("Issues.Get.GetUser().GetPlan().GetName() returned %+v, want %+v", got, want) 2625 } 2626 } 2627 2628 func TestTwoFactorAuthError(t *testing.T) { 2629 t.Parallel() 2630 u, err := url.Parse("https://example.com") 2631 if err != nil { 2632 t.Fatal(err) 2633 } 2634 2635 e := &TwoFactorAuthError{ 2636 Response: &http.Response{ 2637 Request: &http.Request{Method: "PUT", URL: u}, 2638 StatusCode: http.StatusTooManyRequests, 2639 }, 2640 Message: "<msg>", 2641 } 2642 if got, want := e.Error(), "PUT https://example.com: 429 <msg> []"; got != want { 2643 t.Errorf("TwoFactorAuthError = %q, want %q", got, want) 2644 } 2645 } 2646 2647 func TestRateLimitError(t *testing.T) { 2648 t.Parallel() 2649 u, err := url.Parse("https://example.com") 2650 if err != nil { 2651 t.Fatal(err) 2652 } 2653 2654 r := &RateLimitError{ 2655 Response: &http.Response{ 2656 Request: &http.Request{Method: "PUT", URL: u}, 2657 StatusCode: http.StatusTooManyRequests, 2658 }, 2659 Message: "<msg>", 2660 } 2661 if got, want := r.Error(), "PUT https://example.com: 429 <msg> [rate limit was reset"; !strings.Contains(got, want) { 2662 t.Errorf("RateLimitError = %q, want %q", got, want) 2663 } 2664 } 2665 2666 func TestAcceptedError(t *testing.T) { 2667 t.Parallel() 2668 a := &AcceptedError{} 2669 if got, want := a.Error(), "try again later"; !strings.Contains(got, want) { 2670 t.Errorf("AcceptedError = %q, want %q", got, want) 2671 } 2672 } 2673 2674 func TestAbuseRateLimitError(t *testing.T) { 2675 t.Parallel() 2676 u, err := url.Parse("https://example.com") 2677 if err != nil { 2678 t.Fatal(err) 2679 } 2680 2681 r := &AbuseRateLimitError{ 2682 Response: &http.Response{ 2683 Request: &http.Request{Method: "PUT", URL: u}, 2684 StatusCode: http.StatusTooManyRequests, 2685 }, 2686 Message: "<msg>", 2687 } 2688 if got, want := r.Error(), "PUT https://example.com: 429 <msg>"; got != want { 2689 t.Errorf("AbuseRateLimitError = %q, want %q", got, want) 2690 } 2691 } 2692 2693 func TestAddOptions_QueryValues(t *testing.T) { 2694 t.Parallel() 2695 if _, err := addOptions("yo", ""); err == nil { 2696 t.Error("addOptions err = nil, want error") 2697 } 2698 } 2699 2700 func TestBareDo_returnsOpenBody(t *testing.T) { 2701 t.Parallel() 2702 client, mux, _ := setup(t) 2703 2704 expectedBody := "Hello from the other side !" 2705 2706 mux.HandleFunc("/test-url", func(w http.ResponseWriter, r *http.Request) { 2707 testMethod(t, r, "GET") 2708 fmt.Fprint(w, expectedBody) 2709 }) 2710 2711 ctx := context.Background() 2712 req, err := client.NewRequest("GET", "test-url", nil) 2713 if err != nil { 2714 t.Fatalf("client.NewRequest returned error: %v", err) 2715 } 2716 2717 resp, err := client.BareDo(ctx, req) 2718 if err != nil { 2719 t.Fatalf("client.BareDo returned error: %v", err) 2720 } 2721 2722 got, err := io.ReadAll(resp.Body) 2723 if err != nil { 2724 t.Fatalf("io.ReadAll returned error: %v", err) 2725 } 2726 if string(got) != expectedBody { 2727 t.Fatalf("Expected %q, got %q", expectedBody, string(got)) 2728 } 2729 if err := resp.Body.Close(); err != nil { 2730 t.Fatalf("resp.Body.Close() returned error: %v", err) 2731 } 2732 } 2733 2734 func TestErrorResponse_Marshal(t *testing.T) { 2735 t.Parallel() 2736 testJSONMarshal(t, &ErrorResponse{}, "{}") 2737 2738 u := &ErrorResponse{ 2739 Message: "msg", 2740 Errors: []Error{ 2741 { 2742 Resource: "res", 2743 Field: "f", 2744 Code: "c", 2745 Message: "msg", 2746 }, 2747 }, 2748 Block: &ErrorBlock{ 2749 Reason: "reason", 2750 CreatedAt: &Timestamp{referenceTime}, 2751 }, 2752 DocumentationURL: "doc", 2753 } 2754 2755 want := `{ 2756 "message": "msg", 2757 "errors": [ 2758 { 2759 "resource": "res", 2760 "field": "f", 2761 "code": "c", 2762 "message": "msg" 2763 } 2764 ], 2765 "block": { 2766 "reason": "reason", 2767 "created_at": ` + referenceTimeStr + ` 2768 }, 2769 "documentation_url": "doc" 2770 }` 2771 2772 testJSONMarshal(t, u, want) 2773 } 2774 2775 func TestErrorBlock_Marshal(t *testing.T) { 2776 t.Parallel() 2777 testJSONMarshal(t, &ErrorBlock{}, "{}") 2778 2779 u := &ErrorBlock{ 2780 Reason: "reason", 2781 CreatedAt: &Timestamp{referenceTime}, 2782 } 2783 2784 want := `{ 2785 "reason": "reason", 2786 "created_at": ` + referenceTimeStr + ` 2787 }` 2788 2789 testJSONMarshal(t, u, want) 2790 } 2791 2792 func TestRateLimitError_Marshal(t *testing.T) { 2793 t.Parallel() 2794 testJSONMarshal(t, &RateLimitError{}, "{}") 2795 2796 u := &RateLimitError{ 2797 Rate: Rate{ 2798 Limit: 1, 2799 Remaining: 1, 2800 Reset: Timestamp{referenceTime}, 2801 }, 2802 Message: "msg", 2803 } 2804 2805 want := `{ 2806 "Rate": { 2807 "limit": 1, 2808 "remaining": 1, 2809 "reset": ` + referenceTimeStr + ` 2810 }, 2811 "message": "msg" 2812 }` 2813 2814 testJSONMarshal(t, u, want) 2815 } 2816 2817 func TestAbuseRateLimitError_Marshal(t *testing.T) { 2818 t.Parallel() 2819 testJSONMarshal(t, &AbuseRateLimitError{}, "{}") 2820 2821 u := &AbuseRateLimitError{ 2822 Message: "msg", 2823 } 2824 2825 want := `{ 2826 "message": "msg" 2827 }` 2828 2829 testJSONMarshal(t, u, want) 2830 } 2831 2832 func TestError_Marshal(t *testing.T) { 2833 t.Parallel() 2834 testJSONMarshal(t, &Error{}, "{}") 2835 2836 u := &Error{ 2837 Resource: "res", 2838 Field: "field", 2839 Code: "code", 2840 Message: "msg", 2841 } 2842 2843 want := `{ 2844 "resource": "res", 2845 "field": "field", 2846 "code": "code", 2847 "message": "msg" 2848 }` 2849 2850 testJSONMarshal(t, u, want) 2851 } 2852 2853 func TestParseTokenExpiration(t *testing.T) { 2854 t.Parallel() 2855 tests := []struct { 2856 header string 2857 want Timestamp 2858 }{ 2859 { 2860 header: "", 2861 want: Timestamp{}, 2862 }, 2863 { 2864 header: "this is a garbage", 2865 want: Timestamp{}, 2866 }, 2867 { 2868 header: "2021-09-03 02:34:04 UTC", 2869 want: Timestamp{time.Date(2021, time.September, 3, 2, 34, 4, 0, time.UTC)}, 2870 }, 2871 { 2872 header: "2021-09-03 14:34:04 UTC", 2873 want: Timestamp{time.Date(2021, time.September, 3, 14, 34, 4, 0, time.UTC)}, 2874 }, 2875 // Some tokens include the timezone offset instead of the timezone. 2876 // https://github.com/google/go-github/issues/2649 2877 { 2878 header: "2023-04-26 20:23:26 +0200", 2879 want: Timestamp{time.Date(2023, time.April, 26, 18, 23, 26, 0, time.UTC)}, 2880 }, 2881 } 2882 2883 for _, tt := range tests { 2884 res := &http.Response{ 2885 Request: &http.Request{}, 2886 Header: http.Header{}, 2887 } 2888 2889 res.Header.Set(headerTokenExpiration, tt.header) 2890 exp := parseTokenExpiration(res) 2891 if !exp.Equal(tt.want) { 2892 t.Errorf("parseTokenExpiration of %q\nreturned %#v\n want %#v", tt.header, exp, tt.want) 2893 } 2894 } 2895 } 2896 2897 func TestClientCopy_leak_transport(t *testing.T) { 2898 t.Parallel() 2899 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 2900 w.Header().Set("Content-Type", "application/json") 2901 accessToken := r.Header.Get("Authorization") 2902 _, _ = fmt.Fprintf(w, `{"login": "%s"}`, accessToken) 2903 })) 2904 clientPreconfiguredWithURLs, err := NewClient(nil).WithEnterpriseURLs(srv.URL, srv.URL) 2905 if err != nil { 2906 t.Fatal(err) 2907 } 2908 2909 aliceClient := clientPreconfiguredWithURLs.WithAuthToken("alice") 2910 bobClient := clientPreconfiguredWithURLs.WithAuthToken("bob") 2911 2912 alice, _, err := aliceClient.Users.Get(context.Background(), "") 2913 if err != nil { 2914 t.Fatal(err) 2915 } 2916 2917 assertNoDiff(t, "Bearer alice", alice.GetLogin()) 2918 2919 bob, _, err := bobClient.Users.Get(context.Background(), "") 2920 if err != nil { 2921 t.Fatal(err) 2922 } 2923 2924 assertNoDiff(t, "Bearer bob", bob.GetLogin()) 2925 } 2926 2927 func TestPtr(t *testing.T) { 2928 t.Parallel() 2929 equal := func(t *testing.T, want, got any) { 2930 t.Helper() 2931 if !reflect.DeepEqual(want, got) { 2932 t.Errorf("want %#v, got %#v", want, got) 2933 } 2934 } 2935 2936 equal(t, true, *Ptr(true)) 2937 equal(t, int(10), *Ptr(int(10))) 2938 equal(t, int64(-10), *Ptr(int64(-10))) 2939 equal(t, "str", *Ptr("str")) 2940 }