github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/auth/csrf_validation_handler_test.go (about) 1 package auth_test 2 3 import ( 4 "bytes" 5 "context" 6 "net/http" 7 "net/http/httptest" 8 9 "code.cloudfoundry.org/lager/lagertest" 10 11 "github.com/pf-qiu/concourse/v6/atc/api/auth" 12 "github.com/pf-qiu/concourse/v6/skymarshal/token/tokenfakes" 13 14 . "github.com/onsi/ginkgo" 15 . "github.com/onsi/gomega" 16 ) 17 18 var _ = Describe("CsrfValidationHandler", func() { 19 var ( 20 server *httptest.Server 21 csrfValidationHandler http.Handler 22 request *http.Request 23 response *http.Response 24 delegateHandlerCalled bool 25 fakeMiddleware *tokenfakes.FakeMiddleware 26 isCSRFRequired bool 27 logger *lagertest.TestLogger 28 isLoggerSet bool 29 ) 30 31 simpleHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 32 delegateHandlerCalled = true 33 }) 34 35 csrfRequiredWrapHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 36 if isLoggerSet { 37 r = request.WithContext(context.WithValue(r.Context(), "logger", logger)) 38 } 39 if isCSRFRequired { 40 r = request.WithContext(context.WithValue(r.Context(), auth.CSRFRequiredKey, true)) 41 } 42 csrfValidationHandler.ServeHTTP(w, r) 43 }) 44 45 BeforeEach(func() { 46 isLoggerSet = true 47 fakeMiddleware = new(tokenfakes.FakeMiddleware) 48 delegateHandlerCalled = false 49 isCSRFRequired = false 50 logger = lagertest.NewTestLogger("csrf-validation-test") 51 52 csrfValidationHandler = auth.CSRFValidationHandler( 53 simpleHandler, 54 fakeMiddleware, 55 ) 56 57 server = httptest.NewServer(csrfRequiredWrapHandler) 58 59 var err error 60 request, err = http.NewRequest("POST", server.URL, bytes.NewBufferString("hello")) 61 Expect(err).NotTo(HaveOccurred()) 62 }) 63 64 JustBeforeEach(func() { 65 var err error 66 response, err = http.DefaultClient.Do(request) 67 Expect(err).NotTo(HaveOccurred()) 68 }) 69 70 AfterEach(func() { 71 server.Close() 72 }) 73 74 Context("when request does not require CSRF validation", func() { 75 Context("when CSRF token is not provided", func() { 76 It("returns 200 OK", func() { 77 Expect(response.StatusCode).To(Equal(http.StatusOK)) 78 }) 79 80 It("calls delegate handler", func() { 81 Expect(delegateHandlerCalled).To(BeTrue()) 82 }) 83 }) 84 }) 85 86 Context("when request requires CSRF validation", func() { 87 BeforeEach(func() { 88 isCSRFRequired = true 89 }) 90 91 Context("when GET request", func() { 92 BeforeEach(func() { 93 var err error 94 request, err = http.NewRequest("GET", server.URL, bytes.NewBufferString("hello")) 95 Expect(err).NotTo(HaveOccurred()) 96 97 request.Header.Set(auth.CSRFHeaderName, "some-token") 98 fakeMiddleware.GetCSRFTokenReturns("some-token") 99 }) 100 101 It("returns 200 OK", func() { 102 Expect(response.StatusCode).To(Equal(http.StatusOK)) 103 }) 104 105 It("calls delegate handler", func() { 106 Expect(delegateHandlerCalled).To(BeTrue()) 107 }) 108 }) 109 110 Context("when CSRF token is not provided", func() { 111 It("returns 401 Bad Request", func() { 112 Expect(response.StatusCode).To(Equal(http.StatusUnauthorized)) 113 }) 114 115 It("does not call delegate handler", func() { 116 Expect(delegateHandlerCalled).To(BeFalse()) 117 }) 118 }) 119 120 Context("when CSRF token is provided", func() { 121 BeforeEach(func() { 122 request.Header.Set(auth.CSRFHeaderName, "some-csrf-token") 123 }) 124 125 Context("when auth token does not contain CSRF", func() { 126 BeforeEach(func() { 127 fakeMiddleware.GetCSRFTokenReturns("") 128 }) 129 130 It("returns 401 Bad Request", func() { 131 Expect(response.StatusCode).To(Equal(http.StatusUnauthorized)) 132 }) 133 134 It("does not call delegate handler", func() { 135 Expect(delegateHandlerCalled).To(BeFalse()) 136 }) 137 }) 138 139 Context("when auth token contains non-matching CSRF", func() { 140 BeforeEach(func() { 141 fakeMiddleware.GetCSRFTokenReturns("some-other-csrf") 142 }) 143 144 It("returns 401 Not Authorized", func() { 145 Expect(response.StatusCode).To(Equal(http.StatusUnauthorized)) 146 }) 147 148 It("does not call delegate handler", func() { 149 Expect(delegateHandlerCalled).To(BeFalse()) 150 }) 151 }) 152 153 Context("when auth token contains matching CSRF", func() { 154 BeforeEach(func() { 155 fakeMiddleware.GetCSRFTokenReturns("some-csrf-token") 156 }) 157 158 It("returns 200 OK", func() { 159 Expect(response.StatusCode).To(Equal(http.StatusOK)) 160 }) 161 162 It("calls delegate handler", func() { 163 Expect(delegateHandlerCalled).To(BeTrue()) 164 }) 165 }) 166 }) 167 }) 168 })