github.com/pusher/oauth2_proxy@v3.2.0+incompatible/providers/internal_util_test.go (about) 1 package providers 2 3 import ( 4 "errors" 5 "net/http" 6 "net/http/httptest" 7 "net/url" 8 "testing" 9 10 "github.com/stretchr/testify/assert" 11 ) 12 13 func updateURL(url *url.URL, hostname string) { 14 url.Scheme = "http" 15 url.Host = hostname 16 } 17 18 type ValidateSessionStateTestProvider struct { 19 *ProviderData 20 } 21 22 func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { 23 return "", errors.New("not implemented") 24 } 25 26 // Note that we're testing the internal validateToken() used to implement 27 // several Provider's ValidateSessionState() implementations 28 func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { 29 return false 30 } 31 32 type ValidateSessionStateTest struct { 33 backend *httptest.Server 34 responseCode int 35 provider *ValidateSessionStateTestProvider 36 header http.Header 37 } 38 39 func NewValidateSessionStateTest() *ValidateSessionStateTest { 40 var vtTest ValidateSessionStateTest 41 42 vtTest.backend = httptest.NewServer( 43 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 44 if r.URL.Path != "/oauth/tokeninfo" { 45 w.WriteHeader(500) 46 w.Write([]byte("unknown URL")) 47 } 48 tokenParam := r.FormValue("access_token") 49 if tokenParam == "" { 50 missing := false 51 receivedHeaders := r.Header 52 for k := range vtTest.header { 53 received := receivedHeaders.Get(k) 54 expected := vtTest.header.Get(k) 55 if received == "" || received != expected { 56 missing = true 57 } 58 } 59 if missing { 60 w.WriteHeader(500) 61 w.Write([]byte("no token param and missing or incorrect headers")) 62 } 63 } 64 w.WriteHeader(vtTest.responseCode) 65 w.Write([]byte("only code matters; contents disregarded")) 66 67 })) 68 backendURL, _ := url.Parse(vtTest.backend.URL) 69 vtTest.provider = &ValidateSessionStateTestProvider{ 70 ProviderData: &ProviderData{ 71 ValidateURL: &url.URL{ 72 Scheme: "http", 73 Host: backendURL.Host, 74 Path: "/oauth/tokeninfo", 75 }, 76 }, 77 } 78 vtTest.responseCode = 200 79 return &vtTest 80 } 81 82 func (vtTest *ValidateSessionStateTest) Close() { 83 vtTest.backend.Close() 84 } 85 86 func TestValidateSessionStateValidToken(t *testing.T) { 87 vtTest := NewValidateSessionStateTest() 88 defer vtTest.Close() 89 assert.Equal(t, true, validateToken(vtTest.provider, "foobar", nil)) 90 } 91 92 func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { 93 vtTest := NewValidateSessionStateTest() 94 defer vtTest.Close() 95 vtTest.header = make(http.Header) 96 vtTest.header.Set("Authorization", "Bearer foobar") 97 assert.Equal(t, true, 98 validateToken(vtTest.provider, "foobar", vtTest.header)) 99 } 100 101 func TestValidateSessionStateEmptyToken(t *testing.T) { 102 vtTest := NewValidateSessionStateTest() 103 defer vtTest.Close() 104 assert.Equal(t, false, validateToken(vtTest.provider, "", nil)) 105 } 106 107 func TestValidateSessionStateEmptyValidateURL(t *testing.T) { 108 vtTest := NewValidateSessionStateTest() 109 defer vtTest.Close() 110 vtTest.provider.Data().ValidateURL = nil 111 assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) 112 } 113 114 func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { 115 vtTest := NewValidateSessionStateTest() 116 // Close immediately to simulate a network failure 117 vtTest.Close() 118 assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) 119 } 120 121 func TestValidateSessionStateExpiredToken(t *testing.T) { 122 vtTest := NewValidateSessionStateTest() 123 defer vtTest.Close() 124 vtTest.responseCode = 401 125 assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) 126 } 127 128 func TestStripTokenNotPresent(t *testing.T) { 129 test := "http://local.test/api/test?a=1&b=2" 130 assert.Equal(t, test, stripToken(test)) 131 } 132 133 func TestStripToken(t *testing.T) { 134 test := "http://local.test/api/test?access_token=deadbeef&b=1&c=2" 135 expected := "http://local.test/api/test?access_token=dead...&b=1&c=2" 136 assert.Equal(t, expected, stripToken(test)) 137 }