github.com/snowflakedb/gosnowflake@v1.9.0/rows.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"database/sql/driver"
     7  	"io"
     8  	"reflect"
     9  	"strings"
    10  	"time"
    11  )
    12  
    13  const (
    14  	headerSseCAlgorithm = "x-amz-server-side-encryption-customer-algorithm"
    15  	headerSseCKey       = "x-amz-server-side-encryption-customer-key"
    16  	headerSseCAes       = "AES256"
    17  )
    18  
    19  var (
    20  	// MaxChunkDownloadWorkers specifies the maximum number of goroutines used to download chunks
    21  	MaxChunkDownloadWorkers = 10
    22  
    23  	// CustomJSONDecoderEnabled has the chunk downloader use the custom JSON decoder to reduce memory footprint.
    24  	CustomJSONDecoderEnabled = false
    25  )
    26  
    27  var (
    28  	maxChunkDownloaderErrorCounter = 5
    29  )
    30  
    31  // SnowflakeRows provides an API for methods exposed to the clients
    32  type SnowflakeRows interface {
    33  	GetQueryID() string
    34  	GetStatus() queryStatus
    35  	GetArrowBatches() ([]*ArrowBatch, error)
    36  }
    37  
    38  type snowflakeRows struct {
    39  	sc                  *snowflakeConn
    40  	ChunkDownloader     chunkDownloader
    41  	tailChunkDownloader chunkDownloader
    42  	queryID             string
    43  	status              queryStatus
    44  	err                 error
    45  	errChannel          chan error
    46  	location            *time.Location
    47  }
    48  
    49  func (rows *snowflakeRows) getLocation() *time.Location {
    50  	if rows.location == nil && rows.sc != nil && rows.sc.cfg != nil {
    51  		rows.location = getCurrentLocation(rows.sc.cfg.Params)
    52  	}
    53  	return rows.location
    54  }
    55  
    56  type snowflakeValue interface{}
    57  
    58  type chunkRowType struct {
    59  	RowSet   []*string
    60  	ArrowRow []snowflakeValue
    61  }
    62  
    63  type rowSetType struct {
    64  	RowType      []execResponseRowType
    65  	JSON         [][]*string
    66  	RowSetBase64 string
    67  }
    68  
    69  type chunkError struct {
    70  	Index int
    71  	Error error
    72  }
    73  
    74  func (rows *snowflakeRows) Close() (err error) {
    75  	if err := rows.waitForAsyncQueryStatus(); err != nil {
    76  		return err
    77  	}
    78  	logger.WithContext(rows.sc.ctx).Debugln("Rows.Close")
    79  	return nil
    80  }
    81  
    82  // ColumnTypeDatabaseTypeName returns the database column name.
    83  func (rows *snowflakeRows) ColumnTypeDatabaseTypeName(index int) string {
    84  	if err := rows.waitForAsyncQueryStatus(); err != nil {
    85  		return err.Error()
    86  	}
    87  	return strings.ToUpper(rows.ChunkDownloader.getRowType()[index].Type)
    88  }
    89  
    90  // ColumnTypeLength returns the length of the column
    91  func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) {
    92  	if err := rows.waitForAsyncQueryStatus(); err != nil {
    93  		return 0, false
    94  	}
    95  	if index < 0 || index > len(rows.ChunkDownloader.getRowType()) {
    96  		return 0, false
    97  	}
    98  	switch rows.ChunkDownloader.getRowType()[index].Type {
    99  	case "text", "variant", "object", "array", "binary":
   100  		return rows.ChunkDownloader.getRowType()[index].Length, true
   101  	}
   102  	return 0, false
   103  }
   104  
   105  func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) {
   106  	if err := rows.waitForAsyncQueryStatus(); err != nil {
   107  		return false, false
   108  	}
   109  	if index < 0 || index > len(rows.ChunkDownloader.getRowType()) {
   110  		return false, false
   111  	}
   112  	return rows.ChunkDownloader.getRowType()[index].Nullable, true
   113  }
   114  
   115  func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
   116  	if err := rows.waitForAsyncQueryStatus(); err != nil {
   117  		return 0, 0, false
   118  	}
   119  	rowType := rows.ChunkDownloader.getRowType()
   120  	if index < 0 || index > len(rowType) {
   121  		return 0, 0, false
   122  	}
   123  	switch rowType[index].Type {
   124  	case "fixed":
   125  		return rowType[index].Precision, rowType[index].Scale, true
   126  	case "time":
   127  		return rowType[index].Scale, 0, true
   128  	case "timestamp":
   129  		return rowType[index].Scale, 0, true
   130  	}
   131  	return 0, 0, false
   132  }
   133  
   134  func (rows *snowflakeRows) Columns() []string {
   135  	if err := rows.waitForAsyncQueryStatus(); err != nil {
   136  		return make([]string, 0)
   137  	}
   138  	logger.Debug("Rows.Columns")
   139  	ret := make([]string, len(rows.ChunkDownloader.getRowType()))
   140  	for i, n := 0, len(rows.ChunkDownloader.getRowType()); i < n; i++ {
   141  		ret[i] = rows.ChunkDownloader.getRowType()[i].Name
   142  	}
   143  	return ret
   144  }
   145  
   146  func (rows *snowflakeRows) ColumnTypeScanType(index int) reflect.Type {
   147  	if err := rows.waitForAsyncQueryStatus(); err != nil {
   148  		return nil
   149  	}
   150  	return snowflakeTypeToGo(
   151  		getSnowflakeType(rows.ChunkDownloader.getRowType()[index].Type),
   152  		rows.ChunkDownloader.getRowType()[index].Scale)
   153  }
   154  
   155  func (rows *snowflakeRows) GetQueryID() string {
   156  	return rows.queryID
   157  }
   158  
   159  func (rows *snowflakeRows) GetStatus() queryStatus {
   160  	return rows.status
   161  }
   162  
   163  // GetArrowBatches returns an array of ArrowBatch objects to retrieve data in arrow.Record format
   164  func (rows *snowflakeRows) GetArrowBatches() ([]*ArrowBatch, error) {
   165  	// Wait for all arrow batches before fetching.
   166  	// Otherwise, a panic error "invalid memory address or nil pointer dereference" will be thrown.
   167  	if err := rows.waitForAsyncQueryStatus(); err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	return rows.ChunkDownloader.getArrowBatches(), nil
   172  }
   173  
   174  func (rows *snowflakeRows) Next(dest []driver.Value) (err error) {
   175  	if err = rows.waitForAsyncQueryStatus(); err != nil {
   176  		return err
   177  	}
   178  	row, err := rows.ChunkDownloader.next()
   179  	if err != nil {
   180  		// includes io.EOF
   181  		if err == io.EOF {
   182  			rows.ChunkDownloader.reset()
   183  		}
   184  		return err
   185  	}
   186  
   187  	if rows.ChunkDownloader.getQueryResultFormat() == arrowFormat {
   188  		for i, n := 0, len(row.ArrowRow); i < n; i++ {
   189  			dest[i] = row.ArrowRow[i]
   190  		}
   191  	} else {
   192  		for i, n := 0, len(row.RowSet); i < n; i++ {
   193  			// could move to chunk downloader so that each go routine
   194  			// can convert data
   195  			err = stringToValue(&dest[i], rows.ChunkDownloader.getRowType()[i], row.RowSet[i], rows.getLocation())
   196  			if err != nil {
   197  				return err
   198  			}
   199  		}
   200  	}
   201  	return err
   202  }
   203  
   204  func (rows *snowflakeRows) HasNextResultSet() bool {
   205  	if err := rows.waitForAsyncQueryStatus(); err != nil {
   206  		return false
   207  	}
   208  	return rows.ChunkDownloader.hasNextResultSet()
   209  }
   210  
   211  func (rows *snowflakeRows) NextResultSet() error {
   212  	if err := rows.waitForAsyncQueryStatus(); err != nil {
   213  		return err
   214  	}
   215  	if len(rows.ChunkDownloader.getChunkMetas()) == 0 {
   216  		if rows.ChunkDownloader.getNextChunkDownloader() == nil {
   217  			return io.EOF
   218  		}
   219  		rows.ChunkDownloader = rows.ChunkDownloader.getNextChunkDownloader()
   220  		if err := rows.ChunkDownloader.start(); err != nil {
   221  			return err
   222  		}
   223  	}
   224  	return rows.ChunkDownloader.nextResultSet()
   225  }
   226  
   227  func (rows *snowflakeRows) waitForAsyncQueryStatus() error {
   228  	// if async query, block until query is finished
   229  	if rows.status == QueryStatusInProgress {
   230  		err := <-rows.errChannel
   231  		rows.status = QueryStatusComplete
   232  		if err != nil {
   233  			rows.status = QueryFailed
   234  			rows.err = err
   235  			return rows.err
   236  		}
   237  	} else if rows.status == QueryFailed {
   238  		return rows.err
   239  	}
   240  	return nil
   241  }
   242  
   243  func (rows *snowflakeRows) addDownloader(newDL chunkDownloader) {
   244  	if rows.ChunkDownloader == nil {
   245  		rows.ChunkDownloader = newDL
   246  		rows.tailChunkDownloader = newDL
   247  		return
   248  	}
   249  	rows.tailChunkDownloader.setNextChunkDownloader(newDL)
   250  	rows.tailChunkDownloader = newDL
   251  }