github.com/wanddynosios/cli/v8@v8.7.9-0.20240221182337-1a92e3a7017f/api/cloudcontroller/wrapper/uaa_authentication_test.go (about)

     1  package wrapper_test
     2  
     3  import (
     4  	"errors"
     5  	"io/ioutil"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"code.cloudfoundry.org/cli/api/uaa"
    11  
    12  	"github.com/SermoDigital/jose/crypto"
    13  	"github.com/SermoDigital/jose/jws"
    14  
    15  	"code.cloudfoundry.org/cli/api/cloudcontroller/ccerror"
    16  
    17  	"code.cloudfoundry.org/cli/api/cloudcontroller"
    18  	"code.cloudfoundry.org/cli/api/cloudcontroller/cloudcontrollerfakes"
    19  	. "code.cloudfoundry.org/cli/api/cloudcontroller/wrapper"
    20  	"code.cloudfoundry.org/cli/api/cloudcontroller/wrapper/wrapperfakes"
    21  	"code.cloudfoundry.org/cli/api/uaa/wrapper/util"
    22  
    23  	. "github.com/onsi/ginkgo"
    24  	. "github.com/onsi/gomega"
    25  )
    26  
    27  var _ = Describe("UAA Authentication", func() {
    28  	var (
    29  		fakeConnection *cloudcontrollerfakes.FakeConnection
    30  		fakeClient     *wrapperfakes.FakeUAAClient
    31  		inMemoryCache  *util.InMemoryCache
    32  
    33  		wrapper cloudcontroller.Connection
    34  		request *cloudcontroller.Request
    35  		inner   *UAAAuthentication
    36  	)
    37  
    38  	BeforeEach(func() {
    39  		fakeConnection = new(cloudcontrollerfakes.FakeConnection)
    40  		fakeClient = new(wrapperfakes.FakeUAAClient)
    41  		inMemoryCache = util.NewInMemoryTokenCache()
    42  		inner = NewUAAAuthentication(fakeClient, inMemoryCache)
    43  		wrapper = inner.Wrap(fakeConnection)
    44  
    45  		request = &cloudcontroller.Request{
    46  			Request: &http.Request{
    47  				Header: http.Header{},
    48  			},
    49  		}
    50  	})
    51  
    52  	Describe("Make", func() {
    53  		When("no tokens are set", func() {
    54  			BeforeEach(func() {
    55  				inMemoryCache.SetAccessToken("")
    56  				inMemoryCache.SetRefreshToken("")
    57  			})
    58  
    59  			It("does not attempt to refresh the token", func() {
    60  				Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(0))
    61  			})
    62  		})
    63  
    64  		When("the access token is invalid", func() {
    65  			var (
    66  				executeErr error
    67  			)
    68  			BeforeEach(func() {
    69  				inMemoryCache.SetAccessToken("Bearer some.invalid.token")
    70  				inMemoryCache.SetRefreshToken("some refresh token")
    71  				executeErr = wrapper.Make(request, nil)
    72  			})
    73  
    74  			It("should refresh the token", func() {
    75  				Expect(executeErr).ToNot(HaveOccurred())
    76  				Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(1))
    77  			})
    78  		})
    79  
    80  		When("the access token is valid", func() {
    81  			var (
    82  				accessToken string
    83  			)
    84  
    85  			BeforeEach(func() {
    86  				var err error
    87  				accessToken, err = buildTokenString(time.Now().AddDate(0, 0, 1))
    88  				Expect(err).ToNot(HaveOccurred())
    89  				inMemoryCache.SetAccessToken(accessToken)
    90  			})
    91  
    92  			It("does not refresh the token", func() {
    93  				Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(0))
    94  			})
    95  
    96  			It("adds authentication headers", func() {
    97  				err := wrapper.Make(request, nil)
    98  				Expect(err).ToNot(HaveOccurred())
    99  
   100  				Expect(fakeConnection.MakeCallCount()).To(Equal(1))
   101  				authenticatedRequest, _ := fakeConnection.MakeArgsForCall(0)
   102  				headers := authenticatedRequest.Header
   103  				Expect(headers["Authorization"]).To(ConsistOf([]string{accessToken}))
   104  			})
   105  
   106  			When("the request already has headers", func() {
   107  				It("preserves existing headers", func() {
   108  					request.Header.Add("Existing", "header")
   109  					err := wrapper.Make(request, nil)
   110  					Expect(err).ToNot(HaveOccurred())
   111  
   112  					Expect(fakeConnection.MakeCallCount()).To(Equal(1))
   113  					authenticatedRequest, _ := fakeConnection.MakeArgsForCall(0)
   114  					headers := authenticatedRequest.Header
   115  					Expect(headers["Existing"]).To(ConsistOf([]string{"header"}))
   116  				})
   117  			})
   118  
   119  			When("the wrapped connection returns nil", func() {
   120  				It("returns nil", func() {
   121  					fakeConnection.MakeReturns(nil)
   122  
   123  					err := wrapper.Make(request, nil)
   124  					Expect(err).ToNot(HaveOccurred())
   125  				})
   126  			})
   127  
   128  			When("the wrapped connection returns an error", func() {
   129  				It("returns the error", func() {
   130  					innerError := errors.New("inner error")
   131  					fakeConnection.MakeReturns(innerError)
   132  
   133  					err := wrapper.Make(request, nil)
   134  					Expect(err).To(Equal(innerError))
   135  				})
   136  			})
   137  		})
   138  
   139  		When("the authorization header is already provided", func() {
   140  			var (
   141  				accessToken string
   142  			)
   143  
   144  			BeforeEach(func() {
   145  				var err error
   146  				accessToken, err = buildTokenString(time.Now().AddDate(0, 0, 1))
   147  				Expect(err).ToNot(HaveOccurred())
   148  				request.Header.Set("Authorization", accessToken)
   149  			})
   150  
   151  			It("does not overwrite the authentication headers", func() {
   152  				err := wrapper.Make(request, nil)
   153  				Expect(err).ToNot(HaveOccurred())
   154  
   155  				Expect(fakeConnection.MakeCallCount()).To(Equal(1))
   156  				authenticatedRequest, _ := fakeConnection.MakeArgsForCall(0)
   157  				headers := authenticatedRequest.Header
   158  				Expect(headers["Authorization"]).To(ConsistOf([]string{accessToken}))
   159  			})
   160  		})
   161  
   162  		When("the access token is expired", func() {
   163  			var (
   164  				expectedBody       string
   165  				request            *cloudcontroller.Request
   166  				executeErr         error
   167  				invalidAccessToken string
   168  				newAccessToken     string
   169  				newRefreshToken    string
   170  			)
   171  
   172  			BeforeEach(func() {
   173  				var err error
   174  				invalidAccessToken, err = buildTokenString(time.Time{})
   175  				Expect(err).ToNot(HaveOccurred())
   176  				newAccessToken, err = buildTokenString(time.Now().AddDate(0, 1, 1))
   177  				Expect(err).ToNot(HaveOccurred())
   178  				newRefreshToken = "newRefreshToken"
   179  
   180  				expectedBody = "this body content should be preserved"
   181  				body := strings.NewReader(expectedBody)
   182  				request = cloudcontroller.NewRequest(&http.Request{
   183  					Header: http.Header{},
   184  					Body:   ioutil.NopCloser(body),
   185  				}, body)
   186  
   187  				inMemoryCache.SetAccessToken(invalidAccessToken)
   188  
   189  				fakeClient.RefreshAccessTokenReturns(
   190  					uaa.RefreshedTokens{
   191  						AccessToken:  newAccessToken,
   192  						RefreshToken: newRefreshToken,
   193  						Type:         "bearer",
   194  					},
   195  					nil,
   196  				)
   197  			})
   198  
   199  			JustBeforeEach(func() {
   200  				executeErr = wrapper.Make(request, nil)
   201  			})
   202  
   203  			It("should refresh the token", func() {
   204  				Expect(executeErr).ToNot(HaveOccurred())
   205  				Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(1))
   206  			})
   207  
   208  			It("should save the refresh token", func() {
   209  				Expect(inMemoryCache.RefreshToken()).To(Equal(newRefreshToken))
   210  				Expect(inMemoryCache.AccessToken()).To(ContainSubstring(newAccessToken))
   211  			})
   212  
   213  			When("token cannot be refreshed", func() {
   214  				JustBeforeEach(func() {
   215  					fakeConnection.MakeReturns(ccerror.InvalidAuthTokenError{})
   216  				})
   217  
   218  				It("should not re-try the initial request", func() {
   219  					Expect(fakeConnection.MakeCallCount()).To(Equal(1))
   220  				})
   221  			})
   222  
   223  		})
   224  	})
   225  })
   226  
   227  func buildTokenString(expiration time.Time) (string, error) {
   228  	c := jws.Claims{}
   229  	c.SetExpiration(expiration)
   230  	token := jws.NewJWT(c, crypto.Unsecured)
   231  	tokenBytes, err := token.Serialize(nil)
   232  	return string(tokenBytes), err
   233  }