github.com/snowflakedb/gosnowflake@v1.9.0/auth_test.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"crypto/rand"
     8  	"crypto/rsa"
     9  	"database/sql"
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"net/http"
    14  	"net/url"
    15  	"os"
    16  	"runtime"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/form3tech-oss/jwt-go"
    21  )
    22  
    23  func TestUnitPostAuth(t *testing.T) {
    24  	sr := &snowflakeRestful{
    25  		TokenAccessor: getSimpleTokenAccessor(),
    26  		FuncAuthPost:  postAuthTestAfterRenew,
    27  	}
    28  	var err error
    29  	bodyCreator := func() ([]byte, error) {
    30  		return []byte{0x12, 0x34}, nil
    31  	}
    32  	_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
    33  	if err != nil {
    34  		t.Fatalf("err: %v", err)
    35  	}
    36  	sr.FuncAuthPost = postAuthTestError
    37  	_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
    38  	if err == nil {
    39  		t.Fatal("should have failed to auth for unknown reason")
    40  	}
    41  	sr.FuncAuthPost = postAuthTestAppBadGatewayError
    42  	_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
    43  	if err == nil {
    44  		t.Fatal("should have failed to auth for unknown reason")
    45  	}
    46  	sr.FuncAuthPost = postAuthTestAppForbiddenError
    47  	_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
    48  	if err == nil {
    49  		t.Fatal("should have failed to auth for unknown reason")
    50  	}
    51  	sr.FuncAuthPost = postAuthTestAppUnexpectedError
    52  	_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
    53  	if err == nil {
    54  		t.Fatal("should have failed to auth for unknown reason")
    55  	}
    56  }
    57  
    58  func postAuthFailServiceIssue(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
    59  	return nil, &SnowflakeError{
    60  		Number: ErrCodeServiceUnavailable,
    61  	}
    62  }
    63  
    64  func postAuthFailWrongAccount(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
    65  	return nil, &SnowflakeError{
    66  		Number: ErrCodeFailedToConnect,
    67  	}
    68  }
    69  
    70  func postAuthFailUnknown(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
    71  	return nil, &SnowflakeError{
    72  		Number: ErrFailedToAuth,
    73  	}
    74  }
    75  
    76  func postAuthSuccessWithErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
    77  	return &authResponse{
    78  		Success: false,
    79  		Code:    "98765",
    80  		Message: "wrong!",
    81  	}, nil
    82  }
    83  
    84  func postAuthSuccessWithInvalidErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
    85  	return &authResponse{
    86  		Success: false,
    87  		Code:    "abcdef",
    88  		Message: "wrong!",
    89  	}, nil
    90  }
    91  
    92  func postAuthSuccess(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
    93  	return &authResponse{
    94  		Success: true,
    95  		Data: authResponseMain{
    96  			Token:       "t",
    97  			MasterToken: "m",
    98  			SessionInfo: authResponseSessionInfo{
    99  				DatabaseName: "dbn",
   100  			},
   101  		},
   102  	}, nil
   103  }
   104  
   105  func postAuthCheckSAMLResponse(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
   106  	var ar authRequest
   107  	jsonBody, err := bodyCreator()
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	if err = json.Unmarshal(jsonBody, &ar); err != nil {
   112  		return nil, err
   113  	}
   114  	if ar.Data.RawSAMLResponse == "" {
   115  		return nil, errors.New("SAML response is empty")
   116  	}
   117  	return &authResponse{
   118  		Success: true,
   119  		Data: authResponseMain{
   120  			Token:       "t",
   121  			MasterToken: "m",
   122  			SessionInfo: authResponseSessionInfo{
   123  				DatabaseName: "dbn",
   124  			},
   125  		},
   126  	}, nil
   127  }
   128  
   129  // Checks that the request body generated when authenticating with OAuth
   130  // contains all the necessary values.
   131  func postAuthCheckOAuth(
   132  	_ context.Context,
   133  	_ *snowflakeRestful,
   134  	_ *http.Client,
   135  	_ *url.Values, _ map[string]string,
   136  	bodyCreator bodyCreatorType,
   137  	_ time.Duration,
   138  ) (*authResponse, error) {
   139  	var ar authRequest
   140  	jsonBody, _ := bodyCreator()
   141  	if err := json.Unmarshal(jsonBody, &ar); err != nil {
   142  		return nil, err
   143  	}
   144  	if ar.Data.Authenticator != AuthTypeOAuth.String() {
   145  		return nil, errors.New("Authenticator is not OAUTH")
   146  	}
   147  	if ar.Data.Token == "" {
   148  		return nil, errors.New("Token is empty")
   149  	}
   150  	if ar.Data.LoginName == "" {
   151  		return nil, errors.New("Login name is empty")
   152  	}
   153  	return &authResponse{
   154  		Success: true,
   155  		Data: authResponseMain{
   156  			Token:       "t",
   157  			MasterToken: "m",
   158  			SessionInfo: authResponseSessionInfo{
   159  				DatabaseName: "dbn",
   160  			},
   161  		},
   162  	}, nil
   163  }
   164  
   165  func postAuthCheckPasscode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
   166  	var ar authRequest
   167  	jsonBody, _ := bodyCreator()
   168  	if err := json.Unmarshal(jsonBody, &ar); err != nil {
   169  		return nil, err
   170  	}
   171  	if ar.Data.Passcode != "987654321" || ar.Data.ExtAuthnDuoMethod != "passcode" {
   172  		return nil, fmt.Errorf("passcode didn't match. expected: 987654321, got: %v, duo: %v", ar.Data.Passcode, ar.Data.ExtAuthnDuoMethod)
   173  	}
   174  	return &authResponse{
   175  		Success: true,
   176  		Data: authResponseMain{
   177  			Token:       "t",
   178  			MasterToken: "m",
   179  			SessionInfo: authResponseSessionInfo{
   180  				DatabaseName: "dbn",
   181  			},
   182  		},
   183  	}, nil
   184  }
   185  
   186  func postAuthCheckPasscodeInPassword(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
   187  	var ar authRequest
   188  	jsonBody, _ := bodyCreator()
   189  	if err := json.Unmarshal(jsonBody, &ar); err != nil {
   190  		return nil, err
   191  	}
   192  	if ar.Data.Passcode != "" || ar.Data.ExtAuthnDuoMethod != "passcode" {
   193  		return nil, fmt.Errorf("passcode must be empty, got: %v, duo: %v", ar.Data.Passcode, ar.Data.ExtAuthnDuoMethod)
   194  	}
   195  	return &authResponse{
   196  		Success: true,
   197  		Data: authResponseMain{
   198  			Token:       "t",
   199  			MasterToken: "m",
   200  			SessionInfo: authResponseSessionInfo{
   201  				DatabaseName: "dbn",
   202  			},
   203  		},
   204  	}, nil
   205  }
   206  
   207  // JWT token validate callback function to check the JWT token
   208  // It uses the public key paired with the testPrivKey
   209  func postAuthCheckJWTToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
   210  	var ar authRequest
   211  	jsonBody, _ := bodyCreator()
   212  	if err := json.Unmarshal(jsonBody, &ar); err != nil {
   213  		return nil, err
   214  	}
   215  	if ar.Data.Authenticator != AuthTypeJwt.String() {
   216  		return nil, errors.New("Authenticator is not JWT")
   217  	}
   218  
   219  	tokenString := ar.Data.Token
   220  
   221  	// Validate token
   222  	_, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
   223  		// Don't forget to validate the alg is what you expect:
   224  		if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
   225  			return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
   226  		}
   227  
   228  		return testPrivKey.Public(), nil
   229  	})
   230  	if err != nil {
   231  		return nil, err
   232  	}
   233  
   234  	return &authResponse{
   235  		Success: true,
   236  		Data: authResponseMain{
   237  			Token:       "t",
   238  			MasterToken: "m",
   239  			SessionInfo: authResponseSessionInfo{
   240  				DatabaseName: "dbn",
   241  			},
   242  		},
   243  	}, nil
   244  }
   245  
   246  func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
   247  	var ar authRequest
   248  	jsonBody, _ := bodyCreator()
   249  	if err := json.Unmarshal(jsonBody, &ar); err != nil {
   250  		return nil, err
   251  	}
   252  
   253  	if ar.Data.SessionParameters["CLIENT_REQUEST_MFA_TOKEN"] != true {
   254  		return nil, fmt.Errorf("expected client_request_mfa_token to be true but was %v", ar.Data.SessionParameters["CLIENT_REQUEST_MFA_TOKEN"])
   255  	}
   256  	return &authResponse{
   257  		Success: true,
   258  		Data: authResponseMain{
   259  			Token:       "t",
   260  			MasterToken: "m",
   261  			MfaToken:    "mockedMfaToken",
   262  			SessionInfo: authResponseSessionInfo{
   263  				DatabaseName: "dbn",
   264  			},
   265  		},
   266  	}, nil
   267  }
   268  
   269  func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
   270  	var ar authRequest
   271  	jsonBody, _ := bodyCreator()
   272  	if err := json.Unmarshal(jsonBody, &ar); err != nil {
   273  		return nil, err
   274  	}
   275  
   276  	if ar.Data.Token != "mockedMfaToken" {
   277  		return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token)
   278  	}
   279  	return &authResponse{
   280  		Success: true,
   281  		Data: authResponseMain{
   282  			Token:       "t",
   283  			MasterToken: "m",
   284  			MfaToken:    "mockedMfaToken",
   285  			SessionInfo: authResponseSessionInfo{
   286  				DatabaseName: "dbn",
   287  			},
   288  		},
   289  	}, nil
   290  }
   291  
   292  func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
   293  	var ar authRequest
   294  	jsonBody, _ := bodyCreator()
   295  	if err := json.Unmarshal(jsonBody, &ar); err != nil {
   296  		return nil, err
   297  	}
   298  
   299  	if ar.Data.Token != "mockedMfaToken" {
   300  		return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token)
   301  	}
   302  	return &authResponse{
   303  		Success: false,
   304  		Data:    authResponseMain{},
   305  		Message: "auth failed",
   306  		Code:    "260008",
   307  	}, nil
   308  }
   309  
   310  func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
   311  	var ar authRequest
   312  	jsonBody, _ := bodyCreator()
   313  	if err := json.Unmarshal(jsonBody, &ar); err != nil {
   314  		return nil, err
   315  	}
   316  
   317  	if ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"] != true {
   318  		return nil, fmt.Errorf("expected client_store_temporary_credential to be true but was %v", ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"])
   319  	}
   320  	return &authResponse{
   321  		Success: true,
   322  		Data: authResponseMain{
   323  			Token:       "t",
   324  			MasterToken: "m",
   325  			IDToken:     "mockedIDToken",
   326  			SessionInfo: authResponseSessionInfo{
   327  				DatabaseName: "dbn",
   328  			},
   329  		},
   330  	}, nil
   331  }
   332  
   333  func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
   334  	var ar authRequest
   335  	jsonBody, _ := bodyCreator()
   336  	if err := json.Unmarshal(jsonBody, &ar); err != nil {
   337  		return nil, err
   338  	}
   339  
   340  	if ar.Data.Token != "mockedIDToken" {
   341  		return nil, fmt.Errorf("unexpected mfatoken: %v", ar.Data.Token)
   342  	}
   343  	return &authResponse{
   344  		Success: true,
   345  		Data: authResponseMain{
   346  			Token:       "t",
   347  			MasterToken: "m",
   348  			IDToken:     "mockedIDToken",
   349  			SessionInfo: authResponseSessionInfo{
   350  				DatabaseName: "dbn",
   351  			},
   352  		},
   353  	}, nil
   354  }
   355  
   356  func postAuthCheckExternalBrowserFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
   357  	var ar authRequest
   358  	jsonBody, _ := bodyCreator()
   359  	if err := json.Unmarshal(jsonBody, &ar); err != nil {
   360  		return nil, err
   361  	}
   362  
   363  	if ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"] != true {
   364  		return nil, fmt.Errorf("expected client_store_temporary_credential to be true but was %v", ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"])
   365  	}
   366  	return &authResponse{
   367  		Success: false,
   368  		Data:    authResponseMain{},
   369  		Message: "auth failed",
   370  		Code:    "260008",
   371  	}, nil
   372  }
   373  
   374  func postAuthOktaWithNewToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
   375  	var ar authRequest
   376  
   377  	cfg := &Config{
   378  		Authenticator: AuthTypeOkta,
   379  	}
   380  
   381  	// Retry 3 times and success
   382  	client := &fakeHTTPClient{
   383  		cnt:        3,
   384  		success:    true,
   385  		statusCode: 429,
   386  	}
   387  
   388  	urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_guid=testguid")
   389  	if err != nil {
   390  		return &authResponse{}, err
   391  	}
   392  
   393  	body := func() ([]byte, error) {
   394  		jsonBody, _ := bodyCreator()
   395  		if err := json.Unmarshal(jsonBody, &ar); err != nil {
   396  			return nil, err
   397  		}
   398  		return jsonBody, err
   399  	}
   400  
   401  	_, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, defaultTimeProvider, cfg).doPost().setBodyCreator(body).execute()
   402  	if err != nil {
   403  		return &authResponse{}, err
   404  	}
   405  
   406  	return &authResponse{
   407  		Success: true,
   408  		Data: authResponseMain{
   409  			Token:       "t",
   410  			MasterToken: "m",
   411  			MfaToken:    "mockedMfaToken",
   412  			SessionInfo: authResponseSessionInfo{
   413  				DatabaseName: "dbn",
   414  			},
   415  		},
   416  	}, nil
   417  }
   418  
   419  func getDefaultSnowflakeConn() *snowflakeConn {
   420  	sc := &snowflakeConn{
   421  		rest: &snowflakeRestful{
   422  			TokenAccessor: getSimpleTokenAccessor(),
   423  		},
   424  		cfg: &Config{
   425  			Account:            "a",
   426  			User:               "u",
   427  			Password:           "p",
   428  			Database:           "d",
   429  			Schema:             "s",
   430  			Warehouse:          "w",
   431  			Role:               "r",
   432  			Region:             "",
   433  			Params:             make(map[string]*string),
   434  			PasscodeInPassword: false,
   435  			Passcode:           "",
   436  			Application:        "testapp",
   437  		},
   438  		telemetry: &snowflakeTelemetry{enabled: false},
   439  	}
   440  	return sc
   441  }
   442  
   443  func TestUnitAuthenticateWithTokenAccessor(t *testing.T) {
   444  	expectedSessionID := int64(123)
   445  	expectedMasterToken := "master_token"
   446  	expectedToken := "auth_token"
   447  
   448  	ta := getSimpleTokenAccessor()
   449  	ta.SetTokens(expectedToken, expectedMasterToken, expectedSessionID)
   450  	sc := getDefaultSnowflakeConn()
   451  	sc.cfg.Authenticator = AuthTypeTokenAccessor
   452  	sc.cfg.TokenAccessor = ta
   453  	sr := &snowflakeRestful{
   454  		FuncPostAuth:  postAuthFailServiceIssue,
   455  		TokenAccessor: ta,
   456  	}
   457  	sc.rest = sr
   458  
   459  	// FuncPostAuth is set to fail, but AuthTypeTokenAccessor should not even make a call to FuncPostAuth
   460  	resp, err := authenticate(context.Background(), sc, []byte{}, []byte{})
   461  	if err != nil {
   462  		t.Fatalf("should not have failed, err %v", err)
   463  	}
   464  
   465  	if resp.SessionID != expectedSessionID {
   466  		t.Fatalf("Expected session id %v but got %v", expectedSessionID, resp.SessionID)
   467  	}
   468  	if resp.Token != expectedToken {
   469  		t.Fatalf("Expected token %v but got %v", expectedToken, resp.Token)
   470  	}
   471  	if resp.MasterToken != expectedMasterToken {
   472  		t.Fatalf("Expected master token %v but got %v", expectedMasterToken, resp.MasterToken)
   473  	}
   474  	if resp.SessionInfo.DatabaseName != sc.cfg.Database {
   475  		t.Fatalf("Expected database %v but got %v", sc.cfg.Database, resp.SessionInfo.DatabaseName)
   476  	}
   477  	if resp.SessionInfo.WarehouseName != sc.cfg.Warehouse {
   478  		t.Fatalf("Expected warehouse %v but got %v", sc.cfg.Warehouse, resp.SessionInfo.WarehouseName)
   479  	}
   480  	if resp.SessionInfo.RoleName != sc.cfg.Role {
   481  		t.Fatalf("Expected role %v but got %v", sc.cfg.Role, resp.SessionInfo.RoleName)
   482  	}
   483  	if resp.SessionInfo.SchemaName != sc.cfg.Schema {
   484  		t.Fatalf("Expected schema %v but got %v", sc.cfg.Schema, resp.SessionInfo.SchemaName)
   485  	}
   486  }
   487  
   488  func TestUnitAuthenticate(t *testing.T) {
   489  	var err error
   490  	var driverErr *SnowflakeError
   491  	var ok bool
   492  
   493  	ta := getSimpleTokenAccessor()
   494  	sc := getDefaultSnowflakeConn()
   495  	sr := &snowflakeRestful{
   496  		FuncPostAuth:  postAuthFailServiceIssue,
   497  		TokenAccessor: ta,
   498  	}
   499  	sc.rest = sr
   500  
   501  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   502  	if err == nil {
   503  		t.Fatal("should have failed.")
   504  	}
   505  	driverErr, ok = err.(*SnowflakeError)
   506  	if !ok || driverErr.Number != ErrCodeServiceUnavailable {
   507  		t.Fatalf("Snowflake error is expected. err: %v", driverErr)
   508  	}
   509  	sr.FuncPostAuth = postAuthFailWrongAccount
   510  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   511  	if err == nil {
   512  		t.Fatal("should have failed.")
   513  	}
   514  	driverErr, ok = err.(*SnowflakeError)
   515  	if !ok || driverErr.Number != ErrCodeFailedToConnect {
   516  		t.Fatalf("Snowflake error is expected. err: %v", driverErr)
   517  	}
   518  	sr.FuncPostAuth = postAuthFailUnknown
   519  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   520  	if err == nil {
   521  		t.Fatal("should have failed.")
   522  	}
   523  	driverErr, ok = err.(*SnowflakeError)
   524  	if !ok || driverErr.Number != ErrFailedToAuth {
   525  		t.Fatalf("Snowflake error is expected. err: %v", driverErr)
   526  	}
   527  	ta.SetTokens("bad-token", "bad-master-token", 1)
   528  	sr.FuncPostAuth = postAuthSuccessWithErrorCode
   529  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   530  	if err == nil {
   531  		t.Fatal("should have failed.")
   532  	}
   533  	newToken, newMasterToken, newSessionID := ta.GetTokens()
   534  	if newToken != "" || newMasterToken != "" || newSessionID != -1 {
   535  		t.Fatalf("failed auth should have reset tokens: %v %v %v", newToken, newMasterToken, newSessionID)
   536  	}
   537  	driverErr, ok = err.(*SnowflakeError)
   538  	if !ok || driverErr.Number != 98765 {
   539  		t.Fatalf("Snowflake error is expected. err: %v", driverErr)
   540  	}
   541  	ta.SetTokens("bad-token", "bad-master-token", 1)
   542  	sr.FuncPostAuth = postAuthSuccessWithInvalidErrorCode
   543  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   544  	if err == nil {
   545  		t.Fatal("should have failed.")
   546  	}
   547  	oldToken, oldMasterToken, oldSessionID := ta.GetTokens()
   548  	if oldToken != "" || oldMasterToken != "" || oldSessionID != -1 {
   549  		t.Fatalf("failed auth should have reset tokens: %v %v %v", oldToken, oldMasterToken, oldSessionID)
   550  	}
   551  	sr.FuncPostAuth = postAuthSuccess
   552  	var resp *authResponseMain
   553  	resp, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   554  	if err != nil {
   555  		t.Fatalf("failed to auth. err: %v", err)
   556  	}
   557  	if resp.SessionInfo.DatabaseName != "dbn" {
   558  		t.Fatalf("failed to get response from auth")
   559  	}
   560  	newToken, newMasterToken, newSessionID = ta.GetTokens()
   561  	if newToken == oldToken {
   562  		t.Fatalf("new token was not set: %v", newToken)
   563  	}
   564  	if newMasterToken == oldMasterToken {
   565  		t.Fatalf("new master token was not set: %v", newMasterToken)
   566  	}
   567  	if newSessionID == oldSessionID {
   568  		t.Fatalf("new session id was not set: %v", newSessionID)
   569  	}
   570  }
   571  
   572  func TestUnitAuthenticateSaml(t *testing.T) {
   573  	var err error
   574  	sr := &snowflakeRestful{
   575  		Protocol:         "https",
   576  		Host:             "abc.com",
   577  		Port:             443,
   578  		FuncPostAuthSAML: postAuthSAMLAuthSuccess,
   579  		FuncPostAuthOKTA: postAuthOKTASuccess,
   580  		FuncGetSSO:       getSSOSuccess,
   581  		FuncPostAuth:     postAuthCheckSAMLResponse,
   582  		TokenAccessor:    getSimpleTokenAccessor(),
   583  	}
   584  	sc := getDefaultSnowflakeConn()
   585  	sc.cfg.Authenticator = AuthTypeOkta
   586  	sc.cfg.OktaURL = &url.URL{
   587  		Scheme: "https",
   588  		Host:   "abc.com",
   589  	}
   590  	sc.rest = sr
   591  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   592  	assertNilF(t, err, "failed to run.")
   593  }
   594  
   595  // Unit test for OAuth.
   596  func TestUnitAuthenticateOAuth(t *testing.T) {
   597  	var err error
   598  	sr := &snowflakeRestful{
   599  		FuncPostAuth:  postAuthCheckOAuth,
   600  		TokenAccessor: getSimpleTokenAccessor(),
   601  	}
   602  	sc := getDefaultSnowflakeConn()
   603  	sc.cfg.Token = "oauthToken"
   604  	sc.cfg.Authenticator = AuthTypeOAuth
   605  	sc.rest = sr
   606  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   607  	if err != nil {
   608  		t.Fatalf("failed to run. err: %v", err)
   609  	}
   610  }
   611  
   612  func TestUnitAuthenticatePasscode(t *testing.T) {
   613  	var err error
   614  	sr := &snowflakeRestful{
   615  		FuncPostAuth:  postAuthCheckPasscode,
   616  		TokenAccessor: getSimpleTokenAccessor(),
   617  	}
   618  	sc := getDefaultSnowflakeConn()
   619  	sc.cfg.Passcode = "987654321"
   620  	sc.rest = sr
   621  
   622  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   623  	if err != nil {
   624  		t.Fatalf("failed to run. err: %v", err)
   625  	}
   626  	sr.FuncPostAuth = postAuthCheckPasscodeInPassword
   627  	sc.rest = sr
   628  	sc.cfg.PasscodeInPassword = true
   629  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   630  	if err != nil {
   631  		t.Fatalf("failed to run. err: %v", err)
   632  	}
   633  }
   634  
   635  // Test JWT function in the local environment against the validation function in go
   636  func TestUnitAuthenticateJWT(t *testing.T) {
   637  	var err error
   638  
   639  	sr := &snowflakeRestful{
   640  		FuncPostAuth:  postAuthCheckJWTToken,
   641  		TokenAccessor: getSimpleTokenAccessor(),
   642  	}
   643  	sc := getDefaultSnowflakeConn()
   644  	sc.cfg.Authenticator = AuthTypeJwt
   645  	sc.cfg.JWTExpireTimeout = defaultJWTTimeout
   646  	sc.cfg.PrivateKey = testPrivKey
   647  	sc.rest = sr
   648  
   649  	// A valid JWT token should pass
   650  	if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err != nil {
   651  		t.Fatalf("failed to run. err: %v", err)
   652  	}
   653  
   654  	// An invalid JWT token should not pass
   655  	invalidPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
   656  	if err != nil {
   657  		t.Error(err)
   658  	}
   659  	sc.cfg.PrivateKey = invalidPrivateKey
   660  	if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err == nil {
   661  		t.Fatalf("invalid token passed")
   662  	}
   663  }
   664  
   665  func TestUnitAuthenticateUsernamePasswordMfa(t *testing.T) {
   666  	var err error
   667  	sr := &snowflakeRestful{
   668  		FuncPostAuth:  postAuthCheckUsernamePasswordMfa,
   669  		TokenAccessor: getSimpleTokenAccessor(),
   670  	}
   671  	sc := getDefaultSnowflakeConn()
   672  	sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA
   673  	sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
   674  	sc.rest = sr
   675  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   676  	if err != nil {
   677  		t.Fatalf("failed to run. err: %v", err)
   678  	}
   679  
   680  	sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaToken
   681  	sc.cfg.MfaToken = "mockedMfaToken"
   682  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   683  	if err != nil {
   684  		t.Fatalf("failed to run. err: %v", err)
   685  	}
   686  
   687  	sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaFailed
   688  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   689  	if err == nil {
   690  		t.Fatal("should have failed")
   691  	}
   692  }
   693  
   694  func TestUnitAuthenticateWithConfigMFA(t *testing.T) {
   695  	var err error
   696  	sr := &snowflakeRestful{
   697  		FuncPostAuth:  postAuthCheckUsernamePasswordMfa,
   698  		TokenAccessor: getSimpleTokenAccessor(),
   699  	}
   700  	sc := getDefaultSnowflakeConn()
   701  	sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA
   702  	sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
   703  	sc.rest = sr
   704  	sc.ctx = context.Background()
   705  	err = authenticateWithConfig(sc)
   706  	if err != nil {
   707  		t.Fatalf("failed to run. err: %v", err)
   708  	}
   709  }
   710  
   711  func TestUnitAuthenticateWithConfigOkta(t *testing.T) {
   712  	var err error
   713  	sr := &snowflakeRestful{
   714  		Protocol:         "https",
   715  		Host:             "abc.com",
   716  		Port:             443,
   717  		FuncPostAuthSAML: postAuthSAMLAuthSuccess,
   718  		FuncPostAuthOKTA: postAuthOKTASuccess,
   719  		FuncGetSSO:       getSSOSuccess,
   720  		FuncPostAuth:     postAuthCheckSAMLResponse,
   721  		TokenAccessor:    getSimpleTokenAccessor(),
   722  	}
   723  	sc := getDefaultSnowflakeConn()
   724  	sc.cfg.Authenticator = AuthTypeOkta
   725  	sc.cfg.OktaURL = &url.URL{
   726  		Scheme: "https",
   727  		Host:   "abc.com",
   728  	}
   729  	sc.rest = sr
   730  	sc.ctx = context.Background()
   731  
   732  	err = authenticateWithConfig(sc)
   733  	assertNilE(t, err, "expected to have no error.")
   734  
   735  	sr.FuncPostAuthSAML = postAuthSAMLError
   736  	err = authenticateWithConfig(sc)
   737  	assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
   738  	assertEqualE(t, err.Error(), "failed to get SAML response")
   739  }
   740  
   741  func TestUnitAuthenticateWithConfigExternalBrowser(t *testing.T) {
   742  	var err error
   743  	sr := &snowflakeRestful{
   744  		FuncPostAuthSAML: postAuthSAMLError,
   745  		TokenAccessor:    getSimpleTokenAccessor(),
   746  	}
   747  	sc := getDefaultSnowflakeConn()
   748  	sc.cfg.Authenticator = AuthTypeExternalBrowser
   749  	sc.cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout
   750  	sc.rest = sr
   751  	sc.ctx = context.Background()
   752  	err = authenticateWithConfig(sc)
   753  	assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
   754  	assertEqualE(t, err.Error(), "failed to get SAML response")
   755  }
   756  
   757  func TestUnitAuthenticateExternalBrowser(t *testing.T) {
   758  	var err error
   759  	sr := &snowflakeRestful{
   760  		FuncPostAuth:  postAuthCheckExternalBrowser,
   761  		TokenAccessor: getSimpleTokenAccessor(),
   762  	}
   763  	sc := getDefaultSnowflakeConn()
   764  	sc.cfg.Authenticator = AuthTypeExternalBrowser
   765  	sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
   766  	sc.rest = sr
   767  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   768  	if err != nil {
   769  		t.Fatalf("failed to run. err: %v", err)
   770  	}
   771  
   772  	sr.FuncPostAuth = postAuthCheckExternalBrowserToken
   773  	sc.cfg.IDToken = "mockedIDToken"
   774  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   775  	if err != nil {
   776  		t.Fatalf("failed to run. err: %v", err)
   777  	}
   778  
   779  	sr.FuncPostAuth = postAuthCheckExternalBrowserFailed
   780  	_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
   781  	if err == nil {
   782  		t.Fatal("should have failed")
   783  	}
   784  }
   785  
   786  // To run this test you need to set environment variables in parameters.json to a user with MFA authentication enabled
   787  // Set any other snowflake_test variables needed for database, schema, role for this user
   788  func TestUsernamePasswordMfaCaching(t *testing.T) {
   789  	t.Skip("manual test for MFA token caching")
   790  
   791  	config, err := ParseDSN(dsn)
   792  	if err != nil {
   793  		t.Fatal("Failed to parse dsn")
   794  	}
   795  	// connect with MFA authentication
   796  	user := os.Getenv("SNOWFLAKE_TEST_MFA_USER")
   797  	password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD")
   798  	config.User = user
   799  	config.Password = password
   800  	config.Authenticator = AuthTypeUsernamePasswordMFA
   801  	if runtime.GOOS == "linux" {
   802  		config.ClientRequestMfaToken = ConfigBoolTrue
   803  	}
   804  	connector := NewConnector(SnowflakeDriver{}, *config)
   805  	db := sql.OpenDB(connector)
   806  	for i := 0; i < 3; i++ {
   807  		// should only be prompted to authenticate first time around.
   808  		_, err := db.Query("select current_user()")
   809  		if err != nil {
   810  			t.Fatal(err)
   811  		}
   812  	}
   813  }
   814  
   815  // To run this test you need to set environment variables in parameters.json to a user with MFA authentication enabled
   816  // Set any other snowflake_test variables needed for database, schema, role for this user
   817  func TestDisableUsernamePasswordMfaCaching(t *testing.T) {
   818  	t.Skip("manual test for disabling MFA token caching")
   819  
   820  	config, err := ParseDSN(dsn)
   821  	if err != nil {
   822  		t.Fatal("Failed to parse dsn")
   823  	}
   824  	// connect with MFA authentication
   825  	user := os.Getenv("SNOWFLAKE_TEST_MFA_USER")
   826  	password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD")
   827  	config.User = user
   828  	config.Password = password
   829  	config.Authenticator = AuthTypeUsernamePasswordMFA
   830  	// disable MFA token caching
   831  	config.ClientRequestMfaToken = ConfigBoolFalse
   832  	connector := NewConnector(SnowflakeDriver{}, *config)
   833  	db := sql.OpenDB(connector)
   834  	for i := 0; i < 3; i++ {
   835  		// should be prompted to authenticate 3 times.
   836  		_, err := db.Query("select current_user()")
   837  		if err != nil {
   838  			t.Fatal(err)
   839  		}
   840  	}
   841  }
   842  
   843  // To run this test you need to set SNOWFLAKE_TEST_EXT_BROWSER_USER environment variable to an external browser user
   844  // Set any other snowflake_test variables needed for database, schema, role for this user
   845  func TestExternalBrowserCaching(t *testing.T) {
   846  	t.Skip("manual test for external browser token caching")
   847  
   848  	config, err := ParseDSN(dsn)
   849  	if err != nil {
   850  		t.Fatal("Failed to parse dsn")
   851  	}
   852  	// connect with external browser authentication
   853  	user := os.Getenv("SNOWFLAKE_TEST_EXT_BROWSER_USER")
   854  	config.User = user
   855  	config.Authenticator = AuthTypeExternalBrowser
   856  	if runtime.GOOS == "linux" {
   857  		config.ClientStoreTemporaryCredential = ConfigBoolTrue
   858  	}
   859  	connector := NewConnector(SnowflakeDriver{}, *config)
   860  	db := sql.OpenDB(connector)
   861  	for i := 0; i < 3; i++ {
   862  		// should only be prompted to authenticate first time around.
   863  		_, err := db.Query("select current_user()")
   864  		if err != nil {
   865  			t.Fatal(err)
   866  		}
   867  	}
   868  }
   869  
   870  // To run this test you need to set SNOWFLAKE_TEST_EXT_BROWSER_USER environment variable to an external browser user
   871  // Set any other snowflake_test variables needed for database, schema, role for this user
   872  func TestDisableExternalBrowserCaching(t *testing.T) {
   873  	t.Skip("manual test for disabling external browser token caching")
   874  
   875  	config, err := ParseDSN(dsn)
   876  	if err != nil {
   877  		t.Fatal("Failed to parse dsn")
   878  	}
   879  	// connect with external browser authentication
   880  	user := os.Getenv("SNOWFLAKE_TEST_EXT_BROWSER_USER")
   881  	config.User = user
   882  	config.Authenticator = AuthTypeExternalBrowser
   883  	// disable external browser token caching
   884  	config.ClientStoreTemporaryCredential = ConfigBoolFalse
   885  	connector := NewConnector(SnowflakeDriver{}, *config)
   886  	db := sql.OpenDB(connector)
   887  	for i := 0; i < 3; i++ {
   888  		// should be prompted to authenticate 3 times.
   889  		_, err := db.Query("select current_user()")
   890  		if err != nil {
   891  			t.Fatal(err)
   892  		}
   893  	}
   894  }
   895  
   896  func TestOktaRetryWithNewToken(t *testing.T) {
   897  	expectedMasterToken := "m"
   898  	expectedToken := "t"
   899  	expectedMfaToken := "mockedMfaToken"
   900  	expectedDatabaseName := "dbn"
   901  
   902  	sr := &snowflakeRestful{
   903  		Protocol:         "https",
   904  		Host:             "abc.com",
   905  		Port:             443,
   906  		FuncPostAuthSAML: postAuthSAMLAuthSuccess,
   907  		FuncPostAuthOKTA: postAuthOKTASuccess,
   908  		FuncGetSSO:       getSSOSuccess,
   909  		FuncPostAuth:     postAuthOktaWithNewToken,
   910  		TokenAccessor:    getSimpleTokenAccessor(),
   911  	}
   912  	sc := getDefaultSnowflakeConn()
   913  	sc.cfg.Authenticator = AuthTypeOkta
   914  	sc.cfg.OktaURL = &url.URL{
   915  		Scheme: "https",
   916  		Host:   "abc.com",
   917  	}
   918  	sc.rest = sr
   919  	sc.ctx = context.Background()
   920  
   921  	authResponse, err := authenticate(context.Background(), sc, []byte{0x12, 0x34}, []byte{0x56, 0x78})
   922  	assertNilF(t, err, "should not have failed to run authenticate()")
   923  	assertEqualF(t, authResponse.MasterToken, expectedMasterToken)
   924  	assertEqualF(t, authResponse.Token, expectedToken)
   925  	assertEqualF(t, authResponse.MfaToken, expectedMfaToken)
   926  	assertEqualF(t, authResponse.SessionInfo.DatabaseName, expectedDatabaseName)
   927  }