google.golang.org/grpc@v1.74.2/credentials/sts/sts_test.go (about) 1 /* 2 * 3 * Copyright 2020 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package sts 20 21 import ( 22 "bytes" 23 "context" 24 "crypto/x509" 25 "encoding/json" 26 "errors" 27 "fmt" 28 "io" 29 "net/http" 30 "net/http/httputil" 31 "strings" 32 "testing" 33 "time" 34 35 "github.com/google/go-cmp/cmp" 36 37 "google.golang.org/grpc/credentials" 38 "google.golang.org/grpc/internal/grpctest" 39 "google.golang.org/grpc/internal/testutils" 40 ) 41 42 const ( 43 requestedTokenType = "urn:ietf:params:oauth:token-type:access-token" 44 actorTokenPath = "/var/run/secrets/token.jwt" 45 actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token" 46 actorTokenContents = "actorToken.jwt.contents" 47 accessTokenContents = "access_token" 48 subjectTokenPath = "/var/run/secrets/token.jwt" 49 subjectTokenType = "urn:ietf:params:oauth:token-type:id_token" 50 subjectTokenContents = "subjectToken.jwt.contents" 51 serviceURI = "http://localhost" 52 exampleResource = "https://backend.example.com/api" 53 exampleAudience = "example-backend-service" 54 testScope = "https://www.googleapis.com/auth/monitoring" 55 defaultTestTimeout = 1 * time.Second 56 defaultTestShortTimeout = 10 * time.Millisecond 57 ) 58 59 var ( 60 goodOptions = Options{ 61 TokenExchangeServiceURI: serviceURI, 62 Audience: exampleAudience, 63 RequestedTokenType: requestedTokenType, 64 SubjectTokenPath: subjectTokenPath, 65 SubjectTokenType: subjectTokenType, 66 } 67 goodRequestParams = &requestParameters{ 68 GrantType: tokenExchangeGrantType, 69 Audience: exampleAudience, 70 Scope: defaultCloudPlatformScope, 71 RequestedTokenType: requestedTokenType, 72 SubjectToken: subjectTokenContents, 73 SubjectTokenType: subjectTokenType, 74 } 75 goodMetadata = map[string]string{ 76 "Authorization": fmt.Sprintf("Bearer %s", accessTokenContents), 77 } 78 ) 79 80 type s struct { 81 grpctest.Tester 82 } 83 84 func Test(t *testing.T) { 85 grpctest.RunSubTests(t, s{}) 86 } 87 88 // A struct that implements AuthInfo interface and added to the context passed 89 // to GetRequestMetadata from tests. 90 type testAuthInfo struct { 91 credentials.CommonAuthInfo 92 } 93 94 func (ta testAuthInfo) AuthType() string { 95 return "testAuthInfo" 96 } 97 98 func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context { 99 auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}} 100 ri := credentials.RequestInfo{ 101 Method: "testInfo", 102 AuthInfo: auth, 103 } 104 return credentials.NewContextWithRequestInfo(ctx, ri) 105 } 106 107 // errReader implements the io.Reader interface and returns an error from the 108 // Read method. 109 type errReader struct{} 110 111 func (r errReader) Read([]byte) (n int, err error) { 112 return 0, errors.New("read error") 113 } 114 115 // We need a function to construct the response instead of simply declaring it 116 // as a variable since the response body will be consumed by the 117 // credentials, and therefore we will need a new one everytime. 118 func makeGoodResponse() *http.Response { 119 respJSON, _ := json.Marshal(responseParameters{ 120 AccessToken: accessTokenContents, 121 IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", 122 TokenType: "Bearer", 123 ExpiresIn: 3600, 124 }) 125 respBody := io.NopCloser(bytes.NewReader(respJSON)) 126 return &http.Response{ 127 Status: "200 OK", 128 StatusCode: http.StatusOK, 129 Body: respBody, 130 } 131 } 132 133 // Overrides the http.Client with a fakeClient which sends a good response. 134 func overrideHTTPClientGood() (*testutils.FakeHTTPClient, func()) { 135 fc := &testutils.FakeHTTPClient{ 136 ReqChan: testutils.NewChannel(), 137 RespChan: testutils.NewChannel(), 138 } 139 fc.RespChan.Send(makeGoodResponse()) 140 141 origMakeHTTPDoer := makeHTTPDoer 142 makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } 143 return fc, func() { makeHTTPDoer = origMakeHTTPDoer } 144 } 145 146 // Overrides the http.Client with the provided fakeClient. 147 func overrideHTTPClient(fc *testutils.FakeHTTPClient) func() { 148 origMakeHTTPDoer := makeHTTPDoer 149 makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } 150 return func() { makeHTTPDoer = origMakeHTTPDoer } 151 } 152 153 // Overrides the subject token read to return a const which we can compare in 154 // our tests. 155 func overrideSubjectTokenGood() func() { 156 origReadSubjectTokenFrom := readSubjectTokenFrom 157 readSubjectTokenFrom = func(string) ([]byte, error) { 158 return []byte(subjectTokenContents), nil 159 } 160 return func() { readSubjectTokenFrom = origReadSubjectTokenFrom } 161 } 162 163 // Overrides the subject token read to always return an error. 164 func overrideSubjectTokenError() func() { 165 origReadSubjectTokenFrom := readSubjectTokenFrom 166 readSubjectTokenFrom = func(string) ([]byte, error) { 167 return nil, errors.New("error reading subject token") 168 } 169 return func() { readSubjectTokenFrom = origReadSubjectTokenFrom } 170 } 171 172 // Overrides the actor token read to return a const which we can compare in 173 // our tests. 174 func overrideActorTokenGood() func() { 175 origReadActorTokenFrom := readActorTokenFrom 176 readActorTokenFrom = func(string) ([]byte, error) { 177 return []byte(actorTokenContents), nil 178 } 179 return func() { readActorTokenFrom = origReadActorTokenFrom } 180 } 181 182 // Overrides the actor token read to always return an error. 183 func overrideActorTokenError() func() { 184 origReadActorTokenFrom := readActorTokenFrom 185 readActorTokenFrom = func(string) ([]byte, error) { 186 return nil, errors.New("error reading actor token") 187 } 188 return func() { readActorTokenFrom = origReadActorTokenFrom } 189 } 190 191 // compareRequest compares the http.Request received in the test with the 192 // expected requestParameters specified in wantReqParams. 193 func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error { 194 jsonBody, err := json.Marshal(wantReqParams) 195 if err != nil { 196 return err 197 } 198 wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody)) 199 if err != nil { 200 return fmt.Errorf("failed to create http request: %v", err) 201 } 202 wantReq.Header.Set("Content-Type", "application/json") 203 204 wantR, err := httputil.DumpRequestOut(wantReq, true) 205 if err != nil { 206 return err 207 } 208 gotR, err := httputil.DumpRequestOut(gotRequest, true) 209 if err != nil { 210 return err 211 } 212 if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" { 213 return fmt.Errorf("sts request diff (-want +got):\n%s", diff) 214 } 215 return nil 216 } 217 218 // receiveAndCompareRequest waits for a request to be sent out by the 219 // credentials implementation using the fakeHTTPClient and compares it to an 220 // expected goodRequest. This is expected to be called in a separate goroutine 221 // by the tests. So, any errors encountered are pushed to an error channel 222 // which is monitored by the test. 223 func receiveAndCompareRequest(ReqChan *testutils.Channel, errCh chan error) { 224 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 225 defer cancel() 226 227 val, err := ReqChan.Receive(ctx) 228 if err != nil { 229 errCh <- err 230 return 231 } 232 req := val.(*http.Request) 233 if err := compareRequest(req, goodRequestParams); err != nil { 234 errCh <- err 235 return 236 } 237 errCh <- nil 238 } 239 240 // TestGetRequestMetadataSuccess verifies the successful case of sending an 241 // token exchange request and processing the response. 242 func (s) TestGetRequestMetadataSuccess(t *testing.T) { 243 defer overrideSubjectTokenGood()() 244 fc, cancel := overrideHTTPClientGood() 245 defer cancel() 246 247 creds, err := NewCredentials(goodOptions) 248 if err != nil { 249 t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) 250 } 251 252 errCh := make(chan error, 1) 253 go receiveAndCompareRequest(fc.ReqChan, errCh) 254 255 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 256 defer cancel() 257 258 gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "") 259 if err != nil { 260 t.Fatalf("creds.GetRequestMetadata() = %v", err) 261 } 262 if !cmp.Equal(gotMetadata, goodMetadata) { 263 t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) 264 } 265 if err := <-errCh; err != nil { 266 t.Fatal(err) 267 } 268 269 // Make another call to get request metadata and this should return contents 270 // from the cache. This will fail if the credentials tries to send a fresh 271 // request here since we have not configured our fakeClient to return any 272 // response on retries. 273 gotMetadata, err = creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "") 274 if err != nil { 275 t.Fatalf("creds.GetRequestMetadata() = %v", err) 276 } 277 if !cmp.Equal(gotMetadata, goodMetadata) { 278 t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) 279 } 280 } 281 282 // TestGetRequestMetadataBadSecurityLevel verifies the case where the 283 // securityLevel specified in the context passed to GetRequestMetadata is not 284 // sufficient. 285 func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) { 286 defer overrideSubjectTokenGood()() 287 288 creds, err := NewCredentials(goodOptions) 289 if err != nil { 290 t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) 291 } 292 293 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 294 defer cancel() 295 gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.IntegrityOnly), "") 296 if err == nil { 297 t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata) 298 } 299 } 300 301 // TestGetRequestMetadataCacheExpiry verifies the case where the cached access 302 // token has expired, and the credentials implementation will have to send a 303 // fresh token exchange request. 304 func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) { 305 const expiresInSecs = 1 306 defer overrideSubjectTokenGood()() 307 fc := &testutils.FakeHTTPClient{ 308 ReqChan: testutils.NewChannel(), 309 RespChan: testutils.NewChannel(), 310 } 311 defer overrideHTTPClient(fc)() 312 313 creds, err := NewCredentials(goodOptions) 314 if err != nil { 315 t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) 316 } 317 318 // The fakeClient is configured to return an access_token with a one second 319 // expiry. So, in the second iteration, the credentials will find the cache 320 // entry, but that would have expired, and therefore we expect it to send 321 // out a fresh request. 322 for i := 0; i < 2; i++ { 323 errCh := make(chan error, 1) 324 go receiveAndCompareRequest(fc.ReqChan, errCh) 325 326 respJSON, _ := json.Marshal(responseParameters{ 327 AccessToken: accessTokenContents, 328 IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", 329 TokenType: "Bearer", 330 ExpiresIn: expiresInSecs, 331 }) 332 respBody := io.NopCloser(bytes.NewReader(respJSON)) 333 resp := &http.Response{ 334 Status: "200 OK", 335 StatusCode: http.StatusOK, 336 Body: respBody, 337 } 338 fc.RespChan.Send(resp) 339 340 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 341 defer cancel() 342 gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "") 343 if err != nil { 344 t.Fatalf("creds.GetRequestMetadata() = %v", err) 345 } 346 if !cmp.Equal(gotMetadata, goodMetadata) { 347 t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) 348 } 349 if err := <-errCh; err != nil { 350 t.Fatal(err) 351 } 352 time.Sleep(expiresInSecs * time.Second) 353 } 354 } 355 356 // TestGetRequestMetadataBadResponses verifies the scenario where the token 357 // exchange server returns bad responses. 358 func (s) TestGetRequestMetadataBadResponses(t *testing.T) { 359 tests := []struct { 360 name string 361 response *http.Response 362 }{ 363 { 364 name: "bad JSON", 365 response: &http.Response{ 366 Status: "200 OK", 367 StatusCode: http.StatusOK, 368 Body: io.NopCloser(strings.NewReader("not JSON")), 369 }, 370 }, 371 { 372 name: "no access token", 373 response: &http.Response{ 374 Status: "200 OK", 375 StatusCode: http.StatusOK, 376 Body: io.NopCloser(strings.NewReader("{}")), 377 }, 378 }, 379 } 380 381 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 382 defer cancel() 383 for _, test := range tests { 384 t.Run(test.name, func(t *testing.T) { 385 defer overrideSubjectTokenGood()() 386 387 fc := &testutils.FakeHTTPClient{ 388 ReqChan: testutils.NewChannel(), 389 RespChan: testutils.NewChannel(), 390 } 391 defer overrideHTTPClient(fc)() 392 393 creds, err := NewCredentials(goodOptions) 394 if err != nil { 395 t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) 396 } 397 398 errCh := make(chan error, 1) 399 go receiveAndCompareRequest(fc.ReqChan, errCh) 400 401 fc.RespChan.Send(test.response) 402 if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil { 403 t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail") 404 } 405 if err := <-errCh; err != nil { 406 t.Fatal(err) 407 } 408 }) 409 } 410 } 411 412 // TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the 413 // attempt to read the subjectToken fails. 414 func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) { 415 defer overrideSubjectTokenError()() 416 fc, cancel := overrideHTTPClientGood() 417 defer cancel() 418 419 creds, err := NewCredentials(goodOptions) 420 if err != nil { 421 t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) 422 } 423 424 errCh := make(chan error, 1) 425 go func() { 426 ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) 427 defer cancel() 428 if _, err := fc.ReqChan.Receive(ctx); err != context.DeadlineExceeded { 429 errCh <- err 430 return 431 } 432 errCh <- nil 433 }() 434 435 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 436 defer cancel() 437 if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil { 438 t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail") 439 } 440 if err := <-errCh; err != nil { 441 t.Fatal(err) 442 } 443 } 444 445 func (s) TestNewCredentials(t *testing.T) { 446 tests := []struct { 447 name string 448 opts Options 449 errSystemRoots bool 450 wantErr bool 451 }{ 452 { 453 name: "invalid options - empty subjectTokenPath", 454 opts: Options{ 455 TokenExchangeServiceURI: serviceURI, 456 }, 457 wantErr: true, 458 }, 459 { 460 name: "invalid system root certs", 461 opts: goodOptions, 462 errSystemRoots: true, 463 wantErr: true, 464 }, 465 { 466 name: "good case", 467 opts: goodOptions, 468 }, 469 } 470 471 for _, test := range tests { 472 t.Run(test.name, func(t *testing.T) { 473 if test.errSystemRoots { 474 oldSystemRoots := loadSystemCertPool 475 loadSystemCertPool = func() (*x509.CertPool, error) { 476 return nil, errors.New("failed to load system cert pool") 477 } 478 defer func() { 479 loadSystemCertPool = oldSystemRoots 480 }() 481 } 482 483 creds, err := NewCredentials(test.opts) 484 if (err != nil) != test.wantErr { 485 t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr) 486 } 487 if err == nil { 488 if !creds.RequireTransportSecurity() { 489 t.Errorf("creds.RequireTransportSecurity() returned false") 490 } 491 } 492 }) 493 } 494 } 495 496 func (s) TestValidateOptions(t *testing.T) { 497 tests := []struct { 498 name string 499 opts Options 500 wantErrPrefix string 501 }{ 502 { 503 name: "empty token exchange service URI", 504 opts: Options{}, 505 wantErrPrefix: "empty token_exchange_service_uri in options", 506 }, 507 { 508 name: "invalid URI", 509 opts: Options{ 510 TokenExchangeServiceURI: "\tI'm a bad URI\n", 511 }, 512 wantErrPrefix: "invalid control character in URL", 513 }, 514 { 515 name: "unsupported scheme", 516 opts: Options{ 517 TokenExchangeServiceURI: "unix:///path/to/socket", 518 }, 519 wantErrPrefix: "scheme is not supported", 520 }, 521 { 522 name: "empty subjectTokenPath", 523 opts: Options{ 524 TokenExchangeServiceURI: serviceURI, 525 }, 526 wantErrPrefix: "required field SubjectTokenPath is not specified", 527 }, 528 { 529 name: "empty subjectTokenType", 530 opts: Options{ 531 TokenExchangeServiceURI: serviceURI, 532 SubjectTokenPath: subjectTokenPath, 533 }, 534 wantErrPrefix: "required field SubjectTokenType is not specified", 535 }, 536 { 537 name: "good options", 538 opts: goodOptions, 539 }, 540 } 541 542 for _, test := range tests { 543 t.Run(test.name, func(t *testing.T) { 544 err := validateOptions(test.opts) 545 if (err != nil) != (test.wantErrPrefix != "") { 546 t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix) 547 } 548 if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) { 549 t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix) 550 } 551 }) 552 } 553 } 554 555 func (s) TestConstructRequest(t *testing.T) { 556 tests := []struct { 557 name string 558 opts Options 559 subjectTokenReadErr bool 560 actorTokenReadErr bool 561 wantReqParams *requestParameters 562 wantErr bool 563 }{ 564 { 565 name: "subject token read failure", 566 subjectTokenReadErr: true, 567 opts: goodOptions, 568 wantErr: true, 569 }, 570 { 571 name: "actor token read failure", 572 actorTokenReadErr: true, 573 opts: Options{ 574 TokenExchangeServiceURI: serviceURI, 575 Audience: exampleAudience, 576 RequestedTokenType: requestedTokenType, 577 SubjectTokenPath: subjectTokenPath, 578 SubjectTokenType: subjectTokenType, 579 ActorTokenPath: actorTokenPath, 580 ActorTokenType: actorTokenType, 581 }, 582 wantErr: true, 583 }, 584 { 585 name: "default cloud platform scope", 586 opts: goodOptions, 587 wantReqParams: goodRequestParams, 588 }, 589 { 590 name: "all good", 591 opts: Options{ 592 TokenExchangeServiceURI: serviceURI, 593 Resource: exampleResource, 594 Audience: exampleAudience, 595 Scope: testScope, 596 RequestedTokenType: requestedTokenType, 597 SubjectTokenPath: subjectTokenPath, 598 SubjectTokenType: subjectTokenType, 599 ActorTokenPath: actorTokenPath, 600 ActorTokenType: actorTokenType, 601 }, 602 wantReqParams: &requestParameters{ 603 GrantType: tokenExchangeGrantType, 604 Resource: exampleResource, 605 Audience: exampleAudience, 606 Scope: testScope, 607 RequestedTokenType: requestedTokenType, 608 SubjectToken: subjectTokenContents, 609 SubjectTokenType: subjectTokenType, 610 ActorToken: actorTokenContents, 611 ActorTokenType: actorTokenType, 612 }, 613 }, 614 } 615 616 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 617 defer cancel() 618 for _, test := range tests { 619 t.Run(test.name, func(t *testing.T) { 620 if test.subjectTokenReadErr { 621 defer overrideSubjectTokenError()() 622 } else { 623 defer overrideSubjectTokenGood()() 624 } 625 626 if test.actorTokenReadErr { 627 defer overrideActorTokenError()() 628 } else { 629 defer overrideActorTokenGood()() 630 } 631 632 gotRequest, err := constructRequest(ctx, test.opts) 633 if (err != nil) != test.wantErr { 634 t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr) 635 } 636 if test.wantErr { 637 return 638 } 639 if err := compareRequest(gotRequest, test.wantReqParams); err != nil { 640 t.Fatal(err) 641 } 642 }) 643 } 644 } 645 646 func (s) TestSendRequest(t *testing.T) { 647 defer overrideSubjectTokenGood()() 648 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 649 defer cancel() 650 req, err := constructRequest(ctx, goodOptions) 651 if err != nil { 652 t.Fatal(err) 653 } 654 655 tests := []struct { 656 name string 657 resp *http.Response 658 respErr error 659 wantErr bool 660 }{ 661 { 662 name: "client error", 663 respErr: errors.New("http.Client.Do failed"), 664 wantErr: true, 665 }, 666 { 667 name: "bad response body", 668 resp: &http.Response{ 669 Status: "200 OK", 670 StatusCode: http.StatusOK, 671 Body: io.NopCloser(errReader{}), 672 }, 673 wantErr: true, 674 }, 675 { 676 name: "nonOK status code", 677 resp: &http.Response{ 678 Status: "400 BadRequest", 679 StatusCode: http.StatusBadRequest, 680 Body: io.NopCloser(strings.NewReader("")), 681 }, 682 wantErr: true, 683 }, 684 { 685 name: "good case", 686 resp: makeGoodResponse(), 687 }, 688 } 689 690 for _, test := range tests { 691 t.Run(test.name, func(t *testing.T) { 692 client := &testutils.FakeHTTPClient{ 693 ReqChan: testutils.NewChannel(), 694 RespChan: testutils.NewChannel(), 695 Err: test.respErr, 696 } 697 client.RespChan.Send(test.resp) 698 _, err := sendRequest(client, req) 699 if (err != nil) != test.wantErr { 700 t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr) 701 } 702 }) 703 } 704 } 705 706 func (s) TestTokenInfoFromResponse(t *testing.T) { 707 noAccessToken, _ := json.Marshal(responseParameters{ 708 IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", 709 TokenType: "Bearer", 710 ExpiresIn: 3600, 711 }) 712 goodResponse, _ := json.Marshal(responseParameters{ 713 IssuedTokenType: requestedTokenType, 714 AccessToken: accessTokenContents, 715 TokenType: "Bearer", 716 ExpiresIn: 3600, 717 }) 718 719 tests := []struct { 720 name string 721 respBody []byte 722 wantTokenInfo *tokenInfo 723 wantErr bool 724 }{ 725 { 726 name: "bad JSON", 727 respBody: []byte("not JSON"), 728 wantErr: true, 729 }, 730 { 731 name: "empty response", 732 respBody: []byte(""), 733 wantErr: true, 734 }, 735 { 736 name: "non-empty response with no access token", 737 respBody: noAccessToken, 738 wantErr: true, 739 }, 740 { 741 name: "good response", 742 respBody: goodResponse, 743 wantTokenInfo: &tokenInfo{ 744 tokenType: "Bearer", 745 token: accessTokenContents, 746 }, 747 }, 748 } 749 750 for _, test := range tests { 751 t.Run(test.name, func(t *testing.T) { 752 gotTokenInfo, err := tokenInfoFromResponse(test.respBody) 753 if (err != nil) != test.wantErr { 754 t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr) 755 } 756 if test.wantErr { 757 return 758 } 759 // Can't do a cmp.Equal on the whole struct since the expiryField 760 // is populated based on time.Now(). 761 if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token { 762 t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo) 763 } 764 }) 765 } 766 }