github.com/snowflakedb/gosnowflake@v1.9.0/secure_storage_manager.go (about)

     1  // Copyright (c) 2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"encoding/json"
     7  	"os"
     8  	"path/filepath"
     9  	"runtime"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/99designs/keyring"
    15  )
    16  
    17  const (
    18  	driverName        = "SNOWFLAKE-GO-DRIVER"
    19  	credCacheDirEnv   = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR"
    20  	credCacheFileName = "temporary_credential.json"
    21  )
    22  
    23  var (
    24  	credCacheDir   = ""
    25  	credCache      = ""
    26  	localCredCache = map[string]string{}
    27  )
    28  
    29  var (
    30  	temporaryCredCacheLock sync.RWMutex
    31  )
    32  
    33  func createCredentialCacheDir() {
    34  	credCacheDir = os.Getenv(credCacheDirEnv)
    35  	if credCacheDir == "" {
    36  		switch runtime.GOOS {
    37  		case "windows":
    38  			credCacheDir = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local", "Snowflake", "Caches")
    39  		case "darwin":
    40  			home := os.Getenv("HOME")
    41  			if home == "" {
    42  				logger.Info("HOME is blank.")
    43  			}
    44  			credCacheDir = filepath.Join(home, "Library", "Caches", "Snowflake")
    45  		default:
    46  			home := os.Getenv("HOME")
    47  			if home == "" {
    48  				logger.Info("HOME is blank")
    49  			}
    50  			credCacheDir = filepath.Join(home, ".cache", "snowflake")
    51  		}
    52  	}
    53  
    54  	if _, err := os.Stat(credCacheDir); os.IsNotExist(err) {
    55  		if err = os.MkdirAll(credCacheDir, os.ModePerm); err != nil {
    56  			logger.Debugf("Failed to create cache directory. %v, err: %v. ignored\n", credCacheDir, err)
    57  		}
    58  	}
    59  	credCache = filepath.Join(credCacheDir, credCacheFileName)
    60  	logger.Infof("Cache directory: %v", credCache)
    61  }
    62  
    63  func setCredential(sc *snowflakeConn, credType, token string) {
    64  	if token == "" {
    65  		logger.Debug("no token provided")
    66  	} else {
    67  		var target string
    68  		if runtime.GOOS == "windows" {
    69  			target = driverName + ":" + credType
    70  			ring, _ := keyring.Open(keyring.Config{
    71  				WinCredPrefix: strings.ToUpper(sc.cfg.Host),
    72  				ServiceName:   strings.ToUpper(sc.cfg.User),
    73  			})
    74  			item := keyring.Item{
    75  				Key:  target,
    76  				Data: []byte(token),
    77  			}
    78  			if err := ring.Set(item); err != nil {
    79  				logger.Debugf("Failed to write to Windows credential manager. Err: %v", err)
    80  			}
    81  		} else if runtime.GOOS == "darwin" {
    82  			target = convertTarget(sc.cfg.Host, sc.cfg.User, credType)
    83  			ring, _ := keyring.Open(keyring.Config{
    84  				ServiceName: target,
    85  			})
    86  			account := strings.ToUpper(sc.cfg.User)
    87  			item := keyring.Item{
    88  				Key:  account,
    89  				Data: []byte(token),
    90  			}
    91  			if err := ring.Set(item); err != nil {
    92  				logger.Debugf("Failed to write to keychain. Err: %v", err)
    93  			}
    94  		} else if runtime.GOOS == "linux" {
    95  			createCredentialCacheDir()
    96  			writeTemporaryCredential(sc, credType, token)
    97  		} else {
    98  			logger.Debug("OS not supported for Local Secure Storage")
    99  		}
   100  	}
   101  }
   102  
   103  func getCredential(sc *snowflakeConn, credType string) {
   104  	var target string
   105  	cred := ""
   106  	if runtime.GOOS == "windows" {
   107  		target = driverName + ":" + credType
   108  		ring, _ := keyring.Open(keyring.Config{
   109  			WinCredPrefix: strings.ToUpper(sc.cfg.Host),
   110  			ServiceName:   strings.ToUpper(sc.cfg.User),
   111  		})
   112  		i, err := ring.Get(target)
   113  		if err != nil {
   114  			logger.Debugf("Failed to read target or could not find it in Windows Credential Manager. Error: %v", err)
   115  		}
   116  		cred = string(i.Data)
   117  	} else if runtime.GOOS == "darwin" {
   118  		target = convertTarget(sc.cfg.Host, sc.cfg.User, credType)
   119  		ring, _ := keyring.Open(keyring.Config{
   120  			ServiceName: target,
   121  		})
   122  		account := strings.ToUpper(sc.cfg.User)
   123  		i, err := ring.Get(account)
   124  		if err != nil {
   125  			logger.Debugf("Failed to find the item in keychain or item does not exist. Error: %v", err)
   126  		}
   127  		cred = string(i.Data)
   128  		if cred == "" {
   129  			logger.Debug("Returned credential is empty")
   130  		} else {
   131  			logger.Debug("Successfully read token. Returning as string")
   132  		}
   133  	} else if runtime.GOOS == "linux" {
   134  		createCredentialCacheDir()
   135  		cred = readTemporaryCredential(sc, credType)
   136  	} else {
   137  		logger.Debug("OS not supported for Local Secure Storage")
   138  	}
   139  
   140  	if credType == idToken {
   141  		sc.cfg.IDToken = cred
   142  	} else if credType == mfaToken {
   143  		sc.cfg.MfaToken = cred
   144  	} else {
   145  		logger.Debugf("Unrecognized type %v for local cached credential", credType)
   146  	}
   147  }
   148  
   149  func deleteCredential(sc *snowflakeConn, credType string) {
   150  	target := driverName + ":" + credType
   151  	if runtime.GOOS == "windows" {
   152  		ring, _ := keyring.Open(keyring.Config{
   153  			WinCredPrefix: strings.ToUpper(sc.cfg.Host),
   154  			ServiceName:   strings.ToUpper(sc.cfg.User),
   155  		})
   156  		err := ring.Remove(target)
   157  		if err != nil {
   158  			logger.Debugf("Failed to delete target in Windows Credential Manager. Error: %v", err)
   159  		}
   160  	} else if runtime.GOOS == "darwin" {
   161  		target = convertTarget(sc.cfg.Host, sc.cfg.User, credType)
   162  		ring, _ := keyring.Open(keyring.Config{
   163  			ServiceName: target,
   164  		})
   165  		account := strings.ToUpper(sc.cfg.User)
   166  		err := ring.Remove(account)
   167  		if err != nil {
   168  			logger.Debugf("Failed to delete target in keychain. Error: %v", err)
   169  		}
   170  	} else if runtime.GOOS == "linux" {
   171  		deleteTemporaryCredential(sc, credType)
   172  	}
   173  }
   174  
   175  // Reads temporary credential file when OS is Linux.
   176  func readTemporaryCredential(sc *snowflakeConn, credType string) string {
   177  	target := convertTarget(sc.cfg.Host, sc.cfg.User, credType)
   178  	temporaryCredCacheLock.RLock()
   179  	localCredCache := readTemporaryCacheFile()
   180  	temporaryCredCacheLock.RUnlock()
   181  	cred := localCredCache[target]
   182  	if cred != "" {
   183  		logger.Debug("Successfully read token. Returning as string")
   184  	} else {
   185  		logger.Debug("Returned credential is empty")
   186  	}
   187  	return cred
   188  }
   189  
   190  // Writes to temporary credential file when OS is Linux.
   191  func writeTemporaryCredential(sc *snowflakeConn, credType, token string) {
   192  	target := convertTarget(sc.cfg.Host, sc.cfg.User, credType)
   193  	localCredCache[target] = token
   194  
   195  	j, err := json.Marshal(localCredCache)
   196  	if err != nil {
   197  		logger.Debugf("failed to convert credential to JSON.")
   198  	}
   199  	temporaryCredCacheLock.Lock()
   200  	writeTemporaryCacheFile(j)
   201  	temporaryCredCacheLock.Unlock()
   202  }
   203  
   204  func deleteTemporaryCredential(sc *snowflakeConn, credType string) {
   205  	if credCacheDir == "" {
   206  		logger.Debug("Cache file doesn't exist. Skipping deleting credential file.")
   207  	} else {
   208  		target := convertTarget(sc.cfg.Host, sc.cfg.User, credType)
   209  		delete(localCredCache, target)
   210  		j, err := json.Marshal(localCredCache)
   211  		if err != nil {
   212  			logger.Debugf("failed to convert credential to JSON.")
   213  		}
   214  		temporaryCredCacheLock.Lock()
   215  		writeTemporaryCacheFile(j)
   216  		temporaryCredCacheLock.Unlock()
   217  	}
   218  }
   219  
   220  func readTemporaryCacheFile() map[string]string {
   221  	if credCache == "" {
   222  		logger.Debug("Cache file doesn't exist. Skipping reading credential file.")
   223  		return nil
   224  	}
   225  	jsonData, err := os.ReadFile(credCache)
   226  	if err != nil {
   227  		logger.Debugf("Failed to read credential file: %v", err)
   228  		return nil
   229  	}
   230  	err = json.Unmarshal([]byte(jsonData), &localCredCache)
   231  	if err != nil {
   232  		logger.Debugf("failed to read JSON. Err: %v", err)
   233  		return nil
   234  	}
   235  
   236  	return localCredCache
   237  }
   238  
   239  func writeTemporaryCacheFile(input []byte) {
   240  	if credCache == "" {
   241  		logger.Debug("Cache file doesn't exist. Skipping writing temporary credential file.")
   242  	} else {
   243  		logger.Debugf("writing credential cache file. %v\n", credCache)
   244  		credCacheLockFileName := credCache + ".lck"
   245  		err := os.Mkdir(credCacheLockFileName, 0600)
   246  		logger.Debugf("Creating lock file. %v", credCacheLockFileName)
   247  
   248  		switch {
   249  		case os.IsExist(err):
   250  			statinfo, err := os.Stat(credCacheLockFileName)
   251  			if err != nil {
   252  				logger.Debugf("failed to write credential cache file. file: %v, err: %v. ignored.\n", credCache, err)
   253  				return
   254  			}
   255  			if time.Since(statinfo.ModTime()) < 15*time.Minute {
   256  				logger.Debugf("other process locks the cache file. %v. ignored.\n", credCache)
   257  				return
   258  			}
   259  			if err = os.Remove(credCacheLockFileName); err != nil {
   260  				logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err)
   261  				return
   262  			}
   263  			if err = os.Mkdir(credCacheLockFileName, 0600); err != nil {
   264  				logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err)
   265  				return
   266  			}
   267  		}
   268  		defer os.RemoveAll(credCacheLockFileName)
   269  
   270  		if err = os.WriteFile(credCache, input, 0644); err != nil {
   271  			logger.Debugf("Failed to write the cache file. File: %v err: %v.", credCache, err)
   272  		}
   273  	}
   274  }
   275  
   276  func convertTarget(host, user, credType string) string {
   277  	host = strings.ToUpper(host)
   278  	user = strings.ToUpper(user)
   279  	credType = strings.ToUpper(credType)
   280  	target := host + ":" + user + ":" + driverName + ":" + credType
   281  	return target
   282  }