github.com/jenspinney/cli@v6.42.1-0.20190207184520-7450c600020e+incompatible/api/uaa/wrapper/uaa_authentication_test.go (about)

     1  package wrapper_test
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/url"
     9  	"strings"
    10  
    11  	"code.cloudfoundry.org/cli/api/uaa"
    12  	"code.cloudfoundry.org/cli/api/uaa/uaafakes"
    13  	. "code.cloudfoundry.org/cli/api/uaa/wrapper"
    14  	"code.cloudfoundry.org/cli/api/uaa/wrapper/util"
    15  	"code.cloudfoundry.org/cli/api/uaa/wrapper/wrapperfakes"
    16  	. "github.com/onsi/ginkgo"
    17  	. "github.com/onsi/gomega"
    18  )
    19  
    20  var _ = Describe("UAA Authentication", func() {
    21  	var (
    22  		fakeConnection *uaafakes.FakeConnection
    23  		fakeClient     *wrapperfakes.FakeUAAClient
    24  		inMemoryCache  *util.InMemoryCache
    25  
    26  		wrapper uaa.Connection
    27  		request *http.Request
    28  		inner   *UAAAuthentication
    29  	)
    30  
    31  	BeforeEach(func() {
    32  		fakeConnection = new(uaafakes.FakeConnection)
    33  		fakeClient = new(wrapperfakes.FakeUAAClient)
    34  		inMemoryCache = util.NewInMemoryTokenCache()
    35  
    36  		inner = NewUAAAuthentication(fakeClient, inMemoryCache)
    37  		wrapper = inner.Wrap(fakeConnection)
    38  	})
    39  
    40  	Describe("Make", func() {
    41  		When("the client is nil", func() {
    42  			BeforeEach(func() {
    43  				inner.SetClient(nil)
    44  
    45  				fakeConnection.MakeReturns(uaa.InvalidAuthTokenError{})
    46  			})
    47  
    48  			It("calls the connection without any side effects", func() {
    49  				err := wrapper.Make(request, nil)
    50  				Expect(err).To(MatchError(uaa.InvalidAuthTokenError{}))
    51  
    52  				Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(0))
    53  				Expect(fakeConnection.MakeCallCount()).To(Equal(1))
    54  			})
    55  		})
    56  
    57  		When("the token is valid", func() {
    58  			BeforeEach(func() {
    59  				request = &http.Request{
    60  					Header: http.Header{},
    61  				}
    62  				inMemoryCache.SetAccessToken("a-ok")
    63  			})
    64  
    65  			It("adds authentication headers", func() {
    66  				err := wrapper.Make(request, nil)
    67  				Expect(err).ToNot(HaveOccurred())
    68  
    69  				Expect(fakeConnection.MakeCallCount()).To(Equal(1))
    70  				authenticatedRequest, _ := fakeConnection.MakeArgsForCall(0)
    71  				headers := authenticatedRequest.Header
    72  				Expect(headers["Authorization"]).To(ConsistOf([]string{"a-ok"}))
    73  			})
    74  
    75  			When("the request already has headers", func() {
    76  				It("preserves existing headers", func() {
    77  					request.Header.Add("Existing", "header")
    78  					err := wrapper.Make(request, nil)
    79  					Expect(err).ToNot(HaveOccurred())
    80  
    81  					Expect(fakeConnection.MakeCallCount()).To(Equal(1))
    82  					authenticatedRequest, _ := fakeConnection.MakeArgsForCall(0)
    83  					headers := authenticatedRequest.Header
    84  					Expect(headers["Existing"]).To(ConsistOf([]string{"header"}))
    85  				})
    86  			})
    87  
    88  			When("the wrapped connection returns nil", func() {
    89  				It("returns nil", func() {
    90  					fakeConnection.MakeReturns(nil)
    91  
    92  					err := wrapper.Make(request, nil)
    93  					Expect(err).ToNot(HaveOccurred())
    94  				})
    95  			})
    96  
    97  			When("the wrapped connection returns an error", func() {
    98  				It("returns the error", func() {
    99  					innerError := errors.New("inner error")
   100  					fakeConnection.MakeReturns(innerError)
   101  
   102  					err := wrapper.Make(request, nil)
   103  					Expect(err).To(Equal(innerError))
   104  				})
   105  			})
   106  		})
   107  
   108  		When("the token is invalid", func() {
   109  			var expectedBody string
   110  
   111  			BeforeEach(func() {
   112  				expectedBody = "this body content should be preserved"
   113  				request, err := http.NewRequest(
   114  					http.MethodGet,
   115  					server.URL(),
   116  					ioutil.NopCloser(strings.NewReader(expectedBody)),
   117  				)
   118  				Expect(err).NotTo(HaveOccurred())
   119  
   120  				makeCount := 0
   121  				fakeConnection.MakeStub = func(request *http.Request, response *uaa.Response) error {
   122  					body, readErr := ioutil.ReadAll(request.Body)
   123  					Expect(readErr).NotTo(HaveOccurred())
   124  					Expect(string(body)).To(Equal(expectedBody))
   125  
   126  					if makeCount == 0 {
   127  						makeCount++
   128  						return uaa.InvalidAuthTokenError{}
   129  					} else {
   130  						return nil
   131  					}
   132  				}
   133  
   134  				fakeClient.RefreshAccessTokenReturns(
   135  					uaa.RefreshedTokens{
   136  						AccessToken:  "foobar-2",
   137  						RefreshToken: "bananananananana",
   138  						Type:         "bearer",
   139  					},
   140  					nil,
   141  				)
   142  
   143  				inMemoryCache.SetAccessToken("what")
   144  
   145  				err = wrapper.Make(request, nil)
   146  				Expect(err).ToNot(HaveOccurred())
   147  			})
   148  
   149  			It("should refresh the token", func() {
   150  				Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(1))
   151  			})
   152  
   153  			It("should resend the request", func() {
   154  				Expect(fakeConnection.MakeCallCount()).To(Equal(2))
   155  
   156  				request, _ := fakeConnection.MakeArgsForCall(1)
   157  				Expect(request.Header.Get("Authorization")).To(Equal("bearer foobar-2"))
   158  			})
   159  
   160  			It("should save the refresh token", func() {
   161  				Expect(inMemoryCache.RefreshToken()).To(Equal("bananananananana"))
   162  			})
   163  		})
   164  
   165  		When("refreshing the token", func() {
   166  			var originalAuthHeader string
   167  			BeforeEach(func() {
   168  				body := strings.NewReader(url.Values{
   169  					"grant_type": {"refresh_token"},
   170  				}.Encode())
   171  
   172  				request, err := http.NewRequest("POST", fmt.Sprintf("%s/oauth/token", server.URL()), body)
   173  				Expect(err).NotTo(HaveOccurred())
   174  				request.SetBasicAuth("some-user", "some-password")
   175  				originalAuthHeader = request.Header.Get("Authorization")
   176  
   177  				inMemoryCache.SetAccessToken("some-access-token")
   178  
   179  				err = wrapper.Make(request, nil)
   180  				Expect(err).ToNot(HaveOccurred())
   181  			})
   182  
   183  			It("does not change the 'Authorization' header", func() {
   184  				Expect(fakeConnection.MakeCallCount()).To(Equal(1))
   185  
   186  				request, _ := fakeConnection.MakeArgsForCall(0)
   187  				Expect(request.Header.Get("Authorization")).To(Equal(originalAuthHeader))
   188  			})
   189  		})
   190  
   191  		When("logging in", func() {
   192  			Context("with password grant_type", func() {
   193  				var originalAuthHeader string
   194  				BeforeEach(func() {
   195  					body := strings.NewReader(url.Values{
   196  						"grant_type": {"password"},
   197  					}.Encode())
   198  
   199  					request, err := http.NewRequest("POST", fmt.Sprintf("%s/oauth/token", server.URL()), body)
   200  					Expect(err).NotTo(HaveOccurred())
   201  					request.SetBasicAuth("some-user", "some-password")
   202  					originalAuthHeader = request.Header.Get("Authorization")
   203  
   204  					inMemoryCache.SetAccessToken("some-access-token")
   205  
   206  					err = wrapper.Make(request, nil)
   207  					Expect(err).ToNot(HaveOccurred())
   208  				})
   209  
   210  				It("does not change the 'Authorization' header", func() {
   211  					Expect(fakeConnection.MakeCallCount()).To(Equal(1))
   212  
   213  					request, _ := fakeConnection.MakeArgsForCall(0)
   214  					Expect(request.Header.Get("Authorization")).To(Equal(originalAuthHeader))
   215  				})
   216  			})
   217  			Context("with client_credentials grant_type", func() {
   218  				var originalAuthHeader string
   219  				BeforeEach(func() {
   220  					body := strings.NewReader(url.Values{
   221  						"grant_type": {"client_credentials"},
   222  					}.Encode())
   223  
   224  					request, err := http.NewRequest("POST", fmt.Sprintf("%s/oauth/token", server.URL()), body)
   225  					Expect(err).NotTo(HaveOccurred())
   226  					request.SetBasicAuth("some-user", "some-password")
   227  					originalAuthHeader = request.Header.Get("Authorization")
   228  
   229  					inMemoryCache.SetAccessToken("some-access-token")
   230  
   231  					err = wrapper.Make(request, nil)
   232  					Expect(err).ToNot(HaveOccurred())
   233  				})
   234  
   235  				It("does not change the 'Authorization' header", func() {
   236  					Expect(fakeConnection.MakeCallCount()).To(Equal(1))
   237  
   238  					request, _ := fakeConnection.MakeArgsForCall(0)
   239  					Expect(request.Header.Get("Authorization")).To(Equal(originalAuthHeader))
   240  				})
   241  			})
   242  		})
   243  	})
   244  })