github.com/snowflakedb/gosnowflake@v1.9.0/auth_test.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "context" 7 "crypto/rand" 8 "crypto/rsa" 9 "database/sql" 10 "encoding/json" 11 "errors" 12 "fmt" 13 "net/http" 14 "net/url" 15 "os" 16 "runtime" 17 "testing" 18 "time" 19 20 "github.com/form3tech-oss/jwt-go" 21 ) 22 23 func TestUnitPostAuth(t *testing.T) { 24 sr := &snowflakeRestful{ 25 TokenAccessor: getSimpleTokenAccessor(), 26 FuncAuthPost: postAuthTestAfterRenew, 27 } 28 var err error 29 bodyCreator := func() ([]byte, error) { 30 return []byte{0x12, 0x34}, nil 31 } 32 _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) 33 if err != nil { 34 t.Fatalf("err: %v", err) 35 } 36 sr.FuncAuthPost = postAuthTestError 37 _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) 38 if err == nil { 39 t.Fatal("should have failed to auth for unknown reason") 40 } 41 sr.FuncAuthPost = postAuthTestAppBadGatewayError 42 _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) 43 if err == nil { 44 t.Fatal("should have failed to auth for unknown reason") 45 } 46 sr.FuncAuthPost = postAuthTestAppForbiddenError 47 _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) 48 if err == nil { 49 t.Fatal("should have failed to auth for unknown reason") 50 } 51 sr.FuncAuthPost = postAuthTestAppUnexpectedError 52 _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) 53 if err == nil { 54 t.Fatal("should have failed to auth for unknown reason") 55 } 56 } 57 58 func postAuthFailServiceIssue(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { 59 return nil, &SnowflakeError{ 60 Number: ErrCodeServiceUnavailable, 61 } 62 } 63 64 func postAuthFailWrongAccount(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { 65 return nil, &SnowflakeError{ 66 Number: ErrCodeFailedToConnect, 67 } 68 } 69 70 func postAuthFailUnknown(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { 71 return nil, &SnowflakeError{ 72 Number: ErrFailedToAuth, 73 } 74 } 75 76 func postAuthSuccessWithErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { 77 return &authResponse{ 78 Success: false, 79 Code: "98765", 80 Message: "wrong!", 81 }, nil 82 } 83 84 func postAuthSuccessWithInvalidErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { 85 return &authResponse{ 86 Success: false, 87 Code: "abcdef", 88 Message: "wrong!", 89 }, nil 90 } 91 92 func postAuthSuccess(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { 93 return &authResponse{ 94 Success: true, 95 Data: authResponseMain{ 96 Token: "t", 97 MasterToken: "m", 98 SessionInfo: authResponseSessionInfo{ 99 DatabaseName: "dbn", 100 }, 101 }, 102 }, nil 103 } 104 105 func postAuthCheckSAMLResponse(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { 106 var ar authRequest 107 jsonBody, err := bodyCreator() 108 if err != nil { 109 return nil, err 110 } 111 if err = json.Unmarshal(jsonBody, &ar); err != nil { 112 return nil, err 113 } 114 if ar.Data.RawSAMLResponse == "" { 115 return nil, errors.New("SAML response is empty") 116 } 117 return &authResponse{ 118 Success: true, 119 Data: authResponseMain{ 120 Token: "t", 121 MasterToken: "m", 122 SessionInfo: authResponseSessionInfo{ 123 DatabaseName: "dbn", 124 }, 125 }, 126 }, nil 127 } 128 129 // Checks that the request body generated when authenticating with OAuth 130 // contains all the necessary values. 131 func postAuthCheckOAuth( 132 _ context.Context, 133 _ *snowflakeRestful, 134 _ *http.Client, 135 _ *url.Values, _ map[string]string, 136 bodyCreator bodyCreatorType, 137 _ time.Duration, 138 ) (*authResponse, error) { 139 var ar authRequest 140 jsonBody, _ := bodyCreator() 141 if err := json.Unmarshal(jsonBody, &ar); err != nil { 142 return nil, err 143 } 144 if ar.Data.Authenticator != AuthTypeOAuth.String() { 145 return nil, errors.New("Authenticator is not OAUTH") 146 } 147 if ar.Data.Token == "" { 148 return nil, errors.New("Token is empty") 149 } 150 if ar.Data.LoginName == "" { 151 return nil, errors.New("Login name is empty") 152 } 153 return &authResponse{ 154 Success: true, 155 Data: authResponseMain{ 156 Token: "t", 157 MasterToken: "m", 158 SessionInfo: authResponseSessionInfo{ 159 DatabaseName: "dbn", 160 }, 161 }, 162 }, nil 163 } 164 165 func postAuthCheckPasscode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { 166 var ar authRequest 167 jsonBody, _ := bodyCreator() 168 if err := json.Unmarshal(jsonBody, &ar); err != nil { 169 return nil, err 170 } 171 if ar.Data.Passcode != "987654321" || ar.Data.ExtAuthnDuoMethod != "passcode" { 172 return nil, fmt.Errorf("passcode didn't match. expected: 987654321, got: %v, duo: %v", ar.Data.Passcode, ar.Data.ExtAuthnDuoMethod) 173 } 174 return &authResponse{ 175 Success: true, 176 Data: authResponseMain{ 177 Token: "t", 178 MasterToken: "m", 179 SessionInfo: authResponseSessionInfo{ 180 DatabaseName: "dbn", 181 }, 182 }, 183 }, nil 184 } 185 186 func postAuthCheckPasscodeInPassword(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { 187 var ar authRequest 188 jsonBody, _ := bodyCreator() 189 if err := json.Unmarshal(jsonBody, &ar); err != nil { 190 return nil, err 191 } 192 if ar.Data.Passcode != "" || ar.Data.ExtAuthnDuoMethod != "passcode" { 193 return nil, fmt.Errorf("passcode must be empty, got: %v, duo: %v", ar.Data.Passcode, ar.Data.ExtAuthnDuoMethod) 194 } 195 return &authResponse{ 196 Success: true, 197 Data: authResponseMain{ 198 Token: "t", 199 MasterToken: "m", 200 SessionInfo: authResponseSessionInfo{ 201 DatabaseName: "dbn", 202 }, 203 }, 204 }, nil 205 } 206 207 // JWT token validate callback function to check the JWT token 208 // It uses the public key paired with the testPrivKey 209 func postAuthCheckJWTToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { 210 var ar authRequest 211 jsonBody, _ := bodyCreator() 212 if err := json.Unmarshal(jsonBody, &ar); err != nil { 213 return nil, err 214 } 215 if ar.Data.Authenticator != AuthTypeJwt.String() { 216 return nil, errors.New("Authenticator is not JWT") 217 } 218 219 tokenString := ar.Data.Token 220 221 // Validate token 222 _, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { 223 // Don't forget to validate the alg is what you expect: 224 if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { 225 return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) 226 } 227 228 return testPrivKey.Public(), nil 229 }) 230 if err != nil { 231 return nil, err 232 } 233 234 return &authResponse{ 235 Success: true, 236 Data: authResponseMain{ 237 Token: "t", 238 MasterToken: "m", 239 SessionInfo: authResponseSessionInfo{ 240 DatabaseName: "dbn", 241 }, 242 }, 243 }, nil 244 } 245 246 func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { 247 var ar authRequest 248 jsonBody, _ := bodyCreator() 249 if err := json.Unmarshal(jsonBody, &ar); err != nil { 250 return nil, err 251 } 252 253 if ar.Data.SessionParameters["CLIENT_REQUEST_MFA_TOKEN"] != true { 254 return nil, fmt.Errorf("expected client_request_mfa_token to be true but was %v", ar.Data.SessionParameters["CLIENT_REQUEST_MFA_TOKEN"]) 255 } 256 return &authResponse{ 257 Success: true, 258 Data: authResponseMain{ 259 Token: "t", 260 MasterToken: "m", 261 MfaToken: "mockedMfaToken", 262 SessionInfo: authResponseSessionInfo{ 263 DatabaseName: "dbn", 264 }, 265 }, 266 }, nil 267 } 268 269 func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { 270 var ar authRequest 271 jsonBody, _ := bodyCreator() 272 if err := json.Unmarshal(jsonBody, &ar); err != nil { 273 return nil, err 274 } 275 276 if ar.Data.Token != "mockedMfaToken" { 277 return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token) 278 } 279 return &authResponse{ 280 Success: true, 281 Data: authResponseMain{ 282 Token: "t", 283 MasterToken: "m", 284 MfaToken: "mockedMfaToken", 285 SessionInfo: authResponseSessionInfo{ 286 DatabaseName: "dbn", 287 }, 288 }, 289 }, nil 290 } 291 292 func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { 293 var ar authRequest 294 jsonBody, _ := bodyCreator() 295 if err := json.Unmarshal(jsonBody, &ar); err != nil { 296 return nil, err 297 } 298 299 if ar.Data.Token != "mockedMfaToken" { 300 return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token) 301 } 302 return &authResponse{ 303 Success: false, 304 Data: authResponseMain{}, 305 Message: "auth failed", 306 Code: "260008", 307 }, nil 308 } 309 310 func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { 311 var ar authRequest 312 jsonBody, _ := bodyCreator() 313 if err := json.Unmarshal(jsonBody, &ar); err != nil { 314 return nil, err 315 } 316 317 if ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"] != true { 318 return nil, fmt.Errorf("expected client_store_temporary_credential to be true but was %v", ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"]) 319 } 320 return &authResponse{ 321 Success: true, 322 Data: authResponseMain{ 323 Token: "t", 324 MasterToken: "m", 325 IDToken: "mockedIDToken", 326 SessionInfo: authResponseSessionInfo{ 327 DatabaseName: "dbn", 328 }, 329 }, 330 }, nil 331 } 332 333 func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { 334 var ar authRequest 335 jsonBody, _ := bodyCreator() 336 if err := json.Unmarshal(jsonBody, &ar); err != nil { 337 return nil, err 338 } 339 340 if ar.Data.Token != "mockedIDToken" { 341 return nil, fmt.Errorf("unexpected mfatoken: %v", ar.Data.Token) 342 } 343 return &authResponse{ 344 Success: true, 345 Data: authResponseMain{ 346 Token: "t", 347 MasterToken: "m", 348 IDToken: "mockedIDToken", 349 SessionInfo: authResponseSessionInfo{ 350 DatabaseName: "dbn", 351 }, 352 }, 353 }, nil 354 } 355 356 func postAuthCheckExternalBrowserFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { 357 var ar authRequest 358 jsonBody, _ := bodyCreator() 359 if err := json.Unmarshal(jsonBody, &ar); err != nil { 360 return nil, err 361 } 362 363 if ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"] != true { 364 return nil, fmt.Errorf("expected client_store_temporary_credential to be true but was %v", ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"]) 365 } 366 return &authResponse{ 367 Success: false, 368 Data: authResponseMain{}, 369 Message: "auth failed", 370 Code: "260008", 371 }, nil 372 } 373 374 func postAuthOktaWithNewToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { 375 var ar authRequest 376 377 cfg := &Config{ 378 Authenticator: AuthTypeOkta, 379 } 380 381 // Retry 3 times and success 382 client := &fakeHTTPClient{ 383 cnt: 3, 384 success: true, 385 statusCode: 429, 386 } 387 388 urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_guid=testguid") 389 if err != nil { 390 return &authResponse{}, err 391 } 392 393 body := func() ([]byte, error) { 394 jsonBody, _ := bodyCreator() 395 if err := json.Unmarshal(jsonBody, &ar); err != nil { 396 return nil, err 397 } 398 return jsonBody, err 399 } 400 401 _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, defaultTimeProvider, cfg).doPost().setBodyCreator(body).execute() 402 if err != nil { 403 return &authResponse{}, err 404 } 405 406 return &authResponse{ 407 Success: true, 408 Data: authResponseMain{ 409 Token: "t", 410 MasterToken: "m", 411 MfaToken: "mockedMfaToken", 412 SessionInfo: authResponseSessionInfo{ 413 DatabaseName: "dbn", 414 }, 415 }, 416 }, nil 417 } 418 419 func getDefaultSnowflakeConn() *snowflakeConn { 420 sc := &snowflakeConn{ 421 rest: &snowflakeRestful{ 422 TokenAccessor: getSimpleTokenAccessor(), 423 }, 424 cfg: &Config{ 425 Account: "a", 426 User: "u", 427 Password: "p", 428 Database: "d", 429 Schema: "s", 430 Warehouse: "w", 431 Role: "r", 432 Region: "", 433 Params: make(map[string]*string), 434 PasscodeInPassword: false, 435 Passcode: "", 436 Application: "testapp", 437 }, 438 telemetry: &snowflakeTelemetry{enabled: false}, 439 } 440 return sc 441 } 442 443 func TestUnitAuthenticateWithTokenAccessor(t *testing.T) { 444 expectedSessionID := int64(123) 445 expectedMasterToken := "master_token" 446 expectedToken := "auth_token" 447 448 ta := getSimpleTokenAccessor() 449 ta.SetTokens(expectedToken, expectedMasterToken, expectedSessionID) 450 sc := getDefaultSnowflakeConn() 451 sc.cfg.Authenticator = AuthTypeTokenAccessor 452 sc.cfg.TokenAccessor = ta 453 sr := &snowflakeRestful{ 454 FuncPostAuth: postAuthFailServiceIssue, 455 TokenAccessor: ta, 456 } 457 sc.rest = sr 458 459 // FuncPostAuth is set to fail, but AuthTypeTokenAccessor should not even make a call to FuncPostAuth 460 resp, err := authenticate(context.Background(), sc, []byte{}, []byte{}) 461 if err != nil { 462 t.Fatalf("should not have failed, err %v", err) 463 } 464 465 if resp.SessionID != expectedSessionID { 466 t.Fatalf("Expected session id %v but got %v", expectedSessionID, resp.SessionID) 467 } 468 if resp.Token != expectedToken { 469 t.Fatalf("Expected token %v but got %v", expectedToken, resp.Token) 470 } 471 if resp.MasterToken != expectedMasterToken { 472 t.Fatalf("Expected master token %v but got %v", expectedMasterToken, resp.MasterToken) 473 } 474 if resp.SessionInfo.DatabaseName != sc.cfg.Database { 475 t.Fatalf("Expected database %v but got %v", sc.cfg.Database, resp.SessionInfo.DatabaseName) 476 } 477 if resp.SessionInfo.WarehouseName != sc.cfg.Warehouse { 478 t.Fatalf("Expected warehouse %v but got %v", sc.cfg.Warehouse, resp.SessionInfo.WarehouseName) 479 } 480 if resp.SessionInfo.RoleName != sc.cfg.Role { 481 t.Fatalf("Expected role %v but got %v", sc.cfg.Role, resp.SessionInfo.RoleName) 482 } 483 if resp.SessionInfo.SchemaName != sc.cfg.Schema { 484 t.Fatalf("Expected schema %v but got %v", sc.cfg.Schema, resp.SessionInfo.SchemaName) 485 } 486 } 487 488 func TestUnitAuthenticate(t *testing.T) { 489 var err error 490 var driverErr *SnowflakeError 491 var ok bool 492 493 ta := getSimpleTokenAccessor() 494 sc := getDefaultSnowflakeConn() 495 sr := &snowflakeRestful{ 496 FuncPostAuth: postAuthFailServiceIssue, 497 TokenAccessor: ta, 498 } 499 sc.rest = sr 500 501 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 502 if err == nil { 503 t.Fatal("should have failed.") 504 } 505 driverErr, ok = err.(*SnowflakeError) 506 if !ok || driverErr.Number != ErrCodeServiceUnavailable { 507 t.Fatalf("Snowflake error is expected. err: %v", driverErr) 508 } 509 sr.FuncPostAuth = postAuthFailWrongAccount 510 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 511 if err == nil { 512 t.Fatal("should have failed.") 513 } 514 driverErr, ok = err.(*SnowflakeError) 515 if !ok || driverErr.Number != ErrCodeFailedToConnect { 516 t.Fatalf("Snowflake error is expected. err: %v", driverErr) 517 } 518 sr.FuncPostAuth = postAuthFailUnknown 519 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 520 if err == nil { 521 t.Fatal("should have failed.") 522 } 523 driverErr, ok = err.(*SnowflakeError) 524 if !ok || driverErr.Number != ErrFailedToAuth { 525 t.Fatalf("Snowflake error is expected. err: %v", driverErr) 526 } 527 ta.SetTokens("bad-token", "bad-master-token", 1) 528 sr.FuncPostAuth = postAuthSuccessWithErrorCode 529 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 530 if err == nil { 531 t.Fatal("should have failed.") 532 } 533 newToken, newMasterToken, newSessionID := ta.GetTokens() 534 if newToken != "" || newMasterToken != "" || newSessionID != -1 { 535 t.Fatalf("failed auth should have reset tokens: %v %v %v", newToken, newMasterToken, newSessionID) 536 } 537 driverErr, ok = err.(*SnowflakeError) 538 if !ok || driverErr.Number != 98765 { 539 t.Fatalf("Snowflake error is expected. err: %v", driverErr) 540 } 541 ta.SetTokens("bad-token", "bad-master-token", 1) 542 sr.FuncPostAuth = postAuthSuccessWithInvalidErrorCode 543 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 544 if err == nil { 545 t.Fatal("should have failed.") 546 } 547 oldToken, oldMasterToken, oldSessionID := ta.GetTokens() 548 if oldToken != "" || oldMasterToken != "" || oldSessionID != -1 { 549 t.Fatalf("failed auth should have reset tokens: %v %v %v", oldToken, oldMasterToken, oldSessionID) 550 } 551 sr.FuncPostAuth = postAuthSuccess 552 var resp *authResponseMain 553 resp, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 554 if err != nil { 555 t.Fatalf("failed to auth. err: %v", err) 556 } 557 if resp.SessionInfo.DatabaseName != "dbn" { 558 t.Fatalf("failed to get response from auth") 559 } 560 newToken, newMasterToken, newSessionID = ta.GetTokens() 561 if newToken == oldToken { 562 t.Fatalf("new token was not set: %v", newToken) 563 } 564 if newMasterToken == oldMasterToken { 565 t.Fatalf("new master token was not set: %v", newMasterToken) 566 } 567 if newSessionID == oldSessionID { 568 t.Fatalf("new session id was not set: %v", newSessionID) 569 } 570 } 571 572 func TestUnitAuthenticateSaml(t *testing.T) { 573 var err error 574 sr := &snowflakeRestful{ 575 Protocol: "https", 576 Host: "abc.com", 577 Port: 443, 578 FuncPostAuthSAML: postAuthSAMLAuthSuccess, 579 FuncPostAuthOKTA: postAuthOKTASuccess, 580 FuncGetSSO: getSSOSuccess, 581 FuncPostAuth: postAuthCheckSAMLResponse, 582 TokenAccessor: getSimpleTokenAccessor(), 583 } 584 sc := getDefaultSnowflakeConn() 585 sc.cfg.Authenticator = AuthTypeOkta 586 sc.cfg.OktaURL = &url.URL{ 587 Scheme: "https", 588 Host: "abc.com", 589 } 590 sc.rest = sr 591 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 592 assertNilF(t, err, "failed to run.") 593 } 594 595 // Unit test for OAuth. 596 func TestUnitAuthenticateOAuth(t *testing.T) { 597 var err error 598 sr := &snowflakeRestful{ 599 FuncPostAuth: postAuthCheckOAuth, 600 TokenAccessor: getSimpleTokenAccessor(), 601 } 602 sc := getDefaultSnowflakeConn() 603 sc.cfg.Token = "oauthToken" 604 sc.cfg.Authenticator = AuthTypeOAuth 605 sc.rest = sr 606 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 607 if err != nil { 608 t.Fatalf("failed to run. err: %v", err) 609 } 610 } 611 612 func TestUnitAuthenticatePasscode(t *testing.T) { 613 var err error 614 sr := &snowflakeRestful{ 615 FuncPostAuth: postAuthCheckPasscode, 616 TokenAccessor: getSimpleTokenAccessor(), 617 } 618 sc := getDefaultSnowflakeConn() 619 sc.cfg.Passcode = "987654321" 620 sc.rest = sr 621 622 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 623 if err != nil { 624 t.Fatalf("failed to run. err: %v", err) 625 } 626 sr.FuncPostAuth = postAuthCheckPasscodeInPassword 627 sc.rest = sr 628 sc.cfg.PasscodeInPassword = true 629 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 630 if err != nil { 631 t.Fatalf("failed to run. err: %v", err) 632 } 633 } 634 635 // Test JWT function in the local environment against the validation function in go 636 func TestUnitAuthenticateJWT(t *testing.T) { 637 var err error 638 639 sr := &snowflakeRestful{ 640 FuncPostAuth: postAuthCheckJWTToken, 641 TokenAccessor: getSimpleTokenAccessor(), 642 } 643 sc := getDefaultSnowflakeConn() 644 sc.cfg.Authenticator = AuthTypeJwt 645 sc.cfg.JWTExpireTimeout = defaultJWTTimeout 646 sc.cfg.PrivateKey = testPrivKey 647 sc.rest = sr 648 649 // A valid JWT token should pass 650 if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err != nil { 651 t.Fatalf("failed to run. err: %v", err) 652 } 653 654 // An invalid JWT token should not pass 655 invalidPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) 656 if err != nil { 657 t.Error(err) 658 } 659 sc.cfg.PrivateKey = invalidPrivateKey 660 if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err == nil { 661 t.Fatalf("invalid token passed") 662 } 663 } 664 665 func TestUnitAuthenticateUsernamePasswordMfa(t *testing.T) { 666 var err error 667 sr := &snowflakeRestful{ 668 FuncPostAuth: postAuthCheckUsernamePasswordMfa, 669 TokenAccessor: getSimpleTokenAccessor(), 670 } 671 sc := getDefaultSnowflakeConn() 672 sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA 673 sc.cfg.ClientRequestMfaToken = ConfigBoolTrue 674 sc.rest = sr 675 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 676 if err != nil { 677 t.Fatalf("failed to run. err: %v", err) 678 } 679 680 sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaToken 681 sc.cfg.MfaToken = "mockedMfaToken" 682 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 683 if err != nil { 684 t.Fatalf("failed to run. err: %v", err) 685 } 686 687 sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaFailed 688 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 689 if err == nil { 690 t.Fatal("should have failed") 691 } 692 } 693 694 func TestUnitAuthenticateWithConfigMFA(t *testing.T) { 695 var err error 696 sr := &snowflakeRestful{ 697 FuncPostAuth: postAuthCheckUsernamePasswordMfa, 698 TokenAccessor: getSimpleTokenAccessor(), 699 } 700 sc := getDefaultSnowflakeConn() 701 sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA 702 sc.cfg.ClientRequestMfaToken = ConfigBoolTrue 703 sc.rest = sr 704 sc.ctx = context.Background() 705 err = authenticateWithConfig(sc) 706 if err != nil { 707 t.Fatalf("failed to run. err: %v", err) 708 } 709 } 710 711 func TestUnitAuthenticateWithConfigOkta(t *testing.T) { 712 var err error 713 sr := &snowflakeRestful{ 714 Protocol: "https", 715 Host: "abc.com", 716 Port: 443, 717 FuncPostAuthSAML: postAuthSAMLAuthSuccess, 718 FuncPostAuthOKTA: postAuthOKTASuccess, 719 FuncGetSSO: getSSOSuccess, 720 FuncPostAuth: postAuthCheckSAMLResponse, 721 TokenAccessor: getSimpleTokenAccessor(), 722 } 723 sc := getDefaultSnowflakeConn() 724 sc.cfg.Authenticator = AuthTypeOkta 725 sc.cfg.OktaURL = &url.URL{ 726 Scheme: "https", 727 Host: "abc.com", 728 } 729 sc.rest = sr 730 sc.ctx = context.Background() 731 732 err = authenticateWithConfig(sc) 733 assertNilE(t, err, "expected to have no error.") 734 735 sr.FuncPostAuthSAML = postAuthSAMLError 736 err = authenticateWithConfig(sc) 737 assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") 738 assertEqualE(t, err.Error(), "failed to get SAML response") 739 } 740 741 func TestUnitAuthenticateWithConfigExternalBrowser(t *testing.T) { 742 var err error 743 sr := &snowflakeRestful{ 744 FuncPostAuthSAML: postAuthSAMLError, 745 TokenAccessor: getSimpleTokenAccessor(), 746 } 747 sc := getDefaultSnowflakeConn() 748 sc.cfg.Authenticator = AuthTypeExternalBrowser 749 sc.cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout 750 sc.rest = sr 751 sc.ctx = context.Background() 752 err = authenticateWithConfig(sc) 753 assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") 754 assertEqualE(t, err.Error(), "failed to get SAML response") 755 } 756 757 func TestUnitAuthenticateExternalBrowser(t *testing.T) { 758 var err error 759 sr := &snowflakeRestful{ 760 FuncPostAuth: postAuthCheckExternalBrowser, 761 TokenAccessor: getSimpleTokenAccessor(), 762 } 763 sc := getDefaultSnowflakeConn() 764 sc.cfg.Authenticator = AuthTypeExternalBrowser 765 sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue 766 sc.rest = sr 767 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 768 if err != nil { 769 t.Fatalf("failed to run. err: %v", err) 770 } 771 772 sr.FuncPostAuth = postAuthCheckExternalBrowserToken 773 sc.cfg.IDToken = "mockedIDToken" 774 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 775 if err != nil { 776 t.Fatalf("failed to run. err: %v", err) 777 } 778 779 sr.FuncPostAuth = postAuthCheckExternalBrowserFailed 780 _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) 781 if err == nil { 782 t.Fatal("should have failed") 783 } 784 } 785 786 // To run this test you need to set environment variables in parameters.json to a user with MFA authentication enabled 787 // Set any other snowflake_test variables needed for database, schema, role for this user 788 func TestUsernamePasswordMfaCaching(t *testing.T) { 789 t.Skip("manual test for MFA token caching") 790 791 config, err := ParseDSN(dsn) 792 if err != nil { 793 t.Fatal("Failed to parse dsn") 794 } 795 // connect with MFA authentication 796 user := os.Getenv("SNOWFLAKE_TEST_MFA_USER") 797 password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD") 798 config.User = user 799 config.Password = password 800 config.Authenticator = AuthTypeUsernamePasswordMFA 801 if runtime.GOOS == "linux" { 802 config.ClientRequestMfaToken = ConfigBoolTrue 803 } 804 connector := NewConnector(SnowflakeDriver{}, *config) 805 db := sql.OpenDB(connector) 806 for i := 0; i < 3; i++ { 807 // should only be prompted to authenticate first time around. 808 _, err := db.Query("select current_user()") 809 if err != nil { 810 t.Fatal(err) 811 } 812 } 813 } 814 815 // To run this test you need to set environment variables in parameters.json to a user with MFA authentication enabled 816 // Set any other snowflake_test variables needed for database, schema, role for this user 817 func TestDisableUsernamePasswordMfaCaching(t *testing.T) { 818 t.Skip("manual test for disabling MFA token caching") 819 820 config, err := ParseDSN(dsn) 821 if err != nil { 822 t.Fatal("Failed to parse dsn") 823 } 824 // connect with MFA authentication 825 user := os.Getenv("SNOWFLAKE_TEST_MFA_USER") 826 password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD") 827 config.User = user 828 config.Password = password 829 config.Authenticator = AuthTypeUsernamePasswordMFA 830 // disable MFA token caching 831 config.ClientRequestMfaToken = ConfigBoolFalse 832 connector := NewConnector(SnowflakeDriver{}, *config) 833 db := sql.OpenDB(connector) 834 for i := 0; i < 3; i++ { 835 // should be prompted to authenticate 3 times. 836 _, err := db.Query("select current_user()") 837 if err != nil { 838 t.Fatal(err) 839 } 840 } 841 } 842 843 // To run this test you need to set SNOWFLAKE_TEST_EXT_BROWSER_USER environment variable to an external browser user 844 // Set any other snowflake_test variables needed for database, schema, role for this user 845 func TestExternalBrowserCaching(t *testing.T) { 846 t.Skip("manual test for external browser token caching") 847 848 config, err := ParseDSN(dsn) 849 if err != nil { 850 t.Fatal("Failed to parse dsn") 851 } 852 // connect with external browser authentication 853 user := os.Getenv("SNOWFLAKE_TEST_EXT_BROWSER_USER") 854 config.User = user 855 config.Authenticator = AuthTypeExternalBrowser 856 if runtime.GOOS == "linux" { 857 config.ClientStoreTemporaryCredential = ConfigBoolTrue 858 } 859 connector := NewConnector(SnowflakeDriver{}, *config) 860 db := sql.OpenDB(connector) 861 for i := 0; i < 3; i++ { 862 // should only be prompted to authenticate first time around. 863 _, err := db.Query("select current_user()") 864 if err != nil { 865 t.Fatal(err) 866 } 867 } 868 } 869 870 // To run this test you need to set SNOWFLAKE_TEST_EXT_BROWSER_USER environment variable to an external browser user 871 // Set any other snowflake_test variables needed for database, schema, role for this user 872 func TestDisableExternalBrowserCaching(t *testing.T) { 873 t.Skip("manual test for disabling external browser token caching") 874 875 config, err := ParseDSN(dsn) 876 if err != nil { 877 t.Fatal("Failed to parse dsn") 878 } 879 // connect with external browser authentication 880 user := os.Getenv("SNOWFLAKE_TEST_EXT_BROWSER_USER") 881 config.User = user 882 config.Authenticator = AuthTypeExternalBrowser 883 // disable external browser token caching 884 config.ClientStoreTemporaryCredential = ConfigBoolFalse 885 connector := NewConnector(SnowflakeDriver{}, *config) 886 db := sql.OpenDB(connector) 887 for i := 0; i < 3; i++ { 888 // should be prompted to authenticate 3 times. 889 _, err := db.Query("select current_user()") 890 if err != nil { 891 t.Fatal(err) 892 } 893 } 894 } 895 896 func TestOktaRetryWithNewToken(t *testing.T) { 897 expectedMasterToken := "m" 898 expectedToken := "t" 899 expectedMfaToken := "mockedMfaToken" 900 expectedDatabaseName := "dbn" 901 902 sr := &snowflakeRestful{ 903 Protocol: "https", 904 Host: "abc.com", 905 Port: 443, 906 FuncPostAuthSAML: postAuthSAMLAuthSuccess, 907 FuncPostAuthOKTA: postAuthOKTASuccess, 908 FuncGetSSO: getSSOSuccess, 909 FuncPostAuth: postAuthOktaWithNewToken, 910 TokenAccessor: getSimpleTokenAccessor(), 911 } 912 sc := getDefaultSnowflakeConn() 913 sc.cfg.Authenticator = AuthTypeOkta 914 sc.cfg.OktaURL = &url.URL{ 915 Scheme: "https", 916 Host: "abc.com", 917 } 918 sc.rest = sr 919 sc.ctx = context.Background() 920 921 authResponse, err := authenticate(context.Background(), sc, []byte{0x12, 0x34}, []byte{0x56, 0x78}) 922 assertNilF(t, err, "should not have failed to run authenticate()") 923 assertEqualF(t, authResponse.MasterToken, expectedMasterToken) 924 assertEqualF(t, authResponse.Token, expectedToken) 925 assertEqualF(t, authResponse.MfaToken, expectedMfaToken) 926 assertEqualF(t, authResponse.SessionInfo.DatabaseName, expectedDatabaseName) 927 }