github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/auth/csrf_validation_handler_test.go (about)

     1  package auth_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"net/http"
     7  	"net/http/httptest"
     8  
     9  	"code.cloudfoundry.org/lager/lagertest"
    10  
    11  	"github.com/pf-qiu/concourse/v6/atc/api/auth"
    12  	"github.com/pf-qiu/concourse/v6/skymarshal/token/tokenfakes"
    13  
    14  	. "github.com/onsi/ginkgo"
    15  	. "github.com/onsi/gomega"
    16  )
    17  
    18  var _ = Describe("CsrfValidationHandler", func() {
    19  	var (
    20  		server                *httptest.Server
    21  		csrfValidationHandler http.Handler
    22  		request               *http.Request
    23  		response              *http.Response
    24  		delegateHandlerCalled bool
    25  		fakeMiddleware        *tokenfakes.FakeMiddleware
    26  		isCSRFRequired        bool
    27  		logger                *lagertest.TestLogger
    28  		isLoggerSet           bool
    29  	)
    30  
    31  	simpleHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    32  		delegateHandlerCalled = true
    33  	})
    34  
    35  	csrfRequiredWrapHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    36  		if isLoggerSet {
    37  			r = request.WithContext(context.WithValue(r.Context(), "logger", logger))
    38  		}
    39  		if isCSRFRequired {
    40  			r = request.WithContext(context.WithValue(r.Context(), auth.CSRFRequiredKey, true))
    41  		}
    42  		csrfValidationHandler.ServeHTTP(w, r)
    43  	})
    44  
    45  	BeforeEach(func() {
    46  		isLoggerSet = true
    47  		fakeMiddleware = new(tokenfakes.FakeMiddleware)
    48  		delegateHandlerCalled = false
    49  		isCSRFRequired = false
    50  		logger = lagertest.NewTestLogger("csrf-validation-test")
    51  
    52  		csrfValidationHandler = auth.CSRFValidationHandler(
    53  			simpleHandler,
    54  			fakeMiddleware,
    55  		)
    56  
    57  		server = httptest.NewServer(csrfRequiredWrapHandler)
    58  
    59  		var err error
    60  		request, err = http.NewRequest("POST", server.URL, bytes.NewBufferString("hello"))
    61  		Expect(err).NotTo(HaveOccurred())
    62  	})
    63  
    64  	JustBeforeEach(func() {
    65  		var err error
    66  		response, err = http.DefaultClient.Do(request)
    67  		Expect(err).NotTo(HaveOccurred())
    68  	})
    69  
    70  	AfterEach(func() {
    71  		server.Close()
    72  	})
    73  
    74  	Context("when request does not require CSRF validation", func() {
    75  		Context("when CSRF token is not provided", func() {
    76  			It("returns 200 OK", func() {
    77  				Expect(response.StatusCode).To(Equal(http.StatusOK))
    78  			})
    79  
    80  			It("calls delegate handler", func() {
    81  				Expect(delegateHandlerCalled).To(BeTrue())
    82  			})
    83  		})
    84  	})
    85  
    86  	Context("when request requires CSRF validation", func() {
    87  		BeforeEach(func() {
    88  			isCSRFRequired = true
    89  		})
    90  
    91  		Context("when GET request", func() {
    92  			BeforeEach(func() {
    93  				var err error
    94  				request, err = http.NewRequest("GET", server.URL, bytes.NewBufferString("hello"))
    95  				Expect(err).NotTo(HaveOccurred())
    96  
    97  				request.Header.Set(auth.CSRFHeaderName, "some-token")
    98  				fakeMiddleware.GetCSRFTokenReturns("some-token")
    99  			})
   100  
   101  			It("returns 200 OK", func() {
   102  				Expect(response.StatusCode).To(Equal(http.StatusOK))
   103  			})
   104  
   105  			It("calls delegate handler", func() {
   106  				Expect(delegateHandlerCalled).To(BeTrue())
   107  			})
   108  		})
   109  
   110  		Context("when CSRF token is not provided", func() {
   111  			It("returns 401 Bad Request", func() {
   112  				Expect(response.StatusCode).To(Equal(http.StatusUnauthorized))
   113  			})
   114  
   115  			It("does not call delegate handler", func() {
   116  				Expect(delegateHandlerCalled).To(BeFalse())
   117  			})
   118  		})
   119  
   120  		Context("when CSRF token is provided", func() {
   121  			BeforeEach(func() {
   122  				request.Header.Set(auth.CSRFHeaderName, "some-csrf-token")
   123  			})
   124  
   125  			Context("when auth token does not contain CSRF", func() {
   126  				BeforeEach(func() {
   127  					fakeMiddleware.GetCSRFTokenReturns("")
   128  				})
   129  
   130  				It("returns 401 Bad Request", func() {
   131  					Expect(response.StatusCode).To(Equal(http.StatusUnauthorized))
   132  				})
   133  
   134  				It("does not call delegate handler", func() {
   135  					Expect(delegateHandlerCalled).To(BeFalse())
   136  				})
   137  			})
   138  
   139  			Context("when auth token contains non-matching CSRF", func() {
   140  				BeforeEach(func() {
   141  					fakeMiddleware.GetCSRFTokenReturns("some-other-csrf")
   142  				})
   143  
   144  				It("returns 401 Not Authorized", func() {
   145  					Expect(response.StatusCode).To(Equal(http.StatusUnauthorized))
   146  				})
   147  
   148  				It("does not call delegate handler", func() {
   149  					Expect(delegateHandlerCalled).To(BeFalse())
   150  				})
   151  			})
   152  
   153  			Context("when auth token contains matching CSRF", func() {
   154  				BeforeEach(func() {
   155  					fakeMiddleware.GetCSRFTokenReturns("some-csrf-token")
   156  				})
   157  
   158  				It("returns 200 OK", func() {
   159  					Expect(response.StatusCode).To(Equal(http.StatusOK))
   160  				})
   161  
   162  				It("calls delegate handler", func() {
   163  					Expect(delegateHandlerCalled).To(BeTrue())
   164  				})
   165  			})
   166  		})
   167  	})
   168  })