github.com/stripe/stripe-go/v76@v76.25.0/stripe_test.go (about) 1 package stripe 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/x509" 7 "encoding/json" 8 "fmt" 9 "io/ioutil" 10 "net/http" 11 "net/http/httptest" 12 "net/url" 13 "regexp" 14 "runtime" 15 "strings" 16 "sync" 17 "sync/atomic" 18 "testing" 19 "time" 20 21 assert "github.com/stretchr/testify/require" 22 ) 23 24 // A shortcut for a leveled logger that spits out all debug information (useful in tests). 25 var debugLeveledLogger = &LeveledLogger{ 26 Level: LevelDebug, 27 } 28 29 // For tests that produce a lot of logging or alarming error logs on a 30 // successful run (thereby making `go test . -test.v` quite noisy), use this 31 // null leveled logger instead of the debug one above. 32 var nullLeveledLogger = &LeveledLogger{ 33 Level: LevelNull, 34 } 35 36 // 37 // --- 38 // 39 40 func TestBearerAuth(t *testing.T) { 41 c := GetBackend(APIBackend).(*BackendImplementation) 42 key := "apiKey" 43 44 req, err := c.NewRequest("", "", key, "", nil) 45 assert.NoError(t, err) 46 47 assert.Equal(t, "Bearer "+key, req.Header.Get("Authorization")) 48 } 49 50 func TestContext(t *testing.T) { 51 c := GetBackend(APIBackend).(*BackendImplementation) 52 p := &Params{Context: context.Background()} 53 54 req, err := c.NewRequest("", "", "", "", p) 55 assert.NoError(t, err) 56 57 // We assume that contexts are sufficiently tested in the standard library 58 // and here we just check that the context sent in to `NewRequest` is 59 // indeed properly set on the request that's returned. 60 assert.Equal(t, p.Context, req.Context()) 61 } 62 63 // Tests client retries. 64 // 65 // You can get pretty good visibility into what's going on by running just this 66 // test on verbose: 67 // 68 // go test . -run TestDo_Retry -test.v 69 func TestDo_Retry(t *testing.T) { 70 type testServerResponse struct { 71 APIResource 72 Message string `json:"message"` 73 } 74 75 message := "Hello, client." 76 requestNum := 0 77 78 testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 79 err := r.ParseForm() 80 assert.NoError(t, err) 81 82 // The body should always be the same with every retry. We've 83 // previously had regressions in this behavior as we switched to HTTP/2 84 // and `Request` became non-reusable, so we want to check it with every 85 // request. 86 assert.Equal(t, "bar", r.Form.Get("foo")) 87 88 switch requestNum { 89 case 0: 90 w.WriteHeader(http.StatusConflict) 91 w.Write([]byte(`{"error":"Conflict (this should be retried)."}`)) 92 93 case 1: 94 response := testServerResponse{Message: message} 95 96 data, err := json.Marshal(response) 97 assert.NoError(t, err) 98 99 _, err = w.Write(data) 100 assert.NoError(t, err) 101 102 default: 103 assert.Fail(t, "Should not have reached request %v", requestNum) 104 } 105 106 requestNum++ 107 })) 108 defer testServer.Close() 109 110 backend := GetBackendWithConfig( 111 APIBackend, 112 &BackendConfig{ 113 LeveledLogger: nullLeveledLogger, 114 MaxNetworkRetries: Int64(5), 115 URL: String(testServer.URL), 116 }, 117 ).(*BackendImplementation) 118 119 // Disable sleeping duration our tests. 120 backend.SetNetworkRetriesSleep(false) 121 122 request, err := backend.NewRequest( 123 http.MethodPost, 124 "/hello", 125 "sk_test_123", 126 "application/x-www-form-urlencoded", 127 nil, 128 ) 129 assert.NoError(t, err) 130 131 bodyBuffer := bytes.NewBufferString("foo=bar") 132 var response testServerResponse 133 err = backend.Do(request, bodyBuffer, &response) 134 135 assert.NoError(t, err) 136 assert.Equal(t, message, response.Message) 137 138 // We should have seen exactly two requests. 139 assert.Equal(t, 2, requestNum) 140 } 141 142 func TestShouldRetry(t *testing.T) { 143 MaxNetworkRetries := int64(3) 144 145 c := GetBackendWithConfig( 146 APIBackend, 147 &BackendConfig{ 148 MaxNetworkRetries: Int64(MaxNetworkRetries), 149 }, 150 ).(*BackendImplementation) 151 152 // Exceeded maximum number of retries 153 t.Run("DontRetryOnExceededRetries", func(t *testing.T) { 154 shouldRetry, _ := c.shouldRetry( 155 nil, 156 &http.Request{}, 157 &http.Response{}, 158 int(MaxNetworkRetries), 159 ) 160 assert.False(t, shouldRetry) 161 }) 162 163 // Canceled context -- don't retry 164 t.Run("DontRetryOnCanceledContext", func(t *testing.T) { 165 ctx, cancel := context.WithCancel(context.Background()) 166 cancel() 167 req, err := http.NewRequestWithContext(ctx, http.MethodPost, "", nil) 168 assert.NoError(t, err) 169 170 shouldRetry, _ := c.shouldRetry( 171 nil, 172 req, 173 &http.Response{StatusCode: http.StatusOK}, 174 0, 175 ) 176 assert.False(t, shouldRetry) 177 }) 178 179 // Doesn't retry most Stripe errors (they must also match a status code 180 // below to be retried) 181 t.Run("DontRetryOnStripeError", func(t *testing.T) { 182 shouldRetry, _ := c.shouldRetry( 183 &Error{Msg: "An error from Stripe"}, 184 &http.Request{}, 185 &http.Response{StatusCode: http.StatusBadRequest}, 186 0, 187 ) 188 assert.False(t, shouldRetry) 189 }) 190 191 // Don't retry too many redirects. 192 t.Run("DontRetryOnTooManyRedirects", func(t *testing.T) { 193 shouldRetry, _ := c.shouldRetry( 194 &url.Error{Err: fmt.Errorf("stopped after 5 redirects")}, 195 &http.Request{}, 196 nil, 197 0, 198 ) 199 assert.False(t, shouldRetry) 200 }) 201 202 // Don't retry invalid protocol scheme. 203 t.Run("DontRetryOnInvalidProtocolScheme", func(t *testing.T) { 204 shouldRetry, _ := c.shouldRetry( 205 &url.Error{Err: fmt.Errorf("unsupported protocol scheme")}, 206 &http.Request{}, 207 nil, 208 0, 209 ) 210 assert.False(t, shouldRetry) 211 }) 212 213 // Don't retry TLS certificate validation problems. 214 t.Run("DontRetryOnCertificateError", func(t *testing.T) { 215 shouldRetry, _ := c.shouldRetry( 216 &url.Error{Err: x509.UnknownAuthorityError{}}, 217 &http.Request{}, 218 nil, 219 0, 220 ) 221 assert.False(t, shouldRetry) 222 }) 223 224 // Retries most non-Stripe errors 225 t.Run("RetryOnNonStripeError", func(t *testing.T) { 226 shouldRetry, _ := c.shouldRetry( 227 fmt.Errorf("an error"), 228 &http.Request{}, 229 nil, 230 0, 231 ) 232 assert.True(t, shouldRetry) 233 }) 234 235 // `Stripe-Should-Retry: false` 236 t.Run("DontRetryOnStripeRetryHeaderFalse", func(t *testing.T) { 237 shouldRetry, _ := c.shouldRetry( 238 nil, 239 &http.Request{}, 240 &http.Response{ 241 Header: http.Header(map[string][]string{ 242 "Stripe-Should-Retry": {"false"}, 243 }), 244 // Note we send status 409 here, which would normally be retried 245 StatusCode: http.StatusConflict, 246 }, 247 0, 248 ) 249 assert.False(t, shouldRetry) 250 }) 251 252 // `Stripe-Should-Retry: true` 253 t.Run("RetryOnStripeRetryHeaderTrue", func(t *testing.T) { 254 shouldRetry, _ := c.shouldRetry( 255 nil, 256 &http.Request{}, 257 &http.Response{ 258 Header: http.Header(map[string][]string{ 259 "Stripe-Should-Retry": {"true"}, 260 }), 261 // Note we send status 400 here, which would normally not be 262 // retried 263 StatusCode: http.StatusBadRequest, 264 }, 265 0, 266 ) 267 assert.True(t, shouldRetry) 268 }) 269 270 // 409 Conflict 271 t.Run("RetryOn409Conflict", func(t *testing.T) { 272 shouldRetry, _ := c.shouldRetry( 273 nil, 274 &http.Request{}, 275 &http.Response{StatusCode: http.StatusConflict}, 276 0, 277 ) 278 assert.True(t, shouldRetry) 279 }) 280 281 // 429 Too Many Requests -- retry on lock timeout 282 t.Run("RetryOn429TooManyRequestsLockTimeout", func(t *testing.T) { 283 shouldRetry, _ := c.shouldRetry( 284 &Error{Code: ErrorCodeLockTimeout}, 285 &http.Request{}, 286 &http.Response{StatusCode: http.StatusTooManyRequests}, 287 0, 288 ) 289 assert.True(t, shouldRetry) 290 }) 291 292 // 429 Too Many Requests -- don't retry normally 293 t.Run("DontRetryOn429TooManyRequests", func(t *testing.T) { 294 shouldRetry, _ := c.shouldRetry( 295 nil, 296 &http.Request{}, 297 &http.Response{StatusCode: http.StatusTooManyRequests}, 298 0, 299 ) 300 assert.False(t, shouldRetry) 301 }) 302 303 // 500 Internal Server Error -- retry if non-POST 304 t.Run("RetryOn500NonPost", func(t *testing.T) { 305 shouldRetry, _ := c.shouldRetry( 306 nil, 307 &http.Request{Method: http.MethodGet}, 308 &http.Response{StatusCode: http.StatusInternalServerError}, 309 0, 310 ) 311 assert.True(t, shouldRetry) 312 }) 313 314 // 500 Internal Server Error -- don't retry POST 315 t.Run("DontRetryOn500Post", func(t *testing.T) { 316 shouldRetry, _ := c.shouldRetry( 317 nil, 318 &http.Request{Method: http.MethodPost}, 319 &http.Response{StatusCode: http.StatusInternalServerError}, 320 0, 321 ) 322 assert.False(t, shouldRetry) 323 }) 324 325 // 503 Service Unavailable 326 t.Run("RetryOn503ServiceUnavailable", func(t *testing.T) { 327 shouldRetry, _ := c.shouldRetry( 328 nil, 329 &http.Request{}, 330 &http.Response{StatusCode: http.StatusServiceUnavailable}, 331 0, 332 ) 333 assert.True(t, shouldRetry) 334 }) 335 } 336 337 func TestDo_RetryOnTimeout(t *testing.T) { 338 type testServerResponse struct { 339 APIResource 340 Message string `json:"message"` 341 } 342 343 timeout := time.Second 344 var counter uint32 345 346 testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 347 atomic.AddUint32(&counter, 1) 348 time.Sleep(timeout) 349 })) 350 defer testServer.Close() 351 352 backend := GetBackendWithConfig( 353 APIBackend, 354 &BackendConfig{ 355 LeveledLogger: nullLeveledLogger, 356 MaxNetworkRetries: Int64(1), 357 URL: String(testServer.URL), 358 HTTPClient: &http.Client{Timeout: timeout / 2}, 359 }, 360 ).(*BackendImplementation) 361 362 backend.SetNetworkRetriesSleep(false) 363 364 request, err := backend.NewRequest( 365 http.MethodPost, 366 "/hello", 367 "sk_test_123", 368 "application/x-www-form-urlencoded", 369 nil, 370 ) 371 assert.NoError(t, err) 372 373 var body = bytes.NewBufferString("foo=bar") 374 var response testServerResponse 375 376 err = backend.Do(request, body, &response) 377 378 assert.Error(t, err) 379 // timeout should not prevent retry 380 assert.Equal(t, uint32(2), atomic.LoadUint32(&counter)) 381 } 382 383 func TestDo_LastResponsePopulated(t *testing.T) { 384 type testServerResponse struct { 385 APIResource 386 Message string `json:"message"` 387 } 388 389 message := "Hello, client." 390 expectedResponse := testServerResponse{Message: message} 391 rawJSON, err := json.Marshal(expectedResponse) 392 assert.NoError(t, err) 393 394 testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 395 w.Header().Set("Idempotency-Key", "key_123") 396 w.Header().Set("Other-Header", "other_header") 397 w.Header().Set("Request-Id", "req_123") 398 399 w.WriteHeader(http.StatusCreated) 400 _, err = w.Write(rawJSON) 401 assert.NoError(t, err) 402 })) 403 defer testServer.Close() 404 405 backend := GetBackendWithConfig( 406 APIBackend, 407 &BackendConfig{ 408 LeveledLogger: debugLeveledLogger, 409 MaxNetworkRetries: Int64(0), 410 URL: String(testServer.URL), 411 }, 412 ).(*BackendImplementation) 413 414 request, err := backend.NewRequest( 415 http.MethodGet, 416 "/hello", 417 "sk_test_123", 418 "application/x-www-form-urlencoded", 419 nil, 420 ) 421 assert.NoError(t, err) 422 423 var resource testServerResponse 424 err = backend.Do(request, nil, &resource) 425 assert.NoError(t, err) 426 assert.Equal(t, message, resource.Message) 427 428 assert.Equal(t, "key_123", resource.LastResponse.IdempotencyKey) 429 assert.Equal(t, "other_header", resource.LastResponse.Header.Get("Other-Header")) 430 assert.Equal(t, rawJSON, resource.LastResponse.RawJSON) 431 assert.Equal(t, "req_123", resource.LastResponse.RequestID) 432 assert.Equal(t, 433 fmt.Sprintf("%v %v", http.StatusCreated, http.StatusText(http.StatusCreated)), 434 resource.LastResponse.Status) 435 assert.Equal(t, http.StatusCreated, resource.LastResponse.StatusCode) 436 } 437 438 // Test that telemetry metrics are not sent by default 439 func TestCall_TelemetryDisabled(t *testing.T) { 440 type testServerResponse struct { 441 APIResource 442 Message string `json:"message"` 443 } 444 445 message := "Hello, client." 446 requestNum := 0 447 448 testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 449 // none of the requests should include telemetry metrics 450 assert.Equal(t, r.Header.Get("X-Stripe-Client-Telemetry"), "") 451 452 response := testServerResponse{Message: message} 453 454 data, err := json.Marshal(response) 455 assert.NoError(t, err) 456 457 _, err = w.Write(data) 458 assert.NoError(t, err) 459 460 requestNum++ 461 })) 462 defer testServer.Close() 463 464 backend := GetBackendWithConfig( 465 APIBackend, 466 &BackendConfig{ 467 LeveledLogger: debugLeveledLogger, 468 MaxNetworkRetries: Int64(0), 469 URL: String(testServer.URL), 470 }, 471 ).(*BackendImplementation) 472 473 // When telemetry is enabled, the metrics for a request are sent with the 474 // _next_ request via the `X-Stripe-Client-Telemetry header`. To test that 475 // metrics aren't being sent, we need to fire off two requests in sequence. 476 for i := 0; i < 2; i++ { 477 var response testServerResponse 478 err := backend.Call("get", "/hello", "sk_test_xyz", nil, &response) 479 480 assert.NoError(t, err) 481 assert.Equal(t, message, response.Message) 482 } 483 484 // We should have seen exactly two requests. 485 assert.Equal(t, 2, requestNum) 486 } 487 488 // Test that telemetry metrics are sent on subsequent requests when 489 // EnableTelemetry = true. 490 func TestCall_TelemetryEnabled(t *testing.T) { 491 type testServerResponse struct { 492 APIResource 493 Message string `json:"message"` 494 } 495 496 type requestTelemetry struct { 497 LastRequestMetrics requestMetrics `json:"last_request_metrics"` 498 } 499 500 message := "Hello, client." 501 requestNum := 0 502 503 testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 504 requestNum++ 505 506 telemetryStr := r.Header.Get("X-Stripe-Client-Telemetry") 507 switch requestNum { 508 case 1: 509 // the first request should not receive any metrics 510 assert.Equal(t, telemetryStr, "") 511 time.Sleep(21 * time.Millisecond) 512 case 2: 513 assert.True(t, len(telemetryStr) > 0, "telemetryStr should not be empty") 514 515 var telemetry requestTelemetry 516 // the telemetry should properly unmarshal into RequestTelemetry 517 err := json.Unmarshal([]byte(telemetryStr), &telemetry) 518 assert.NoError(t, err) 519 520 // the second request should include the metrics for the first request 521 assert.Equal(t, telemetry.LastRequestMetrics.RequestID, "req_1") 522 assert.True(t, *telemetry.LastRequestMetrics.RequestDurationMS > 20, 523 "request_duration_ms should be > 20ms") 524 525 // The telemetry in the second request should contain the 526 // expected usage 527 assert.Equal(t, telemetry.LastRequestMetrics.Usage, []string{"llama", "bufo"}) 528 default: 529 assert.Fail(t, "Should not have reached request %v", requestNum) 530 } 531 532 w.Header().Set("Request-Id", fmt.Sprintf("req_%d", requestNum)) 533 response := testServerResponse{Message: message} 534 535 data, err := json.Marshal(response) 536 assert.NoError(t, err) 537 538 _, err = w.Write(data) 539 assert.NoError(t, err) 540 })) 541 defer testServer.Close() 542 543 backend := GetBackendWithConfig( 544 APIBackend, 545 &BackendConfig{ 546 EnableTelemetry: Bool(true), 547 LeveledLogger: debugLeveledLogger, 548 MaxNetworkRetries: Int64(0), 549 URL: String(testServer.URL), 550 }, 551 ).(*BackendImplementation) 552 553 type myCreateParams struct { 554 Params `form:"*"` 555 Foo string `form:"foo"` 556 } 557 params := &myCreateParams{ 558 Foo: "bar", 559 } 560 params.InternalSetUsage([]string{"llama", "bufo"}) 561 for i := 0; i < 2; i++ { 562 var response testServerResponse 563 err := backend.Call("get", "/hello", "sk_test_xyz", params, &response) 564 565 assert.NoError(t, err) 566 assert.Equal(t, message, response.Message) 567 } 568 569 // We should have seen exactly two requests. 570 assert.Equal(t, 2, requestNum) 571 } 572 573 // This test does not perform any super valuable assertions - instead, it checks 574 // that our logic for buffering requestMetrics when EnableTelemetry = true does 575 // not trigger any data races. This test should pass when the -race flag is 576 // passed to `go test`. 577 func TestDo_TelemetryEnabledNoDataRace(t *testing.T) { 578 type testServerResponse struct { 579 APIResource 580 Message string `json:"message"` 581 } 582 583 message := "Hello, client." 584 var requestNum int32 585 586 testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 587 reqID := atomic.AddInt32(&requestNum, 1) 588 589 w.Header().Set("Request-Id", fmt.Sprintf("req_%d", reqID)) 590 response := testServerResponse{Message: message} 591 592 data, err := json.Marshal(response) 593 assert.NoError(t, err) 594 595 _, err = w.Write(data) 596 assert.NoError(t, err) 597 })) 598 defer testServer.Close() 599 600 backend := GetBackendWithConfig( 601 APIBackend, 602 &BackendConfig{ 603 EnableTelemetry: Bool(true), 604 LeveledLogger: nullLeveledLogger, 605 MaxNetworkRetries: Int64(0), 606 URL: String(testServer.URL), 607 }, 608 ).(*BackendImplementation) 609 610 times := 20 // 20 > telemetryBufferSize, so some metrics could be discarded 611 done := make(chan struct{}) 612 613 for i := 0; i < times; i++ { 614 go func() { 615 var response testServerResponse 616 err := backend.Call("get", "/hello", "sk_test_xyz", nil, &response) 617 618 assert.NoError(t, err) 619 assert.Equal(t, message, response.Message) 620 621 done <- struct{}{} 622 }() 623 } 624 625 for i := 0; i < times; i++ { 626 <-done 627 } 628 629 assert.Equal(t, int32(times), requestNum) 630 } 631 632 func TestDo_Redaction(t *testing.T) { 633 type testServerResponse struct { 634 Error *Error `json:"error"` 635 } 636 637 testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 638 639 w.WriteHeader(402) 640 data, err := json.Marshal(testServerResponse{Error: &Error{PaymentIntent: &PaymentIntent{ClientSecret: "SHOULDBEREDACTED"}}}) 641 assert.NoError(t, err) 642 643 _, err = w.Write(data) 644 assert.NoError(t, err) 645 646 })) 647 defer testServer.Close() 648 649 var logs bytes.Buffer 650 logger := &LeveledLogger{Level: LevelDebug, stderrOverride: &logs, stdoutOverride: &logs} 651 652 backend := GetBackendWithConfig( 653 APIBackend, 654 &BackendConfig{ 655 EnableTelemetry: Bool(true), 656 LeveledLogger: logger, 657 MaxNetworkRetries: Int64(0), 658 URL: String(testServer.URL), 659 }, 660 ).(*BackendImplementation) 661 662 request, err := backend.NewRequest( 663 http.MethodGet, 664 "/hello", 665 "sk_test_123", 666 "application/x-www-form-urlencoded", 667 nil, 668 ) 669 assert.NoError(t, err) 670 671 var response Charge 672 err = backend.Do(request, nil, &response) 673 assert.Error(t, err) 674 675 assert.NotContains(t, logs.String(), "SHOULDBEREDACTED") 676 assert.Contains(t, logs.String(), "REDACTED") 677 } 678 679 func TestDoStreaming(t *testing.T) { 680 testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 681 w.WriteHeader(200) 682 data := []byte("hello") 683 684 var err error 685 _, err = w.Write(data) 686 assert.NoError(t, err) 687 })) 688 defer testServer.Close() 689 690 var logs bytes.Buffer 691 logger := &LeveledLogger{Level: LevelDebug, stderrOverride: &logs, stdoutOverride: &logs} 692 693 backend := GetBackendWithConfig( 694 APIBackend, 695 &BackendConfig{ 696 EnableTelemetry: Bool(true), 697 LeveledLogger: logger, 698 MaxNetworkRetries: Int64(0), 699 URL: String(testServer.URL), 700 }, 701 ).(*BackendImplementation) 702 703 type streamingResource struct { 704 APIStream 705 } 706 707 response := streamingResource{} 708 err := backend.CallStreaming( 709 http.MethodGet, 710 "/pdf", 711 "sk_test_123", 712 nil, 713 &response, 714 ) 715 assert.NoError(t, err) 716 result, err := ioutil.ReadAll(response.LastResponse.Body) 717 assert.NoError(t, err) 718 err = response.LastResponse.Body.Close() 719 assert.NoError(t, err) 720 assert.Equal(t, "hello", string(result)) 721 } 722 723 func TestDoStreaming_ParsableError(t *testing.T) { 724 type testServerResponse struct { 725 Error *Error `json:"error"` 726 } 727 testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 728 w.WriteHeader(400) 729 var data []byte 730 var err error 731 data, err = json.Marshal(testServerResponse{Error: &Error{Msg: "Text of error"}}) 732 assert.NoError(t, err) 733 734 _, err = w.Write(data) 735 assert.NoError(t, err) 736 })) 737 defer testServer.Close() 738 739 var logs bytes.Buffer 740 logger := &LeveledLogger{Level: LevelDebug, stderrOverride: &logs, stdoutOverride: &logs} 741 742 backend := GetBackendWithConfig( 743 APIBackend, 744 &BackendConfig{ 745 EnableTelemetry: Bool(true), 746 LeveledLogger: logger, 747 MaxNetworkRetries: Int64(0), 748 URL: String(testServer.URL), 749 }, 750 ).(*BackendImplementation) 751 752 type streamingResource struct { 753 APIStream 754 } 755 756 response := streamingResource{} 757 err := backend.CallStreaming( 758 http.MethodGet, 759 "/pdf", 760 "sk_test_123", 761 nil, 762 &response, 763 ) 764 assert.NotNil(t, err) 765 stripeErr, ok := err.(*Error) 766 assert.True(t, ok) 767 assert.Equal(t, stripeErr.Msg, "Text of error") 768 } 769 770 func TestDoStreaming_UnparsableError(t *testing.T) { 771 testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 772 w.WriteHeader(400) 773 var data []byte 774 var err error 775 data = []byte("{invalid json}") 776 777 _, err = w.Write(data) 778 assert.NoError(t, err) 779 })) 780 defer testServer.Close() 781 782 var logs bytes.Buffer 783 logger := &LeveledLogger{Level: LevelDebug, stderrOverride: &logs, stdoutOverride: &logs} 784 785 backend := GetBackendWithConfig( 786 APIBackend, 787 &BackendConfig{ 788 EnableTelemetry: Bool(true), 789 LeveledLogger: logger, 790 MaxNetworkRetries: Int64(0), 791 URL: String(testServer.URL), 792 }, 793 ).(*BackendImplementation) 794 795 type streamingResource struct { 796 APIStream 797 } 798 799 response := streamingResource{} 800 err := backend.CallStreaming( 801 http.MethodGet, 802 "/pdf", 803 "sk_test_123", 804 nil, 805 &response, 806 ) 807 assert.NotNil(t, err) 808 _, ok := err.(*Error) 809 assert.False(t, ok) 810 assert.True(t, strings.Contains(err.Error(), "Couldn't deserialize JSON")) 811 } 812 813 func TestFormatURLPath(t *testing.T) { 814 assert.Equal(t, "/v1/resources/1/subresources/2", 815 FormatURLPath("/v1/resources/%s/subresources/%s", "1", "2")) 816 817 // Tests that each parameter is escaped for use in URLs 818 assert.Equal(t, "/v1/resources/%25", 819 FormatURLPath("/v1/resources/%s", "%")) 820 } 821 822 func TestGetBackendWithConfig_Loggers(t *testing.T) { 823 leveledLogger := &LeveledLogger{} 824 825 backend := GetBackendWithConfig( 826 APIBackend, 827 &BackendConfig{ 828 LeveledLogger: leveledLogger, 829 }, 830 ).(*BackendImplementation) 831 832 assert.Equal(t, leveledLogger, backend.LeveledLogger) 833 } 834 835 func TestGetBackendWithConfig_TrimV1Suffix(t *testing.T) { 836 { 837 backend := GetBackendWithConfig( 838 APIBackend, 839 &BackendConfig{ 840 URL: String("https://api.com/v1"), 841 }, 842 ).(*BackendImplementation) 843 844 // The `/v1` suffix has been stripped. 845 assert.Equal(t, "https://api.com", backend.URL) 846 } 847 848 // Also support trimming a `/v1/` with an extra trailing slash which is 849 // probably an often seen mistake. 850 { 851 backend := GetBackendWithConfig( 852 APIBackend, 853 &BackendConfig{ 854 URL: String("https://api.com/v1/"), 855 }, 856 ).(*BackendImplementation) 857 858 assert.Equal(t, "https://api.com", backend.URL) 859 } 860 861 // No-op otherwise. 862 { 863 backend := GetBackendWithConfig( 864 APIBackend, 865 &BackendConfig{ 866 URL: String("https://api.com"), 867 }, 868 ).(*BackendImplementation) 869 870 assert.Equal(t, "https://api.com", backend.URL) 871 } 872 } 873 874 func TestParseID(t *testing.T) { 875 // JSON string 876 { 877 id, ok := ParseID([]byte(`"ch_123"`)) 878 assert.Equal(t, "ch_123", id) 879 assert.True(t, ok) 880 } 881 882 // JSON object 883 { 884 id, ok := ParseID([]byte(`{"id":"ch_123"}`)) 885 assert.Equal(t, "", id) 886 assert.False(t, ok) 887 } 888 889 // Other JSON scalar (this should never be used, but check the results anyway) 890 { 891 id, ok := ParseID([]byte(`123`)) 892 assert.Equal(t, "", id) 893 assert.False(t, ok) 894 } 895 896 // Edge case that should never happen; found via fuzzing 897 { 898 id, ok := ParseID([]byte(`"`)) 899 assert.Equal(t, "", id) 900 assert.False(t, ok) 901 } 902 } 903 904 // TestMultipleAPICalls will fail the test run if a race condition is thrown while running multiple NewRequest calls. 905 func TestMultipleAPICalls(t *testing.T) { 906 wg := &sync.WaitGroup{} 907 for i := 0; i < 10; i++ { 908 wg.Add(1) 909 go func() { 910 defer wg.Done() 911 c := GetBackend(APIBackend).(*BackendImplementation) 912 key := "apiKey" 913 914 req, err := c.NewRequest("", "", key, "", nil) 915 assert.NoError(t, err) 916 917 assert.Equal(t, "Bearer "+key, req.Header.Get("Authorization")) 918 }() 919 } 920 wg.Wait() 921 } 922 923 func TestIdempotencyKey(t *testing.T) { 924 c := GetBackend(APIBackend).(*BackendImplementation) 925 p := &Params{IdempotencyKey: String("idempotency-key")} 926 927 req, err := c.NewRequest("", "", "", "", p) 928 assert.NoError(t, err) 929 930 assert.Equal(t, "idempotency-key", req.Header.Get("Idempotency-Key")) 931 } 932 933 func TestNewBackends(t *testing.T) { 934 httpClient := &http.Client{} 935 backends := NewBackends(httpClient) 936 assert.Equal(t, httpClient, backends.API.(*BackendImplementation).HTTPClient) 937 assert.Equal(t, httpClient, backends.Uploads.(*BackendImplementation).HTTPClient) 938 } 939 940 func TestStripeAccount(t *testing.T) { 941 c := GetBackend(APIBackend).(*BackendImplementation) 942 p := &Params{} 943 p.SetStripeAccount("acct_123") 944 945 req, err := c.NewRequest("", "", "", "", p) 946 assert.NoError(t, err) 947 948 assert.Equal(t, "acct_123", req.Header.Get("Stripe-Account")) 949 } 950 951 func TestErrorOnDuplicateMetadata(t *testing.T) { 952 c := GetBackend(APIBackend).(*BackendImplementation) 953 type myParams struct { 954 Params `form:"*"` 955 Metadata map[string]string `form:"metadata"` 956 } 957 958 metadata := map[string]string{"foo": "bar"} 959 resource := APIResource{} 960 err := c.Call("POST", "/v1/customers", "sk_test_xyz", &myParams{}, &resource) 961 assert.NoError(t, err) 962 963 err = 964 c.Call("POST", "/v1/customers", "sk_test_xyz", &myParams{Metadata: metadata}, &resource) 965 assert.NoError(t, err) 966 967 err = 968 c.Call("POST", "/v1/customers", "sk_test_xyz", &myParams{Params: Params{Metadata: metadata}}, &resource) 969 assert.NoError(t, err) 970 971 err = 972 c.Call("POST", "/v1/customers", "sk_test_xyz", &myParams{Metadata: metadata, Params: Params{Metadata: metadata}}, &resource) 973 assert.Errorf(t, err, "You cannot specify both the (deprecated) .Params.Metadata and .Metadata in myParams") 974 } 975 976 func TestErrorOnDuplicateExpand(t *testing.T) { 977 c := GetBackend(APIBackend).(*BackendImplementation) 978 type myParams struct { 979 Params `form:"*"` 980 Expand []*string `form:"expand"` 981 } 982 983 expand := []*string{String("foo"), String("bar")} 984 resource := APIResource{} 985 err := c.Call("POST", "/v1/customers", "sk_test_xyz", &myParams{}, &resource) 986 assert.NoError(t, err) 987 988 err = 989 c.Call("POST", "/v1/customers", "sk_test_xyz", &myParams{Expand: expand}, &resource) 990 assert.NoError(t, err) 991 992 err = 993 c.Call("POST", "/v1/customers", "sk_test_xyz", &myParams{ 994 Params: Params{Expand: expand}, 995 }, &resource) 996 assert.NoError(t, err) 997 998 err = 999 c.Call("POST", "/v1/customers", "sk_test_xyz", &myParams{ 1000 Expand: expand, Params: Params{Expand: expand}}, &resource) 1001 1002 assert.Errorf(t, err, "You cannot specify both the (deprecated) .Params.Expand and .Expand in myParams") 1003 } 1004 1005 func TestUnmarshalJSONVerbose(t *testing.T) { 1006 type testServerResponse struct { 1007 Message string `json:"message"` 1008 } 1009 1010 backend := GetBackend(APIBackend).(*BackendImplementation) 1011 1012 // Valid JSON 1013 { 1014 type testServerResponse struct { 1015 Message string `json:"message"` 1016 } 1017 1018 var sample testServerResponse 1019 err := backend.UnmarshalJSONVerbose(200, []byte(`{"message":"hello"}`), &sample) 1020 assert.NoError(t, err) 1021 assert.Equal(t, "hello", sample.Message) 1022 } 1023 1024 // Invalid JSON (short) 1025 { 1026 body := `server error` 1027 1028 var sample testServerResponse 1029 err := backend.UnmarshalJSONVerbose(200, []byte(body), &sample) 1030 assert.Regexp(t, 1031 fmt.Sprintf(`^Couldn't deserialize JSON \(response status: 200, body sample: '%s'\): invalid character`, body), 1032 err) 1033 } 1034 1035 // Invalid JSON (long, and therefore truncated) 1036 { 1037 // Assembles a body that's at least as long as our maximum sample. 1038 // body is ~130 characters * 5. 1039 bodyText := `this is a really long body that will be truncated when added to the error message to protect against dumping huge responses in logs ` 1040 body := bodyText + bodyText + bodyText + bodyText + bodyText 1041 1042 var sample testServerResponse 1043 err := backend.UnmarshalJSONVerbose(200, []byte(body), &sample) 1044 assert.Regexp(t, 1045 fmt.Sprintf(`^Couldn't deserialize JSON \(response status: 200, body sample: '%s ...'\): invalid character`, body[0:500]), 1046 err) 1047 } 1048 } 1049 1050 func TestUserAgent(t *testing.T) { 1051 c := GetBackend(APIBackend).(*BackendImplementation) 1052 1053 req, err := c.NewRequest("", "", "", "", nil) 1054 assert.NoError(t, err) 1055 1056 // We keep out version constant private to the package, so use a regexp 1057 // match instead. 1058 expectedPattern := regexp.MustCompile(`^Stripe/v1 GoBindings/[.\-\w\d]+$`) 1059 1060 match := expectedPattern.MatchString(req.Header.Get("User-Agent")) 1061 assert.True(t, match) 1062 } 1063 1064 func TestUserAgentWithAppInfo(t *testing.T) { 1065 appInfo := &AppInfo{ 1066 Name: "MyAwesomePlugin", 1067 PartnerID: "partner_1234", 1068 URL: "https://myawesomeplugin.info", 1069 Version: "1.2.34", 1070 } 1071 SetAppInfo(appInfo) 1072 defer SetAppInfo(nil) 1073 1074 c := GetBackend(APIBackend).(*BackendImplementation) 1075 1076 req, err := c.NewRequest("", "", "", "", nil) 1077 assert.NoError(t, err) 1078 1079 // 1080 // User-Agent 1081 // 1082 1083 // We keep out version constant private to the package, so use a regexp 1084 // match instead. 1085 expectedPattern := regexp.MustCompile(`^Stripe/v1 GoBindings/[.\-\w\d]+ MyAwesomePlugin/1.2.34 \(https://myawesomeplugin.info\)$`) 1086 1087 match := expectedPattern.MatchString(req.Header.Get("User-Agent")) 1088 assert.True(t, match) 1089 1090 // 1091 // X-Stripe-Client-User-Agent 1092 // 1093 1094 encodedUserAgent := req.Header.Get("X-Stripe-Client-User-Agent") 1095 assert.NotEmpty(t, encodedUserAgent) 1096 1097 var userAgent map[string]interface{} 1098 err = json.Unmarshal([]byte(encodedUserAgent), &userAgent) 1099 assert.NoError(t, err) 1100 1101 application := userAgent["application"].(map[string]interface{}) 1102 1103 assert.Equal(t, "MyAwesomePlugin", application["name"]) 1104 assert.Equal(t, "partner_1234", application["partner_id"]) 1105 assert.Equal(t, "https://myawesomeplugin.info", application["url"]) 1106 assert.Equal(t, "1.2.34", application["version"]) 1107 } 1108 1109 func TestStripeClientUserAgent(t *testing.T) { 1110 c := GetBackend(APIBackend).(*BackendImplementation) 1111 1112 req, err := c.NewRequest("", "", "", "", nil) 1113 assert.NoError(t, err) 1114 1115 encodedUserAgent := req.Header.Get("X-Stripe-Client-User-Agent") 1116 assert.NotEmpty(t, encodedUserAgent) 1117 1118 var userAgent map[string]string 1119 err = json.Unmarshal([]byte(encodedUserAgent), &userAgent) 1120 assert.NoError(t, err) 1121 1122 // 1123 // Just test a few headers that we know to be stable. 1124 // 1125 1126 assert.Empty(t, userAgent["application"]) 1127 assert.Equal(t, "go", userAgent["lang"]) 1128 assert.Equal(t, runtime.Version(), userAgent["lang_version"]) 1129 1130 // Anywhere these tests are running can reasonable be expected to have a 1131 // `uname` to run, so do this basic check. 1132 assert.NotEqual(t, UnknownPlatform, userAgent["lang_version"]) 1133 } 1134 1135 func TestStripeClientUserAgentWithAppInfo(t *testing.T) { 1136 appInfo := &AppInfo{ 1137 Name: "MyAwesomePlugin", 1138 URL: "https://myawesomeplugin.info", 1139 Version: "1.2.34", 1140 } 1141 SetAppInfo(appInfo) 1142 defer SetAppInfo(nil) 1143 1144 c := GetBackend(APIBackend).(*BackendImplementation) 1145 1146 req, err := c.NewRequest("", "", "", "", nil) 1147 assert.NoError(t, err) 1148 1149 encodedUserAgent := req.Header.Get("X-Stripe-Client-User-Agent") 1150 assert.NotEmpty(t, encodedUserAgent) 1151 1152 var userAgent map[string]interface{} 1153 err = json.Unmarshal([]byte(encodedUserAgent), &userAgent) 1154 assert.NoError(t, err) 1155 1156 decodedAppInfo := userAgent["application"].(map[string]interface{}) 1157 assert.Equal(t, appInfo.Name, decodedAppInfo["name"]) 1158 assert.Equal(t, appInfo.URL, decodedAppInfo["url"]) 1159 assert.Equal(t, appInfo.Version, decodedAppInfo["version"]) 1160 } 1161 1162 func TestResponseToError(t *testing.T) { 1163 c := GetBackend(APIBackend).(*BackendImplementation) 1164 1165 // A test response that includes a status code and request ID. 1166 res := &http.Response{ 1167 Header: http.Header{ 1168 "Request-Id": []string{"request-id"}, 1169 }, 1170 StatusCode: 402, 1171 } 1172 1173 // An error that contains expected fields which we're going to serialize to 1174 // JSON and inject into our conversion function. 1175 expectedErr := &Error{ 1176 Code: ErrorCodeMissing, 1177 Msg: "That card was declined", 1178 Param: "expiry_date", 1179 Type: ErrorTypeCard, 1180 } 1181 bytes, err := json.Marshal(expectedErr) 1182 assert.NoError(t, err) 1183 1184 // Unpack the error that we just serialized so that we can inject a 1185 // type-specific field into it ("decline_code"). This will show up in a 1186 // field on a special CardError type which is attached to the common 1187 // Error. 1188 var raw map[string]string 1189 err = json.Unmarshal(bytes, &raw) 1190 assert.NoError(t, err) 1191 1192 expectedDeclineCode := DeclineCodeInvalidCVC 1193 raw["decline_code"] = string(expectedDeclineCode) 1194 bytes, err = json.Marshal(raw) 1195 assert.NoError(t, err) 1196 1197 // A generic Golang error. 1198 err = c.ResponseToError(res, wrapError(bytes)) 1199 1200 // An error containing Stripe-specific fields that we cast back from the 1201 // generic Golang error. 1202 stripeErr := err.(*Error) 1203 1204 assert.Equal(t, expectedErr.Code, stripeErr.Code) 1205 assert.Equal(t, expectedErr.Msg, stripeErr.Msg) 1206 assert.Equal(t, expectedErr.Param, stripeErr.Param) 1207 assert.Equal(t, res.Header.Get("Request-Id"), stripeErr.RequestID) 1208 assert.Equal(t, res.StatusCode, stripeErr.HTTPStatusCode) 1209 assert.Equal(t, expectedErr.Type, stripeErr.Type) 1210 assert.Equal(t, expectedDeclineCode, stripeErr.DeclineCode) 1211 1212 // Not exhaustive, but verify LastResponse is basically working as 1213 // expected. 1214 assert.Equal(t, res.Header.Get("Request-Id"), stripeErr.LastResponse.RequestID) 1215 assert.Equal(t, res.StatusCode, stripeErr.LastResponse.StatusCode) 1216 1217 // Just a bogus type coercion to demonstrate how this code might be 1218 // written. Because we've assigned ErrorTypeCard as the error's type, Err 1219 // should always come out as a CardError. 1220 _, ok := stripeErr.Err.(*InvalidRequestError) 1221 assert.False(t, ok) 1222 1223 cardErr, ok := stripeErr.Err.(*CardError) 1224 assert.True(t, ok) 1225 1226 // For backwards compatibility, `DeclineCode` is also set on the 1227 // `CardError` structure. 1228 assert.Equal(t, expectedDeclineCode, cardErr.DeclineCode) 1229 } 1230 1231 func TestStringSlice(t *testing.T) { 1232 input := []string{"a", "b", "c"} 1233 result := StringSlice(input) 1234 1235 assert.Equal(t, "a", *result[0]) 1236 assert.Equal(t, "b", *result[1]) 1237 assert.Equal(t, "c", *result[2]) 1238 1239 assert.Equal(t, 0, len(StringSlice(nil))) 1240 } 1241 1242 func TestInt64Slice(t *testing.T) { 1243 input := []int64{8, 7, 6} 1244 result := Int64Slice(input) 1245 1246 assert.Equal(t, int64(8), *result[0]) 1247 assert.Equal(t, int64(7), *result[1]) 1248 assert.Equal(t, int64(6), *result[2]) 1249 1250 assert.Equal(t, 0, len(Int64Slice(nil))) 1251 } 1252 1253 func TestFloat64Slice(t *testing.T) { 1254 input := []float64{8, 7, 6} 1255 result := Float64Slice(input) 1256 1257 assert.Equal(t, float64(8), *result[0]) 1258 assert.Equal(t, float64(7), *result[1]) 1259 assert.Equal(t, float64(6), *result[2]) 1260 1261 assert.Equal(t, 0, len(Float64Slice(nil))) 1262 } 1263 1264 func TestBoolSlice(t *testing.T) { 1265 input := []bool{true, false, true, false} 1266 result := BoolSlice(input) 1267 1268 assert.Equal(t, true, *result[0]) 1269 assert.Equal(t, false, *result[1]) 1270 assert.Equal(t, true, *result[2]) 1271 assert.Equal(t, false, *result[3]) 1272 1273 assert.Equal(t, 0, len(BoolSlice(nil))) 1274 } 1275 1276 // 1277 // --- 1278 // 1279 1280 // A simple function that allows us to represent an error response from Stripe 1281 // which comes wrapper in a JSON object with a single field of "error". 1282 func wrapError(serialized []byte) []byte { 1283 return []byte(`{"error":` + string(serialized) + `}`) 1284 }