github.com/blend/go-sdk@v1.20220411.3/oauth/manager_test.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package oauth 9 10 import ( 11 "encoding/base64" 12 "encoding/json" 13 "fmt" 14 "net/http" 15 "net/http/httptest" 16 "net/url" 17 "testing" 18 "time" 19 20 "golang.org/x/oauth2" 21 22 "github.com/golang-jwt/jwt" 23 24 "github.com/blend/go-sdk/assert" 25 "github.com/blend/go-sdk/crypto" 26 "github.com/blend/go-sdk/jwk" 27 "github.com/blend/go-sdk/r2" 28 "github.com/blend/go-sdk/uuid" 29 "github.com/blend/go-sdk/webutil" 30 ) 31 32 func Test_Manager_Finish(t *testing.T) { 33 it := assert.New(t) 34 35 pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem)) 36 it.Nil(err) 37 pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem)) 38 it.Nil(err) 39 keys := []jwk.JWK{ 40 createJWK(pk0), 41 createJWK(pk1), 42 } 43 keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 44 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 45 rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control 46 rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat)) // set expires 47 rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat)) // set date 48 rw.WriteHeader(200) 49 _ = json.NewEncoder(rw).Encode(struct { 50 Keys []jwk.JWK `json:"keys"` 51 }{ 52 Keys: keys, 53 }) 54 })) 55 defer keysResponder.Close() 56 57 codeResponse, err := createCodeResponse("test_client_id", keys[1].KID, pk1) 58 it.Nil(err) 59 60 codeResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 61 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 62 rw.WriteHeader(200) 63 _, _ = rw.Write(codeResponse) 64 })) 65 defer codeResponder.Close() 66 67 profileResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 68 if accessToken := req.Header.Get(webutil.HeaderAuthorization); accessToken != "Bearer test_access_token" { 69 http.Error(rw, "not authorized", http.StatusUnauthorized) 70 return 71 } 72 73 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 74 rw.WriteHeader(200) 75 fmt.Fprintf(rw, `{ 76 "id": "12012312390931", 77 "email": "example-string@test.blend.com", 78 "verified_email": true, 79 "name": "example-string Dog", 80 "given_name": "example-string", 81 "family_name": "Dog", 82 "picture": "https://example.com/example-string.jpg", 83 "locale": "en", 84 "hd": "test.blend.com" 85 }`) 86 })) 87 defer profileResponder.Close() 88 89 mgr, err := New( 90 OptClientID("test_client_id"), 91 OptClientSecret(crypto.MustCreateKeyString(32)), 92 OptSecret(crypto.MustCreateKey(32)), 93 OptAllowedDomains("test.blend.com"), 94 ) 95 it.Nil(err) 96 mgr.PublicKeyCache.FetchPublicKeysDefaults = []r2.Option{ 97 r2.OptURL(keysResponder.URL), 98 } 99 mgr.FetchProfileDefaults = []r2.Option{ 100 r2.OptURL(profileResponder.URL), 101 } 102 mgr.Endpoint = oauth2.Endpoint{ 103 AuthStyle: oauth2.AuthStyleInParams, 104 TokenURL: codeResponder.URL, 105 } 106 finishRequest := &http.Request{ 107 URL: &url.URL{ 108 RawQuery: (url.Values{ 109 "code": []string{"test_code"}, 110 "state": []string{MustSerializeState(mgr.CreateState())}, 111 }).Encode(), 112 }, 113 } 114 115 res, err := mgr.Finish(finishRequest) 116 it.Nil(err) 117 it.Equal("example-string@test.blend.com", res.Profile.Email) 118 it.Equal("example-string", res.Profile.GivenName) 119 it.Equal("Dog", res.Profile.FamilyName) 120 it.Equal("en", res.Profile.Locale) 121 it.Equal("https://example.com/example-string.jpg", res.Profile.PictureURL) 122 } 123 124 func Test_Manager_Finish_disallowedDomain(t *testing.T) { 125 it := assert.New(t) 126 127 pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem)) 128 it.Nil(err) 129 pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem)) 130 it.Nil(err) 131 keys := []jwk.JWK{ 132 createJWK(pk0), 133 createJWK(pk1), 134 } 135 keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 136 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 137 rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control 138 rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat)) // set expires 139 rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat)) // set date 140 rw.WriteHeader(200) 141 _ = json.NewEncoder(rw).Encode(struct { 142 Keys []jwk.JWK `json:"keys"` 143 }{ 144 Keys: keys, 145 }) 146 })) 147 defer keysResponder.Close() 148 149 codeResponse, err := createCodeResponse("test_client_id", keys[1].KID, pk1) 150 it.Nil(err) 151 152 codeResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 153 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 154 rw.WriteHeader(200) 155 _, _ = rw.Write(codeResponse) 156 })) 157 defer codeResponder.Close() 158 159 profileResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 160 if accessToken := req.URL.Query().Get("access_token"); accessToken != "test_access_token" { 161 http.Error(rw, "not authorized", http.StatusUnauthorized) 162 return 163 } 164 165 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 166 rw.WriteHeader(200) 167 fmt.Fprintf(rw, `{ 168 "id": "12012312390931", 169 "email": "example-string@test.blend.com", 170 "verified_email": true, 171 "name": "example-string Dog", 172 "given_name": "example-string", 173 "family_name": "Dog", 174 "picture": "https://example.com/example-string.jpg", 175 "locale": "en", 176 "hd": "test.blend.com" 177 }`) 178 })) 179 defer profileResponder.Close() 180 181 mgr, err := New( 182 OptClientID("test_client_id"), 183 OptClientSecret(crypto.MustCreateKeyString(32)), 184 OptSecret(crypto.MustCreateKey(32)), 185 OptAllowedDomains("blend.com"), 186 ) 187 it.Nil(err) 188 mgr.PublicKeyCache.FetchPublicKeysDefaults = []r2.Option{ 189 r2.OptURL(keysResponder.URL), 190 } 191 mgr.FetchProfileDefaults = []r2.Option{ 192 r2.OptURL(profileResponder.URL), 193 } 194 mgr.Endpoint = oauth2.Endpoint{ 195 AuthStyle: oauth2.AuthStyleInParams, 196 TokenURL: codeResponder.URL, 197 } 198 finishRequest := &http.Request{ 199 URL: &url.URL{ 200 RawQuery: (url.Values{ 201 "code": []string{"test_code"}, 202 "state": []string{MustSerializeState(mgr.CreateState())}, 203 }).Encode(), 204 }, 205 } 206 207 res, err := mgr.Finish(finishRequest) 208 it.NotNil(err) 209 it.Empty(res.Profile.Email) 210 } 211 212 func Test_Manager_Finish_failsAudience(t *testing.T) { 213 it := assert.New(t) 214 215 pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem)) 216 it.Nil(err) 217 pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem)) 218 it.Nil(err) 219 keys := []jwk.JWK{ 220 createJWK(pk0), 221 createJWK(pk1), 222 } 223 keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 224 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 225 rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control 226 rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat)) // set expires 227 rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat)) // set date 228 rw.WriteHeader(200) 229 _ = json.NewEncoder(rw).Encode(struct { 230 Keys []jwk.JWK `json:"keys"` 231 }{ 232 Keys: keys, 233 }) 234 })) 235 defer keysResponder.Close() 236 237 codeResponse, err := createCodeResponse("not_test_client_id", keys[1].KID, pk1) 238 it.Nil(err) 239 240 codeResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 241 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 242 rw.WriteHeader(200) 243 _, _ = rw.Write(codeResponse) 244 })) 245 defer codeResponder.Close() 246 247 mgr, err := New( 248 OptClientID("test_client_id"), 249 OptClientSecret(crypto.MustCreateKeyString(32)), 250 OptSecret(crypto.MustCreateKey(32)), 251 OptAllowedDomains("blend.com"), 252 ) 253 it.Nil(err) 254 mgr.PublicKeyCache.FetchPublicKeysDefaults = []r2.Option{ 255 r2.OptURL(keysResponder.URL), 256 } 257 mgr.Endpoint = oauth2.Endpoint{ 258 AuthStyle: oauth2.AuthStyleInParams, 259 TokenURL: codeResponder.URL, 260 } 261 finishRequest := &http.Request{ 262 URL: &url.URL{ 263 RawQuery: (url.Values{ 264 "code": []string{"test_code"}, 265 "state": []string{MustSerializeState(mgr.CreateState())}, 266 }).Encode(), 267 }, 268 } 269 270 res, err := mgr.Finish(finishRequest) 271 it.NotNil(err) 272 it.Empty(res.Profile.Email) 273 } 274 275 func Test_Manager_Finish_failsVerification(t *testing.T) { 276 it := assert.New(t) 277 278 pk0, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk0pem)) 279 it.Nil(err) 280 pk1, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk1pem)) 281 it.Nil(err) 282 pk2, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(pk2pem)) 283 it.Nil(err) 284 keys := []jwk.JWK{ 285 createJWK(pk0), 286 createJWK(pk1), 287 } 288 keysResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 289 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 290 rw.Header().Set("Cache-Control", "public, max-age=23196, must-revalidate, no-transform") // set cache control 291 rw.Header().Set("Expires", time.Now().UTC().AddDate(0, 1, 0).Format(http.TimeFormat)) // set expires 292 rw.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat)) // set date 293 rw.WriteHeader(200) 294 _ = json.NewEncoder(rw).Encode(struct { 295 Keys []jwk.JWK `json:"keys"` 296 }{ 297 Keys: keys, 298 }) 299 })) 300 defer keysResponder.Close() 301 302 codeResponse, err := createCodeResponse("test_client_id", uuid.V4().String(), pk2) 303 it.Nil(err) 304 305 codeResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 306 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 307 rw.WriteHeader(200) 308 _, _ = rw.Write(codeResponse) 309 })) 310 defer codeResponder.Close() 311 312 profileResponder := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 313 if accessToken := req.URL.Query().Get("access_token"); accessToken != "test_access_token" { 314 http.Error(rw, "not authorized", http.StatusUnauthorized) 315 return 316 } 317 318 rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 319 rw.WriteHeader(200) 320 fmt.Fprintf(rw, `{ 321 "id": "12012312390931", 322 "email": "example-string@test.blend.com", 323 "verified_email": true, 324 "name": "example-string Dog", 325 "given_name": "example-string", 326 "family_name": "Dog", 327 "picture": "https://example.com/example-string.jpg", 328 "locale": "en", 329 "hd": "test.blend.com" 330 }`) 331 })) 332 defer profileResponder.Close() 333 334 mgr, err := New( 335 OptClientID("test_client_id"), 336 OptClientSecret(crypto.MustCreateKeyString(32)), 337 OptSecret(crypto.MustCreateKey(32)), 338 OptAllowedDomains("test.blend.com"), 339 ) 340 it.Nil(err) 341 mgr.PublicKeyCache.FetchPublicKeysDefaults = []r2.Option{ 342 r2.OptURL(keysResponder.URL), 343 } 344 mgr.FetchProfileDefaults = []r2.Option{ 345 r2.OptURL(profileResponder.URL), 346 } 347 mgr.Endpoint = oauth2.Endpoint{ 348 AuthStyle: oauth2.AuthStyleInParams, 349 TokenURL: codeResponder.URL, 350 } 351 finishRequest := &http.Request{ 352 URL: &url.URL{ 353 RawQuery: (url.Values{ 354 "code": []string{"test_code"}, 355 "state": []string{MustSerializeState(mgr.CreateState())}, 356 }).Encode(), 357 }, 358 } 359 360 res, err := mgr.Finish(finishRequest) 361 it.NotNil(err) 362 it.Empty(res.Profile.Email) 363 } 364 365 func Test_MustNew(t *testing.T) { 366 assert := assert.New(t) 367 assert.Empty(MustNew().Secret) 368 assert.NotEmpty(MustNew().Endpoint.AuthURL) 369 assert.NotEmpty(MustNew().Scopes) 370 } 371 372 func Test_NewFromConfig(t *testing.T) { 373 assert := assert.New(t) 374 375 m, err := New(OptConfig(Config{ 376 RedirectURI: "https://app.com/oauth/google", 377 HostedDomain: "foo.com", 378 ClientID: "foo_client", 379 ClientSecret: "bar_secret", 380 })) 381 382 assert.Nil(err) 383 assert.Empty(m.Secret) 384 assert.Equal("https://app.com/oauth/google", m.RedirectURL) 385 assert.Equal("foo_client", m.ClientID) 386 assert.Equal("bar_secret", m.ClientSecret) 387 } 388 389 func Test_NewFromConfigWithSecret(t *testing.T) { 390 assert := assert.New(t) 391 392 m, err := New(OptConfig(Config{ 393 Secret: base64.StdEncoding.EncodeToString([]byte("test string")), 394 })) 395 396 assert.Nil(err) 397 assert.NotEmpty(m.Secret) 398 assert.Equal("test string", string(m.Secret)) 399 } 400 401 func Test_Manager_OAuthURL_FullyQualifiedRedirectURI(t *testing.T) { 402 assert := assert.New(t) 403 404 m, err := New() 405 assert.Nil(err) 406 m.ClientID = "test_client_id" 407 m.HostedDomain = "test.blend.com" 408 m.RedirectURL = "https://local.shortcut-service.centrio.com/oauth/google" 409 410 oauthURL, err := m.OAuthURL(nil) 411 assert.Nil(err) 412 413 parsed, err := url.Parse(oauthURL) 414 assert.Nil(err) 415 assert.Equal("test_client_id", parsed.Query().Get("client_id")) 416 assert.Equal("test.blend.com", parsed.Query().Get("hd"), "we should set the hosted domain if it's configured") 417 } 418 419 func Test_Manager_OAuthURL(t *testing.T) { 420 assert := assert.New(t) 421 422 m, err := New() 423 assert.Nil(err) 424 m.ClientID = "test_client_id" 425 m.RedirectURL = "/oauth/google" 426 427 oauthURL, err := m.OAuthURL(&http.Request{RequestURI: "https://test.blend.com/foo"}) 428 assert.Nil(err) 429 430 _, err = url.Parse(oauthURL) 431 assert.Nil(err) 432 } 433 434 func Test_Manager_OAuthURLRedirect(t *testing.T) { 435 assert := assert.New(t) 436 437 m, err := New() 438 assert.Nil(err) 439 m.ClientID = "test_client_id" 440 m.RedirectURL = "https://local.shortcut-service.centrio.com/oauth/google" 441 442 urlFragment, err := m.OAuthURL(nil, OptStateRedirectURI("bar_foo")) 443 assert.Nil(err) 444 445 u, err := url.Parse(urlFragment) 446 assert.Nil(err) 447 assert.NotEmpty(u.Query().Get("state")) 448 449 state := u.Query().Get("state") 450 deserialized, err := DeserializeState(state) 451 assert.Nil(err) 452 assert.Nil(m.ValidateState(deserialized)) 453 assert.Equal("bar_foo", deserialized.RedirectURI) 454 } 455 456 func Test_Manager_ValidateState(t *testing.T) { 457 assert := assert.New(t) 458 459 insecure := MustNew() 460 assert.Nil(insecure.ValidateState(insecure.CreateState())) 461 462 secure := MustNew() 463 secure.Secret = crypto.MustCreateKey(32) 464 assert.Nil(secure.ValidateState(secure.CreateState())) 465 466 wrongKey := MustNew() 467 wrongKey.Secret = crypto.MustCreateKey(32) 468 469 assert.NotNil(secure.ValidateState(wrongKey.CreateState())) 470 }