gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/credentials/sts/sts_test.go (about)

     1  /*
     2   *
     3   * Copyright 2020 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package sts
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"io/ioutil"
    28  	"strings"
    29  	"testing"
    30  	"time"
    31  
    32  	"gitee.com/ks-custle/core-gm/gmhttp/httputil"
    33  
    34  	http "gitee.com/ks-custle/core-gm/gmhttp"
    35  	"gitee.com/ks-custle/core-gm/grpc/credentials"
    36  	icredentials "gitee.com/ks-custle/core-gm/grpc/internal/credentials"
    37  	"gitee.com/ks-custle/core-gm/grpc/internal/grpctest"
    38  	"gitee.com/ks-custle/core-gm/grpc/internal/testutils"
    39  	"gitee.com/ks-custle/core-gm/x509"
    40  	"github.com/google/go-cmp/cmp"
    41  )
    42  
    43  const (
    44  	requestedTokenType      = "urn:ietf:params:oauth:token-type:access-token"
    45  	actorTokenPath          = "/var/run/secrets/token.jwt"
    46  	actorTokenType          = "urn:ietf:params:oauth:token-type:refresh_token"
    47  	actorTokenContents      = "actorToken.jwt.contents"
    48  	accessTokenContents     = "access_token"
    49  	subjectTokenPath        = "/var/run/secrets/token.jwt"
    50  	subjectTokenType        = "urn:ietf:params:oauth:token-type:id_token"
    51  	subjectTokenContents    = "subjectToken.jwt.contents"
    52  	serviceURI              = "http://localhost"
    53  	exampleResource         = "https://backend.example.com/api"
    54  	exampleAudience         = "example-backend-service"
    55  	testScope               = "https://www.googleapis.com/auth/monitoring"
    56  	defaultTestTimeout      = 1 * time.Second
    57  	defaultTestShortTimeout = 10 * time.Millisecond
    58  )
    59  
    60  var (
    61  	goodOptions = Options{
    62  		TokenExchangeServiceURI: serviceURI,
    63  		Audience:                exampleAudience,
    64  		RequestedTokenType:      requestedTokenType,
    65  		SubjectTokenPath:        subjectTokenPath,
    66  		SubjectTokenType:        subjectTokenType,
    67  	}
    68  	goodRequestParams = &requestParameters{
    69  		GrantType:          tokenExchangeGrantType,
    70  		Audience:           exampleAudience,
    71  		Scope:              defaultCloudPlatformScope,
    72  		RequestedTokenType: requestedTokenType,
    73  		SubjectToken:       subjectTokenContents,
    74  		SubjectTokenType:   subjectTokenType,
    75  	}
    76  	goodMetadata = map[string]string{
    77  		"Authorization": fmt.Sprintf("Bearer %s", accessTokenContents),
    78  	}
    79  )
    80  
    81  type s struct {
    82  	grpctest.Tester
    83  }
    84  
    85  func Test(t *testing.T) {
    86  	grpctest.RunSubTests(t, s{})
    87  }
    88  
    89  // A struct that implements AuthInfo interface and added to the context passed
    90  // to GetRequestMetadata from tests.
    91  type testAuthInfo struct {
    92  	credentials.CommonAuthInfo
    93  }
    94  
    95  func (ta testAuthInfo) AuthType() string {
    96  	return "testAuthInfo"
    97  }
    98  
    99  func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context {
   100  	auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}}
   101  	ri := credentials.RequestInfo{
   102  		Method:   "testInfo",
   103  		AuthInfo: auth,
   104  	}
   105  	return icredentials.NewRequestInfoContext(ctx, ri)
   106  }
   107  
   108  // errReader implements the io.Reader interface and returns an error from the
   109  // Read method.
   110  type errReader struct{}
   111  
   112  func (r errReader) Read(b []byte) (n int, err error) {
   113  	return 0, errors.New("read error")
   114  }
   115  
   116  // We need a function to construct the response instead of simply declaring it
   117  // as a variable since the the response body will be consumed by the
   118  // credentials, and therefore we will need a new one everytime.
   119  func makeGoodResponse() *http.Response {
   120  	respJSON, _ := json.Marshal(responseParameters{
   121  		AccessToken:     accessTokenContents,
   122  		IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
   123  		TokenType:       "Bearer",
   124  		ExpiresIn:       3600,
   125  	})
   126  	respBody := ioutil.NopCloser(bytes.NewReader(respJSON))
   127  	return &http.Response{
   128  		Status:     "200 OK",
   129  		StatusCode: http.StatusOK,
   130  		Body:       respBody,
   131  	}
   132  }
   133  
   134  // Overrides the http.Client with a fakeClient which sends a good response.
   135  func overrideHTTPClientGood() (*testutils.FakeHTTPClient, func()) {
   136  	fc := &testutils.FakeHTTPClient{
   137  		ReqChan:  testutils.NewChannel(),
   138  		RespChan: testutils.NewChannel(),
   139  	}
   140  	fc.RespChan.Send(makeGoodResponse())
   141  
   142  	origMakeHTTPDoer := makeHTTPDoer
   143  	makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
   144  	return fc, func() { makeHTTPDoer = origMakeHTTPDoer }
   145  }
   146  
   147  // Overrides the http.Client with the provided fakeClient.
   148  func overrideHTTPClient(fc *testutils.FakeHTTPClient) func() {
   149  	origMakeHTTPDoer := makeHTTPDoer
   150  	makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
   151  	return func() { makeHTTPDoer = origMakeHTTPDoer }
   152  }
   153  
   154  // Overrides the subject token read to return a const which we can compare in
   155  // our tests.
   156  func overrideSubjectTokenGood() func() {
   157  	origReadSubjectTokenFrom := readSubjectTokenFrom
   158  	readSubjectTokenFrom = func(path string) ([]byte, error) {
   159  		return []byte(subjectTokenContents), nil
   160  	}
   161  	return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
   162  }
   163  
   164  // Overrides the subject token read to always return an error.
   165  func overrideSubjectTokenError() func() {
   166  	origReadSubjectTokenFrom := readSubjectTokenFrom
   167  	readSubjectTokenFrom = func(path string) ([]byte, error) {
   168  		return nil, errors.New("error reading subject token")
   169  	}
   170  	return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
   171  }
   172  
   173  // Overrides the actor token read to return a const which we can compare in
   174  // our tests.
   175  func overrideActorTokenGood() func() {
   176  	origReadActorTokenFrom := readActorTokenFrom
   177  	readActorTokenFrom = func(path string) ([]byte, error) {
   178  		return []byte(actorTokenContents), nil
   179  	}
   180  	return func() { readActorTokenFrom = origReadActorTokenFrom }
   181  }
   182  
   183  // Overrides the actor token read to always return an error.
   184  func overrideActorTokenError() func() {
   185  	origReadActorTokenFrom := readActorTokenFrom
   186  	readActorTokenFrom = func(path string) ([]byte, error) {
   187  		return nil, errors.New("error reading actor token")
   188  	}
   189  	return func() { readActorTokenFrom = origReadActorTokenFrom }
   190  }
   191  
   192  // compareRequest compares the http.Request received in the test with the
   193  // expected requestParameters specified in wantReqParams.
   194  func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error {
   195  	jsonBody, err := json.Marshal(wantReqParams)
   196  	if err != nil {
   197  		return err
   198  	}
   199  	wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody))
   200  	if err != nil {
   201  		return fmt.Errorf("failed to create http request: %v", err)
   202  	}
   203  	wantReq.Header.Set("Content-Type", "application/json")
   204  
   205  	wantR, err := httputil.DumpRequestOut(wantReq, true)
   206  	if err != nil {
   207  		return err
   208  	}
   209  	gotR, err := httputil.DumpRequestOut(gotRequest, true)
   210  	if err != nil {
   211  		return err
   212  	}
   213  	if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" {
   214  		return fmt.Errorf("sts request diff (-want +got):\n%s", diff)
   215  	}
   216  	return nil
   217  }
   218  
   219  // receiveAndCompareRequest waits for a request to be sent out by the
   220  // credentials implementation using the fakeHTTPClient and compares it to an
   221  // expected goodRequest. This is expected to be called in a separate goroutine
   222  // by the tests. So, any errors encountered are pushed to an error channel
   223  // which is monitored by the test.
   224  func receiveAndCompareRequest(ReqChan *testutils.Channel, errCh chan error) {
   225  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   226  	defer cancel()
   227  
   228  	val, err := ReqChan.Receive(ctx)
   229  	if err != nil {
   230  		errCh <- err
   231  		return
   232  	}
   233  	req := val.(*http.Request)
   234  	if err := compareRequest(req, goodRequestParams); err != nil {
   235  		errCh <- err
   236  		return
   237  	}
   238  	errCh <- nil
   239  }
   240  
   241  // TestGetRequestMetadataSuccess verifies the successful case of sending an
   242  // token exchange request and processing the response.
   243  func (s) TestGetRequestMetadataSuccess(t *testing.T) {
   244  	defer overrideSubjectTokenGood()()
   245  	fc, cancel := overrideHTTPClientGood()
   246  	defer cancel()
   247  
   248  	creds, err := NewCredentials(goodOptions)
   249  	if err != nil {
   250  		t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
   251  	}
   252  
   253  	errCh := make(chan error, 1)
   254  	go receiveAndCompareRequest(fc.ReqChan, errCh)
   255  
   256  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   257  	defer cancel()
   258  
   259  	gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
   260  	if err != nil {
   261  		t.Fatalf("creds.GetRequestMetadata() = %v", err)
   262  	}
   263  	if !cmp.Equal(gotMetadata, goodMetadata) {
   264  		t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
   265  	}
   266  	if err := <-errCh; err != nil {
   267  		t.Fatal(err)
   268  	}
   269  
   270  	// Make another call to get request metadata and this should return contents
   271  	// from the cache. This will fail if the credentials tries to send a fresh
   272  	// request here since we have not configured our fakeClient to return any
   273  	// response on retries.
   274  	gotMetadata, err = creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
   275  	if err != nil {
   276  		t.Fatalf("creds.GetRequestMetadata() = %v", err)
   277  	}
   278  	if !cmp.Equal(gotMetadata, goodMetadata) {
   279  		t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
   280  	}
   281  }
   282  
   283  // TestGetRequestMetadataBadSecurityLevel verifies the case where the
   284  // securityLevel specified in the context passed to GetRequestMetadata is not
   285  // sufficient.
   286  func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) {
   287  	defer overrideSubjectTokenGood()()
   288  
   289  	creds, err := NewCredentials(goodOptions)
   290  	if err != nil {
   291  		t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
   292  	}
   293  
   294  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   295  	defer cancel()
   296  	gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.IntegrityOnly), "")
   297  	if err == nil {
   298  		t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata)
   299  	}
   300  }
   301  
   302  // TestGetRequestMetadataCacheExpiry verifies the case where the cached access
   303  // token has expired, and the credentials implementation will have to send a
   304  // fresh token exchange request.
   305  func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) {
   306  	const expiresInSecs = 1
   307  	defer overrideSubjectTokenGood()()
   308  	fc := &testutils.FakeHTTPClient{
   309  		ReqChan:  testutils.NewChannel(),
   310  		RespChan: testutils.NewChannel(),
   311  	}
   312  	defer overrideHTTPClient(fc)()
   313  
   314  	creds, err := NewCredentials(goodOptions)
   315  	if err != nil {
   316  		t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
   317  	}
   318  
   319  	// The fakeClient is configured to return an access_token with a one second
   320  	// expiry. So, in the second iteration, the credentials will find the cache
   321  	// entry, but that would have expired, and therefore we expect it to send
   322  	// out a fresh request.
   323  	for i := 0; i < 2; i++ {
   324  		errCh := make(chan error, 1)
   325  		go receiveAndCompareRequest(fc.ReqChan, errCh)
   326  
   327  		respJSON, _ := json.Marshal(responseParameters{
   328  			AccessToken:     accessTokenContents,
   329  			IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
   330  			TokenType:       "Bearer",
   331  			ExpiresIn:       expiresInSecs,
   332  		})
   333  		respBody := ioutil.NopCloser(bytes.NewReader(respJSON))
   334  		resp := &http.Response{
   335  			Status:     "200 OK",
   336  			StatusCode: http.StatusOK,
   337  			Body:       respBody,
   338  		}
   339  		fc.RespChan.Send(resp)
   340  
   341  		ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   342  		defer cancel()
   343  		gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
   344  		if err != nil {
   345  			t.Fatalf("creds.GetRequestMetadata() = %v", err)
   346  		}
   347  		if !cmp.Equal(gotMetadata, goodMetadata) {
   348  			t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
   349  		}
   350  		if err := <-errCh; err != nil {
   351  			t.Fatal(err)
   352  		}
   353  		time.Sleep(expiresInSecs * time.Second)
   354  	}
   355  }
   356  
   357  // TestGetRequestMetadataBadResponses verifies the scenario where the token
   358  // exchange server returns bad responses.
   359  func (s) TestGetRequestMetadataBadResponses(t *testing.T) {
   360  	tests := []struct {
   361  		name     string
   362  		response *http.Response
   363  	}{
   364  		{
   365  			name: "bad JSON",
   366  			response: &http.Response{
   367  				Status:     "200 OK",
   368  				StatusCode: http.StatusOK,
   369  				Body:       ioutil.NopCloser(strings.NewReader("not JSON")),
   370  			},
   371  		},
   372  		{
   373  			name: "no access token",
   374  			response: &http.Response{
   375  				Status:     "200 OK",
   376  				StatusCode: http.StatusOK,
   377  				Body:       ioutil.NopCloser(strings.NewReader("{}")),
   378  			},
   379  		},
   380  	}
   381  
   382  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   383  	defer cancel()
   384  	for _, test := range tests {
   385  		t.Run(test.name, func(t *testing.T) {
   386  			defer overrideSubjectTokenGood()()
   387  
   388  			fc := &testutils.FakeHTTPClient{
   389  				ReqChan:  testutils.NewChannel(),
   390  				RespChan: testutils.NewChannel(),
   391  			}
   392  			defer overrideHTTPClient(fc)()
   393  
   394  			creds, err := NewCredentials(goodOptions)
   395  			if err != nil {
   396  				t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
   397  			}
   398  
   399  			errCh := make(chan error, 1)
   400  			go receiveAndCompareRequest(fc.ReqChan, errCh)
   401  
   402  			fc.RespChan.Send(test.response)
   403  			if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil {
   404  				t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
   405  			}
   406  			if err := <-errCh; err != nil {
   407  				t.Fatal(err)
   408  			}
   409  		})
   410  	}
   411  }
   412  
   413  // TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the
   414  // attempt to read the subjectToken fails.
   415  func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) {
   416  	defer overrideSubjectTokenError()()
   417  	fc, cancel := overrideHTTPClientGood()
   418  	defer cancel()
   419  
   420  	creds, err := NewCredentials(goodOptions)
   421  	if err != nil {
   422  		t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
   423  	}
   424  
   425  	errCh := make(chan error, 1)
   426  	go func() {
   427  		ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
   428  		defer cancel()
   429  		if _, err := fc.ReqChan.Receive(ctx); err != context.DeadlineExceeded {
   430  			errCh <- err
   431  			return
   432  		}
   433  		errCh <- nil
   434  	}()
   435  
   436  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   437  	defer cancel()
   438  	if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil {
   439  		t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
   440  	}
   441  	if err := <-errCh; err != nil {
   442  		t.Fatal(err)
   443  	}
   444  }
   445  
   446  func (s) TestNewCredentials(t *testing.T) {
   447  	tests := []struct {
   448  		name           string
   449  		opts           Options
   450  		errSystemRoots bool
   451  		wantErr        bool
   452  	}{
   453  		{
   454  			name: "invalid options - empty subjectTokenPath",
   455  			opts: Options{
   456  				TokenExchangeServiceURI: serviceURI,
   457  			},
   458  			wantErr: true,
   459  		},
   460  		{
   461  			name:           "invalid system root certs",
   462  			opts:           goodOptions,
   463  			errSystemRoots: true,
   464  			wantErr:        true,
   465  		},
   466  		{
   467  			name: "good case",
   468  			opts: goodOptions,
   469  		},
   470  	}
   471  
   472  	for _, test := range tests {
   473  		t.Run(test.name, func(t *testing.T) {
   474  			if test.errSystemRoots {
   475  				oldSystemRoots := loadSystemCertPool
   476  				loadSystemCertPool = func() (*x509.CertPool, error) {
   477  					return nil, errors.New("failed to load system cert pool")
   478  				}
   479  				defer func() {
   480  					loadSystemCertPool = oldSystemRoots
   481  				}()
   482  			}
   483  
   484  			creds, err := NewCredentials(test.opts)
   485  			if (err != nil) != test.wantErr {
   486  				t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr)
   487  			}
   488  			if err == nil {
   489  				if !creds.RequireTransportSecurity() {
   490  					t.Errorf("creds.RequireTransportSecurity() returned false")
   491  				}
   492  			}
   493  		})
   494  	}
   495  }
   496  
   497  func (s) TestValidateOptions(t *testing.T) {
   498  	tests := []struct {
   499  		name          string
   500  		opts          Options
   501  		wantErrPrefix string
   502  	}{
   503  		{
   504  			name:          "empty token exchange service URI",
   505  			opts:          Options{},
   506  			wantErrPrefix: "empty token_exchange_service_uri in options",
   507  		},
   508  		{
   509  			name: "invalid URI",
   510  			opts: Options{
   511  				TokenExchangeServiceURI: "\tI'm a bad URI\n",
   512  			},
   513  			wantErrPrefix: "invalid control character in URL",
   514  		},
   515  		{
   516  			name: "unsupported scheme",
   517  			opts: Options{
   518  				TokenExchangeServiceURI: "unix:///path/to/socket",
   519  			},
   520  			wantErrPrefix: "scheme is not supported",
   521  		},
   522  		{
   523  			name: "empty subjectTokenPath",
   524  			opts: Options{
   525  				TokenExchangeServiceURI: serviceURI,
   526  			},
   527  			wantErrPrefix: "required field SubjectTokenPath is not specified",
   528  		},
   529  		{
   530  			name: "empty subjectTokenType",
   531  			opts: Options{
   532  				TokenExchangeServiceURI: serviceURI,
   533  				SubjectTokenPath:        subjectTokenPath,
   534  			},
   535  			wantErrPrefix: "required field SubjectTokenType is not specified",
   536  		},
   537  		{
   538  			name: "good options",
   539  			opts: goodOptions,
   540  		},
   541  	}
   542  
   543  	for _, test := range tests {
   544  		t.Run(test.name, func(t *testing.T) {
   545  			err := validateOptions(test.opts)
   546  			if (err != nil) != (test.wantErrPrefix != "") {
   547  				t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
   548  			}
   549  			if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) {
   550  				t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
   551  			}
   552  		})
   553  	}
   554  }
   555  
   556  func (s) TestConstructRequest(t *testing.T) {
   557  	tests := []struct {
   558  		name                string
   559  		opts                Options
   560  		subjectTokenReadErr bool
   561  		actorTokenReadErr   bool
   562  		wantReqParams       *requestParameters
   563  		wantErr             bool
   564  	}{
   565  		{
   566  			name:                "subject token read failure",
   567  			subjectTokenReadErr: true,
   568  			opts:                goodOptions,
   569  			wantErr:             true,
   570  		},
   571  		{
   572  			name:              "actor token read failure",
   573  			actorTokenReadErr: true,
   574  			opts: Options{
   575  				TokenExchangeServiceURI: serviceURI,
   576  				Audience:                exampleAudience,
   577  				RequestedTokenType:      requestedTokenType,
   578  				SubjectTokenPath:        subjectTokenPath,
   579  				SubjectTokenType:        subjectTokenType,
   580  				ActorTokenPath:          actorTokenPath,
   581  				ActorTokenType:          actorTokenType,
   582  			},
   583  			wantErr: true,
   584  		},
   585  		{
   586  			name:          "default cloud platform scope",
   587  			opts:          goodOptions,
   588  			wantReqParams: goodRequestParams,
   589  		},
   590  		{
   591  			name: "all good",
   592  			opts: Options{
   593  				TokenExchangeServiceURI: serviceURI,
   594  				Resource:                exampleResource,
   595  				Audience:                exampleAudience,
   596  				Scope:                   testScope,
   597  				RequestedTokenType:      requestedTokenType,
   598  				SubjectTokenPath:        subjectTokenPath,
   599  				SubjectTokenType:        subjectTokenType,
   600  				ActorTokenPath:          actorTokenPath,
   601  				ActorTokenType:          actorTokenType,
   602  			},
   603  			wantReqParams: &requestParameters{
   604  				GrantType:          tokenExchangeGrantType,
   605  				Resource:           exampleResource,
   606  				Audience:           exampleAudience,
   607  				Scope:              testScope,
   608  				RequestedTokenType: requestedTokenType,
   609  				SubjectToken:       subjectTokenContents,
   610  				SubjectTokenType:   subjectTokenType,
   611  				ActorToken:         actorTokenContents,
   612  				ActorTokenType:     actorTokenType,
   613  			},
   614  		},
   615  	}
   616  
   617  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   618  	defer cancel()
   619  	for _, test := range tests {
   620  		t.Run(test.name, func(t *testing.T) {
   621  			if test.subjectTokenReadErr {
   622  				defer overrideSubjectTokenError()()
   623  			} else {
   624  				defer overrideSubjectTokenGood()()
   625  			}
   626  
   627  			if test.actorTokenReadErr {
   628  				defer overrideActorTokenError()()
   629  			} else {
   630  				defer overrideActorTokenGood()()
   631  			}
   632  
   633  			gotRequest, err := constructRequest(ctx, test.opts)
   634  			if (err != nil) != test.wantErr {
   635  				t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr)
   636  			}
   637  			if test.wantErr {
   638  				return
   639  			}
   640  			if err := compareRequest(gotRequest, test.wantReqParams); err != nil {
   641  				t.Fatal(err)
   642  			}
   643  		})
   644  	}
   645  }
   646  
   647  func (s) TestSendRequest(t *testing.T) {
   648  	defer overrideSubjectTokenGood()()
   649  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   650  	defer cancel()
   651  	req, err := constructRequest(ctx, goodOptions)
   652  	if err != nil {
   653  		t.Fatal(err)
   654  	}
   655  
   656  	tests := []struct {
   657  		name    string
   658  		resp    *http.Response
   659  		respErr error
   660  		wantErr bool
   661  	}{
   662  		{
   663  			name:    "client error",
   664  			respErr: errors.New("http.Client.Do failed"),
   665  			wantErr: true,
   666  		},
   667  		{
   668  			name: "bad response body",
   669  			resp: &http.Response{
   670  				Status:     "200 OK",
   671  				StatusCode: http.StatusOK,
   672  				Body:       ioutil.NopCloser(errReader{}),
   673  			},
   674  			wantErr: true,
   675  		},
   676  		{
   677  			name: "nonOK status code",
   678  			resp: &http.Response{
   679  				Status:     "400 BadRequest",
   680  				StatusCode: http.StatusBadRequest,
   681  				Body:       ioutil.NopCloser(strings.NewReader("")),
   682  			},
   683  			wantErr: true,
   684  		},
   685  		{
   686  			name: "good case",
   687  			resp: makeGoodResponse(),
   688  		},
   689  	}
   690  
   691  	for _, test := range tests {
   692  		t.Run(test.name, func(t *testing.T) {
   693  			client := &testutils.FakeHTTPClient{
   694  				ReqChan:  testutils.NewChannel(),
   695  				RespChan: testutils.NewChannel(),
   696  				Err:      test.respErr,
   697  			}
   698  			client.RespChan.Send(test.resp)
   699  			_, err := sendRequest(client, req)
   700  			if (err != nil) != test.wantErr {
   701  				t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr)
   702  			}
   703  		})
   704  	}
   705  }
   706  
   707  func (s) TestTokenInfoFromResponse(t *testing.T) {
   708  	noAccessToken, _ := json.Marshal(responseParameters{
   709  		IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
   710  		TokenType:       "Bearer",
   711  		ExpiresIn:       3600,
   712  	})
   713  	goodResponse, _ := json.Marshal(responseParameters{
   714  		IssuedTokenType: requestedTokenType,
   715  		AccessToken:     accessTokenContents,
   716  		TokenType:       "Bearer",
   717  		ExpiresIn:       3600,
   718  	})
   719  
   720  	tests := []struct {
   721  		name          string
   722  		respBody      []byte
   723  		wantTokenInfo *tokenInfo
   724  		wantErr       bool
   725  	}{
   726  		{
   727  			name:     "bad JSON",
   728  			respBody: []byte("not JSON"),
   729  			wantErr:  true,
   730  		},
   731  		{
   732  			name:     "empty response",
   733  			respBody: []byte(""),
   734  			wantErr:  true,
   735  		},
   736  		{
   737  			name:     "non-empty response with no access token",
   738  			respBody: noAccessToken,
   739  			wantErr:  true,
   740  		},
   741  		{
   742  			name:     "good response",
   743  			respBody: goodResponse,
   744  			wantTokenInfo: &tokenInfo{
   745  				tokenType: "Bearer",
   746  				token:     accessTokenContents,
   747  			},
   748  		},
   749  	}
   750  
   751  	for _, test := range tests {
   752  		t.Run(test.name, func(t *testing.T) {
   753  			gotTokenInfo, err := tokenInfoFromResponse(test.respBody)
   754  			if (err != nil) != test.wantErr {
   755  				t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr)
   756  			}
   757  			if test.wantErr {
   758  				return
   759  			}
   760  			// Can't do a cmp.Equal on the whole struct since the expiryField
   761  			// is populated based on time.Now().
   762  			if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token {
   763  				t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo)
   764  			}
   765  		})
   766  	}
   767  }