github.com/snowflakedb/gosnowflake@v1.9.0/util.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"database/sql/driver"
     8  	"fmt"
     9  	"io"
    10  	"math/rand"
    11  	"os"
    12  	"os/exec"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/apache/arrow/go/v15/arrow/memory"
    18  )
    19  
    20  type contextKey string
    21  
    22  const (
    23  	multiStatementCount              contextKey = "MULTI_STATEMENT_COUNT"
    24  	asyncMode                        contextKey = "ASYNC_MODE_QUERY"
    25  	queryIDChannel                   contextKey = "QUERY_ID_CHANNEL"
    26  	snowflakeRequestIDKey            contextKey = "SNOWFLAKE_REQUEST_ID"
    27  	fetchResultByID                  contextKey = "SF_FETCH_RESULT_BY_ID"
    28  	fileStreamFile                   contextKey = "STREAMING_PUT_FILE"
    29  	fileTransferOptions              contextKey = "FILE_TRANSFER_OPTIONS"
    30  	enableHigherPrecision            contextKey = "ENABLE_HIGHER_PRECISION"
    31  	enableArrowBatchesUtf8Validation contextKey = "ENABLE_ARROW_BATCHES_UTF8_VALIDATION"
    32  	arrowBatches                     contextKey = "ARROW_BATCHES"
    33  	arrowAlloc                       contextKey = "ARROW_ALLOC"
    34  	arrowBatchesTimestampOption      contextKey = "ARROW_BATCHES_TIMESTAMP_OPTION"
    35  	queryTag                         contextKey = "QUERY_TAG"
    36  )
    37  
    38  const (
    39  	describeOnly        contextKey = "DESCRIBE_ONLY"
    40  	cancelRetry         contextKey = "CANCEL_RETRY"
    41  	streamChunkDownload contextKey = "STREAM_CHUNK_DOWNLOAD"
    42  )
    43  
    44  var (
    45  	defaultTimeProvider = &unixTimeProvider{}
    46  )
    47  
    48  // WithMultiStatement returns a context that allows the user to execute the desired number of sql queries in one query
    49  func WithMultiStatement(ctx context.Context, num int) (context.Context, error) {
    50  	return context.WithValue(ctx, multiStatementCount, num), nil
    51  }
    52  
    53  // WithAsyncMode returns a context that allows execution of query in async mode
    54  func WithAsyncMode(ctx context.Context) context.Context {
    55  	return context.WithValue(ctx, asyncMode, true)
    56  }
    57  
    58  // WithQueryIDChan returns a context that contains the channel to receive the query ID
    59  func WithQueryIDChan(ctx context.Context, c chan<- string) context.Context {
    60  	return context.WithValue(ctx, queryIDChannel, c)
    61  }
    62  
    63  // WithRequestID returns a new context with the specified snowflake request id
    64  func WithRequestID(ctx context.Context, requestID UUID) context.Context {
    65  	return context.WithValue(ctx, snowflakeRequestIDKey, requestID)
    66  }
    67  
    68  // WithStreamDownloader returns a context that allows the use of a stream based chunk downloader
    69  func WithStreamDownloader(ctx context.Context) context.Context {
    70  	return context.WithValue(ctx, streamChunkDownload, true)
    71  }
    72  
    73  // WithFetchResultByID returns a context that allows retrieving the result by query ID
    74  func WithFetchResultByID(ctx context.Context, queryID string) context.Context {
    75  	return context.WithValue(ctx, fetchResultByID, queryID)
    76  }
    77  
    78  // WithFileStream returns a context that contains the address of the file stream to be PUT
    79  func WithFileStream(ctx context.Context, reader io.Reader) context.Context {
    80  	return context.WithValue(ctx, fileStreamFile, reader)
    81  }
    82  
    83  // WithFileTransferOptions returns a context that contains the address of file transfer options
    84  func WithFileTransferOptions(ctx context.Context, options *SnowflakeFileTransferOptions) context.Context {
    85  	return context.WithValue(ctx, fileTransferOptions, options)
    86  }
    87  
    88  // WithDescribeOnly returns a context that enables a describe only query
    89  func WithDescribeOnly(ctx context.Context) context.Context {
    90  	return context.WithValue(ctx, describeOnly, true)
    91  }
    92  
    93  // WithHigherPrecision returns a context that enables higher precision by
    94  // returning a *big.Int or *big.Float variable when querying rows for column
    95  // types with numbers that don't fit into its native Golang counterpart
    96  // When used in combination with WithArrowBatches, original BigDecimal in arrow batches will be preserved.
    97  func WithHigherPrecision(ctx context.Context) context.Context {
    98  	return context.WithValue(ctx, enableHigherPrecision, true)
    99  }
   100  
   101  // WithArrowBatches returns a context that allows users to retrieve
   102  // arrow.Record download workers upon querying
   103  func WithArrowBatches(ctx context.Context) context.Context {
   104  	return context.WithValue(ctx, arrowBatches, true)
   105  }
   106  
   107  // WithArrowAllocator returns a context embedding the provided allocator
   108  // which will be utilized by chunk downloaders when constructing Arrow
   109  // objects.
   110  func WithArrowAllocator(ctx context.Context, pool memory.Allocator) context.Context {
   111  	return context.WithValue(ctx, arrowAlloc, pool)
   112  }
   113  
   114  // WithOriginalTimestamp in combination with WithArrowBatches returns a context
   115  // that allows users to retrieve arrow.Record with original timestamp struct returned by Snowflake.
   116  // It can be used in case arrow.Timestamp cannot fit original timestamp values.
   117  //
   118  // Deprecated: please use WithArrowBatchesTimestampOption instead.
   119  func WithOriginalTimestamp(ctx context.Context) context.Context {
   120  	return context.WithValue(ctx, arrowBatchesTimestampOption, UseOriginalTimestamp)
   121  }
   122  
   123  // WithArrowBatchesTimestampOption in combination with WithArrowBatches returns a context
   124  // that allows users to retrieve arrow.Record with different timestamp options.
   125  // UseNanosecondTimestamp: arrow.Timestamp in nanosecond precision, could cause ErrTooHighTimestampPrecision if arrow.Timestamp cannot fit original timestamp values.
   126  // UseMicrosecondTimestamp: arrow.Timestamp in microsecond precision
   127  // UseMillisecondTimestamp: arrow.Timestamp in millisecond precision
   128  // UseSecondTimestamp: arrow.Timestamp in second precision
   129  // UseOriginalTimestamp: original timestamp struct returned by Snowflake. It can be used in case arrow.Timestamp cannot fit original timestamp values.
   130  func WithArrowBatchesTimestampOption(ctx context.Context, option snowflakeArrowBatchesTimestampOption) context.Context {
   131  	return context.WithValue(ctx, arrowBatchesTimestampOption, option)
   132  }
   133  
   134  // WithArrowBatchesUtf8Validation in combination with WithArrowBatches returns a context that
   135  // will validate and replace invalid UTF-8 characters in string columns with the replacement character
   136  // Theoretically, this should not be necessary, because arrow string column is only intended to contain valid UTF-8 characters.
   137  // However, in practice, it is possible that the data in the string column is not valid UTF-8.
   138  func WithArrowBatchesUtf8Validation(ctx context.Context) context.Context {
   139  	return context.WithValue(ctx, enableArrowBatchesUtf8Validation, true)
   140  
   141  }
   142  
   143  // WithQueryTag returns a context that will set the given tag as the QUERY_TAG
   144  // parameter on any queries that are run
   145  func WithQueryTag(ctx context.Context, tag string) context.Context {
   146  	return context.WithValue(ctx, queryTag, tag)
   147  }
   148  
   149  // Get the request ID from the context if specified, otherwise generate one
   150  func getOrGenerateRequestIDFromContext(ctx context.Context) UUID {
   151  	requestID, ok := ctx.Value(snowflakeRequestIDKey).(UUID)
   152  	if ok && requestID != nilUUID {
   153  		return requestID
   154  	}
   155  	return NewUUID()
   156  }
   157  
   158  // integer min
   159  func intMin(a, b int) int {
   160  	if a < b {
   161  		return a
   162  	}
   163  	return b
   164  }
   165  
   166  // integer max
   167  func intMax(a, b int) int {
   168  	if a > b {
   169  		return a
   170  	}
   171  	return b
   172  }
   173  
   174  func int64Max(a, b int64) int64 {
   175  	if a > b {
   176  		return a
   177  	}
   178  	return b
   179  }
   180  
   181  func getMin(arr []int) int {
   182  	if len(arr) == 0 {
   183  		return -1
   184  	}
   185  	min := arr[0]
   186  	for _, v := range arr {
   187  		if v <= min {
   188  			min = v
   189  		}
   190  	}
   191  	return min
   192  }
   193  
   194  // time.Duration max
   195  func durationMax(d1, d2 time.Duration) time.Duration {
   196  	if d1-d2 > 0 {
   197  		return d1
   198  	}
   199  	return d2
   200  }
   201  
   202  // time.Duration min
   203  func durationMin(d1, d2 time.Duration) time.Duration {
   204  	if d1-d2 < 0 {
   205  		return d1
   206  	}
   207  	return d2
   208  }
   209  
   210  // toNamedValues converts a slice of driver.Value to a slice of driver.NamedValue for Go 1.8 SQL package
   211  func toNamedValues(values []driver.Value) []driver.NamedValue {
   212  	namedValues := make([]driver.NamedValue, len(values))
   213  	for idx, value := range values {
   214  		namedValues[idx] = driver.NamedValue{Name: "", Ordinal: idx + 1, Value: value}
   215  	}
   216  	return namedValues
   217  }
   218  
   219  // TokenAccessor manages the session token and master token
   220  type TokenAccessor interface {
   221  	GetTokens() (token string, masterToken string, sessionID int64)
   222  	SetTokens(token string, masterToken string, sessionID int64)
   223  	Lock() error
   224  	Unlock()
   225  }
   226  
   227  type simpleTokenAccessor struct {
   228  	token        string
   229  	masterToken  string
   230  	sessionID    int64
   231  	accessorLock sync.Mutex   // Used to implement accessor's Lock and Unlock
   232  	tokenLock    sync.RWMutex // Used to synchronize SetTokens and GetTokens
   233  }
   234  
   235  func getSimpleTokenAccessor() TokenAccessor {
   236  	return &simpleTokenAccessor{sessionID: -1}
   237  }
   238  
   239  func (sta *simpleTokenAccessor) Lock() error {
   240  	sta.accessorLock.Lock()
   241  	return nil
   242  }
   243  
   244  func (sta *simpleTokenAccessor) Unlock() {
   245  	sta.accessorLock.Unlock()
   246  }
   247  
   248  func (sta *simpleTokenAccessor) GetTokens() (token string, masterToken string, sessionID int64) {
   249  	sta.tokenLock.RLock()
   250  	defer sta.tokenLock.RUnlock()
   251  	return sta.token, sta.masterToken, sta.sessionID
   252  }
   253  
   254  func (sta *simpleTokenAccessor) SetTokens(token string, masterToken string, sessionID int64) {
   255  	sta.tokenLock.Lock()
   256  	defer sta.tokenLock.Unlock()
   257  	sta.token = token
   258  	sta.masterToken = masterToken
   259  	sta.sessionID = sessionID
   260  }
   261  
   262  func escapeForCSV(value string) string {
   263  	if value == "" {
   264  		return "\"\""
   265  	}
   266  	if strings.Contains(value, "\"") || strings.Contains(value, "\n") ||
   267  		strings.Contains(value, ",") || strings.Contains(value, "\\") {
   268  		return "\"" + strings.ReplaceAll(value, "\"", "\"\"") + "\""
   269  	}
   270  	return value
   271  }
   272  
   273  // GetFromEnv is used to get the value of an environment variable from the system
   274  func GetFromEnv(name string, failOnMissing bool) (string, error) {
   275  	if value := os.Getenv(name); value != "" {
   276  		return value, nil
   277  	}
   278  	if failOnMissing {
   279  		return "", fmt.Errorf("%v environment variable is not set", name)
   280  	}
   281  	return "", nil
   282  }
   283  
   284  type currentTimeProvider interface {
   285  	currentTime() int64
   286  }
   287  
   288  type unixTimeProvider struct {
   289  }
   290  
   291  func (utp *unixTimeProvider) currentTime() int64 {
   292  	return time.Now().UnixMilli()
   293  }
   294  
   295  func contains[T comparable](s []T, e T) bool {
   296  	for _, v := range s {
   297  		if v == e {
   298  			return true
   299  		}
   300  	}
   301  	return false
   302  }
   303  
   304  func chooseRandomFromRange(min float64, max float64) float64 {
   305  	return rand.Float64()*(max-min) + min
   306  }
   307  
   308  func isDbusDaemonRunning() bool {
   309  	// TODO: delete this once we replaced 99designs/keyring (SNOW-1017659) and/or keyring#103 is resolved
   310  	cmd := exec.Command("pidof", "dbus-daemon")
   311  	_, err := cmd.Output()
   312  
   313  	// false: process not running, pidof not available (sysvinit-tools, busybox, etc missing)
   314  	return err == nil
   315  }
   316  
   317  func canDbusLeakProcesses() (bool, string) {
   318  	// TODO: delete this once we replaced 99designs/keyring (SNOW-1017659) and/or keyring#103 is resolved
   319  	leak := false
   320  	message := ""
   321  
   322  	valDbus, haveDbus := os.LookupEnv("DBUS_SESSION_BUS_ADDRESS")
   323  	if !haveDbus || strings.Contains(valDbus, "unix:abstract") {
   324  		// if DBUS_SESSION_BUS_ADDRESS is not set or set to an abstract socket, it's not necessarily a problem, only if dbus-daemon is running
   325  		if isDbusDaemonRunning() {
   326  			// we're probably susceptible to https://github.com/99designs/keyring/issues/103 here
   327  			leak = true
   328  			message += "DBUS_SESSION_BUS_ADDRESS envvar looks to be not set, this can lead to runaway dbus-daemon processes. " +
   329  				"To avoid this, set envvar DBUS_SESSION_BUS_ADDRESS=$XDG_RUNTIME_DIR/bus (if it exists) or DBUS_SESSION_BUS_ADDRESS=/dev/null."
   330  		}
   331  	}
   332  	return leak, message
   333  }