github.com/snowflakedb/gosnowflake@v1.9.0/connection_util.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"fmt"
     9  	"io"
    10  	"os"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  )
    15  
    16  func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool {
    17  	paramsMutex.Lock()
    18  	v, ok := sc.cfg.Params[sessionClientSessionKeepAlive]
    19  	paramsMutex.Unlock()
    20  	if !ok {
    21  		return false
    22  	}
    23  	return strings.Compare(*v, "true") == 0
    24  }
    25  
    26  func (sc *snowflakeConn) startHeartBeat() {
    27  	if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() {
    28  		return
    29  	}
    30  	if sc.rest != nil {
    31  		sc.rest.HeartBeat = &heartbeat{
    32  			restful: sc.rest,
    33  		}
    34  		sc.rest.HeartBeat.start()
    35  	}
    36  }
    37  
    38  func (sc *snowflakeConn) stopHeartBeat() {
    39  	if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() {
    40  		return
    41  	}
    42  	if sc.rest != nil && sc.rest.HeartBeat != nil {
    43  		sc.rest.HeartBeat.stop()
    44  	}
    45  }
    46  
    47  func (sc *snowflakeConn) getArrayBindStageThreshold() int {
    48  	paramsMutex.Lock()
    49  	v, ok := sc.cfg.Params[sessionArrayBindStageThreshold]
    50  	paramsMutex.Unlock()
    51  	if !ok {
    52  		return 0
    53  	}
    54  	num, err := strconv.Atoi(*v)
    55  	if err != nil {
    56  		return 0
    57  	}
    58  	return num
    59  }
    60  
    61  func (sc *snowflakeConn) connectionTelemetry(cfg *Config) {
    62  	data := &telemetryData{
    63  		Message: map[string]string{
    64  			typeKey:          connectionParameters,
    65  			sourceKey:        telemetrySource,
    66  			driverTypeKey:    "Go",
    67  			driverVersionKey: SnowflakeGoDriverVersion,
    68  		},
    69  		Timestamp: time.Now().UnixNano() / int64(time.Millisecond),
    70  	}
    71  	paramsMutex.Lock()
    72  	for k, v := range cfg.Params {
    73  		data.Message[k] = *v
    74  	}
    75  	paramsMutex.Unlock()
    76  	sc.telemetry.addLog(data)
    77  	sc.telemetry.sendBatch()
    78  }
    79  
    80  // processFileTransfer creates a snowflakeFileTransferAgent object to process
    81  // any PUT/GET commands with their specified options
    82  func (sc *snowflakeConn) processFileTransfer(
    83  	ctx context.Context,
    84  	data *execResponse,
    85  	query string,
    86  	isInternal bool) (
    87  	*execResponse, error) {
    88  	sfa := snowflakeFileTransferAgent{
    89  		sc:      sc,
    90  		data:    &data.Data,
    91  		command: query,
    92  		options: new(SnowflakeFileTransferOptions),
    93  	}
    94  	if fs := getFileStream(ctx); fs != nil {
    95  		sfa.sourceStream = fs
    96  		if isInternal {
    97  			sfa.data.AutoCompress = false
    98  		}
    99  	}
   100  	if op := getFileTransferOptions(ctx); op != nil {
   101  		sfa.options = op
   102  	}
   103  	if sfa.options.MultiPartThreshold == 0 {
   104  		sfa.options.MultiPartThreshold = dataSizeThreshold
   105  	}
   106  	if err := sfa.execute(); err != nil {
   107  		return nil, err
   108  	}
   109  	data, err := sfa.result()
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	return data, nil
   114  }
   115  
   116  func getFileStream(ctx context.Context) *bytes.Buffer {
   117  	s := ctx.Value(fileStreamFile)
   118  	r, ok := s.(io.Reader)
   119  	if !ok {
   120  		return nil
   121  	}
   122  	buf := new(bytes.Buffer)
   123  	buf.ReadFrom(r)
   124  	return buf
   125  }
   126  
   127  func getFileTransferOptions(ctx context.Context) *SnowflakeFileTransferOptions {
   128  	v := ctx.Value(fileTransferOptions)
   129  	if v == nil {
   130  		return nil
   131  	}
   132  	o, ok := v.(*SnowflakeFileTransferOptions)
   133  	if !ok {
   134  		return nil
   135  	}
   136  	return o
   137  }
   138  
   139  func (sc *snowflakeConn) populateSessionParameters(parameters []nameValueParameter) {
   140  	// other session parameters (not all)
   141  	logger.WithContext(sc.ctx).Infof("params: %#v", parameters)
   142  	for _, param := range parameters {
   143  		v := ""
   144  		switch param.Value.(type) {
   145  		case int64:
   146  			if vv, ok := param.Value.(int64); ok {
   147  				v = strconv.FormatInt(vv, 10)
   148  			}
   149  		case float64:
   150  			if vv, ok := param.Value.(float64); ok {
   151  				v = strconv.FormatFloat(vv, 'g', -1, 64)
   152  			}
   153  		case bool:
   154  			if vv, ok := param.Value.(bool); ok {
   155  				v = strconv.FormatBool(vv)
   156  			}
   157  		default:
   158  			if vv, ok := param.Value.(string); ok {
   159  				v = vv
   160  			}
   161  		}
   162  		logger.Debugf("parameter. name: %v, value: %v", param.Name, v)
   163  		paramsMutex.Lock()
   164  		sc.cfg.Params[strings.ToLower(param.Name)] = &v
   165  		paramsMutex.Unlock()
   166  	}
   167  }
   168  
   169  func isAsyncMode(ctx context.Context) bool {
   170  	val := ctx.Value(asyncMode)
   171  	if val == nil {
   172  		return false
   173  	}
   174  	a, ok := val.(bool)
   175  	return ok && a
   176  }
   177  
   178  func isDescribeOnly(ctx context.Context) bool {
   179  	v := ctx.Value(describeOnly)
   180  	if v == nil {
   181  		return false
   182  	}
   183  	d, ok := v.(bool)
   184  	return ok && d
   185  }
   186  
   187  func setResultType(ctx context.Context, resType resultType) context.Context {
   188  	return context.WithValue(ctx, snowflakeResultType, resType)
   189  }
   190  
   191  func getResultType(ctx context.Context) resultType {
   192  	return ctx.Value(snowflakeResultType).(resultType)
   193  }
   194  
   195  // isDml returns true if the statement type code is in the range of DML.
   196  func isDml(v int64) bool {
   197  	return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert
   198  }
   199  
   200  func isDql(data *execResponseData) bool {
   201  	return data.StatementTypeID == statementTypeIDSelect && !isMultiStmt(data)
   202  }
   203  
   204  func updateRows(data execResponseData) (int64, error) {
   205  	if data.RowSet == nil {
   206  		return 0, nil
   207  	}
   208  	var count int64
   209  	for i, n := 0, len(data.RowType); i < n; i++ {
   210  		v, err := strconv.ParseInt(*data.RowSet[0][i], 10, 64)
   211  		if err != nil {
   212  			return -1, err
   213  		}
   214  		count += v
   215  	}
   216  	return count, nil
   217  }
   218  
   219  // isMultiStmt returns true if the statement code is of type multistatement
   220  // Note that the statement type code is also equivalent to type INSERT, so an
   221  // additional check of the name is required
   222  func isMultiStmt(data *execResponseData) bool {
   223  	var isMultistatementByReturningSelect = data.StatementTypeID == statementTypeIDSelect && data.RowType[0].Name == "multiple statement execution"
   224  	return isMultistatementByReturningSelect || data.StatementTypeID == statementTypeIDMultistatement
   225  }
   226  
   227  func getResumeQueryID(ctx context.Context) (string, error) {
   228  	val := ctx.Value(fetchResultByID)
   229  	if val == nil {
   230  		return "", nil
   231  	}
   232  	strVal, ok := val.(string)
   233  	if !ok {
   234  		return "", fmt.Errorf("failed to cast val %+v to string", val)
   235  	}
   236  	// so there is a queryID in context for which we want to fetch the result
   237  	if !queryIDRegexp.MatchString(strVal) {
   238  		return strVal, &SnowflakeError{
   239  			Number:  ErrQueryIDFormat,
   240  			Message: "Invalid QID",
   241  			QueryID: strVal,
   242  		}
   243  	}
   244  	return strVal, nil
   245  }
   246  
   247  // returns snowflake chunk downloader by default or stream based chunk
   248  // downloader if option provided through context
   249  func populateChunkDownloader(
   250  	ctx context.Context,
   251  	sc *snowflakeConn,
   252  	data execResponseData) chunkDownloader {
   253  	if useStreamDownloader(ctx) && resultFormat(data.QueryResultFormat) == jsonFormat {
   254  		// stream chunk downloading only works for row based data formats, i.e. json
   255  		fetcher := &httpStreamChunkFetcher{
   256  			ctx:      ctx,
   257  			client:   sc.rest.Client,
   258  			clientIP: sc.cfg.ClientIP,
   259  			headers:  data.ChunkHeaders,
   260  			qrmk:     data.Qrmk,
   261  		}
   262  		return newStreamChunkDownloader(ctx, fetcher, data.Total, data.RowType,
   263  			data.RowSet, data.Chunks)
   264  	}
   265  
   266  	return &snowflakeChunkDownloader{
   267  		sc:                 sc,
   268  		ctx:                ctx,
   269  		pool:               getAllocator(ctx),
   270  		CurrentChunk:       make([]chunkRowType, len(data.RowSet)),
   271  		ChunkMetas:         data.Chunks,
   272  		Total:              data.Total,
   273  		TotalRowIndex:      int64(-1),
   274  		CellCount:          len(data.RowType),
   275  		Qrmk:               data.Qrmk,
   276  		QueryResultFormat:  data.QueryResultFormat,
   277  		ChunkHeader:        data.ChunkHeaders,
   278  		FuncDownload:       downloadChunk,
   279  		FuncDownloadHelper: downloadChunkHelper,
   280  		FuncGet:            getChunk,
   281  		RowSet: rowSetType{
   282  			RowType:      data.RowType,
   283  			JSON:         data.RowSet,
   284  			RowSetBase64: data.RowSetBase64,
   285  		},
   286  	}
   287  }
   288  
   289  func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error {
   290  	ocspCacheServer := fmt.Sprintf("http://ocsp.%v/ocsp_response_cache.json", host)
   291  	logger.Debugf("OCSP Cache Server for Privatelink: %v\n", ocspCacheServer)
   292  	if err := os.Setenv(cacheServerURLEnv, ocspCacheServer); err != nil {
   293  		return err
   294  	}
   295  	ocspRetryHostTemplate := fmt.Sprintf("http://ocsp.%v/retry/", host) + "%v/%v"
   296  	logger.Debugf("OCSP Retry URL for Privatelink: %v\n", ocspRetryHostTemplate)
   297  	if err := os.Setenv(ocspRetryURLEnv, ocspRetryHostTemplate); err != nil {
   298  		return err
   299  	}
   300  	return nil
   301  }
   302  
   303  func isStatementContext(ctx context.Context) bool {
   304  	v := ctx.Value(executionType)
   305  	return v == executionTypeStatement
   306  }