github.com/sap/cf-mta-plugin@v2.6.3+incompatible/clients/csrf/csrf_token_manager_test.go (about)

     1  package csrf
     2  
     3  import (
     4  	"net/http"
     5  	"net/url"
     6  
     7  	"github.com/cloudfoundry-incubator/multiapps-cli-plugin/clients/csrf/fakes"
     8  	. "github.com/onsi/ginkgo"
     9  	. "github.com/onsi/gomega"
    10  )
    11  
    12  const testUrl = "http://localhost:1000"
    13  
    14  const csrfTokenNotSet = ""
    15  
    16  var _ = Describe("DefaultCsrfTokenUpdater", func() {
    17  	Context("", func() {
    18  		It("protection not needed", func() {
    19  			transport, request := createTransport(), createRequest(http.MethodGet)
    20  			csrfTokenManager := NewDefaultCsrfTokenManager(transport, request)
    21  			Expect(csrfTokenManager.isProtectionRequired(request, transport)).To(BeFalse())
    22  		})
    23  		It("protection not needed", func() {
    24  			transport, request := createTransport(), createRequest(http.MethodOptions)
    25  			csrfTokenManager := NewDefaultCsrfTokenManager(transport, request)
    26  			Expect(csrfTokenManager.isProtectionRequired(request, transport)).To(BeFalse())
    27  		})
    28  		It("protection not needed", func() {
    29  			transport, request := createTransport(), createRequest(http.MethodHead)
    30  			csrfTokenManager := NewDefaultCsrfTokenManager(transport, request)
    31  			Expect(csrfTokenManager.isProtectionRequired(request, transport)).To(BeFalse())
    32  		})
    33  		It("protection needed", func() {
    34  			transport, request := createTransport(), createRequest(http.MethodPost)
    35  			csrfTokenManager := NewDefaultCsrfTokenManager(transport, request)
    36  			Expect(csrfTokenManager.isProtectionRequired(request, transport)).To(BeTrue())
    37  		})
    38  		It("retry is not needed", func() {
    39  			transport, request := createTransport(), createRequest(http.MethodPost)
    40  			csrfTokenManager := NewDefaultCsrfTokenManager(transport, request)
    41  			Expect(csrfTokenManager.refreshTokenIfNeeded(createResponse(http.StatusOK, ""))).To(BeFalse())
    42  		})
    43  		It("retry is not needed", func() {
    44  			transport, request := createTransport(), createRequest(http.MethodPost)
    45  			csrfTokenManager := NewDefaultCsrfTokenManager(transport, request)
    46  			Expect(csrfTokenManager.refreshTokenIfNeeded(createResponse(http.StatusForbidden, CsrfTokenHeaderRequiredValue))).To(BeFalse())
    47  		})
    48  		It("retry is needed", func() {
    49  			transport := createTransport()
    50  			transport.Csrf.IsInitialized = true
    51  			request := createRequest(http.MethodPost)
    52  			csrfTokenManager := NewDefaultCsrfTokenManagerWithFetcher(transport, request, fakes.NewFakeCsrfTokenFetcher())
    53  			isRetryNeeded, err := csrfTokenManager.refreshTokenIfNeeded(createResponse(http.StatusForbidden, CsrfTokenHeaderRequiredValue))
    54  			Ω(err).ShouldNot(HaveOccurred())
    55  			Expect(isRetryNeeded).To(BeTrue())
    56  		})
    57  		It("initialize new token", func() {
    58  			transport := createTransport()
    59  			transport.Csrf.IsInitialized = true
    60  			request := createRequest(http.MethodPost)
    61  			csrfTokenManager := NewDefaultCsrfTokenManagerWithFetcher(transport, request, fakes.NewFakeCsrfTokenFetcher())
    62  			err := csrfTokenManager.initializeToken(true)
    63  			Ω(err).ShouldNot(HaveOccurred())
    64  			Expect(transport.Csrf.Header).To(Equal(fakes.FakeCsrfTokenHeader))
    65  			Expect(transport.Csrf.Token).To(Equal(fakes.FakeCsrfTokenValue))
    66  			Expect(transport.Csrf.IsInitialized).To(BeTrue())
    67  		})
    68  		It("update current csrf tokens", func() {
    69  			transport := createTransport()
    70  			request := createRequest(http.MethodGet)
    71  			csrfTokenManager := NewDefaultCsrfTokenManagerWithFetcher(transport, request, fakes.NewFakeCsrfTokenFetcher())
    72  			err := csrfTokenManager.initializeToken(true)
    73  			Ω(err).ShouldNot(HaveOccurred())
    74  			csrfTokenManager.updateTokenInRequest()
    75  			expectCsrfTokenIsProperlySet(request, fakes.FakeCsrfTokenHeader, fakes.FakeCsrfTokenValue)
    76  		})
    77  		It("should not update csrf tokens", func() {
    78  			transport, request := createTransport(), createRequest(http.MethodGet)
    79  			csrfTokenManager := NewDefaultCsrfTokenManagerWithFetcher(transport, request, fakes.NewFakeCsrfTokenFetcher())
    80  			err := csrfTokenManager.updateToken()
    81  			Ω(err).ShouldNot(HaveOccurred())
    82  			expectCsrfTokenIsProperlySet(request, csrfTokenNotSet, csrfTokenNotSet)
    83  		})
    84  		It("should not update csrf tokens", func() {
    85  			transport, request := createTransport(), createRequest(http.MethodPost)
    86  			transport.Csrf.IsInitialized = true
    87  			csrfTokenManager := NewDefaultCsrfTokenManagerWithFetcher(transport, request, fakes.NewFakeCsrfTokenFetcher())
    88  			err := csrfTokenManager.updateToken()
    89  			Ω(err).ShouldNot(HaveOccurred())
    90  			expectCsrfTokenIsProperlySet(request, csrfTokenNotSet, csrfTokenNotSet)
    91  		})
    92  		It("should not update csrf tokens", func() {
    93  			transport, request := createTransport(), createRequest(http.MethodGet)
    94  			csrfTokenManager := NewDefaultCsrfTokenManagerWithFetcher(transport, request, fakes.NewFakeCsrfTokenFetcher())
    95  			err := csrfTokenManager.updateToken()
    96  			Ω(err).ShouldNot(HaveOccurred())
    97  			expectCsrfTokenIsProperlySet(request, csrfTokenNotSet, csrfTokenNotSet)
    98  		})
    99  		It("should update csrf tokens", func() {
   100  			transport, request := createTransport(), createRequest(http.MethodPost)
   101  			csrfTokenManager := NewDefaultCsrfTokenManagerWithFetcher(transport, request, fakes.NewFakeCsrfTokenFetcher())
   102  			err := csrfTokenManager.updateToken()
   103  			Ω(err).ShouldNot(HaveOccurred())
   104  			expectCsrfTokenIsProperlySet(request, fakes.FakeCsrfTokenHeader, fakes.FakeCsrfTokenValue)
   105  		})
   106  		Context("set cookies in the request, valid cookies", func() {
   107  			It("should be equal", func() {
   108  				request := createRequest(http.MethodGet)
   109  				cookies := createValidCookies()
   110  				UpdateCookiesIfNeeded(cookies, request)
   111  				Expect(cookies).To(Equal(request.Cookies()))
   112  			})
   113  		})
   114  	})
   115  })
   116  
   117  func expectCsrfTokenIsProperlySet(request *http.Request, csrfTokenHeader, csrfTokenValue string) {
   118  	Expect(request.Header.Get(XCsrfHeader)).To(Equal(csrfTokenHeader))
   119  	Expect(request.Header.Get(XCsrfToken)).To(Equal(csrfTokenValue))
   120  }
   121  
   122  func createResponse(httpStatusCode int, csrfToken string) *http.Response {
   123  	response := &http.Response{}
   124  	response.Header = make(http.Header)
   125  	response.StatusCode = httpStatusCode
   126  	response.Header.Set(XCsrfToken, csrfToken)
   127  
   128  	return response
   129  }
   130  
   131  func createTransport() *Transport {
   132  	return &Transport{http.DefaultTransport.(*http.Transport),
   133  		&Csrf{"", "", false, getNonProtectedMethods()}, &Cookies{[]*http.Cookie{}}}
   134  }
   135  
   136  func getNonProtectedMethods() map[string]bool {
   137  	nonProtectedMethods := make(map[string]bool)
   138  
   139  	nonProtectedMethods[http.MethodGet] = true
   140  	nonProtectedMethods[http.MethodHead] = true
   141  	nonProtectedMethods[http.MethodOptions] = true
   142  
   143  	return nonProtectedMethods
   144  }
   145  
   146  func createValidCookies() []*http.Cookie {
   147  	var cookies []*http.Cookie
   148  	cookie1 := &http.Cookie{}
   149  	cookie1.Name = "JSESSION"
   150  	cookie1.Value = "123"
   151  	cookie2 := &http.Cookie{}
   152  	cookie2.Name = "__V_CAP__"
   153  	cookie2.Value = "321"
   154  	cookies = append(cookies, cookie1)
   155  	cookies = append(cookies, cookie2)
   156  
   157  	return cookies
   158  }
   159  
   160  func createRequest(method string) *http.Request {
   161  	request := &http.Request{}
   162  	requestUrl := &url.URL{}
   163  	requestUrl.Scheme = "http"
   164  	requestUrl.Host = "localhost:1000"
   165  	request.URL = requestUrl
   166  	request.Header = make(http.Header)
   167  	request.Method = method
   168  
   169  	return request
   170  }