github.com/snowflakedb/gosnowflake@v1.9.0/connection.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bufio"
     7  	"bytes"
     8  	"compress/gzip"
     9  	"context"
    10  	"database/sql"
    11  	"database/sql/driver"
    12  	"encoding/base64"
    13  	"encoding/json"
    14  	"io"
    15  	"net/http"
    16  	"net/url"
    17  	"os"
    18  	"regexp"
    19  	"strconv"
    20  	"strings"
    21  	"sync"
    22  	"sync/atomic"
    23  	"time"
    24  
    25  	"github.com/apache/arrow/go/v15/arrow/ipc"
    26  )
    27  
    28  const (
    29  	httpHeaderContentType      = "Content-Type"
    30  	httpHeaderAccept           = "accept"
    31  	httpHeaderUserAgent        = "User-Agent"
    32  	httpHeaderServiceName      = "X-Snowflake-Service"
    33  	httpHeaderContentLength    = "Content-Length"
    34  	httpHeaderHost             = "Host"
    35  	httpHeaderValueOctetStream = "application/octet-stream"
    36  	httpHeaderContentEncoding  = "Content-Encoding"
    37  	httpClientAppID            = "CLIENT_APP_ID"
    38  	httpClientAppVersion       = "CLIENT_APP_VERSION"
    39  )
    40  
    41  const (
    42  	statementTypeIDSelect           = int64(0x1000)
    43  	statementTypeIDDml              = int64(0x3000)
    44  	statementTypeIDMultiTableInsert = statementTypeIDDml + int64(0x500)
    45  	statementTypeIDMultistatement   = int64(0xA000)
    46  )
    47  
    48  const (
    49  	sessionClientSessionKeepAlive          = "client_session_keep_alive"
    50  	sessionClientValidateDefaultParameters = "CLIENT_VALIDATE_DEFAULT_PARAMETERS"
    51  	sessionArrayBindStageThreshold         = "client_stage_array_binding_threshold"
    52  	serviceName                            = "service_name"
    53  )
    54  
    55  type resultType string
    56  
    57  const (
    58  	snowflakeResultType contextKey = "snowflakeResultType"
    59  	execResultType      resultType = "exec"
    60  	queryResultType     resultType = "query"
    61  )
    62  
    63  type execKey string
    64  
    65  const (
    66  	executionType          execKey = "executionType"
    67  	executionTypeStatement string  = "statement"
    68  )
    69  
    70  const privateLinkSuffix = "privatelink.snowflakecomputing.com"
    71  
    72  type snowflakeConn struct {
    73  	ctx                 context.Context
    74  	cfg                 *Config
    75  	rest                *snowflakeRestful
    76  	SequenceCounter     uint64
    77  	telemetry           *snowflakeTelemetry
    78  	internal            InternalClient
    79  	queryContextCache   *queryContextCache
    80  	currentTimeProvider currentTimeProvider
    81  }
    82  
    83  var (
    84  	queryIDPattern = `[\w\-_]+`
    85  	queryIDRegexp  = regexp.MustCompile(queryIDPattern)
    86  )
    87  
    88  func (sc *snowflakeConn) exec(
    89  	ctx context.Context,
    90  	query string,
    91  	noResult bool,
    92  	isInternal bool,
    93  	describeOnly bool,
    94  	bindings []driver.NamedValue) (
    95  	*execResponse, error) {
    96  	var err error
    97  	counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter
    98  
    99  	queryContext, err := buildQueryContext(sc.queryContextCache)
   100  	if err != nil {
   101  		logger.Errorf("error while building query context: %v", err)
   102  	}
   103  	req := execRequest{
   104  		SQLText:      query,
   105  		AsyncExec:    noResult,
   106  		Parameters:   map[string]interface{}{},
   107  		IsInternal:   isInternal,
   108  		DescribeOnly: describeOnly,
   109  		SequenceID:   counter,
   110  		QueryContext: queryContext,
   111  	}
   112  	if key := ctx.Value(multiStatementCount); key != nil {
   113  		req.Parameters[string(multiStatementCount)] = key
   114  	}
   115  	if tag := ctx.Value(queryTag); tag != nil {
   116  		req.Parameters[string(queryTag)] = tag
   117  	}
   118  	logger.WithContext(ctx).Infof("parameters: %v", req.Parameters)
   119  
   120  	// handle bindings, if required
   121  	requestID := getOrGenerateRequestIDFromContext(ctx)
   122  	if len(bindings) > 0 {
   123  		if err = sc.processBindings(ctx, bindings, describeOnly, requestID, &req); err != nil {
   124  			return nil, err
   125  		}
   126  	}
   127  	logger.WithContext(ctx).Infof("bindings: %v", req.Bindings)
   128  
   129  	// populate headers
   130  	headers := getHeaders()
   131  	if isFileTransfer(query) {
   132  		headers[httpHeaderAccept] = headerContentTypeApplicationJSON
   133  	}
   134  	paramsMutex.Lock()
   135  	if serviceName, ok := sc.cfg.Params[serviceName]; ok {
   136  		headers[httpHeaderServiceName] = *serviceName
   137  	}
   138  	paramsMutex.Unlock()
   139  
   140  	jsonBody, err := json.Marshal(req)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	data, err := sc.rest.FuncPostQuery(ctx, sc.rest, &url.Values{}, headers,
   146  		jsonBody, sc.rest.RequestTimeout, requestID, sc.cfg)
   147  	if err != nil {
   148  		return data, err
   149  	}
   150  	code := -1
   151  	if data.Code != "" {
   152  		code, err = strconv.Atoi(data.Code)
   153  		if err != nil {
   154  			return data, err
   155  		}
   156  	}
   157  	logger.WithContext(ctx).Infof("Success: %v, Code: %v", data.Success, code)
   158  	if !data.Success {
   159  		err = (populateErrorFields(code, data)).exceptionTelemetry(sc)
   160  		return nil, err
   161  	}
   162  
   163  	if !sc.cfg.DisableQueryContextCache && data.Data.QueryContext != nil {
   164  		queryContext, err := extractQueryContext(data)
   165  		if err != nil {
   166  			logger.Errorf("error while decoding query context: ", err)
   167  		} else {
   168  			sc.queryContextCache.add(sc, queryContext.Entries...)
   169  		}
   170  	}
   171  
   172  	// handle PUT/GET commands
   173  	if isFileTransfer(query) {
   174  		data, err = sc.processFileTransfer(ctx, data, query, isInternal)
   175  		if err != nil {
   176  			return nil, err
   177  		}
   178  	}
   179  
   180  	logger.WithContext(ctx).Info("Exec/Query SUCCESS")
   181  	if data.Data.FinalDatabaseName != "" {
   182  		sc.cfg.Database = data.Data.FinalDatabaseName
   183  	}
   184  	if data.Data.FinalSchemaName != "" {
   185  		sc.cfg.Schema = data.Data.FinalSchemaName
   186  	}
   187  	if data.Data.FinalWarehouseName != "" {
   188  		sc.cfg.Warehouse = data.Data.FinalWarehouseName
   189  	}
   190  	if data.Data.FinalRoleName != "" {
   191  		sc.cfg.Role = data.Data.FinalRoleName
   192  	}
   193  	sc.populateSessionParameters(data.Data.Parameters)
   194  	return data, err
   195  }
   196  
   197  func extractQueryContext(data *execResponse) (queryContext, error) {
   198  	var queryContext queryContext
   199  	err := json.Unmarshal(data.Data.QueryContext, &queryContext)
   200  	return queryContext, err
   201  }
   202  
   203  func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) {
   204  	rqc := requestQueryContext{}
   205  	if qcc == nil || len(qcc.entries) == 0 {
   206  		logger.Debugf("empty qcc")
   207  		return rqc, nil
   208  	}
   209  	for _, qce := range qcc.entries {
   210  		contextData := contextData{}
   211  		if qce.Context == "" {
   212  			contextData.Base64Data = qce.Context
   213  		}
   214  		rqc.Entries = append(rqc.Entries, requestQueryContextEntry{
   215  			ID:        qce.ID,
   216  			Priority:  qce.Priority,
   217  			Timestamp: qce.Timestamp,
   218  			Context:   contextData,
   219  		})
   220  	}
   221  	return rqc, nil
   222  }
   223  
   224  func (sc *snowflakeConn) Begin() (driver.Tx, error) {
   225  	return sc.BeginTx(sc.ctx, driver.TxOptions{})
   226  }
   227  
   228  func (sc *snowflakeConn) BeginTx(
   229  	ctx context.Context,
   230  	opts driver.TxOptions) (
   231  	driver.Tx, error) {
   232  	logger.WithContext(ctx).Info("BeginTx")
   233  	if opts.ReadOnly {
   234  		return nil, (&SnowflakeError{
   235  			Number:   ErrNoReadOnlyTransaction,
   236  			SQLState: SQLStateFeatureNotSupported,
   237  			Message:  errMsgNoReadOnlyTransaction,
   238  		}).exceptionTelemetry(sc)
   239  	}
   240  	if int(opts.Isolation) != int(sql.LevelDefault) {
   241  		return nil, (&SnowflakeError{
   242  			Number:   ErrNoDefaultTransactionIsolationLevel,
   243  			SQLState: SQLStateFeatureNotSupported,
   244  			Message:  errMsgNoDefaultTransactionIsolationLevel,
   245  		}).exceptionTelemetry(sc)
   246  	}
   247  	if sc.rest == nil {
   248  		return nil, driver.ErrBadConn
   249  	}
   250  	isDesc := isDescribeOnly(ctx)
   251  	if _, err := sc.exec(ctx, "BEGIN", false, /* noResult */
   252  		false /* isInternal */, isDesc, nil); err != nil {
   253  		return nil, err
   254  	}
   255  	return &snowflakeTx{sc, ctx}, nil
   256  }
   257  
   258  func (sc *snowflakeConn) cleanup() {
   259  	// must flush log buffer while the process is running.
   260  	if sc.rest != nil && sc.rest.Client != nil {
   261  		sc.rest.Client.CloseIdleConnections()
   262  	}
   263  	sc.rest = nil
   264  	sc.cfg = nil
   265  }
   266  
   267  func (sc *snowflakeConn) Close() (err error) {
   268  	logger.WithContext(sc.ctx).Infoln("Close")
   269  	sc.telemetry.sendBatch()
   270  	sc.stopHeartBeat()
   271  	defer sc.cleanup()
   272  
   273  	if sc.cfg != nil && !sc.cfg.KeepSessionAlive {
   274  		if err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout); err != nil {
   275  			logger.Error(err)
   276  		}
   277  	}
   278  	return nil
   279  }
   280  
   281  func (sc *snowflakeConn) PrepareContext(
   282  	ctx context.Context,
   283  	query string) (
   284  	driver.Stmt, error) {
   285  	logger.WithContext(sc.ctx).Infoln("Prepare")
   286  	if sc.rest == nil {
   287  		return nil, driver.ErrBadConn
   288  	}
   289  	stmt := &snowflakeStmt{
   290  		sc:    sc,
   291  		query: query,
   292  	}
   293  	return stmt, nil
   294  }
   295  
   296  func (sc *snowflakeConn) ExecContext(
   297  	ctx context.Context,
   298  	query string,
   299  	args []driver.NamedValue) (
   300  	driver.Result, error) {
   301  	logger.WithContext(ctx).Infof("Exec: %#v, %v", query, args)
   302  	if sc.rest == nil {
   303  		return nil, driver.ErrBadConn
   304  	}
   305  	noResult := isAsyncMode(ctx)
   306  	isDesc := isDescribeOnly(ctx)
   307  	// TODO handle isInternal
   308  	ctx = setResultType(ctx, execResultType)
   309  	data, err := sc.exec(ctx, query, noResult, false /* isInternal */, isDesc, args)
   310  	if err != nil {
   311  		logger.WithContext(ctx).Infof("error: %v", err)
   312  		if data != nil {
   313  			code, e := strconv.Atoi(data.Code)
   314  			if e != nil {
   315  				return nil, e
   316  			}
   317  			return nil, (&SnowflakeError{
   318  				Number:   code,
   319  				SQLState: data.Data.SQLState,
   320  				Message:  err.Error(),
   321  				QueryID:  data.Data.QueryID,
   322  			}).exceptionTelemetry(sc)
   323  		}
   324  		return nil, err
   325  	}
   326  
   327  	// if async exec, return result object right away
   328  	if noResult {
   329  		return data.Data.AsyncResult, nil
   330  	}
   331  
   332  	if isDml(data.Data.StatementTypeID) {
   333  		// collects all values from the returned row sets
   334  		updatedRows, err := updateRows(data.Data)
   335  		if err != nil {
   336  			return nil, err
   337  		}
   338  		logger.WithContext(ctx).Debugf("number of updated rows: %#v", updatedRows)
   339  		return &snowflakeResult{
   340  			affectedRows: updatedRows,
   341  			insertID:     -1,
   342  			queryID:      data.Data.QueryID,
   343  		}, nil // last insert id is not supported by Snowflake
   344  	} else if isMultiStmt(&data.Data) {
   345  		return sc.handleMultiExec(ctx, data.Data)
   346  	} else if isDql(&data.Data) {
   347  		logger.WithContext(ctx).Debugf("DQL")
   348  		if isStatementContext(ctx) {
   349  			return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil
   350  		}
   351  		return driver.ResultNoRows, nil
   352  	}
   353  	logger.Debug("DDL")
   354  	if isStatementContext(ctx) {
   355  		return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil
   356  	}
   357  	return driver.ResultNoRows, nil
   358  }
   359  
   360  func (sc *snowflakeConn) QueryContext(
   361  	ctx context.Context,
   362  	query string,
   363  	args []driver.NamedValue) (
   364  	driver.Rows, error) {
   365  	qid, err := getResumeQueryID(ctx)
   366  	if err != nil {
   367  		return nil, err
   368  	}
   369  	if qid == "" {
   370  		return sc.queryContextInternal(ctx, query, args)
   371  	}
   372  
   373  	// check the query status to find out if there is a result to fetch
   374  	_, err = sc.checkQueryStatus(ctx, qid)
   375  	snowflakeErr, isSnowflakeError := err.(*SnowflakeError)
   376  	if err == nil || (isSnowflakeError && snowflakeErr.Number == ErrQueryIsRunning) {
   377  		// the query is running. Rows object will be returned from here.
   378  		return sc.buildRowsForRunningQuery(ctx, qid)
   379  	}
   380  	return nil, err
   381  }
   382  
   383  func (sc *snowflakeConn) queryContextInternal(
   384  	ctx context.Context,
   385  	query string,
   386  	args []driver.NamedValue) (
   387  	driver.Rows, error) {
   388  	logger.WithContext(ctx).Infof("Query: %#v, %v", query, args)
   389  	if sc.rest == nil {
   390  		return nil, driver.ErrBadConn
   391  	}
   392  
   393  	noResult := isAsyncMode(ctx)
   394  	isDesc := isDescribeOnly(ctx)
   395  	ctx = setResultType(ctx, queryResultType)
   396  	// TODO: handle isInternal
   397  	data, err := sc.exec(ctx, query, noResult, false /* isInternal */, isDesc, args)
   398  	if err != nil {
   399  		logger.WithContext(ctx).Errorf("error: %v", err)
   400  		if data != nil {
   401  			code, e := strconv.Atoi(data.Code)
   402  			if e != nil {
   403  				return nil, e
   404  			}
   405  			return nil, (&SnowflakeError{
   406  				Number:   code,
   407  				SQLState: data.Data.SQLState,
   408  				Message:  err.Error(),
   409  				QueryID:  data.Data.QueryID,
   410  			}).exceptionTelemetry(sc)
   411  		}
   412  		return nil, err
   413  	}
   414  
   415  	// if async query, return row object right away
   416  	if noResult {
   417  		return data.Data.AsyncRows, nil
   418  	}
   419  
   420  	rows := new(snowflakeRows)
   421  	rows.sc = sc
   422  	rows.queryID = data.Data.QueryID
   423  
   424  	if isMultiStmt(&data.Data) {
   425  		// handleMultiQuery is responsible to fill rows with childResults
   426  		if err = sc.handleMultiQuery(ctx, data.Data, rows); err != nil {
   427  			return nil, err
   428  		}
   429  	} else {
   430  		rows.addDownloader(populateChunkDownloader(ctx, sc, data.Data))
   431  	}
   432  
   433  	err = rows.ChunkDownloader.start()
   434  	return rows, err
   435  }
   436  
   437  func (sc *snowflakeConn) Prepare(query string) (driver.Stmt, error) {
   438  	return sc.PrepareContext(sc.ctx, query)
   439  }
   440  
   441  func (sc *snowflakeConn) Exec(
   442  	query string,
   443  	args []driver.Value) (
   444  	driver.Result, error) {
   445  	return sc.ExecContext(sc.ctx, query, toNamedValues(args))
   446  }
   447  
   448  func (sc *snowflakeConn) Query(
   449  	query string,
   450  	args []driver.Value) (
   451  	driver.Rows, error) {
   452  	return sc.QueryContext(sc.ctx, query, toNamedValues(args))
   453  }
   454  
   455  func (sc *snowflakeConn) Ping(ctx context.Context) error {
   456  	logger.WithContext(ctx).Infoln("Ping")
   457  	if sc.rest == nil {
   458  		return driver.ErrBadConn
   459  	}
   460  	noResult := isAsyncMode(ctx)
   461  	isDesc := isDescribeOnly(ctx)
   462  	// TODO: handle isInternal
   463  	ctx = setResultType(ctx, execResultType)
   464  	_, err := sc.exec(ctx, "SELECT 1", noResult, false, /* isInternal */
   465  		isDesc, []driver.NamedValue{})
   466  	return err
   467  }
   468  
   469  // CheckNamedValue determines which types are handled by this driver aside from
   470  // the instances captured by driver.Value
   471  func (sc *snowflakeConn) CheckNamedValue(nv *driver.NamedValue) error {
   472  	if supportedNullBind(nv) || supportedArrayBind(nv) {
   473  		return nil
   474  	}
   475  	return driver.ErrSkip
   476  }
   477  
   478  func (sc *snowflakeConn) GetQueryStatus(
   479  	ctx context.Context,
   480  	queryID string) (
   481  	*SnowflakeQueryStatus, error) {
   482  	queryRet, err := sc.checkQueryStatus(ctx, queryID)
   483  	if err != nil {
   484  		return nil, err
   485  	}
   486  	return &SnowflakeQueryStatus{
   487  		queryRet.SQLText,
   488  		queryRet.StartTime,
   489  		queryRet.EndTime,
   490  		queryRet.ErrorCode,
   491  		queryRet.ErrorMessage,
   492  		queryRet.Stats.ScanBytes,
   493  		queryRet.Stats.ProducedRows,
   494  	}, nil
   495  }
   496  
   497  // QueryArrowStream returns batches which can be queried for their raw arrow
   498  // ipc stream of bytes. This way consumers don't need to be using the exact
   499  // same version of Arrow as the connection is using internally in order
   500  // to consume Arrow data.
   501  func (sc *snowflakeConn) QueryArrowStream(ctx context.Context, query string, bindings ...driver.NamedValue) (ArrowStreamLoader, error) {
   502  	ctx = WithArrowBatches(context.WithValue(ctx, asyncMode, false))
   503  	ctx = setResultType(ctx, queryResultType)
   504  	isDesc := isDescribeOnly(ctx)
   505  	data, err := sc.exec(ctx, query, false, false /* isinternal */, isDesc, bindings)
   506  	if err != nil {
   507  		logger.WithContext(ctx).Errorf("error: %v", err)
   508  		if data != nil {
   509  			code, e := strconv.Atoi(data.Code)
   510  			if e != nil {
   511  				return nil, e
   512  			}
   513  			return nil, (&SnowflakeError{
   514  				Number:   code,
   515  				SQLState: data.Data.SQLState,
   516  				Message:  err.Error(),
   517  				QueryID:  data.Data.QueryID,
   518  			}).exceptionTelemetry(sc)
   519  		}
   520  		return nil, err
   521  	}
   522  
   523  	return &snowflakeArrowStreamChunkDownloader{
   524  		sc:          sc,
   525  		ChunkMetas:  data.Data.Chunks,
   526  		Total:       data.Data.Total,
   527  		Qrmk:        data.Data.Qrmk,
   528  		ChunkHeader: data.Data.ChunkHeaders,
   529  		FuncGet:     getChunk,
   530  		RowSet: rowSetType{
   531  			RowType:      data.Data.RowType,
   532  			JSON:         data.Data.RowSet,
   533  			RowSetBase64: data.Data.RowSetBase64,
   534  		},
   535  	}, nil
   536  }
   537  
   538  // ArrowStreamBatch is a type describing a potentially yet-to-be-downloaded
   539  // Arrow IPC stream. Call `GetStream` to download and retrieve an io.Reader
   540  // that can be used with ipc.NewReader to get record batch results.
   541  type ArrowStreamBatch struct {
   542  	idx     int
   543  	numrows int64
   544  	scd     *snowflakeArrowStreamChunkDownloader
   545  	Loc     *time.Location
   546  	rr      io.ReadCloser
   547  }
   548  
   549  // NumRows returns the total number of rows that the metadata stated should
   550  // be in this stream of record batches.
   551  func (asb *ArrowStreamBatch) NumRows() int64 { return asb.numrows }
   552  
   553  // gzip.Reader.Close does NOT close the underlying reader, so we
   554  // need to wrap with wrapReader so that closing will close the
   555  // response body (or any other reader that we want to gzip uncompress)
   556  type wrapReader struct {
   557  	io.Reader
   558  	wrapped io.ReadCloser
   559  }
   560  
   561  func (w *wrapReader) Close() error {
   562  	if cl, ok := w.Reader.(io.ReadCloser); ok {
   563  		if err := cl.Close(); err != nil {
   564  			return err
   565  		}
   566  	}
   567  	return w.wrapped.Close()
   568  }
   569  
   570  func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) error {
   571  	headers := make(map[string]string)
   572  	if len(asb.scd.ChunkHeader) > 0 {
   573  		logger.Debug("chunk header is provided")
   574  		for k, v := range asb.scd.ChunkHeader {
   575  			logger.Debugf("adding header: %v, value: %v", k, v)
   576  
   577  			headers[k] = v
   578  		}
   579  	} else {
   580  		headers[headerSseCAlgorithm] = headerSseCAes
   581  		headers[headerSseCKey] = asb.scd.Qrmk
   582  	}
   583  
   584  	resp, err := asb.scd.FuncGet(ctx, asb.scd.sc, asb.scd.ChunkMetas[asb.idx].URL, headers, asb.scd.sc.rest.RequestTimeout)
   585  	if err != nil {
   586  		return err
   587  	}
   588  	logger.Debugf("response returned chunk: %v for URL: %v", asb.idx+1, asb.scd.ChunkMetas[asb.idx].URL)
   589  	if resp.StatusCode != http.StatusOK {
   590  		defer resp.Body.Close()
   591  		b, err := io.ReadAll(resp.Body)
   592  		if err != nil {
   593  			return err
   594  		}
   595  
   596  		logger.Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, asb.scd.ChunkMetas[asb.idx].URL, b)
   597  		logger.Infof("Header: %v", resp.Header)
   598  		return &SnowflakeError{
   599  			Number:      ErrFailedToGetChunk,
   600  			SQLState:    SQLStateConnectionFailure,
   601  			Message:     errMsgFailedToGetChunk,
   602  			MessageArgs: []interface{}{asb.idx},
   603  		}
   604  	}
   605  
   606  	defer func() {
   607  		if asb.rr == nil {
   608  			resp.Body.Close()
   609  		}
   610  	}()
   611  
   612  	bufStream := bufio.NewReader(resp.Body)
   613  	gzipMagic, err := bufStream.Peek(2)
   614  	if err != nil {
   615  		return err
   616  	}
   617  
   618  	if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b {
   619  		// detect and uncompress gzip
   620  		bufStream0, err := gzip.NewReader(bufStream)
   621  		if err != nil {
   622  			return err
   623  		}
   624  		// gzip.Reader.Close() does NOT close the underlying
   625  		// reader, so we need to wrap it and ensure close will
   626  		// close the response body. Otherwise we'll leak it.
   627  		asb.rr = &wrapReader{Reader: bufStream0, wrapped: resp.Body}
   628  	} else {
   629  		asb.rr = &wrapReader{Reader: bufStream, wrapped: resp.Body}
   630  	}
   631  	return nil
   632  }
   633  
   634  // GetStream returns a stream of bytes consisting of an Arrow IPC Record
   635  // batch stream. Close should be called on the returned stream when done
   636  // to ensure no leaked memory.
   637  func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, error) {
   638  	if asb.rr == nil {
   639  		if err := asb.downloadChunkStreamHelper(ctx); err != nil {
   640  			return nil, err
   641  		}
   642  	}
   643  
   644  	return asb.rr, nil
   645  }
   646  
   647  // ArrowStreamLoader is a convenience interface for downloading
   648  // Snowflake results via multiple Arrow Record Batch streams.
   649  //
   650  // Some queries from Snowflake do not return Arrow data regardless
   651  // of the settings, such as "SHOW WAREHOUSES". In these cases,
   652  // you'll find TotalRows() > 0 but GetBatches returns no batches
   653  // and no errors. In this case, the data is accessible via JSONData
   654  // with the actual types matching up to the metadata in RowTypes.
   655  type ArrowStreamLoader interface {
   656  	GetBatches() ([]ArrowStreamBatch, error)
   657  	TotalRows() int64
   658  	RowTypes() []execResponseRowType
   659  	Location() *time.Location
   660  	JSONData() [][]*string
   661  }
   662  
   663  type snowflakeArrowStreamChunkDownloader struct {
   664  	sc          *snowflakeConn
   665  	ChunkMetas  []execResponseChunk
   666  	Total       int64
   667  	Qrmk        string
   668  	ChunkHeader map[string]string
   669  	FuncGet     func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error)
   670  	RowSet      rowSetType
   671  }
   672  
   673  func (scd *snowflakeArrowStreamChunkDownloader) Location() *time.Location {
   674  	if scd.sc != nil {
   675  		return getCurrentLocation(scd.sc.cfg.Params)
   676  	}
   677  	return nil
   678  }
   679  func (scd *snowflakeArrowStreamChunkDownloader) TotalRows() int64 { return scd.Total }
   680  func (scd *snowflakeArrowStreamChunkDownloader) RowTypes() []execResponseRowType {
   681  	return scd.RowSet.RowType
   682  }
   683  func (scd *snowflakeArrowStreamChunkDownloader) JSONData() [][]*string {
   684  	return scd.RowSet.JSON
   685  }
   686  
   687  // the server might have had an empty first batch, check if we can decode
   688  // that first batch, if not we skip it.
   689  func (scd *snowflakeArrowStreamChunkDownloader) maybeFirstBatch() []byte {
   690  	if scd.RowSet.RowSetBase64 == "" {
   691  		return nil
   692  	}
   693  
   694  	// first batch
   695  	rowSetBytes, err := base64.StdEncoding.DecodeString(scd.RowSet.RowSetBase64)
   696  	if err != nil {
   697  		// match logic in buildFirstArrowChunk
   698  		// assume there's no first chunk if we can't decode the base64 string
   699  		return nil
   700  	}
   701  
   702  	// verify it's a valid ipc stream, otherwise skip it
   703  	rr, err := ipc.NewReader(bytes.NewReader(rowSetBytes))
   704  	if err != nil {
   705  		return nil
   706  	}
   707  	rr.Release()
   708  
   709  	return rowSetBytes
   710  }
   711  
   712  func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamBatch, err error) {
   713  	chunkMetaLen := len(scd.ChunkMetas)
   714  	loc := scd.Location()
   715  
   716  	out = make([]ArrowStreamBatch, chunkMetaLen, chunkMetaLen+1)
   717  	toFill := out
   718  	rowSetBytes := scd.maybeFirstBatch()
   719  	// if there was no first batch in the response from the server,
   720  	// skip it and move on. toFill == out
   721  	// otherwise expand out by one to account for the first batch
   722  	// and fill it in. have toFill refer to the slice of out excluding
   723  	// the first batch.
   724  	if len(rowSetBytes) > 0 {
   725  		out = out[:chunkMetaLen+1]
   726  		out[0] = ArrowStreamBatch{
   727  			scd: scd,
   728  			Loc: loc,
   729  			rr:  io.NopCloser(bytes.NewReader(rowSetBytes)),
   730  		}
   731  		toFill = out[1:]
   732  	}
   733  
   734  	var totalCounted int64
   735  	for i := range toFill {
   736  		toFill[i] = ArrowStreamBatch{
   737  			idx:     i,
   738  			numrows: int64(scd.ChunkMetas[i].RowCount),
   739  			Loc:     loc,
   740  			scd:     scd,
   741  		}
   742  		totalCounted += int64(scd.ChunkMetas[i].RowCount)
   743  	}
   744  
   745  	if len(rowSetBytes) > 0 {
   746  		// if we had a first batch, fill in the numrows
   747  		out[0].numrows = scd.Total - totalCounted
   748  	}
   749  	return
   750  }
   751  
   752  func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) {
   753  	sc := &snowflakeConn{
   754  		SequenceCounter:     0,
   755  		ctx:                 ctx,
   756  		cfg:                 &config,
   757  		queryContextCache:   (&queryContextCache{}).init(),
   758  		currentTimeProvider: defaultTimeProvider,
   759  	}
   760  	err := initEasyLogging(config.ClientConfigFile)
   761  	if err != nil {
   762  		return nil, err
   763  	}
   764  	var st http.RoundTripper = SnowflakeTransport
   765  	if sc.cfg.Transporter == nil {
   766  		if sc.cfg.InsecureMode {
   767  			// no revocation check with OCSP. Think twice when you want to enable this option.
   768  			st = snowflakeInsecureTransport
   769  		} else {
   770  			// set OCSP fail open mode
   771  			ocspResponseCacheLock.Lock()
   772  			atomic.StoreUint32((*uint32)(&ocspFailOpen), uint32(sc.cfg.OCSPFailOpen))
   773  			ocspResponseCacheLock.Unlock()
   774  		}
   775  	} else {
   776  		// use the custom transport
   777  		st = sc.cfg.Transporter
   778  	}
   779  	if strings.HasSuffix(sc.cfg.Host, privateLinkSuffix) {
   780  		if err := sc.setupOCSPPrivatelink(sc.cfg.Application, sc.cfg.Host); err != nil {
   781  			return nil, err
   782  		}
   783  	} else {
   784  		if _, set := os.LookupEnv(cacheServerURLEnv); set {
   785  			os.Unsetenv(cacheServerURLEnv)
   786  		}
   787  	}
   788  	var tokenAccessor TokenAccessor
   789  	if sc.cfg.TokenAccessor != nil {
   790  		tokenAccessor = sc.cfg.TokenAccessor
   791  	} else {
   792  		tokenAccessor = getSimpleTokenAccessor()
   793  	}
   794  
   795  	// authenticate
   796  	sc.rest = &snowflakeRestful{
   797  		Host:     sc.cfg.Host,
   798  		Port:     sc.cfg.Port,
   799  		Protocol: sc.cfg.Protocol,
   800  		Client: &http.Client{
   801  			// request timeout including reading response body
   802  			Timeout:   sc.cfg.ClientTimeout,
   803  			Transport: st,
   804  		},
   805  		JWTClient: &http.Client{
   806  			Timeout:   sc.cfg.JWTClientTimeout,
   807  			Transport: st,
   808  		},
   809  		TokenAccessor:       tokenAccessor,
   810  		LoginTimeout:        sc.cfg.LoginTimeout,
   811  		RequestTimeout:      sc.cfg.RequestTimeout,
   812  		MaxRetryCount:       sc.cfg.MaxRetryCount,
   813  		FuncPost:            postRestful,
   814  		FuncGet:             getRestful,
   815  		FuncAuthPost:        postAuthRestful,
   816  		FuncPostQuery:       postRestfulQuery,
   817  		FuncPostQueryHelper: postRestfulQueryHelper,
   818  		FuncRenewSession:    renewRestfulSession,
   819  		FuncPostAuth:        postAuth,
   820  		FuncCloseSession:    closeSession,
   821  		FuncCancelQuery:     cancelQuery,
   822  		FuncPostAuthSAML:    postAuthSAML,
   823  		FuncPostAuthOKTA:    postAuthOKTA,
   824  		FuncGetSSO:          getSSO,
   825  	}
   826  
   827  	if sc.cfg.DisableTelemetry {
   828  		sc.telemetry = &snowflakeTelemetry{enabled: false}
   829  	} else {
   830  		sc.telemetry = &snowflakeTelemetry{
   831  			flushSize: defaultFlushSize,
   832  			sr:        sc.rest,
   833  			mutex:     &sync.Mutex{},
   834  			enabled:   true,
   835  		}
   836  	}
   837  
   838  	return sc, nil
   839  }