github.com/blend/go-sdk@v1.20220411.3/oauth/manager_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package oauth
     9  
    10  import (
    11  	"encoding/base64"
    12  	"encoding/json"
    13  	"fmt"
    14  	"net/http"
    15  	"net/http/httptest"
    16  	"net/url"
    17  	"testing"
    18  	"time"
    19  
    20  	"golang.org/x/oauth2"
    21  
    22  	"github.com/golang-jwt/jwt"
    23  
    24  	"github.com/blend/go-sdk/assert"
    25  	"github.com/blend/go-sdk/crypto"
    26  	"github.com/blend/go-sdk/jwk"
    27  	"github.com/blend/go-sdk/r2"
    28  	"github.com/blend/go-sdk/uuid"
    29  	"github.com/blend/go-sdk/webutil"
    30  )
    31  
    32  func Test_Manager_Finish(t *testing.T) {
    33  	it := assert.New(t)
    34  
    35  	pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem))
    36  	it.Nil(err)
    37  	pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem))
    38  	it.Nil(err)
    39  	keys := []jwk.JWK{
    40  		createJWK(pk0),
    41  		createJWK(pk1),
    42  	}
    43  	keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    44  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
    45  		rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control
    46  		rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat))    // set expires
    47  		rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat))                        // set date
    48  		rw.WriteHeader(200)
    49  		_ = json.NewEncoder(rw).Encode(struct {
    50  			Keys []jwk.JWK `json:"keys"`
    51  		}{
    52  			Keys: keys,
    53  		})
    54  	}))
    55  	defer keysResponder.Close()
    56  
    57  	codeResponse, err := createCodeResponse("test_client_id", keys[1].KID, pk1)
    58  	it.Nil(err)
    59  
    60  	codeResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    61  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
    62  		rw.WriteHeader(200)
    63  		_, _ = rw.Write(codeResponse)
    64  	}))
    65  	defer codeResponder.Close()
    66  
    67  	profileResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    68  		if accessToken := req.Header.Get(webutil.HeaderAuthorization); accessToken != "Bearer test_access_token" {
    69  			http.Error(rw, "not authorized", http.StatusUnauthorized)
    70  			return
    71  		}
    72  
    73  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
    74  		rw.WriteHeader(200)
    75  		fmt.Fprintf(rw, `{
    76  			"id": "12012312390931",
    77  			"email": "example-string@test.blend.com",
    78  			"verified_email": true,
    79  			"name": "example-string Dog",
    80  			"given_name": "example-string",
    81  			"family_name": "Dog",
    82  			"picture": "https://example.com/example-string.jpg",
    83  			"locale": "en",
    84  			"hd": "test.blend.com"
    85  		  }`)
    86  	}))
    87  	defer profileResponder.Close()
    88  
    89  	mgr, err := New(
    90  		OptClientID("test_client_id"),
    91  		OptClientSecret(crypto.MustCreateKeyString(32)),
    92  		OptSecret(crypto.MustCreateKey(32)),
    93  		OptAllowedDomains("test.blend.com"),
    94  	)
    95  	it.Nil(err)
    96  	mgr.PublicKeyCache.FetchPublicKeysDefaults = []r2.Option{
    97  		r2.OptURL(keysResponder.URL),
    98  	}
    99  	mgr.FetchProfileDefaults = []r2.Option{
   100  		r2.OptURL(profileResponder.URL),
   101  	}
   102  	mgr.Endpoint = oauth2.Endpoint{
   103  		AuthStyle: oauth2.AuthStyleInParams,
   104  		TokenURL:  codeResponder.URL,
   105  	}
   106  	finishRequest := &http.Request{
   107  		URL: &url.URL{
   108  			RawQuery: (url.Values{
   109  				"code":  []string{"test_code"},
   110  				"state": []string{MustSerializeState(mgr.CreateState())},
   111  			}).Encode(),
   112  		},
   113  	}
   114  
   115  	res, err := mgr.Finish(finishRequest)
   116  	it.Nil(err)
   117  	it.Equal("example-string@test.blend.com", res.Profile.Email)
   118  	it.Equal("example-string", res.Profile.GivenName)
   119  	it.Equal("Dog", res.Profile.FamilyName)
   120  	it.Equal("en", res.Profile.Locale)
   121  	it.Equal("https://example.com/example-string.jpg", res.Profile.PictureURL)
   122  }
   123  
   124  func Test_Manager_Finish_disallowedDomain(t *testing.T) {
   125  	it := assert.New(t)
   126  
   127  	pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem))
   128  	it.Nil(err)
   129  	pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem))
   130  	it.Nil(err)
   131  	keys := []jwk.JWK{
   132  		createJWK(pk0),
   133  		createJWK(pk1),
   134  	}
   135  	keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   136  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
   137  		rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control
   138  		rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat))    // set expires
   139  		rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat))                        // set date
   140  		rw.WriteHeader(200)
   141  		_ = json.NewEncoder(rw).Encode(struct {
   142  			Keys []jwk.JWK `json:"keys"`
   143  		}{
   144  			Keys: keys,
   145  		})
   146  	}))
   147  	defer keysResponder.Close()
   148  
   149  	codeResponse, err := createCodeResponse("test_client_id", keys[1].KID, pk1)
   150  	it.Nil(err)
   151  
   152  	codeResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   153  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
   154  		rw.WriteHeader(200)
   155  		_, _ = rw.Write(codeResponse)
   156  	}))
   157  	defer codeResponder.Close()
   158  
   159  	profileResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   160  		if accessToken := req.URL.Query().Get("access_token"); accessToken != "test_access_token" {
   161  			http.Error(rw, "not authorized", http.StatusUnauthorized)
   162  			return
   163  		}
   164  
   165  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
   166  		rw.WriteHeader(200)
   167  		fmt.Fprintf(rw, `{
   168  			"id": "12012312390931",
   169  			"email": "example-string@test.blend.com",
   170  			"verified_email": true,
   171  			"name": "example-string Dog",
   172  			"given_name": "example-string",
   173  			"family_name": "Dog",
   174  			"picture": "https://example.com/example-string.jpg",
   175  			"locale": "en",
   176  			"hd": "test.blend.com"
   177  		  }`)
   178  	}))
   179  	defer profileResponder.Close()
   180  
   181  	mgr, err := New(
   182  		OptClientID("test_client_id"),
   183  		OptClientSecret(crypto.MustCreateKeyString(32)),
   184  		OptSecret(crypto.MustCreateKey(32)),
   185  		OptAllowedDomains("blend.com"),
   186  	)
   187  	it.Nil(err)
   188  	mgr.PublicKeyCache.FetchPublicKeysDefaults = []r2.Option{
   189  		r2.OptURL(keysResponder.URL),
   190  	}
   191  	mgr.FetchProfileDefaults = []r2.Option{
   192  		r2.OptURL(profileResponder.URL),
   193  	}
   194  	mgr.Endpoint = oauth2.Endpoint{
   195  		AuthStyle: oauth2.AuthStyleInParams,
   196  		TokenURL:  codeResponder.URL,
   197  	}
   198  	finishRequest := &http.Request{
   199  		URL: &url.URL{
   200  			RawQuery: (url.Values{
   201  				"code":  []string{"test_code"},
   202  				"state": []string{MustSerializeState(mgr.CreateState())},
   203  			}).Encode(),
   204  		},
   205  	}
   206  
   207  	res, err := mgr.Finish(finishRequest)
   208  	it.NotNil(err)
   209  	it.Empty(res.Profile.Email)
   210  }
   211  
   212  func Test_Manager_Finish_failsAudience(t *testing.T) {
   213  	it := assert.New(t)
   214  
   215  	pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem))
   216  	it.Nil(err)
   217  	pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem))
   218  	it.Nil(err)
   219  	keys := []jwk.JWK{
   220  		createJWK(pk0),
   221  		createJWK(pk1),
   222  	}
   223  	keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   224  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
   225  		rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control
   226  		rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat))    // set expires
   227  		rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat))                        // set date
   228  		rw.WriteHeader(200)
   229  		_ = json.NewEncoder(rw).Encode(struct {
   230  			Keys []jwk.JWK `json:"keys"`
   231  		}{
   232  			Keys: keys,
   233  		})
   234  	}))
   235  	defer keysResponder.Close()
   236  
   237  	codeResponse, err := createCodeResponse("not_test_client_id", keys[1].KID, pk1)
   238  	it.Nil(err)
   239  
   240  	codeResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   241  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
   242  		rw.WriteHeader(200)
   243  		_, _ = rw.Write(codeResponse)
   244  	}))
   245  	defer codeResponder.Close()
   246  
   247  	mgr, err := New(
   248  		OptClientID("test_client_id"),
   249  		OptClientSecret(crypto.MustCreateKeyString(32)),
   250  		OptSecret(crypto.MustCreateKey(32)),
   251  		OptAllowedDomains("blend.com"),
   252  	)
   253  	it.Nil(err)
   254  	mgr.PublicKeyCache.FetchPublicKeysDefaults = []r2.Option{
   255  		r2.OptURL(keysResponder.URL),
   256  	}
   257  	mgr.Endpoint = oauth2.Endpoint{
   258  		AuthStyle: oauth2.AuthStyleInParams,
   259  		TokenURL:  codeResponder.URL,
   260  	}
   261  	finishRequest := &http.Request{
   262  		URL: &url.URL{
   263  			RawQuery: (url.Values{
   264  				"code":  []string{"test_code"},
   265  				"state": []string{MustSerializeState(mgr.CreateState())},
   266  			}).Encode(),
   267  		},
   268  	}
   269  
   270  	res, err := mgr.Finish(finishRequest)
   271  	it.NotNil(err)
   272  	it.Empty(res.Profile.Email)
   273  }
   274  
   275  func Test_Manager_Finish_failsVerification(t *testing.T) {
   276  	it := assert.New(t)
   277  
   278  	pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem))
   279  	it.Nil(err)
   280  	pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem))
   281  	it.Nil(err)
   282  	pk2, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk2pem))
   283  	it.Nil(err)
   284  	keys := []jwk.JWK{
   285  		createJWK(pk0),
   286  		createJWK(pk1),
   287  	}
   288  	keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   289  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
   290  		rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control
   291  		rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat))    // set expires
   292  		rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat))                        // set date
   293  		rw.WriteHeader(200)
   294  		_ = json.NewEncoder(rw).Encode(struct {
   295  			Keys []jwk.JWK `json:"keys"`
   296  		}{
   297  			Keys: keys,
   298  		})
   299  	}))
   300  	defer keysResponder.Close()
   301  
   302  	codeResponse, err := createCodeResponse("test_client_id", uuid.V4().String(), pk2)
   303  	it.Nil(err)
   304  
   305  	codeResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   306  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
   307  		rw.WriteHeader(200)
   308  		_, _ = rw.Write(codeResponse)
   309  	}))
   310  	defer codeResponder.Close()
   311  
   312  	profileResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   313  		if accessToken := req.URL.Query().Get("access_token"); accessToken != "test_access_token" {
   314  			http.Error(rw, "not authorized", http.StatusUnauthorized)
   315  			return
   316  		}
   317  
   318  		rw.Header().Set("Content-Type", "application/json; charset=UTF-8")
   319  		rw.WriteHeader(200)
   320  		fmt.Fprintf(rw, `{
   321  			"id": "12012312390931",
   322  			"email": "example-string@test.blend.com",
   323  			"verified_email": true,
   324  			"name": "example-string Dog",
   325  			"given_name": "example-string",
   326  			"family_name": "Dog",
   327  			"picture": "https://example.com/example-string.jpg",
   328  			"locale": "en",
   329  			"hd": "test.blend.com"
   330  		  }`)
   331  	}))
   332  	defer profileResponder.Close()
   333  
   334  	mgr, err := New(
   335  		OptClientID("test_client_id"),
   336  		OptClientSecret(crypto.MustCreateKeyString(32)),
   337  		OptSecret(crypto.MustCreateKey(32)),
   338  		OptAllowedDomains("test.blend.com"),
   339  	)
   340  	it.Nil(err)
   341  	mgr.PublicKeyCache.FetchPublicKeysDefaults = []r2.Option{
   342  		r2.OptURL(keysResponder.URL),
   343  	}
   344  	mgr.FetchProfileDefaults = []r2.Option{
   345  		r2.OptURL(profileResponder.URL),
   346  	}
   347  	mgr.Endpoint = oauth2.Endpoint{
   348  		AuthStyle: oauth2.AuthStyleInParams,
   349  		TokenURL:  codeResponder.URL,
   350  	}
   351  	finishRequest := &http.Request{
   352  		URL: &url.URL{
   353  			RawQuery: (url.Values{
   354  				"code":  []string{"test_code"},
   355  				"state": []string{MustSerializeState(mgr.CreateState())},
   356  			}).Encode(),
   357  		},
   358  	}
   359  
   360  	res, err := mgr.Finish(finishRequest)
   361  	it.NotNil(err)
   362  	it.Empty(res.Profile.Email)
   363  }
   364  
   365  func Test_MustNew(t *testing.T) {
   366  	assert := assert.New(t)
   367  	assert.Empty(MustNew().Secret)
   368  	assert.NotEmpty(MustNew().Endpoint.AuthURL)
   369  	assert.NotEmpty(MustNew().Scopes)
   370  }
   371  
   372  func Test_NewFromConfig(t *testing.T) {
   373  	assert := assert.New(t)
   374  
   375  	m, err := New(OptConfig(Config{
   376  		RedirectURI:  "https://app.com/oauth/google",
   377  		HostedDomain: "foo.com",
   378  		ClientID:     "foo_client",
   379  		ClientSecret: "bar_secret",
   380  	}))
   381  
   382  	assert.Nil(err)
   383  	assert.Empty(m.Secret)
   384  	assert.Equal("https://app.com/oauth/google", m.RedirectURL)
   385  	assert.Equal("foo_client", m.ClientID)
   386  	assert.Equal("bar_secret", m.ClientSecret)
   387  }
   388  
   389  func Test_NewFromConfigWithSecret(t *testing.T) {
   390  	assert := assert.New(t)
   391  
   392  	m, err := New(OptConfig(Config{
   393  		Secret: base64.StdEncoding.EncodeToString([]byte("test string")),
   394  	}))
   395  
   396  	assert.Nil(err)
   397  	assert.NotEmpty(m.Secret)
   398  	assert.Equal("test string", string(m.Secret))
   399  }
   400  
   401  func Test_Manager_OAuthURL_FullyQualifiedRedirectURI(t *testing.T) {
   402  	assert := assert.New(t)
   403  
   404  	m, err := New()
   405  	assert.Nil(err)
   406  	m.ClientID = "test_client_id"
   407  	m.HostedDomain = "test.blend.com"
   408  	m.RedirectURL = "https://local.shortcut-service.centrio.com/oauth/google"
   409  
   410  	oauthURL, err := m.OAuthURL(nil)
   411  	assert.Nil(err)
   412  
   413  	parsed, err := url.Parse(oauthURL)
   414  	assert.Nil(err)
   415  	assert.Equal("test_client_id", parsed.Query().Get("client_id"))
   416  	assert.Equal("test.blend.com", parsed.Query().Get("hd"), "we should set the hosted domain if it's configured")
   417  }
   418  
   419  func Test_Manager_OAuthURL(t *testing.T) {
   420  	assert := assert.New(t)
   421  
   422  	m, err := New()
   423  	assert.Nil(err)
   424  	m.ClientID = "test_client_id"
   425  	m.RedirectURL = "/oauth/google"
   426  
   427  	oauthURL, err := m.OAuthURL(&http.Request{RequestURI: "https://test.blend.com/foo"})
   428  	assert.Nil(err)
   429  
   430  	_, err = url.Parse(oauthURL)
   431  	assert.Nil(err)
   432  }
   433  
   434  func Test_Manager_OAuthURLRedirect(t *testing.T) {
   435  	assert := assert.New(t)
   436  
   437  	m, err := New()
   438  	assert.Nil(err)
   439  	m.ClientID = "test_client_id"
   440  	m.RedirectURL = "https://local.shortcut-service.centrio.com/oauth/google"
   441  
   442  	urlFragment, err := m.OAuthURL(nil, OptStateRedirectURI("bar_foo"))
   443  	assert.Nil(err)
   444  
   445  	u, err := url.Parse(urlFragment)
   446  	assert.Nil(err)
   447  	assert.NotEmpty(u.Query().Get("state"))
   448  
   449  	state := u.Query().Get("state")
   450  	deserialized, err := DeserializeState(state)
   451  	assert.Nil(err)
   452  	assert.Nil(m.ValidateState(deserialized))
   453  	assert.Equal("bar_foo", deserialized.RedirectURI)
   454  }
   455  
   456  func Test_Manager_ValidateState(t *testing.T) {
   457  	assert := assert.New(t)
   458  
   459  	insecure := MustNew()
   460  	assert.Nil(insecure.ValidateState(insecure.CreateState()))
   461  
   462  	secure := MustNew()
   463  	secure.Secret = crypto.MustCreateKey(32)
   464  	assert.Nil(secure.ValidateState(secure.CreateState()))
   465  
   466  	wrongKey := MustNew()
   467  	wrongKey.Secret = crypto.MustCreateKey(32)
   468  
   469  	assert.NotNil(secure.ValidateState(wrongKey.CreateState()))
   470  }