github.com/orange-cloudfoundry/cli@v7.1.0+incompatible/api/cloudcontroller/wrapper/uaa_authentication_test.go (about)

     1  package wrapper_test
     2  
     3  import (
     4  	"errors"
     5  	"io/ioutil"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"code.cloudfoundry.org/cli/api/uaa"
    11  
    12  	"github.com/SermoDigital/jose/crypto"
    13  	"github.com/SermoDigital/jose/jws"
    14  
    15  	"code.cloudfoundry.org/cli/api/cloudcontroller/ccerror"
    16  
    17  	"code.cloudfoundry.org/cli/api/cloudcontroller"
    18  	"code.cloudfoundry.org/cli/api/cloudcontroller/cloudcontrollerfakes"
    19  	. "code.cloudfoundry.org/cli/api/cloudcontroller/wrapper"
    20  	"code.cloudfoundry.org/cli/api/cloudcontroller/wrapper/wrapperfakes"
    21  	"code.cloudfoundry.org/cli/api/uaa/wrapper/util"
    22  
    23  	. "github.com/onsi/ginkgo"
    24  	. "github.com/onsi/gomega"
    25  )
    26  
    27  var _ = Describe("UAA Authentication", func() {
    28  	var (
    29  		fakeConnection *cloudcontrollerfakes.FakeConnection
    30  		fakeClient     *wrapperfakes.FakeUAAClient
    31  		inMemoryCache  *util.InMemoryCache
    32  
    33  		wrapper cloudcontroller.Connection
    34  		request *cloudcontroller.Request
    35  		inner   *UAAAuthentication
    36  	)
    37  
    38  	BeforeEach(func() {
    39  		fakeConnection = new(cloudcontrollerfakes.FakeConnection)
    40  		fakeClient = new(wrapperfakes.FakeUAAClient)
    41  		inMemoryCache = util.NewInMemoryTokenCache()
    42  		inner = NewUAAAuthentication(fakeClient, inMemoryCache)
    43  		wrapper = inner.Wrap(fakeConnection)
    44  
    45  		request = &cloudcontroller.Request{
    46  			Request: &http.Request{
    47  				Header: http.Header{},
    48  			},
    49  		}
    50  	})
    51  
    52  	Describe("Make", func() {
    53  		When("the client is nil", func() {
    54  			BeforeEach(func() {
    55  				inner.SetClient(nil)
    56  
    57  				fakeConnection.MakeReturns(ccerror.InvalidAuthTokenError{})
    58  			})
    59  
    60  			It("calls the connection without any side effects", func() {
    61  				err := wrapper.Make(request, nil)
    62  				Expect(err).To(MatchError(ccerror.InvalidAuthTokenError{}))
    63  
    64  				Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(0))
    65  				Expect(fakeConnection.MakeCallCount()).To(Equal(1))
    66  			})
    67  		})
    68  
    69  		When("no tokens are set", func() {
    70  			BeforeEach(func() {
    71  				inMemoryCache.SetAccessToken("")
    72  				inMemoryCache.SetRefreshToken("")
    73  			})
    74  
    75  			It("does not attempt to refresh the token", func() {
    76  				Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(0))
    77  			})
    78  		})
    79  
    80  		When("the access token is invalid", func() {
    81  			var (
    82  				executeErr error
    83  			)
    84  			BeforeEach(func() {
    85  				inMemoryCache.SetAccessToken("Bearer some.invalid.token")
    86  				inMemoryCache.SetRefreshToken("some refresh token")
    87  				executeErr = wrapper.Make(request, nil)
    88  			})
    89  
    90  			It("should refresh the token", func() {
    91  				Expect(executeErr).ToNot(HaveOccurred())
    92  				Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(1))
    93  			})
    94  		})
    95  
    96  		When("the access token is valid", func() {
    97  			var (
    98  				accessToken string
    99  			)
   100  
   101  			BeforeEach(func() {
   102  				var err error
   103  				accessToken, err = buildTokenString(time.Now().AddDate(0, 0, 1))
   104  				Expect(err).ToNot(HaveOccurred())
   105  				inMemoryCache.SetAccessToken(accessToken)
   106  			})
   107  
   108  			It("does not refresh the token", func() {
   109  				Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(0))
   110  			})
   111  
   112  			It("adds authentication headers", func() {
   113  				err := wrapper.Make(request, nil)
   114  				Expect(err).ToNot(HaveOccurred())
   115  
   116  				Expect(fakeConnection.MakeCallCount()).To(Equal(1))
   117  				authenticatedRequest, _ := fakeConnection.MakeArgsForCall(0)
   118  				headers := authenticatedRequest.Header
   119  				Expect(headers["Authorization"]).To(ConsistOf([]string{accessToken}))
   120  			})
   121  
   122  			When("the request already has headers", func() {
   123  				It("preserves existing headers", func() {
   124  					request.Header.Add("Existing", "header")
   125  					err := wrapper.Make(request, nil)
   126  					Expect(err).ToNot(HaveOccurred())
   127  
   128  					Expect(fakeConnection.MakeCallCount()).To(Equal(1))
   129  					authenticatedRequest, _ := fakeConnection.MakeArgsForCall(0)
   130  					headers := authenticatedRequest.Header
   131  					Expect(headers["Existing"]).To(ConsistOf([]string{"header"}))
   132  				})
   133  			})
   134  
   135  			When("the wrapped connection returns nil", func() {
   136  				It("returns nil", func() {
   137  					fakeConnection.MakeReturns(nil)
   138  
   139  					err := wrapper.Make(request, nil)
   140  					Expect(err).ToNot(HaveOccurred())
   141  				})
   142  			})
   143  
   144  			When("the wrapped connection returns an error", func() {
   145  				It("returns the error", func() {
   146  					innerError := errors.New("inner error")
   147  					fakeConnection.MakeReturns(innerError)
   148  
   149  					err := wrapper.Make(request, nil)
   150  					Expect(err).To(Equal(innerError))
   151  				})
   152  			})
   153  		})
   154  
   155  		When("the access token is expired", func() {
   156  			var (
   157  				expectedBody       string
   158  				request            *cloudcontroller.Request
   159  				executeErr         error
   160  				invalidAccessToken string
   161  				newAccessToken     string
   162  				newRefreshToken    string
   163  			)
   164  
   165  			BeforeEach(func() {
   166  				var err error
   167  				invalidAccessToken, err = buildTokenString(time.Time{})
   168  				Expect(err).ToNot(HaveOccurred())
   169  				newAccessToken, err = buildTokenString(time.Now().AddDate(0, 1, 1))
   170  				Expect(err).ToNot(HaveOccurred())
   171  				newRefreshToken = "newRefreshToken"
   172  
   173  				expectedBody = "this body content should be preserved"
   174  				body := strings.NewReader(expectedBody)
   175  				request = cloudcontroller.NewRequest(&http.Request{
   176  					Header: http.Header{},
   177  					Body:   ioutil.NopCloser(body),
   178  				}, body)
   179  
   180  				inMemoryCache.SetAccessToken(invalidAccessToken)
   181  
   182  				fakeClient.RefreshAccessTokenReturns(
   183  					uaa.RefreshedTokens{
   184  						AccessToken:  newAccessToken,
   185  						RefreshToken: newRefreshToken,
   186  						Type:         "bearer",
   187  					},
   188  					nil,
   189  				)
   190  			})
   191  
   192  			JustBeforeEach(func() {
   193  				executeErr = wrapper.Make(request, nil)
   194  			})
   195  
   196  			It("should refresh the token", func() {
   197  				Expect(executeErr).ToNot(HaveOccurred())
   198  				Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(1))
   199  			})
   200  
   201  			It("should save the refresh token", func() {
   202  				Expect(inMemoryCache.RefreshToken()).To(Equal(newRefreshToken))
   203  				Expect(inMemoryCache.AccessToken()).To(ContainSubstring(newAccessToken))
   204  			})
   205  
   206  			When("token cannot be refreshed", func() {
   207  				JustBeforeEach(func() {
   208  					fakeConnection.MakeReturns(ccerror.InvalidAuthTokenError{})
   209  				})
   210  
   211  				It("should not re-try the initial request", func() {
   212  					Expect(fakeConnection.MakeCallCount()).To(Equal(1))
   213  				})
   214  			})
   215  
   216  		})
   217  	})
   218  })
   219  
   220  func buildTokenString(expiration time.Time) (string, error) {
   221  	c := jws.Claims{}
   222  	c.SetExpiration(expiration)
   223  	token := jws.NewJWT(c, crypto.Unsecured)
   224  	tokenBytes, err := token.Serialize(nil)
   225  	return string(tokenBytes), err
   226  }