github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/auth/web_auth_handler_test.go (about) 1 package auth_test 2 3 import ( 4 "bytes" 5 "context" 6 "net/http" 7 "net/http/httptest" 8 9 "code.cloudfoundry.org/lager" 10 . "github.com/onsi/ginkgo" 11 . "github.com/onsi/gomega" 12 13 "github.com/pf-qiu/concourse/v6/atc/api/auth" 14 "github.com/pf-qiu/concourse/v6/atc/api/auth/authfakes" 15 "github.com/pf-qiu/concourse/v6/atc/api/buildserver" 16 "github.com/pf-qiu/concourse/v6/atc/db" 17 "github.com/pf-qiu/concourse/v6/atc/db/dbfakes" 18 "github.com/pf-qiu/concourse/v6/atc/event" 19 "github.com/pf-qiu/concourse/v6/skymarshal/token/tokenfakes" 20 ) 21 22 var _ = Describe("WebAuthHandler", func() { 23 var ( 24 fakeMiddleware *tokenfakes.FakeMiddleware 25 fakeHandler *authfakes.FakeHandler 26 ) 27 28 var server *httptest.Server 29 30 BeforeEach(func() { 31 fakeMiddleware = new(tokenfakes.FakeMiddleware) 32 fakeHandler = new(authfakes.FakeHandler) 33 34 server = httptest.NewServer(auth.WebAuthHandler{ 35 Handler: fakeHandler, 36 Middleware: fakeMiddleware, 37 }) 38 }) 39 40 AfterEach(func() { 41 server.Close() 42 }) 43 44 Describe("handling a request", func() { 45 var request *http.Request 46 var response *http.Response 47 48 BeforeEach(func() { 49 var err error 50 request, err = http.NewRequest("GET", server.URL, bytes.NewBufferString("hello")) 51 Expect(err).NotTo(HaveOccurred()) 52 }) 53 54 JustBeforeEach(func() { 55 var err error 56 response, err = http.DefaultClient.Do(request) 57 Expect(err).NotTo(HaveOccurred()) 58 defer response.Body.Close() 59 }) 60 61 Context("without the auth cookie", func() { 62 BeforeEach(func() { 63 fakeMiddleware.GetAuthTokenReturns("") 64 }) 65 66 It("does not set auth cookie in response", func() { 67 Expect(response.Cookies()).To(HaveLen(0)) 68 }) 69 70 It("proxies to the handler without setting the Authorization header", func() { 71 Expect(fakeHandler.ServeHTTPCallCount()).To(Equal(1)) 72 _, r := fakeHandler.ServeHTTPArgsForCall(0) 73 Expect(r.Header.Get("Authorization")).To(BeEmpty()) 74 }) 75 76 It("does not set CSRF required context in request", func() { 77 Expect(fakeHandler.ServeHTTPCallCount()).To(Equal(1)) 78 _, r := fakeHandler.ServeHTTPArgsForCall(0) 79 csrfRequiredContext := r.Context().Value(auth.CSRFRequiredKey) 80 Expect(csrfRequiredContext).To(BeNil()) 81 }) 82 83 Context("the nested handler returns unauthorized", func() { 84 BeforeEach(func() { 85 fakeHandler.ServeHTTPStub = func(w http.ResponseWriter, r *http.Request) { 86 w.WriteHeader(http.StatusUnauthorized) 87 } 88 }) 89 90 It("does not unset the auth cookie", func() { 91 Expect(fakeMiddleware.UnsetAuthTokenCallCount()).To(Equal(0)) 92 }) 93 94 It("does not unset the csrf cookie", func() { 95 Expect(fakeMiddleware.UnsetCSRFTokenCallCount()).To(Equal(0)) 96 }) 97 }) 98 }) 99 100 Context("with the auth cookie", func() { 101 BeforeEach(func() { 102 fakeMiddleware.GetAuthTokenReturns("username:password") 103 }) 104 105 It("sets the Authorization header with the value from the cookie", func() { 106 Expect(fakeHandler.ServeHTTPCallCount()).To(Equal(1)) 107 _, r := fakeHandler.ServeHTTPArgsForCall(0) 108 Expect(r.Header.Get("Authorization")).To(Equal("username:password")) 109 }) 110 111 It("sets CSRF required context in request", func() { 112 Expect(fakeHandler.ServeHTTPCallCount()).To(Equal(1)) 113 _, r := fakeHandler.ServeHTTPArgsForCall(0) 114 csrfRequiredContext := r.Context().Value(auth.CSRFRequiredKey) 115 Expect(csrfRequiredContext).NotTo(BeNil()) 116 117 boolCsrf := csrfRequiredContext.(bool) 118 Expect(boolCsrf).To(BeFalse()) 119 }) 120 121 Context("and the request also has an Authorization header", func() { 122 BeforeEach(func() { 123 request.Header.Set("Authorization", "foobar") 124 }) 125 126 It("does not override the Authorization header", func() { 127 Expect(fakeHandler.ServeHTTPCallCount()).To(Equal(1)) 128 _, r := fakeHandler.ServeHTTPArgsForCall(0) 129 Expect(r.Header.Get("Authorization")).To(Equal("foobar")) 130 }) 131 }) 132 133 Context("the nested handler returns an event stream", func() { 134 BeforeEach(func() { 135 build := new(dbfakes.FakeBuild) 136 fakeEventSource := new(dbfakes.FakeEventSource) 137 fakeEventSource.NextReturns(event.Envelope{}, db.ErrEndOfBuildEventStream) 138 build.EventsReturns(fakeEventSource, nil) 139 140 server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 141 defer GinkgoRecover() 142 auth.WebAuthHandler{ 143 Handler: buildserver.NewEventHandler(lager.NewLogger("test"), build), 144 Middleware: fakeMiddleware, 145 }.ServeHTTP(w, r) 146 })) 147 148 var err error 149 request, err = http.NewRequest("GET", server.URL, bytes.NewBufferString("hello")) 150 Expect(err).NotTo(HaveOccurred()) 151 }) 152 153 It("returns success", func() { 154 Expect(response.StatusCode).To(Equal(http.StatusOK)) 155 }) 156 }) 157 158 Context("the nested handler returns unauthorized", func() { 159 BeforeEach(func() { 160 fakeHandler.ServeHTTPStub = func(w http.ResponseWriter, r *http.Request) { 161 w.WriteHeader(http.StatusUnauthorized) 162 } 163 }) 164 165 It("unsets the auth cookie", func() { 166 Expect(fakeMiddleware.UnsetAuthTokenCallCount()).To(Equal(1)) 167 }) 168 169 It("unsets the csrf cookie", func() { 170 Expect(fakeMiddleware.UnsetCSRFTokenCallCount()).To(Equal(1)) 171 }) 172 }) 173 }) 174 }) 175 176 Describe("CSRF Required", func() { 177 var request *http.Request 178 var err error 179 Context("when CSRF context is set", func() { 180 BeforeEach(func() { 181 request, err = http.NewRequest("GET", server.URL, bytes.NewBufferString("hello")) 182 Expect(err).To(BeNil()) 183 184 ctx := context.WithValue(request.Context(), auth.CSRFRequiredKey, true) 185 request = request.WithContext(ctx) 186 187 }) 188 It("fetches the bool value", func() { 189 Expect(auth.IsCSRFRequired(request)).To(BeTrue()) 190 }) 191 }) 192 193 Context("when CSRF context is not set", func() { 194 BeforeEach(func() { 195 request, err = http.NewRequest("GET", server.URL, bytes.NewBufferString("hello")) 196 Expect(err).To(BeNil()) 197 }) 198 It("fetches the bool value", func() { 199 Expect(auth.IsCSRFRequired(request)).To(BeFalse()) 200 }) 201 202 }) 203 }) 204 })