github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/skymarshal/skyserver/skyserver_test.go (about) 1 package skyserver_test 2 3 import ( 4 "encoding/base64" 5 "encoding/json" 6 "errors" 7 "io/ioutil" 8 "net/http" 9 "net/url" 10 "time" 11 12 . "github.com/onsi/ginkgo" 13 . "github.com/onsi/gomega" 14 15 "github.com/onsi/gomega/ghttp" 16 ) 17 18 var _ = Describe("Sky Server API", func() { 19 20 ExpectServerBehaviour := func() { 21 22 Describe("GET /sky/login", func() { 23 var ( 24 err error 25 request *http.Request 26 response *http.Response 27 ) 28 29 BeforeEach(func() { 30 request, err = http.NewRequest("GET", skyServer.URL+"/sky/login", nil) 31 Expect(err).NotTo(HaveOccurred()) 32 }) 33 34 JustBeforeEach(func() { 35 skyServer.Client().CheckRedirect = func(req *http.Request, via []*http.Request) error { 36 return http.ErrUseLastResponse 37 } 38 39 response, err = skyServer.Client().Do(request) 40 Expect(err).NotTo(HaveOccurred()) 41 }) 42 43 ExpectNewLogin := func() { 44 45 It("stores a state cookie", func() { 46 Expect(fakeTokenMiddleware.SetStateTokenCallCount()).To(Equal(1)) 47 _, state, _ := fakeTokenMiddleware.SetStateTokenArgsForCall(0) 48 Expect(state).NotTo(BeEmpty()) 49 }) 50 51 It("redirects the initial request to the oauthConfig.AuthURL", func() { 52 _, state, _ := fakeTokenMiddleware.SetStateTokenArgsForCall(0) 53 54 redirectURL, err := response.Location() 55 Expect(err).NotTo(HaveOccurred()) 56 Expect(redirectURL.Path).To(Equal("/auth")) 57 58 redirectValues := redirectURL.Query() 59 Expect(redirectValues.Get("access_type")).To(Equal("offline")) 60 Expect(redirectValues.Get("response_type")).To(Equal("code")) 61 Expect(redirectValues.Get("state")).To(Equal(state)) 62 Expect(redirectValues.Get("scope")).To(Equal("some-scope")) 63 }) 64 65 Context("when redirect_uri is provided", func() { 66 BeforeEach(func() { 67 request.URL.RawQuery = "redirect_uri=/redirect" 68 }) 69 70 It("stores redirect_uri in the state token cookie", func() { 71 _, raw, _ := fakeTokenMiddleware.SetStateTokenArgsForCall(0) 72 73 data, err := base64.StdEncoding.DecodeString(raw) 74 Expect(err).NotTo(HaveOccurred()) 75 76 var state map[string]string 77 json.Unmarshal(data, &state) 78 Expect(state["redirect_uri"]).To(Equal("/redirect")) 79 }) 80 }) 81 82 Context("when redirect_uri is NOT provided", func() { 83 BeforeEach(func() { 84 request.URL.RawQuery = "" 85 }) 86 87 It("stores / as the default redirect_uri in the state token cookie", func() { 88 _, raw, _ := fakeTokenMiddleware.SetStateTokenArgsForCall(0) 89 90 data, err := base64.StdEncoding.DecodeString(raw) 91 Expect(err).NotTo(HaveOccurred()) 92 93 var state map[string]string 94 json.Unmarshal(data, &state) 95 Expect(state["redirect_uri"]).To(Equal("/")) 96 }) 97 }) 98 } 99 100 Context("without an existing token", func() { 101 BeforeEach(func() { 102 fakeTokenMiddleware.GetAuthTokenReturns("") 103 }) 104 ExpectNewLogin() 105 }) 106 107 Context("when the token has no type", func() { 108 BeforeEach(func() { 109 fakeTokenMiddleware.GetAuthTokenReturns("some-token") 110 }) 111 ExpectNewLogin() 112 }) 113 114 Context("when the token is not a valid bearer token", func() { 115 BeforeEach(func() { 116 fakeTokenMiddleware.GetAuthTokenReturns("not-bearer some-token") 117 }) 118 ExpectNewLogin() 119 }) 120 121 Context("when parsing the expiry errors", func() { 122 BeforeEach(func() { 123 fakeTokenParser.ParseExpiryReturns(time.Time{}, errors.New("error")) 124 fakeTokenMiddleware.GetAuthTokenReturns("bearer some-token") 125 }) 126 ExpectNewLogin() 127 }) 128 129 Context("when the token is expired", func() { 130 BeforeEach(func() { 131 fakeTokenParser.ParseExpiryReturns(time.Now().Add(-time.Hour), nil) 132 fakeTokenMiddleware.GetAuthTokenReturns("bearer some-token") 133 }) 134 ExpectNewLogin() 135 }) 136 137 Context("when the token is valid", func() { 138 BeforeEach(func() { 139 fakeTokenParser.ParseExpiryReturns(time.Now().Add(time.Hour), nil) 140 fakeTokenMiddleware.GetAuthTokenReturns("bearer some-token") 141 }) 142 143 It("updates the auth token", func() { 144 Expect(fakeTokenMiddleware.SetAuthTokenCallCount()).To(Equal(1)) 145 _, tokenArg, _ := fakeTokenMiddleware.SetAuthTokenArgsForCall(0) 146 Expect(tokenArg).To(Equal("bearer some-token")) 147 }) 148 149 It("updates the csrf token", func() { 150 Expect(fakeTokenMiddleware.SetCSRFTokenCallCount()).To(Equal(1)) 151 _, tokenArg, _ := fakeTokenMiddleware.SetCSRFTokenArgsForCall(0) 152 Expect(tokenArg).NotTo(BeEmpty()) 153 }) 154 155 It("redirects the request to the provided redirect_uri", func() { 156 _, tokenArg, _ := fakeTokenMiddleware.SetCSRFTokenArgsForCall(0) 157 158 redirectURL, err := response.Location() 159 Expect(err).NotTo(HaveOccurred()) 160 161 atcURL, err := url.Parse(skyServer.URL) 162 Expect(err).NotTo(HaveOccurred()) 163 Expect(redirectURL.Host).To(Equal(atcURL.Host)) 164 165 redirectValues := redirectURL.Query() 166 Expect(redirectValues.Get("csrf_token")).To(Equal(tokenArg)) 167 }) 168 }) 169 }) 170 171 Describe("GET /sky/logout", func() { 172 var ( 173 err error 174 request *http.Request 175 response *http.Response 176 ) 177 178 BeforeEach(func() { 179 request, err = http.NewRequest("GET", skyServer.URL+"/sky/logout", nil) 180 Expect(err).NotTo(HaveOccurred()) 181 }) 182 183 JustBeforeEach(func() { 184 response, err = skyServer.Client().Do(request) 185 Expect(err).NotTo(HaveOccurred()) 186 }) 187 188 It("succeeds", func() { 189 Expect(response.StatusCode).To(Equal(http.StatusOK)) 190 }) 191 192 It("removes auth token and csrf token", func() { 193 Expect(fakeTokenMiddleware.UnsetAuthTokenCallCount()).To(Equal(1)) 194 Expect(fakeTokenMiddleware.UnsetCSRFTokenCallCount()).To(Equal(1)) 195 }) 196 }) 197 198 Describe("GET /sky/callback", func() { 199 var ( 200 err error 201 request *http.Request 202 response *http.Response 203 body []byte 204 ) 205 206 BeforeEach(func() { 207 request, err = http.NewRequest("GET", skyServer.URL+"/sky/callback", nil) 208 Expect(err).NotTo(HaveOccurred()) 209 }) 210 211 JustBeforeEach(func() { 212 response, err = skyServer.Client().Do(request) 213 Expect(err).NotTo(HaveOccurred()) 214 215 body, err = ioutil.ReadAll(response.Body) 216 Expect(err).NotTo(HaveOccurred()) 217 }) 218 219 Context("when there's an error param", func() { 220 BeforeEach(func() { 221 request.URL.RawQuery = "error=some-error" 222 }) 223 224 It("errors", func() { 225 Expect(response.StatusCode).To(Equal(http.StatusBadRequest)) 226 }) 227 228 It("shows the error message", func() { 229 Expect(string(body)).To(Equal("some-error\n")) 230 }) 231 }) 232 233 Context("when the state cookie doesn't exist", func() { 234 BeforeEach(func() { 235 fakeTokenMiddleware.GetStateTokenReturns("") 236 }) 237 238 It("errors", func() { 239 Expect(response.StatusCode).To(Equal(http.StatusBadRequest)) 240 }) 241 242 It("shows state cookie invalid message", func() { 243 Expect(string(body)).To(Equal("invalid state token\n")) 244 }) 245 }) 246 247 Context("when the cookie state doesn't match the form state", func() { 248 BeforeEach(func() { 249 fakeTokenMiddleware.GetStateTokenReturns("not-state") 250 request.URL.RawQuery = "state=some-state" 251 }) 252 253 It("errors", func() { 254 Expect(response.StatusCode).To(Equal(http.StatusBadRequest)) 255 }) 256 257 It("shows state cookie unexpected message", func() { 258 Expect(string(body)).To(Equal("unexpected state token\n")) 259 }) 260 }) 261 262 Context("when the cookie state matches the form state", func() { 263 BeforeEach(func() { 264 fakeTokenMiddleware.GetStateTokenReturns("some-state") 265 request.URL.RawQuery = "state=some-state" 266 }) 267 268 Context("when there is an authorization code", func() { 269 BeforeEach(func() { 270 request.URL.RawQuery = "code=some-code&state=some-state" 271 }) 272 273 Context("when requesting a token fails", func() { 274 BeforeEach(func() { 275 dexServer.AppendHandlers( 276 ghttp.CombineHandlers( 277 ghttp.VerifyRequest("POST", "/token"), 278 ghttp.VerifyHeaderKV("Authorization", "Basic ZGV4LWNsaWVudC1pZDpkZXgtY2xpZW50LXNlY3JldA=="), 279 ghttp.VerifyFormKV("grant_type", "authorization_code"), 280 ghttp.VerifyFormKV("code", "some-code"), 281 ghttp.RespondWith(http.StatusInternalServerError, "some-token-error"), 282 ), 283 ) 284 }) 285 286 It("errors", func() { 287 Expect(response.StatusCode).To(Equal(http.StatusInternalServerError)) 288 }) 289 290 It("shows the oauth2 retrieve error response", func() { 291 Expect(string(body)).To(Equal("some-token-error\n")) 292 }) 293 }) 294 295 Context("when requesting a token from dex fails with oauth error (dex 200 with no access_token returned)", func() { 296 BeforeEach(func() { 297 dexServer.AppendHandlers( 298 ghttp.CombineHandlers( 299 ghttp.VerifyRequest("POST", "/token"), 300 ghttp.RespondWithJSONEncoded(http.StatusOK, map[string]string{ 301 "token_type": "some-type", 302 "id_token": "some-id-token", 303 }), 304 ), 305 ) 306 }) 307 308 It("errors", func() { 309 Expect(response.StatusCode).To(Equal(http.StatusBadRequest)) 310 }) 311 312 It("shows oauth error", func() { 313 Expect(string(body)).To(Equal("oauth2: server response missing access_token\n")) 314 }) 315 }) 316 317 Context("when the server returns a token", func() { 318 319 BeforeEach(func() { 320 dexServer.AppendHandlers( 321 ghttp.CombineHandlers( 322 ghttp.VerifyRequest("POST", "/token"), 323 ghttp.VerifyHeaderKV("Authorization", "Basic ZGV4LWNsaWVudC1pZDpkZXgtY2xpZW50LXNlY3JldA=="), 324 ghttp.VerifyFormKV("grant_type", "authorization_code"), 325 ghttp.VerifyFormKV("code", "some-code"), 326 ghttp.RespondWithJSONEncoded(http.StatusOK, map[string]string{ 327 "token_type": "some-type", 328 "access_token": "some-token", 329 "id_token": "some-id-token", 330 }), 331 ), 332 ) 333 }) 334 335 Context("when redirect URI is http://example.com", func() { 336 BeforeEach(func() { 337 state, _ := json.Marshal(map[string]string{ 338 "redirect_uri": "http://example.com", 339 }) 340 341 stateToken := base64.StdEncoding.EncodeToString(state) 342 fakeTokenMiddleware.GetStateTokenReturns(stateToken) 343 344 request.URL.RawQuery = "code=some-code&state=" + stateToken 345 }) 346 347 It("errors", func() { 348 Expect(response.StatusCode).To(Equal(http.StatusBadRequest)) 349 }) 350 }) 351 352 Context("when redirect URI is https:example.com", func() { 353 BeforeEach(func() { 354 state, _ := json.Marshal(map[string]string{ 355 "redirect_uri": "https:google.com", 356 }) 357 358 stateToken := base64.StdEncoding.EncodeToString(state) 359 fakeTokenMiddleware.GetStateTokenReturns(stateToken) 360 361 request.URL.RawQuery = "code=some-code&state=" + stateToken 362 }) 363 364 It("doesn't error on Get https:google.com", func() { 365 Expect(response.StatusCode).To(Equal(http.StatusNotFound)) 366 }) 367 }) 368 369 Context("when redirect URI is example.com", func() { 370 BeforeEach(func() { 371 state, _ := json.Marshal(map[string]string{ 372 "redirect_uri": "example.com", 373 }) 374 375 stateToken := base64.StdEncoding.EncodeToString(state) 376 fakeTokenMiddleware.GetStateTokenReturns(stateToken) 377 378 request.URL.RawQuery = "code=some-code&state=" + stateToken 379 }) 380 381 It("errors", func() { 382 Expect(response.StatusCode).To(Equal(http.StatusBadRequest)) 383 }) 384 }) 385 386 Context("when redirecting to the ATC", func() { 387 BeforeEach(func() { 388 state, _ := json.Marshal(map[string]string{ 389 "redirect_uri": "/valid-redirect", 390 }) 391 392 stateToken := base64.StdEncoding.EncodeToString(state) 393 fakeTokenMiddleware.GetStateTokenReturns(stateToken) 394 395 request.URL.RawQuery = "code=some-code&state=" + stateToken 396 }) 397 398 Context("when setting the auth token fails", func() { 399 BeforeEach(func() { 400 fakeTokenMiddleware.SetAuthTokenReturns(errors.New("nope")) 401 }) 402 It("errors", func() { 403 Expect(response.StatusCode).To(Equal(http.StatusInternalServerError)) 404 }) 405 }) 406 407 Context("when setting the auth token succeeds", func() { 408 BeforeEach(func() { 409 fakeTokenMiddleware.SetAuthTokenReturns(nil) 410 }) 411 412 Context("when setting the csrf token fails", func() { 413 BeforeEach(func() { 414 fakeTokenMiddleware.SetCSRFTokenReturns(errors.New("nope")) 415 }) 416 It("errors", func() { 417 Expect(response.StatusCode).To(Equal(http.StatusInternalServerError)) 418 }) 419 }) 420 421 Context("when setting the csrf token succeeds", func() { 422 BeforeEach(func() { 423 fakeTokenMiddleware.SetCSRFTokenReturns(nil) 424 }) 425 426 It("unsets the cookie state", func() { 427 Expect(fakeTokenMiddleware.UnsetStateTokenCallCount()).To(Equal(1)) 428 }) 429 430 It("saves the access token from the response", func() { 431 Expect(fakeTokenMiddleware.SetAuthTokenCallCount()).To(Equal(1)) 432 _, tokenString, _ := fakeTokenMiddleware.SetAuthTokenArgsForCall(0) 433 Expect(tokenString).To(Equal("some-type some-token")) 434 }) 435 436 It("sets a new csrf token", func() { 437 Expect(fakeTokenMiddleware.SetCSRFTokenCallCount()).To(Equal(1)) 438 _, tokenString, _ := fakeTokenMiddleware.SetCSRFTokenArgsForCall(0) 439 Expect(tokenString).NotTo(BeEmpty()) 440 }) 441 442 It("redirects to redirect_uri from state token with the csrf_token", func() { 443 _, tokenArg, _ := fakeTokenMiddleware.SetCSRFTokenArgsForCall(0) 444 445 redirectResponse := response.Request.Response 446 Expect(redirectResponse).NotTo(BeNil()) 447 Expect(redirectResponse.StatusCode).To(Equal(http.StatusTemporaryRedirect)) 448 449 skyServerURL, err := url.Parse(skyServer.URL) 450 Expect(err).NotTo(HaveOccurred()) 451 452 locationURL, err := redirectResponse.Location() 453 Expect(err).NotTo(HaveOccurred()) 454 Expect(locationURL.Host).To(Equal(skyServerURL.Host)) 455 Expect(locationURL.Path).To(Equal("/valid-redirect")) 456 Expect(locationURL.Query().Get("csrf_token")).To(Equal(tokenArg)) 457 }) 458 }) 459 }) 460 }) 461 }) 462 }) 463 }) 464 }) 465 } 466 467 Describe("With TLS Server", func() { 468 BeforeEach(func() { 469 skyServer.StartTLS() 470 }) 471 472 ExpectServerBehaviour() 473 }) 474 475 Describe("Without TLS Server", func() { 476 BeforeEach(func() { 477 skyServer.Start() 478 }) 479 480 ExpectServerBehaviour() 481 }) 482 })