google.golang.org/grpc@v1.74.2/credentials/sts/sts_test.go (about)

     1  /*
     2   *
     3   * Copyright 2020 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package sts
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"crypto/x509"
    25  	"encoding/json"
    26  	"errors"
    27  	"fmt"
    28  	"io"
    29  	"net/http"
    30  	"net/http/httputil"
    31  	"strings"
    32  	"testing"
    33  	"time"
    34  
    35  	"github.com/google/go-cmp/cmp"
    36  
    37  	"google.golang.org/grpc/credentials"
    38  	"google.golang.org/grpc/internal/grpctest"
    39  	"google.golang.org/grpc/internal/testutils"
    40  )
    41  
    42  const (
    43  	requestedTokenType      = "urn:ietf:params:oauth:token-type:access-token"
    44  	actorTokenPath          = "/var/run/secrets/token.jwt"
    45  	actorTokenType          = "urn:ietf:params:oauth:token-type:refresh_token"
    46  	actorTokenContents      = "actorToken.jwt.contents"
    47  	accessTokenContents     = "access_token"
    48  	subjectTokenPath        = "/var/run/secrets/token.jwt"
    49  	subjectTokenType        = "urn:ietf:params:oauth:token-type:id_token"
    50  	subjectTokenContents    = "subjectToken.jwt.contents"
    51  	serviceURI              = "http://localhost"
    52  	exampleResource         = "https://backend.example.com/api"
    53  	exampleAudience         = "example-backend-service"
    54  	testScope               = "https://www.googleapis.com/auth/monitoring"
    55  	defaultTestTimeout      = 1 * time.Second
    56  	defaultTestShortTimeout = 10 * time.Millisecond
    57  )
    58  
    59  var (
    60  	goodOptions = Options{
    61  		TokenExchangeServiceURI: serviceURI,
    62  		Audience:                exampleAudience,
    63  		RequestedTokenType:      requestedTokenType,
    64  		SubjectTokenPath:        subjectTokenPath,
    65  		SubjectTokenType:        subjectTokenType,
    66  	}
    67  	goodRequestParams = &requestParameters{
    68  		GrantType:          tokenExchangeGrantType,
    69  		Audience:           exampleAudience,
    70  		Scope:              defaultCloudPlatformScope,
    71  		RequestedTokenType: requestedTokenType,
    72  		SubjectToken:       subjectTokenContents,
    73  		SubjectTokenType:   subjectTokenType,
    74  	}
    75  	goodMetadata = map[string]string{
    76  		"Authorization": fmt.Sprintf("Bearer %s", accessTokenContents),
    77  	}
    78  )
    79  
    80  type s struct {
    81  	grpctest.Tester
    82  }
    83  
    84  func Test(t *testing.T) {
    85  	grpctest.RunSubTests(t, s{})
    86  }
    87  
    88  // A struct that implements AuthInfo interface and added to the context passed
    89  // to GetRequestMetadata from tests.
    90  type testAuthInfo struct {
    91  	credentials.CommonAuthInfo
    92  }
    93  
    94  func (ta testAuthInfo) AuthType() string {
    95  	return "testAuthInfo"
    96  }
    97  
    98  func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context {
    99  	auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}}
   100  	ri := credentials.RequestInfo{
   101  		Method:   "testInfo",
   102  		AuthInfo: auth,
   103  	}
   104  	return credentials.NewContextWithRequestInfo(ctx, ri)
   105  }
   106  
   107  // errReader implements the io.Reader interface and returns an error from the
   108  // Read method.
   109  type errReader struct{}
   110  
   111  func (r errReader) Read([]byte) (n int, err error) {
   112  	return 0, errors.New("read error")
   113  }
   114  
   115  // We need a function to construct the response instead of simply declaring it
   116  // as a variable since the response body will be consumed by the
   117  // credentials, and therefore we will need a new one everytime.
   118  func makeGoodResponse() *http.Response {
   119  	respJSON, _ := json.Marshal(responseParameters{
   120  		AccessToken:     accessTokenContents,
   121  		IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
   122  		TokenType:       "Bearer",
   123  		ExpiresIn:       3600,
   124  	})
   125  	respBody := io.NopCloser(bytes.NewReader(respJSON))
   126  	return &http.Response{
   127  		Status:     "200 OK",
   128  		StatusCode: http.StatusOK,
   129  		Body:       respBody,
   130  	}
   131  }
   132  
   133  // Overrides the http.Client with a fakeClient which sends a good response.
   134  func overrideHTTPClientGood() (*testutils.FakeHTTPClient, func()) {
   135  	fc := &testutils.FakeHTTPClient{
   136  		ReqChan:  testutils.NewChannel(),
   137  		RespChan: testutils.NewChannel(),
   138  	}
   139  	fc.RespChan.Send(makeGoodResponse())
   140  
   141  	origMakeHTTPDoer := makeHTTPDoer
   142  	makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
   143  	return fc, func() { makeHTTPDoer = origMakeHTTPDoer }
   144  }
   145  
   146  // Overrides the http.Client with the provided fakeClient.
   147  func overrideHTTPClient(fc *testutils.FakeHTTPClient) func() {
   148  	origMakeHTTPDoer := makeHTTPDoer
   149  	makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
   150  	return func() { makeHTTPDoer = origMakeHTTPDoer }
   151  }
   152  
   153  // Overrides the subject token read to return a const which we can compare in
   154  // our tests.
   155  func overrideSubjectTokenGood() func() {
   156  	origReadSubjectTokenFrom := readSubjectTokenFrom
   157  	readSubjectTokenFrom = func(string) ([]byte, error) {
   158  		return []byte(subjectTokenContents), nil
   159  	}
   160  	return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
   161  }
   162  
   163  // Overrides the subject token read to always return an error.
   164  func overrideSubjectTokenError() func() {
   165  	origReadSubjectTokenFrom := readSubjectTokenFrom
   166  	readSubjectTokenFrom = func(string) ([]byte, error) {
   167  		return nil, errors.New("error reading subject token")
   168  	}
   169  	return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
   170  }
   171  
   172  // Overrides the actor token read to return a const which we can compare in
   173  // our tests.
   174  func overrideActorTokenGood() func() {
   175  	origReadActorTokenFrom := readActorTokenFrom
   176  	readActorTokenFrom = func(string) ([]byte, error) {
   177  		return []byte(actorTokenContents), nil
   178  	}
   179  	return func() { readActorTokenFrom = origReadActorTokenFrom }
   180  }
   181  
   182  // Overrides the actor token read to always return an error.
   183  func overrideActorTokenError() func() {
   184  	origReadActorTokenFrom := readActorTokenFrom
   185  	readActorTokenFrom = func(string) ([]byte, error) {
   186  		return nil, errors.New("error reading actor token")
   187  	}
   188  	return func() { readActorTokenFrom = origReadActorTokenFrom }
   189  }
   190  
   191  // compareRequest compares the http.Request received in the test with the
   192  // expected requestParameters specified in wantReqParams.
   193  func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error {
   194  	jsonBody, err := json.Marshal(wantReqParams)
   195  	if err != nil {
   196  		return err
   197  	}
   198  	wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody))
   199  	if err != nil {
   200  		return fmt.Errorf("failed to create http request: %v", err)
   201  	}
   202  	wantReq.Header.Set("Content-Type", "application/json")
   203  
   204  	wantR, err := httputil.DumpRequestOut(wantReq, true)
   205  	if err != nil {
   206  		return err
   207  	}
   208  	gotR, err := httputil.DumpRequestOut(gotRequest, true)
   209  	if err != nil {
   210  		return err
   211  	}
   212  	if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" {
   213  		return fmt.Errorf("sts request diff (-want +got):\n%s", diff)
   214  	}
   215  	return nil
   216  }
   217  
   218  // receiveAndCompareRequest waits for a request to be sent out by the
   219  // credentials implementation using the fakeHTTPClient and compares it to an
   220  // expected goodRequest. This is expected to be called in a separate goroutine
   221  // by the tests. So, any errors encountered are pushed to an error channel
   222  // which is monitored by the test.
   223  func receiveAndCompareRequest(ReqChan *testutils.Channel, errCh chan error) {
   224  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   225  	defer cancel()
   226  
   227  	val, err := ReqChan.Receive(ctx)
   228  	if err != nil {
   229  		errCh <- err
   230  		return
   231  	}
   232  	req := val.(*http.Request)
   233  	if err := compareRequest(req, goodRequestParams); err != nil {
   234  		errCh <- err
   235  		return
   236  	}
   237  	errCh <- nil
   238  }
   239  
   240  // TestGetRequestMetadataSuccess verifies the successful case of sending an
   241  // token exchange request and processing the response.
   242  func (s) TestGetRequestMetadataSuccess(t *testing.T) {
   243  	defer overrideSubjectTokenGood()()
   244  	fc, cancel := overrideHTTPClientGood()
   245  	defer cancel()
   246  
   247  	creds, err := NewCredentials(goodOptions)
   248  	if err != nil {
   249  		t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
   250  	}
   251  
   252  	errCh := make(chan error, 1)
   253  	go receiveAndCompareRequest(fc.ReqChan, errCh)
   254  
   255  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   256  	defer cancel()
   257  
   258  	gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
   259  	if err != nil {
   260  		t.Fatalf("creds.GetRequestMetadata() = %v", err)
   261  	}
   262  	if !cmp.Equal(gotMetadata, goodMetadata) {
   263  		t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
   264  	}
   265  	if err := <-errCh; err != nil {
   266  		t.Fatal(err)
   267  	}
   268  
   269  	// Make another call to get request metadata and this should return contents
   270  	// from the cache. This will fail if the credentials tries to send a fresh
   271  	// request here since we have not configured our fakeClient to return any
   272  	// response on retries.
   273  	gotMetadata, err = creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
   274  	if err != nil {
   275  		t.Fatalf("creds.GetRequestMetadata() = %v", err)
   276  	}
   277  	if !cmp.Equal(gotMetadata, goodMetadata) {
   278  		t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
   279  	}
   280  }
   281  
   282  // TestGetRequestMetadataBadSecurityLevel verifies the case where the
   283  // securityLevel specified in the context passed to GetRequestMetadata is not
   284  // sufficient.
   285  func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) {
   286  	defer overrideSubjectTokenGood()()
   287  
   288  	creds, err := NewCredentials(goodOptions)
   289  	if err != nil {
   290  		t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
   291  	}
   292  
   293  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   294  	defer cancel()
   295  	gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.IntegrityOnly), "")
   296  	if err == nil {
   297  		t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata)
   298  	}
   299  }
   300  
   301  // TestGetRequestMetadataCacheExpiry verifies the case where the cached access
   302  // token has expired, and the credentials implementation will have to send a
   303  // fresh token exchange request.
   304  func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) {
   305  	const expiresInSecs = 1
   306  	defer overrideSubjectTokenGood()()
   307  	fc := &testutils.FakeHTTPClient{
   308  		ReqChan:  testutils.NewChannel(),
   309  		RespChan: testutils.NewChannel(),
   310  	}
   311  	defer overrideHTTPClient(fc)()
   312  
   313  	creds, err := NewCredentials(goodOptions)
   314  	if err != nil {
   315  		t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
   316  	}
   317  
   318  	// The fakeClient is configured to return an access_token with a one second
   319  	// expiry. So, in the second iteration, the credentials will find the cache
   320  	// entry, but that would have expired, and therefore we expect it to send
   321  	// out a fresh request.
   322  	for i := 0; i < 2; i++ {
   323  		errCh := make(chan error, 1)
   324  		go receiveAndCompareRequest(fc.ReqChan, errCh)
   325  
   326  		respJSON, _ := json.Marshal(responseParameters{
   327  			AccessToken:     accessTokenContents,
   328  			IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
   329  			TokenType:       "Bearer",
   330  			ExpiresIn:       expiresInSecs,
   331  		})
   332  		respBody := io.NopCloser(bytes.NewReader(respJSON))
   333  		resp := &http.Response{
   334  			Status:     "200 OK",
   335  			StatusCode: http.StatusOK,
   336  			Body:       respBody,
   337  		}
   338  		fc.RespChan.Send(resp)
   339  
   340  		ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   341  		defer cancel()
   342  		gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
   343  		if err != nil {
   344  			t.Fatalf("creds.GetRequestMetadata() = %v", err)
   345  		}
   346  		if !cmp.Equal(gotMetadata, goodMetadata) {
   347  			t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
   348  		}
   349  		if err := <-errCh; err != nil {
   350  			t.Fatal(err)
   351  		}
   352  		time.Sleep(expiresInSecs * time.Second)
   353  	}
   354  }
   355  
   356  // TestGetRequestMetadataBadResponses verifies the scenario where the token
   357  // exchange server returns bad responses.
   358  func (s) TestGetRequestMetadataBadResponses(t *testing.T) {
   359  	tests := []struct {
   360  		name     string
   361  		response *http.Response
   362  	}{
   363  		{
   364  			name: "bad JSON",
   365  			response: &http.Response{
   366  				Status:     "200 OK",
   367  				StatusCode: http.StatusOK,
   368  				Body:       io.NopCloser(strings.NewReader("not JSON")),
   369  			},
   370  		},
   371  		{
   372  			name: "no access token",
   373  			response: &http.Response{
   374  				Status:     "200 OK",
   375  				StatusCode: http.StatusOK,
   376  				Body:       io.NopCloser(strings.NewReader("{}")),
   377  			},
   378  		},
   379  	}
   380  
   381  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   382  	defer cancel()
   383  	for _, test := range tests {
   384  		t.Run(test.name, func(t *testing.T) {
   385  			defer overrideSubjectTokenGood()()
   386  
   387  			fc := &testutils.FakeHTTPClient{
   388  				ReqChan:  testutils.NewChannel(),
   389  				RespChan: testutils.NewChannel(),
   390  			}
   391  			defer overrideHTTPClient(fc)()
   392  
   393  			creds, err := NewCredentials(goodOptions)
   394  			if err != nil {
   395  				t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
   396  			}
   397  
   398  			errCh := make(chan error, 1)
   399  			go receiveAndCompareRequest(fc.ReqChan, errCh)
   400  
   401  			fc.RespChan.Send(test.response)
   402  			if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil {
   403  				t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
   404  			}
   405  			if err := <-errCh; err != nil {
   406  				t.Fatal(err)
   407  			}
   408  		})
   409  	}
   410  }
   411  
   412  // TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the
   413  // attempt to read the subjectToken fails.
   414  func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) {
   415  	defer overrideSubjectTokenError()()
   416  	fc, cancel := overrideHTTPClientGood()
   417  	defer cancel()
   418  
   419  	creds, err := NewCredentials(goodOptions)
   420  	if err != nil {
   421  		t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
   422  	}
   423  
   424  	errCh := make(chan error, 1)
   425  	go func() {
   426  		ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
   427  		defer cancel()
   428  		if _, err := fc.ReqChan.Receive(ctx); err != context.DeadlineExceeded {
   429  			errCh <- err
   430  			return
   431  		}
   432  		errCh <- nil
   433  	}()
   434  
   435  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   436  	defer cancel()
   437  	if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil {
   438  		t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
   439  	}
   440  	if err := <-errCh; err != nil {
   441  		t.Fatal(err)
   442  	}
   443  }
   444  
   445  func (s) TestNewCredentials(t *testing.T) {
   446  	tests := []struct {
   447  		name           string
   448  		opts           Options
   449  		errSystemRoots bool
   450  		wantErr        bool
   451  	}{
   452  		{
   453  			name: "invalid options - empty subjectTokenPath",
   454  			opts: Options{
   455  				TokenExchangeServiceURI: serviceURI,
   456  			},
   457  			wantErr: true,
   458  		},
   459  		{
   460  			name:           "invalid system root certs",
   461  			opts:           goodOptions,
   462  			errSystemRoots: true,
   463  			wantErr:        true,
   464  		},
   465  		{
   466  			name: "good case",
   467  			opts: goodOptions,
   468  		},
   469  	}
   470  
   471  	for _, test := range tests {
   472  		t.Run(test.name, func(t *testing.T) {
   473  			if test.errSystemRoots {
   474  				oldSystemRoots := loadSystemCertPool
   475  				loadSystemCertPool = func() (*x509.CertPool, error) {
   476  					return nil, errors.New("failed to load system cert pool")
   477  				}
   478  				defer func() {
   479  					loadSystemCertPool = oldSystemRoots
   480  				}()
   481  			}
   482  
   483  			creds, err := NewCredentials(test.opts)
   484  			if (err != nil) != test.wantErr {
   485  				t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr)
   486  			}
   487  			if err == nil {
   488  				if !creds.RequireTransportSecurity() {
   489  					t.Errorf("creds.RequireTransportSecurity() returned false")
   490  				}
   491  			}
   492  		})
   493  	}
   494  }
   495  
   496  func (s) TestValidateOptions(t *testing.T) {
   497  	tests := []struct {
   498  		name          string
   499  		opts          Options
   500  		wantErrPrefix string
   501  	}{
   502  		{
   503  			name:          "empty token exchange service URI",
   504  			opts:          Options{},
   505  			wantErrPrefix: "empty token_exchange_service_uri in options",
   506  		},
   507  		{
   508  			name: "invalid URI",
   509  			opts: Options{
   510  				TokenExchangeServiceURI: "\tI'm a bad URI\n",
   511  			},
   512  			wantErrPrefix: "invalid control character in URL",
   513  		},
   514  		{
   515  			name: "unsupported scheme",
   516  			opts: Options{
   517  				TokenExchangeServiceURI: "unix:///path/to/socket",
   518  			},
   519  			wantErrPrefix: "scheme is not supported",
   520  		},
   521  		{
   522  			name: "empty subjectTokenPath",
   523  			opts: Options{
   524  				TokenExchangeServiceURI: serviceURI,
   525  			},
   526  			wantErrPrefix: "required field SubjectTokenPath is not specified",
   527  		},
   528  		{
   529  			name: "empty subjectTokenType",
   530  			opts: Options{
   531  				TokenExchangeServiceURI: serviceURI,
   532  				SubjectTokenPath:        subjectTokenPath,
   533  			},
   534  			wantErrPrefix: "required field SubjectTokenType is not specified",
   535  		},
   536  		{
   537  			name: "good options",
   538  			opts: goodOptions,
   539  		},
   540  	}
   541  
   542  	for _, test := range tests {
   543  		t.Run(test.name, func(t *testing.T) {
   544  			err := validateOptions(test.opts)
   545  			if (err != nil) != (test.wantErrPrefix != "") {
   546  				t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
   547  			}
   548  			if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) {
   549  				t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
   550  			}
   551  		})
   552  	}
   553  }
   554  
   555  func (s) TestConstructRequest(t *testing.T) {
   556  	tests := []struct {
   557  		name                string
   558  		opts                Options
   559  		subjectTokenReadErr bool
   560  		actorTokenReadErr   bool
   561  		wantReqParams       *requestParameters
   562  		wantErr             bool
   563  	}{
   564  		{
   565  			name:                "subject token read failure",
   566  			subjectTokenReadErr: true,
   567  			opts:                goodOptions,
   568  			wantErr:             true,
   569  		},
   570  		{
   571  			name:              "actor token read failure",
   572  			actorTokenReadErr: true,
   573  			opts: Options{
   574  				TokenExchangeServiceURI: serviceURI,
   575  				Audience:                exampleAudience,
   576  				RequestedTokenType:      requestedTokenType,
   577  				SubjectTokenPath:        subjectTokenPath,
   578  				SubjectTokenType:        subjectTokenType,
   579  				ActorTokenPath:          actorTokenPath,
   580  				ActorTokenType:          actorTokenType,
   581  			},
   582  			wantErr: true,
   583  		},
   584  		{
   585  			name:          "default cloud platform scope",
   586  			opts:          goodOptions,
   587  			wantReqParams: goodRequestParams,
   588  		},
   589  		{
   590  			name: "all good",
   591  			opts: Options{
   592  				TokenExchangeServiceURI: serviceURI,
   593  				Resource:                exampleResource,
   594  				Audience:                exampleAudience,
   595  				Scope:                   testScope,
   596  				RequestedTokenType:      requestedTokenType,
   597  				SubjectTokenPath:        subjectTokenPath,
   598  				SubjectTokenType:        subjectTokenType,
   599  				ActorTokenPath:          actorTokenPath,
   600  				ActorTokenType:          actorTokenType,
   601  			},
   602  			wantReqParams: &requestParameters{
   603  				GrantType:          tokenExchangeGrantType,
   604  				Resource:           exampleResource,
   605  				Audience:           exampleAudience,
   606  				Scope:              testScope,
   607  				RequestedTokenType: requestedTokenType,
   608  				SubjectToken:       subjectTokenContents,
   609  				SubjectTokenType:   subjectTokenType,
   610  				ActorToken:         actorTokenContents,
   611  				ActorTokenType:     actorTokenType,
   612  			},
   613  		},
   614  	}
   615  
   616  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   617  	defer cancel()
   618  	for _, test := range tests {
   619  		t.Run(test.name, func(t *testing.T) {
   620  			if test.subjectTokenReadErr {
   621  				defer overrideSubjectTokenError()()
   622  			} else {
   623  				defer overrideSubjectTokenGood()()
   624  			}
   625  
   626  			if test.actorTokenReadErr {
   627  				defer overrideActorTokenError()()
   628  			} else {
   629  				defer overrideActorTokenGood()()
   630  			}
   631  
   632  			gotRequest, err := constructRequest(ctx, test.opts)
   633  			if (err != nil) != test.wantErr {
   634  				t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr)
   635  			}
   636  			if test.wantErr {
   637  				return
   638  			}
   639  			if err := compareRequest(gotRequest, test.wantReqParams); err != nil {
   640  				t.Fatal(err)
   641  			}
   642  		})
   643  	}
   644  }
   645  
   646  func (s) TestSendRequest(t *testing.T) {
   647  	defer overrideSubjectTokenGood()()
   648  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   649  	defer cancel()
   650  	req, err := constructRequest(ctx, goodOptions)
   651  	if err != nil {
   652  		t.Fatal(err)
   653  	}
   654  
   655  	tests := []struct {
   656  		name    string
   657  		resp    *http.Response
   658  		respErr error
   659  		wantErr bool
   660  	}{
   661  		{
   662  			name:    "client error",
   663  			respErr: errors.New("http.Client.Do failed"),
   664  			wantErr: true,
   665  		},
   666  		{
   667  			name: "bad response body",
   668  			resp: &http.Response{
   669  				Status:     "200 OK",
   670  				StatusCode: http.StatusOK,
   671  				Body:       io.NopCloser(errReader{}),
   672  			},
   673  			wantErr: true,
   674  		},
   675  		{
   676  			name: "nonOK status code",
   677  			resp: &http.Response{
   678  				Status:     "400 BadRequest",
   679  				StatusCode: http.StatusBadRequest,
   680  				Body:       io.NopCloser(strings.NewReader("")),
   681  			},
   682  			wantErr: true,
   683  		},
   684  		{
   685  			name: "good case",
   686  			resp: makeGoodResponse(),
   687  		},
   688  	}
   689  
   690  	for _, test := range tests {
   691  		t.Run(test.name, func(t *testing.T) {
   692  			client := &testutils.FakeHTTPClient{
   693  				ReqChan:  testutils.NewChannel(),
   694  				RespChan: testutils.NewChannel(),
   695  				Err:      test.respErr,
   696  			}
   697  			client.RespChan.Send(test.resp)
   698  			_, err := sendRequest(client, req)
   699  			if (err != nil) != test.wantErr {
   700  				t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr)
   701  			}
   702  		})
   703  	}
   704  }
   705  
   706  func (s) TestTokenInfoFromResponse(t *testing.T) {
   707  	noAccessToken, _ := json.Marshal(responseParameters{
   708  		IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
   709  		TokenType:       "Bearer",
   710  		ExpiresIn:       3600,
   711  	})
   712  	goodResponse, _ := json.Marshal(responseParameters{
   713  		IssuedTokenType: requestedTokenType,
   714  		AccessToken:     accessTokenContents,
   715  		TokenType:       "Bearer",
   716  		ExpiresIn:       3600,
   717  	})
   718  
   719  	tests := []struct {
   720  		name          string
   721  		respBody      []byte
   722  		wantTokenInfo *tokenInfo
   723  		wantErr       bool
   724  	}{
   725  		{
   726  			name:     "bad JSON",
   727  			respBody: []byte("not JSON"),
   728  			wantErr:  true,
   729  		},
   730  		{
   731  			name:     "empty response",
   732  			respBody: []byte(""),
   733  			wantErr:  true,
   734  		},
   735  		{
   736  			name:     "non-empty response with no access token",
   737  			respBody: noAccessToken,
   738  			wantErr:  true,
   739  		},
   740  		{
   741  			name:     "good response",
   742  			respBody: goodResponse,
   743  			wantTokenInfo: &tokenInfo{
   744  				tokenType: "Bearer",
   745  				token:     accessTokenContents,
   746  			},
   747  		},
   748  	}
   749  
   750  	for _, test := range tests {
   751  		t.Run(test.name, func(t *testing.T) {
   752  			gotTokenInfo, err := tokenInfoFromResponse(test.respBody)
   753  			if (err != nil) != test.wantErr {
   754  				t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr)
   755  			}
   756  			if test.wantErr {
   757  				return
   758  			}
   759  			// Can't do a cmp.Equal on the whole struct since the expiryField
   760  			// is populated based on time.Now().
   761  			if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token {
   762  				t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo)
   763  			}
   764  		})
   765  	}
   766  }