github.com/pusher/oauth2_proxy@v3.2.0+incompatible/oauthproxy_test.go (about)

     1  package main
     2  
     3  import (
     4  	"crypto"
     5  	"encoding/base64"
     6  	"io"
     7  	"io/ioutil"
     8  	"log"
     9  	"net"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"net/url"
    13  	"regexp"
    14  	"strings"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/mbland/hmacauth"
    19  	"github.com/pusher/oauth2_proxy/providers"
    20  	"github.com/stretchr/testify/assert"
    21  	"github.com/stretchr/testify/require"
    22  	"golang.org/x/net/websocket"
    23  )
    24  
    25  func init() {
    26  	log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
    27  
    28  }
    29  
    30  type WebSocketOrRestHandler struct {
    31  	restHandler http.Handler
    32  	wsHandler   http.Handler
    33  }
    34  
    35  func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    36  	if r.Header.Get("Upgrade") == "websocket" {
    37  		h.wsHandler.ServeHTTP(w, r)
    38  	} else {
    39  		h.restHandler.ServeHTTP(w, r)
    40  	}
    41  }
    42  
    43  func TestWebSocketProxy(t *testing.T) {
    44  	handler := WebSocketOrRestHandler{
    45  		restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    46  			w.WriteHeader(200)
    47  			hostname, _, _ := net.SplitHostPort(r.Host)
    48  			w.Write([]byte(hostname))
    49  		}),
    50  		wsHandler: websocket.Handler(func(ws *websocket.Conn) {
    51  			defer ws.Close()
    52  			var data []byte
    53  			err := websocket.Message.Receive(ws, &data)
    54  			if err != nil {
    55  				t.Fatalf("err %s", err)
    56  				return
    57  			}
    58  			err = websocket.Message.Send(ws, data)
    59  			if err != nil {
    60  				t.Fatalf("err %s", err)
    61  			}
    62  			return
    63  		}),
    64  	}
    65  	backend := httptest.NewServer(&handler)
    66  	defer backend.Close()
    67  
    68  	backendURL, _ := url.Parse(backend.URL)
    69  
    70  	options := NewOptions()
    71  	var auth hmacauth.HmacAuth
    72  	options.PassHostHeader = true
    73  	proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, options, auth)
    74  	frontend := httptest.NewServer(proxyHandler)
    75  	defer frontend.Close()
    76  
    77  	frontendURL, _ := url.Parse(frontend.URL)
    78  	frontendWSURL := "ws://" + frontendURL.Host + "/"
    79  
    80  	ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/")
    81  	if err != nil {
    82  		t.Fatalf("err %s", err)
    83  	}
    84  	request := []byte("hello, world!")
    85  	err = websocket.Message.Send(ws, request)
    86  	if err != nil {
    87  		t.Fatalf("err %s", err)
    88  	}
    89  	var response = make([]byte, 1024)
    90  	websocket.Message.Receive(ws, &response)
    91  	if err != nil {
    92  		t.Fatalf("err %s", err)
    93  	}
    94  	if g, e := string(request), string(response); g != e {
    95  		t.Errorf("got body %q; expected %q", g, e)
    96  	}
    97  
    98  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    99  	res, _ := http.DefaultClient.Do(getReq)
   100  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   101  	backendHostname, _, _ := net.SplitHostPort(backendURL.Host)
   102  	if g, e := string(bodyBytes), backendHostname; g != e {
   103  		t.Errorf("got body %q; expected %q", g, e)
   104  	}
   105  }
   106  
   107  func TestNewReverseProxy(t *testing.T) {
   108  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   109  		w.WriteHeader(200)
   110  		hostname, _, _ := net.SplitHostPort(r.Host)
   111  		w.Write([]byte(hostname))
   112  	}))
   113  	defer backend.Close()
   114  
   115  	backendURL, _ := url.Parse(backend.URL)
   116  	backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
   117  	backendHost := net.JoinHostPort(backendHostname, backendPort)
   118  	proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
   119  
   120  	proxyHandler := NewReverseProxy(proxyURL, time.Second)
   121  	setProxyUpstreamHostHeader(proxyHandler, proxyURL)
   122  	frontend := httptest.NewServer(proxyHandler)
   123  	defer frontend.Close()
   124  
   125  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   126  	res, _ := http.DefaultClient.Do(getReq)
   127  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   128  	if g, e := string(bodyBytes), backendHostname; g != e {
   129  		t.Errorf("got body %q; expected %q", g, e)
   130  	}
   131  }
   132  
   133  func TestEncodedSlashes(t *testing.T) {
   134  	var seen string
   135  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   136  		w.WriteHeader(200)
   137  		seen = r.RequestURI
   138  	}))
   139  	defer backend.Close()
   140  
   141  	b, _ := url.Parse(backend.URL)
   142  	proxyHandler := NewReverseProxy(b, time.Second)
   143  	setProxyDirector(proxyHandler)
   144  	frontend := httptest.NewServer(proxyHandler)
   145  	defer frontend.Close()
   146  
   147  	f, _ := url.Parse(frontend.URL)
   148  	encodedPath := "/a%2Fb/?c=1"
   149  	getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}}
   150  	_, err := http.DefaultClient.Do(getReq)
   151  	if err != nil {
   152  		t.Fatalf("err %s", err)
   153  	}
   154  	if seen != encodedPath {
   155  		t.Errorf("got bad request %q expected %q", seen, encodedPath)
   156  	}
   157  }
   158  
   159  func TestRobotsTxt(t *testing.T) {
   160  	opts := NewOptions()
   161  	opts.ClientID = "bazquux"
   162  	opts.ClientSecret = "foobar"
   163  	opts.CookieSecret = "xyzzyplugh"
   164  	opts.Validate()
   165  
   166  	proxy := NewOAuthProxy(opts, func(string) bool { return true })
   167  	rw := httptest.NewRecorder()
   168  	req, _ := http.NewRequest("GET", "/robots.txt", nil)
   169  	proxy.ServeHTTP(rw, req)
   170  	assert.Equal(t, 200, rw.Code)
   171  	assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String())
   172  }
   173  
   174  func TestIsValidRedirect(t *testing.T) {
   175  	opts := NewOptions()
   176  	opts.ClientID = "bazquux"
   177  	opts.ClientSecret = "foobar"
   178  	opts.CookieSecret = "xyzzyplugh"
   179  	// Should match domains that are exactly foo.bar and any subdomain of bar.foo
   180  	opts.WhitelistDomains = []string{"foo.bar", ".bar.foo"}
   181  	opts.Validate()
   182  
   183  	proxy := NewOAuthProxy(opts, func(string) bool { return true })
   184  
   185  	noRD := proxy.IsValidRedirect("")
   186  	assert.Equal(t, false, noRD)
   187  
   188  	singleSlash := proxy.IsValidRedirect("/redirect")
   189  	assert.Equal(t, true, singleSlash)
   190  
   191  	doubleSlash := proxy.IsValidRedirect("//redirect")
   192  	assert.Equal(t, false, doubleSlash)
   193  
   194  	validHTTP := proxy.IsValidRedirect("http://foo.bar/redirect")
   195  	assert.Equal(t, true, validHTTP)
   196  
   197  	validHTTPS := proxy.IsValidRedirect("https://foo.bar/redirect")
   198  	assert.Equal(t, true, validHTTPS)
   199  
   200  	invalidHTTPSubdomain := proxy.IsValidRedirect("http://baz.foo.bar/redirect")
   201  	assert.Equal(t, false, invalidHTTPSubdomain)
   202  
   203  	invalidHTTPSSubdomain := proxy.IsValidRedirect("https://baz.foo.bar/redirect")
   204  	assert.Equal(t, false, invalidHTTPSSubdomain)
   205  
   206  	validHTTPSubdomain := proxy.IsValidRedirect("http://baz.bar.foo/redirect")
   207  	assert.Equal(t, true, validHTTPSubdomain)
   208  
   209  	validHTTPSSubdomain := proxy.IsValidRedirect("https://baz.bar.foo/redirect")
   210  	assert.Equal(t, true, validHTTPSSubdomain)
   211  
   212  	invalidHTTP1 := proxy.IsValidRedirect("http://foo.bar.evil.corp/redirect")
   213  	assert.Equal(t, false, invalidHTTP1)
   214  
   215  	invalidHTTPS1 := proxy.IsValidRedirect("https://foo.bar.evil.corp/redirect")
   216  	assert.Equal(t, false, invalidHTTPS1)
   217  
   218  	invalidHTTP2 := proxy.IsValidRedirect("http://evil.corp/redirect?rd=foo.bar")
   219  	assert.Equal(t, false, invalidHTTP2)
   220  
   221  	invalidHTTPS2 := proxy.IsValidRedirect("https://evil.corp/redirect?rd=foo.bar")
   222  	assert.Equal(t, false, invalidHTTPS2)
   223  }
   224  
   225  type TestProvider struct {
   226  	*providers.ProviderData
   227  	EmailAddress string
   228  	ValidToken   bool
   229  }
   230  
   231  func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
   232  	return &TestProvider{
   233  		ProviderData: &providers.ProviderData{
   234  			ProviderName: "Test Provider",
   235  			LoginURL: &url.URL{
   236  				Scheme: "http",
   237  				Host:   providerURL.Host,
   238  				Path:   "/oauth/authorize",
   239  			},
   240  			RedeemURL: &url.URL{
   241  				Scheme: "http",
   242  				Host:   providerURL.Host,
   243  				Path:   "/oauth/token",
   244  			},
   245  			ProfileURL: &url.URL{
   246  				Scheme: "http",
   247  				Host:   providerURL.Host,
   248  				Path:   "/api/v1/profile",
   249  			},
   250  			Scope: "profile.email",
   251  		},
   252  		EmailAddress: emailAddress,
   253  	}
   254  }
   255  
   256  func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) {
   257  	return tp.EmailAddress, nil
   258  }
   259  
   260  func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool {
   261  	return tp.ValidToken
   262  }
   263  
   264  func TestBasicAuthPassword(t *testing.T) {
   265  	providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   266  		log.Printf("%#v", r)
   267  		var payload string
   268  		switch r.URL.Path {
   269  		case "/oauth/token":
   270  			payload = `{"access_token": "my_auth_token"}`
   271  		default:
   272  			payload = r.Header.Get("Authorization")
   273  			if payload == "" {
   274  				payload = "No Authorization header found."
   275  			}
   276  		}
   277  		w.WriteHeader(200)
   278  		w.Write([]byte(payload))
   279  	}))
   280  	opts := NewOptions()
   281  	opts.Upstreams = append(opts.Upstreams, providerServer.URL)
   282  	// The CookieSecret must be 32 bytes in order to create the AES
   283  	// cipher.
   284  	opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
   285  	opts.ClientID = "bazquux"
   286  	opts.ClientSecret = "foobar"
   287  	opts.CookieSecure = false
   288  	opts.PassBasicAuth = true
   289  	opts.PassUserHeaders = true
   290  	opts.BasicAuthPassword = "This is a secure password"
   291  	opts.Validate()
   292  
   293  	providerURL, _ := url.Parse(providerServer.URL)
   294  	const emailAddress = "michael.bland@gsa.gov"
   295  	const username = "michael.bland"
   296  
   297  	opts.provider = NewTestProvider(providerURL, emailAddress)
   298  	proxy := NewOAuthProxy(opts, func(email string) bool {
   299  		return email == emailAddress
   300  	})
   301  
   302  	rw := httptest.NewRecorder()
   303  	req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
   304  		strings.NewReader(""))
   305  	req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
   306  	proxy.ServeHTTP(rw, req)
   307  	if rw.Code >= 400 {
   308  		t.Fatalf("expected 3xx got %d", rw.Code)
   309  	}
   310  	cookie := rw.HeaderMap["Set-Cookie"][1]
   311  
   312  	cookieName := proxy.CookieName
   313  	var value string
   314  	keyPrefix := cookieName + "="
   315  
   316  	for _, field := range strings.Split(cookie, "; ") {
   317  		value = strings.TrimPrefix(field, keyPrefix)
   318  		if value != field {
   319  			break
   320  		} else {
   321  			value = ""
   322  		}
   323  	}
   324  
   325  	req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
   326  	req.AddCookie(&http.Cookie{
   327  		Name:     cookieName,
   328  		Value:    value,
   329  		Path:     "/",
   330  		Expires:  time.Now().Add(time.Duration(24)),
   331  		HttpOnly: true,
   332  	})
   333  	req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
   334  
   335  	rw = httptest.NewRecorder()
   336  	proxy.ServeHTTP(rw, req)
   337  
   338  	expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+opts.BasicAuthPassword))
   339  	assert.Equal(t, expectedHeader, rw.Body.String())
   340  	providerServer.Close()
   341  }
   342  
   343  type PassAccessTokenTest struct {
   344  	providerServer *httptest.Server
   345  	proxy          *OAuthProxy
   346  	opts           *Options
   347  }
   348  
   349  type PassAccessTokenTestOptions struct {
   350  	PassAccessToken bool
   351  }
   352  
   353  func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
   354  	t := &PassAccessTokenTest{}
   355  
   356  	t.providerServer = httptest.NewServer(
   357  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   358  			log.Printf("%#v", r)
   359  			var payload string
   360  			switch r.URL.Path {
   361  			case "/oauth/token":
   362  				payload = `{"access_token": "my_auth_token"}`
   363  			default:
   364  				payload = r.Header.Get("X-Forwarded-Access-Token")
   365  				if payload == "" {
   366  					payload = "No access token found."
   367  				}
   368  			}
   369  			w.WriteHeader(200)
   370  			w.Write([]byte(payload))
   371  		}))
   372  
   373  	t.opts = NewOptions()
   374  	t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL)
   375  	// The CookieSecret must be 32 bytes in order to create the AES
   376  	// cipher.
   377  	t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
   378  	t.opts.ClientID = "bazquux"
   379  	t.opts.ClientSecret = "foobar"
   380  	t.opts.CookieSecure = false
   381  	t.opts.PassAccessToken = opts.PassAccessToken
   382  	t.opts.Validate()
   383  
   384  	providerURL, _ := url.Parse(t.providerServer.URL)
   385  	const emailAddress = "michael.bland@gsa.gov"
   386  
   387  	t.opts.provider = NewTestProvider(providerURL, emailAddress)
   388  	t.proxy = NewOAuthProxy(t.opts, func(email string) bool {
   389  		return email == emailAddress
   390  	})
   391  	return t
   392  }
   393  
   394  func (patTest *PassAccessTokenTest) Close() {
   395  	patTest.providerServer.Close()
   396  }
   397  
   398  func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int,
   399  	cookie string) {
   400  	rw := httptest.NewRecorder()
   401  	req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
   402  		strings.NewReader(""))
   403  	if err != nil {
   404  		return 0, ""
   405  	}
   406  	req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
   407  	patTest.proxy.ServeHTTP(rw, req)
   408  	return rw.Code, rw.HeaderMap["Set-Cookie"][1]
   409  }
   410  
   411  func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int, accessToken string) {
   412  	cookieName := patTest.proxy.CookieName
   413  	var value string
   414  	keyPrefix := cookieName + "="
   415  
   416  	for _, field := range strings.Split(cookie, "; ") {
   417  		value = strings.TrimPrefix(field, keyPrefix)
   418  		if value != field {
   419  			break
   420  		} else {
   421  			value = ""
   422  		}
   423  	}
   424  	if value == "" {
   425  		return 0, ""
   426  	}
   427  
   428  	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
   429  	if err != nil {
   430  		return 0, ""
   431  	}
   432  	req.AddCookie(&http.Cookie{
   433  		Name:     cookieName,
   434  		Value:    value,
   435  		Path:     "/",
   436  		Expires:  time.Now().Add(time.Duration(24)),
   437  		HttpOnly: true,
   438  	})
   439  
   440  	rw := httptest.NewRecorder()
   441  	patTest.proxy.ServeHTTP(rw, req)
   442  	return rw.Code, rw.Body.String()
   443  }
   444  
   445  func TestForwardAccessTokenUpstream(t *testing.T) {
   446  	patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
   447  		PassAccessToken: true,
   448  	})
   449  	defer patTest.Close()
   450  
   451  	// A successful validation will redirect and set the auth cookie.
   452  	code, cookie := patTest.getCallbackEndpoint()
   453  	if code != 302 {
   454  		t.Fatalf("expected 302; got %d", code)
   455  	}
   456  	assert.NotEqual(t, nil, cookie)
   457  
   458  	// Now we make a regular request; the access_token from the cookie is
   459  	// forwarded as the "X-Forwarded-Access-Token" header. The token is
   460  	// read by the test provider server and written in the response body.
   461  	code, payload := patTest.getRootEndpoint(cookie)
   462  	if code != 200 {
   463  		t.Fatalf("expected 200; got %d", code)
   464  	}
   465  	assert.Equal(t, "my_auth_token", payload)
   466  }
   467  
   468  func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
   469  	patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
   470  		PassAccessToken: false,
   471  	})
   472  	defer patTest.Close()
   473  
   474  	// A successful validation will redirect and set the auth cookie.
   475  	code, cookie := patTest.getCallbackEndpoint()
   476  	if code != 302 {
   477  		t.Fatalf("expected 302; got %d", code)
   478  	}
   479  	assert.NotEqual(t, nil, cookie)
   480  
   481  	// Now we make a regular request, but the access token header should
   482  	// not be present.
   483  	code, payload := patTest.getRootEndpoint(cookie)
   484  	if code != 200 {
   485  		t.Fatalf("expected 200; got %d", code)
   486  	}
   487  	assert.Equal(t, "No access token found.", payload)
   488  }
   489  
   490  type SignInPageTest struct {
   491  	opts                 *Options
   492  	proxy                *OAuthProxy
   493  	signInRegexp         *regexp.Regexp
   494  	signInProviderRegexp *regexp.Regexp
   495  }
   496  
   497  const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
   498  const signInSkipProvider = `>Found<`
   499  
   500  func NewSignInPageTest(skipProvider bool) *SignInPageTest {
   501  	var sipTest SignInPageTest
   502  
   503  	sipTest.opts = NewOptions()
   504  	sipTest.opts.CookieSecret = "foobar"
   505  	sipTest.opts.ClientID = "bazquux"
   506  	sipTest.opts.ClientSecret = "xyzzyplugh"
   507  	sipTest.opts.SkipProviderButton = skipProvider
   508  	sipTest.opts.Validate()
   509  
   510  	sipTest.proxy = NewOAuthProxy(sipTest.opts, func(email string) bool {
   511  		return true
   512  	})
   513  	sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern)
   514  	sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider)
   515  
   516  	return &sipTest
   517  }
   518  
   519  func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
   520  	rw := httptest.NewRecorder()
   521  	req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
   522  	sipTest.proxy.ServeHTTP(rw, req)
   523  	return rw.Code, rw.Body.String()
   524  }
   525  
   526  func TestSignInPageIncludesTargetRedirect(t *testing.T) {
   527  	sipTest := NewSignInPageTest(false)
   528  	const endpoint = "/some/random/endpoint"
   529  
   530  	code, body := sipTest.GetEndpoint(endpoint)
   531  	assert.Equal(t, 403, code)
   532  
   533  	match := sipTest.signInRegexp.FindStringSubmatch(body)
   534  	if match == nil {
   535  		t.Fatal("Did not find pattern in body: " +
   536  			signInRedirectPattern + "\nBody:\n" + body)
   537  	}
   538  	if match[1] != endpoint {
   539  		t.Fatal(`expected redirect to "` + endpoint +
   540  			`", but was "` + match[1] + `"`)
   541  	}
   542  }
   543  
   544  func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
   545  	sipTest := NewSignInPageTest(false)
   546  	code, body := sipTest.GetEndpoint("/oauth2/sign_in")
   547  	assert.Equal(t, 200, code)
   548  
   549  	match := sipTest.signInRegexp.FindStringSubmatch(body)
   550  	if match == nil {
   551  		t.Fatal("Did not find pattern in body: " +
   552  			signInRedirectPattern + "\nBody:\n" + body)
   553  	}
   554  	if match[1] != "/" {
   555  		t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
   556  	}
   557  }
   558  
   559  func TestSignInPageSkipProvider(t *testing.T) {
   560  	sipTest := NewSignInPageTest(true)
   561  	const endpoint = "/some/random/endpoint"
   562  
   563  	code, body := sipTest.GetEndpoint(endpoint)
   564  	assert.Equal(t, 302, code)
   565  
   566  	match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
   567  	if match == nil {
   568  		t.Fatal("Did not find pattern in body: " +
   569  			signInSkipProvider + "\nBody:\n" + body)
   570  	}
   571  }
   572  
   573  func TestSignInPageSkipProviderDirect(t *testing.T) {
   574  	sipTest := NewSignInPageTest(true)
   575  	const endpoint = "/sign_in"
   576  
   577  	code, body := sipTest.GetEndpoint(endpoint)
   578  	assert.Equal(t, 302, code)
   579  
   580  	match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
   581  	if match == nil {
   582  		t.Fatal("Did not find pattern in body: " +
   583  			signInSkipProvider + "\nBody:\n" + body)
   584  	}
   585  }
   586  
   587  type ProcessCookieTest struct {
   588  	opts         *Options
   589  	proxy        *OAuthProxy
   590  	rw           *httptest.ResponseRecorder
   591  	req          *http.Request
   592  	provider     TestProvider
   593  	responseCode int
   594  	validateUser bool
   595  }
   596  
   597  type ProcessCookieTestOpts struct {
   598  	providerValidateCookieResponse bool
   599  }
   600  
   601  func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
   602  	var pcTest ProcessCookieTest
   603  
   604  	pcTest.opts = NewOptions()
   605  	pcTest.opts.ClientID = "bazquux"
   606  	pcTest.opts.ClientSecret = "xyzzyplugh"
   607  	pcTest.opts.CookieSecret = "0123456789abcdefabcd"
   608  	// First, set the CookieRefresh option so proxy.AesCipher is created,
   609  	// needed to encrypt the access_token.
   610  	pcTest.opts.CookieRefresh = time.Hour
   611  	pcTest.opts.Validate()
   612  
   613  	pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
   614  		return pcTest.validateUser
   615  	})
   616  	pcTest.proxy.provider = &TestProvider{
   617  		ValidToken: opts.providerValidateCookieResponse,
   618  	}
   619  
   620  	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
   621  	// access_token validation.
   622  	pcTest.proxy.CookieRefresh = time.Duration(0)
   623  	pcTest.rw = httptest.NewRecorder()
   624  	pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
   625  	pcTest.validateUser = true
   626  	return &pcTest
   627  }
   628  
   629  func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
   630  	return NewProcessCookieTest(ProcessCookieTestOpts{
   631  		providerValidateCookieResponse: true,
   632  	})
   633  }
   634  
   635  func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie {
   636  	return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
   637  }
   638  
   639  func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
   640  	value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
   641  	if err != nil {
   642  		return err
   643  	}
   644  	for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) {
   645  		p.req.AddCookie(c)
   646  	}
   647  	return nil
   648  }
   649  
   650  func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) {
   651  	return p.proxy.LoadCookiedSession(p.req)
   652  }
   653  
   654  func TestLoadCookiedSession(t *testing.T) {
   655  	pcTest := NewProcessCookieTestWithDefaults()
   656  
   657  	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
   658  	pcTest.SaveSession(startSession, time.Now())
   659  
   660  	session, _, err := pcTest.LoadCookiedSession()
   661  	assert.Equal(t, nil, err)
   662  	assert.Equal(t, startSession.Email, session.Email)
   663  	assert.Equal(t, "michael.bland", session.User)
   664  	assert.Equal(t, startSession.AccessToken, session.AccessToken)
   665  }
   666  
   667  func TestProcessCookieNoCookieError(t *testing.T) {
   668  	pcTest := NewProcessCookieTestWithDefaults()
   669  
   670  	session, _, err := pcTest.LoadCookiedSession()
   671  	assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
   672  	if session != nil {
   673  		t.Errorf("expected nil session. got %#v", session)
   674  	}
   675  }
   676  
   677  func TestProcessCookieRefreshNotSet(t *testing.T) {
   678  	pcTest := NewProcessCookieTestWithDefaults()
   679  	pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
   680  	reference := time.Now().Add(time.Duration(-2) * time.Hour)
   681  
   682  	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
   683  	pcTest.SaveSession(startSession, reference)
   684  
   685  	session, age, err := pcTest.LoadCookiedSession()
   686  	assert.Equal(t, nil, err)
   687  	if age < time.Duration(-2)*time.Hour {
   688  		t.Errorf("cookie too young %v", age)
   689  	}
   690  	assert.Equal(t, startSession.Email, session.Email)
   691  }
   692  
   693  func TestProcessCookieFailIfCookieExpired(t *testing.T) {
   694  	pcTest := NewProcessCookieTestWithDefaults()
   695  	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
   696  	reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
   697  	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
   698  	pcTest.SaveSession(startSession, reference)
   699  
   700  	session, _, err := pcTest.LoadCookiedSession()
   701  	assert.NotEqual(t, nil, err)
   702  	if session != nil {
   703  		t.Errorf("expected nil session %#v", session)
   704  	}
   705  }
   706  
   707  func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
   708  	pcTest := NewProcessCookieTestWithDefaults()
   709  	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
   710  	reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
   711  	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
   712  	pcTest.SaveSession(startSession, reference)
   713  
   714  	pcTest.proxy.CookieRefresh = time.Hour
   715  	session, _, err := pcTest.LoadCookiedSession()
   716  	assert.NotEqual(t, nil, err)
   717  	if session != nil {
   718  		t.Errorf("expected nil session %#v", session)
   719  	}
   720  }
   721  
   722  func NewAuthOnlyEndpointTest() *ProcessCookieTest {
   723  	pcTest := NewProcessCookieTestWithDefaults()
   724  	pcTest.req, _ = http.NewRequest("GET",
   725  		pcTest.opts.ProxyPrefix+"/auth", nil)
   726  	return pcTest
   727  }
   728  
   729  func TestAuthOnlyEndpointAccepted(t *testing.T) {
   730  	test := NewAuthOnlyEndpointTest()
   731  	startSession := &providers.SessionState{
   732  		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
   733  	test.SaveSession(startSession, time.Now())
   734  
   735  	test.proxy.ServeHTTP(test.rw, test.req)
   736  	assert.Equal(t, http.StatusAccepted, test.rw.Code)
   737  	bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
   738  	assert.Equal(t, "", string(bodyBytes))
   739  }
   740  
   741  func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
   742  	test := NewAuthOnlyEndpointTest()
   743  
   744  	test.proxy.ServeHTTP(test.rw, test.req)
   745  	assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
   746  	bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
   747  	assert.Equal(t, "unauthorized request\n", string(bodyBytes))
   748  }
   749  
   750  func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
   751  	test := NewAuthOnlyEndpointTest()
   752  	test.proxy.CookieExpire = time.Duration(24) * time.Hour
   753  	reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
   754  	startSession := &providers.SessionState{
   755  		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
   756  	test.SaveSession(startSession, reference)
   757  
   758  	test.proxy.ServeHTTP(test.rw, test.req)
   759  	assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
   760  	bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
   761  	assert.Equal(t, "unauthorized request\n", string(bodyBytes))
   762  }
   763  
   764  func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
   765  	test := NewAuthOnlyEndpointTest()
   766  	startSession := &providers.SessionState{
   767  		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
   768  	test.SaveSession(startSession, time.Now())
   769  	test.validateUser = false
   770  
   771  	test.proxy.ServeHTTP(test.rw, test.req)
   772  	assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
   773  	bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
   774  	assert.Equal(t, "unauthorized request\n", string(bodyBytes))
   775  }
   776  
   777  func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
   778  	var pcTest ProcessCookieTest
   779  
   780  	pcTest.opts = NewOptions()
   781  	pcTest.opts.SetXAuthRequest = true
   782  	pcTest.opts.Validate()
   783  
   784  	pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
   785  		return pcTest.validateUser
   786  	})
   787  	pcTest.proxy.provider = &TestProvider{
   788  		ValidToken: true,
   789  	}
   790  
   791  	pcTest.validateUser = true
   792  
   793  	pcTest.rw = httptest.NewRecorder()
   794  	pcTest.req, _ = http.NewRequest("GET",
   795  		pcTest.opts.ProxyPrefix+"/auth", nil)
   796  
   797  	startSession := &providers.SessionState{
   798  		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
   799  	pcTest.SaveSession(startSession, time.Now())
   800  
   801  	pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
   802  	assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
   803  	assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0])
   804  	assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0])
   805  }
   806  
   807  func TestAuthSkippedForPreflightRequests(t *testing.T) {
   808  	upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   809  		w.WriteHeader(200)
   810  		w.Write([]byte("response"))
   811  	}))
   812  	defer upstream.Close()
   813  
   814  	opts := NewOptions()
   815  	opts.Upstreams = append(opts.Upstreams, upstream.URL)
   816  	opts.ClientID = "bazquux"
   817  	opts.ClientSecret = "foobar"
   818  	opts.CookieSecret = "xyzzyplugh"
   819  	opts.SkipAuthPreflight = true
   820  	opts.Validate()
   821  
   822  	upstreamURL, _ := url.Parse(upstream.URL)
   823  	opts.provider = NewTestProvider(upstreamURL, "")
   824  
   825  	proxy := NewOAuthProxy(opts, func(string) bool { return false })
   826  	rw := httptest.NewRecorder()
   827  	req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil)
   828  	proxy.ServeHTTP(rw, req)
   829  
   830  	assert.Equal(t, 200, rw.Code)
   831  	assert.Equal(t, "response", rw.Body.String())
   832  }
   833  
   834  type SignatureAuthenticator struct {
   835  	auth hmacauth.HmacAuth
   836  }
   837  
   838  func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) {
   839  	result, headerSig, computedSig := v.auth.AuthenticateRequest(r)
   840  	if result == hmacauth.ResultNoSignature {
   841  		w.Write([]byte("no signature received"))
   842  	} else if result == hmacauth.ResultMatch {
   843  		w.Write([]byte("signatures match"))
   844  	} else if result == hmacauth.ResultMismatch {
   845  		w.Write([]byte("signatures do not match:" +
   846  			"\n  received: " + headerSig +
   847  			"\n  computed: " + computedSig))
   848  	} else {
   849  		panic("Unknown result value: " + result.String())
   850  	}
   851  }
   852  
   853  type SignatureTest struct {
   854  	opts          *Options
   855  	upstream      *httptest.Server
   856  	upstreamHost  string
   857  	provider      *httptest.Server
   858  	header        http.Header
   859  	rw            *httptest.ResponseRecorder
   860  	authenticator *SignatureAuthenticator
   861  }
   862  
   863  func NewSignatureTest() *SignatureTest {
   864  	opts := NewOptions()
   865  	opts.CookieSecret = "cookie secret"
   866  	opts.ClientID = "client ID"
   867  	opts.ClientSecret = "client secret"
   868  	opts.EmailDomains = []string{"acm.org"}
   869  
   870  	authenticator := &SignatureAuthenticator{}
   871  	upstream := httptest.NewServer(
   872  		http.HandlerFunc(authenticator.Authenticate))
   873  	upstreamURL, _ := url.Parse(upstream.URL)
   874  	opts.Upstreams = append(opts.Upstreams, upstream.URL)
   875  
   876  	providerHandler := func(w http.ResponseWriter, r *http.Request) {
   877  		w.Write([]byte(`{"access_token": "my_auth_token"}`))
   878  	}
   879  	provider := httptest.NewServer(http.HandlerFunc(providerHandler))
   880  	providerURL, _ := url.Parse(provider.URL)
   881  	opts.provider = NewTestProvider(providerURL, "mbland@acm.org")
   882  
   883  	return &SignatureTest{
   884  		opts,
   885  		upstream,
   886  		upstreamURL.Host,
   887  		provider,
   888  		make(http.Header),
   889  		httptest.NewRecorder(),
   890  		authenticator,
   891  	}
   892  }
   893  
   894  func (st *SignatureTest) Close() {
   895  	st.provider.Close()
   896  	st.upstream.Close()
   897  }
   898  
   899  // fakeNetConn simulates an http.Request.Body buffer that will be consumed
   900  // when it is read by the hmacauth.HmacAuth if not handled properly. See:
   901  //   https://github.com/18F/hmacauth/pull/4
   902  type fakeNetConn struct {
   903  	reqBody string
   904  }
   905  
   906  func (fnc *fakeNetConn) Read(p []byte) (n int, err error) {
   907  	if bodyLen := len(fnc.reqBody); bodyLen != 0 {
   908  		copy(p, fnc.reqBody)
   909  		fnc.reqBody = ""
   910  		return bodyLen, io.EOF
   911  	}
   912  	return 0, io.EOF
   913  }
   914  
   915  func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
   916  	err := st.opts.Validate()
   917  	if err != nil {
   918  		panic(err)
   919  	}
   920  	proxy := NewOAuthProxy(st.opts, func(email string) bool { return true })
   921  
   922  	var bodyBuf io.ReadCloser
   923  	if body != "" {
   924  		bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body})
   925  	}
   926  	req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
   927  	req.Header = st.header
   928  
   929  	state := &providers.SessionState{
   930  		Email: "mbland@acm.org", AccessToken: "my_access_token"}
   931  	value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
   932  	if err != nil {
   933  		panic(err)
   934  	}
   935  	for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) {
   936  		req.AddCookie(c)
   937  	}
   938  	// This is used by the upstream to validate the signature.
   939  	st.authenticator.auth = hmacauth.NewHmacAuth(
   940  		crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders)
   941  	proxy.ServeHTTP(st.rw, req)
   942  }
   943  
   944  func TestNoRequestSignature(t *testing.T) {
   945  	st := NewSignatureTest()
   946  	defer st.Close()
   947  	st.MakeRequestWithExpectedKey("GET", "", "")
   948  	assert.Equal(t, 200, st.rw.Code)
   949  	assert.Equal(t, st.rw.Body.String(), "no signature received")
   950  }
   951  
   952  func TestRequestSignatureGetRequest(t *testing.T) {
   953  	st := NewSignatureTest()
   954  	defer st.Close()
   955  	st.opts.SignatureKey = "sha1:foobar"
   956  	st.MakeRequestWithExpectedKey("GET", "", "foobar")
   957  	assert.Equal(t, 200, st.rw.Code)
   958  	assert.Equal(t, st.rw.Body.String(), "signatures match")
   959  }
   960  
   961  func TestRequestSignaturePostRequest(t *testing.T) {
   962  	st := NewSignatureTest()
   963  	defer st.Close()
   964  	st.opts.SignatureKey = "sha1:foobar"
   965  	payload := `{ "hello": "world!" }`
   966  	st.MakeRequestWithExpectedKey("POST", payload, "foobar")
   967  	assert.Equal(t, 200, st.rw.Code)
   968  	assert.Equal(t, st.rw.Body.String(), "signatures match")
   969  }
   970  
   971  func TestGetRedirect(t *testing.T) {
   972  	options := NewOptions()
   973  	_ = options.Validate()
   974  	require.NotEmpty(t, options.ProxyPrefix)
   975  	proxy := NewOAuthProxy(options, func(s string) bool { return false })
   976  
   977  	tests := []struct {
   978  		name             string
   979  		url              string
   980  		expectedRedirect string
   981  	}{
   982  		{
   983  			name:             "request outside of ProxyPrefix redirects to original URL",
   984  			url:              "/foo/bar",
   985  			expectedRedirect: "/foo/bar",
   986  		},
   987  		{
   988  			name:             "request under ProxyPrefix redirects to root",
   989  			url:              proxy.ProxyPrefix + "/foo/bar",
   990  			expectedRedirect: "/",
   991  		},
   992  	}
   993  	for _, tt := range tests {
   994  		t.Run(tt.name, func(t *testing.T) {
   995  			req, _ := http.NewRequest("GET", tt.url, nil)
   996  			redirect, err := proxy.GetRedirect(req)
   997  
   998  			assert.NoError(t, err)
   999  			assert.Equal(t, tt.expectedRedirect, redirect)
  1000  		})
  1001  	}
  1002  }
  1003  
  1004  type ajaxRequestTest struct {
  1005  	opts  *Options
  1006  	proxy *OAuthProxy
  1007  }
  1008  
  1009  func newAjaxRequestTest() *ajaxRequestTest {
  1010  	test := &ajaxRequestTest{}
  1011  	test.opts = NewOptions()
  1012  	test.opts.CookieSecret = "foobar"
  1013  	test.opts.ClientID = "bazquux"
  1014  	test.opts.ClientSecret = "xyzzyplugh"
  1015  	test.opts.Validate()
  1016  	test.proxy = NewOAuthProxy(test.opts, func(email string) bool {
  1017  		return true
  1018  	})
  1019  	return test
  1020  }
  1021  
  1022  func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) {
  1023  	rw := httptest.NewRecorder()
  1024  	req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader(""))
  1025  	if err != nil {
  1026  		return 0, nil, err
  1027  	}
  1028  	req.Header = header
  1029  	test.proxy.ServeHTTP(rw, req)
  1030  	return rw.Code, rw.Header(), nil
  1031  }
  1032  
  1033  func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) {
  1034  	test := newAjaxRequestTest()
  1035  	endpoint := "/test"
  1036  
  1037  	code, rh, err := test.getEndpoint(endpoint, header)
  1038  	assert.NoError(t, err)
  1039  	assert.Equal(t, http.StatusUnauthorized, code)
  1040  	mime := rh.Get("Content-Type")
  1041  	assert.Equal(t, applicationJSON, mime)
  1042  }
  1043  func TestAjaxUnauthorizedRequest1(t *testing.T) {
  1044  	header := make(http.Header)
  1045  	header.Add("accept", applicationJSON)
  1046  
  1047  	testAjaxUnauthorizedRequest(t, header)
  1048  }
  1049  
  1050  func TestAjaxUnauthorizedRequest2(t *testing.T) {
  1051  	header := make(http.Header)
  1052  	header.Add("Accept", applicationJSON)
  1053  
  1054  	testAjaxUnauthorizedRequest(t, header)
  1055  }
  1056  
  1057  func TestAjaxForbiddendRequest(t *testing.T) {
  1058  	test := newAjaxRequestTest()
  1059  	endpoint := "/test"
  1060  	header := make(http.Header)
  1061  	code, rh, err := test.getEndpoint(endpoint, header)
  1062  	assert.NoError(t, err)
  1063  	assert.Equal(t, http.StatusForbidden, code)
  1064  	mime := rh.Get("Content-Type")
  1065  	assert.NotEqual(t, applicationJSON, mime)
  1066  }
  1067  
  1068  func TestClearSplitCookie(t *testing.T) {
  1069  	p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"}
  1070  	var rw = httptest.NewRecorder()
  1071  	req := httptest.NewRequest("get", "/", nil)
  1072  
  1073  	req.AddCookie(&http.Cookie{
  1074  		Name:  "test1",
  1075  		Value: "test1",
  1076  	})
  1077  	req.AddCookie(&http.Cookie{
  1078  		Name:  "oauth2_0",
  1079  		Value: "oauth2_0",
  1080  	})
  1081  	req.AddCookie(&http.Cookie{
  1082  		Name:  "oauth2_1",
  1083  		Value: "oauth2_1",
  1084  	})
  1085  
  1086  	p.ClearSessionCookie(rw, req)
  1087  	header := rw.Header()
  1088  
  1089  	assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries")
  1090  }
  1091  
  1092  func TestClearSingleCookie(t *testing.T) {
  1093  	p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"}
  1094  	var rw = httptest.NewRecorder()
  1095  	req := httptest.NewRequest("get", "/", nil)
  1096  
  1097  	req.AddCookie(&http.Cookie{
  1098  		Name:  "test1",
  1099  		Value: "test1",
  1100  	})
  1101  	req.AddCookie(&http.Cookie{
  1102  		Name:  "oauth2",
  1103  		Value: "oauth2",
  1104  	})
  1105  
  1106  	p.ClearSessionCookie(rw, req)
  1107  	header := rw.Header()
  1108  
  1109  	assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries")
  1110  }