github.com/snowflakedb/gosnowflake@v1.9.0/chunk_downloader.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bufio"
     7  	"compress/gzip"
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"io"
    12  	"math/rand"
    13  	"net"
    14  	"net/http"
    15  	"net/url"
    16  	"strings"
    17  	"sync"
    18  	"time"
    19  
    20  	"github.com/apache/arrow/go/v15/arrow"
    21  	"github.com/apache/arrow/go/v15/arrow/ipc"
    22  	"github.com/apache/arrow/go/v15/arrow/memory"
    23  )
    24  
    25  type chunkDownloader interface {
    26  	totalUncompressedSize() (acc int64)
    27  	hasNextResultSet() bool
    28  	nextResultSet() error
    29  	start() error
    30  	next() (chunkRowType, error)
    31  	reset()
    32  	getChunkMetas() []execResponseChunk
    33  	getQueryResultFormat() resultFormat
    34  	getRowType() []execResponseRowType
    35  	setNextChunkDownloader(downloader chunkDownloader)
    36  	getNextChunkDownloader() chunkDownloader
    37  	getArrowBatches() []*ArrowBatch
    38  }
    39  
    40  type snowflakeChunkDownloader struct {
    41  	sc                 *snowflakeConn
    42  	ctx                context.Context
    43  	pool               memory.Allocator
    44  	Total              int64
    45  	TotalRowIndex      int64
    46  	CellCount          int
    47  	CurrentChunk       []chunkRowType
    48  	CurrentChunkIndex  int
    49  	CurrentChunkSize   int
    50  	CurrentIndex       int
    51  	ChunkHeader        map[string]string
    52  	ChunkMetas         []execResponseChunk
    53  	Chunks             map[int][]chunkRowType
    54  	ChunksChan         chan int
    55  	ChunksError        chan *chunkError
    56  	ChunksErrorCounter int
    57  	ChunksFinalErrors  []*chunkError
    58  	ChunksMutex        *sync.Mutex
    59  	DoneDownloadCond   *sync.Cond
    60  	FirstBatch         *ArrowBatch
    61  	NextDownloader     chunkDownloader
    62  	Qrmk               string
    63  	QueryResultFormat  string
    64  	ArrowBatches       []*ArrowBatch
    65  	RowSet             rowSetType
    66  	FuncDownload       func(context.Context, *snowflakeChunkDownloader, int)
    67  	FuncDownloadHelper func(context.Context, *snowflakeChunkDownloader, int) error
    68  	FuncGet            func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error)
    69  }
    70  
    71  func (scd *snowflakeChunkDownloader) totalUncompressedSize() (acc int64) {
    72  	for _, c := range scd.ChunkMetas {
    73  		acc += c.UncompressedSize
    74  	}
    75  	return
    76  }
    77  
    78  func (scd *snowflakeChunkDownloader) hasNextResultSet() bool {
    79  	if len(scd.ChunkMetas) == 0 && scd.NextDownloader == nil {
    80  		return false // no extra chunk
    81  	}
    82  	// next result set exists if current chunk has remaining result sets or there is another downloader
    83  	return scd.CurrentChunkIndex < len(scd.ChunkMetas) || scd.NextDownloader != nil
    84  }
    85  
    86  func (scd *snowflakeChunkDownloader) nextResultSet() error {
    87  	// no error at all times as the next chunk/resultset is automatically read
    88  	if scd.CurrentChunkIndex < len(scd.ChunkMetas) {
    89  		return nil
    90  	}
    91  	return io.EOF
    92  }
    93  
    94  func (scd *snowflakeChunkDownloader) start() error {
    95  	if usesArrowBatches(scd.ctx) {
    96  		return scd.startArrowBatches()
    97  	}
    98  	scd.CurrentChunkSize = len(scd.RowSet.JSON) // cache the size
    99  	scd.CurrentIndex = -1                       // initial chunks idx
   100  	scd.CurrentChunkIndex = -1                  // initial chunk
   101  
   102  	scd.CurrentChunk = make([]chunkRowType, scd.CurrentChunkSize)
   103  	populateJSONRowSet(scd.CurrentChunk, scd.RowSet.JSON)
   104  
   105  	if scd.getQueryResultFormat() == arrowFormat && scd.RowSet.RowSetBase64 != "" {
   106  		// if the rowsetbase64 retrieved from the server is empty, move on to downloading chunks
   107  		var err error
   108  		var loc *time.Location
   109  		if scd.sc != nil && scd.sc.cfg != nil {
   110  			loc = getCurrentLocation(scd.sc.cfg.Params)
   111  		}
   112  		firstArrowChunk, err := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool)
   113  		if err != nil {
   114  			return err
   115  		}
   116  		higherPrecision := higherPrecisionEnabled(scd.ctx)
   117  		scd.CurrentChunk, err = firstArrowChunk.decodeArrowChunk(scd.RowSet.RowType, higherPrecision)
   118  		scd.CurrentChunkSize = firstArrowChunk.rowCount
   119  		if err != nil {
   120  			return err
   121  		}
   122  	}
   123  
   124  	// start downloading chunks if exists
   125  	chunkMetaLen := len(scd.ChunkMetas)
   126  	if chunkMetaLen > 0 {
   127  		logger.Debugf("MaxChunkDownloadWorkers: %v", MaxChunkDownloadWorkers)
   128  		logger.Debugf("chunks: %v, total bytes: %d", chunkMetaLen, scd.totalUncompressedSize())
   129  		scd.ChunksMutex = &sync.Mutex{}
   130  		scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex)
   131  		scd.Chunks = make(map[int][]chunkRowType)
   132  		scd.ChunksChan = make(chan int, chunkMetaLen)
   133  		scd.ChunksError = make(chan *chunkError, MaxChunkDownloadWorkers)
   134  		for i := 0; i < chunkMetaLen; i++ {
   135  			chunk := scd.ChunkMetas[i]
   136  			logger.Debugf("add chunk to channel ChunksChan: %v, URL: %v, RowCount: %v, UncompressedSize: %v, ChunkResultFormat: %v",
   137  				i+1, chunk.URL, chunk.RowCount, chunk.UncompressedSize, scd.QueryResultFormat)
   138  			scd.ChunksChan <- i
   139  		}
   140  		for i := 0; i < intMin(MaxChunkDownloadWorkers, chunkMetaLen); i++ {
   141  			scd.schedule()
   142  		}
   143  	}
   144  	return nil
   145  }
   146  
   147  func (scd *snowflakeChunkDownloader) schedule() {
   148  	select {
   149  	case nextIdx := <-scd.ChunksChan:
   150  		logger.Infof("schedule chunk: %v", nextIdx+1)
   151  		go scd.FuncDownload(scd.ctx, scd, nextIdx)
   152  	default:
   153  		// no more download
   154  		logger.Info("no more download")
   155  	}
   156  }
   157  
   158  func (scd *snowflakeChunkDownloader) checkErrorRetry() (err error) {
   159  	select {
   160  	case errc := <-scd.ChunksError:
   161  		if scd.ChunksErrorCounter < maxChunkDownloaderErrorCounter &&
   162  			errc.Error != context.Canceled &&
   163  			errc.Error != context.DeadlineExceeded {
   164  			// add the index to the chunks channel so that the download will be retried.
   165  			go scd.FuncDownload(scd.ctx, scd, errc.Index)
   166  			scd.ChunksErrorCounter++
   167  			logger.Warningf("chunk idx: %v, err: %v. retrying (%v/%v)...",
   168  				errc.Index, errc.Error, scd.ChunksErrorCounter, maxChunkDownloaderErrorCounter)
   169  		} else {
   170  			scd.ChunksFinalErrors = append(scd.ChunksFinalErrors, errc)
   171  			logger.Warningf("chunk idx: %v, err: %v. no further retry", errc.Index, errc.Error)
   172  			return errc.Error
   173  		}
   174  	default:
   175  		logger.Info("no error is detected.")
   176  	}
   177  	return nil
   178  }
   179  
   180  func (scd *snowflakeChunkDownloader) next() (chunkRowType, error) {
   181  	for {
   182  		scd.CurrentIndex++
   183  		if scd.CurrentIndex < scd.CurrentChunkSize {
   184  			return scd.CurrentChunk[scd.CurrentIndex], nil
   185  		}
   186  		scd.CurrentChunkIndex++ // next chunk
   187  		scd.CurrentIndex = -1   // reset
   188  		if scd.CurrentChunkIndex >= len(scd.ChunkMetas) {
   189  			break
   190  		}
   191  
   192  		scd.ChunksMutex.Lock()
   193  		if scd.CurrentChunkIndex > 0 {
   194  			scd.Chunks[scd.CurrentChunkIndex-1] = nil // detach the previously used chunk
   195  		}
   196  
   197  		for scd.Chunks[scd.CurrentChunkIndex] == nil {
   198  			logger.Debugf("waiting for chunk idx: %v/%v",
   199  				scd.CurrentChunkIndex+1, len(scd.ChunkMetas))
   200  
   201  			if err := scd.checkErrorRetry(); err != nil {
   202  				scd.ChunksMutex.Unlock()
   203  				return chunkRowType{}, err
   204  			}
   205  
   206  			// wait for chunk downloader goroutine to broadcast the event,
   207  			// 1) one chunk download finishes or 2) an error occurs.
   208  			scd.DoneDownloadCond.Wait()
   209  		}
   210  		logger.Debugf("ready: chunk %v", scd.CurrentChunkIndex+1)
   211  		scd.CurrentChunk = scd.Chunks[scd.CurrentChunkIndex]
   212  		scd.ChunksMutex.Unlock()
   213  		scd.CurrentChunkSize = len(scd.CurrentChunk)
   214  
   215  		// kick off the next download
   216  		scd.schedule()
   217  	}
   218  
   219  	logger.Debugf("no more data")
   220  	if len(scd.ChunkMetas) > 0 {
   221  		close(scd.ChunksError)
   222  		close(scd.ChunksChan)
   223  	}
   224  	return chunkRowType{}, io.EOF
   225  }
   226  
   227  func (scd *snowflakeChunkDownloader) reset() {
   228  	scd.Chunks = nil // detach all chunks. No way to go backward without reinitialize it.
   229  }
   230  
   231  func (scd *snowflakeChunkDownloader) getChunkMetas() []execResponseChunk {
   232  	return scd.ChunkMetas
   233  }
   234  
   235  func (scd *snowflakeChunkDownloader) getQueryResultFormat() resultFormat {
   236  	return resultFormat(scd.QueryResultFormat)
   237  }
   238  
   239  func (scd *snowflakeChunkDownloader) setNextChunkDownloader(nextDownloader chunkDownloader) {
   240  	scd.NextDownloader = nextDownloader
   241  }
   242  
   243  func (scd *snowflakeChunkDownloader) getNextChunkDownloader() chunkDownloader {
   244  	return scd.NextDownloader
   245  }
   246  
   247  func (scd *snowflakeChunkDownloader) getRowType() []execResponseRowType {
   248  	return scd.RowSet.RowType
   249  }
   250  
   251  func (scd *snowflakeChunkDownloader) getArrowBatches() []*ArrowBatch {
   252  	if scd.FirstBatch == nil || scd.FirstBatch.rec == nil {
   253  		return scd.ArrowBatches
   254  	}
   255  	return append([]*ArrowBatch{scd.FirstBatch}, scd.ArrowBatches...)
   256  }
   257  
   258  func getChunk(
   259  	ctx context.Context,
   260  	sc *snowflakeConn,
   261  	fullURL string,
   262  	headers map[string]string,
   263  	timeout time.Duration) (
   264  	*http.Response, error,
   265  ) {
   266  	u, err := url.Parse(fullURL)
   267  	if err != nil {
   268  		return nil, err
   269  	}
   270  	return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.rest.MaxRetryCount, sc.currentTimeProvider, sc.cfg).execute()
   271  }
   272  
   273  func (scd *snowflakeChunkDownloader) startArrowBatches() error {
   274  	var loc *time.Location
   275  	if scd.sc != nil && scd.sc.cfg != nil {
   276  		loc = getCurrentLocation(scd.sc.cfg.Params)
   277  	}
   278  	if scd.RowSet.RowSetBase64 != "" {
   279  		var err error
   280  		firstArrowChunk, err := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool)
   281  		if err != nil {
   282  			return err
   283  		}
   284  		scd.FirstBatch = &ArrowBatch{
   285  			idx:                0,
   286  			scd:                scd,
   287  			funcDownloadHelper: scd.FuncDownloadHelper,
   288  			loc:                loc,
   289  		}
   290  		// decode first chunk if possible
   291  		if firstArrowChunk.allocator != nil {
   292  			scd.FirstBatch.rec, err = firstArrowChunk.decodeArrowBatch(scd)
   293  			if err != nil {
   294  				return err
   295  			}
   296  		}
   297  	}
   298  	chunkMetaLen := len(scd.ChunkMetas)
   299  	scd.ArrowBatches = make([]*ArrowBatch, chunkMetaLen)
   300  	for i := range scd.ArrowBatches {
   301  		scd.ArrowBatches[i] = &ArrowBatch{
   302  			idx:                i,
   303  			scd:                scd,
   304  			funcDownloadHelper: scd.FuncDownloadHelper,
   305  			loc:                loc,
   306  		}
   307  	}
   308  	return nil
   309  }
   310  
   311  /* largeResultSetReader is a reader that wraps the large result set with leading and tailing brackets. */
   312  type largeResultSetReader struct {
   313  	status int
   314  	body   io.Reader
   315  }
   316  
   317  func (r *largeResultSetReader) Read(p []byte) (n int, err error) {
   318  	if r.status == 0 {
   319  		p[0] = 0x5b // initial 0x5b ([)
   320  		r.status = 1
   321  		return 1, nil
   322  	}
   323  	if r.status == 1 {
   324  		var len int
   325  		len, err = r.body.Read(p)
   326  		if err == io.EOF {
   327  			r.status = 2
   328  			return len, nil
   329  		}
   330  		if err != nil {
   331  			return 0, err
   332  		}
   333  		return len, nil
   334  	}
   335  	if r.status == 2 {
   336  		p[0] = 0x5d // tail 0x5d (])
   337  		r.status = 3
   338  		return 1, nil
   339  	}
   340  	// ensure no data and EOF
   341  	return 0, io.EOF
   342  }
   343  
   344  func downloadChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int) {
   345  	logger.Infof("download start chunk: %v", idx+1)
   346  	defer scd.DoneDownloadCond.Broadcast()
   347  
   348  	if err := scd.FuncDownloadHelper(ctx, scd, idx); err != nil {
   349  		logger.Errorf(
   350  			"failed to extract HTTP response body. URL: %v, err: %v", scd.ChunkMetas[idx].URL, err)
   351  		scd.ChunksError <- &chunkError{Index: idx, Error: err}
   352  	} else if scd.ctx.Err() == context.Canceled || scd.ctx.Err() == context.DeadlineExceeded {
   353  		scd.ChunksError <- &chunkError{Index: idx, Error: scd.ctx.Err()}
   354  	}
   355  }
   356  
   357  func downloadChunkHelper(ctx context.Context, scd *snowflakeChunkDownloader, idx int) error {
   358  	headers := make(map[string]string)
   359  	if len(scd.ChunkHeader) > 0 {
   360  		logger.Debug("chunk header is provided.")
   361  		for k, v := range scd.ChunkHeader {
   362  			logger.Debugf("adding header: %v, value: %v", k, v)
   363  
   364  			headers[k] = v
   365  		}
   366  	} else {
   367  		headers[headerSseCAlgorithm] = headerSseCAes
   368  		headers[headerSseCKey] = scd.Qrmk
   369  	}
   370  
   371  	resp, err := scd.FuncGet(ctx, scd.sc, scd.ChunkMetas[idx].URL, headers, scd.sc.rest.RequestTimeout)
   372  	if err != nil {
   373  		return err
   374  	}
   375  	bufStream := bufio.NewReader(resp.Body)
   376  	defer resp.Body.Close()
   377  	logger.Debugf("response returned chunk: %v for URL: %v", idx+1, scd.ChunkMetas[idx].URL)
   378  	if resp.StatusCode != http.StatusOK {
   379  		b, err := io.ReadAll(bufStream)
   380  		if err != nil {
   381  			return err
   382  		}
   383  		logger.Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, scd.ChunkMetas[idx].URL, b)
   384  		logger.Infof("Header: %v", resp.Header)
   385  		return &SnowflakeError{
   386  			Number:      ErrFailedToGetChunk,
   387  			SQLState:    SQLStateConnectionFailure,
   388  			Message:     errMsgFailedToGetChunk,
   389  			MessageArgs: []interface{}{idx},
   390  		}
   391  	}
   392  	return decodeChunk(scd, idx, bufStream)
   393  }
   394  
   395  func decodeChunk(scd *snowflakeChunkDownloader, idx int, bufStream *bufio.Reader) (err error) {
   396  	gzipMagic, err := bufStream.Peek(2)
   397  	if err != nil {
   398  		return err
   399  	}
   400  	start := time.Now()
   401  	var source io.Reader
   402  	if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b {
   403  		// detects and uncompresses Gzip format data
   404  		bufStream0, err := gzip.NewReader(bufStream)
   405  		if err != nil {
   406  			return err
   407  		}
   408  		defer bufStream0.Close()
   409  		source = bufStream0
   410  	} else {
   411  		source = bufStream
   412  	}
   413  	st := &largeResultSetReader{
   414  		status: 0,
   415  		body:   source,
   416  	}
   417  	var respd []chunkRowType
   418  	if scd.getQueryResultFormat() != arrowFormat {
   419  		var decRespd [][]*string
   420  		if !CustomJSONDecoderEnabled {
   421  			dec := json.NewDecoder(st)
   422  			for {
   423  				if err = dec.Decode(&decRespd); err == io.EOF {
   424  					break
   425  				} else if err != nil {
   426  					return err
   427  				}
   428  			}
   429  		} else {
   430  			decRespd, err = decodeLargeChunk(st, scd.ChunkMetas[idx].RowCount, scd.CellCount)
   431  			if err != nil {
   432  				return err
   433  			}
   434  		}
   435  		respd = make([]chunkRowType, len(decRespd))
   436  		populateJSONRowSet(respd, decRespd)
   437  	} else {
   438  		ipcReader, err := ipc.NewReader(source, ipc.WithAllocator(scd.pool))
   439  		if err != nil {
   440  			return err
   441  		}
   442  		var loc *time.Location
   443  		if scd.sc != nil && scd.sc.cfg != nil {
   444  			loc = getCurrentLocation(scd.sc.cfg.Params)
   445  		}
   446  		arc := arrowResultChunk{
   447  			ipcReader,
   448  			0,
   449  			loc,
   450  			scd.pool,
   451  		}
   452  		if usesArrowBatches(scd.ctx) {
   453  			if scd.ArrowBatches[idx].rec, err = arc.decodeArrowBatch(scd); err != nil {
   454  				return err
   455  			}
   456  			// updating metadata
   457  			scd.ArrowBatches[idx].rowCount = countArrowBatchRows(scd.ArrowBatches[idx].rec)
   458  			return nil
   459  		}
   460  		highPrec := higherPrecisionEnabled(scd.ctx)
   461  		respd, err = arc.decodeArrowChunk(scd.RowSet.RowType, highPrec)
   462  		if err != nil {
   463  			return err
   464  		}
   465  	}
   466  	logger.Debugf(
   467  		"decoded %d rows w/ %d bytes in %s (chunk %v)",
   468  		scd.ChunkMetas[idx].RowCount,
   469  		scd.ChunkMetas[idx].UncompressedSize,
   470  		time.Since(start), idx+1,
   471  	)
   472  
   473  	scd.ChunksMutex.Lock()
   474  	defer scd.ChunksMutex.Unlock()
   475  	scd.Chunks[idx] = respd
   476  	return nil
   477  }
   478  
   479  func populateJSONRowSet(dst []chunkRowType, src [][]*string) {
   480  	// populate string rowset from src to dst's chunkRowType struct's RowSet field
   481  	for i, row := range src {
   482  		dst[i].RowSet = row
   483  	}
   484  }
   485  
   486  type streamChunkDownloader struct {
   487  	ctx            context.Context
   488  	id             int64
   489  	fetcher        streamChunkFetcher
   490  	readErr        error
   491  	rowStream      chan []*string
   492  	Total          int64
   493  	ChunkMetas     []execResponseChunk
   494  	NextDownloader chunkDownloader
   495  	RowSet         rowSetType
   496  }
   497  
   498  func (scd *streamChunkDownloader) totalUncompressedSize() (acc int64) {
   499  	return -1
   500  }
   501  
   502  func (scd *streamChunkDownloader) hasNextResultSet() bool {
   503  	return scd.readErr == nil
   504  }
   505  
   506  func (scd *streamChunkDownloader) nextResultSet() error {
   507  	return scd.readErr
   508  }
   509  
   510  func (scd *streamChunkDownloader) start() error {
   511  	go func() {
   512  		readErr := io.EOF
   513  
   514  		logger.WithContext(scd.ctx).Infof(
   515  			"start downloading. downloader id: %v, %v/%v rows, %v chunks",
   516  			scd.id, len(scd.RowSet.RowType), scd.Total, len(scd.ChunkMetas))
   517  		t := time.Now()
   518  
   519  		defer func() {
   520  			if readErr == io.EOF {
   521  				logger.WithContext(scd.ctx).Infof("downloading done. downloader id: %v", scd.id)
   522  			} else {
   523  				logger.WithContext(scd.ctx).Debugf("downloading error. downloader id: %v", scd.id)
   524  			}
   525  			scd.readErr = readErr
   526  			close(scd.rowStream)
   527  
   528  			if r := recover(); r != nil {
   529  				if err, ok := r.(error); ok {
   530  					readErr = err
   531  				} else {
   532  					readErr = fmt.Errorf("%v", r)
   533  				}
   534  			}
   535  		}()
   536  
   537  		logger.WithContext(scd.ctx).Infof("sending initial set of rows in %vms", time.Since(t).Microseconds())
   538  		t = time.Now()
   539  		for _, row := range scd.RowSet.JSON {
   540  			scd.rowStream <- row
   541  		}
   542  		scd.RowSet.JSON = nil
   543  
   544  		// Download and parse one chunk at a time. The fetcher will send each
   545  		// parsed row to the row stream. When an error occurs, the fetcher will
   546  		// stop writing to the row stream so we can stop processing immediately
   547  		for i, chunk := range scd.ChunkMetas {
   548  			logger.WithContext(scd.ctx).Infof("starting chunk fetch %d (%d rows)", i, chunk.RowCount)
   549  			if err := scd.fetcher.fetch(chunk.URL, scd.rowStream); err != nil {
   550  				logger.WithContext(scd.ctx).Debugf(
   551  					"failed chunk fetch %d: %#v, downloader id: %v, %v/%v rows, %v chunks",
   552  					i, err, scd.id, len(scd.RowSet.RowType), scd.Total, len(scd.ChunkMetas))
   553  				readErr = fmt.Errorf("chunk fetch: %w", err)
   554  				break
   555  			}
   556  			logger.WithContext(scd.ctx).Infof("fetched chunk %d (%d rows) in %vms", i, chunk.RowCount, time.Since(t).Microseconds())
   557  			t = time.Now()
   558  		}
   559  	}()
   560  	return nil
   561  }
   562  
   563  func (scd *streamChunkDownloader) next() (chunkRowType, error) {
   564  	if row, ok := <-scd.rowStream; ok {
   565  		return chunkRowType{RowSet: row}, nil
   566  	}
   567  	return chunkRowType{}, scd.readErr
   568  }
   569  
   570  func (scd *streamChunkDownloader) reset() {}
   571  
   572  func (scd *streamChunkDownloader) getChunkMetas() []execResponseChunk {
   573  	return scd.ChunkMetas
   574  }
   575  
   576  func (scd *streamChunkDownloader) getQueryResultFormat() resultFormat {
   577  	return jsonFormat
   578  }
   579  
   580  func (scd *streamChunkDownloader) setNextChunkDownloader(nextDownloader chunkDownloader) {
   581  	scd.NextDownloader = nextDownloader
   582  }
   583  
   584  func (scd *streamChunkDownloader) getNextChunkDownloader() chunkDownloader {
   585  	return scd.NextDownloader
   586  }
   587  
   588  func (scd *streamChunkDownloader) getRowType() []execResponseRowType {
   589  	return scd.RowSet.RowType
   590  }
   591  
   592  func (scd *streamChunkDownloader) getArrowBatches() []*ArrowBatch {
   593  	return nil
   594  }
   595  
   596  func useStreamDownloader(ctx context.Context) bool {
   597  	val := ctx.Value(streamChunkDownload)
   598  	if val == nil {
   599  		return false
   600  	}
   601  	s, ok := val.(bool)
   602  	return s && ok
   603  }
   604  
   605  type streamChunkFetcher interface {
   606  	fetch(url string, rows chan<- []*string) error
   607  }
   608  
   609  type httpStreamChunkFetcher struct {
   610  	ctx      context.Context
   611  	client   *http.Client
   612  	clientIP net.IP
   613  	headers  map[string]string
   614  	qrmk     string
   615  }
   616  
   617  func newStreamChunkDownloader(
   618  	ctx context.Context,
   619  	fetcher streamChunkFetcher,
   620  	total int64,
   621  	rowType []execResponseRowType,
   622  	firstRows [][]*string,
   623  	chunks []execResponseChunk,
   624  ) *streamChunkDownloader {
   625  	return &streamChunkDownloader{
   626  		ctx:        ctx,
   627  		id:         rand.Int63(),
   628  		fetcher:    fetcher,
   629  		readErr:    nil,
   630  		rowStream:  make(chan []*string),
   631  		Total:      total,
   632  		ChunkMetas: chunks,
   633  		RowSet:     rowSetType{RowType: rowType, JSON: firstRows},
   634  	}
   635  }
   636  
   637  func (f *httpStreamChunkFetcher) fetch(URL string, rows chan<- []*string) error {
   638  	if len(f.headers) == 0 {
   639  		f.headers = map[string]string{
   640  			headerSseCAlgorithm: headerSseCAes,
   641  			headerSseCKey:       f.qrmk,
   642  		}
   643  	}
   644  
   645  	fullURL, err := url.Parse(URL)
   646  	if err != nil {
   647  		return err
   648  	}
   649  	res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0, 0, defaultTimeProvider, nil).execute()
   650  	if err != nil {
   651  		return err
   652  	}
   653  	defer res.Body.Close()
   654  	if res.StatusCode != http.StatusOK {
   655  		b, err := io.ReadAll(res.Body)
   656  		if err != nil {
   657  			return err
   658  		}
   659  		return fmt.Errorf("status (%d): %s", res.StatusCode, string(b))
   660  	}
   661  	if err = copyChunkStream(res.Body, rows); err != nil {
   662  		return fmt.Errorf("read: %w", err)
   663  	}
   664  	return nil
   665  }
   666  
   667  func copyChunkStream(body io.Reader, rows chan<- []*string) error {
   668  	bufStream := bufio.NewReader(body)
   669  	gzipMagic, err := bufStream.Peek(2)
   670  	if err != nil {
   671  		return err
   672  	}
   673  	var source io.Reader
   674  	if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b {
   675  		// detect and decompress Gzip format data
   676  		bufStream0, err := gzip.NewReader(bufStream)
   677  		if err != nil {
   678  			return err
   679  		}
   680  		defer bufStream0.Close()
   681  		source = bufStream0
   682  	} else {
   683  		source = bufStream
   684  	}
   685  	r := io.MultiReader(strings.NewReader("["), source, strings.NewReader("]"))
   686  	dec := json.NewDecoder(r)
   687  	openToken := json.Delim('[')
   688  	closeToken := json.Delim(']')
   689  	for {
   690  		if t, err := dec.Token(); err == io.EOF {
   691  			break
   692  		} else if err != nil {
   693  			return fmt.Errorf("delim open: %w", err)
   694  		} else if t != openToken {
   695  			return fmt.Errorf("delim open: got %T", t)
   696  		}
   697  		for dec.More() {
   698  			var row []*string
   699  			if err = dec.Decode(&row); err != nil {
   700  				return fmt.Errorf("decode: %w", err)
   701  			}
   702  			rows <- row
   703  		}
   704  		if t, err := dec.Token(); err != nil {
   705  			return fmt.Errorf("delim close: %w", err)
   706  		} else if t != closeToken {
   707  			return fmt.Errorf("delim close: got %T", t)
   708  		}
   709  	}
   710  	return nil
   711  }
   712  
   713  // ArrowBatch object represents a chunk of data, or subset of rows, retrievable in arrow.Record format
   714  type ArrowBatch struct {
   715  	rec                *[]arrow.Record
   716  	idx                int
   717  	rowCount           int
   718  	scd                *snowflakeChunkDownloader
   719  	funcDownloadHelper func(context.Context, *snowflakeChunkDownloader, int) error
   720  	ctx                context.Context
   721  	loc                *time.Location
   722  }
   723  
   724  // WithContext sets the context which will be used for this ArrowBatch.
   725  func (rb *ArrowBatch) WithContext(ctx context.Context) *ArrowBatch {
   726  	rb.ctx = ctx
   727  	return rb
   728  }
   729  
   730  // Fetch returns an array of records representing a chunk in the query
   731  func (rb *ArrowBatch) Fetch() (*[]arrow.Record, error) {
   732  	// chunk has already been downloaded
   733  	if rb.rec != nil {
   734  		// updating metadata
   735  		rb.rowCount = countArrowBatchRows(rb.rec)
   736  		return rb.rec, nil
   737  	}
   738  	var ctx context.Context
   739  	if rb.ctx != nil {
   740  		ctx = rb.ctx
   741  	} else {
   742  		ctx = context.Background()
   743  	}
   744  	if err := rb.funcDownloadHelper(ctx, rb.scd, rb.idx); err != nil {
   745  		return nil, err
   746  	}
   747  	return rb.rec, nil
   748  }
   749  
   750  // GetRowCount returns the number of rows in an arrow batch
   751  func (rb *ArrowBatch) GetRowCount() int {
   752  	return rb.rowCount
   753  }
   754  
   755  func getAllocator(ctx context.Context) memory.Allocator {
   756  	pool, ok := ctx.Value(arrowAlloc).(memory.Allocator)
   757  	if !ok {
   758  		return memory.DefaultAllocator
   759  	}
   760  	return pool
   761  }
   762  
   763  func usesArrowBatches(ctx context.Context) bool {
   764  	val := ctx.Value(arrowBatches)
   765  	if val == nil {
   766  		return false
   767  	}
   768  	a, ok := val.(bool)
   769  	return a && ok
   770  }
   771  
   772  func countArrowBatchRows(recs *[]arrow.Record) int {
   773  	var cnt int
   774  	for _, r := range *recs {
   775  		cnt += int(r.NumRows())
   776  	}
   777  	return cnt
   778  }