github.com/pusher/oauth2_proxy@v3.2.0+incompatible/providers/internal_util_test.go (about)

     1  package providers
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"net/url"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  )
    12  
    13  func updateURL(url *url.URL, hostname string) {
    14  	url.Scheme = "http"
    15  	url.Host = hostname
    16  }
    17  
    18  type ValidateSessionStateTestProvider struct {
    19  	*ProviderData
    20  }
    21  
    22  func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) {
    23  	return "", errors.New("not implemented")
    24  }
    25  
    26  // Note that we're testing the internal validateToken() used to implement
    27  // several Provider's ValidateSessionState() implementations
    28  func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool {
    29  	return false
    30  }
    31  
    32  type ValidateSessionStateTest struct {
    33  	backend      *httptest.Server
    34  	responseCode int
    35  	provider     *ValidateSessionStateTestProvider
    36  	header       http.Header
    37  }
    38  
    39  func NewValidateSessionStateTest() *ValidateSessionStateTest {
    40  	var vtTest ValidateSessionStateTest
    41  
    42  	vtTest.backend = httptest.NewServer(
    43  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    44  			if r.URL.Path != "/oauth/tokeninfo" {
    45  				w.WriteHeader(500)
    46  				w.Write([]byte("unknown URL"))
    47  			}
    48  			tokenParam := r.FormValue("access_token")
    49  			if tokenParam == "" {
    50  				missing := false
    51  				receivedHeaders := r.Header
    52  				for k := range vtTest.header {
    53  					received := receivedHeaders.Get(k)
    54  					expected := vtTest.header.Get(k)
    55  					if received == "" || received != expected {
    56  						missing = true
    57  					}
    58  				}
    59  				if missing {
    60  					w.WriteHeader(500)
    61  					w.Write([]byte("no token param and missing or incorrect headers"))
    62  				}
    63  			}
    64  			w.WriteHeader(vtTest.responseCode)
    65  			w.Write([]byte("only code matters; contents disregarded"))
    66  
    67  		}))
    68  	backendURL, _ := url.Parse(vtTest.backend.URL)
    69  	vtTest.provider = &ValidateSessionStateTestProvider{
    70  		ProviderData: &ProviderData{
    71  			ValidateURL: &url.URL{
    72  				Scheme: "http",
    73  				Host:   backendURL.Host,
    74  				Path:   "/oauth/tokeninfo",
    75  			},
    76  		},
    77  	}
    78  	vtTest.responseCode = 200
    79  	return &vtTest
    80  }
    81  
    82  func (vtTest *ValidateSessionStateTest) Close() {
    83  	vtTest.backend.Close()
    84  }
    85  
    86  func TestValidateSessionStateValidToken(t *testing.T) {
    87  	vtTest := NewValidateSessionStateTest()
    88  	defer vtTest.Close()
    89  	assert.Equal(t, true, validateToken(vtTest.provider, "foobar", nil))
    90  }
    91  
    92  func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) {
    93  	vtTest := NewValidateSessionStateTest()
    94  	defer vtTest.Close()
    95  	vtTest.header = make(http.Header)
    96  	vtTest.header.Set("Authorization", "Bearer foobar")
    97  	assert.Equal(t, true,
    98  		validateToken(vtTest.provider, "foobar", vtTest.header))
    99  }
   100  
   101  func TestValidateSessionStateEmptyToken(t *testing.T) {
   102  	vtTest := NewValidateSessionStateTest()
   103  	defer vtTest.Close()
   104  	assert.Equal(t, false, validateToken(vtTest.provider, "", nil))
   105  }
   106  
   107  func TestValidateSessionStateEmptyValidateURL(t *testing.T) {
   108  	vtTest := NewValidateSessionStateTest()
   109  	defer vtTest.Close()
   110  	vtTest.provider.Data().ValidateURL = nil
   111  	assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
   112  }
   113  
   114  func TestValidateSessionStateRequestNetworkFailure(t *testing.T) {
   115  	vtTest := NewValidateSessionStateTest()
   116  	// Close immediately to simulate a network failure
   117  	vtTest.Close()
   118  	assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
   119  }
   120  
   121  func TestValidateSessionStateExpiredToken(t *testing.T) {
   122  	vtTest := NewValidateSessionStateTest()
   123  	defer vtTest.Close()
   124  	vtTest.responseCode = 401
   125  	assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
   126  }
   127  
   128  func TestStripTokenNotPresent(t *testing.T) {
   129  	test := "http://local.test/api/test?a=1&b=2"
   130  	assert.Equal(t, test, stripToken(test))
   131  }
   132  
   133  func TestStripToken(t *testing.T) {
   134  	test := "http://local.test/api/test?access_token=deadbeef&b=1&c=2"
   135  	expected := "http://local.test/api/test?access_token=dead...&b=1&c=2"
   136  	assert.Equal(t, expected, stripToken(test))
   137  }