github.com/pusher/oauth2_proxy@v3.2.0+incompatible/oauthproxy_test.go (about) 1 package main 2 3 import ( 4 "crypto" 5 "encoding/base64" 6 "io" 7 "io/ioutil" 8 "log" 9 "net" 10 "net/http" 11 "net/http/httptest" 12 "net/url" 13 "regexp" 14 "strings" 15 "testing" 16 "time" 17 18 "github.com/mbland/hmacauth" 19 "github.com/pusher/oauth2_proxy/providers" 20 "github.com/stretchr/testify/assert" 21 "github.com/stretchr/testify/require" 22 "golang.org/x/net/websocket" 23 ) 24 25 func init() { 26 log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) 27 28 } 29 30 type WebSocketOrRestHandler struct { 31 restHandler http.Handler 32 wsHandler http.Handler 33 } 34 35 func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 36 if r.Header.Get("Upgrade") == "websocket" { 37 h.wsHandler.ServeHTTP(w, r) 38 } else { 39 h.restHandler.ServeHTTP(w, r) 40 } 41 } 42 43 func TestWebSocketProxy(t *testing.T) { 44 handler := WebSocketOrRestHandler{ 45 restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 46 w.WriteHeader(200) 47 hostname, _, _ := net.SplitHostPort(r.Host) 48 w.Write([]byte(hostname)) 49 }), 50 wsHandler: websocket.Handler(func(ws *websocket.Conn) { 51 defer ws.Close() 52 var data []byte 53 err := websocket.Message.Receive(ws, &data) 54 if err != nil { 55 t.Fatalf("err %s", err) 56 return 57 } 58 err = websocket.Message.Send(ws, data) 59 if err != nil { 60 t.Fatalf("err %s", err) 61 } 62 return 63 }), 64 } 65 backend := httptest.NewServer(&handler) 66 defer backend.Close() 67 68 backendURL, _ := url.Parse(backend.URL) 69 70 options := NewOptions() 71 var auth hmacauth.HmacAuth 72 options.PassHostHeader = true 73 proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, options, auth) 74 frontend := httptest.NewServer(proxyHandler) 75 defer frontend.Close() 76 77 frontendURL, _ := url.Parse(frontend.URL) 78 frontendWSURL := "ws://" + frontendURL.Host + "/" 79 80 ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/") 81 if err != nil { 82 t.Fatalf("err %s", err) 83 } 84 request := []byte("hello, world!") 85 err = websocket.Message.Send(ws, request) 86 if err != nil { 87 t.Fatalf("err %s", err) 88 } 89 var response = make([]byte, 1024) 90 websocket.Message.Receive(ws, &response) 91 if err != nil { 92 t.Fatalf("err %s", err) 93 } 94 if g, e := string(request), string(response); g != e { 95 t.Errorf("got body %q; expected %q", g, e) 96 } 97 98 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 99 res, _ := http.DefaultClient.Do(getReq) 100 bodyBytes, _ := ioutil.ReadAll(res.Body) 101 backendHostname, _, _ := net.SplitHostPort(backendURL.Host) 102 if g, e := string(bodyBytes), backendHostname; g != e { 103 t.Errorf("got body %q; expected %q", g, e) 104 } 105 } 106 107 func TestNewReverseProxy(t *testing.T) { 108 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 109 w.WriteHeader(200) 110 hostname, _, _ := net.SplitHostPort(r.Host) 111 w.Write([]byte(hostname)) 112 })) 113 defer backend.Close() 114 115 backendURL, _ := url.Parse(backend.URL) 116 backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) 117 backendHost := net.JoinHostPort(backendHostname, backendPort) 118 proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") 119 120 proxyHandler := NewReverseProxy(proxyURL, time.Second) 121 setProxyUpstreamHostHeader(proxyHandler, proxyURL) 122 frontend := httptest.NewServer(proxyHandler) 123 defer frontend.Close() 124 125 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 126 res, _ := http.DefaultClient.Do(getReq) 127 bodyBytes, _ := ioutil.ReadAll(res.Body) 128 if g, e := string(bodyBytes), backendHostname; g != e { 129 t.Errorf("got body %q; expected %q", g, e) 130 } 131 } 132 133 func TestEncodedSlashes(t *testing.T) { 134 var seen string 135 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 136 w.WriteHeader(200) 137 seen = r.RequestURI 138 })) 139 defer backend.Close() 140 141 b, _ := url.Parse(backend.URL) 142 proxyHandler := NewReverseProxy(b, time.Second) 143 setProxyDirector(proxyHandler) 144 frontend := httptest.NewServer(proxyHandler) 145 defer frontend.Close() 146 147 f, _ := url.Parse(frontend.URL) 148 encodedPath := "/a%2Fb/?c=1" 149 getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}} 150 _, err := http.DefaultClient.Do(getReq) 151 if err != nil { 152 t.Fatalf("err %s", err) 153 } 154 if seen != encodedPath { 155 t.Errorf("got bad request %q expected %q", seen, encodedPath) 156 } 157 } 158 159 func TestRobotsTxt(t *testing.T) { 160 opts := NewOptions() 161 opts.ClientID = "bazquux" 162 opts.ClientSecret = "foobar" 163 opts.CookieSecret = "xyzzyplugh" 164 opts.Validate() 165 166 proxy := NewOAuthProxy(opts, func(string) bool { return true }) 167 rw := httptest.NewRecorder() 168 req, _ := http.NewRequest("GET", "/robots.txt", nil) 169 proxy.ServeHTTP(rw, req) 170 assert.Equal(t, 200, rw.Code) 171 assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String()) 172 } 173 174 func TestIsValidRedirect(t *testing.T) { 175 opts := NewOptions() 176 opts.ClientID = "bazquux" 177 opts.ClientSecret = "foobar" 178 opts.CookieSecret = "xyzzyplugh" 179 // Should match domains that are exactly foo.bar and any subdomain of bar.foo 180 opts.WhitelistDomains = []string{"foo.bar", ".bar.foo"} 181 opts.Validate() 182 183 proxy := NewOAuthProxy(opts, func(string) bool { return true }) 184 185 noRD := proxy.IsValidRedirect("") 186 assert.Equal(t, false, noRD) 187 188 singleSlash := proxy.IsValidRedirect("/redirect") 189 assert.Equal(t, true, singleSlash) 190 191 doubleSlash := proxy.IsValidRedirect("//redirect") 192 assert.Equal(t, false, doubleSlash) 193 194 validHTTP := proxy.IsValidRedirect("http://foo.bar/redirect") 195 assert.Equal(t, true, validHTTP) 196 197 validHTTPS := proxy.IsValidRedirect("https://foo.bar/redirect") 198 assert.Equal(t, true, validHTTPS) 199 200 invalidHTTPSubdomain := proxy.IsValidRedirect("http://baz.foo.bar/redirect") 201 assert.Equal(t, false, invalidHTTPSubdomain) 202 203 invalidHTTPSSubdomain := proxy.IsValidRedirect("https://baz.foo.bar/redirect") 204 assert.Equal(t, false, invalidHTTPSSubdomain) 205 206 validHTTPSubdomain := proxy.IsValidRedirect("http://baz.bar.foo/redirect") 207 assert.Equal(t, true, validHTTPSubdomain) 208 209 validHTTPSSubdomain := proxy.IsValidRedirect("https://baz.bar.foo/redirect") 210 assert.Equal(t, true, validHTTPSSubdomain) 211 212 invalidHTTP1 := proxy.IsValidRedirect("http://foo.bar.evil.corp/redirect") 213 assert.Equal(t, false, invalidHTTP1) 214 215 invalidHTTPS1 := proxy.IsValidRedirect("https://foo.bar.evil.corp/redirect") 216 assert.Equal(t, false, invalidHTTPS1) 217 218 invalidHTTP2 := proxy.IsValidRedirect("http://evil.corp/redirect?rd=foo.bar") 219 assert.Equal(t, false, invalidHTTP2) 220 221 invalidHTTPS2 := proxy.IsValidRedirect("https://evil.corp/redirect?rd=foo.bar") 222 assert.Equal(t, false, invalidHTTPS2) 223 } 224 225 type TestProvider struct { 226 *providers.ProviderData 227 EmailAddress string 228 ValidToken bool 229 } 230 231 func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { 232 return &TestProvider{ 233 ProviderData: &providers.ProviderData{ 234 ProviderName: "Test Provider", 235 LoginURL: &url.URL{ 236 Scheme: "http", 237 Host: providerURL.Host, 238 Path: "/oauth/authorize", 239 }, 240 RedeemURL: &url.URL{ 241 Scheme: "http", 242 Host: providerURL.Host, 243 Path: "/oauth/token", 244 }, 245 ProfileURL: &url.URL{ 246 Scheme: "http", 247 Host: providerURL.Host, 248 Path: "/api/v1/profile", 249 }, 250 Scope: "profile.email", 251 }, 252 EmailAddress: emailAddress, 253 } 254 } 255 256 func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { 257 return tp.EmailAddress, nil 258 } 259 260 func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { 261 return tp.ValidToken 262 } 263 264 func TestBasicAuthPassword(t *testing.T) { 265 providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 266 log.Printf("%#v", r) 267 var payload string 268 switch r.URL.Path { 269 case "/oauth/token": 270 payload = `{"access_token": "my_auth_token"}` 271 default: 272 payload = r.Header.Get("Authorization") 273 if payload == "" { 274 payload = "No Authorization header found." 275 } 276 } 277 w.WriteHeader(200) 278 w.Write([]byte(payload)) 279 })) 280 opts := NewOptions() 281 opts.Upstreams = append(opts.Upstreams, providerServer.URL) 282 // The CookieSecret must be 32 bytes in order to create the AES 283 // cipher. 284 opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" 285 opts.ClientID = "bazquux" 286 opts.ClientSecret = "foobar" 287 opts.CookieSecure = false 288 opts.PassBasicAuth = true 289 opts.PassUserHeaders = true 290 opts.BasicAuthPassword = "This is a secure password" 291 opts.Validate() 292 293 providerURL, _ := url.Parse(providerServer.URL) 294 const emailAddress = "michael.bland@gsa.gov" 295 const username = "michael.bland" 296 297 opts.provider = NewTestProvider(providerURL, emailAddress) 298 proxy := NewOAuthProxy(opts, func(email string) bool { 299 return email == emailAddress 300 }) 301 302 rw := httptest.NewRecorder() 303 req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", 304 strings.NewReader("")) 305 req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now())) 306 proxy.ServeHTTP(rw, req) 307 if rw.Code >= 400 { 308 t.Fatalf("expected 3xx got %d", rw.Code) 309 } 310 cookie := rw.HeaderMap["Set-Cookie"][1] 311 312 cookieName := proxy.CookieName 313 var value string 314 keyPrefix := cookieName + "=" 315 316 for _, field := range strings.Split(cookie, "; ") { 317 value = strings.TrimPrefix(field, keyPrefix) 318 if value != field { 319 break 320 } else { 321 value = "" 322 } 323 } 324 325 req, _ = http.NewRequest("GET", "/", strings.NewReader("")) 326 req.AddCookie(&http.Cookie{ 327 Name: cookieName, 328 Value: value, 329 Path: "/", 330 Expires: time.Now().Add(time.Duration(24)), 331 HttpOnly: true, 332 }) 333 req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now())) 334 335 rw = httptest.NewRecorder() 336 proxy.ServeHTTP(rw, req) 337 338 expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+opts.BasicAuthPassword)) 339 assert.Equal(t, expectedHeader, rw.Body.String()) 340 providerServer.Close() 341 } 342 343 type PassAccessTokenTest struct { 344 providerServer *httptest.Server 345 proxy *OAuthProxy 346 opts *Options 347 } 348 349 type PassAccessTokenTestOptions struct { 350 PassAccessToken bool 351 } 352 353 func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { 354 t := &PassAccessTokenTest{} 355 356 t.providerServer = httptest.NewServer( 357 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 358 log.Printf("%#v", r) 359 var payload string 360 switch r.URL.Path { 361 case "/oauth/token": 362 payload = `{"access_token": "my_auth_token"}` 363 default: 364 payload = r.Header.Get("X-Forwarded-Access-Token") 365 if payload == "" { 366 payload = "No access token found." 367 } 368 } 369 w.WriteHeader(200) 370 w.Write([]byte(payload)) 371 })) 372 373 t.opts = NewOptions() 374 t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL) 375 // The CookieSecret must be 32 bytes in order to create the AES 376 // cipher. 377 t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" 378 t.opts.ClientID = "bazquux" 379 t.opts.ClientSecret = "foobar" 380 t.opts.CookieSecure = false 381 t.opts.PassAccessToken = opts.PassAccessToken 382 t.opts.Validate() 383 384 providerURL, _ := url.Parse(t.providerServer.URL) 385 const emailAddress = "michael.bland@gsa.gov" 386 387 t.opts.provider = NewTestProvider(providerURL, emailAddress) 388 t.proxy = NewOAuthProxy(t.opts, func(email string) bool { 389 return email == emailAddress 390 }) 391 return t 392 } 393 394 func (patTest *PassAccessTokenTest) Close() { 395 patTest.providerServer.Close() 396 } 397 398 func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, 399 cookie string) { 400 rw := httptest.NewRecorder() 401 req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", 402 strings.NewReader("")) 403 if err != nil { 404 return 0, "" 405 } 406 req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now())) 407 patTest.proxy.ServeHTTP(rw, req) 408 return rw.Code, rw.HeaderMap["Set-Cookie"][1] 409 } 410 411 func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int, accessToken string) { 412 cookieName := patTest.proxy.CookieName 413 var value string 414 keyPrefix := cookieName + "=" 415 416 for _, field := range strings.Split(cookie, "; ") { 417 value = strings.TrimPrefix(field, keyPrefix) 418 if value != field { 419 break 420 } else { 421 value = "" 422 } 423 } 424 if value == "" { 425 return 0, "" 426 } 427 428 req, err := http.NewRequest("GET", "/", strings.NewReader("")) 429 if err != nil { 430 return 0, "" 431 } 432 req.AddCookie(&http.Cookie{ 433 Name: cookieName, 434 Value: value, 435 Path: "/", 436 Expires: time.Now().Add(time.Duration(24)), 437 HttpOnly: true, 438 }) 439 440 rw := httptest.NewRecorder() 441 patTest.proxy.ServeHTTP(rw, req) 442 return rw.Code, rw.Body.String() 443 } 444 445 func TestForwardAccessTokenUpstream(t *testing.T) { 446 patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ 447 PassAccessToken: true, 448 }) 449 defer patTest.Close() 450 451 // A successful validation will redirect and set the auth cookie. 452 code, cookie := patTest.getCallbackEndpoint() 453 if code != 302 { 454 t.Fatalf("expected 302; got %d", code) 455 } 456 assert.NotEqual(t, nil, cookie) 457 458 // Now we make a regular request; the access_token from the cookie is 459 // forwarded as the "X-Forwarded-Access-Token" header. The token is 460 // read by the test provider server and written in the response body. 461 code, payload := patTest.getRootEndpoint(cookie) 462 if code != 200 { 463 t.Fatalf("expected 200; got %d", code) 464 } 465 assert.Equal(t, "my_auth_token", payload) 466 } 467 468 func TestDoNotForwardAccessTokenUpstream(t *testing.T) { 469 patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ 470 PassAccessToken: false, 471 }) 472 defer patTest.Close() 473 474 // A successful validation will redirect and set the auth cookie. 475 code, cookie := patTest.getCallbackEndpoint() 476 if code != 302 { 477 t.Fatalf("expected 302; got %d", code) 478 } 479 assert.NotEqual(t, nil, cookie) 480 481 // Now we make a regular request, but the access token header should 482 // not be present. 483 code, payload := patTest.getRootEndpoint(cookie) 484 if code != 200 { 485 t.Fatalf("expected 200; got %d", code) 486 } 487 assert.Equal(t, "No access token found.", payload) 488 } 489 490 type SignInPageTest struct { 491 opts *Options 492 proxy *OAuthProxy 493 signInRegexp *regexp.Regexp 494 signInProviderRegexp *regexp.Regexp 495 } 496 497 const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">` 498 const signInSkipProvider = `>Found<` 499 500 func NewSignInPageTest(skipProvider bool) *SignInPageTest { 501 var sipTest SignInPageTest 502 503 sipTest.opts = NewOptions() 504 sipTest.opts.CookieSecret = "foobar" 505 sipTest.opts.ClientID = "bazquux" 506 sipTest.opts.ClientSecret = "xyzzyplugh" 507 sipTest.opts.SkipProviderButton = skipProvider 508 sipTest.opts.Validate() 509 510 sipTest.proxy = NewOAuthProxy(sipTest.opts, func(email string) bool { 511 return true 512 }) 513 sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern) 514 sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider) 515 516 return &sipTest 517 } 518 519 func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) { 520 rw := httptest.NewRecorder() 521 req, _ := http.NewRequest("GET", endpoint, strings.NewReader("")) 522 sipTest.proxy.ServeHTTP(rw, req) 523 return rw.Code, rw.Body.String() 524 } 525 526 func TestSignInPageIncludesTargetRedirect(t *testing.T) { 527 sipTest := NewSignInPageTest(false) 528 const endpoint = "/some/random/endpoint" 529 530 code, body := sipTest.GetEndpoint(endpoint) 531 assert.Equal(t, 403, code) 532 533 match := sipTest.signInRegexp.FindStringSubmatch(body) 534 if match == nil { 535 t.Fatal("Did not find pattern in body: " + 536 signInRedirectPattern + "\nBody:\n" + body) 537 } 538 if match[1] != endpoint { 539 t.Fatal(`expected redirect to "` + endpoint + 540 `", but was "` + match[1] + `"`) 541 } 542 } 543 544 func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { 545 sipTest := NewSignInPageTest(false) 546 code, body := sipTest.GetEndpoint("/oauth2/sign_in") 547 assert.Equal(t, 200, code) 548 549 match := sipTest.signInRegexp.FindStringSubmatch(body) 550 if match == nil { 551 t.Fatal("Did not find pattern in body: " + 552 signInRedirectPattern + "\nBody:\n" + body) 553 } 554 if match[1] != "/" { 555 t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`) 556 } 557 } 558 559 func TestSignInPageSkipProvider(t *testing.T) { 560 sipTest := NewSignInPageTest(true) 561 const endpoint = "/some/random/endpoint" 562 563 code, body := sipTest.GetEndpoint(endpoint) 564 assert.Equal(t, 302, code) 565 566 match := sipTest.signInProviderRegexp.FindStringSubmatch(body) 567 if match == nil { 568 t.Fatal("Did not find pattern in body: " + 569 signInSkipProvider + "\nBody:\n" + body) 570 } 571 } 572 573 func TestSignInPageSkipProviderDirect(t *testing.T) { 574 sipTest := NewSignInPageTest(true) 575 const endpoint = "/sign_in" 576 577 code, body := sipTest.GetEndpoint(endpoint) 578 assert.Equal(t, 302, code) 579 580 match := sipTest.signInProviderRegexp.FindStringSubmatch(body) 581 if match == nil { 582 t.Fatal("Did not find pattern in body: " + 583 signInSkipProvider + "\nBody:\n" + body) 584 } 585 } 586 587 type ProcessCookieTest struct { 588 opts *Options 589 proxy *OAuthProxy 590 rw *httptest.ResponseRecorder 591 req *http.Request 592 provider TestProvider 593 responseCode int 594 validateUser bool 595 } 596 597 type ProcessCookieTestOpts struct { 598 providerValidateCookieResponse bool 599 } 600 601 func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { 602 var pcTest ProcessCookieTest 603 604 pcTest.opts = NewOptions() 605 pcTest.opts.ClientID = "bazquux" 606 pcTest.opts.ClientSecret = "xyzzyplugh" 607 pcTest.opts.CookieSecret = "0123456789abcdefabcd" 608 // First, set the CookieRefresh option so proxy.AesCipher is created, 609 // needed to encrypt the access_token. 610 pcTest.opts.CookieRefresh = time.Hour 611 pcTest.opts.Validate() 612 613 pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool { 614 return pcTest.validateUser 615 }) 616 pcTest.proxy.provider = &TestProvider{ 617 ValidToken: opts.providerValidateCookieResponse, 618 } 619 620 // Now, zero-out proxy.CookieRefresh for the cases that don't involve 621 // access_token validation. 622 pcTest.proxy.CookieRefresh = time.Duration(0) 623 pcTest.rw = httptest.NewRecorder() 624 pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader("")) 625 pcTest.validateUser = true 626 return &pcTest 627 } 628 629 func NewProcessCookieTestWithDefaults() *ProcessCookieTest { 630 return NewProcessCookieTest(ProcessCookieTestOpts{ 631 providerValidateCookieResponse: true, 632 }) 633 } 634 635 func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie { 636 return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) 637 } 638 639 func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error { 640 value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) 641 if err != nil { 642 return err 643 } 644 for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) { 645 p.req.AddCookie(c) 646 } 647 return nil 648 } 649 650 func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) { 651 return p.proxy.LoadCookiedSession(p.req) 652 } 653 654 func TestLoadCookiedSession(t *testing.T) { 655 pcTest := NewProcessCookieTestWithDefaults() 656 657 startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} 658 pcTest.SaveSession(startSession, time.Now()) 659 660 session, _, err := pcTest.LoadCookiedSession() 661 assert.Equal(t, nil, err) 662 assert.Equal(t, startSession.Email, session.Email) 663 assert.Equal(t, "michael.bland", session.User) 664 assert.Equal(t, startSession.AccessToken, session.AccessToken) 665 } 666 667 func TestProcessCookieNoCookieError(t *testing.T) { 668 pcTest := NewProcessCookieTestWithDefaults() 669 670 session, _, err := pcTest.LoadCookiedSession() 671 assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) 672 if session != nil { 673 t.Errorf("expected nil session. got %#v", session) 674 } 675 } 676 677 func TestProcessCookieRefreshNotSet(t *testing.T) { 678 pcTest := NewProcessCookieTestWithDefaults() 679 pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour 680 reference := time.Now().Add(time.Duration(-2) * time.Hour) 681 682 startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} 683 pcTest.SaveSession(startSession, reference) 684 685 session, age, err := pcTest.LoadCookiedSession() 686 assert.Equal(t, nil, err) 687 if age < time.Duration(-2)*time.Hour { 688 t.Errorf("cookie too young %v", age) 689 } 690 assert.Equal(t, startSession.Email, session.Email) 691 } 692 693 func TestProcessCookieFailIfCookieExpired(t *testing.T) { 694 pcTest := NewProcessCookieTestWithDefaults() 695 pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour 696 reference := time.Now().Add(time.Duration(25) * time.Hour * -1) 697 startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} 698 pcTest.SaveSession(startSession, reference) 699 700 session, _, err := pcTest.LoadCookiedSession() 701 assert.NotEqual(t, nil, err) 702 if session != nil { 703 t.Errorf("expected nil session %#v", session) 704 } 705 } 706 707 func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { 708 pcTest := NewProcessCookieTestWithDefaults() 709 pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour 710 reference := time.Now().Add(time.Duration(25) * time.Hour * -1) 711 startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} 712 pcTest.SaveSession(startSession, reference) 713 714 pcTest.proxy.CookieRefresh = time.Hour 715 session, _, err := pcTest.LoadCookiedSession() 716 assert.NotEqual(t, nil, err) 717 if session != nil { 718 t.Errorf("expected nil session %#v", session) 719 } 720 } 721 722 func NewAuthOnlyEndpointTest() *ProcessCookieTest { 723 pcTest := NewProcessCookieTestWithDefaults() 724 pcTest.req, _ = http.NewRequest("GET", 725 pcTest.opts.ProxyPrefix+"/auth", nil) 726 return pcTest 727 } 728 729 func TestAuthOnlyEndpointAccepted(t *testing.T) { 730 test := NewAuthOnlyEndpointTest() 731 startSession := &providers.SessionState{ 732 Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} 733 test.SaveSession(startSession, time.Now()) 734 735 test.proxy.ServeHTTP(test.rw, test.req) 736 assert.Equal(t, http.StatusAccepted, test.rw.Code) 737 bodyBytes, _ := ioutil.ReadAll(test.rw.Body) 738 assert.Equal(t, "", string(bodyBytes)) 739 } 740 741 func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { 742 test := NewAuthOnlyEndpointTest() 743 744 test.proxy.ServeHTTP(test.rw, test.req) 745 assert.Equal(t, http.StatusUnauthorized, test.rw.Code) 746 bodyBytes, _ := ioutil.ReadAll(test.rw.Body) 747 assert.Equal(t, "unauthorized request\n", string(bodyBytes)) 748 } 749 750 func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { 751 test := NewAuthOnlyEndpointTest() 752 test.proxy.CookieExpire = time.Duration(24) * time.Hour 753 reference := time.Now().Add(time.Duration(25) * time.Hour * -1) 754 startSession := &providers.SessionState{ 755 Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} 756 test.SaveSession(startSession, reference) 757 758 test.proxy.ServeHTTP(test.rw, test.req) 759 assert.Equal(t, http.StatusUnauthorized, test.rw.Code) 760 bodyBytes, _ := ioutil.ReadAll(test.rw.Body) 761 assert.Equal(t, "unauthorized request\n", string(bodyBytes)) 762 } 763 764 func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { 765 test := NewAuthOnlyEndpointTest() 766 startSession := &providers.SessionState{ 767 Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} 768 test.SaveSession(startSession, time.Now()) 769 test.validateUser = false 770 771 test.proxy.ServeHTTP(test.rw, test.req) 772 assert.Equal(t, http.StatusUnauthorized, test.rw.Code) 773 bodyBytes, _ := ioutil.ReadAll(test.rw.Body) 774 assert.Equal(t, "unauthorized request\n", string(bodyBytes)) 775 } 776 777 func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { 778 var pcTest ProcessCookieTest 779 780 pcTest.opts = NewOptions() 781 pcTest.opts.SetXAuthRequest = true 782 pcTest.opts.Validate() 783 784 pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool { 785 return pcTest.validateUser 786 }) 787 pcTest.proxy.provider = &TestProvider{ 788 ValidToken: true, 789 } 790 791 pcTest.validateUser = true 792 793 pcTest.rw = httptest.NewRecorder() 794 pcTest.req, _ = http.NewRequest("GET", 795 pcTest.opts.ProxyPrefix+"/auth", nil) 796 797 startSession := &providers.SessionState{ 798 User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} 799 pcTest.SaveSession(startSession, time.Now()) 800 801 pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) 802 assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) 803 assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0]) 804 assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0]) 805 } 806 807 func TestAuthSkippedForPreflightRequests(t *testing.T) { 808 upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 809 w.WriteHeader(200) 810 w.Write([]byte("response")) 811 })) 812 defer upstream.Close() 813 814 opts := NewOptions() 815 opts.Upstreams = append(opts.Upstreams, upstream.URL) 816 opts.ClientID = "bazquux" 817 opts.ClientSecret = "foobar" 818 opts.CookieSecret = "xyzzyplugh" 819 opts.SkipAuthPreflight = true 820 opts.Validate() 821 822 upstreamURL, _ := url.Parse(upstream.URL) 823 opts.provider = NewTestProvider(upstreamURL, "") 824 825 proxy := NewOAuthProxy(opts, func(string) bool { return false }) 826 rw := httptest.NewRecorder() 827 req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil) 828 proxy.ServeHTTP(rw, req) 829 830 assert.Equal(t, 200, rw.Code) 831 assert.Equal(t, "response", rw.Body.String()) 832 } 833 834 type SignatureAuthenticator struct { 835 auth hmacauth.HmacAuth 836 } 837 838 func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) { 839 result, headerSig, computedSig := v.auth.AuthenticateRequest(r) 840 if result == hmacauth.ResultNoSignature { 841 w.Write([]byte("no signature received")) 842 } else if result == hmacauth.ResultMatch { 843 w.Write([]byte("signatures match")) 844 } else if result == hmacauth.ResultMismatch { 845 w.Write([]byte("signatures do not match:" + 846 "\n received: " + headerSig + 847 "\n computed: " + computedSig)) 848 } else { 849 panic("Unknown result value: " + result.String()) 850 } 851 } 852 853 type SignatureTest struct { 854 opts *Options 855 upstream *httptest.Server 856 upstreamHost string 857 provider *httptest.Server 858 header http.Header 859 rw *httptest.ResponseRecorder 860 authenticator *SignatureAuthenticator 861 } 862 863 func NewSignatureTest() *SignatureTest { 864 opts := NewOptions() 865 opts.CookieSecret = "cookie secret" 866 opts.ClientID = "client ID" 867 opts.ClientSecret = "client secret" 868 opts.EmailDomains = []string{"acm.org"} 869 870 authenticator := &SignatureAuthenticator{} 871 upstream := httptest.NewServer( 872 http.HandlerFunc(authenticator.Authenticate)) 873 upstreamURL, _ := url.Parse(upstream.URL) 874 opts.Upstreams = append(opts.Upstreams, upstream.URL) 875 876 providerHandler := func(w http.ResponseWriter, r *http.Request) { 877 w.Write([]byte(`{"access_token": "my_auth_token"}`)) 878 } 879 provider := httptest.NewServer(http.HandlerFunc(providerHandler)) 880 providerURL, _ := url.Parse(provider.URL) 881 opts.provider = NewTestProvider(providerURL, "mbland@acm.org") 882 883 return &SignatureTest{ 884 opts, 885 upstream, 886 upstreamURL.Host, 887 provider, 888 make(http.Header), 889 httptest.NewRecorder(), 890 authenticator, 891 } 892 } 893 894 func (st *SignatureTest) Close() { 895 st.provider.Close() 896 st.upstream.Close() 897 } 898 899 // fakeNetConn simulates an http.Request.Body buffer that will be consumed 900 // when it is read by the hmacauth.HmacAuth if not handled properly. See: 901 // https://github.com/18F/hmacauth/pull/4 902 type fakeNetConn struct { 903 reqBody string 904 } 905 906 func (fnc *fakeNetConn) Read(p []byte) (n int, err error) { 907 if bodyLen := len(fnc.reqBody); bodyLen != 0 { 908 copy(p, fnc.reqBody) 909 fnc.reqBody = "" 910 return bodyLen, io.EOF 911 } 912 return 0, io.EOF 913 } 914 915 func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { 916 err := st.opts.Validate() 917 if err != nil { 918 panic(err) 919 } 920 proxy := NewOAuthProxy(st.opts, func(email string) bool { return true }) 921 922 var bodyBuf io.ReadCloser 923 if body != "" { 924 bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body}) 925 } 926 req := httptest.NewRequest(method, "/foo/bar", bodyBuf) 927 req.Header = st.header 928 929 state := &providers.SessionState{ 930 Email: "mbland@acm.org", AccessToken: "my_access_token"} 931 value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) 932 if err != nil { 933 panic(err) 934 } 935 for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) { 936 req.AddCookie(c) 937 } 938 // This is used by the upstream to validate the signature. 939 st.authenticator.auth = hmacauth.NewHmacAuth( 940 crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders) 941 proxy.ServeHTTP(st.rw, req) 942 } 943 944 func TestNoRequestSignature(t *testing.T) { 945 st := NewSignatureTest() 946 defer st.Close() 947 st.MakeRequestWithExpectedKey("GET", "", "") 948 assert.Equal(t, 200, st.rw.Code) 949 assert.Equal(t, st.rw.Body.String(), "no signature received") 950 } 951 952 func TestRequestSignatureGetRequest(t *testing.T) { 953 st := NewSignatureTest() 954 defer st.Close() 955 st.opts.SignatureKey = "sha1:foobar" 956 st.MakeRequestWithExpectedKey("GET", "", "foobar") 957 assert.Equal(t, 200, st.rw.Code) 958 assert.Equal(t, st.rw.Body.String(), "signatures match") 959 } 960 961 func TestRequestSignaturePostRequest(t *testing.T) { 962 st := NewSignatureTest() 963 defer st.Close() 964 st.opts.SignatureKey = "sha1:foobar" 965 payload := `{ "hello": "world!" }` 966 st.MakeRequestWithExpectedKey("POST", payload, "foobar") 967 assert.Equal(t, 200, st.rw.Code) 968 assert.Equal(t, st.rw.Body.String(), "signatures match") 969 } 970 971 func TestGetRedirect(t *testing.T) { 972 options := NewOptions() 973 _ = options.Validate() 974 require.NotEmpty(t, options.ProxyPrefix) 975 proxy := NewOAuthProxy(options, func(s string) bool { return false }) 976 977 tests := []struct { 978 name string 979 url string 980 expectedRedirect string 981 }{ 982 { 983 name: "request outside of ProxyPrefix redirects to original URL", 984 url: "/foo/bar", 985 expectedRedirect: "/foo/bar", 986 }, 987 { 988 name: "request under ProxyPrefix redirects to root", 989 url: proxy.ProxyPrefix + "/foo/bar", 990 expectedRedirect: "/", 991 }, 992 } 993 for _, tt := range tests { 994 t.Run(tt.name, func(t *testing.T) { 995 req, _ := http.NewRequest("GET", tt.url, nil) 996 redirect, err := proxy.GetRedirect(req) 997 998 assert.NoError(t, err) 999 assert.Equal(t, tt.expectedRedirect, redirect) 1000 }) 1001 } 1002 } 1003 1004 type ajaxRequestTest struct { 1005 opts *Options 1006 proxy *OAuthProxy 1007 } 1008 1009 func newAjaxRequestTest() *ajaxRequestTest { 1010 test := &ajaxRequestTest{} 1011 test.opts = NewOptions() 1012 test.opts.CookieSecret = "foobar" 1013 test.opts.ClientID = "bazquux" 1014 test.opts.ClientSecret = "xyzzyplugh" 1015 test.opts.Validate() 1016 test.proxy = NewOAuthProxy(test.opts, func(email string) bool { 1017 return true 1018 }) 1019 return test 1020 } 1021 1022 func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) { 1023 rw := httptest.NewRecorder() 1024 req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader("")) 1025 if err != nil { 1026 return 0, nil, err 1027 } 1028 req.Header = header 1029 test.proxy.ServeHTTP(rw, req) 1030 return rw.Code, rw.Header(), nil 1031 } 1032 1033 func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) { 1034 test := newAjaxRequestTest() 1035 endpoint := "/test" 1036 1037 code, rh, err := test.getEndpoint(endpoint, header) 1038 assert.NoError(t, err) 1039 assert.Equal(t, http.StatusUnauthorized, code) 1040 mime := rh.Get("Content-Type") 1041 assert.Equal(t, applicationJSON, mime) 1042 } 1043 func TestAjaxUnauthorizedRequest1(t *testing.T) { 1044 header := make(http.Header) 1045 header.Add("accept", applicationJSON) 1046 1047 testAjaxUnauthorizedRequest(t, header) 1048 } 1049 1050 func TestAjaxUnauthorizedRequest2(t *testing.T) { 1051 header := make(http.Header) 1052 header.Add("Accept", applicationJSON) 1053 1054 testAjaxUnauthorizedRequest(t, header) 1055 } 1056 1057 func TestAjaxForbiddendRequest(t *testing.T) { 1058 test := newAjaxRequestTest() 1059 endpoint := "/test" 1060 header := make(http.Header) 1061 code, rh, err := test.getEndpoint(endpoint, header) 1062 assert.NoError(t, err) 1063 assert.Equal(t, http.StatusForbidden, code) 1064 mime := rh.Get("Content-Type") 1065 assert.NotEqual(t, applicationJSON, mime) 1066 } 1067 1068 func TestClearSplitCookie(t *testing.T) { 1069 p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} 1070 var rw = httptest.NewRecorder() 1071 req := httptest.NewRequest("get", "/", nil) 1072 1073 req.AddCookie(&http.Cookie{ 1074 Name: "test1", 1075 Value: "test1", 1076 }) 1077 req.AddCookie(&http.Cookie{ 1078 Name: "oauth2_0", 1079 Value: "oauth2_0", 1080 }) 1081 req.AddCookie(&http.Cookie{ 1082 Name: "oauth2_1", 1083 Value: "oauth2_1", 1084 }) 1085 1086 p.ClearSessionCookie(rw, req) 1087 header := rw.Header() 1088 1089 assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries") 1090 } 1091 1092 func TestClearSingleCookie(t *testing.T) { 1093 p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} 1094 var rw = httptest.NewRecorder() 1095 req := httptest.NewRequest("get", "/", nil) 1096 1097 req.AddCookie(&http.Cookie{ 1098 Name: "test1", 1099 Value: "test1", 1100 }) 1101 req.AddCookie(&http.Cookie{ 1102 Name: "oauth2", 1103 Value: "oauth2", 1104 }) 1105 1106 p.ClearSessionCookie(rw, req) 1107 header := rw.Header() 1108 1109 assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries") 1110 }