github.com/snowflakedb/gosnowflake@v1.9.0/secure_storage_manager_test.go (about)

     1  // Copyright (c) 2023 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"errors"
     7  	"io"
     8  	"os"
     9  	"runtime"
    10  	"testing"
    11  )
    12  
    13  type tcTargets struct {
    14  	host     string
    15  	user     string
    16  	credType string
    17  	out      string
    18  }
    19  
    20  type tcCredentials struct {
    21  	credType string
    22  	token    string
    23  }
    24  
    25  func TestSetAndGetCredentialMfa(t *testing.T) {
    26  	if runtime.GOOS == "darwin" {
    27  		t.Skip("MacOS requires keychain password to be manually entered.")
    28  	} else {
    29  		fakeMfaToken := "fakeMfaToken"
    30  		expectedMfaToken := "fakeMfaToken"
    31  		sc := getDefaultSnowflakeConn()
    32  		sc.cfg.Host = "testhost"
    33  		setCredential(sc, mfaToken, fakeMfaToken)
    34  		getCredential(sc, mfaToken)
    35  
    36  		if sc.cfg.MfaToken != expectedMfaToken {
    37  			t.Fatalf("Expected mfa token %v but got %v", expectedMfaToken, sc.cfg.MfaToken)
    38  		}
    39  
    40  		// delete credential and check it no longer exists
    41  		deleteCredential(sc, mfaToken)
    42  		getCredential(sc, mfaToken)
    43  		if sc.cfg.MfaToken != "" {
    44  			t.Fatalf("Expected mfa token to be empty but got %v", sc.cfg.MfaToken)
    45  		}
    46  	}
    47  }
    48  
    49  func TestSetAndGetCredentialIdToken(t *testing.T) {
    50  	if runtime.GOOS == "darwin" {
    51  		t.Skip("MacOS requires keychain password to be manually entered.")
    52  	} else {
    53  		fakeIDToken := "fakeIDToken"
    54  		expectedIDToken := "fakeIDToken"
    55  		sc := getDefaultSnowflakeConn()
    56  		sc.cfg.Host = "testhost"
    57  		setCredential(sc, idToken, fakeIDToken)
    58  		getCredential(sc, idToken)
    59  
    60  		if sc.cfg.IDToken != expectedIDToken {
    61  			t.Fatalf("Expected id token %v but got %v", expectedIDToken, sc.cfg.IDToken)
    62  		}
    63  
    64  		// delete credential and check it no longer exists
    65  		deleteCredential(sc, idToken)
    66  		getCredential(sc, idToken)
    67  		if sc.cfg.IDToken != "" {
    68  			t.Fatalf("Expected id token to be empty but got %v", sc.cfg.IDToken)
    69  		}
    70  	}
    71  }
    72  func TestCreateCredentialCache(t *testing.T) {
    73  	skipOnJenkins(t, "cannot write to file system")
    74  	if runningOnGithubAction() {
    75  		t.Skip("cannot write to github file system")
    76  	}
    77  	dirName, err := os.UserHomeDir()
    78  	if err != nil {
    79  		t.Error(err)
    80  	}
    81  	srcFileName := dirName + "/.cache/snowflake/temporary_credential.json"
    82  	tmpFileName := srcFileName + "_tmp"
    83  	dst, err := os.Create(tmpFileName)
    84  	if err != nil {
    85  		t.Error(err)
    86  	}
    87  	defer dst.Close()
    88  
    89  	var src *os.File
    90  	if _, err = os.Stat(srcFileName); errors.Is(err, os.ErrNotExist) {
    91  		// file does not exist
    92  		if err = os.MkdirAll(dirName+"/.cache/snowflake/", os.ModePerm); err != nil {
    93  			t.Error(err)
    94  		}
    95  		if _, err = os.Create(srcFileName); err != nil {
    96  			t.Error(err)
    97  		}
    98  	} else if err != nil {
    99  		t.Error(err)
   100  	} else {
   101  		// file exists
   102  		src, err = os.Open(srcFileName)
   103  		if err != nil {
   104  			t.Error(err)
   105  		}
   106  		defer src.Close()
   107  		// copy original contents to temporary file
   108  		if _, err = io.Copy(dst, src); err != nil {
   109  			t.Error(err)
   110  		}
   111  		if err = os.Remove(srcFileName); err != nil {
   112  			t.Error(err)
   113  		}
   114  	}
   115  
   116  	createCredentialCacheDir()
   117  	if _, err = os.Stat(srcFileName); errors.Is(err, os.ErrNotExist) {
   118  		t.Error(err)
   119  	} else if err != nil {
   120  		t.Error(err)
   121  	}
   122  
   123  	// cleanup
   124  	src, _ = os.Open(tmpFileName)
   125  	defer src.Close()
   126  	dst, _ = os.OpenFile(srcFileName, os.O_WRONLY, readWriteFileMode)
   127  	defer dst.Close()
   128  	// copy temporary file contents back to original file
   129  	if _, err = io.Copy(dst, src); err != nil {
   130  		t.Fatal(err)
   131  	}
   132  	if err = os.Remove(tmpFileName); err != nil {
   133  		t.Error(err)
   134  	}
   135  }
   136  
   137  func TestStoreTemporaryCredental(t *testing.T) {
   138  	if runningOnGithubAction() {
   139  		t.Skip("cannot write to github file system")
   140  	}
   141  
   142  	testcases := []tcCredentials{
   143  		{"mfaToken", "598ghFnjfh8BBgmf45mmhgkfRR45mgkt5"},
   144  		{"IdToken", "090Arftf54Jk3gh57ggrVvf09lJa3DD"},
   145  	}
   146  	createCredentialCacheDir()
   147  	if credCache == "" {
   148  		t.Fatalf("failed to create credential cache")
   149  	}
   150  	sc := getDefaultSnowflakeConn()
   151  	for _, test := range testcases {
   152  		t.Run(test.token, func(t *testing.T) {
   153  			writeTemporaryCredential(sc, test.credType, test.token)
   154  			target := convertTarget(sc.cfg.Host, sc.cfg.User, test.credType)
   155  			_, ok := localCredCache[target]
   156  			if !ok {
   157  				t.Fatalf("failed to write credential to local cache")
   158  			}
   159  			tmpCred := readTemporaryCredential(sc, test.credType)
   160  			if tmpCred == "" {
   161  				t.Fatalf("failed to read credential from temporary cache")
   162  			} else {
   163  				deleteTemporaryCredential(sc, test.credType)
   164  			}
   165  		})
   166  	}
   167  }
   168  
   169  func TestConvertTarget(t *testing.T) {
   170  	testcases := []tcTargets{
   171  		{"testaccount.snowflakecomputing.com", "testuser", "mfaToken", "TESTACCOUNT.SNOWFLAKECOMPUTING.COM:TESTUSER:SNOWFLAKE-GO-DRIVER:MFATOKEN"},
   172  		{"testaccount.snowflakecomputing.com", "testuser", "IdToken", "TESTACCOUNT.SNOWFLAKECOMPUTING.COM:TESTUSER:SNOWFLAKE-GO-DRIVER:IDTOKEN"},
   173  	}
   174  	for _, test := range testcases {
   175  		target := convertTarget(test.host, test.user, test.credType)
   176  		if target != test.out {
   177  			t.Fatalf("failed to convert target. expected: %v, but got: %v", test.out, target)
   178  		}
   179  	}
   180  }