github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/skymarshal/skyserver/skyserver_test.go (about)

     1  package skyserver_test
     2  
     3  import (
     4  	"encoding/base64"
     5  	"encoding/json"
     6  	"errors"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/url"
    10  	"time"
    11  
    12  	. "github.com/onsi/ginkgo"
    13  	. "github.com/onsi/gomega"
    14  
    15  	"github.com/onsi/gomega/ghttp"
    16  )
    17  
    18  var _ = Describe("Sky Server API", func() {
    19  
    20  	ExpectServerBehaviour := func() {
    21  
    22  		Describe("GET /sky/login", func() {
    23  			var (
    24  				err      error
    25  				request  *http.Request
    26  				response *http.Response
    27  			)
    28  
    29  			BeforeEach(func() {
    30  				request, err = http.NewRequest("GET", skyServer.URL+"/sky/login", nil)
    31  				Expect(err).NotTo(HaveOccurred())
    32  			})
    33  
    34  			JustBeforeEach(func() {
    35  				skyServer.Client().CheckRedirect = func(req *http.Request, via []*http.Request) error {
    36  					return http.ErrUseLastResponse
    37  				}
    38  
    39  				response, err = skyServer.Client().Do(request)
    40  				Expect(err).NotTo(HaveOccurred())
    41  			})
    42  
    43  			ExpectNewLogin := func() {
    44  
    45  				It("stores a state cookie", func() {
    46  					Expect(fakeTokenMiddleware.SetStateTokenCallCount()).To(Equal(1))
    47  					_, state, _ := fakeTokenMiddleware.SetStateTokenArgsForCall(0)
    48  					Expect(state).NotTo(BeEmpty())
    49  				})
    50  
    51  				It("redirects the initial request to the oauthConfig.AuthURL", func() {
    52  					_, state, _ := fakeTokenMiddleware.SetStateTokenArgsForCall(0)
    53  
    54  					redirectURL, err := response.Location()
    55  					Expect(err).NotTo(HaveOccurred())
    56  					Expect(redirectURL.Path).To(Equal("/auth"))
    57  
    58  					redirectValues := redirectURL.Query()
    59  					Expect(redirectValues.Get("access_type")).To(Equal("offline"))
    60  					Expect(redirectValues.Get("response_type")).To(Equal("code"))
    61  					Expect(redirectValues.Get("state")).To(Equal(state))
    62  					Expect(redirectValues.Get("scope")).To(Equal("some-scope"))
    63  				})
    64  
    65  				Context("when redirect_uri is provided", func() {
    66  					BeforeEach(func() {
    67  						request.URL.RawQuery = "redirect_uri=/redirect"
    68  					})
    69  
    70  					It("stores redirect_uri in the state token cookie", func() {
    71  						_, raw, _ := fakeTokenMiddleware.SetStateTokenArgsForCall(0)
    72  
    73  						data, err := base64.StdEncoding.DecodeString(raw)
    74  						Expect(err).NotTo(HaveOccurred())
    75  
    76  						var state map[string]string
    77  						json.Unmarshal(data, &state)
    78  						Expect(state["redirect_uri"]).To(Equal("/redirect"))
    79  					})
    80  				})
    81  
    82  				Context("when redirect_uri is NOT provided", func() {
    83  					BeforeEach(func() {
    84  						request.URL.RawQuery = ""
    85  					})
    86  
    87  					It("stores / as the default redirect_uri in the state token cookie", func() {
    88  						_, raw, _ := fakeTokenMiddleware.SetStateTokenArgsForCall(0)
    89  
    90  						data, err := base64.StdEncoding.DecodeString(raw)
    91  						Expect(err).NotTo(HaveOccurred())
    92  
    93  						var state map[string]string
    94  						json.Unmarshal(data, &state)
    95  						Expect(state["redirect_uri"]).To(Equal("/"))
    96  					})
    97  				})
    98  			}
    99  
   100  			Context("without an existing token", func() {
   101  				BeforeEach(func() {
   102  					fakeTokenMiddleware.GetAuthTokenReturns("")
   103  				})
   104  				ExpectNewLogin()
   105  			})
   106  
   107  			Context("when the token has no type", func() {
   108  				BeforeEach(func() {
   109  					fakeTokenMiddleware.GetAuthTokenReturns("some-token")
   110  				})
   111  				ExpectNewLogin()
   112  			})
   113  
   114  			Context("when the token is not a valid bearer token", func() {
   115  				BeforeEach(func() {
   116  					fakeTokenMiddleware.GetAuthTokenReturns("not-bearer some-token")
   117  				})
   118  				ExpectNewLogin()
   119  			})
   120  
   121  			Context("when parsing the expiry errors", func() {
   122  				BeforeEach(func() {
   123  					fakeTokenParser.ParseExpiryReturns(time.Time{}, errors.New("error"))
   124  					fakeTokenMiddleware.GetAuthTokenReturns("bearer some-token")
   125  				})
   126  				ExpectNewLogin()
   127  			})
   128  
   129  			Context("when the token is expired", func() {
   130  				BeforeEach(func() {
   131  					fakeTokenParser.ParseExpiryReturns(time.Now().Add(-time.Hour), nil)
   132  					fakeTokenMiddleware.GetAuthTokenReturns("bearer some-token")
   133  				})
   134  				ExpectNewLogin()
   135  			})
   136  
   137  			Context("when the token is valid", func() {
   138  				BeforeEach(func() {
   139  					fakeTokenParser.ParseExpiryReturns(time.Now().Add(time.Hour), nil)
   140  					fakeTokenMiddleware.GetAuthTokenReturns("bearer some-token")
   141  				})
   142  
   143  				It("updates the auth token", func() {
   144  					Expect(fakeTokenMiddleware.SetAuthTokenCallCount()).To(Equal(1))
   145  					_, tokenArg, _ := fakeTokenMiddleware.SetAuthTokenArgsForCall(0)
   146  					Expect(tokenArg).To(Equal("bearer some-token"))
   147  				})
   148  
   149  				It("updates the csrf token", func() {
   150  					Expect(fakeTokenMiddleware.SetCSRFTokenCallCount()).To(Equal(1))
   151  					_, tokenArg, _ := fakeTokenMiddleware.SetCSRFTokenArgsForCall(0)
   152  					Expect(tokenArg).NotTo(BeEmpty())
   153  				})
   154  
   155  				It("redirects the request to the provided redirect_uri", func() {
   156  					_, tokenArg, _ := fakeTokenMiddleware.SetCSRFTokenArgsForCall(0)
   157  
   158  					redirectURL, err := response.Location()
   159  					Expect(err).NotTo(HaveOccurred())
   160  
   161  					atcURL, err := url.Parse(skyServer.URL)
   162  					Expect(err).NotTo(HaveOccurred())
   163  					Expect(redirectURL.Host).To(Equal(atcURL.Host))
   164  
   165  					redirectValues := redirectURL.Query()
   166  					Expect(redirectValues.Get("csrf_token")).To(Equal(tokenArg))
   167  				})
   168  			})
   169  		})
   170  
   171  		Describe("GET /sky/logout", func() {
   172  			var (
   173  				err      error
   174  				request  *http.Request
   175  				response *http.Response
   176  			)
   177  
   178  			BeforeEach(func() {
   179  				request, err = http.NewRequest("GET", skyServer.URL+"/sky/logout", nil)
   180  				Expect(err).NotTo(HaveOccurred())
   181  			})
   182  
   183  			JustBeforeEach(func() {
   184  				response, err = skyServer.Client().Do(request)
   185  				Expect(err).NotTo(HaveOccurred())
   186  			})
   187  
   188  			It("succeeds", func() {
   189  				Expect(response.StatusCode).To(Equal(http.StatusOK))
   190  			})
   191  
   192  			It("removes auth token and csrf token", func() {
   193  				Expect(fakeTokenMiddleware.UnsetAuthTokenCallCount()).To(Equal(1))
   194  				Expect(fakeTokenMiddleware.UnsetCSRFTokenCallCount()).To(Equal(1))
   195  			})
   196  		})
   197  
   198  		Describe("GET /sky/callback", func() {
   199  			var (
   200  				err      error
   201  				request  *http.Request
   202  				response *http.Response
   203  				body     []byte
   204  			)
   205  
   206  			BeforeEach(func() {
   207  				request, err = http.NewRequest("GET", skyServer.URL+"/sky/callback", nil)
   208  				Expect(err).NotTo(HaveOccurred())
   209  			})
   210  
   211  			JustBeforeEach(func() {
   212  				response, err = skyServer.Client().Do(request)
   213  				Expect(err).NotTo(HaveOccurred())
   214  
   215  				body, err = ioutil.ReadAll(response.Body)
   216  				Expect(err).NotTo(HaveOccurred())
   217  			})
   218  
   219  			Context("when there's an error param", func() {
   220  				BeforeEach(func() {
   221  					request.URL.RawQuery = "error=some-error"
   222  				})
   223  
   224  				It("errors", func() {
   225  					Expect(response.StatusCode).To(Equal(http.StatusBadRequest))
   226  				})
   227  
   228  				It("shows the error message", func() {
   229  					Expect(string(body)).To(Equal("some-error\n"))
   230  				})
   231  			})
   232  
   233  			Context("when the state cookie doesn't exist", func() {
   234  				BeforeEach(func() {
   235  					fakeTokenMiddleware.GetStateTokenReturns("")
   236  				})
   237  
   238  				It("errors", func() {
   239  					Expect(response.StatusCode).To(Equal(http.StatusBadRequest))
   240  				})
   241  
   242  				It("shows state cookie invalid message", func() {
   243  					Expect(string(body)).To(Equal("invalid state token\n"))
   244  				})
   245  			})
   246  
   247  			Context("when the cookie state doesn't match the form state", func() {
   248  				BeforeEach(func() {
   249  					fakeTokenMiddleware.GetStateTokenReturns("not-state")
   250  					request.URL.RawQuery = "state=some-state"
   251  				})
   252  
   253  				It("errors", func() {
   254  					Expect(response.StatusCode).To(Equal(http.StatusBadRequest))
   255  				})
   256  
   257  				It("shows state cookie unexpected message", func() {
   258  					Expect(string(body)).To(Equal("unexpected state token\n"))
   259  				})
   260  			})
   261  
   262  			Context("when the cookie state matches the form state", func() {
   263  				BeforeEach(func() {
   264  					fakeTokenMiddleware.GetStateTokenReturns("some-state")
   265  					request.URL.RawQuery = "state=some-state"
   266  				})
   267  
   268  				Context("when there is an authorization code", func() {
   269  					BeforeEach(func() {
   270  						request.URL.RawQuery = "code=some-code&state=some-state"
   271  					})
   272  
   273  					Context("when requesting a token fails", func() {
   274  						BeforeEach(func() {
   275  							dexServer.AppendHandlers(
   276  								ghttp.CombineHandlers(
   277  									ghttp.VerifyRequest("POST", "/token"),
   278  									ghttp.VerifyHeaderKV("Authorization", "Basic ZGV4LWNsaWVudC1pZDpkZXgtY2xpZW50LXNlY3JldA=="),
   279  									ghttp.VerifyFormKV("grant_type", "authorization_code"),
   280  									ghttp.VerifyFormKV("code", "some-code"),
   281  									ghttp.RespondWith(http.StatusInternalServerError, "some-token-error"),
   282  								),
   283  							)
   284  						})
   285  
   286  						It("errors", func() {
   287  							Expect(response.StatusCode).To(Equal(http.StatusInternalServerError))
   288  						})
   289  
   290  						It("shows the oauth2 retrieve error response", func() {
   291  							Expect(string(body)).To(Equal("some-token-error\n"))
   292  						})
   293  					})
   294  
   295  					Context("when requesting a token from dex fails with oauth error (dex 200 with no access_token returned)", func() {
   296  						BeforeEach(func() {
   297  							dexServer.AppendHandlers(
   298  								ghttp.CombineHandlers(
   299  									ghttp.VerifyRequest("POST", "/token"),
   300  									ghttp.RespondWithJSONEncoded(http.StatusOK, map[string]string{
   301  										"token_type": "some-type",
   302  										"id_token":   "some-id-token",
   303  									}),
   304  								),
   305  							)
   306  						})
   307  
   308  						It("errors", func() {
   309  							Expect(response.StatusCode).To(Equal(http.StatusBadRequest))
   310  						})
   311  
   312  						It("shows oauth error", func() {
   313  							Expect(string(body)).To(Equal("oauth2: server response missing access_token\n"))
   314  						})
   315  					})
   316  
   317  					Context("when the server returns a token", func() {
   318  
   319  						BeforeEach(func() {
   320  							dexServer.AppendHandlers(
   321  								ghttp.CombineHandlers(
   322  									ghttp.VerifyRequest("POST", "/token"),
   323  									ghttp.VerifyHeaderKV("Authorization", "Basic ZGV4LWNsaWVudC1pZDpkZXgtY2xpZW50LXNlY3JldA=="),
   324  									ghttp.VerifyFormKV("grant_type", "authorization_code"),
   325  									ghttp.VerifyFormKV("code", "some-code"),
   326  									ghttp.RespondWithJSONEncoded(http.StatusOK, map[string]string{
   327  										"token_type":   "some-type",
   328  										"access_token": "some-token",
   329  										"id_token":     "some-id-token",
   330  									}),
   331  								),
   332  							)
   333  						})
   334  
   335  						Context("when redirect URI is http://example.com", func() {
   336  							BeforeEach(func() {
   337  								state, _ := json.Marshal(map[string]string{
   338  									"redirect_uri": "http://example.com",
   339  								})
   340  
   341  								stateToken := base64.StdEncoding.EncodeToString(state)
   342  								fakeTokenMiddleware.GetStateTokenReturns(stateToken)
   343  
   344  								request.URL.RawQuery = "code=some-code&state=" + stateToken
   345  							})
   346  
   347  							It("errors", func() {
   348  								Expect(response.StatusCode).To(Equal(http.StatusBadRequest))
   349  							})
   350  						})
   351  
   352  						Context("when redirect URI is https:example.com", func() {
   353  							BeforeEach(func() {
   354  								state, _ := json.Marshal(map[string]string{
   355  									"redirect_uri": "https:google.com",
   356  								})
   357  
   358  								stateToken := base64.StdEncoding.EncodeToString(state)
   359  								fakeTokenMiddleware.GetStateTokenReturns(stateToken)
   360  
   361  								request.URL.RawQuery = "code=some-code&state=" + stateToken
   362  							})
   363  
   364  							It("doesn't error on Get https:google.com", func() {
   365  								Expect(response.StatusCode).To(Equal(http.StatusNotFound))
   366  							})
   367  						})
   368  
   369  						Context("when redirect URI is example.com", func() {
   370  							BeforeEach(func() {
   371  								state, _ := json.Marshal(map[string]string{
   372  									"redirect_uri": "example.com",
   373  								})
   374  
   375  								stateToken := base64.StdEncoding.EncodeToString(state)
   376  								fakeTokenMiddleware.GetStateTokenReturns(stateToken)
   377  
   378  								request.URL.RawQuery = "code=some-code&state=" + stateToken
   379  							})
   380  
   381  							It("errors", func() {
   382  								Expect(response.StatusCode).To(Equal(http.StatusBadRequest))
   383  							})
   384  						})
   385  
   386  						Context("when redirecting to the ATC", func() {
   387  							BeforeEach(func() {
   388  								state, _ := json.Marshal(map[string]string{
   389  									"redirect_uri": "/valid-redirect",
   390  								})
   391  
   392  								stateToken := base64.StdEncoding.EncodeToString(state)
   393  								fakeTokenMiddleware.GetStateTokenReturns(stateToken)
   394  
   395  								request.URL.RawQuery = "code=some-code&state=" + stateToken
   396  							})
   397  
   398  							Context("when setting the auth token fails", func() {
   399  								BeforeEach(func() {
   400  									fakeTokenMiddleware.SetAuthTokenReturns(errors.New("nope"))
   401  								})
   402  								It("errors", func() {
   403  									Expect(response.StatusCode).To(Equal(http.StatusInternalServerError))
   404  								})
   405  							})
   406  
   407  							Context("when setting the auth token succeeds", func() {
   408  								BeforeEach(func() {
   409  									fakeTokenMiddleware.SetAuthTokenReturns(nil)
   410  								})
   411  
   412  								Context("when setting the csrf token fails", func() {
   413  									BeforeEach(func() {
   414  										fakeTokenMiddleware.SetCSRFTokenReturns(errors.New("nope"))
   415  									})
   416  									It("errors", func() {
   417  										Expect(response.StatusCode).To(Equal(http.StatusInternalServerError))
   418  									})
   419  								})
   420  
   421  								Context("when setting the csrf token succeeds", func() {
   422  									BeforeEach(func() {
   423  										fakeTokenMiddleware.SetCSRFTokenReturns(nil)
   424  									})
   425  
   426  									It("unsets the cookie state", func() {
   427  										Expect(fakeTokenMiddleware.UnsetStateTokenCallCount()).To(Equal(1))
   428  									})
   429  
   430  									It("saves the access token from the response", func() {
   431  										Expect(fakeTokenMiddleware.SetAuthTokenCallCount()).To(Equal(1))
   432  										_, tokenString, _ := fakeTokenMiddleware.SetAuthTokenArgsForCall(0)
   433  										Expect(tokenString).To(Equal("some-type some-token"))
   434  									})
   435  
   436  									It("sets a new csrf token", func() {
   437  										Expect(fakeTokenMiddleware.SetCSRFTokenCallCount()).To(Equal(1))
   438  										_, tokenString, _ := fakeTokenMiddleware.SetCSRFTokenArgsForCall(0)
   439  										Expect(tokenString).NotTo(BeEmpty())
   440  									})
   441  
   442  									It("redirects to redirect_uri from state token with the csrf_token", func() {
   443  										_, tokenArg, _ := fakeTokenMiddleware.SetCSRFTokenArgsForCall(0)
   444  
   445  										redirectResponse := response.Request.Response
   446  										Expect(redirectResponse).NotTo(BeNil())
   447  										Expect(redirectResponse.StatusCode).To(Equal(http.StatusTemporaryRedirect))
   448  
   449  										skyServerURL, err := url.Parse(skyServer.URL)
   450  										Expect(err).NotTo(HaveOccurred())
   451  
   452  										locationURL, err := redirectResponse.Location()
   453  										Expect(err).NotTo(HaveOccurred())
   454  										Expect(locationURL.Host).To(Equal(skyServerURL.Host))
   455  										Expect(locationURL.Path).To(Equal("/valid-redirect"))
   456  										Expect(locationURL.Query().Get("csrf_token")).To(Equal(tokenArg))
   457  									})
   458  								})
   459  							})
   460  						})
   461  					})
   462  				})
   463  			})
   464  		})
   465  	}
   466  
   467  	Describe("With TLS Server", func() {
   468  		BeforeEach(func() {
   469  			skyServer.StartTLS()
   470  		})
   471  
   472  		ExpectServerBehaviour()
   473  	})
   474  
   475  	Describe("Without TLS Server", func() {
   476  		BeforeEach(func() {
   477  			skyServer.Start()
   478  		})
   479  
   480  		ExpectServerBehaviour()
   481  	})
   482  })