github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/auth/web_auth_handler_test.go (about)

     1  package auth_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"net/http"
     7  	"net/http/httptest"
     8  
     9  	"code.cloudfoundry.org/lager"
    10  	. "github.com/onsi/ginkgo"
    11  	. "github.com/onsi/gomega"
    12  
    13  	"github.com/pf-qiu/concourse/v6/atc/api/auth"
    14  	"github.com/pf-qiu/concourse/v6/atc/api/auth/authfakes"
    15  	"github.com/pf-qiu/concourse/v6/atc/api/buildserver"
    16  	"github.com/pf-qiu/concourse/v6/atc/db"
    17  	"github.com/pf-qiu/concourse/v6/atc/db/dbfakes"
    18  	"github.com/pf-qiu/concourse/v6/atc/event"
    19  	"github.com/pf-qiu/concourse/v6/skymarshal/token/tokenfakes"
    20  )
    21  
    22  var _ = Describe("WebAuthHandler", func() {
    23  	var (
    24  		fakeMiddleware *tokenfakes.FakeMiddleware
    25  		fakeHandler    *authfakes.FakeHandler
    26  	)
    27  
    28  	var server *httptest.Server
    29  
    30  	BeforeEach(func() {
    31  		fakeMiddleware = new(tokenfakes.FakeMiddleware)
    32  		fakeHandler = new(authfakes.FakeHandler)
    33  
    34  		server = httptest.NewServer(auth.WebAuthHandler{
    35  			Handler:    fakeHandler,
    36  			Middleware: fakeMiddleware,
    37  		})
    38  	})
    39  
    40  	AfterEach(func() {
    41  		server.Close()
    42  	})
    43  
    44  	Describe("handling a request", func() {
    45  		var request *http.Request
    46  		var response *http.Response
    47  
    48  		BeforeEach(func() {
    49  			var err error
    50  			request, err = http.NewRequest("GET", server.URL, bytes.NewBufferString("hello"))
    51  			Expect(err).NotTo(HaveOccurred())
    52  		})
    53  
    54  		JustBeforeEach(func() {
    55  			var err error
    56  			response, err = http.DefaultClient.Do(request)
    57  			Expect(err).NotTo(HaveOccurred())
    58  			defer response.Body.Close()
    59  		})
    60  
    61  		Context("without the auth cookie", func() {
    62  			BeforeEach(func() {
    63  				fakeMiddleware.GetAuthTokenReturns("")
    64  			})
    65  
    66  			It("does not set auth cookie in response", func() {
    67  				Expect(response.Cookies()).To(HaveLen(0))
    68  			})
    69  
    70  			It("proxies to the handler without setting the Authorization header", func() {
    71  				Expect(fakeHandler.ServeHTTPCallCount()).To(Equal(1))
    72  				_, r := fakeHandler.ServeHTTPArgsForCall(0)
    73  				Expect(r.Header.Get("Authorization")).To(BeEmpty())
    74  			})
    75  
    76  			It("does not set CSRF required context in request", func() {
    77  				Expect(fakeHandler.ServeHTTPCallCount()).To(Equal(1))
    78  				_, r := fakeHandler.ServeHTTPArgsForCall(0)
    79  				csrfRequiredContext := r.Context().Value(auth.CSRFRequiredKey)
    80  				Expect(csrfRequiredContext).To(BeNil())
    81  			})
    82  
    83  			Context("the nested handler returns unauthorized", func() {
    84  				BeforeEach(func() {
    85  					fakeHandler.ServeHTTPStub = func(w http.ResponseWriter, r *http.Request) {
    86  						w.WriteHeader(http.StatusUnauthorized)
    87  					}
    88  				})
    89  
    90  				It("does not unset the auth cookie", func() {
    91  					Expect(fakeMiddleware.UnsetAuthTokenCallCount()).To(Equal(0))
    92  				})
    93  
    94  				It("does not unset the csrf cookie", func() {
    95  					Expect(fakeMiddleware.UnsetCSRFTokenCallCount()).To(Equal(0))
    96  				})
    97  			})
    98  		})
    99  
   100  		Context("with the auth cookie", func() {
   101  			BeforeEach(func() {
   102  				fakeMiddleware.GetAuthTokenReturns("username:password")
   103  			})
   104  
   105  			It("sets the Authorization header with the value from the cookie", func() {
   106  				Expect(fakeHandler.ServeHTTPCallCount()).To(Equal(1))
   107  				_, r := fakeHandler.ServeHTTPArgsForCall(0)
   108  				Expect(r.Header.Get("Authorization")).To(Equal("username:password"))
   109  			})
   110  
   111  			It("sets CSRF required context in request", func() {
   112  				Expect(fakeHandler.ServeHTTPCallCount()).To(Equal(1))
   113  				_, r := fakeHandler.ServeHTTPArgsForCall(0)
   114  				csrfRequiredContext := r.Context().Value(auth.CSRFRequiredKey)
   115  				Expect(csrfRequiredContext).NotTo(BeNil())
   116  
   117  				boolCsrf := csrfRequiredContext.(bool)
   118  				Expect(boolCsrf).To(BeFalse())
   119  			})
   120  
   121  			Context("and the request also has an Authorization header", func() {
   122  				BeforeEach(func() {
   123  					request.Header.Set("Authorization", "foobar")
   124  				})
   125  
   126  				It("does not override the Authorization header", func() {
   127  					Expect(fakeHandler.ServeHTTPCallCount()).To(Equal(1))
   128  					_, r := fakeHandler.ServeHTTPArgsForCall(0)
   129  					Expect(r.Header.Get("Authorization")).To(Equal("foobar"))
   130  				})
   131  			})
   132  
   133  			Context("the nested handler returns an event stream", func() {
   134  				BeforeEach(func() {
   135  					build := new(dbfakes.FakeBuild)
   136  					fakeEventSource := new(dbfakes.FakeEventSource)
   137  					fakeEventSource.NextReturns(event.Envelope{}, db.ErrEndOfBuildEventStream)
   138  					build.EventsReturns(fakeEventSource, nil)
   139  
   140  					server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   141  						defer GinkgoRecover()
   142  						auth.WebAuthHandler{
   143  							Handler:    buildserver.NewEventHandler(lager.NewLogger("test"), build),
   144  							Middleware: fakeMiddleware,
   145  						}.ServeHTTP(w, r)
   146  					}))
   147  
   148  					var err error
   149  					request, err = http.NewRequest("GET", server.URL, bytes.NewBufferString("hello"))
   150  					Expect(err).NotTo(HaveOccurred())
   151  				})
   152  
   153  				It("returns success", func() {
   154  					Expect(response.StatusCode).To(Equal(http.StatusOK))
   155  				})
   156  			})
   157  
   158  			Context("the nested handler returns unauthorized", func() {
   159  				BeforeEach(func() {
   160  					fakeHandler.ServeHTTPStub = func(w http.ResponseWriter, r *http.Request) {
   161  						w.WriteHeader(http.StatusUnauthorized)
   162  					}
   163  				})
   164  
   165  				It("unsets the auth cookie", func() {
   166  					Expect(fakeMiddleware.UnsetAuthTokenCallCount()).To(Equal(1))
   167  				})
   168  
   169  				It("unsets the csrf cookie", func() {
   170  					Expect(fakeMiddleware.UnsetCSRFTokenCallCount()).To(Equal(1))
   171  				})
   172  			})
   173  		})
   174  	})
   175  
   176  	Describe("CSRF Required", func() {
   177  		var request *http.Request
   178  		var err error
   179  		Context("when CSRF context is set", func() {
   180  			BeforeEach(func() {
   181  				request, err = http.NewRequest("GET", server.URL, bytes.NewBufferString("hello"))
   182  				Expect(err).To(BeNil())
   183  
   184  				ctx := context.WithValue(request.Context(), auth.CSRFRequiredKey, true)
   185  				request = request.WithContext(ctx)
   186  
   187  			})
   188  			It("fetches the bool value", func() {
   189  				Expect(auth.IsCSRFRequired(request)).To(BeTrue())
   190  			})
   191  		})
   192  
   193  		Context("when CSRF context is not set", func() {
   194  			BeforeEach(func() {
   195  				request, err = http.NewRequest("GET", server.URL, bytes.NewBufferString("hello"))
   196  				Expect(err).To(BeNil())
   197  			})
   198  			It("fetches the bool value", func() {
   199  				Expect(auth.IsCSRFRequired(request)).To(BeFalse())
   200  			})
   201  
   202  		})
   203  	})
   204  })